Skip to content

Commit

Permalink
feat: HuggingFaceAPIDocumentEmbedder (#7485)
Browse files Browse the repository at this point in the history
* add HuggingFaceAPITextEmbedder

* add HuggingFaceAPITextEmbedder

* rm unneeded else

* wip

* small fixes

* deprecation; reno

* Apply suggestions from code review

Co-authored-by: Madeesh Kannan <[email protected]>

* make params mandatory

* changes requested

* fix test

* fix test

---------

Co-authored-by: Madeesh Kannan <[email protected]>
  • Loading branch information
anakin87 and shadeMe authored Apr 8, 2024
1 parent c91bd49 commit eff53a9
Show file tree
Hide file tree
Showing 6 changed files with 630 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/pydoc/config/embedders_api.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ loaders:
"azure_text_embedder",
"hugging_face_tei_document_embedder",
"hugging_face_tei_text_embedder",
"hugging_face_api_document_embedder",
"hugging_face_api_text_embedder",
"openai_document_embedder",
"openai_text_embedder",
Expand Down
2 changes: 2 additions & 0 deletions haystack/components/embedders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from haystack.components.embedders.azure_document_embedder import AzureOpenAIDocumentEmbedder
from haystack.components.embedders.azure_text_embedder import AzureOpenAITextEmbedder
from haystack.components.embedders.hugging_face_api_document_embedder import HuggingFaceAPIDocumentEmbedder
from haystack.components.embedders.hugging_face_api_text_embedder import HuggingFaceAPITextEmbedder
from haystack.components.embedders.hugging_face_tei_document_embedder import HuggingFaceTEIDocumentEmbedder
from haystack.components.embedders.hugging_face_tei_text_embedder import HuggingFaceTEITextEmbedder
Expand All @@ -12,6 +13,7 @@
"HuggingFaceTEITextEmbedder",
"HuggingFaceTEIDocumentEmbedder",
"HuggingFaceAPITextEmbedder",
"HuggingFaceAPIDocumentEmbedder",
"SentenceTransformersTextEmbedder",
"SentenceTransformersDocumentEmbedder",
"OpenAITextEmbedder",
Expand Down
263 changes: 263 additions & 0 deletions haystack/components/embedders/hugging_face_api_document_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
import json
from typing import Any, Dict, List, Optional, Union

from tqdm import tqdm

from haystack import component, default_from_dict, default_to_dict, logging
from haystack.dataclasses import Document
from haystack.lazy_imports import LazyImport
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack.utils.hf import HFEmbeddingAPIType, HFModelType, check_valid_model
from haystack.utils.url_validation import is_valid_http_url

with LazyImport(message="Run 'pip install \"huggingface_hub>=0.22.0\"'") as huggingface_hub_import:
from huggingface_hub import InferenceClient

logger = logging.getLogger(__name__)


@component
class HuggingFaceAPIDocumentEmbedder:
"""
This component can be used to compute Document embeddings using different Hugging Face APIs:
- [Free Serverless Inference API]((https://huggingface.co/inference-api)
- [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
- [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
Example usage with the free Serverless Inference API:
```python
from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
from haystack.utils import Secret
from haystack.dataclasses import Document
doc = Document(content="I love pizza!")
doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="serverless_inference_api",
api_params={"model": "BAAI/bge-small-en-v1.5"},
token=Secret.from_token("<your-api-key>"))
result = document_embedder.run([doc])
print(result["documents"][0].embedding)
# [0.017020374536514282, -0.023255806416273117, ...]
```
Example usage with paid Inference Endpoints:
```python
from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
from haystack.utils import Secret
from haystack.dataclasses import Document
doc = Document(content="I love pizza!")
doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="inference_endpoints",
api_params={"url": "<your-inference-endpoint-url>"},
token=Secret.from_token("<your-api-key>"))
result = document_embedder.run([doc])
print(result["documents"][0].embedding)
# [0.017020374536514282, -0.023255806416273117, ...]
```
Example usage with self-hosted Text Embeddings Inference:
```python
from haystack.components.embedders import HuggingFaceAPIDocumentEmbedder
from haystack.dataclasses import Document
doc = Document(content="I love pizza!")
doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="text_embeddings_inference",
api_params={"url": "http://localhost:8080"})
result = document_embedder.run([doc])
print(result["documents"][0].embedding)
# [0.017020374536514282, -0.023255806416273117, ...]
```
"""

def __init__(
self,
api_type: Union[HFEmbeddingAPIType, str],
api_params: Dict[str, str],
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
prefix: str = "",
suffix: str = "",
truncate: bool = True,
normalize: bool = False,
batch_size: int = 32,
progress_bar: bool = True,
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
"""
Create an HuggingFaceAPITextEmbedder component.
:param api_type:
The type of Hugging Face API to use.
:param api_params:
A dictionary containing the following keys:
- `model`: model ID on the Hugging Face Hub. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or `TEXT_EMBEDDINGS_INFERENCE`.
:param token: The HuggingFace token to use as HTTP bearer authorization.
You can find your HF token in your [account settings](https://huggingface.co/settings/tokens).
:param prefix:
A string to add at the beginning of each text.
:param suffix:
A string to add at the end of each text.
:param truncate:
Truncate input text from the end to the maximum length supported by the model.
This parameter takes effect when the `api_type` is `TEXT_EMBEDDINGS_INFERENCE`.
It also takes effect when the `api_type` is `INFERENCE_ENDPOINTS` and the backend is based on Text Embeddings Inference.
This parameter is ignored when the `api_type` is `SERVERLESS_INFERENCE_API` (it is always set to `True` and cannot be changed).
:param normalize:
Normalize the embeddings to unit length.
This parameter takes effect when the `api_type` is `TEXT_EMBEDDINGS_INFERENCE`.
It also takes effect when the `api_type` is `INFERENCE_ENDPOINTS` and the backend is based on Text Embeddings Inference.
This parameter is ignored when the `api_type` is `SERVERLESS_INFERENCE_API` (it is always set to `False` and cannot be changed).
:param batch_size:
Number of Documents to process at once.
:param progress_bar:
If `True` shows a progress bar when running.
:param meta_fields_to_embed:
List of meta fields that will be embedded along with the Document text.
:param embedding_separator:
Separator used to concatenate the meta fields to the Document text.
"""
huggingface_hub_import.check()

if isinstance(api_type, str):
api_type = HFEmbeddingAPIType.from_str(api_type)

api_params = api_params or {}

if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
model = api_params.get("model")
if model is None:
raise ValueError(
"To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
)
check_valid_model(model, HFModelType.EMBEDDING, token)
model_or_url = model
elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]:
url = api_params.get("url")
if url is None:
raise ValueError(
"To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` parameter in `api_params`."
)
if not is_valid_http_url(url):
raise ValueError(f"Invalid URL: {url}")
model_or_url = url

self.api_type = api_type
self.api_params = api_params
self.token = token
self.prefix = prefix
self.suffix = suffix
self.truncate = truncate
self.normalize = normalize
self.batch_size = batch_size
self.progress_bar = progress_bar
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)

