From 286061f005ddefce52faeffad8dbc434ad2bebc3 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Fri, 20 Dec 2024 10:41:44 +0100 Subject: [PATCH 1/2] fix: Move potential nltk download to warm_up (#8646) * Move potential nltk download to warm_up * Update tests * Add release notes * Fix tests * Uncomment * Make mypy happy * Add RuntimeError message * Update release notes --------- Co-authored-by: Julian Risch --- .../preprocessors/document_splitter.py | 27 ++++++++++----- .../preprocessors/nltk_document_splitter.py | 32 ++++++++++++----- ...-download-to-warm-up-f2b22bda3a9ba673.yaml | 4 +++ .../preprocessors/test_document_splitter.py | 34 +++++++++++++++++++ .../test_nltk_document_splitter.py | 10 ++++++ 5 files changed, 90 insertions(+), 17 deletions(-) create mode 100644 releasenotes/notes/move-nltk-download-to-warm-up-f2b22bda3a9ba673.yaml diff --git a/haystack/components/preprocessors/document_splitter.py b/haystack/components/preprocessors/document_splitter.py index b3e99924a7..d03897b4b6 100644 --- a/haystack/components/preprocessors/document_splitter.py +++ b/haystack/components/preprocessors/document_splitter.py @@ -107,15 +107,10 @@ def __init__( # pylint: disable=too-many-positional-arguments splitting_function=splitting_function, respect_sentence_boundary=respect_sentence_boundary, ) - - if split_by == "sentence" or (respect_sentence_boundary and split_by == "word"): + self._use_sentence_splitter = split_by == "sentence" or (respect_sentence_boundary and split_by == "word") + if self._use_sentence_splitter: nltk_imports.check() - self.sentence_splitter = SentenceSplitter( - language=language, - use_split_rules=use_split_rules, - extend_abbreviations=extend_abbreviations, - keep_white_spaces=True, - ) + self.sentence_splitter = None if split_by == "sentence": # ToDo: remove this warning in the next major release @@ -164,6 +159,18 @@ def _init_checks( ) self.respect_sentence_boundary = False + def warm_up(self): + """ + Warm up the DocumentSplitter by loading the sentence tokenizer. + """ + if self._use_sentence_splitter and self.sentence_splitter is None: + self.sentence_splitter = SentenceSplitter( + language=self.language, + use_split_rules=self.use_split_rules, + extend_abbreviations=self.extend_abbreviations, + keep_white_spaces=True, + ) + @component.output_types(documents=List[Document]) def run(self, documents: List[Document]): """ @@ -182,6 +189,10 @@ def run(self, documents: List[Document]): :raises TypeError: if the input is not a list of Documents. :raises ValueError: if the content of a document is None. """ + if self._use_sentence_splitter and self.sentence_splitter is None: + raise RuntimeError( + "The component DocumentSplitter wasn't warmed up. Run 'warm_up()' before calling 'run()'." + ) if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): raise TypeError("DocumentSplitter expects a List of Documents as input.") diff --git a/haystack/components/preprocessors/nltk_document_splitter.py b/haystack/components/preprocessors/nltk_document_splitter.py index eb242d9013..ab787d599d 100644 --- a/haystack/components/preprocessors/nltk_document_splitter.py +++ b/haystack/components/preprocessors/nltk_document_splitter.py @@ -77,16 +77,23 @@ def __init__( # pylint: disable=too-many-positional-arguments self.respect_sentence_boundary = respect_sentence_boundary self.use_split_rules = use_split_rules self.extend_abbreviations = extend_abbreviations - self.sentence_splitter = SentenceSplitter( - language=language, - use_split_rules=use_split_rules, - extend_abbreviations=extend_abbreviations, - keep_white_spaces=True, - ) + self.sentence_splitter = None self.language = language + def warm_up(self): + """ + Warm up the NLTKDocumentSplitter by loading the sentence tokenizer. + """ + if self.sentence_splitter is None: + self.sentence_splitter = SentenceSplitter( + language=self.language, + use_split_rules=self.use_split_rules, + extend_abbreviations=self.extend_abbreviations, + keep_white_spaces=True, + ) + def _split_into_units( - self, text: str, split_by: Literal["function", "page", "passage", "sentence", "word", "line"] + self, text: str, split_by: Literal["function", "page", "passage", "period", "sentence", "word", "line"] ) -> List[str]: """ Splits the text into units based on the specified split_by parameter. @@ -106,6 +113,7 @@ def _split_into_units( # whitespace is preserved while splitting text into sentences when using keep_white_spaces=True # so split_at is set to an empty string self.split_at = "" + assert self.sentence_splitter is not None result = self.sentence_splitter.split_sentences(text) units = [sentence["sentence"] for sentence in result] elif split_by == "word": @@ -142,6 +150,11 @@ def run(self, documents: List[Document]) -> Dict[str, List[Document]]: :raises TypeError: if the input is not a list of Documents. :raises ValueError: if the content of a document is None. """ + if self.sentence_splitter is None: + raise RuntimeError( + "The component NLTKDocumentSplitter wasn't warmed up. Run 'warm_up()' before calling 'run()'." + ) + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): raise TypeError("DocumentSplitter expects a List of Documents as input.") @@ -221,8 +234,9 @@ def _number_of_sentences_to_keep(sentences: List[str], split_length: int, split_ break return num_sentences_to_keep + @staticmethod def _concatenate_sentences_based_on_word_amount( - self, sentences: List[str], split_length: int, split_overlap: int + sentences: List[str], split_length: int, split_overlap: int ) -> Tuple[List[str], List[int], List[int]]: """ Groups the sentences into chunks of `split_length` words while respecting sentence boundaries. @@ -258,7 +272,7 @@ def _concatenate_sentences_based_on_word_amount( split_start_indices.append(chunk_start_idx) # Get the number of sentences that overlap with the next chunk - num_sentences_to_keep = self._number_of_sentences_to_keep( + num_sentences_to_keep = NLTKDocumentSplitter._number_of_sentences_to_keep( sentences=current_chunk, split_length=split_length, split_overlap=split_overlap ) # Set up information for the new chunk diff --git a/releasenotes/notes/move-nltk-download-to-warm-up-f2b22bda3a9ba673.yaml b/releasenotes/notes/move-nltk-download-to-warm-up-f2b22bda3a9ba673.yaml new file mode 100644 index 0000000000..fb308f39c8 --- /dev/null +++ b/releasenotes/notes/move-nltk-download-to-warm-up-f2b22bda3a9ba673.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Moved the NLTK download of DocumentSplitter and NLTKDocumentSplitter to warm_up(). This prevents calling to an external api during instantiation. If a DocumentSplitter or NLTKDocumentSplitter is used for sentence splitting outside of a pipeline, warm_up() now needs to be called before running the component. diff --git a/test/components/preprocessors/test_document_splitter.py b/test/components/preprocessors/test_document_splitter.py index 78767dbccd..094c17eeea 100644 --- a/test/components/preprocessors/test_document_splitter.py +++ b/test/components/preprocessors/test_document_splitter.py @@ -44,16 +44,19 @@ def test_non_text_document(self): ValueError, match="DocumentSplitter only works with text documents but content for document ID" ): splitter = DocumentSplitter() + splitter.warm_up() splitter.run(documents=[Document()]) assert "DocumentSplitter only works with text documents but content for document ID" in caplog.text def test_single_doc(self): with pytest.raises(TypeError, match="DocumentSplitter expects a List of Documents as input."): splitter = DocumentSplitter() + splitter.warm_up() splitter.run(documents=Document()) def test_empty_list(self): splitter = DocumentSplitter() + splitter.warm_up() res = splitter.run(documents=[]) assert res == {"documents": []} @@ -76,6 +79,7 @@ def test_unsupported_split_overlap(self): def test_split_by_word(self): splitter = DocumentSplitter(split_by="word", split_length=10) text = "This is a text with some words. There is a second sentence. And there is a third sentence." + splitter.warm_up() result = splitter.run(documents=[Document(content=text)]) docs = result["documents"] assert len(docs) == 2 @@ -88,6 +92,7 @@ def test_split_by_word(self): def test_split_by_word_with_threshold(self): splitter = DocumentSplitter(split_by="word", split_length=15, split_threshold=10) + splitter.warm_up() result = splitter.run( documents=[ Document( @@ -105,6 +110,7 @@ def test_split_by_word_multiple_input_docs(self): splitter = DocumentSplitter(split_by="word", split_length=10) text1 = "This is a text with some words. There is a second sentence. And there is a third sentence." text2 = "This is a different text with some words. There is a second sentence. And there is a third sentence. And there is a fourth sentence." + splitter.warm_up() result = splitter.run(documents=[Document(content=text1), Document(content=text2)]) docs = result["documents"] assert len(docs) == 5 @@ -132,6 +138,7 @@ def test_split_by_word_multiple_input_docs(self): def test_split_by_period(self): splitter = DocumentSplitter(split_by="period", split_length=1) text = "This is a text with some words. There is a second sentence. And there is a third sentence." + splitter.warm_up() result = splitter.run(documents=[Document(content=text)]) docs = result["documents"] assert len(docs) == 3 @@ -148,6 +155,7 @@ def test_split_by_period(self): def test_split_by_passage(self): splitter = DocumentSplitter(split_by="passage", split_length=1) text = "This is a text with some words. There is a second sentence.\n\nAnd there is a third sentence.\n\n And another passage." + splitter.warm_up() result = splitter.run(documents=[Document(content=text)]) docs = result["documents"] assert len(docs) == 3 @@ -164,6 +172,7 @@ def test_split_by_passage(self): def test_split_by_page(self): splitter = DocumentSplitter(split_by="page", split_length=1) text = "This is a text with some words. There is a second sentence.\f And there is a third sentence.\f And another passage." + splitter.warm_up() result = splitter.run(documents=[Document(content=text)]) docs = result["documents"] assert len(docs) == 3 @@ -183,6 +192,7 @@ def test_split_by_page(self): def test_split_by_function(self): splitting_function = lambda s: s.split(".") splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function) + splitter.warm_up() text = "This.Is.A.Test" result = splitter.run(documents=[Document(id="1", content=text, meta={"key": "value"})]) docs = result["documents"] @@ -200,6 +210,7 @@ def test_split_by_function(self): splitting_function = lambda s: re.split(r"[\s]{2,}", s) splitter = DocumentSplitter(split_by="function", splitting_function=splitting_function) text = "This Is\n A Test" + splitter.warm_up() result = splitter.run(documents=[Document(id="1", content=text, meta={"key": "value"})]) docs = result["documents"] assert len(docs) == 4 @@ -215,6 +226,7 @@ def test_split_by_function(self): def test_split_by_word_with_overlap(self): splitter = DocumentSplitter(split_by="word", split_length=10, split_overlap=2) text = "This is a text with some words. There is a second sentence. And there is a third sentence." + splitter.warm_up() result = splitter.run(documents=[Document(content=text)]) docs = result["documents"] assert len(docs) == 2 @@ -234,6 +246,7 @@ def test_split_by_word_with_overlap(self): def test_split_by_line(self): splitter = DocumentSplitter(split_by="line", split_length=1) text = "This is a text with some words.\nThere is a second sentence.\nAnd there is a third sentence." + splitter.warm_up() result = splitter.run(documents=[Document(content=text)]) docs = result["documents"] @@ -252,6 +265,7 @@ def test_source_id_stored_in_metadata(self): splitter = DocumentSplitter(split_by="word", split_length=10) doc1 = Document(content="This is a text with some words.") doc2 = Document(content="This is a different text with some words.") + splitter.warm_up() result = splitter.run(documents=[doc1, doc2]) assert result["documents"][0].meta["source_id"] == doc1.id assert result["documents"][1].meta["source_id"] == doc2.id @@ -262,6 +276,7 @@ def test_copy_metadata(self): Document(content="Text.", meta={"name": "doc 0"}), Document(content="Text.", meta={"name": "doc 1"}), ] + splitter.warm_up() result = splitter.run(documents=documents) assert len(result["documents"]) == 2 assert result["documents"][0].id != result["documents"][1].id @@ -273,6 +288,7 @@ def test_add_page_number_to_metadata_with_no_overlap_word_split(self): splitter = DocumentSplitter(split_by="word", split_length=2) doc1 = Document(content="This is some text.\f This text is on another page.") doc2 = Document(content="This content has two.\f\f page brakes.") + splitter.warm_up() result = splitter.run(documents=[doc1, doc2]) expected_pages = [1, 1, 2, 2, 2, 1, 1, 3] @@ -283,6 +299,7 @@ def test_add_page_number_to_metadata_with_no_overlap_period_split(self): splitter = DocumentSplitter(split_by="period", split_length=1) doc1 = Document(content="This is some text.\f This text is on another page.") doc2 = Document(content="This content has two.\f\f page brakes.") + splitter.warm_up() result = splitter.run(documents=[doc1, doc2]) expected_pages = [1, 1, 1, 1] @@ -294,6 +311,7 @@ def test_add_page_number_to_metadata_with_no_overlap_passage_split(self): doc1 = Document( content="This is a text with some words.\f There is a second sentence.\n\nAnd there is a third sentence.\n\nAnd more passages.\n\n\f And another passage." ) + splitter.warm_up() result = splitter.run(documents=[doc1]) expected_pages = [1, 2, 2, 2] @@ -305,6 +323,7 @@ def test_add_page_number_to_metadata_with_no_overlap_page_split(self): doc1 = Document( content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f And another passage." ) + splitter.warm_up() result = splitter.run(documents=[doc1]) expected_pages = [1, 2, 3] for doc, p in zip(result["documents"], expected_pages): @@ -314,6 +333,7 @@ def test_add_page_number_to_metadata_with_no_overlap_page_split(self): doc1 = Document( content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f And another passage." ) + splitter.warm_up() result = splitter.run(documents=[doc1]) expected_pages = [1, 3] @@ -324,6 +344,7 @@ def test_add_page_number_to_metadata_with_overlap_word_split(self): splitter = DocumentSplitter(split_by="word", split_length=3, split_overlap=1) doc1 = Document(content="This is some text. And\f this text is on another page.") doc2 = Document(content="This content has two.\f\f page brakes.") + splitter.warm_up() result = splitter.run(documents=[doc1, doc2]) expected_pages = [1, 1, 1, 2, 2, 1, 1, 3] @@ -334,6 +355,7 @@ def test_add_page_number_to_metadata_with_overlap_period_split(self): splitter = DocumentSplitter(split_by="period", split_length=2, split_overlap=1) doc1 = Document(content="This is some text. And this is more text.\f This text is on another page. End.") doc2 = Document(content="This content has two.\f\f page brakes. More text.") + splitter.warm_up() result = splitter.run(documents=[doc1, doc2]) expected_pages = [1, 1, 1, 2, 1, 1] @@ -345,6 +367,7 @@ def test_add_page_number_to_metadata_with_overlap_passage_split(self): doc1 = Document( content="This is a text with some words.\f There is a second sentence.\n\nAnd there is a third sentence.\n\nAnd more passages.\n\n\f And another passage." ) + splitter.warm_up() result = splitter.run(documents=[doc1]) expected_pages = [1, 2, 2] @@ -356,6 +379,7 @@ def test_add_page_number_to_metadata_with_overlap_page_split(self): doc1 = Document( content="This is a text with some words. There is a second sentence.\f And there is a third sentence.\f And another passage." ) + splitter.warm_up() result = splitter.run(documents=[doc1]) expected_pages = [1, 2, 3] @@ -366,6 +390,7 @@ def test_add_split_overlap_information(self): splitter = DocumentSplitter(split_length=10, split_overlap=5, split_by="word") text = "This is a text with some words. There is a second sentence. And a third sentence." doc = Document(content="This is a text with some words. There is a second sentence. And a third sentence.") + splitter.warm_up() docs = splitter.run(documents=[doc])["documents"] # check split_overlap is added to all the documents @@ -487,6 +512,7 @@ def test_run_empty_document(self): """ splitter = DocumentSplitter() doc = Document(content="") + splitter.warm_up() results = splitter.run([doc]) assert results["documents"] == [] @@ -496,6 +522,7 @@ def test_run_document_only_whitespaces(self): """ splitter = DocumentSplitter() doc = Document(content=" ") + splitter.warm_up() results = splitter.run([doc]) assert results["documents"][0].content == " " @@ -543,6 +570,7 @@ def test_run_split_by_sentence_1(self) -> None: "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night ... " "The moon was full." ) + document_splitter.warm_up() documents = document_splitter.run(documents=[Document(content=text)])["documents"] assert len(documents) == 2 @@ -568,6 +596,7 @@ def test_run_split_by_sentence_2(self) -> None: "This is another test sentence. (This is a third test sentence.) " "This is the last test sentence." ) + document_splitter.warm_up() documents = document_splitter.run(documents=[Document(content=text)])["documents"] assert len(documents) == 4 @@ -601,6 +630,7 @@ def test_run_split_by_sentence_3(self) -> None: use_split_rules=True, extend_abbreviations=True, ) + document_splitter.warm_up() text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5." documents = document_splitter.run(documents=[Document(content=text)])["documents"] @@ -633,6 +663,7 @@ def test_run_split_by_sentence_4(self) -> None: use_split_rules=True, extend_abbreviations=True, ) + document_splitter.warm_up() text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5." documents = document_splitter.run(documents=[Document(content=text)])["documents"] @@ -660,6 +691,7 @@ def test_run_split_by_word_respect_sentence_boundary(self) -> None: language="en", respect_sentence_boundary=True, ) + document_splitter.warm_up() text = ( "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night.\f" @@ -692,6 +724,7 @@ def test_run_split_by_word_respect_sentence_boundary_no_repeats(self) -> None: use_split_rules=False, extend_abbreviations=False, ) + document_splitter.warm_up() text = ( "This is a test sentence with many many words that exceeds the split length and should not be repeated. " "This is another test sentence. (This is a third test sentence.) " @@ -717,6 +750,7 @@ def test_run_split_by_word_respect_sentence_boundary_with_split_overlap_and_page extend_abbreviations=True, respect_sentence_boundary=True, ) + document_splitter.warm_up() text = ( "Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f" diff --git a/test/components/preprocessors/test_nltk_document_splitter.py b/test/components/preprocessors/test_nltk_document_splitter.py index 38952575d1..fe80848c74 100644 --- a/test/components/preprocessors/test_nltk_document_splitter.py +++ b/test/components/preprocessors/test_nltk_document_splitter.py @@ -42,6 +42,7 @@ def test_document_splitter_split_into_units_sentence(self) -> None: document_splitter = NLTKDocumentSplitter( split_by="sentence", split_length=2, split_overlap=0, split_threshold=0, language="en" ) + document_splitter.warm_up() text = "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night." units = document_splitter._split_into_units(text=text, split_by="sentence") @@ -121,11 +122,13 @@ class TestNLTKDocumentSplitterRun: def test_run_type_error(self) -> None: document_splitter = NLTKDocumentSplitter() with pytest.raises(TypeError): + document_splitter.warm_up() document_splitter.run(documents=Document(content="Moonlight shimmered softly.")) # type: ignore def test_run_value_error(self) -> None: document_splitter = NLTKDocumentSplitter() with pytest.raises(ValueError): + document_splitter.warm_up() document_splitter.run(documents=[Document(content=None)]) def test_run_split_by_sentence_1(self) -> None: @@ -138,6 +141,7 @@ def test_run_split_by_sentence_1(self) -> None: use_split_rules=True, extend_abbreviations=True, ) + document_splitter.warm_up() text = ( "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night ... " @@ -168,6 +172,7 @@ def test_run_split_by_sentence_2(self) -> None: "This is another test sentence. (This is a third test sentence.) " "This is the last test sentence." ) + document_splitter.warm_up() documents = document_splitter.run(documents=[Document(content=text)])["documents"] assert len(documents) == 4 @@ -201,6 +206,7 @@ def test_run_split_by_sentence_3(self) -> None: use_split_rules=True, extend_abbreviations=True, ) + document_splitter.warm_up() text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5." documents = document_splitter.run(documents=[Document(content=text)])["documents"] @@ -233,6 +239,7 @@ def test_run_split_by_sentence_4(self) -> None: use_split_rules=True, extend_abbreviations=True, ) + document_splitter.warm_up() text = "Sentence on page 1.\fSentence on page 2. \fSentence on page 3. \f\f Sentence on page 5." documents = document_splitter.run(documents=[Document(content=text)])["documents"] @@ -262,6 +269,7 @@ def test_run_split_by_word_respect_sentence_boundary(self) -> None: language="en", respect_sentence_boundary=True, ) + document_splitter.warm_up() text = ( "Moonlight shimmered softly, wolves howled nearby, night enveloped everything. It was a dark night.\f" @@ -294,6 +302,7 @@ def test_run_split_by_word_respect_sentence_boundary_no_repeats(self) -> None: use_split_rules=False, extend_abbreviations=False, ) + document_splitter.warm_up() text = ( "This is a test sentence with many many words that exceeds the split length and should not be repeated. " "This is another test sentence. (This is a third test sentence.) " @@ -319,6 +328,7 @@ def test_run_split_by_word_respect_sentence_boundary_with_split_overlap_and_page extend_abbreviations=True, respect_sentence_boundary=True, ) + document_splitter.warm_up() text = ( "Sentence on page 1. Another on page 1.\fSentence on page 2. Another on page 2.\f" From c192488bf6216b5b727ffae03635a64d0541149a Mon Sep 17 00:00:00 2001 From: Michele Pangrazzi Date: Fri, 20 Dec 2024 11:15:55 +0100 Subject: [PATCH 2/2] Named entity extractor private models (#8658) * add 'token' support to NamedEntityExtractor to enable using private models on HF backend * fix existing error message format * add release note * add HF_API_TOKEN to e2e workflow * add informative comment * Updated to_dict / from_dict to handle 'token' correctly ; Added tests * Fix lint * Revert unwanted change --- .github/workflows/e2e.yml | 1 + e2e/pipelines/test_named_entity_extractor.py | 13 +++++ .../extractors/named_entity_extractor.py | 37 ++++++++++-- ...med-entity-extractor-3124acb1ae297c0e.yaml | 4 ++ .../extractors/test_named_entity_extractor.py | 57 ++++++++++++++++++- 5 files changed, 106 insertions(+), 6 deletions(-) create mode 100644 releasenotes/notes/add-token-to-named-entity-extractor-3124acb1ae297c0e.yaml diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 7e817f2039..ad47c7a68a 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -19,6 +19,7 @@ env: PYTHON_VERSION: "3.9" OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} HATCH_VERSION: "1.13.0" + HF_API_TOKEN: ${{ secrets.HUGGINGFACE_API_KEY }} jobs: run: diff --git a/e2e/pipelines/test_named_entity_extractor.py b/e2e/pipelines/test_named_entity_extractor.py index 2fc15a9a1a..bade774048 100644 --- a/e2e/pipelines/test_named_entity_extractor.py +++ b/e2e/pipelines/test_named_entity_extractor.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 +import os import pytest from haystack import Document, Pipeline @@ -65,6 +66,18 @@ def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size): _extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size) +@pytest.mark.parametrize("batch_size", [1, 3]) +@pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", +) +def test_ner_extractor_hf_backend_private_models(raw_texts, hf_annotations, batch_size): + extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER") + extractor.warm_up() + + _extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size) + + @pytest.mark.parametrize("batch_size", [1, 3]) def test_ner_extractor_spacy_backend(raw_texts, spacy_annotations, batch_size): extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_trf") diff --git a/haystack/components/extractors/named_entity_extractor.py b/haystack/components/extractors/named_entity_extractor.py index 00350e6aed..5651530dbf 100644 --- a/haystack/components/extractors/named_entity_extractor.py +++ b/haystack/components/extractors/named_entity_extractor.py @@ -10,7 +10,9 @@ from haystack import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict from haystack.lazy_imports import LazyImport +from haystack.utils.auth import Secret, deserialize_secrets_inplace from haystack.utils.device import ComponentDevice +from haystack.utils.hf import deserialize_hf_model_kwargs, resolve_hf_pipeline_kwargs, serialize_hf_model_kwargs with LazyImport(message="Run 'pip install \"transformers[torch]\"'") as transformers_import: from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline @@ -110,6 +112,7 @@ def __init__( model: str, pipeline_kwargs: Optional[Dict[str, Any]] = None, device: Optional[ComponentDevice] = None, + token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), ) -> None: """ Create a Named Entity extractor component. @@ -128,6 +131,8 @@ def __init__( device/device map is specified in `pipeline_kwargs`, it overrides this parameter (only applicable to the HuggingFace backend). + :param token: + The API token to download private models from Hugging Face. """ if isinstance(backend, str): @@ -135,9 +140,19 @@ def __init__( self._backend: _NerBackend self._warmed_up: bool = False + self.token = token device = ComponentDevice.resolve_device(device) if backend == NamedEntityExtractorBackend.HUGGING_FACE: + pipeline_kwargs = resolve_hf_pipeline_kwargs( + huggingface_pipeline_kwargs=pipeline_kwargs or {}, + model=model, + task="ner", + supported_tasks=["ner"], + device=device, + token=token, + ) + self._backend = _HfBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs) elif backend == NamedEntityExtractorBackend.SPACY: self._backend = _SpacyBackend(model_name_or_path=model, device=device, pipeline_kwargs=pipeline_kwargs) @@ -159,7 +174,7 @@ def warm_up(self): self._warmed_up = True except Exception as e: raise ComponentError( - f"Named entity extractor with backend '{self._backend.type} failed to initialize." + f"Named entity extractor with backend '{self._backend.type}' failed to initialize." ) from e @component.output_types(documents=List[Document]) @@ -201,14 +216,21 @@ def to_dict(self) -> Dict[str, Any]: :returns: Dictionary with serialized data. """ - return default_to_dict( + serialization_dict = default_to_dict( self, backend=self._backend.type.name, model=self._backend.model_name, device=self._backend.device.to_dict(), pipeline_kwargs=self._backend._pipeline_kwargs, + token=self.token.to_dict() if self.token else None, ) + hf_pipeline_kwargs = serialization_dict["init_parameters"]["pipeline_kwargs"] + hf_pipeline_kwargs.pop("token", None) + + serialize_hf_model_kwargs(hf_pipeline_kwargs) + return serialization_dict + @classmethod def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor": """ @@ -220,10 +242,14 @@ def from_dict(cls, data: Dict[str, Any]) -> "NamedEntityExtractor": Deserialized component. """ try: - init_params = data["init_parameters"] + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + init_params = data.get("init_parameters", {}) if init_params.get("device") is not None: init_params["device"] = ComponentDevice.from_dict(init_params["device"]) init_params["backend"] = NamedEntityExtractorBackend[init_params["backend"]] + + hf_pipeline_kwargs = init_params.get("pipeline_kwargs", {}) + deserialize_hf_model_kwargs(hf_pipeline_kwargs) return default_from_dict(cls, data) except Exception as e: raise DeserializationError(f"Couldn't deserialize {cls.__name__} instance") from e @@ -352,8 +378,9 @@ def __init__( self.pipeline: Optional[HfPipeline] = None def initialize(self): - self.tokenizer = AutoTokenizer.from_pretrained(self._model_name_or_path) - self.model = AutoModelForTokenClassification.from_pretrained(self._model_name_or_path) + token = self._pipeline_kwargs.get("token", None) + self.tokenizer = AutoTokenizer.from_pretrained(self._model_name_or_path, token=token) + self.model = AutoModelForTokenClassification.from_pretrained(self._model_name_or_path, token=token) pipeline_params = { "task": "ner", diff --git a/releasenotes/notes/add-token-to-named-entity-extractor-3124acb1ae297c0e.yaml b/releasenotes/notes/add-token-to-named-entity-extractor-3124acb1ae297c0e.yaml new file mode 100644 index 0000000000..ee859c0c94 --- /dev/null +++ b/releasenotes/notes/add-token-to-named-entity-extractor-3124acb1ae297c0e.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Add `token` argument to `NamedEntityExtractor` to allow usage of private Hugging Face models. diff --git a/test/components/extractors/test_named_entity_extractor.py b/test/components/extractors/test_named_entity_extractor.py index a4826c1e95..04d36399f3 100644 --- a/test/components/extractors/test_named_entity_extractor.py +++ b/test/components/extractors/test_named_entity_extractor.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from haystack.utils.auth import Secret import pytest from haystack import ComponentError, DeserializationError, Pipeline @@ -11,6 +12,9 @@ def test_named_entity_extractor_backend(): _ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER") + # private model + _ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="deepset/bert-base-NER") + _ = NamedEntityExtractor(backend="hugging_face", model="dslim/bert-base-NER") _ = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_sm") @@ -40,7 +44,58 @@ def test_named_entity_extractor_serde(): _ = NamedEntityExtractor.from_dict(serde_data) -def test_named_entity_extractor_from_dict_no_default_parameters_hf(): +def test_to_dict_default(monkeypatch): + monkeypatch.delenv("HF_API_TOKEN", raising=False) + + component = NamedEntityExtractor( + backend=NamedEntityExtractorBackend.HUGGING_FACE, + model="dslim/bert-base-NER", + device=ComponentDevice.from_str("mps"), + ) + data = component.to_dict() + + assert data == { + "type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor", + "init_parameters": { + "backend": "HUGGING_FACE", + "model": "dslim/bert-base-NER", + "device": {"type": "single", "device": "mps"}, + "pipeline_kwargs": {"model": "dslim/bert-base-NER", "device": "mps", "task": "ner"}, + "token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False}, + }, + } + + +def test_to_dict_with_parameters(): + component = NamedEntityExtractor( + backend=NamedEntityExtractorBackend.HUGGING_FACE, + model="dslim/bert-base-NER", + device=ComponentDevice.from_str("mps"), + pipeline_kwargs={"model_kwargs": {"load_in_4bit": True}}, + token=Secret.from_env_var("ENV_VAR", strict=False), + ) + data = component.to_dict() + + assert data == { + "type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor", + "init_parameters": { + "backend": "HUGGING_FACE", + "model": "dslim/bert-base-NER", + "device": {"type": "single", "device": "mps"}, + "pipeline_kwargs": { + "model": "dslim/bert-base-NER", + "device": "mps", + "task": "ner", + "model_kwargs": {"load_in_4bit": True}, + }, + "token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, + }, + } + + +def test_named_entity_extractor_from_dict_no_default_parameters_hf(monkeypatch): + monkeypatch.delenv("HF_API_TOKEN", raising=False) + data = { "type": "haystack.components.extractors.named_entity_extractor.NamedEntityExtractor", "init_parameters": {"backend": "HUGGING_FACE", "model": "dslim/bert-base-NER"},