diff --git a/docs/pydoc/config/embedders_api.yml b/docs/pydoc/config/embedders_api.yml index 326d98c881..5e11a9f1f9 100644 --- a/docs/pydoc/config/embedders_api.yml +++ b/docs/pydoc/config/embedders_api.yml @@ -7,6 +7,7 @@ loaders: "azure_text_embedder", "hugging_face_tei_document_embedder", "hugging_face_tei_text_embedder", + "hugging_face_api_document_embedder", "hugging_face_api_text_embedder", "openai_document_embedder", "openai_text_embedder", diff --git a/haystack/components/embedders/__init__.py b/haystack/components/embedders/__init__.py index a2e3d15a4f..5e73479769 100644 --- a/haystack/components/embedders/__init__.py +++ b/haystack/components/embedders/__init__.py @@ -1,5 +1,6 @@ from haystack.components.embedders.azure_document_embedder import AzureOpenAIDocumentEmbedder from haystack.components.embedders.azure_text_embedder import AzureOpenAITextEmbedder +from haystack.components.embedders.hugging_face_api_document_embedder import HuggingFaceAPIDocumentEmbedder from haystack.components.embedders.hugging_face_api_text_embedder import HuggingFaceAPITextEmbedder from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder @@ -12,6 +13,7 @@ "HuggingFaceTEITextEmbedder", "HuggingFaceTEIDocumentEmbedder", "HuggingFaceAPITextEmbedder", + "HuggingFaceAPIDocumentEmbedder", "SentenceTransformersTextEmbedder", "SentenceTransformersDocumentEmbedder", "OpenAITextEmbedder", diff --git a/haystack/components/embedders/hugging_face_api_document_embedder.py b/haystack/components/embedders/hugging_face_api_document_embedder.py new file mode 100644 index 0000000000..3f8ebfba04 --- /dev/null +++ b/haystack/components/embedders/hugging_face_api_document_embedder.py @@ -0,0 +1,263 @@ +import json +from typing import Any, Dict, List, Optional, Union + +from tqdm import tqdm + +from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses import Document +from haystack.lazy_imports import LazyImport +from haystack.utils import Secret, deserialize_secrets_inplace +from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model +from haystack.utils.url_validation import is_valid_http_url + +with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import: + from huggingface_hub import InferenceClient + +logger = logging.getLogger(__name__) + + +@component +class HuggingFaceAPIDocumentEmbedder: + """ + This component can be used to compute Document embeddings using different Hugging Face APIs: + - [Free Serverless Inference API]((https://huggingface.co/inference-api) + - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints) + - [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference) + + + Example usage with the free Serverless Inference API: + ```python + from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder + from haystack.utils import Secret + from haystack.dataclasses import Document + + doc = Document(content="I love pizza!") + + doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="serverless_inference_api", + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("")) + + result = document_embedder.run([doc]) + print(result["documents"][0].embedding) + + # [0.017020374536514282, -0.023255806416273117, ...] + ``` + + Example usage with paid Inference Endpoints: + ```python + from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder + from haystack.utils import Secret + from haystack.dataclasses import Document + + doc = Document(content="I love pizza!") + + doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="inference_endpoints", + api_params={"url": ""}, + token=Secret.from_token("")) + + result = document_embedder.run([doc]) + print(result["documents"][0].embedding) + + # [0.017020374536514282, -0.023255806416273117, ...] + ``` + + Example usage with self-hosted Text Embeddings Inference: + ```python + from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder + from haystack.dataclasses import Document + + doc = Document(content="I love pizza!") + + doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="text_embeddings_inference", + api_params={"url": "http://localhost:8080"}) + + result = document_embedder.run([doc]) + print(result["documents"][0].embedding) + + # [0.017020374536514282, -0.023255806416273117, ...] + ``` + """ + + def __init__( + self, + api_type: Union[HFEmbeddingAPIType, str], + api_params: Dict[str, str], + token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False), + prefix: str = "", + suffix: str = "", + truncate: bool = True, + normalize: bool = False, + batch_size: int = 32, + progress_bar: bool = True, + meta_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + ): + """ + Create an HuggingFaceAPITextEmbedder component. + + :param api_type: + The type of Hugging Face API to use. + :param api_params: + A dictionary containing the following keys: + - `model`: model ID on the Hugging Face Hub. Required when `api_type` is `SERVERLESS_INFERENCE_API`. + - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or `TEXT_EMBEDDINGS_INFERENCE`. + :param token: The HuggingFace token to use as HTTP bearer authorization. + You can find your HF token in your [account settings](https://huggingface.co/settings/tokens). + :param prefix: + A string to add at the beginning of each text. + :param suffix: + A string to add at the end of each text. + :param truncate: + Truncate input text from the end to the maximum length supported by the model. + This parameter takes effect when the `api_type` is `TEXT_EMBEDDINGS_INFERENCE`. + It also takes effect when the `api_type` is `INFERENCE_ENDPOINTS` and the backend is based on Text Embeddings Inference. + This parameter is ignored when the `api_type` is `SERVERLESS_INFERENCE_API` (it is always set to `True` and cannot be changed). + :param normalize: + Normalize the embeddings to unit length. + This parameter takes effect when the `api_type` is `TEXT_EMBEDDINGS_INFERENCE`. + It also takes effect when the `api_type` is `INFERENCE_ENDPOINTS` and the backend is based on Text Embeddings Inference. + This parameter is ignored when the `api_type` is `SERVERLESS_INFERENCE_API` (it is always set to `False` and cannot be changed). + :param batch_size: + Number of Documents to process at once. + :param progress_bar: + If `True` shows a progress bar when running. + :param meta_fields_to_embed: + List of meta fields that will be embedded along with the Document text. + :param embedding_separator: + Separator used to concatenate the meta fields to the Document text. + """ + huggingface_hub_import.check() + + if isinstance(api_type, str): + api_type = HFEmbeddingAPIType.from_str(api_type) + + api_params = api_params or {} + + if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API: + model = api_params.get("model") + if model is None: + raise ValueError( + "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`." + ) + check_valid_model(model, HFModelType.EMBEDDING, token) + model_or_url = model + elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]: + url = api_params.get("url") + if url is None: + raise ValueError( + "To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` parameter in `api_params`." + ) + if not is_valid_http_url(url): + raise ValueError(f"Invalid URL: {url}") + model_or_url = url + + self.api_type = api_type + self.api_params = api_params + self.token = token + self.prefix = prefix + self.suffix = suffix + self.truncate = truncate + self.normalize = normalize + self.batch_size = batch_size + self.progress_bar = progress_bar + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.embedding_separator = embedding_separator + self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes the component to a dictionary. + + :returns: + Dictionary with serialized data. + """ + return default_to_dict( + self, + api_type=self.api_type, + api_params=self.api_params, + prefix=self.prefix, + suffix=self.suffix, + token=self.token.to_dict() if self.token else None, + truncate=self.truncate, + normalize=self.normalize, + batch_size=self.batch_size, + progress_bar=self.progress_bar, + meta_fields_to_embed=self.meta_fields_to_embed, + embedding_separator=self.embedding_separator, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIDocumentEmbedder": + """ + Deserializes the component from a dictionary. + + :param data: + Dictionary to deserialize from. + :returns: + Deserialized component. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + return default_from_dict(cls, data) + + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: + """ + Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. + """ + texts_to_embed = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None + ] + + text_to_embed = ( + self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix + ) + + texts_to_embed.append(text_to_embed) + return texts_to_embed + + def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[float]]: + """ + Embed a list of texts in batches. + """ + + all_embeddings = [] + for i in tqdm( + range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings" + ): + batch = texts_to_embed[i : i + batch_size] + response = self._client.post( + json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize}, + task="feature-extraction", + ) + embeddings = json.loads(response.decode()) + all_embeddings.extend(embeddings) + + return all_embeddings + + @component.output_types(documents=List[Document]) + def run(self, documents: List[Document]): + """ + Embed a list of Documents. + + :param documents: + Documents to embed. + + :returns: + A dictionary with the following keys: + - `documents`: Documents with embeddings + """ + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + raise TypeError( + "HuggingFaceAPIDocumentEmbedder expects a list of Documents as input." + " In case you want to embed a string, please use the HuggingFaceAPITextEmbedder." + ) + + texts_to_embed = self._prepare_texts_to_embed(documents=documents) + + embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size) + + for doc, emb in zip(documents, embeddings): + doc.embedding = emb + + return {"documents": documents} diff --git a/haystack/components/embedders/hugging_face_tei_document_embedder.py b/haystack/components/embedders/hugging_face_tei_document_embedder.py index 9a9803e45b..e721395a15 100644 --- a/haystack/components/embedders/hugging_face_tei_document_embedder.py +++ b/haystack/components/embedders/hugging_face_tei_document_embedder.py @@ -1,4 +1,5 @@ import json +import warnings from typing import Any, Dict, List, Optional from urllib.parse import urlparse @@ -91,6 +92,12 @@ def __init__( :param embedding_separator: Separator used to concatenate the meta fields to the Document text. """ + warnings.warn( + "`HuggingFaceTEIDocumentEmbedder` is deprecated and will be removed in Haystack 2.3.0." + "Use `HuggingFaceAPIDocumentEmbedder` instead.", + DeprecationWarning, + ) + huggingface_hub_import.check() if url: diff --git a/releasenotes/notes/hfapidocembedder-4c3970d002275edb.yaml b/releasenotes/notes/hfapidocembedder-4c3970d002275edb.yaml new file mode 100644 index 0000000000..8a5db68979 --- /dev/null +++ b/releasenotes/notes/hfapidocembedder-4c3970d002275edb.yaml @@ -0,0 +1,13 @@ +--- +features: + - | + Introduce `HuggingFaceAPIDocumentEmbedder`. + This component can be used to compute Document embeddings using different Hugging Face APIs: + - free Serverless Inference API + - paid Inference Endpoints + - self-hosted Text Embeddings Inference. + This embedder will replace the `HuggingFaceTEIDocumentEmbedder` in the future. +deprecations: + - | + Deprecate `HuggingFaceTEIDocumentEmbedder`. This component will be removed in Haystack 2.3.0. + Use `HuggingFaceAPIDocumentEmbedder` instead. diff --git a/test/components/embedders/test_hugging_face_api_document_embedder.py b/test/components/embedders/test_hugging_face_api_document_embedder.py new file mode 100644 index 0000000000..e083a59bd2 --- /dev/null +++ b/test/components/embedders/test_hugging_face_api_document_embedder.py @@ -0,0 +1,344 @@ +from unittest.mock import MagicMock, patch + +import pytest +from huggingface_hub.utils import RepositoryNotFoundError +from numpy import array, random + +from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder +from haystack.dataclasses import Document +from haystack.utils.auth import Secret +from haystack.utils.hf import HFEmbeddingAPIType + + +@pytest.fixture +def mock_check_valid_model(): + with patch( + "haystack.components.embedders.hugging_face_api_document_embedder.check_valid_model", + MagicMock(return_value=None), + ) as mock: + yield mock + + +def mock_embedding_generation(json, **kwargs): + response = str(array([random.rand(384) for i in range(len(json["inputs"]))]).tolist()).encode() + return response + + +class TestHuggingFaceAPIDocumentEmbedder: + def test_init_invalid_api_type(self): + with pytest.raises(ValueError): + HuggingFaceAPIDocumentEmbedder(api_type="invalid_api_type", api_params={}) + + def test_init_serverless(self, mock_check_valid_model): + model = "BAAI/bge-small-en-v1.5" + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": model} + ) + + assert embedder.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API + assert embedder.api_params == {"model": model} + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.truncate + assert not embedder.normalize + assert embedder.batch_size == 32 + assert embedder.progress_bar + assert embedder.meta_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + + def test_init_serverless_invalid_model(self, mock_check_valid_model): + mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") + with pytest.raises(RepositoryNotFoundError): + HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"} + ) + + def test_init_serverless_no_model(self): + with pytest.raises(ValueError): + HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"} + ) + + def test_init_tei(self): + url = "https://some_model.com" + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"url": url} + ) + + assert embedder.api_type == HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE + assert embedder.api_params == {"url": url} + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.truncate + assert not embedder.normalize + assert embedder.batch_size == 32 + assert embedder.progress_bar + assert embedder.meta_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + + def test_init_tei_invalid_url(self): + with pytest.raises(ValueError): + HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"url": "invalid_url"} + ) + + def test_init_tei_no_url(self): + with pytest.raises(ValueError): + HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, api_params={"param": "irrelevant"} + ) + + def test_to_dict(self, mock_check_valid_model): + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + prefix="prefix", + suffix="suffix", + truncate=False, + normalize=True, + batch_size=128, + progress_bar=False, + meta_fields_to_embed=["meta_field"], + embedding_separator=" ", + ) + + data = embedder.to_dict() + + assert data == { + "type": "haystack.components.embedders.hugging_face_api_document_embedder.HuggingFaceAPIDocumentEmbedder", + "init_parameters": { + "api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + "api_params": {"model": "BAAI/bge-small-en-v1.5"}, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "prefix": "prefix", + "suffix": "suffix", + "truncate": False, + "normalize": True, + "batch_size": 128, + "progress_bar": False, + "meta_fields_to_embed": ["meta_field"], + "embedding_separator": " ", + }, + } + + def test_from_dict(self, mock_check_valid_model): + data = { + "type": "haystack.components.embedders.hugging_face_api_document_embedder.HuggingFaceAPIDocumentEmbedder", + "init_parameters": { + "api_type": HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + "api_params": {"model": "BAAI/bge-small-en-v1.5"}, + "token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"}, + "prefix": "prefix", + "suffix": "suffix", + "truncate": False, + "normalize": True, + "batch_size": 128, + "progress_bar": False, + "meta_fields_to_embed": ["meta_field"], + "embedding_separator": " ", + }, + } + + embedder = HuggingFaceAPIDocumentEmbedder.from_dict(data) + + assert embedder.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API + assert embedder.api_params == {"model": "BAAI/bge-small-en-v1.5"} + assert embedder.prefix == "prefix" + assert embedder.suffix == "suffix" + assert not embedder.truncate + assert embedder.normalize + assert embedder.batch_size == 128 + assert not embedder.progress_bar + assert embedder.meta_fields_to_embed == ["meta_field"] + assert embedder.embedding_separator == " " + + def test_prepare_texts_to_embed_w_metadata(self): + documents = [ + Document(content=f"document number {i}: content", meta={"meta_field": f"meta_value {i}"}) for i in range(5) + ] + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, + api_params={"url": "https://some_model.com"}, + token=Secret.from_token("fake-api-token"), + meta_fields_to_embed=["meta_field"], + embedding_separator=" | ", + ) + + prepared_texts = embedder._prepare_texts_to_embed(documents) + + assert prepared_texts == [ + "meta_value 0 | document number 0: content", + "meta_value 1 | document number 1: content", + "meta_value 2 | document number 2: content", + "meta_value 3 | document number 3: content", + "meta_value 4 | document number 4: content", + ] + + def test_prepare_texts_to_embed_w_suffix(self, mock_check_valid_model): + documents = [Document(content=f"document number {i}") for i in range(5)] + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE, + api_params={"url": "https://some_model.com"}, + token=Secret.from_token("fake-api-token"), + prefix="my_prefix ", + suffix=" my_suffix", + ) + + prepared_texts = embedder._prepare_texts_to_embed(documents) + + assert prepared_texts == [ + "my_prefix document number 0 my_suffix", + "my_prefix document number 1 my_suffix", + "my_prefix document number 2 my_suffix", + "my_prefix document number 3 my_suffix", + "my_prefix document number 4 my_suffix", + ] + + def test_embed_batch(self, mock_check_valid_model): + texts = ["text 1", "text 2", "text 3", "text 4", "text 5"] + + with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + mock_embedding_patch.side_effect = mock_embedding_generation + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + ) + embeddings = embedder._embed_batch(texts_to_embed=texts, batch_size=2) + + assert mock_embedding_patch.call_count == 3 + + assert isinstance(embeddings, list) + assert len(embeddings) == len(texts) + for embedding in embeddings: + assert isinstance(embedding, list) + assert len(embedding) == 384 + assert all(isinstance(x, float) for x in embedding) + + def test_run_wrong_input_format(self, mock_check_valid_model): + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "BAAI/bge-small-en-v1.5"} + ) + + list_integers_input = [1, 2, 3] + + with pytest.raises(TypeError): + embedder.run(text=list_integers_input) + + def test_run_on_empty_list(self, mock_check_valid_model): + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + ) + + empty_list_input = [] + result = embedder.run(documents=empty_list_input) + + assert result["documents"] is not None + assert not result["documents"] # empty list + + def test_run(self, mock_check_valid_model): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + mock_embedding_patch.side_effect = mock_embedding_generation + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + prefix="prefix ", + suffix=" suffix", + meta_fields_to_embed=["topic"], + embedding_separator=" | ", + ) + + result = embedder.run(documents=docs) + + mock_embedding_patch.assert_called_once_with( + json={ + "inputs": [ + "prefix Cuisine | I love cheese suffix", + "prefix ML | A transformer is a deep learning architecture suffix", + ], + "truncate": True, + "normalize": False, + }, + task="feature-extraction", + ) + documents_with_embeddings = result["documents"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 384 + assert all(isinstance(x, float) for x in doc.embedding) + + def test_run_custom_batch_size(self, mock_check_valid_model): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + with patch("huggingface_hub.InferenceClient.post") as mock_embedding_patch: + mock_embedding_patch.side_effect = mock_embedding_generation + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "BAAI/bge-small-en-v1.5"}, + token=Secret.from_token("fake-api-token"), + prefix="prefix ", + suffix=" suffix", + meta_fields_to_embed=["topic"], + embedding_separator=" | ", + batch_size=1, + ) + + result = embedder.run(documents=docs) + + assert mock_embedding_patch.call_count == 2 + + documents_with_embeddings = result["documents"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 384 + assert all(isinstance(x, float) for x in doc.embedding) + + @pytest.mark.flaky(reruns=5, reruns_delay=5) + @pytest.mark.integration + def test_live_run_serverless(self): + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + embedder = HuggingFaceAPIDocumentEmbedder( + api_type=HFEmbeddingAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "sentence-transformers/all-MiniLM-L6-v2"}, + meta_fields_to_embed=["topic"], + embedding_separator=" | ", + ) + result = embedder.run(documents=docs) + documents_with_embeddings = result["documents"] + + assert isinstance(documents_with_embeddings, list) + assert len(documents_with_embeddings) == len(docs) + for doc in documents_with_embeddings: + assert isinstance(doc, Document) + assert isinstance(doc.embedding, list) + assert len(doc.embedding) == 384 + assert all(isinstance(x, float) for x in doc.embedding)