diff --git a/haystack/preview/components/embedders/__init__.py b/haystack/preview/components/embedders/__init__.py index a0840d7e0a..bb5e4b9619 100644 --- a/haystack/preview/components/embedders/__init__.py +++ b/haystack/preview/components/embedders/__init__.py @@ -1,3 +1,5 @@ +from haystack.preview.components.embedders.cohere_text_embedder import CohereTextEmbedder +from haystack.preview.components.embedders.cohere_document_embedder import CohereDocumentEmbedder from haystack.preview.components.embedders.sentence_transformers_text_embedder import SentenceTransformersTextEmbedder from haystack.preview.components.embedders.sentence_transformers_document_embedder import ( SentenceTransformersDocumentEmbedder, @@ -6,6 +8,8 @@ from haystack.preview.components.embedders.openai_text_embedder import OpenAITextEmbedder __all__ = [ + "CohereTextEmbedder", + "CohereDocumentEmbedder", "SentenceTransformersTextEmbedder", "SentenceTransformersDocumentEmbedder", "OpenAITextEmbedder", diff --git a/haystack/preview/components/embedders/cohere_document_embedder.py b/haystack/preview/components/embedders/cohere_document_embedder.py new file mode 100644 index 0000000000..34b52ccffd --- /dev/null +++ b/haystack/preview/components/embedders/cohere_document_embedder.py @@ -0,0 +1,174 @@ +from typing import List, Optional, Dict, Any +import os +from tqdm import tqdm + +from haystack.preview import component, Document, default_to_dict, default_from_dict +from haystack.preview.lazy_imports import LazyImport + +with LazyImport(message="Run 'pip install cohere'") as cohere_import: + from cohere import Client, AsyncClient, CohereError + +API_BASE_URL = "https://api.cohere.ai/v1/embed" + + +@component +class CohereDocumentEmbedder: + """ + A component for computing Document embeddings using Cohere models. + The embedding of each Document is stored in the `embedding` field of the Document. + """ + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "embed-english-v2.0", + api_base_url: str = API_BASE_URL, + truncate: str = "END", + use_async_client: bool = False, + max_retries: Optional[int] = 3, + timeout: Optional[int] = 120, + batch_size: int = 32, + progress_bar: bool = True, + metadata_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + ): + """ + Create a CohereDocumentEmbedder component. + + :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended). + :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. + :param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. + :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If NONE is selected, when the input exceeds the maximum input token length an error will be returned. + :param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use AsyncClient for applications with many concurrent calls. + :param max_retries: maximal number of retries for requests, defaults to `3`. + :param timeout: request timeout in seconds, defaults to `120`. + :param batch_size: Number of Documents to encode at once. + :param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments + to keep the logs clean. + :param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document text. + :param embedding_separator: Separator used to concatenate the meta fields to the Document text. + """ + + if api_key is None: + try: + api_key = os.environ["COHERE_API_KEY"] + except KeyError as error_msg: + raise ValueError( + "CohereDocumentEmbedder expects an Cohere API key. " + "Please provide one by setting the environment variable COHERE_API_KEY (recommended) or by passing it explicitly." + ) from error_msg + + self.api_key = api_key + self.model_name = model_name + self.api_base_url = api_base_url + self.truncate = truncate + self.use_async_client = use_async_client + self.max_retries = max_retries + self.timeout = timeout + self.batch_size = batch_size + self.progress_bar = progress_bar + self.metadata_fields_to_embed = metadata_fields_to_embed or [] + self.embedding_separator = embedding_separator + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict( + self, + model_name=self.model_name, + api_base_url=self.api_base_url, + truncate=self.truncate, + use_async_client=self.use_async_client, + max_retries=self.max_retries, + timeout=self.timeout, + batch_size=self.batch_size, + progress_bar=self.progress_bar, + metadata_fields_to_embed=self.metadata_fields_to_embed, + embedding_separator=self.embedding_separator, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CohereDocumentEmbedder": + """ + Deserialize this component from a dictionary. + """ + return default_from_dict(cls, data) + + async def _get_async_response(self, cohere_async_client: AsyncClient, documents: List[Document]): + try: + response = await cohere_async_client.embed(texts=[documents], model=self.model_name, truncate=self.truncate) + metadata = response.meta + embedding = [list(map(float, emb)) for emb in response.embeddings][0] + + except CohereError as error_response: + print(error_response.message) + + return embedding, metadata + + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: + """ + Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. + """ + texts_to_embed = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.metadata_fields_to_embed if doc.meta.get(key) is not None + ] + + text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + texts_to_embed.append(text_to_embed) + return texts_to_embed + + @component.output_types(documents=List[Document], metadata=Dict[str, Any]) + def run(self, documents: List[Document]): + """ + Embed a list of Documents. + The embedding of each Document is stored in the `embedding` field of the Document. + + :param documents: A list of Documents to embed. + """ + + if not isinstance(documents, list) or not isinstance(documents[0], Document): + raise TypeError( + "CohereDocumentEmbedder expects a list of Documents as input." + "In case you want to embed a string, please use the CohereTextEmbedder." + ) + + # Establish connection to API + + if self.use_async_client == True: + cohere_client = AsyncClient( + self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + ) + texts_to_embed = self._prepare_texts_to_embed(cohere_client, documents) + + else: + cohere_client = Client( + self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + ) + + try: + all_embeddings = [] + metadata = {} + for i in tqdm( + range(0, len(texts_to_embed), self.batch_size), + disable=not self.progress_bar, + desc="Calculating embeddings", + ): + batch = texts_to_embed[i : i + self.batch_size] + response = cohere_client.embed(batch) + embeddings = [list(map(float, emb)) for emb in response.embeddings] + all_embeddings.extend(embeddings) + + metadata = response.meta + + documents_with_embeddings = [] + for doc, emb in zip(documents, all_embeddings): + doc_as_dict = doc.to_dict() + doc_as_dict["embedding"] = emb + documents_with_embeddings.append(Document.from_dict(doc_as_dict)) + except CohereError as error_response: + print(error_response.message) + + return {"documents": documents_with_embeddings, "metadata": metadata} diff --git a/haystack/preview/components/embedders/cohere_text_embedder.py b/haystack/preview/components/embedders/cohere_text_embedder.py new file mode 100644 index 0000000000..b34a09efea --- /dev/null +++ b/haystack/preview/components/embedders/cohere_text_embedder.py @@ -0,0 +1,121 @@ +from typing import List, Optional, Dict, Any +import os + +from haystack.preview import component, default_to_dict, default_from_dict +from haystack.preview.lazy_imports import LazyImport + +with LazyImport(message="Run 'pip install cohere'") as cohere_import: + from cohere import Client, AsyncClient, CohereError + + +API_BASE_URL = "https://api.cohere.ai/v1/embed" + + +@component +class CohereTextEmbedder: + """ + A component for embedding strings using Cohere models. + """ + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "embed-english-v2.0", + api_base_url: str = API_BASE_URL, + truncate: str = "END", + use_async_client: bool = False, + max_retries: Optional[int] = 3, + timeout: Optional[int] = 120, + ): + """ + Create a CohereTextEmbedder component. + + :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended). + :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. + :param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. + :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both cases, input is discarded until the remaining input is exactly the maximum input token length for the model. If NONE is selected, when the input exceeds the maximum input token length an error will be returned. + :param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use AsyncClient for applications with many concurrent calls. + :param max_retries: Maximum number of retries for requests, defaults to `3`. + :param timeout: Request timeout in seconds, defaults to `120`. + """ + + if api_key is None: + try: + api_key = os.environ["COHERE_API_KEY"] + except KeyError as error_msg: + raise ValueError( + "CohereTextEmbedder expects an Cohere API key. " + "Please provide one by setting the environment variable COHERE_API_KEY (recommended) or by passing it explicitly." + ) from error_msg + + self.api_key = api_key + self.model_name = model_name + self.api_base_url = api_base_url + self.truncate = truncate + self.use_async_client = use_async_client + self.max_retries = max_retries + self.timeout = timeout + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict( + self, + model_name=self.model_name, + api_base_url=self.api_base_url, + truncate=self.truncate, + use_async_client=self.use_async_client, + max_retries=self.max_retries, + timeout=self.timeout, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CohereTextEmbedder": + """ + Deserialize this component from a dictionary. + """ + return default_from_dict(cls, data) + + async def _get_async_response(self, cohere_async_client: AsyncClient, text: str): + try: + response = await cohere_async_client.embed(texts=[text], model=self.model_name, truncate=self.truncate) + metadata = response.meta + embedding = [list(map(float, emb)) for emb in response.embeddings][0] + + except CohereError as error_response: + print(error_response.message) + + return embedding, metadata + + @component.output_types(embedding=List[float], metadata=Dict[str, Any]) + def run(self, text: str): + """Embed a string.""" + if not isinstance(text, str): + raise TypeError( + "CohereTextEmbedder expects a string as input." + "In case you want to embed a list of Documents, please use the CohereDocumentEmbedder." + ) + + # Establish connection to API + + if self.use_async_client == True: + cohere_client = AsyncClient( + self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + ) + embedding, metadata = self._get_async_response(cohere_client, text) + + else: + cohere_client = Client( + self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + ) + + try: + response = cohere_client.embed(texts=[text], model=self.model_name, truncate=self.truncate) + metadata = response.meta + embedding = [list(map(float, emb)) for emb in response.embeddings][0] + + except CohereError as error_response: + print(error_response.message) + + return {"embedding": embedding, "metadata": metadata} diff --git a/releasenotes/notes/add-CohereTextEmbedder-a429ecf033f36631.yaml b/releasenotes/notes/add-CohereTextEmbedder-a429ecf033f36631.yaml new file mode 100644 index 0000000000..54e2a96612 --- /dev/null +++ b/releasenotes/notes/add-CohereTextEmbedder-a429ecf033f36631.yaml @@ -0,0 +1,6 @@ +--- +preview: + - | + Add `CohereTextEmbedder`, a component that uses Cohere embedding models to embed strings into vectors. + - | + Add `CohereDocumentEmbedder`, a component that uses Cohere embedding models to embeds a list of Documents. diff --git a/test/preview/components/embedders/test_cohere_document_embedder.py b/test/preview/components/embedders/test_cohere_document_embedder.py new file mode 100644 index 0000000000..3dda63776e --- /dev/null +++ b/test/preview/components/embedders/test_cohere_document_embedder.py @@ -0,0 +1,165 @@ +from unittest.mock import patch, MagicMock +import pytest +from cohere.responses.embeddings import Embeddings +import numpy as np +from haystack.preview import Document +from haystack.preview.components.embedders.cohere_document_embedder import CohereDocumentEmbedder + + +class TestCohereDocumentEmbedder: + @pytest.mark.unit + def test_init_default(self): + embedder = CohereDocumentEmbedder(api_key="test-api-key") + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-english-v2.0" + assert embedder.api_base_url == "https://api.cohere.ai/v1/embed" + assert embedder.truncate == "END" + assert embedder.use_async_client == False + assert embedder.max_retries == 3 + assert embedder.timeout == 120 + assert embedder.batch_size == 32 + assert embedder.progress_bar == True + assert embedder.metadata_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + + @pytest.mark.unit + def test_init_with_parameters(self): + embedder = CohereDocumentEmbedder( + api_key="test-api-key", + model_name="embed-multilingual-v2.0", + api_base_url="https://custom-api-base-url.com", + truncate="START", + use_async_client=True, + max_retries=5, + timeout=60, + batch_size=64, + progress_bar=False, + metadata_fields_to_embed=["test_field"], + embedding_separator="-", + ) + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-multilingual-v2.0" + assert embedder.api_base_url == "https://custom-api-base-url.com" + assert embedder.truncate == "START" + assert embedder.use_async_client == True + assert embedder.max_retries == 5 + assert embedder.timeout == 60 + assert embedder.batch_size == 64 + assert embedder.progress_bar == False + assert embedder.metadata_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == "-" + + @pytest.mark.unit + def test_to_dict(self): + embedder_component = CohereDocumentEmbedder(api_key="test-api-key") + component_dict = embedder_component.to_dict() + assert component_dict == { + "type": "CohereDocumentEmbedder", + "init_parameters": { + "model_name": "embed-english-v2.0", + "api_base_url": "https://api.cohere.ai/v1/embed", + "truncate": "END", + "use_async_client": False, + "max_retries": 3, + "timeout": 120, + "batch_size": 32, + "progress_bar": True, + "metadata_fields_to_embed": [], + "embedding_separator": "\n", + }, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + embedder_component = CohereDocumentEmbedder( + api_key="test-api-key", + model_name="embed-multilingual-v2.0", + api_base_url="https://custom-api-base-url.com", + truncate="START", + use_async_client=True, + max_retries=5, + timeout=60, + batch_size=64, + progress_bar=False, + metadata_fields_to_embed=["text_field"], + embedding_separator="-", + ) + component_dict = embedder_component.to_dict() + assert component_dict == { + "type": "CohereDocumentEmbedder", + "init_parameters": { + "model_name": "embed-multilingual-v2.0", + "api_base_url": "https://custom-api-base-url.com", + "truncate": "START", + "use_async_client": True, + "max_retries": 5, + "timeout": 60, + "batch_size": 64, + "progress_bar": False, + "metadata_fields_to_embed": ["text_field"], + "embedding_separator": "-", + }, + } + + @pytest.mark.unit + def test_from_dict(self): + embedder_component_dict = { + "type": "CohereDocumentEmbedder", + "init_parameters": { + "api_key": "test-api-key", + "model_name": "embed-english-v2.0", + "api_base_url": "https://api.cohere.ai/v1/embed", + "truncate": "START", + "use_async_client": True, + "max_retries": 5, + "timeout": 60, + "batch_size": 32, + "progress_bar": False, + "metadata_fields_to_embed": ["test_field"], + "embedding_separator": "-", + }, + } + embedder = CohereDocumentEmbedder.from_dict(embedder_component_dict) + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-english-v2.0" + assert embedder.api_base_url == "https://api.cohere.ai/v1/embed" + assert embedder.truncate == "START" + assert embedder.use_async_client == True + assert embedder.max_retries == 5 + assert embedder.timeout == 60 + assert embedder.batch_size == 32 + assert embedder.progress_bar == False + assert embedder.metadata_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == "-" + + @pytest.mark.unit + def test_run(self): + embedder = CohereDocumentEmbedder(api_key="test-api-key") + embedder = MagicMock() + embedder.run = lambda x, **kwargs: np.random.rand(len(x), 2).tolist() + + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + result = embedder.run(docs) + embeddings = result["documents"] + + assert isinstance(embeddings, list) + assert len(embeddings) == len(docs) + for embedding in embeddings: + assert isinstance(embedding, list) + assert isinstance(embedding[0], float) + + @pytest.mark.unit + def test_run_wrong_input_format(self): + embedder = CohereDocumentEmbedder(api_key="test-api-key") + + string_input = "text" + list_integers_input = [1, 2, 3] + + with pytest.raises(TypeError, match="CohereDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents=string_input) + with pytest.raises(TypeError, match="CohereDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents=list_integers_input) diff --git a/test/preview/components/embedders/test_cohere_text_embedder.py b/test/preview/components/embedders/test_cohere_text_embedder.py new file mode 100644 index 0000000000..bb0ed0694a --- /dev/null +++ b/test/preview/components/embedders/test_cohere_text_embedder.py @@ -0,0 +1,165 @@ +from unittest.mock import patch, MagicMock +import pytest +from cohere.responses.embeddings import Embeddings +from haystack.preview.components.embedders.cohere_text_embedder import CohereTextEmbedder + + +class TestCohereTextEmbedder: + @pytest.mark.unit + def test_init_default(self): + """ + Test default initialization parameters for CohereTextEmbedder. + """ + embedder = CohereTextEmbedder(api_key="test-api-key") + + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-english-v2.0" + assert embedder.api_base_url == "https://api.cohere.ai/v1/embed" + assert embedder.truncate == "END" + assert embedder.use_async_client == False + assert embedder.max_retries == 3 + assert embedder.timeout == 120 + + @pytest.mark.unit + def test_init_with_parameters(self): + """ + Test custom initialization parameters for CohereTextEmbedder. + """ + embedder = CohereTextEmbedder( + api_key="test-api-key", + model_name="embed-multilingual-v2.0", + api_base_url="https://custom-api-base-url.com", + truncate="START", + use_async_client=True, + max_retries=5, + timeout=60, + ) + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-multilingual-v2.0" + assert embedder.api_base_url == "https://custom-api-base-url.com" + assert embedder.truncate == "START" + assert embedder.use_async_client == True + assert embedder.max_retries == 5 + assert embedder.timeout == 60 + + @pytest.mark.unit + def test_to_dict(self): + """ + Test serialization of this component to a dictionary, using default initialization parameters. + """ + embedder_component = CohereTextEmbedder(api_key="test-api-key") + component_dict = embedder_component.to_dict() + assert component_dict == { + "type": "CohereTextEmbedder", + "init_parameters": { + "model_name": "embed-english-v2.0", + "api_base_url": "https://api.cohere.ai/v1/embed", + "truncate": "END", + "use_async_client": False, + "max_retries": 3, + "timeout": 120, + }, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + """ + Test serialization of this component to a dictionary, using custom initialization parameters. + """ + embedder_component = CohereTextEmbedder( + api_key="test-api-key", + model_name="embed-multilingual-v2.0", + api_base_url="https://custom-api-base-url.com", + truncate="START", + use_async_client=True, + max_retries=5, + timeout=60, + ) + component_dict = embedder_component.to_dict() + assert component_dict == { + "type": "CohereTextEmbedder", + "init_parameters": { + "model_name": "embed-multilingual-v2.0", + "api_base_url": "https://custom-api-base-url.com", + "truncate": "START", + "use_async_client": True, + "max_retries": 5, + "timeout": 60, + }, + } + + @pytest.mark.unit + def test_from_dict(self): + """ + Test deserialization of this component from a dictionary, using default initialization parameters. + """ + embedder_component_dict = { + "type": "CohereTextEmbedder", + "init_parameters": { + "api_key": "test-api-key", + "model_name": "embed-english-v2.0", + "api_base_url": "https://api.cohere.ai/v1/embed", + "truncate": "END", + "use_async_client": False, + "max_retries": 3, + "timeout": 120, + }, + } + embedder = CohereTextEmbedder.from_dict(embedder_component_dict) + + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-english-v2.0" + assert embedder.api_base_url == "https://api.cohere.ai/v1/embed" + assert embedder.truncate == "END" + assert embedder.use_async_client == False + assert embedder.max_retries == 3 + assert embedder.timeout == 120 + + @pytest.mark.unit + def test_from_dict_with_custom_init_parameters(self): + """ + Test deserialization of this component from a dictionary, using custom initialization parameters. + """ + embedder_component_dict = { + "type": "CohereTextEmbedder", + "init_parameters": { + "api_key": "test-api-key", + "model_name": "embed-multilingual-v2.0", + "api_base_url": "https://custom-api-base-url.com", + "truncate": "START", + "use_async_client": True, + "max_retries": 5, + "timeout": 60, + }, + } + embedder = CohereTextEmbedder.from_dict(embedder_component_dict) + + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-multilingual-v2.0" + assert embedder.api_base_url == "https://custom-api-base-url.com" + assert embedder.truncate == "START" + assert embedder.use_async_client == True + assert embedder.max_retries == 5 + assert embedder.timeout == 60 + + @pytest.mark.unit + def test_run_wrong_input_format(self): + """ + Test for checking incorrect input when creating embedding. + """ + embedder = CohereTextEmbedder(api_key="test-api-key") + + list_integers_input = ["text_snippet_1", "text_snippet_2"] + + with pytest.raises(TypeError, match="CohereTextEmbedder expects a string as input"): + embedder.run(text=list_integers_input) + + @pytest.mark.integration + def test_run(self): + embedder = CohereTextEmbedder(api_key="test-api-key") + embedder = MagicMock() + text = "The food was delicious" + + result = embedder.run(text) + + assert all(isinstance(x, float) for x in result["embedding"])