-
Notifications
You must be signed in to change notification settings - Fork 126
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[cohere] Add text and document embedders (#80)
* Add text and document embedders ------ Co-authored-by: vrunm <[email protected]> * refactoring * linting * add pytest markers * more * fix api url management * fix integration tests * fix integrations tests for good * final cleanup * review feedback --------- Co-authored-by: vrunm <[email protected]>
- Loading branch information
Showing
10 changed files
with
634 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 3 additions & 0 deletions
3
integrations/cohere/src/cohere_haystack/embedders/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 |
153 changes: 153 additions & 0 deletions
153
integrations/cohere/src/cohere_haystack/embedders/document_embedder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import asyncio | ||
import os | ||
from typing import Any, Dict, List, Optional | ||
|
||
from cohere import COHERE_API_URL, AsyncClient, Client | ||
from haystack import Document, component, default_to_dict | ||
|
||
from cohere_haystack.embedders.utils import get_async_response, get_response | ||
|
||
|
||
@component | ||
class CohereDocumentEmbedder: | ||
""" | ||
A component for computing Document embeddings using Cohere models. | ||
The embedding of each Document is stored in the `embedding` field of the Document. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
api_key: Optional[str] = None, | ||
model_name: str = "embed-english-v2.0", | ||
api_base_url: str = COHERE_API_URL, | ||
truncate: str = "END", | ||
use_async_client: bool = False, | ||
max_retries: int = 3, | ||
timeout: int = 120, | ||
batch_size: int = 32, | ||
progress_bar: bool = True, | ||
metadata_fields_to_embed: Optional[List[str]] = None, | ||
embedding_separator: str = "\n", | ||
): | ||
""" | ||
Create a CohereDocumentEmbedder component. | ||
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment | ||
variable COHERE_API_KEY (recommended). | ||
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are | ||
`"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, | ||
`"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. | ||
:param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. | ||
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to | ||
`"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both | ||
cases, input is discarded until the remaining input is exactly the maximum input token length for the model. | ||
If NONE is selected, when the input exceeds the maximum input token length an error will be returned. | ||
:param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use | ||
AsyncClient for applications with many concurrent calls. | ||
:param max_retries: maximal number of retries for requests, defaults to `3`. | ||
:param timeout: request timeout in seconds, defaults to `120`. | ||
:param batch_size: Number of Documents to encode at once. | ||
:param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments | ||
to keep the logs clean. | ||
:param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document text. | ||
:param embedding_separator: Separator used to concatenate the meta fields to the Document text. | ||
""" | ||
|
||
if api_key is None: | ||
try: | ||
api_key = os.environ["COHERE_API_KEY"] | ||
except KeyError as error_msg: | ||
msg = ( | ||
"CohereDocumentEmbedder expects an Cohere API key. Please provide one by setting the environment " | ||
"variable COHERE_API_KEY (recommended) or by passing it explicitly." | ||
) | ||
raise ValueError(msg) from error_msg | ||
|
||
self.api_key = api_key | ||
self.model_name = model_name | ||
self.api_base_url = api_base_url | ||
self.truncate = truncate | ||
self.use_async_client = use_async_client | ||
self.max_retries = max_retries | ||
self.timeout = timeout | ||
self.batch_size = batch_size | ||
self.progress_bar = progress_bar | ||
self.metadata_fields_to_embed = metadata_fields_to_embed or [] | ||
self.embedding_separator = embedding_separator | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary omitting the api_key field. | ||
""" | ||
return default_to_dict( | ||
self, | ||
model_name=self.model_name, | ||
api_base_url=self.api_base_url, | ||
truncate=self.truncate, | ||
use_async_client=self.use_async_client, | ||
max_retries=self.max_retries, | ||
timeout=self.timeout, | ||
batch_size=self.batch_size, | ||
progress_bar=self.progress_bar, | ||
metadata_fields_to_embed=self.metadata_fields_to_embed, | ||
embedding_separator=self.embedding_separator, | ||
) | ||
|
||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: | ||
""" | ||
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. | ||
""" | ||
texts_to_embed: List[str] = [] | ||
for doc in documents: | ||
meta_values_to_embed = [ | ||
str(doc.meta[key]) for key in self.metadata_fields_to_embed if doc.meta.get(key) is not None | ||
] | ||
|
||
text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) # noqa: RUF005 | ||
texts_to_embed.append(text_to_embed) | ||
return texts_to_embed | ||
|
||
@component.output_types(documents=List[Document], metadata=Dict[str, Any]) | ||
def run(self, documents: List[Document]): | ||
""" | ||
Embed a list of Documents. | ||
The embedding of each Document is stored in the `embedding` field of the Document. | ||
:param documents: A list of Documents to embed. | ||
""" | ||
|
||
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): | ||
msg = ( | ||
"CohereDocumentEmbedder expects a list of Documents as input." | ||
"In case you want to embed a string, please use the CohereTextEmbedder." | ||
) | ||
raise TypeError(msg) | ||
|
||
if not documents: | ||
# return early if we were passed an empty list | ||
return {"documents": [], "metadata": {}} | ||
|
||
texts_to_embed = self._prepare_texts_to_embed(documents) | ||
|
||
if self.use_async_client: | ||
cohere_client = AsyncClient( | ||
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout | ||
) | ||
all_embeddings, metadata = asyncio.run( | ||
get_async_response(cohere_client, texts_to_embed, self.model_name, self.truncate) | ||
) | ||
else: | ||
cohere_client = Client( | ||
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout | ||
) | ||
all_embeddings, metadata = get_response( | ||
cohere_client, texts_to_embed, self.model_name, self.truncate, self.batch_size, self.progress_bar | ||
) | ||
|
||
for doc, embeddings in zip(documents, all_embeddings): | ||
doc.embedding = embeddings | ||
|
||
return {"documents": documents, "metadata": metadata} |
104 changes: 104 additions & 0 deletions
104
integrations/cohere/src/cohere_haystack/embedders/text_embedder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]> | ||
# | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import asyncio | ||
import os | ||
from typing import Any, Dict, List, Optional | ||
|
||
from cohere import COHERE_API_URL, AsyncClient, Client | ||
from haystack import component, default_to_dict | ||
|
||
from cohere_haystack.embedders.utils import get_async_response, get_response | ||
|
||
|
||
@component | ||
class CohereTextEmbedder: | ||
""" | ||
A component for embedding strings using Cohere models. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
api_key: Optional[str] = None, | ||
model_name: str = "embed-english-v2.0", | ||
api_base_url: str = COHERE_API_URL, | ||
truncate: str = "END", | ||
use_async_client: bool = False, | ||
max_retries: int = 3, | ||
timeout: int = 120, | ||
): | ||
""" | ||
Create a CohereTextEmbedder component. | ||
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment | ||
variable COHERE_API_KEY (recommended). | ||
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are | ||
`"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, | ||
`"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. | ||
:param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. | ||
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to | ||
`"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both | ||
cases, input is discarded until the remaining input is exactly the maximum input token length for the model. | ||
If NONE is selected, when the input exceeds the maximum input token length an error will be returned. | ||
:param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use | ||
AsyncClient for applications with many concurrent calls. | ||
:param max_retries: Maximum number of retries for requests, defaults to `3`. | ||
:param timeout: Request timeout in seconds, defaults to `120`. | ||
""" | ||
|
||
if api_key is None: | ||
try: | ||
api_key = os.environ["COHERE_API_KEY"] | ||
except KeyError as error_msg: | ||
msg = ( | ||
"CohereTextEmbedder expects an Cohere API key. Please provide one by setting the environment " | ||
"variable COHERE_API_KEY (recommended) or by passing it explicitly." | ||
) | ||
raise ValueError(msg) from error_msg | ||
|
||
self.api_key = api_key | ||
self.model_name = model_name | ||
self.api_base_url = api_base_url | ||
self.truncate = truncate | ||
self.use_async_client = use_async_client | ||
self.max_retries = max_retries | ||
self.timeout = timeout | ||
|
||
def to_dict(self) -> Dict[str, Any]: | ||
""" | ||
Serialize this component to a dictionary omitting the api_key field. | ||
""" | ||
return default_to_dict( | ||
self, | ||
model_name=self.model_name, | ||
api_base_url=self.api_base_url, | ||
truncate=self.truncate, | ||
use_async_client=self.use_async_client, | ||
max_retries=self.max_retries, | ||
timeout=self.timeout, | ||
) | ||
|
||
@component.output_types(embedding=List[float], metadata=Dict[str, Any]) | ||
def run(self, text: str): | ||
"""Embed a string.""" | ||
if not isinstance(text, str): | ||
msg = ( | ||
"CohereTextEmbedder expects a string as input." | ||
"In case you want to embed a list of Documents, please use the CohereDocumentEmbedder." | ||
) | ||
raise TypeError(msg) | ||
|
||
# Establish connection to API | ||
|
||
if self.use_async_client: | ||
cohere_client = AsyncClient( | ||
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout | ||
) | ||
embedding, metadata = asyncio.run(get_async_response(cohere_client, [text], self.model_name, self.truncate)) | ||
else: | ||
cohere_client = Client( | ||
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout | ||
) | ||
embedding, metadata = get_response(cohere_client, [text], self.model_name, self.truncate) | ||
|
||
return {"embedding": embedding[0], "metadata": metadata} |
Oops, something went wrong.