def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
api_type=self.api_type,
api_params=self.api_params,
prefix=self.prefix,
suffix=self.suffix,
token=self.token.to_dict() if self.token else None,
truncate=self.truncate,
normalize=self.normalize,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIDocumentEmbedder":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
"""
texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
]

text_to_embed = (
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
)

texts_to_embed.append(text_to_embed)
return texts_to_embed

def _embed_batch(self, texts_to_embed: List[str], batch_size: int) -> List[List[float]]:
"""
Embed a list of texts in batches.
"""

all_embeddings = []
for i in tqdm(
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
batch = texts_to_embed[i : i + batch_size]
response = self._client.post(
json={"inputs": batch, "truncate": self.truncate, "normalize": self.normalize},
task="feature-extraction",
)
embeddings = json.loads(response.decode())
all_embeddings.extend(embeddings)

return all_embeddings

@component.output_types(documents=List[Document])
def run(self, documents: List[Document]):
"""
Embed a list of Documents.
:param documents:
Documents to embed.
:returns:
A dictionary with the following keys:
- `documents`: Documents with embeddings
"""
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
raise TypeError(
"HuggingFaceAPIDocumentEmbedder expects a list of Documents as input."
" In case you want to embed a string, please use the HuggingFaceAPITextEmbedder."
)

texts_to_embed = self._prepare_texts_to_embed(documents=documents)

embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)

for doc, emb in zip(documents, embeddings):
doc.embedding = emb

return {"documents": documents}
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import warnings
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse

Expand Down Expand Up @@ -91,6 +92,12 @@ def __init__(
:param embedding_separator:
Separator used to concatenate the meta fields to the Document text.
"""
warnings.warn(
"`HuggingFaceTEIDocumentEmbedder` is deprecated and will be removed in Haystack 2.3.0."
"Use `HuggingFaceAPIDocumentEmbedder` instead.",
DeprecationWarning,
)

huggingface_hub_import.check()

if url:
Expand Down
13 changes: 13 additions & 0 deletions releasenotes/notes/hfapidocembedder-4c3970d002275edb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
---
features:
- |
Introduce `HuggingFaceAPIDocumentEmbedder`.
This component can be used to compute Document embeddings using different Hugging Face APIs:
- free Serverless Inference API
- paid Inference Endpoints
- self-hosted Text Embeddings Inference.
This embedder will replace the `HuggingFaceTEIDocumentEmbedder` in the future.
deprecations:
- |
Deprecate `HuggingFaceTEIDocumentEmbedder`. This component will be removed in Haystack 2.3.0.
Use `HuggingFaceAPIDocumentEmbedder` instead.
Loading

0 comments on commit eff53a9

Please sign in to comment.