Skip to content

Commit

Permalink
community[patch]: Fix generation_config not setting properly for Deep…
Browse files Browse the repository at this point in the history
…Sparse (#15036)

- **Description:** Tiny but important bugfix to use a more stable
interface for specifying generation_config parameters for DeepSparse LLM
  • Loading branch information
mgoin authored Dec 22, 2023
1 parent 2460f97 commit 501cc83
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions libs/community/langchain_community/llms/deepsparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def validate_environment(cls, values: Dict) -> Dict:
except ImportError:
raise ImportError(
"Could not import `deepsparse` package. "
"Please install it with `pip install deepsparse`"
"Please install it with `pip install deepsparse[llm]`"
)

model_config = values["model_config"] or {}
Expand Down Expand Up @@ -103,9 +103,7 @@ def _call(
text = combined_output
else:
text = (
self.pipeline(
sequences=prompt, generation_config=self.generation_config
)
self.pipeline(sequences=prompt, **self.generation_config)
.generations[0]
.text
)
Expand Down Expand Up @@ -143,9 +141,7 @@ async def _acall(
text = combined_output
else:
text = (
self.pipeline(
sequences=prompt, generation_config=self.generation_config
)
self.pipeline(sequences=prompt, **self.generation_config)
.generations[0]
.text
)
Expand Down Expand Up @@ -184,7 +180,7 @@ def _stream(
print(chunk, end='', flush=True)
"""
inference = self.pipeline(
sequences=prompt, generation_config=self.generation_config, streaming=True
sequences=prompt, streaming=True, **self.generation_config
)
for token in inference:
chunk = GenerationChunk(text=token.generations[0].text)
Expand Down Expand Up @@ -222,7 +218,7 @@ async def _astream(
print(chunk, end='', flush=True)
"""
inference = self.pipeline(
sequences=prompt, generation_config=self.generation_config, streaming=True
sequences=prompt, streaming=True, **self.generation_config
)
for token in inference:
chunk = GenerationChunk(text=token.generations[0].text)
Expand Down

0 comments on commit 501cc83

Please sign in to comment.