Skip to content

Commit

Permalink
Updated to_dict / from_dict to handle 'token' correctly ; Added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mpangrazzi committed Dec 19, 2024
1 parent 505b9d4 commit 35edc33
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 12 deletions.
5 changes: 5 additions & 0 deletions e2e/pipelines/test_named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import pytest
import os

from haystack import Document, Pipeline
from haystack.components.extractors import NamedEntityAnnotation, NamedEntityExtractor, NamedEntityExtractorBackend
Expand Down Expand Up @@ -66,6 +67,10 @@ def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size):


@pytest.mark.parametrize("batch_size", [1, 3])
@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
)
def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size):
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")
extractor.warm_up()
Expand Down
20 changes: 16 additions & 4 deletions haystack/components/extractors/named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

from haystack import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret
from haystack.utils.auth import Secret, deserialize_secrets_inplace
from haystack.utils.device import ComponentDevice
from haystack.utils.hf import resolve_hf_pipeline_kwargs
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs

with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
Expand Down Expand Up @@ -140,6 +140,7 @@ def __init__(

self._backend: _NerBackend
self._warmed_up: bool = False
self.token = token
device = ComponentDevice.resolve_device(device)

if backend == NamedEntityExtractorBackend.HUGGING_FACE:
Expand Down Expand Up @@ -215,14 +216,21 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
serialization_dict = default_to_dict(
self,
backend=self._backend.type.name,
model=self._backend.model_name,
device=self._backend.device.to_dict(),
pipeline_kwargs=self._backend._pipeline_kwargs,
token=self.token.to_dict() if self.token else None,
)

hf_pipeline_kwargs = serialization_dict["init_parameters"]["pipeline_kwargs"]
hf_pipeline_kwargs.pop("token", None)

serialize_hf_model_kwargs(hf_pipeline_kwargs)
return serialization_dict

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor":
"""
Expand All @@ -234,10 +242,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor":
Deserialized component.
"""
try:
init_params = data["init_parameters"]
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
init_params = data.get("init_parameters", {})
if init_params.get("device") is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]]

hf_pipeline_kwargs = init_params.get("pipeline_kwargs", {})
deserialize_hf_model_kwargs(hf_pipeline_kwargs)
return default_from_dict(cls, data)
except Exception as e:
raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e
Expand Down
56 changes: 54 additions & 2 deletions test/components/extractors/test_named_entity_extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from haystack.utils.auth import Secret
import pytest

from haystack import ComponentError, DeserializationError, Pipeline
Expand All @@ -11,7 +12,7 @@
def test_named_entity_extractor_backend():
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")

# private model, will fail if HF_API_TOKEN is not set
# private model
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")

_ = NamedEntityExtractor(backend="hugging_face", model="dslim/bert-base-NER")
Expand Down Expand Up @@ -43,7 +44,58 @@ def test_named_entity_extractor_serde():
_ = NamedEntityExtractor.from_dict(serde_data)


def test_named_entity_extractor_from_dict_no_default_parameters_hf():
def test_to_dict_default(monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)

component = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE,
model="dslim/bert-base-NER",
device=ComponentDevice.from_str("mps"),
)
data = component.to_dict()

assert data == {
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
"init_parameters": {
"backend": "HUGGING_FACE",
"model": "dslim/bert-base-NER",
"device": {"type": "single", "device": "mps"},
"pipeline_kwargs": {"model": "dslim/bert-base-NER", "device": "mps", "task": "ner"},
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
},
}


def test_to_dict_with_parameters():
component = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE,
model="dslim/bert-base-NER",
device=ComponentDevice.from_str("mps"),
pipeline_kwargs={"model_kwargs": {"load_in_4bit": True}},
token=Secret.from_env_var("ENV_VAR", strict=False),
)
data = component.to_dict()

assert data == {
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
"init_parameters": {
"backend": "HUGGING_FACE",
"model": "dslim/bert-base-NER",
"device": {"type": "single", "device": "mps"},
"pipeline_kwargs": {
"model": "dslim/bert-base-NER",
"device": "mps",
"task": "ner",
"model_kwargs": {"load_in_4bit": True},
},
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
},
}


def test_named_entity_extractor_from_dict_no_default_parameters_hf(monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)

data = {
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
"init_parameters": {"backend": "HUGGING_FACE", "model": "dslim/bert-base-NER"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,7 @@ def test_init_huggingface_pipeline_kwargs_override_other_parameters(self):
If they are provided, they should override other init parameters.
"""

huggingface_pipeline_kwargs = {
"model": "gpt2",
"task": "text-generation",
"device": "cuda:0",
"token": "another-test-token",
}
huggingface_pipeline_kwargs = {"model": "gpt2", "device": "cuda:0", "token": "another-test-token"}

generator = HuggingFaceLocalGenerator(
model="google/flan-t5-base",
Expand Down

0 comments on commit 35edc33

Please sign in to comment.