Skip to content

Commit

Permalink
fix: HuggingFaceAPIGenerator - use forward references (#8502)
Browse files Browse the repository at this point in the history
* hf API generator: forward references + refactor

* release note
  • Loading branch information
anakin87 authored Oct 30, 2024
1 parent 8a35e79 commit 700684a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
27 changes: 15 additions & 12 deletions haystack/components/generators/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,37 +204,40 @@ def run(
# check if streaming_callback is passed
streaming_callback = streaming_callback or self.streaming_callback

stream = streaming_callback is not None
response = self._client.text_generation(prompt, details=True, stream=stream, **generation_kwargs)
hf_output = self._client.text_generation(
prompt, details=True, stream=streaming_callback is not None, **generation_kwargs
)

if streaming_callback is not None:
return self._stream_and_build_response(hf_output, streaming_callback)

output = self._get_stream_response(response, streaming_callback) if stream else self._get_response(response) # type: ignore
return output
return self._build_non_streaming_response(hf_output)

def _get_stream_response(
self, response: Iterable[TextGenerationStreamOutput], streaming_callback: Callable[[StreamingChunk], None]
def _stream_and_build_response(
self, hf_output: Iterable["TextGenerationStreamOutput"], streaming_callback: Callable[[StreamingChunk], None]
):
chunks: List[StreamingChunk] = []
for chunk in response:
for chunk in hf_output:
token: TextGenerationOutputToken = chunk.token
if token.special:
continue
chunk_metadata = {**asdict(token), **(asdict(chunk.details) if chunk.details else {})}
stream_chunk = StreamingChunk(token.text, chunk_metadata)
chunks.append(stream_chunk)
streaming_callback(stream_chunk) # type: ignore # streaming_callback is not None (verified in the run method)
streaming_callback(stream_chunk)
metadata = {
"finish_reason": chunks[-1].meta.get("finish_reason", None),
"model": self._client.model,
"usage": {"completion_tokens": chunks[-1].meta.get("generated_tokens", 0)},
}
return {"replies": ["".join([chunk.content for chunk in chunks])], "meta": [metadata]}

def _get_response(self, response: TextGenerationOutput):
def _build_non_streaming_response(self, hf_output: "TextGenerationOutput"):
meta = [
{
"model": self._client.model,
"finish_reason": response.details.finish_reason if response.details else None,
"usage": {"completion_tokens": len(response.details.tokens) if response.details else 0},
"finish_reason": hf_output.details.finish_reason if hf_output.details else None,
"usage": {"completion_tokens": len(hf_output.details.tokens) if hf_output.details else 0},
}
]
return {"replies": [response.generated_text], "meta": meta}
return {"replies": [hf_output.generated_text], "meta": meta}
5 changes: 5 additions & 0 deletions releasenotes/notes/hfapigen-forwardref-5c06090282557195.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
fixes:
- |
Use forward references for Hugging Face Hub types in the `HuggingFaceAPIGenerator` component
to prevent import errors.

0 comments on commit 700684a

Please sign in to comment.