From c50a835909e44727519fb51bd2d828c41b98f364 Mon Sep 17 00:00:00 2001 From: Sebastian Husch Lee Date: Tue, 10 Dec 2024 15:18:58 +0100 Subject: [PATCH] feat: Metadata extractor update (#147) * Updates * Fix doc string * Remove commented out code * Fix linting * Fix another lint * Fix typing * Fix tests * Update integration test * More tests * Update tests * Try adding mocks * Adding more tests * More tests * Update haystack_experimental/components/extractors/llm_metadata_extractor.py Co-authored-by: David S. Batista * Update haystack_experimental/components/extractors/llm_metadata_extractor.py Co-authored-by: David S. Batista * Update haystack_experimental/components/extractors/llm_metadata_extractor.py Co-authored-by: David S. Batista * Fix linting --------- Co-authored-by: David S. Batista --- .../extractors/llm_metadata_extractor.py | 318 ++++++++----- pyproject.toml | 2 +- .../extractors/test_llm_metadata_extractor.py | 432 +++++++++++++----- 3 files changed, 541 insertions(+), 211 deletions(-) diff --git a/haystack_experimental/components/extractors/llm_metadata_extractor.py b/haystack_experimental/components/extractors/llm_metadata_extractor.py index 32547045..8b1fd723 100644 --- a/haystack_experimental/components/extractors/llm_metadata_extractor.py +++ b/haystack_experimental/components/extractors/llm_metadata_extractor.py @@ -2,7 +2,9 @@ # # SPDX-License-Identifier: Apache-2.0 +import copy import json +from concurrent.futures import ThreadPoolExecutor from enum import Enum from typing import Any, Dict, List, Optional, Union @@ -11,15 +13,18 @@ from haystack.components.generators import AzureOpenAIGenerator, OpenAIGenerator from haystack.components.preprocessors import DocumentSplitter from haystack.lazy_imports import LazyImport -from haystack.utils import deserialize_secrets_inplace +from haystack.utils import deserialize_callable, deserialize_secrets_inplace +from jinja2 import meta +from jinja2.sandbox import SandboxedEnvironment from haystack_experimental.util.utils import expand_page_range -with LazyImport(message="Run 'pip install \"amazon-bedrock-haystack==1.0.2\"'") as amazon_bedrock_generator: +with LazyImport(message="Run 'pip install \"amazon-bedrock-haystack>=1.0.2\"'") as amazon_bedrock_generator: from haystack_integrations.components.generators.amazon_bedrock import AmazonBedrockGenerator -with LazyImport(message="Run 'pip install \"google-vertex-haystack==2.0.0\"'") as vertex_ai_gemini_generator: +with LazyImport(message="Run 'pip install \"google-vertex-haystack>=2.0.0\"'") as vertex_ai_gemini_generator: from haystack_integrations.components.generators.google_vertex import VertexAIGeminiGenerator + from vertexai.generative_models import GenerationConfig logger = logging.getLogger(__name__) @@ -56,12 +61,21 @@ class LLMMetadataExtractor: """ Extracts metadata from documents using a Large Language Model (LLM) from OpenAI. - The metadata is extracted by providing a prompt to n LLM that generates the metadata. + The metadata is extracted by providing a prompt to an LLM that generates the metadata. + + This component expects as input a list of documents and a prompt. The prompt should have a variable called + `document` that will point to a single document in the list of documents. So to access the content of the document, + you can use `{{ document.content }}` in the prompt. + + The component will run the LLM on each document in the list and extract metadata from the document. The metadata + will be added to the document's metadata field. If the LLM fails to extract metadata from a document, the document + will be added to the `failed_documents` list. The failed documents will have the keys `metadata_extraction_error` and + `metadata_extraction_response` in their metadata. These documents can be re-run with another extractor to + extract metadata by using the `metadata_extraction_response` and `metadata_extraction_error` in the prompt. ```python from haystack import Document - from haystack.components.generators import OpenAIGenerator - from haystack_experimental.components.extractors import LLMMetadataExtractor + from haystack_experimental.components.extractors.llm_metadata_extractor import LLMMetadataExtractor NER_PROMPT = ''' -Goal- @@ -79,10 +93,17 @@ class LLMMetadataExtractor: ###################### Example 1: entity_types: [organization, person, partnership, financial metric, product, service, industry, investment strategy, market trend] - text: - Another area of strength is our co-brand issuance. Visa is the primary network partner for eight of the top 10 co-brand partnerships in the US today and we are pleased that Visa has finalized a multi-year extension of our successful credit co-branded partnership with Alaska Airlines, a portfolio that benefits from a loyal customer base and high cross-border usage. - We have also had significant co-brand momentum in CEMEA. First, we launched a new co-brand card in partnership with Qatar Airways, British Airways and the National Bank of Kuwait. Second, we expanded our strong global Marriott relationship to launch Qatar's first hospitality co-branded card with Qatar Islamic Bank. Across the United Arab Emirates, we now have exclusive agreements with all the leading airlines marked by a recent agreement with Emirates Skywards. - And we also signed an inaugural Airline co-brand agreement in Morocco with Royal Air Maroc. Now newer digital issuers are equally + text: Another area of strength is our co-brand issuance. Visa is the primary network partner for eight of the top + 10 co-brand partnerships in the US today and we are pleased that Visa has finalized a multi-year extension of + our successful credit co-branded partnership with Alaska Airlines, a portfolio that benefits from a loyal customer + base and high cross-border usage. + We have also had significant co-brand momentum in CEMEA. First, we launched a new co-brand card in partnership + with Qatar Airways, British Airways and the National Bank of Kuwait. Second, we expanded our strong global + Marriott relationship to launch Qatar's first hospitality co-branded card with Qatar Islamic Bank. Across the + United Arab Emirates, we now have exclusive agreements with all the leading airlines marked by a recent + agreement with Emirates Skywards. + And we also signed an inaugural Airline co-brand agreement in Morocco with Royal Air Maroc. Now newer digital + issuers are equally ------------------------ output: {"entities": [{"entity": "Visa", "entity_type": "company"}, {"entity": "Alaska Airlines", "entity_type": "company"}, {"entity": "Qatar Airways", "entity_type": "company"}, {"entity": "British Airways", "entity_type": "company"}, {"entity": "National Bank of Kuwait", "entity_type": "company"}, {"entity": "Marriott", "entity_type": "company"}, {"entity": "Qatar Islamic Bank", "entity_type": "company"}, {"entity": "Emirates Skywards", "entity_type": "company"}, {"entity": "Royal Air Maroc", "entity_type": "company"}]} @@ -90,7 +111,7 @@ class LLMMetadataExtractor: -Real Data- ###################### entity_types: [company, organization, person, country, product, service] - text: {{input_text}} + text: {{ document.content }} ###################### output: ''' @@ -100,7 +121,23 @@ class LLMMetadataExtractor: Document(content="Hugging Face is a company founded in Paris, France and is known for its Transformers library") ] - extractor = LLMMetadataExtractor(prompt=NER_PROMPT, expected_keys=["entities"], generator_api="openai", prompt_variable='input_text') + extractor = LLMMetadataExtractor( + prompt=NER_PROMPT, + generator_api="openai", + generator_api_params={ + "generation_kwargs": { + "max_tokens": 500, + "temperature": 0.0, + "seed": 0, + "response_format": {"type": "json_object"}, + }, + "max_retries": 1, + "timeout": 60.0, + }, + expected_keys=["entities"], + raise_on_failure=False, + ) + extractor.warm_up() extractor.run(documents=docs) >> {'documents': [ Document(id=.., content: 'deepset was founded in 2018 in Berlin, and is known for its Haystack framework', @@ -112,6 +149,7 @@ class LLMMetadataExtractor: {'entity': 'France', 'entity_type': 'country'}, {'entity': 'Transformers', 'entity_type': 'product'} ]}) ] + 'failed_documents': [] } >> ``` @@ -120,45 +158,50 @@ class LLMMetadataExtractor: def __init__( # pylint: disable=R0917 self, prompt: str, - prompt_variable: str, - expected_keys: List[str], generator_api: Union[str, LLMProvider], generator_api_params: Optional[Dict[str, Any]] = None, + expected_keys: Optional[List[str]] = None, page_range: Optional[List[Union[str, int]]] = None, raise_on_failure: bool = False, + max_workers: int = 3, ): """ Initializes the LLMMetadataExtractor. :param prompt: The prompt to be used for the LLM. - :param prompt_variable: The variable in the prompt to be processed by the PromptBuilder. - :param expected_keys: The keys expected in the JSON output from the LLM. :param generator_api: The API provider for the LLM. Currently supported providers are: "openai", "openai_azure", "aws_bedrock", "google_vertex" :param generator_api_params: The parameters for the LLM generator. + :param expected_keys: The keys expected in the JSON output from the LLM. :param page_range: A range of pages to extract metadata from. For example, page_range=['1', '3'] will extract metadata from the first and third pages of each document. It also accepts printable range strings, e.g.: ['1-3', '5', '8', '10-12'] will extract metadata from pages 1, 2, 3, 5, 8, 10, 11, 12. If None, metadata will be extracted from the entire document for each document in the documents list. This parameter is optional and can be overridden in the `run` method. - :param raise_on_failure: Whether to raise an error on failure to validate JSON output. - :returns: - + :param raise_on_failure: Whether to raise an error on failure during the execution of the Generator or + validation of the JSON output. + :param max_workers: The maximum number of workers to use in the thread pool executor. """ self.prompt = prompt - self.prompt_variable = prompt_variable - self.builder = PromptBuilder(prompt, required_variables=[prompt_variable]) + ast = SandboxedEnvironment().parse(prompt) + template_variables = meta.find_undeclared_variables(ast) + variables = list(template_variables) + if len(variables) > 1 or variables[0] != "document": + raise ValueError( + f"Prompt must have exactly one variable called 'document'. Found {','.join(variables)} in the prompt." + ) + self.builder = PromptBuilder(prompt, required_variables=variables) + self.raise_on_failure = raise_on_failure - self.expected_keys = expected_keys - self.generator_api = generator_api if isinstance(generator_api, LLMProvider)\ + self.expected_keys = expected_keys or [] + self.generator_api = generator_api if isinstance(generator_api, LLMProvider) \ else LLMProvider.from_str(generator_api) self.generator_api_params = generator_api_params or {} self.llm_provider = self._init_generator(self.generator_api, self.generator_api_params) - if self.prompt_variable not in self.prompt: - raise ValueError(f"Prompt variable '{self.prompt_variable}' must be in the prompt.") self.splitter = DocumentSplitter(split_by="page", split_length=1) self.expanded_range = expand_page_range(page_range) if page_range else None + self.max_workers = max_workers @staticmethod def _init_generator( @@ -170,50 +213,23 @@ def _init_generator( """ if generator_api == LLMProvider.OPENAI: return OpenAIGenerator(**generator_api_params) - if generator_api == LLMProvider.OPENAI_AZURE: + elif generator_api == LLMProvider.OPENAI_AZURE: return AzureOpenAIGenerator(**generator_api_params) - if generator_api == LLMProvider.AWS_BEDROCK: + elif generator_api == LLMProvider.AWS_BEDROCK: amazon_bedrock_generator.check() return AmazonBedrockGenerator(**generator_api_params) - if generator_api == LLMProvider.GOOGLE_VERTEX: + elif generator_api == LLMProvider.GOOGLE_VERTEX: vertex_ai_gemini_generator.check() return VertexAIGeminiGenerator(**generator_api_params) - raise ValueError(f"Unsupported generator API: {generator_api}") + else: + raise ValueError(f"Unsupported generator API: {generator_api}") - def is_valid_json_and_has_expected_keys(self, expected: List[str], received: str) -> bool: + def warm_up(self): """ - Output must be a valid JSON with the expected keys. - - :param expected: - Names of expected outputs - :param received: - Names of received outputs - - :raises ValueError: - If the output is not a valid JSON with the expected keys: - - with `raise_on_failure` set to True a ValueError is raised. - - with `raise_on_failure` set to False a warning is issued and False is returned. - - :returns: - True if the received output is a valid JSON with the expected keys, False otherwise. + Warm up the LLM provider component. """ - try: - parsed_output = json.loads(received) - except json.JSONDecodeError: - msg = "Response from LLM is not a valid JSON." - if self.raise_on_failure: - raise ValueError(msg) - logger.warning(msg) - return False - - if not all(output in parsed_output for output in expected): - msg = f"Expected response from LLM to be a JSON with keys {expected}, got {received}." - if self.raise_on_failure: - raise ValueError(msg) - logger.warning(msg) - return False - - return True + if hasattr(self.llm_provider, "warm_up"): + self.llm_provider.warm_up() def to_dict(self) -> Dict[str, Any]: """ @@ -228,12 +244,12 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, prompt=self.prompt, - prompt_variable=self.prompt_variable, - expected_keys=self.expected_keys, - raise_on_failure=self.raise_on_failure, generator_api=self.generator_api.value, generator_api_params=llm_provider["init_parameters"], - page_range=self.expanded_range + expected_keys=self.expected_keys, + page_range=self.expanded_range, + raise_on_failure=self.raise_on_failure, + max_workers=self.max_workers, ) @classmethod @@ -248,37 +264,111 @@ def from_dict(cls, data: Dict[str, Any]) -> "LLMMetadataExtractor": """ init_parameters = data.get("init_parameters", {}) + if "generator_api" in init_parameters: data["init_parameters"]["generator_api"] = LLMProvider.from_str(data["init_parameters"]["generator_api"]) + if "generator_api_params" in init_parameters: - deserialize_secrets_inplace(data["init_parameters"]["generator_api_params"], keys=["api_key"]) + # Check all the keys that need to be deserialized + azure_openai_keys = ["azure_ad_token"] + aws_bedrock_keys = [ + "aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region_name", "aws_profile_name" + ] + deserialize_secrets_inplace( + data["init_parameters"]["generator_api_params"], + keys=["api_key"] + azure_openai_keys + aws_bedrock_keys, + ) + + # For VertexAI + if "generation_config" in init_parameters["generator_api_params"]: + data["init_parameters"]["generation_config"] = GenerationConfig.from_dict( + init_parameters["generator_api_params"]["generation_config"] + ) + + # For all + serialized_callback_handler = init_parameters["generator_api_params"].get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + return default_from_dict(cls, data) - def _extract_metadata_and_update_doc(self, document: Document, content: str): - """ - Extract metadata from the content and updates the document's metadata with the extracted metadata. + def _extract_metadata(self, llm_answer: str) -> Dict[str, Any]: + try: + parsed_metadata = json.loads(llm_answer) + except json.JSONDecodeError as e: + logger.warning( + "Response from the LLM is not valid JSON. Skipping metadata extraction. Received output: {response}", + response=llm_answer + ) + if self.raise_on_failure: + raise e + return {"error": "Response is not valid JSON. Received JSONDecodeError: " + str(e)} + + if not all(key in parsed_metadata for key in self.expected_keys): + logger.warning( + "Expected response from LLM to be a JSON with keys {expected_keys}, got {parsed_json}. " + "Continuing extraction with received output.", + expected_keys=self.expected_keys, + parsed_json=parsed_metadata + ) - If the extraction fails, i.e.: no JSON is returned by the LLM API, the error message will be stored in - `errors`. + return parsed_metadata - :param document: Document to be updated with the extracted metadata. - :param content: Content to extract metadata from. - """ - prompt_with_doc = self.builder.run( - template=self.prompt, - template_variables={self.prompt_variable: content} - ) - result = self.llm_provider.run(prompt=prompt_with_doc["prompt"]) - llm_answer = result["replies"][0] - if self.is_valid_json_and_has_expected_keys(expected=self.expected_keys, received=llm_answer): - extracted_metadata = json.loads(llm_answer) - for k in self.expected_keys: - document.meta[k] = extracted_metadata[k] - - @component.output_types(documents=List[Document], errors=Dict[str, Any]) + def _prepare_prompts( + self, + documents: List[Document], + expanded_range: Optional[List[int]] = None + ) -> List[Union[str, None]]: + all_prompts: List[Union[str, None]] = [] + for document in documents: + if not document.content: + logger.warning( + "Document {doc_id} has no content. Skipping metadata extraction.", + doc_id=document.id + ) + all_prompts.append(None) + continue + + if expanded_range: + doc_copy = copy.deepcopy(document) + pages = self.splitter.run(documents=[doc_copy]) + content = "" + for idx, page in enumerate(pages["documents"]): + if idx + 1 in expanded_range: + content += page.content + doc_copy.content = content + else: + doc_copy = document + + prompt_with_doc = self.builder.run( + template=self.prompt, + template_variables={"document": doc_copy} + ) + all_prompts.append(prompt_with_doc["prompt"]) + return all_prompts + + def _run_on_thread(self, prompt: Optional[str]) -> Dict[str, Any]: + # If prompt is None, return an empty dictionary + if prompt is None: + return {"replies": ["{}"]} + + try: + result = self.llm_provider.run(prompt=prompt) + except Exception as e: + logger.error( + "LLM {class_name} execution failed. Skipping metadata extraction. Failed with exception '{error}'.", + class_name=self.llm_provider.__class__.__name__, + error=e, + ) + if self.raise_on_failure: + raise e + result = {"error": "LLM failed with exception: " + str(e)} + return result + + @component.output_types(documents=List[Document], failed_documents=List[Document]) def run(self, documents: List[Document], page_range: Optional[List[Union[str, int]]] = None): """ - Extract metadata from documents using a Language Model. + Extract metadata from documents using a Large Language Model. If `page_range` is provided, the metadata will be extracted from the specified range of pages. This component will split the documents into pages and extract metadata from the specified range of pages. The metadata will be @@ -295,27 +385,47 @@ def run(self, documents: List[Document], page_range: Optional[List[Union[str, in documents list. :returns: A dictionary with the keys: - - "documents": The original list of documents updated with the extracted metadata. - - "errors": A dictionary with document IDs as keys and error messages as values. + - "documents": A list of documents that were successfully updated with the extracted metadata. + - "failed_documents": A list of documents that failed to extract metadata. These documents will have + "metadata_extraction_error" and "metadata_extraction_response" in their metadata. These documents can be + re-run with the extractor to extract metadata. """ + if len(documents) == 0: + logger.warning("No documents provided. Skipping metadata extraction.") + return {"documents": [], "failed_documents": []} - errors: Dict[str, Any] = {} expanded_range = self.expanded_range if page_range: expanded_range = expand_page_range(page_range) - for document in documents: - if not document.content: - logger.warning(f"Document {document.id} has no content. Skipping metadata extraction.") + # Create prompts for each document + all_prompts = self._prepare_prompts(documents=documents, expanded_range=expanded_range) + + # Run the LLM on each prompt + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + results = executor.map(self._run_on_thread, all_prompts) + + successful_documents = [] + failed_documents = [] + for document, result in zip(documents, results): + if "error" in result: + document.meta["metadata_extraction_error"] = result["error"] + document.meta["metadata_extraction_response"] = None + failed_documents.append(document) continue - if expanded_range: - pages = self.splitter.run(documents=[document]) - content = "" - for idx, page in enumerate(pages["documents"]): - if idx + 1 in expanded_range: - content += page.content + "\f" - else: - # extract metadata from the entire document - content = document.content - self._extract_metadata_and_update_doc(document, content) - return {"documents": documents, "errors": errors} + + parsed_metadata = self._extract_metadata(result["replies"][0]) + if "error" in parsed_metadata: + document.meta["metadata_extraction_error"] = parsed_metadata["error"] + document.meta["metadata_extraction_response"] = result["replies"][0] + failed_documents.append(document) + continue + + for key in parsed_metadata: + document.meta[key] = parsed_metadata[key] + # Remove metadata_extraction_error and metadata_extraction_response if present from previous runs + document.meta.pop("metadata_extraction_error", None) + document.meta.pop("metadata_extraction_response", None) + successful_documents.append(document) + + return {"documents": successful_documents, "failed_documents": failed_documents} diff --git a/pyproject.toml b/pyproject.toml index c21b26f9..8dd0e8d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ extra-dependencies = [ "opensearch-haystack", "opensearch-py[async]", # LLMMetadataExtractor dependencies - "amazon-bedrock-haystack>=1.0.2", + "amazon-bedrock-haystack>=1.1.1", "google-vertex-haystack>=2.0.0", ] diff --git a/test/components/extractors/test_llm_metadata_extractor.py b/test/components/extractors/test_llm_metadata_extractor.py index 094633fb..92d7deeb 100644 --- a/test/components/extractors/test_llm_metadata_extractor.py +++ b/test/components/extractors/test_llm_metadata_extractor.py @@ -1,5 +1,7 @@ +import boto3 import os import pytest +from unittest.mock import MagicMock from haystack import Pipeline, Document from haystack.components.builders import PromptBuilder @@ -10,138 +12,342 @@ class TestLLMMetadataExtractor: + @pytest.fixture + def boto3_session_mock(self, monkeypatch: pytest.MonkeyPatch) -> MagicMock: + mock = MagicMock() + monkeypatch.setattr(boto3, "Session", mock) + return mock def test_init_default(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor( - prompt="prompt {{test}}", + prompt="prompt {{document.content}}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI, - prompt_variable="test" ) assert isinstance(extractor.builder, PromptBuilder) assert extractor.generator_api == LLMProvider.OPENAI assert extractor.expected_keys == ["key1", "key2"] assert extractor.raise_on_failure is False - assert extractor.prompt_variable == "test" def test_init_with_parameters(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor( - prompt="prompt {{test}}", + prompt="prompt {{document.content}}", expected_keys=["key1", "key2"], raise_on_failure=True, generator_api=LLMProvider.OPENAI, generator_api_params={ - 'model': 'gpt-3.5-turbo', - 'generation_kwargs': {"temperature": 0.5} + "model": "gpt-3.5-turbo", + "generation_kwargs": {"temperature": 0.5}, }, - prompt_variable="test", - page_range=['1-5'] + page_range=["1-5"], ) assert isinstance(extractor.builder, PromptBuilder) assert extractor.expected_keys == ["key1", "key2"] assert extractor.raise_on_failure is True assert extractor.generator_api == LLMProvider.OPENAI assert extractor.generator_api_params == { - 'model': 'gpt-3.5-turbo', - 'generation_kwargs': {"temperature": 0.5} - } + "model": "gpt-3.5-turbo", + "generation_kwargs": {"temperature": 0.5}, + } assert extractor.expanded_range == [1, 2, 3, 4, 5] def test_init_missing_prompt_variable(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") with pytest.raises(ValueError): _ = LLMMetadataExtractor( - prompt="prompt {{test}}", + prompt="prompt {{ wrong_variable }}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI, - prompt_variable="test2" ) - def test_to_dict_default_params(self, monkeypatch): + def test_to_dict_openai(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor( - prompt="some prompt that was used with the LLM {{test}}", + prompt="some prompt that was used with the LLM {{document.content}}", expected_keys=["key1", "key2"], generator_api=LLMProvider.OPENAI, - prompt_variable="test", - generator_api_params={'model': 'gpt-4o-mini', 'generation_kwargs': {"temperature": 0.5}}, - raise_on_failure=True) - + generator_api_params={ + "model": "gpt-4o-mini", + "generation_kwargs": {"temperature": 0.5}, + }, + raise_on_failure=True, + ) extractor_dict = extractor.to_dict() assert extractor_dict == { - 'type': 'haystack_experimental.components.extractors.llm_metadata_extractor.LLMMetadataExtractor', - 'init_parameters': { - 'prompt': 'some prompt that was used with the LLM {{test}}', - 'expected_keys': ['key1', 'key2'], - 'raise_on_failure': True, - 'prompt_variable': 'test', - 'generator_api': 'openai', - 'page_range': None, - 'generator_api_params': { - 'api_base_url': None, - 'api_key': {'env_vars': ['OPENAI_API_KEY'],'strict': True,'type': 'env_var'}, - 'generation_kwargs': {"temperature": 0.5}, - 'model': 'gpt-4o-mini', - 'organization': None, - 'streaming_callback': None, - 'system_prompt': None, - }, - } + "type": "haystack_experimental.components.extractors.llm_metadata_extractor.LLMMetadataExtractor", + "init_parameters": { + "prompt": "some prompt that was used with the LLM {{document.content}}", + "expected_keys": ["key1", "key2"], + "raise_on_failure": True, + "generator_api": "openai", + "page_range": None, + "generator_api_params": { + "api_base_url": None, + "api_key": { + "env_vars": ["OPENAI_API_KEY"], + "strict": True, + "type": "env_var", + }, + "generation_kwargs": {"temperature": 0.5}, + "model": "gpt-4o-mini", + "organization": None, + "streaming_callback": None, + "system_prompt": None, + }, + "max_workers": 3, + }, } - def test_from_dict(self, monkeypatch): + def test_to_dict_aws_bedrock(self, boto3_session_mock): + extractor = LLMMetadataExtractor( + prompt="some prompt that was used with the LLM {{document.content}}", + expected_keys=["key1", "key2"], + generator_api=LLMProvider.AWS_BEDROCK, + generator_api_params={ + "model": "meta.llama.test", + "max_length": 100, + "truncate": False, + }, + raise_on_failure=True, + ) + extractor_dict = extractor.to_dict() + assert extractor_dict == { + "type": "haystack_experimental.components.extractors.llm_metadata_extractor.LLMMetadataExtractor", + "init_parameters": { + "prompt": "some prompt that was used with the LLM {{document.content}}", + "generator_api": "aws_bedrock", + "generator_api_params": { + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, + "model": "meta.llama.test", + "max_length": 100, + "truncate": False, + "streaming_callback": None, + "boto3_config": None, + }, + "expected_keys": ["key1", "key2"], + "page_range": None, + "raise_on_failure": True, + "max_workers": 3, + }, + } + + def test_from_dict_openai(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor_dict = { - 'type': 'haystack_experimental.components.extractors.llm_metadata_extractor.LLMMetadataExtractor', - 'init_parameters': { - 'prompt': 'some prompt that was used with the LLM {{test}}', - 'expected_keys': ['key1', 'key2'], - 'raise_on_failure': True, - 'prompt_variable': 'test', - 'generator_api': 'openai', - 'generator_api_params': { - 'api_base_url': None, - 'api_key': {'env_vars': ['OPENAI_API_KEY'], 'strict': True, 'type': 'env_var'}, - 'generation_kwargs': {}, - 'model': 'gpt-4o-mini', - 'organization': None, - 'streaming_callback': None, - 'system_prompt': None, - } - } + "type": "haystack_experimental.components.extractors.llm_metadata_extractor.LLMMetadataExtractor", + "init_parameters": { + "prompt": "some prompt that was used with the LLM {{document.content}}", + "expected_keys": ["key1", "key2"], + "raise_on_failure": True, + "generator_api": "openai", + "generator_api_params": { + "api_base_url": None, + "api_key": { + "env_vars": ["OPENAI_API_KEY"], + "strict": True, + "type": "env_var", + }, + "generation_kwargs": {}, + "model": "gpt-4o-mini", + "organization": None, + "streaming_callback": None, + "system_prompt": None, + }, + }, } extractor = LLMMetadataExtractor.from_dict(extractor_dict) assert extractor.raise_on_failure is True assert extractor.expected_keys == ["key1", "key2"] - assert extractor.prompt == "some prompt that was used with the LLM {{test}}" + assert ( + extractor.prompt + == "some prompt that was used with the LLM {{document.content}}" + ) assert extractor.generator_api == LLMProvider.OPENAI - def test_output_invalid_json_raise_on_failure_true(self, monkeypatch): + def test_from_dict_aws_bedrock(self, boto3_session_mock): + extractor_dict = { + "type": "haystack_experimental.components.extractors.llm_metadata_extractor.LLMMetadataExtractor", + "init_parameters": { + "prompt": "some prompt that was used with the LLM {{document.content}}", + "generator_api": "aws_bedrock", + "generator_api_params": { + "aws_access_key_id": { + "type": "env_var", + "env_vars": ["AWS_ACCESS_KEY_ID"], + "strict": False, + }, + "aws_secret_access_key": { + "type": "env_var", + "env_vars": ["AWS_SECRET_ACCESS_KEY"], + "strict": False, + }, + "aws_session_token": { + "type": "env_var", + "env_vars": ["AWS_SESSION_TOKEN"], + "strict": False, + }, + "aws_region_name": { + "type": "env_var", + "env_vars": ["AWS_DEFAULT_REGION"], + "strict": False, + }, + "aws_profile_name": { + "type": "env_var", + "env_vars": ["AWS_PROFILE"], + "strict": False, + }, + "model": "meta.llama.test", + "max_length": 200, + "truncate": False, + "streaming_callback": None, + "boto3_config": None, + }, + "expected_keys": ["key1", "key2"], + "page_range": None, + "raise_on_failure": True, + "max_workers": 3, + }, + } + extractor = LLMMetadataExtractor.from_dict(extractor_dict) + assert extractor.raise_on_failure is True + assert extractor.expected_keys == ["key1", "key2"] + assert ( + extractor.prompt + == "some prompt that was used with the LLM {{document.content}}" + ) + assert extractor.generator_api == LLMProvider.AWS_BEDROCK + assert extractor.llm_provider.max_length == 200 + assert extractor.llm_provider.truncate is False + assert extractor.llm_provider.model == "meta.llama.test" + + def test_warm_up(self, monkeypatch): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor( - prompt="prompt {{test}}", - expected_keys=["key1", "key2"], - generator_api=LLMProvider.OPENAI, - prompt_variable="test", - raise_on_failure=True - ) + prompt="prompt {{document.content}}", + generator_api=LLMProvider.OPENAI, + ) + assert extractor.warm_up() is None + + def test_extract_metadata(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + extractor = LLMMetadataExtractor( + prompt="prompt {{document.content}}", + generator_api=LLMProvider.OPENAI, + ) + result = extractor._extract_metadata(llm_answer='{"output": "valid json"}') + assert result == {"output": "valid json"} + + def test_extract_metadata_invalid_json(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + extractor = LLMMetadataExtractor( + prompt="prompt {{document.content}}", + generator_api=LLMProvider.OPENAI, + raise_on_failure=True, + ) with pytest.raises(ValueError): - extractor.is_valid_json_and_has_expected_keys(expected=["entities"], received="""{"json": "output"}""") + extractor._extract_metadata(llm_answer='{"output: "valid json"}') - def test_output_valid_json_not_expected_keys(self, monkeypatch): + def test_extract_metadata_missing_key(self, monkeypatch, caplog): monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") extractor = LLMMetadataExtractor( - prompt="prompt {{test}}", - expected_keys=["key1", "key2"], - generator_api=LLMProvider.OPENAI, - prompt_variable="test", - raise_on_failure=True + prompt="prompt {{document.content}}", + generator_api=LLMProvider.OPENAI, + expected_keys=["key1"], + ) + extractor._extract_metadata(llm_answer='{"output": "valid json"}') + assert "Expected response from LLM to be a JSON with keys" in caplog.text + + def test_prepare_prompts(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + extractor = LLMMetadataExtractor( + prompt="prompt {{document.content}}", + generator_api=LLMProvider.OPENAI, + ) + docs = [ + Document( + content="deepset was founded in 2018 in Berlin, and is known for its Haystack framework" + ), + Document( + content="Hugging Face is a company founded in Paris, France and is known for its Transformers library" + ), + ] + prompts = extractor._prepare_prompts(docs) + assert prompts == [ + "prompt deepset was founded in 2018 in Berlin, and is known for its Haystack framework", + "prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library", + ] + + def test_prepare_prompts_empty_document(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + extractor = LLMMetadataExtractor( + prompt="prompt {{document.content}}", + generator_api=LLMProvider.OPENAI, + ) + docs = [ + Document(content=""), + Document( + content="Hugging Face is a company founded in Paris, France and is known for its Transformers library" + ), + ] + prompts = extractor._prepare_prompts(docs) + assert prompts == [ + None, + "prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library", + ] + + def test_prepare_prompts_expanded_range(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + extractor = LLMMetadataExtractor( + prompt="prompt {{document.content}}", + generator_api=LLMProvider.OPENAI, + page_range=["1-2"], + ) + docs = [ + Document( + content="Hugging Face is a company founded in Paris, France and is known for its Transformers library\fPage 2\fPage 3" ) - with pytest.raises(ValueError): - extractor.is_valid_json_and_has_expected_keys(expected=["entities"], received="{'json': 'output'}") + ] + prompts = extractor._prepare_prompts(docs, expanded_range=[1, 2]) + assert prompts == [ + "prompt Hugging Face is a company founded in Paris, France and is known for its Transformers library\fPage 2\f", + ] + + def test_run_no_documents(self, monkeypatch): + monkeypatch.setenv("OPENAI_API_KEY", "test-api-key") + extractor = LLMMetadataExtractor( + prompt="prompt {{document.content}}", + generator_api=LLMProvider.OPENAI, + ) + result = extractor.run(documents=[]) + assert result["documents"] == [] + assert result["failed_documents"] == [] @pytest.mark.integration @pytest.mark.skipif( @@ -150,44 +356,58 @@ def test_output_valid_json_not_expected_keys(self, monkeypatch): ) def test_live_run(self): docs = [ - Document(content="deepset was founded in 2018 in Berlin, and is known for its Haystack framework"), - Document(content="Hugging Face is a company founded in Paris, France and is known for its Transformers library") + Document( + content="deepset was founded in 2018 in Berlin, and is known for its Haystack framework" + ), + Document( + content="Hugging Face is a company founded in Paris, France and is known for its Transformers library" + ), ] - ner_prompt = """ - Given a text and a list of entity types, identify all entities of those types from the text. - - -Steps- - 1. Identify all entities. For each identified entity, extract the following information: - - entity_name: Name of the entity, capitalized - - entity_type: One of the following types: [organization, person, product, service, industry] - Format each entity as {"entity": , "entity_type": } - - 2. Return output in a single list with all the entities identified in steps 1. - - -Examples- - ###################### - Example 1: - entity_types: [organization, product, service, industry, investment strategy, market trend] - text: - Another area of strength is our co-brand issuance. Visa is the primary network partner for eight of the top 10 co-brand partnerships in the US today and we are pleased that Visa has finalized a multi-year extension of our successful credit co-branded partnership with Alaska Airlines, a portfolio that benefits from a loyal customer base and high cross-border usage. - We have also had significant co-brand momentum in CEMEA. First, we launched a new co-brand card in partnership with Qatar Airways, British Airways and the National Bank of Kuwait. Second, we expanded our strong global Marriott relationship to launch Qatar's first hospitality co-branded card with Qatar Islamic Bank. Across the United Arab Emirates, we now have exclusive agreements with all the leading airlines marked by a recent agreement with Emirates Skywards. - And we also signed an inaugural Airline co-brand agreement in Morocco with Royal Air Maroc. Now newer digital issuers are equally - ------------------------ - output: - {"entities": [{"entity": "Visa", "entity_type": "company"}, {"entity": "Alaska Airlines", "entity_type": "company"}, {"entity": "Qatar Airways", "entity_type": "company"}, {"entity": "British Airways", "entity_type": "company"}, {"entity": "National Bank of Kuwait", "entity_type": "company"}, {"entity": "Marriott", "entity_type": "company"}, {"entity": "Qatar Islamic Bank", "entity_type": "company"}, {"entity": "Emirates Skywards", "entity_type": "company"}, {"entity": "Royal Air Maroc", "entity_type": "company"}]} - ############################# - - -Real Data- - ###################### - entity_types: [company, organization, person, country, product, service] - text: {{input_text}} - ###################### - output: - """ + ner_prompt = """-Goal- +Given text and a list of entity types, identify all entities of those types from the text. + +-Steps- +1. Identify all entities. For each identified entity, extract the following information: +- entity_name: Name of the entity, capitalized +- entity_type: One of the following types: [organization, product, service, industry] +Format each entity as {"entity": , "entity_type": } + +2. Return output in a single list with all the entities identified in steps 1. + +-Examples- +###################### +Example 1: +entity_types: [organization, person, partnership, financial metric, product, service, industry, investment strategy, market trend] +text: Another area of strength is our co-brand issuance. Visa is the primary network partner for eight of the top +10 co-brand partnerships in the US today and we are pleased that Visa has finalized a multi-year extension of +our successful credit co-branded partnership with Alaska Airlines, a portfolio that benefits from a loyal customer +base and high cross-border usage. +We have also had significant co-brand momentum in CEMEA. First, we launched a new co-brand card in partnership +with Qatar Airways, British Airways and the National Bank of Kuwait. Second, we expanded our strong global +Marriott relationship to launch Qatar's first hospitality co-branded card with Qatar Islamic Bank. Across the +United Arab Emirates, we now have exclusive agreements with all the leading airlines marked by a recent +agreement with Emirates Skywards. +And we also signed an inaugural Airline co-brand agreement in Morocco with Royal Air Maroc. Now newer digital +issuers are equally +------------------------ +output: +{"entities": [{"entity": "Visa", "entity_type": "company"}, {"entity": "Alaska Airlines", "entity_type": "company"}, {"entity": "Qatar Airways", "entity_type": "company"}, {"entity": "British Airways", "entity_type": "company"}, {"entity": "National Bank of Kuwait", "entity_type": "company"}, {"entity": "Marriott", "entity_type": "company"}, {"entity": "Qatar Islamic Bank", "entity_type": "company"}, {"entity": "Emirates Skywards", "entity_type": "company"}, {"entity": "Royal Air Maroc", "entity_type": "company"}]} +############################# +-Real Data- +###################### +entity_types: [company, organization, person, country, product, service] +text: {{ document.content }} +###################### +output: +""" doc_store = InMemoryDocumentStore() - extractor = LLMMetadataExtractor(prompt=ner_prompt, expected_keys=["entities"], prompt_variable="input_text", generator_api=LLMProvider.OPENAI) + extractor = LLMMetadataExtractor( + prompt=ner_prompt, + expected_keys=["entities"], + generator_api=LLMProvider.OPENAI, + ) writer = DocumentWriter(document_store=doc_store) pipeline = Pipeline() pipeline.add_component("extractor", extractor)