Skip to content

Commit

Permalink
Merge branch 'main' into add-recursive-chunking
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsbatista authored Dec 20, 2024
2 parents 3ad73a5 + c192488 commit d292de6
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 23 deletions.
1 change: 1 addition & 0 deletions .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ env:
PYTHON_VERSION: "3.9"
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
HATCH_VERSION: "1.13.0"
HF_API_TOKEN: ${{ secrets.HUGGINGFACE_API_KEY }}

jobs:
run:
Expand Down
13 changes: 13 additions & 0 deletions e2e/pipelines/test_named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import os
import pytest

from haystack import Document, Pipeline
Expand Down Expand Up @@ -65,6 +66,18 @@ def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size):
_extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size)


@pytest.mark.parametrize("batch_size", [1, 3])
@pytest.mark.skipif(
not os.environ.get("HF_API_TOKEN", None),
reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.",
)
def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size):
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")
extractor.warm_up()

_extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size)


@pytest.mark.parametrize("batch_size", [1, 3])
def test_ner_extractor_spacy_backend(raw_texts, spacy_annotations, batch_size):
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_trf")
Expand Down
37 changes: 32 additions & 5 deletions haystack/components/extractors/named_entity_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

from haystack import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict
from haystack.lazy_imports import LazyImport
from haystack.utils.auth import Secret, deserialize_secrets_inplace
from haystack.utils.device import ComponentDevice
from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs

with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import:
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
Expand Down Expand Up @@ -110,6 +112,7 @@ def __init__(
model: str,
pipeline_kwargs: Optional[Dict[str, Any]] = None,
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
) -> None:
"""
Create a Named Entity extractor component.
Expand All @@ -128,16 +131,28 @@ def __init__(
device/device map is specified in `pipeline_kwargs`,
it overrides this parameter (only applicable to the
HuggingFace backend).
:param token:
The API token to download private models from Hugging Face.
"""

if isinstance(backend, str):
backend = NamedEntityExtractorBackend.from_str(backend)

self._backend: _NerBackend
self._warmed_up: bool = False
self.token = token
device = ComponentDevice.resolve_device(device)

if backend == NamedEntityExtractorBackend.HUGGING_FACE:
pipeline_kwargs = resolve_hf_pipeline_kwargs(
huggingface_pipeline_kwargs=pipeline_kwargs or {},
model=model,
task="ner",
supported_tasks=["ner"],
device=device,
token=token,
)

self._backend = _HfBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs)
elif backend == NamedEntityExtractorBackend.SPACY:
self._backend = _SpacyBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs)
Expand All @@ -159,7 +174,7 @@ def warm_up(self):
self._warmed_up = True
except Exception as e:
raise ComponentError(
f"Named entity extractor with backend '{self._backend.type} failed to initialize."
f"Named entity extractor with backend '{self._backend.type}' failed to initialize."
) from e

@component.output_types(documents=List[Document])
Expand Down Expand Up @@ -201,14 +216,21 @@ def to_dict(self) -> Dict[str, Any]:
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
serialization_dict = default_to_dict(
self,
backend=self._backend.type.name,
model=self._backend.model_name,
device=self._backend.device.to_dict(),
pipeline_kwargs=self._backend._pipeline_kwargs,
token=self.token.to_dict() if self.token else None,
)

hf_pipeline_kwargs = serialization_dict["init_parameters"]["pipeline_kwargs"]
hf_pipeline_kwargs.pop("token", None)

serialize_hf_model_kwargs(hf_pipeline_kwargs)
return serialization_dict

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor":
"""
Expand All @@ -220,10 +242,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor":
Deserialized component.
"""
try:
init_params = data["init_parameters"]
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
init_params = data.get("init_parameters", {})
if init_params.get("device") is not None:
init_params["device"] = ComponentDevice.from_dict(init_params["device"])
init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]]

