Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Named entity extractor private models #8658

Merged
merged 8 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ env:
PYTHON_VERSION: "3.9"
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
HATCH_VERSION: "1.13.0"
HF_API_TOKEN: ${{ secrets.HUGGINGFACE_API_KEY }}

jobs:
run:
Expand Down
13 changes: 13 additions & 0 deletions e2e/pipelines/test_named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import os
import pytest

from haystack import Document, Pipeline
Expand Down Expand Up @@ -65,6 +66,18 @@ def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size):
_extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size)


@pytest.mark.parametrize("batch_size", [1, 3])
@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
)
def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size):
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")
mpangrazzi marked this conversation as resolved.
Show resolved Hide resolved
extractor.warm_up()

_extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size)


@pytest.mark.parametrize("batch_size", [1, 3])
def test_ner_extractor_spacy_backend(raw_texts, spacy_annotations, batch_size):
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_trf")
Expand Down
37 changes: 32 additions & 5 deletions haystack/components/extractors/named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

from haystack import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret, deserialize_secrets_inplace
from haystack.utils.device import ComponentDevice
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs

with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
Expand Down Expand Up @@ -110,6 +112,7 @@ def __init__(
model: str,
pipeline_kwargs: Optional[Dict[str, Any]] = None,
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
) -> None:
"""
Create a Named Entity extractor component.
Expand All @@ -128,16 +131,28 @@ def __init__(
device/device map is specified in `pipeline_kwargs`,
it overrides this parameter (only applicable to the
HuggingFace backend).
:param token:
The API token to download private models from Hugging Face.
"""

if isinstance(backend, str):
backend = NamedEntityExtractorBackend.from_str(backend)

self._backend: _NerBackend
self._warmed_up: bool = False
self.token = token
device = ComponentDevice.resolve_device(device)

if backend == NamedEntityExtractorBackend.HUGGING_FACE:
pipeline_kwargs = resolve_hf_pipeline_kwargs(
huggingface_pipeline_kwargs=pipeline_kwargs or {},
model=model,
task="ner",
supported_tasks=["ner"],
device=device,
token=token,
)

self._backend = _HfBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs)
elif backend == NamedEntityExtractorBackend.SPACY:
self._backend = _SpacyBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs)
Expand All @@ -159,7 +174,7 @@ def warm_up(self):
self._warmed_up = True
except Exception as e:
raise ComponentError(
f"Named entity extractor with backend '{self._backend.type} failed to initialize."
f"Named entity extractor with backend '{self._backend.type}' failed to initialize."
mpangrazzi marked this conversation as resolved.
Show resolved Hide resolved
) from e

@component.output_types(documents=List[Document])
Expand Down Expand Up @@ -201,14 +216,21 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
serialization_dict = default_to_dict(
self,
backend=self._backend.type.name,
model=self._backend.model_name,
device=self._backend.device.to_dict(),
pipeline_kwargs=self._backend._pipeline_kwargs,
token=self.token.to_dict() if self.token else None,
)

hf_pipeline_kwargs = serialization_dict["init_parameters"]["pipeline_kwargs"]
hf_pipeline_kwargs.pop("token", None)

serialize_hf_model_kwargs(hf_pipeline_kwargs)
return serialization_dict

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor":
"""
Expand All @@ -220,10 +242,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor":
Deserialized component.
"""
try:
init_params = data["init_parameters"]
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
init_params = data.get("init_parameters", {})
if init_params.get("device") is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]]

hf_pipeline_kwargs = init_params.get("pipeline_kwargs", {})
deserialize_hf_model_kwargs(hf_pipeline_kwargs)
return default_from_dict(cls, data)
except Exception as e:
raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e
Expand Down Expand Up @@ -352,8 +378,9 @@ def __init__(
self.pipeline: Optional[HfPipeline] = None

def initialize(self):
self.tokenizer = AutoTokenizer.from_pretrained(self._model_name_or_path)
self.model = AutoModelForTokenClassification.from_pretrained(self._model_name_or_path)
token = self._pipeline_kwargs.get("token", None)
self.tokenizer = AutoTokenizer.from_pretrained(self._model_name_or_path, token=token)
self.model = AutoModelForTokenClassification.from_pretrained(self._model_name_or_path, token=token)

pipeline_params = {
"task": "ner",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Add `token` argument to `NamedEntityExtractor` to allow usage of private Hugging Face models.
57 changes: 56 additions & 1 deletion test/components/extractors/test_named_entity_extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from haystack.utils.auth import Secret
import pytest

from haystack import ComponentError, DeserializationError, Pipeline
Expand All @@ -11,6 +12,9 @@
def test_named_entity_extractor_backend():
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")

# private model
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")
mpangrazzi marked this conversation as resolved.
Show resolved Hide resolved

_ = NamedEntityExtractor(backend="hugging_face", model="dslim/bert-base-NER")

_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_sm")
Expand Down Expand Up @@ -40,7 +44,58 @@ def test_named_entity_extractor_serde():
_ = NamedEntityExtractor.from_dict(serde_data)


def test_named_entity_extractor_from_dict_no_default_parameters_hf():
def test_to_dict_default(monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)

component = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE,
model="dslim/bert-base-NER",
device=ComponentDevice.from_str("mps"),
)
data = component.to_dict()

assert data == {
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
"init_parameters": {
"backend": "HUGGING_FACE",
"model": "dslim/bert-base-NER",
"device": {"type": "single", "device": "mps"},
"pipeline_kwargs": {"model": "dslim/bert-base-NER", "device": "mps", "task": "ner"},
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
},
}


def test_to_dict_with_parameters():
component = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE,
model="dslim/bert-base-NER",
device=ComponentDevice.from_str("mps"),
pipeline_kwargs={"model_kwargs": {"load_in_4bit": True}},
token=Secret.from_env_var("ENV_VAR", strict=False),
)
data = component.to_dict()

assert data == {
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
"init_parameters": {
"backend": "HUGGING_FACE",
"model": "dslim/bert-base-NER",
"device": {"type": "single", "device": "mps"},
"pipeline_kwargs": {
"model": "dslim/bert-base-NER",
"device": "mps",
"task": "ner",
"model_kwargs": {"load_in_4bit": True},
},
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
},
}


def test_named_entity_extractor_from_dict_no_default_parameters_hf(monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)

data = {
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
"init_parameters": {"backend": "HUGGING_FACE", "model": "dslim/bert-base-NER"},
Expand Down
Loading