Skip to content

Commit

Permalink
chore: Update huggingface_hub classes used after library upgrade (#7631)
Browse files Browse the repository at this point in the history
* Update huggingface_hub classes used after library upgrade

* Fix chat tests

* Update lazy import guard and other references to huggingface_hub>=0.23.0

* In huggingface_hub 0.23.0 TextGenerationOutput property details is now optional

* More fixes

* Add reno note
  • Loading branch information
vblagoje authored May 3, 2024
1 parent db87074 commit 5f81337
Show file tree
Hide file tree
Showing 15 changed files with 81 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
from huggingface_hub import InferenceClient

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
from huggingface_hub import InferenceClient

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFModelType, check_valid_model

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
from huggingface_hub import InferenceClient

logger = logging.getLogger(__name__)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFModelType, check_valid_model

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
from huggingface_hub import InferenceClient

logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion haystack/components/generators/chat/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.22.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.23.0\"'") as huggingface_hub_import:
from huggingface_hub import ChatCompletionOutput, ChatCompletionStreamOutput, InferenceClient


Expand Down
21 changes: 14 additions & 7 deletions haystack/components/generators/chat/hugging_face_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFModelType, check_generation_params, check_valid_model, list_inference_deployed_models

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\" transformers'") as transformers_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\" transformers'") as transformers_import:
from huggingface_hub import (
InferenceClient,
TextGenerationOutput,
Expand Down Expand Up @@ -275,13 +275,13 @@ def _run_streaming(
message = ChatMessage.from_assistant(chunk.generated_text)
message.meta.update(
{
"finish_reason": chunk.details.finish_reason,
"finish_reason": chunk.details.finish_reason if chunk.details else None,
"index": 0,
"model": self.client.model,
"usage": {
"completion_tokens": chunk.details.generated_tokens,
"completion_tokens": chunk.details.generated_tokens if chunk.details else 0,
"prompt_tokens": prompt_token_count,
"total_tokens": prompt_token_count + chunk.details.generated_tokens,
"total_tokens": prompt_token_count + chunk.details.generated_tokens if chunk.details else 0,
},
}
)
Expand All @@ -294,15 +294,22 @@ def _run_non_streaming(
for _i in range(num_responses):
tgr: TextGenerationOutput = self.client.text_generation(prepared_prompt, details=True, **generation_kwargs)
message = ChatMessage.from_assistant(tgr.generated_text)
if tgr.details:
completion_tokens = len(tgr.details.tokens)
prompt_token_count = prompt_token_count + completion_tokens
finish_reason = tgr.details.finish_reason
else:
finish_reason = None
completion_tokens = 0
message.meta.update(
{
"finish_reason": tgr.details.finish_reason,
"finish_reason": finish_reason,
"index": _i,
"model": self.client.model,
"usage": {
"completion_tokens": len(tgr.details.tokens),
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_token_count,
"total_tokens": prompt_token_count + len(tgr.details.tokens),
"total_tokens": prompt_token_count + completion_tokens,
},
}
)
Expand Down
6 changes: 3 additions & 3 deletions haystack/components/generators/hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from haystack.utils.hf import HFGenerationAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
from huggingface_hub import (
InferenceClient,
TextGenerationOutput,
Expand Down Expand Up @@ -208,8 +208,8 @@ def _run_non_streaming(self, prompt: str, generation_kwargs: Dict[str, Any]):
meta = [
{
"model": self._client.model,
"finish_reason": tgr.details.finish_reason,
"usage": {"completion_tokens": len(tgr.details.tokens)},
"finish_reason": tgr.details.finish_reason if tgr.details else None,
"usage": {"completion_tokens": len(tgr.details.tokens) if tgr.details else 0},
}
]
return {"replies": [tgr.generated_text], "meta": meta}
17 changes: 12 additions & 5 deletions haystack/components/generators/hugging_face_tgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
from haystack.utils.hf import HFModelType, check_generation_params, check_valid_model, list_inference_deployed_models

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\" transformers'") as transformers_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\" transformers'") as transformers_import:
from huggingface_hub import (
InferenceClient,
TextGenerationOutput,
Expand Down Expand Up @@ -57,7 +57,7 @@ class HuggingFaceTGIGenerator:
client = HuggingFaceTGIGenerator(model="mistralai/Mistral-7B-v0.1", token=Secret.from_token("<your-api-key>"))
client.warm_up()
response = client.run("What's Natural Language Processing?", max_new_tokens=120)
response = client.run("What's Natural Language Processing?", generation_kwargs={"max_new_tokens": 120})
print(response)
```
Expand Down Expand Up @@ -255,15 +255,22 @@ def _run_non_streaming(
all_metadata: List[Dict[str, Any]] = []
for _i in range(num_responses):
tgr: TextGenerationOutput = self.client.text_generation(prompt, details=True, **generation_kwargs)
if tgr.details:
completion_tokens = len(tgr.details.tokens)
prompt_token_count = prompt_token_count + completion_tokens
finish_reason = tgr.details.finish_reason
else:
finish_reason = None
completion_tokens = 0
all_metadata.append(
{
"model": self.client.model,
"index": _i,
"finish_reason": tgr.details.finish_reason,
"finish_reason": finish_reason,
"usage": {
"completion_tokens": len(tgr.details.tokens),
"completion_tokens": completion_tokens,
"prompt_tokens": prompt_token_count,
"total_tokens": prompt_token_count + len(tgr.details.tokens),
"total_tokens": prompt_token_count + completion_tokens,
},
}
)
Expand Down
2 changes: 1 addition & 1 deletion haystack/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
with LazyImport(message="Run 'pip install transformers[torch]'") as torch_import:
import torch

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
with LazyImport(message="Run 'pip install \"huggingface_hub>=0.23.0\"'") as huggingface_hub_import:
from huggingface_hub import HfApi, InferenceClient, model_info
from huggingface_hub.utils import RepositoryNotFoundError

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ format-check = "black --check ."
[tool.hatch.envs.test]
extra-dependencies = [
"transformers[torch,sentencepiece]==4.38.2", # ExtractiveReader, TransformersSimilarityRanker, LocalWhisperTranscriber, HFGenerators...
"huggingface_hub>=0.22.0", # TGI Generators and TEI Embedders
"huggingface_hub>=0.23.0", # TGI Generators and TEI Embedders
"spacy>=3.7,<3.8", # NamedEntityExtractor
"spacy-curated-transformers>=0.2,<=0.3", # NamedEntityExtractor
"en-core-web-trf @ https://github.com/explosion/spacy-models/releases/download/en_core_web_trf-3.7.3/en_core_web_trf-3.7.3-py3-none-any.whl", # NamedEntityExtractor
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
upgrade:
- |
Upgraded the required version of `huggingface_hub` to `>=0.23.0` across various modules to ensure compatibility and leverage the latest features. This update includes modifications to error handling for token generation details and introduces adjustments in the chat and text generation interfaces to enhance functionality and developer experience. Users are advised to upgrade their `huggingface_hub` dependency.
23 changes: 17 additions & 6 deletions test/components/generators/chat/test_hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import pytest
from huggingface_hub import (
ChatCompletionOutput,
ChatCompletionOutputChoice,
ChatCompletionOutputChoiceMessage,
ChatCompletionStreamOutput,
ChatCompletionOutputComplete,
ChatCompletionStreamOutputChoice,
ChatCompletionOutputMessage,
ChatCompletionStreamOutputDelta,
)
from huggingface_hub.utils import RepositoryNotFoundError
Expand All @@ -33,14 +33,17 @@ def mock_chat_completion():
with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion:
completion = ChatCompletionOutput(
choices=[
ChatCompletionOutputChoice(
ChatCompletionOutputComplete(
finish_reason="eos_token",
index=0,
message=ChatCompletionOutputChoiceMessage(
content="The capital of France is Paris.", role="assistant"
),
message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"),
)
],
id="some_id",
model="some_model",
object="some_object",
system_fingerprint="some_fingerprint",
usage={"completion_tokens": 10, "prompt_tokens": 5, "total_tokens": 15},
created=1710498360,
)

Expand Down Expand Up @@ -208,6 +211,10 @@ def mock_iter(self):
finish_reason=None,
)
],
id="some_id",
model="some_model",
object="some_object",
system_fingerprint="some_fingerprint",
created=1710498504,
)

Expand All @@ -217,6 +224,10 @@ def mock_iter(self):
delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length"
)
],
id="some_id",
model="some_model",
object="some_object",
system_fingerprint="some_fingerprint",
created=1710498504,
)

Expand Down
10 changes: 8 additions & 2 deletions test/components/generators/chat/test_hugging_face_tgi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from unittest.mock import MagicMock, Mock, patch

import pytest
from huggingface_hub import TextGenerationOutputToken, TextGenerationStreamDetails, TextGenerationStreamOutput
from huggingface_hub import (
TextGenerationOutputToken,
TextGenerationStreamOutput,
TextGenerationStreamOutputStreamDetails,
)
from huggingface_hub.utils import RepositoryNotFoundError

from haystack.components.generators.chat import HuggingFaceTGIChatGenerator
Expand Down Expand Up @@ -329,13 +333,15 @@ def streaming_callback_fn(chunk: StreamingChunk):
# self needed here, don't remove
def mock_iter(self):
yield TextGenerationStreamOutput(
index=0,
generated_text=None,
token=TextGenerationOutputToken(id=1, text="I'm fine, thanks.", logprob=0.0, special=False),
)
yield TextGenerationStreamOutput(
index=0,
generated_text=None,
token=TextGenerationOutputToken(id=1, text="Ok bye", logprob=0.0, special=False),
details=TextGenerationStreamDetails(finish_reason="length", generated_tokens=5, seed=None),
details=TextGenerationStreamOutputStreamDetails(finish_reason="length", generated_tokens=5, seed=None),
)

mock_response = Mock(**{"__iter__": mock_iter})
Expand Down
10 changes: 8 additions & 2 deletions test/components/generators/test_hugging_face_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
from unittest.mock import MagicMock, Mock, patch

import pytest
from huggingface_hub import TextGenerationOutputToken, TextGenerationStreamDetails, TextGenerationStreamOutput
from huggingface_hub import (
TextGenerationOutputToken,
TextGenerationStreamOutput,
TextGenerationStreamOutputStreamDetails,
)
from huggingface_hub.utils import RepositoryNotFoundError

from haystack.components.generators import HuggingFaceAPIGenerator
Expand Down Expand Up @@ -236,13 +240,15 @@ def streaming_callback_fn(chunk: StreamingChunk):
# Don't remove self
def mock_iter(self):
yield TextGenerationStreamOutput(
index=0,
generated_text=None,
token=TextGenerationOutputToken(id=1, text="I'm fine, thanks.", logprob=0.0, special=False),
)
yield TextGenerationStreamOutput(
index=1,
generated_text=None,
token=TextGenerationOutputToken(id=1, text="Ok bye", logprob=0.0, special=False),
details=TextGenerationStreamDetails(finish_reason="length", generated_tokens=5, seed=None),
details=TextGenerationStreamOutputStreamDetails(finish_reason="length", generated_tokens=5, seed=None),
)

mock_response = Mock(**{"__iter__": mock_iter})
Expand Down
10 changes: 8 additions & 2 deletions test/components/generators/test_hugging_face_tgi.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from unittest.mock import MagicMock, Mock, patch

import pytest
from huggingface_hub import TextGenerationOutputToken, TextGenerationStreamDetails, TextGenerationStreamOutput
from huggingface_hub import (
TextGenerationOutputToken,
TextGenerationStreamOutput,
TextGenerationStreamOutputStreamDetails,
)
from huggingface_hub.utils import RepositoryNotFoundError

from haystack.components.generators import HuggingFaceTGIGenerator
Expand Down Expand Up @@ -271,13 +275,15 @@ def streaming_callback_fn(chunk: StreamingChunk):
# Don't remove self
def mock_iter(self):
yield TextGenerationStreamOutput(
index=0,
generated_text=None,
token=TextGenerationOutputToken(id=1, text="I'm fine, thanks.", logprob=0.0, special=False),
)
yield TextGenerationStreamOutput(
index=1,
generated_text=None,
token=TextGenerationOutputToken(id=1, text="Ok bye", logprob=0.0, special=False),
details=TextGenerationStreamDetails(finish_reason="length", generated_tokens=5, seed=None),
details=TextGenerationStreamOutputStreamDetails(finish_reason="length", generated_tokens=5, seed=None),
)

mock_response = Mock(**{"__iter__": mock_iter})
Expand Down

0 comments on commit 5f81337

Please sign in to comment.