Skip to content

Commit

Permalink
In huggingface_hub 0.23.0 TextGenerationOutput property details is no…
Browse files Browse the repository at this point in the history
…w optional
  • Loading branch information
vblagoje committed May 2, 2024
1 parent 472ef4f commit 2082532
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
13 changes: 10 additions & 3 deletions haystack/components/generators/chat/hugging_face_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,22 @@ def _run_non_streaming(
for _i in range(num_responses):
tgr: TextGenerationOutput = self.client.text_generation(prepared_prompt, details=True, **generation_kwargs)
message = ChatMessage.from_assistant(tgr.generated_text)
if tgr.details:
completion_tokens = len(tgr.details.tokens)
prompt_token_count = prompt_token_count + completion_tokens
finish_reason = tgr.details.finish_reason
else:
finish_reason = None
completion_tokens = 0
message.meta.update(
{
"finish_reason": tgr.details.finish_reason,
"finish_reason": finish_reason,
"index": _i,
"model": self.client.model,
"usage": {
"completion_tokens": len(tgr.details.tokens),
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_token_count,
"total_tokens": prompt_token_count + len(tgr.details.tokens),
"total_tokens": prompt_token_count + completion_tokens,
},
}
)
Expand Down
4 changes: 2 additions & 2 deletions haystack/components/generators/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def _run_non_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]):
meta = [
{
"model": self._client.model,
"finish_reason": tgr.details.finish_reason,
"usage": {"completion_tokens": len(tgr.details.tokens)},
"finish_reason": tgr.details.finish_reason if tgr.details else None,
"usage": {"completion_tokens": len(tgr.details.tokens) if tgr.details else 0},
}
]
return {"replies": [tgr.generated_text], "meta": meta}
13 changes: 10 additions & 3 deletions haystack/components/generators/hugging_face_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,22 @@ def _run_non_streaming(
all_metadata: List[Dict[str, Any]] = []
for _i in range(num_responses):
tgr: TextGenerationOutput = self.client.text_generation(prompt, details=True, **generation_kwargs)
if tgr.details:
completion_tokens = len(tgr.details.tokens)
prompt_token_count = prompt_token_count + completion_tokens
finish_reason = tgr.details.finish_reason
else:
finish_reason = None
completion_tokens = 0
all_metadata.append(
{
"model": self.client.model,
"index": _i,
"finish_reason": tgr.details.finish_reason,
"finish_reason": finish_reason,
"usage": {
"completion_tokens": len(tgr.details.tokens),
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_token_count,
"total_tokens": prompt_token_count + len(tgr.details.tokens),
"total_tokens": prompt_token_count + completion_tokens,
},
}
)
Expand Down

0 comments on commit 2082532

Please sign in to comment.