hf_pipeline_kwargs = init_params.get("pipeline_kwargs", {})
deserialize_hf_model_kwargs(hf_pipeline_kwargs)
return default_from_dict(cls, data)
except Exception as e:
raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e
Expand Down Expand Up @@ -352,8 +378,9 @@ def __init__(
self.pipeline: Optional[HfPipeline] = None

def initialize(self):
self.tokenizer = AutoTokenizer.from_pretrained(self._model_name_or_path)
self.model = AutoModelForTokenClassification.from_pretrained(self._model_name_or_path)
token = self._pipeline_kwargs.get("token", None)
self.tokenizer = AutoTokenizer.from_pretrained(self._model_name_or_path, token=token)
self.model = AutoModelForTokenClassification.from_pretrained(self._model_name_or_path, token=token)

pipeline_params = {
"task": "ner",
Expand Down
27 changes: 19 additions & 8 deletions haystack/components/preprocessors/document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,10 @@ def __init__( # pylint: disable=too-many-positional-arguments
splitting_function=splitting_function,
respect_sentence_boundary=respect_sentence_boundary,
)

if split_by == "sentence" or (respect_sentence_boundary and split_by == "word"):
self._use_sentence_splitter = split_by == "sentence" or (respect_sentence_boundary and split_by == "word")
if self._use_sentence_splitter:
nltk_imports.check()
self.sentence_splitter = SentenceSplitter(
language=language,
use_split_rules=use_split_rules,
extend_abbreviations=extend_abbreviations,
keep_white_spaces=True,
)
self.sentence_splitter = None

if split_by == "sentence":
# ToDo: remove this warning in the next major release
Expand Down Expand Up @@ -164,6 +159,18 @@ def _init_checks(
)
self.respect_sentence_boundary = False

def warm_up(self):
"""
Warm up the DocumentSplitter by loading the sentence tokenizer.
"""
if self._use_sentence_splitter and self.sentence_splitter is None:
self.sentence_splitter = SentenceSplitter(
language=self.language,
use_split_rules=self.use_split_rules,
extend_abbreviations=self.extend_abbreviations,
keep_white_spaces=True,
)

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Expand All @@ -182,6 +189,10 @@ def run(self, documents: List[Document]):
:raises TypeError: if the input is not a list of Documents.
:raises ValueError: if the content of a document is None.
"""
if self._use_sentence_splitter and self.sentence_splitter is None:
raise RuntimeError(
"The component DocumentSplitter wasn't warmed up. Run 'warm_up()' before calling 'run()'."
)

if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
raise TypeError("DocumentSplitter expects a List of Documents as input.")
Expand Down
32 changes: 23 additions & 9 deletions haystack/components/preprocessors/nltk_document_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,23 @@ def __init__( # pylint: disable=too-many-positional-arguments
self.respect_sentence_boundary = respect_sentence_boundary
self.use_split_rules = use_split_rules
self.extend_abbreviations = extend_abbreviations
self.sentence_splitter = SentenceSplitter(
language=language,
use_split_rules=use_split_rules,
extend_abbreviations=extend_abbreviations,
keep_white_spaces=True,
)
self.sentence_splitter = None
self.language = language

def warm_up(self):
"""
Warm up the NLTKDocumentSplitter by loading the sentence tokenizer.
"""
if self.sentence_splitter is None:
self.sentence_splitter = SentenceSplitter(
language=self.language,
use_split_rules=self.use_split_rules,
extend_abbreviations=self.extend_abbreviations,
keep_white_spaces=True,
)

def _split_into_units(
self, text: str, split_by: Literal["function", "page", "passage", "sentence", "word", "line"]
self, text: str, split_by: Literal["function", "page", "passage", "period", "sentence", "word", "line"]
) -> List[str]:
"""
Splits the text into units based on the specified split_by parameter.
Expand All @@ -106,6 +113,7 @@ def _split_into_units(
# whitespace is preserved while splitting text into sentences when using keep_white_spaces=True
# so split_at is set to an empty string
self.split_at = ""
assert self.sentence_splitter is not None
result = self.sentence_splitter.split_sentences(text)
units = [sentence["sentence"] for sentence in result]
elif split_by == "word":
Expand Down Expand Up @@ -142,6 +150,11 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]:
:raises TypeError: if the input is not a list of Documents.
:raises ValueError: if the content of a document is None.
"""
if self.sentence_splitter is None:
raise RuntimeError(
"The component NLTKDocumentSplitter wasn't warmed up. Run 'warm_up()' before calling 'run()'."
)

if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
raise TypeError("DocumentSplitter expects a List of Documents as input.")

Expand Down Expand Up @@ -221,8 +234,9 @@ def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_
break
return num_sentences_to_keep

@staticmethod
def _concatenate_sentences_based_on_word_amount(
self, sentences: List[str], split_length: int, split_overlap: int
sentences: List[str], split_length: int, split_overlap: int
) -> Tuple[List[str], List[int], List[int]]:
"""
Groups the sentences into chunks of `split_length` words while respecting sentence boundaries.
Expand Down Expand Up @@ -258,7 +272,7 @@ def _concatenate_sentences_based_on_word_amount(
split_start_indices.append(chunk_start_idx)

# Get the number of sentences that overlap with the next chunk
num_sentences_to_keep = self._number_of_sentences_to_keep(
num_sentences_to_keep = NLTKDocumentSplitter._number_of_sentences_to_keep(
sentences=current_chunk, split_length=split_length, split_overlap=split_overlap
)
# Set up information for the new chunk
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Add `token` argument to `NamedEntityExtractor` to allow usage of private Hugging Face models.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Moved the NLTK download of DocumentSplitter and NLTKDocumentSplitter to warm_up(). This prevents calling to an external api during instantiation. If a DocumentSplitter or NLTKDocumentSplitter is used for sentence splitting outside of a pipeline, warm_up() now needs to be called before running the component.
57 changes: 56 additions & 1 deletion test/components/extractors/test_named_entity_extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from haystack.utils.auth import Secret
import pytest

from haystack import ComponentError, DeserializationError, Pipeline
Expand All @@ -11,6 +12,9 @@
def test_named_entity_extractor_backend():
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")

# private model
_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER")

_ = NamedEntityExtractor(backend="hugging_face", model="dslim/bert-base-NER")

_ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_sm")
Expand Down Expand Up @@ -40,7 +44,58 @@ def test_named_entity_extractor_serde():
_ = NamedEntityExtractor.from_dict(serde_data)


def test_named_entity_extractor_from_dict_no_default_parameters_hf():
def test_to_dict_default(monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)

component = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE,
model="dslim/bert-base-NER",
device=ComponentDevice.from_str("mps"),
)
data = component.to_dict()

assert data == {
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
"init_parameters": {
"backend": "HUGGING_FACE",
"model": "dslim/bert-base-NER",
"device": {"type": "single", "device": "mps"},
"pipeline_kwargs": {"model": "dslim/bert-base-NER", "device": "mps", "task": "ner"},
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
},
}


def test_to_dict_with_parameters():
component = NamedEntityExtractor(
backend=NamedEntityExtractorBackend.HUGGING_FACE,
model="dslim/bert-base-NER",
device=ComponentDevice.from_str("mps"),
pipeline_kwargs={"model_kwargs": {"load_in_4bit": True}},
token=Secret.from_env_var("ENV_VAR", strict=False),
)
data = component.to_dict()

assert data == {
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
"init_parameters": {
"backend": "HUGGING_FACE",
"model": "dslim/bert-base-NER",
"device": {"type": "single", "device": "mps"},
"pipeline_kwargs": {
"model": "dslim/bert-base-NER",
"device": "mps",
"task": "ner",
"model_kwargs": {"load_in_4bit": True},
},
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
},
}


def test_named_entity_extractor_from_dict_no_default_parameters_hf(monkeypatch):
monkeypatch.delenv("HF_API_TOKEN", raising=False)

data = {
"type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor",
"init_parameters": {"backend": "HUGGING_FACE", "model": "dslim/bert-base-NER"},
Expand Down
Loading

0 comments on commit d292de6

Please sign in to comment.