diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 7e817f2039..ad47c7a68a 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -19,6 +19,7 @@ env: PYTHON_VERSION: "3.9" OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} HATCH_VERSION: "1.13.0" + HF_API_TOKEN: ${{ secrets.HUGGINGFACE_API_KEY }} jobs: run: diff --git a/e2e/pipelines/test_named_entity_extractor.py b/e2e/pipelines/test_named_entity_extractor.py index 2fc15a9a1a..bade774048 100644 --- a/e2e/pipelines/test_named_entity_extractor.py +++ b/e2e/pipelines/test_named_entity_extractor.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import os import pytest from haystack import Document, Pipeline @@ -65,6 +66,18 @@ def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size): _extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size) +@pytest.mark.parametrize("batch_size", [1, 3]) +@pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", +) +def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size): + extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER") + extractor.warm_up() + + _extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size) + + @pytest.mark.parametrize("batch_size", [1, 3]) def test_ner_extractor_spacy_backend(raw_texts, spacy_annotations, batch_size): extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_trf") diff --git a/haystack/components/extractors/named_entity_extractor.py b/haystack/components/extractors/named_entity_extractor.py index 00350e6aed..5651530dbf 100644 --- a/haystack/components/extractors/named_entity_extractor.py +++ b/haystack/components/extractors/named_entity_extractor.py @@ -10,7 +10,9 @@ from haystack import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict from haystack.lazy_imports import LazyImport +from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack.utils.device import ComponentDevice +from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import: from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline @@ -110,6 +112,7 @@ def __init__( model: str, pipeline_kwargs: Optional[Dict[str, Any]] = None, device: Optional[ComponentDevice] = None, + token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), ) -> None: """ Create a Named Entity extractor component. @@ -128,6 +131,8 @@ def __init__( device/device map is specified in `pipeline_kwargs`, it overrides this parameter (only applicable to the HuggingFace backend). + :param token: + The API token to download private models from Hugging Face. """ if isinstance(backend, str): @@ -135,9 +140,19 @@ def __init__( self._backend: _NerBackend self._warmed_up: bool = False + self.token = token device = ComponentDevice.resolve_device(device) if backend == NamedEntityExtractorBackend.HUGGING_FACE: + pipeline_kwargs = resolve_hf_pipeline_kwargs( + huggingface_pipeline_kwargs=pipeline_kwargs or {}, + model=model, + task="ner", + supported_tasks=["ner"], + device=device, + token=token, + ) + self._backend = _HfBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs) elif backend == NamedEntityExtractorBackend.SPACY: self._backend = _SpacyBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs) @@ -159,7 +174,7 @@ def warm_up(self): self._warmed_up = True except Exception as e: raise ComponentError( - f"Named entity extractor with backend '{self._backend.type} failed to initialize." + f"Named entity extractor with backend '{self._backend.type}' failed to initialize." ) from e @component.output_types(documents=List[Document]) @@ -201,14 +216,21 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ - return default_to_dict( + serialization_dict = default_to_dict( self, backend=self._backend.type.name, model=self._backend.model_name, device=self._backend.device.to_dict(), pipeline_kwargs=self._backend._pipeline_kwargs, + token=self.token.to_dict() if self.token else None, ) + hf_pipeline_kwargs = serialization_dict["init_parameters"]["pipeline_kwargs"] + hf_pipeline_kwargs.pop("token", None) + + serialize_hf_model_kwargs(hf_pipeline_kwargs) + return serialization_dict + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor": """ @@ -220,10 +242,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor": Deserialized component. """ try: - init_params = data["init_parameters"] + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + init_params = data.get("init_parameters", {}) if init_params.get("device") is not None: init_params["device"] = ComponentDevice.from_dict(init_params["device"]) init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]] + + hf_pipeline_kwargs = init_params.get("pipeline_kwargs", {}) + deserialize_hf_model_kwargs(hf_pipeline_kwargs) return default_from_dict(cls, data) except Exception as e: raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e @@ -352,8 +378,9 @@ def __init__( self.pipeline: Optional[HfPipeline] = None def initialize(self): - self.tokenizer = AutoTokenizer.from_pretrained(self._model_name_or_path) - self.model = AutoModelForTokenClassification.from_pretrained(self._model_name_or_path) + token = self._pipeline_kwargs.get("token", None) + self.tokenizer = AutoTokenizer.from_pretrained(self._model_name_or_path, token=token) + self.model = AutoModelForTokenClassification.from_pretrained(self._model_name_or_path, token=token) pipeline_params = { "task": "ner", diff --git a/releasenotes/notes/add-token-to-named-entity-extractor-3124acb1ae297c0e.yaml b/releasenotes/notes/add-token-to-named-entity-extractor-3124acb1ae297c0e.yaml new file mode 100644 index 0000000000..ee859c0c94 --- /dev/null +++ b/releasenotes/notes/add-token-to-named-entity-extractor-3124acb1ae297c0e.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Add `token` argument to `NamedEntityExtractor` to allow usage of private Hugging Face models. diff --git a/test/components/extractors/test_named_entity_extractor.py b/test/components/extractors/test_named_entity_extractor.py index a4826c1e95..04d36399f3 100644 --- a/test/components/extractors/test_named_entity_extractor.py +++ b/test/components/extractors/test_named_entity_extractor.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from haystack.utils.auth import Secret import pytest from haystack import ComponentError, DeserializationError, Pipeline @@ -11,6 +12,9 @@ def test_named_entity_extractor_backend(): _ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER") + # private model + _ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER") + _ = NamedEntityExtractor(backend="hugging_face", model="dslim/bert-base-NER") _ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_sm") @@ -40,7 +44,58 @@ def test_named_entity_extractor_serde(): _ = NamedEntityExtractor.from_dict(serde_data) -def test_named_entity_extractor_from_dict_no_default_parameters_hf(): +def test_to_dict_default(monkeypatch): + monkeypatch.delenv("HF_API_TOKEN", raising=False) + + component = NamedEntityExtractor( + backend=NamedEntityExtractorBackend.HUGGING_FACE, + model="dslim/bert-base-NER", + device=ComponentDevice.from_str("mps"), + ) + data = component.to_dict() + + assert data == { + "type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor", + "init_parameters": { + "backend": "HUGGING_FACE", + "model": "dslim/bert-base-NER", + "device": {"type": "single", "device": "mps"}, + "pipeline_kwargs": {"model": "dslim/bert-base-NER", "device": "mps", "task": "ner"}, + "token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False}, + }, + } + + +def test_to_dict_with_parameters(): + component = NamedEntityExtractor( + backend=NamedEntityExtractorBackend.HUGGING_FACE, + model="dslim/bert-base-NER", + device=ComponentDevice.from_str("mps"), + pipeline_kwargs={"model_kwargs": {"load_in_4bit": True}}, + token=Secret.from_env_var("ENV_VAR", strict=False), + ) + data = component.to_dict() + + assert data == { + "type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor", + "init_parameters": { + "backend": "HUGGING_FACE", + "model": "dslim/bert-base-NER", + "device": {"type": "single", "device": "mps"}, + "pipeline_kwargs": { + "model": "dslim/bert-base-NER", + "device": "mps", + "task": "ner", + "model_kwargs": {"load_in_4bit": True}, + }, + "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, + }, + } + + +def test_named_entity_extractor_from_dict_no_default_parameters_hf(monkeypatch): + monkeypatch.delenv("HF_API_TOKEN", raising=False) + data = { "type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor", "init_parameters": {"backend": "HUGGING_FACE", "model": "dslim/bert-base-NER"},