From eff53a91314b079a1be73e49c6dc8f053683fcff Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Mon, 8 Apr 2024 15:06:26 +0200 Subject: [PATCH 1/3] feat: `HuggingFaceAPIDocumentEmbedder` (#7485) * add HuggingFaceAPITextEmbedder * add HuggingFaceAPITextEmbedder * rm unneeded else * wip * small fixes * deprecation; reno * Apply suggestions from code review Co-authored-by: Madeesh Kannan * make params mandatory * changes requested * fix test * fix test --------- Co-authored-by: Madeesh Kannan --- docs/pydoc/config/embedders_api.yml | 1 + haystack/components/embedders/__init__.py | 2 + .../hugging_face_api_document_embedder.py | 263 +++++++++++++ .../hugging_face_tei_document_embedder.py | 7 + .../hfapidocembedder-4c3970d002275edb.yaml | 13 + ...test_hugging_face_api_document_embedder.py | 344 ++++++++++++++++++ 6 files changed, 630 insertions(+) create mode 100644 haystack/components/embedders/hugging_face_api_document_embedder.py create mode 100644 releasenotes/notes/hfapidocembedder-4c3970d002275edb.yaml create mode 100644 test/components/embedders/test_hugging_face_api_document_embedder.py diff --git a/docs/pydoc/config/embedders_api.yml b/docs/pydoc/config/embedders_api.yml index 326d98c881..5e11a9f1f9 100644 --- a/docs/pydoc/config/embedders_api.yml +++ b/docs/pydoc/config/embedders_api.yml @@ -7,6 +7,7 @@ loaders: "azure_text_embedder", "hugging_face_tei_document_embedder", "hugging_face_tei_text_embedder", + "hugging_face_api_document_embedder", "hugging_face_api_text_embedder", "openai_document_embedder", "openai_text_embedder", diff --git a/haystack/components/embedders/__init__.py b/haystack/components/embedders/__init__.py index a2e3d15a4f..5e73479769 100644 --- a/haystack/components/embedders/__init__.py +++ b/haystack/components/embedders/__init__.py @@ -1,5 +1,6 @@ from haystack.components.embedders.azure_document_embedder import AzureOpenAIDocumentEmbedder from haystack.components.embedders.azure_text_embedder import AzureOpenAITextEmbedder +from haystack.components.embedders.hugging_face_api_document_embedder import HuggingFaceAPIDocumentEmbedder from haystack.components.embedders.hugging_face_api_text_embedder import HuggingFaceAPITextEmbedder from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder @@ -12,6 +13,7 @@ "HuggingFaceTEITextEmbedder", "HuggingFaceTEIDocumentEmbedder", "HuggingFaceAPITextEmbedder", + "HuggingFaceAPIDocumentEmbedder", "SentenceTransformersTextEmbedder", "SentenceTransformersDocumentEmbedder", "OpenAITextEmbedder", diff --git a/haystack/components/embedders/hugging_face_api_document_embedder.py b/haystack/components/embedders/hugging_face_api_document_embedder.py new file mode 100644 index 0000000000..3f8ebfba04 --- /dev/null +++ b/haystack/components/embedders/hugging_face_api_document_embedder.py @@ -0,0 +1,263 @@ +import json +from typing import Any, Dict, List, Optional, Union + +from tqdm import tqdm + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses import Document +from haystack.lazy_imports import LazyImport +from haystack.utils import Secret, deserialize_secrets_inplace +from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model +from haystack.utils.url_validation import is_valid_http_url + +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import: + from huggingface_hub import InferenceClient + +logger = logging.getLogger(__name__) + + +@component +class HuggingFaceAPIDocumentEmbedder: + """ + This component can be used to compute Document embeddings using different Hugging Face APIs: + - [Free Serverless Inference API]((https://huggingface.co/inference-api) + - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints) + - [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference) + + + Example usage with the free Serverless Inference API: + ```python + from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder + from haystack.utils import Secret + from haystack.dataclasses import Document + + doc = Document(content="I love pizza!") + + doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="serverless_inference_api", + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("")) + + result = document_embedder.run([doc]) + print(result["documents"][0].embedding) + + # [0.017020374536514282, -0.023255806416273117, ...] + ``` + + Example usage with paid Inference Endpoints: + ```python + from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder + from haystack.utils import Secret + from haystack.dataclasses import Document + + doc = Document(content="I love pizza!") + + doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="inference_endpoints", + api_params={"url": ""}, + token=Secret.from_token("")) + + result = document_embedder.run([doc]) + print(result["documents"][0].embedding) + + # [0.017020374536514282, -0.023255806416273117, ...] + ``` + + Example usage with self-hosted Text Embeddings Inference: + ```python + from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder + from haystack.dataclasses import Document + + doc = Document(content="I love pizza!") + + doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="text_embeddings_inference", + api_params={"url": "http://localhost:8080"}) + + result = document_embedder.run([doc]) + print(result["documents"][0].embedding) + + # [0.017020374536514282, -0.023255806416273117, ...] + ``` + """ + + def __init__( + self, + api_type: Union[HFEmbeddingAPIType, str], + api_params: Dict[str, str], + token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), + prefix: str = "", + suffix: str = "", + truncate: bool = True, + normalize: bool = False, + batch_size: int = 32, + progress_bar: bool = True, + meta_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + ): + """ + Create an HuggingFaceAPITextEmbedder component. + + :param api_type: + The type of Hugging Face API to use. + :param api_params: + A dictionary containing the following keys: + - `model`: model ID on the Hugging Face Hub. Required when `api_type` is `SERVERLESS_INFERENCE_API`. + - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or `TEXT_EMBEDDINGS_INFERENCE`. + :param token: The HuggingFace token to use as HTTP bearer authorization. + You can find your HF token in your [account settings](https://huggingface.co/settings/tokens). + :param prefix: + A string to add at the beginning of each text. + :param suffix: + A string to add at the end of each text. + :param truncate: + Truncate input text from the end to the maximum length supported by the model. + This parameter takes effect when the `api_type` is `TEXT_EMBEDDINGS_INFERENCE`. + It also takes effect when the `api_type` is `INFERENCE_ENDPOINTS` and the backend is based on Text Embeddings Inference. + This parameter is ignored when the `api_type` is `SERVERLESS_INFERENCE_API` (it is always set to `True` and cannot be changed). + :param normalize: + Normalize the embeddings to unit length. + This parameter takes effect when the `api_type` is `TEXT_EMBEDDINGS_INFERENCE`. + It also takes effect when the `api_type` is `INFERENCE_ENDPOINTS` and the backend is based on Text Embeddings Inference. + This parameter is ignored when the `api_type` is `SERVERLESS_INFERENCE_API` (it is always set to `False` and cannot be changed). + :param batch_size: + Number of Documents to process at once. + :param progress_bar: + If `True` shows a progress bar when running. + :param meta_fields_to_embed: + List of meta fields that will be embedded along with the Document text. + :param embedding_separator: + Separator used to concatenate the meta fields to the Document text. + """ + huggingface_hub_import.check() + + if isinstance(api_type, str): + api_type = HFEmbeddingAPIType.from_str(api_type) + + api_params = api_params or {} + + if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: + model = api_params.get("model") + if model is None: + raise ValueError( + "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`." + ) + check_valid_model(model, HFModelType.EMBEDDING, token) + model_or_url = model + elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]: + url = api_params.get("url") + if url is None: + raise ValueError( + "To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` parameter in `api_params`." + ) + if not is_valid_http_url(url): + raise ValueError(f"Invalid URL: {url}") + model_or_url = url + + self.api_type = api_type + self.api_params = api_params + self.token = token + self.prefix = prefix + self.suffix = suffix + self.truncate = truncate + self.normalize = normalize + self.batch_size = batch_size + self.progress_bar = progress_bar + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.embedding_separator = embedding_separator + self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + api_type=self.api_type, + api_params=self.api_params, + prefix=self.prefix, + suffix=self.suffix, + token=self.token.to_dict() if self.token else None, + truncate=self.truncate, + normalize=self.normalize, + batch_size=self.batch_size, + progress_bar=self.progress_bar, + meta_fields_to_embed=self.meta_fields_to_embed, + embedding_separator=self.embedding_separator, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIDocumentEmbedder": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + return default_from_dict(cls, data) + + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: + """ + Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. + """ + texts_to_embed = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None + ] + + text_to_embed = ( + self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix + ) + + texts_to_embed.append(text_to_embed) + return texts_to_embed + + def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[float]]: + """ + Embed a list of texts in batches. + """ + + all_embeddings = [] + for i in tqdm( + range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" + ): + batch = texts_to_embed[i : i + batch_size] + response = self._client.post( + json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize}, + task="feature-extraction", + ) + embeddings = json.loads(response.decode()) + all_embeddings.extend(embeddings) + + return all_embeddings + + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]): + """ + Embed a list of Documents. + + :param documents: + Documents to embed. + + :returns: + A dictionary with the following keys: + - `documents`: Documents with embeddings + """ + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + raise TypeError( + "HuggingFaceAPIDocumentEmbedder expects a list of Documents as input." + " In case you want to embed a string, please use the HuggingFaceAPITextEmbedder." + ) + + texts_to_embed = self._prepare_texts_to_embed(documents=documents) + + embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size) + + for doc, emb in zip(documents, embeddings): + doc.embedding = emb + + return {"documents": documents} diff --git a/haystack/components/embedders/hugging_face_tei_document_embedder.py b/haystack/components/embedders/hugging_face_tei_document_embedder.py index 9a9803e45b..e721395a15 100644 --- a/haystack/components/embedders/hugging_face_tei_document_embedder.py +++ b/haystack/components/embedders/hugging_face_tei_document_embedder.py @@ -1,4 +1,5 @@ import json +import warnings from typing import Any, Dict, List, Optional from urllib.parse import urlparse @@ -91,6 +92,12 @@ def __init__( :param embedding_separator: Separator used to concatenate the meta fields to the Document text. """ + warnings.warn( + "`HuggingFaceTEIDocumentEmbedder` is deprecated and will be removed in Haystack 2.3.0." + "Use `HuggingFaceAPIDocumentEmbedder` instead.", + DeprecationWarning, + ) + huggingface_hub_import.check() if url: diff --git a/releasenotes/notes/hfapidocembedder-4c3970d002275edb.yaml b/releasenotes/notes/hfapidocembedder-4c3970d002275edb.yaml new file mode 100644 index 0000000000..8a5db68979 --- /dev/null +++ b/releasenotes/notes/hfapidocembedder-4c3970d002275edb.yaml @@ -0,0 +1,13 @@ +--- +features: + - | + Introduce `HuggingFaceAPIDocumentEmbedder`. + This component can be used to compute Document embeddings using different Hugging Face APIs: + - free Serverless Inference API + - paid Inference Endpoints + - self-hosted Text Embeddings Inference. + This embedder will replace the `HuggingFaceTEIDocumentEmbedder` in the future. +deprecations: + - | + Deprecate `HuggingFaceTEIDocumentEmbedder`. This component will be removed in Haystack 2.3.0. + Use `HuggingFaceAPIDocumentEmbedder` instead. diff --git a/test/components/embedders/test_hugging_face_api_document_embedder.py b/test/components/embedders/test_hugging_face_api_document_embedder.py new file mode 100644 index 0000000000..e083a59bd2 --- /dev/null +++ b/test/components/embedders/test_hugging_face_api_document_embedder.py @@ -0,0 +1,344 @@ +from unittest.mock import MagicMock, patch + +import pytest +from huggingface_hub.utils import RepositoryNotFoundError +from numpy import array, random + +from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder +from haystack.dataclasses import Document +from haystack.utils.auth import Secret +from haystack.utils.hf import HFEmbeddingAPIType + + +@pytest.fixture +def mock_check_valid_model(): + with patch( + "haystack.components.embedders.hugging_face_api_document_embedder.check_valid_model", + MagicMock(return_value=None), + ) as mock: + yield mock + + +def mock_embedding_generation(json, **kwargs): + response = str(array([random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode() + return response + + +class TestHuggingFaceAPIDocumentEmbedder: + def test_init_invalid_api_type(self): + with pytest.raises(ValueError): + HuggingFaceAPIDocumentEmbedder(api_type="invalid_api_type", api_params={}) + + def test_init_serverless(self, mock_check_valid_model): + model = "BAAI/bge-small-en-v1.5" + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model} + ) + + assert embedder.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API + assert embedder.api_params == {"model": model} + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.truncate + assert not embedder.normalize + assert embedder.batch_size == 32 + assert embedder.progress_bar + assert embedder.meta_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + + def test_init_serverless_invalid_model(self, mock_check_valid_model): + mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") + with pytest.raises(RepositoryNotFoundError): + HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"} + ) + + def test_init_serverless_no_model(self): + with pytest.raises(ValueError): + HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"} + ) + + def test_init_tei(self): + url = "https://some_model.com" + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"url": url} + ) + + assert embedder.api_type == HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE + assert embedder.api_params == {"url": url} + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.truncate + assert not embedder.normalize + assert embedder.batch_size == 32 + assert embedder.progress_bar + assert embedder.meta_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + + def test_init_tei_invalid_url(self): + with pytest.raises(ValueError): + HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"url": "invalid_url"} + ) + + def test_init_tei_no_url(self): + with pytest.raises(ValueError): + HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"param": "irrelevant"} + ) + + def test_to_dict(self, mock_check_valid_model): + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + prefix="prefix", + suffix="suffix", + truncate=False, + normalize=True, + batch_size=128, + progress_bar=False, + meta_fields_to_embed=["meta_field"], + embedding_separator=" ", + ) + + data = embedder.to_dict() + + assert data == { + "type": "haystack.components.embedders.hugging_face_api_document_embedder.HuggingFaceAPIDocumentEmbedder", + "init_parameters": { + "api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + "api_params": {"model": "BAAI/bge-small-en-v1.5"}, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "prefix": "prefix", + "suffix": "suffix", + "truncate": False, + "normalize": True, + "batch_size": 128, + "progress_bar": False, + "meta_fields_to_embed": ["meta_field"], + "embedding_separator": " ", + }, + } + + def test_from_dict(self, mock_check_valid_model): + data = { + "type": "haystack.components.embedders.hugging_face_api_document_embedder.HuggingFaceAPIDocumentEmbedder", + "init_parameters": { + "api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + "api_params": {"model": "BAAI/bge-small-en-v1.5"}, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "prefix": "prefix", + "suffix": "suffix", + "truncate": False, + "normalize": True, + "batch_size": 128, + "progress_bar": False, + "meta_fields_to_embed": ["meta_field"], + "embedding_separator": " ", + }, + } + + embedder = HuggingFaceAPIDocumentEmbedder.from_dict(data) + + assert embedder.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API + assert embedder.api_params == {"model": "BAAI/bge-small-en-v1.5"} + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + assert not embedder.truncate + assert embedder.normalize + assert embedder.batch_size == 128 + assert not embedder.progress_bar + assert embedder.meta_fields_to_embed == ["meta_field"] + assert embedder.embedding_separator == " " + + def test_prepare_texts_to_embed_w_metadata(self): + documents = [ + Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5) + ] + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, + api_params={"url": "https://some_model.com"}, + token=Secret.from_token("fake-api-token"), + meta_fields_to_embed=["meta_field"], + embedding_separator=" | ", + ) + + prepared_texts = embedder._prepare_texts_to_embed(documents) + + assert prepared_texts == [ + "meta_value 0 | document number 0: content", + "meta_value 1 | document number 1: content", + "meta_value 2 | document number 2: content", + "meta_value 3 | document number 3: content", + "meta_value 4 | document number 4: content", + ] + + def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model): + documents = [Document(content=f"document number {i}") for i in range(5)] + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, + api_params={"url": "https://some_model.com"}, + token=Secret.from_token("fake-api-token"), + prefix="my_prefix ", + suffix=" my_suffix", + ) + + prepared_texts = embedder._prepare_texts_to_embed(documents) + + assert prepared_texts == [ + "my_prefix document number 0 my_suffix", + "my_prefix document number 1 my_suffix", + "my_prefix document number 2 my_suffix", + "my_prefix document number 3 my_suffix", + "my_prefix document number 4 my_suffix", + ] + + def test_embed_batch(self, mock_check_valid_model): + texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] + + with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + mock_embedding_patch.side_effect = mock_embedding_generation + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + ) + embeddings = embedder._embed_batch(texts_to_embed=texts, batch_size=2) + + assert mock_embedding_patch.call_count == 3 + + assert isinstance(embeddings, list) + assert len(embeddings) == len(texts) + for embedding in embeddings: + assert isinstance(embedding, list) + assert len(embedding) == 384 + assert all(isinstance(x, float) for x in embedding) + + def test_run_wrong_input_format(self, mock_check_valid_model): + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} + ) + + list_integers_input = [1, 2, 3] + + with pytest.raises(TypeError): + embedder.run(text=list_integers_input) + + def test_run_on_empty_list(self, mock_check_valid_model): + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + ) + + empty_list_input = [] + result = embedder.run(documents=empty_list_input) + + assert result["documents"] is not None + assert not result["documents"] # empty list + + def test_run(self, mock_check_valid_model): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + mock_embedding_patch.side_effect = mock_embedding_generation + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + prefix="prefix ", + suffix=" suffix", + meta_fields_to_embed=["topic"], + embedding_separator=" | ", + ) + + result = embedder.run(documents=docs) + + mock_embedding_patch.assert_called_once_with( + json={ + "inputs": [ + "prefix Cuisine | I love cheese suffix", + "prefix ML | A transformer is a deep learning architecture suffix", + ], + "truncate": True, + "normalize": False, + }, + task="feature-extraction", + ) + documents_with_embeddings = result["documents"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 384 + assert all(isinstance(x, float) for x in doc.embedding) + + def test_run_custom_batch_size(self, mock_check_valid_model): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + mock_embedding_patch.side_effect = mock_embedding_generation + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + prefix="prefix ", + suffix=" suffix", + meta_fields_to_embed=["topic"], + embedding_separator=" | ", + batch_size=1, + ) + + result = embedder.run(documents=docs) + + assert mock_embedding_patch.call_count == 2 + + documents_with_embeddings = result["documents"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 384 + assert all(isinstance(x, float) for x in doc.embedding) + + @pytest.mark.flaky(reruns=5, reruns_delay=5) + @pytest.mark.integration + def test_live_run_serverless(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "sentence-transformers/all-MiniLM-L6-v2"}, + meta_fields_to_embed=["topic"], + embedding_separator=" | ", + ) + result = embedder.run(documents=docs) + documents_with_embeddings = result["documents"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 384 + assert all(isinstance(x, float) for x in doc.embedding) From be77e85f432940ab4f0d65cceb67b59595ff9f96 Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Mon, 8 Apr 2024 16:44:45 +0200 Subject: [PATCH 2/3] Fix inaccuracy in PromptBuilder docstring (#7503) --- haystack/components/builders/prompt_builder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/haystack/components/builders/prompt_builder.py b/haystack/components/builders/prompt_builder.py index f78610b06e..64b85d76a4 100644 --- a/haystack/components/builders/prompt_builder.py +++ b/haystack/components/builders/prompt_builder.py @@ -9,7 +9,9 @@ class PromptBuilder: """ PromptBuilder is a component that renders a prompt from a template string using Jinja2 templates. - The template variables found in the template string are used as input types for the component and are all required. + + The template variables found in the template string are used as input types for the component and are all optional. + If a template variable is not provided as an input, it will be replaced with an empty string in the rendered prompt. Usage example: ```python From 1b10a8304f12855c84e9dddc3e617a897e45cac6 Mon Sep 17 00:00:00 2001 From: "David S. Batista" Date: Mon, 8 Apr 2024 18:37:48 +0200 Subject: [PATCH 3/3] proposal: rag evaluation results presentation (#7462) * adding files * adding proposal in md * renaming proposal number * removing stuff * cleaning up * adding PR number and issue * updating proposal * updating proposal * Update proposals/text/7462-rag-evaluation.md Co-authored-by: Madeesh Kannan * changing name * PR comments * changing output to table format * adding user stories * Update proposals/text/7462-rag-evaluation.md Co-authored-by: Madeesh Kannan * adding user stories --------- Co-authored-by: Madeesh Kannan --- proposals/text/7462-rag-evaluation.md | 223 ++++++++++++++++++++++++++ 1 file changed, 223 insertions(+) create mode 100644 proposals/text/7462-rag-evaluation.md diff --git a/proposals/text/7462-rag-evaluation.md b/proposals/text/7462-rag-evaluation.md new file mode 100644 index 0000000000..21935ce88f --- /dev/null +++ b/proposals/text/7462-rag-evaluation.md @@ -0,0 +1,223 @@ +- Title: Proposal for presentation of evaluation results +- Decision driver: David S. Batista +- Start Date: 2024-04-03 +- Proposal PR: #7462 +- Github Issue or Discussion: https://github.com/deepset-ai/haystack/issues/7398 + +# Summary + +Add a new component to Haystack allowing users interact with the results of evaluating the performance of a RAG model. + + +# Motivation + +RAG models are one of them most popular use cases for Haystack. We are adding support for evaluations metrics, but there is no way to present the results of the evaluation. + + +# Detailed design + +The output results of an evaluation pipeline composed of `evaluator` components are passed to a `EvaluationResults` +(this is a placeholder name) which stores them internally and acts as an interface to access and present the results. + +The examples below are just for illustrative purposes and are subject to change. + +Example of the data structure that the `EvaluationResults` class will receive for initialization: + +```python + +data = { + "inputs": { + "query_id": ["53c3b3e6", "225f87f7"], + "question": ["What is the capital of France?", "What is the capital of Spain?"], + "contexts": ["wiki_France", "wiki_Spain"], + "answer": ["Paris", "Madrid"], + "predicted_answer": ["Paris", "Madrid"] + }, + "metrics": + [ + {"name": "reciprocal_rank", "scores": [0.378064, 0.534964, 0.216058, 0.778642]}, + {"name": "single_hit", "scores": [1, 1, 0, 1]}, + {"name": "multi_hit", "scores": [0.706125, 0.454976, 0.445512, 0.250522]}, + {"name": "context_relevance", "scores": [0.805466, 0.410251, 0.750070, 0.361332]}, + {"name": "faithfulness", "scores": [0.135581, 0.695974, 0.749861, 0.041999]}, + {"name": "semantic_answer_similarity", "scores": [0.971241, 0.159320, 0.019722, 1]} + ], + }, + +``` + +The `EvaluationResults` class provides the following methods to different types of users: + +Basic users: +- `individual_aggregate_score_report()` +- `comparative_aggregate_score_report()` + +Intermediate users: +- `individual_detailed_score_report()` +- `comparative_detailed_score_report()` + +Advanced users: +- `find_thresholds()` +- `find_inputs_below_threshold()` + + +### Methods description +An evaluation report that provides a summary of the performance of the model across all queries, showing the +aggregated scores for all available metrics. + +```python +def individual_aggregate_score_report(): +``` + +Example output + +```bash +{'Reciprocal Rank': 0.448, + 'Single Hit': 0.5, + 'Multi Hit': 0.540, + 'Context Relevance': 0.537, + 'Faithfulness': 0.452, + 'Semantic Answer Similarity': 0.478 + } + ``` + +A detailed evaluation report that provides the scores of all available metrics for all queries or a subset of queries. + +```python +def individual_detailed_score_report(queries: Union[List[str], str] = "all"): +``` + +Example output + +```bash +| question | context | answer | predicted_answer | reciprocal_rank | single_hit | multi_hit | context_relevance | faithfulness | semantic_answer_similarity | +|----------|---------|--------|------------------|-----------------|------------|-----------|-------------------|-------------|----------------------------| +| What is the capital of France? | wiki_France | Paris | Paris | 0.378064 | 1 | 0.706125 | 0.805466 | 0.135581 | 0.971241 | +| What is the capital of Spain? | wiki_Spain | Madrid | Madrid | 0.534964 | 1 | 0.454976 | 0.410251 | 0.695974 | 0.159320 | +``` + +### Comparative Evaluation Report + +A comparative summary that compares the performance of the model with another model based on the aggregated scores +for all available metrics. + +```python +def comparative_aggregate_score_report(self, other: "EvaluationResults"): +``` + +```bash +{ + "model_1": { + 'Reciprocal Rank': 0.448, + 'Single Hit': 0.5, + 'Multi Hit': 0.540, + 'Context Relevance': 0.537, + 'Faithfulness': 0.452, + 'Semantic Answer Similarity': 0.478 + }, + "model_2": { + 'Reciprocal Rank': 0.448, + 'Single Hit': 0.5, + 'Multi Hit': 0.540, + 'Context Relevance': 0.537, + 'Faithfulness': 0.452, + 'Semantic Answer Similarity': 0.478 + } +} + +``` + +A detailed comparative summary that compares the performance of the model with another model based on the scores of all +available metrics for all queries. + + +```python +def comparative_detailed_score_report(self, other: "EvaluationResults"): +``` + +```bash +| question | context | answer | predicted_answer_model_1 | predicted_answer_model_2 | reciprocal_rank_model_1 | reciprocal_rank_model_2 | single_hit_model_1 | single_hit_model_2 | multi_hit_model_1 | multi_hit_model_2 | context_relevance_model_1 | context_relevance_model_2 | faithfulness_model_1 | faithfulness_model_2 | semantic_answer_similarity_model_1 | semantic_answer_similarity_model_2 | +|----------|---------|--------|--------------------------|--------------------------|-------------------------|-------------------------|--------------------|--------------------|-------------------|-------------------|---------------------------|---------------------------|----------------------|----------------------|------------------------------------|------------------------------------| +| What is the capital of France? | wiki_France | Paris | Paris | Paris | 0.378064 | 0.378064 | 1 | 1 | 0.706125 | 0.706125 | 0.805466 | 0.805466 | 0.135581 | 0.135581 | 0.971241 | 0.971241 | +| What is the capital of Spain? | wiki_Spain | Madrid | Madrid | Madrid | 0.534964 | 0.534964 | 1 | 1 | 0.454976 | 0.454976 | 0.410251 | 0.410251 | 0.695974 | 0.695974 | 0.159320 | 0.159320 | +```` + + +Have a method to find interesting scores thresholds, typically used for error analysis, for all metrics available. +Some potentially interesting thresholds to find are: the 25th percentile, the 75th percentile, the mean , the median. + +```python +def find_thresholds(self, metrics: List[str]) -> Dict[str, float]: +``` + +```bash +data = { + "thresholds": ["25th percentile", "75th percentile", "median", "average"], + "reciprocal_rank": [0.378064, 0.534964, 0.216058, 0.778642], + "context_relevance": [0.805466, 0.410251, 0.750070, 0.361332], + "faithfulness": [0.135581, 0.695974, 0.749861, 0.041999], + "semantic_answer_similarity": [0.971241, 0.159320, 0.019722, 1], +} +```` + +Then have another method that + +```python +def find_inputs_below_threshold(self, metric: str, threshold: float): + """Get the all the queries with a score below a certain threshold for a given metric""" +``` + +# Drawbacks + +- Having the output in a format table may not be flexible enough, and maybe too verbose for datasets with a large number of queries. +- Maybe the option to export to a .csv file would be better than having the output in a table format. +- Maybe a JSON format would be better with the option for advanced users to do further analysis and visualization. + + +# Adoption strategy + +- Doesn't introduce any breaking change, it is a new feature that can be adopted by users as they see fit for their use cases. + +# How we teach this + +- A tutorial would be the best approach to teach users how to use this feature. +- Adding a new entry to the documentation. + +# User stories + +### 1. I would like to get a single summary score for my RAG pipeline so I can compare several pipeline configurations. + +Run `individual_aggregate_score_report()` and get the following output: + +```bash +{'Reciprocal Rank': 0.448, + 'Single Hit': 0.5, + 'Multi Hit': 0.540, + 'Context Relevance': 0.537, + 'Faithfulness': 0.452, + 'Semantic Answer Similarity': 0.478 + } + ``` + +### 2. I am not sure what evaluation metrics work best for my RAG pipeline, specially when using the more novel LLM-based + +Use `context relevance` or `faithfulness` + +### 3. My RAG pipeline has a low aggregate score, so I would like to see examples of specific inputs where the score was low to be able to diagnose what the issue could be. + +Let's say it's a low score in `reciprocal_rank` and one already has an idea of what "low" is a query/question, then simply run: + + find_inputs_below_threshold("reciprocal_rank", ) + +If the low score is in `reciprocal_rank` one can first get thresholds for this metric using: + + `find_thresholds(["reciprocal_rank"])` + +this will give: + +- 25th percentile: (Q1) the value below which 25% of the data falls. +- median percentile: (Q2) the value below which 50% of the data falls. +- 75th percentile: (Q3) the value below which 75% of the data falls. + +this can help to decide what is considered a low score, and then get, for instance, queries with a score below +the Q2 threshold using `find_inputs_below_threshold("context_relevance", threshold)`