diff --git a/integrations/cohere/README.md b/integrations/cohere/README.md index 79cefed21..86a43bf83 100644 --- a/integrations/cohere/README.md +++ b/integrations/cohere/README.md @@ -7,8 +7,10 @@ **Table of Contents** -- [Installation](#installation) -- [License](#license) +- [cohere-haystack](#cohere-haystack) + - [Installation](#installation) + - [Contributing](#contributing) + - [License](#license) ## Installation @@ -16,6 +18,45 @@ pip install cohere-haystack ``` +## Contributing + +`hatch` is the best way to interact with this project, to install it: +```sh +pip install hatch +``` + +With `hatch` installed, to run all the tests: +``` +hatch run test +``` +> Note: integration tests will be skipped unless the env var COHERE_API_KEY is set. The api key needs to be valid +> in order to pass the tests. + +To only run unit tests: +``` +hatch run test -m"not integration" +``` + +To only run embedders tests: +``` +hatch run test -m"embedders" +``` + +To only run generators tests: +``` +hatch run test -m"generators" +``` + +Markers can be combined, for example you can run only integration tests for embedders with: +``` +hatch run test -m"integrations and embedders" +``` + +To run the linters `ruff` and `mypy`: +``` +hatch run lint:all +``` + ## License `cohere-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index e291907fd..5d589df7b 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -97,7 +97,6 @@ select = [ "E", "EM", "F", - "FBT", "I", "ICN", "ISC", @@ -118,8 +117,6 @@ select = [ ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", - # Allow boolean positional values in function calls, like `dict.get(... True)` - "FBT003", # Ignore checks for possible passwords "S105", "S106", "S107", # Ignore complexity @@ -163,6 +160,16 @@ exclude_lines = [ module = [ "cohere.*", "haystack.*", - "pytest.*" + "pytest.*", + "numpy.*", ] -ignore_missing_imports = true \ No newline at end of file +ignore_missing_imports = true + +[tool.pytest.ini_options] +addopts = "--strict-markers" +markers = [ + "integration: integration tests", + "embedders: embedders tests", + "generators: generators tests", +] +log_cli = true \ No newline at end of file diff --git a/integrations/cohere/src/cohere_haystack/embedders/__init__.py b/integrations/cohere/src/cohere_haystack/embedders/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/cohere/src/cohere_haystack/embedders/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py new file mode 100644 index 000000000..681471947 --- /dev/null +++ b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import os +from typing import Any, Dict, List, Optional + +from cohere import COHERE_API_URL, AsyncClient, Client +from haystack import Document, component, default_to_dict + +from cohere_haystack.embedders.utils import get_async_response, get_response + + +@component +class CohereDocumentEmbedder: + """ + A component for computing Document embeddings using Cohere models. + The embedding of each Document is stored in the `embedding` field of the Document. + """ + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "embed-english-v2.0", + api_base_url: str = COHERE_API_URL, + truncate: str = "END", + use_async_client: bool = False, + max_retries: int = 3, + timeout: int = 120, + batch_size: int = 32, + progress_bar: bool = True, + metadata_fields_to_embed: Optional[List[str]] = None, + embedding_separator: str = "\n", + ): + """ + Create a CohereDocumentEmbedder component. + + :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment + variable COHERE_API_KEY (recommended). + :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are + `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, + `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. + :param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. + :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to + `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both + cases, input is discarded until the remaining input is exactly the maximum input token length for the model. + If NONE is selected, when the input exceeds the maximum input token length an error will be returned. + :param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use + AsyncClient for applications with many concurrent calls. + :param max_retries: maximal number of retries for requests, defaults to `3`. + :param timeout: request timeout in seconds, defaults to `120`. + :param batch_size: Number of Documents to encode at once. + :param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments + to keep the logs clean. + :param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document text. + :param embedding_separator: Separator used to concatenate the meta fields to the Document text. + """ + + if api_key is None: + try: + api_key = os.environ["COHERE_API_KEY"] + except KeyError as error_msg: + msg = ( + "CohereDocumentEmbedder expects an Cohere API key. Please provide one by setting the environment " + "variable COHERE_API_KEY (recommended) or by passing it explicitly." + ) + raise ValueError(msg) from error_msg + + self.api_key = api_key + self.model_name = model_name + self.api_base_url = api_base_url + self.truncate = truncate + self.use_async_client = use_async_client + self.max_retries = max_retries + self.timeout = timeout + self.batch_size = batch_size + self.progress_bar = progress_bar + self.metadata_fields_to_embed = metadata_fields_to_embed or [] + self.embedding_separator = embedding_separator + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary omitting the api_key field. + """ + return default_to_dict( + self, + model_name=self.model_name, + api_base_url=self.api_base_url, + truncate=self.truncate, + use_async_client=self.use_async_client, + max_retries=self.max_retries, + timeout=self.timeout, + batch_size=self.batch_size, + progress_bar=self.progress_bar, + metadata_fields_to_embed=self.metadata_fields_to_embed, + embedding_separator=self.embedding_separator, + ) + + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: + """ + Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. + """ + texts_to_embed: List[str] = [] + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.metadata_fields_to_embed if doc.meta.get(key) is not None + ] + + text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) # noqa: RUF005 + texts_to_embed.append(text_to_embed) + return texts_to_embed + + @component.output_types(documents=List[Document], metadata=Dict[str, Any]) + def run(self, documents: List[Document]): + """ + Embed a list of Documents. + The embedding of each Document is stored in the `embedding` field of the Document. + + :param documents: A list of Documents to embed. + """ + + if not isinstance(documents, list) or documents and not isinstance(documents[0], Document): + msg = ( + "CohereDocumentEmbedder expects a list of Documents as input." + "In case you want to embed a string, please use the CohereTextEmbedder." + ) + raise TypeError(msg) + + if not documents: + # return early if we were passed an empty list + return {"documents": [], "metadata": {}} + + texts_to_embed = self._prepare_texts_to_embed(documents) + + if self.use_async_client: + cohere_client = AsyncClient( + self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + ) + all_embeddings, metadata = asyncio.run( + get_async_response(cohere_client, texts_to_embed, self.model_name, self.truncate) + ) + else: + cohere_client = Client( + self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + ) + all_embeddings, metadata = get_response( + cohere_client, texts_to_embed, self.model_name, self.truncate, self.batch_size, self.progress_bar + ) + + for doc, embeddings in zip(documents, all_embeddings): + doc.embedding = embeddings + + return {"documents": documents, "metadata": metadata} diff --git a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py new file mode 100644 index 000000000..936926b99 --- /dev/null +++ b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import asyncio +import os +from typing import Any, Dict, List, Optional + +from cohere import COHERE_API_URL, AsyncClient, Client +from haystack import component, default_to_dict + +from cohere_haystack.embedders.utils import get_async_response, get_response + + +@component +class CohereTextEmbedder: + """ + A component for embedding strings using Cohere models. + """ + + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "embed-english-v2.0", + api_base_url: str = COHERE_API_URL, + truncate: str = "END", + use_async_client: bool = False, + max_retries: int = 3, + timeout: int = 120, + ): + """ + Create a CohereTextEmbedder component. + + :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment + variable COHERE_API_KEY (recommended). + :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are + `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, + `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. + :param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. + :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to + `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both + cases, input is discarded until the remaining input is exactly the maximum input token length for the model. + If NONE is selected, when the input exceeds the maximum input token length an error will be returned. + :param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use + AsyncClient for applications with many concurrent calls. + :param max_retries: Maximum number of retries for requests, defaults to `3`. + :param timeout: Request timeout in seconds, defaults to `120`. + """ + + if api_key is None: + try: + api_key = os.environ["COHERE_API_KEY"] + except KeyError as error_msg: + msg = ( + "CohereTextEmbedder expects an Cohere API key. Please provide one by setting the environment " + "variable COHERE_API_KEY (recommended) or by passing it explicitly." + ) + raise ValueError(msg) from error_msg + + self.api_key = api_key + self.model_name = model_name + self.api_base_url = api_base_url + self.truncate = truncate + self.use_async_client = use_async_client + self.max_retries = max_retries + self.timeout = timeout + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary omitting the api_key field. + """ + return default_to_dict( + self, + model_name=self.model_name, + api_base_url=self.api_base_url, + truncate=self.truncate, + use_async_client=self.use_async_client, + max_retries=self.max_retries, + timeout=self.timeout, + ) + + @component.output_types(embedding=List[float], metadata=Dict[str, Any]) + def run(self, text: str): + """Embed a string.""" + if not isinstance(text, str): + msg = ( + "CohereTextEmbedder expects a string as input." + "In case you want to embed a list of Documents, please use the CohereDocumentEmbedder." + ) + raise TypeError(msg) + + # Establish connection to API + + if self.use_async_client: + cohere_client = AsyncClient( + self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + ) + embedding, metadata = asyncio.run(get_async_response(cohere_client, [text], self.model_name, self.truncate)) + else: + cohere_client = Client( + self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + ) + embedding, metadata = get_response(cohere_client, [text], self.model_name, self.truncate) + + return {"embedding": embedding[0], "metadata": metadata} diff --git a/integrations/cohere/src/cohere_haystack/embedders/utils.py b/integrations/cohere/src/cohere_haystack/embedders/utils.py new file mode 100644 index 000000000..a3511008b --- /dev/null +++ b/integrations/cohere/src/cohere_haystack/embedders/utils.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Tuple + +from cohere import AsyncClient, Client, CohereError +from tqdm import tqdm + +API_BASE_URL = "https://api.cohere.ai/v1/embed" + + +async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, truncate): + all_embeddings: List[List[float]] = [] + metadata: Dict[str, Any] = {} + try: + response = await cohere_async_client.embed(texts=texts, model=model_name, truncate=truncate) + if response.meta is not None: + metadata = response.meta + for emb in response.embeddings: + all_embeddings.append(emb) + + return all_embeddings, metadata + + except CohereError as error_response: + msg = error_response.message + raise ValueError(msg) from error_response + + +def get_response( + cohere_client: Client, texts: List[str], model_name, truncate, batch_size=32, progress_bar=False +) -> Tuple[List[List[float]], Dict[str, Any]]: + """ + We support batching with the sync client. + """ + all_embeddings: List[List[float]] = [] + metadata: Dict[str, Any] = {} + + try: + for i in tqdm( + range(0, len(texts), batch_size), + disable=not progress_bar, + desc="Calculating embeddings", + ): + batch = texts[i : i + batch_size] + response = cohere_client.embed(batch, model=model_name, truncate=truncate) + for emb in response.embeddings: + all_embeddings.append(emb) + embeddings = [list(map(float, emb)) for emb in response.embeddings] + all_embeddings.extend(embeddings) + if response.meta is not None: + metadata = response.meta + + return all_embeddings, metadata + + except CohereError as error_response: + msg = error_response.message + raise ValueError(msg) from error_response diff --git a/integrations/cohere/src/cohere_haystack/generator.py b/integrations/cohere/src/cohere_haystack/generator.py index 4b18fb75d..a07225804 100644 --- a/integrations/cohere/src/cohere_haystack/generator.py +++ b/integrations/cohere/src/cohere_haystack/generator.py @@ -4,13 +4,11 @@ import logging import os import sys -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, cast +from cohere import COHERE_API_URL, Client +from cohere.responses import Generations from haystack import DeserializationError, component, default_from_dict, default_to_dict -from haystack.lazy_imports import LazyImport - -with LazyImport(message="Run 'pip install cohere'") as cohere_import: - from cohere import COHERE_API_URL, Client logger = logging.getLogger(__name__) @@ -75,8 +73,6 @@ def __init__( - 'logit_bias': Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. The format is {token_id: bias} where bias is a float between -10 and 10. """ - cohere_import.check() - if not api_key: api_key = os.environ.get("COHERE_API_KEY") if not api_key: @@ -159,7 +155,7 @@ def run(self, prompt: str): self._check_truncated_answers(metadata) return {"replies": replies, "metadata": metadata} - metadata = [{"finish_reason": resp.finish_reason} for resp in response] + metadata = [{"finish_reason": resp.finish_reason} for resp in cast(Generations, response)] replies = [resp.text for resp in response] self._check_truncated_answers(metadata) return {"replies": replies, "metadata": metadata} diff --git a/integrations/cohere/tests/test_cohere_generators.py b/integrations/cohere/tests/test_cohere_generators.py index d267847a4..9462f364d 100644 --- a/integrations/cohere/tests/test_cohere_generators.py +++ b/integrations/cohere/tests/test_cohere_generators.py @@ -4,9 +4,12 @@ import os import pytest +from cohere import COHERE_API_URL from cohere_haystack.generator import CohereGenerator +pytestmark = pytest.mark.generators + def default_streaming_callback(chunk): """ @@ -16,16 +19,13 @@ def default_streaming_callback(chunk): print(chunk.text, flush=True, end="") # noqa: T201 -@pytest.mark.integration class TestCohereGenerator: def test_init_default(self): - import cohere - component = CohereGenerator(api_key="test-api-key") assert component.api_key == "test-api-key" assert component.model_name == "command" assert component.streaming_callback is None - assert component.api_base_url == cohere.COHERE_API_URL + assert component.api_base_url == COHERE_API_URL assert component.model_parameters == {} def test_init_with_parameters(self): @@ -45,8 +45,6 @@ def test_init_with_parameters(self): assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} def test_to_dict_default(self): - import cohere - component = CohereGenerator(api_key="test-api-key") data = component.to_dict() assert data == { @@ -54,7 +52,7 @@ def test_to_dict_default(self): "init_parameters": { "model_name": "command", "streaming_callback": None, - "api_base_url": cohere.COHERE_API_URL, + "api_base_url": COHERE_API_URL, }, } @@ -112,7 +110,7 @@ def test_from_dict(self, monkeypatch): "streaming_callback": "tests.test_cohere_generators.default_streaming_callback", }, } - component = CohereGenerator.from_dict(data) + component: CohereGenerator = CohereGenerator.from_dict(data) assert component.api_key == "test-key" assert component.model_name == "command" assert component.streaming_callback == default_streaming_callback @@ -134,7 +132,7 @@ def test_check_truncated_answers(self, caplog): ) @pytest.mark.integration def test_cohere_generator_run(self): - component = CohereGenerator(api_key=os.environ.get("COHERE_API_KEY")) + component = CohereGenerator() results = component.run(prompt="What's the capital of France?") assert len(results["replies"]) == 1 assert "Paris" in results["replies"][0] @@ -149,7 +147,7 @@ def test_cohere_generator_run(self): def test_cohere_generator_run_wrong_model_name(self): import cohere - component = CohereGenerator(model_name="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")) + component = CohereGenerator(model_name="something-obviously-wrong") with pytest.raises( cohere.CohereAPIError, match="model not found, make sure the correct model ID was used and that you have access to the model.", @@ -171,7 +169,7 @@ def __call__(self, chunk): return chunk callback = Callback() - component = CohereGenerator(os.environ.get("COHERE_API_KEY"), streaming_callback=callback) + component = CohereGenerator(streaming_callback=callback) results = component.run(prompt="What's the capital of France?") assert len(results["replies"]) == 1 diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py new file mode 100644 index 000000000..d6309704c --- /dev/null +++ b/integrations/cohere/tests/test_document_embedder.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +from cohere import COHERE_API_URL +from haystack import Document + +from cohere_haystack.embedders.document_embedder import CohereDocumentEmbedder + +pytestmark = pytest.mark.embedders + + +class TestCohereDocumentEmbedder: + def test_init_default(self): + embedder = CohereDocumentEmbedder(api_key="test-api-key") + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-english-v2.0" + assert embedder.api_base_url == COHERE_API_URL + assert embedder.truncate == "END" + assert embedder.use_async_client is False + assert embedder.max_retries == 3 + assert embedder.timeout == 120 + assert embedder.batch_size == 32 + assert embedder.progress_bar is True + assert embedder.metadata_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + + def test_init_with_parameters(self): + embedder = CohereDocumentEmbedder( + api_key="test-api-key", + model_name="embed-multilingual-v2.0", + api_base_url="https://custom-api-base-url.com", + truncate="START", + use_async_client=True, + max_retries=5, + timeout=60, + batch_size=64, + progress_bar=False, + metadata_fields_to_embed=["test_field"], + embedding_separator="-", + ) + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-multilingual-v2.0" + assert embedder.api_base_url == "https://custom-api-base-url.com" + assert embedder.truncate == "START" + assert embedder.use_async_client is True + assert embedder.max_retries == 5 + assert embedder.timeout == 60 + assert embedder.batch_size == 64 + assert embedder.progress_bar is False + assert embedder.metadata_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == "-" + + def test_to_dict(self): + embedder_component = CohereDocumentEmbedder(api_key="test-api-key") + component_dict = embedder_component.to_dict() + assert component_dict == { + "type": "cohere_haystack.embedders.document_embedder.CohereDocumentEmbedder", + "init_parameters": { + "model_name": "embed-english-v2.0", + "api_base_url": COHERE_API_URL, + "truncate": "END", + "use_async_client": False, + "max_retries": 3, + "timeout": 120, + "batch_size": 32, + "progress_bar": True, + "metadata_fields_to_embed": [], + "embedding_separator": "\n", + }, + } + + def test_to_dict_with_custom_init_parameters(self): + embedder_component = CohereDocumentEmbedder( + api_key="test-api-key", + model_name="embed-multilingual-v2.0", + api_base_url="https://custom-api-base-url.com", + truncate="START", + use_async_client=True, + max_retries=5, + timeout=60, + batch_size=64, + progress_bar=False, + metadata_fields_to_embed=["text_field"], + embedding_separator="-", + ) + component_dict = embedder_component.to_dict() + assert component_dict == { + "type": "cohere_haystack.embedders.document_embedder.CohereDocumentEmbedder", + "init_parameters": { + "model_name": "embed-multilingual-v2.0", + "api_base_url": "https://custom-api-base-url.com", + "truncate": "START", + "use_async_client": True, + "max_retries": 5, + "timeout": 60, + "batch_size": 64, + "progress_bar": False, + "metadata_fields_to_embed": ["text_field"], + "embedding_separator": "-", + }, + } + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None), + reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_run(self): + embedder = CohereDocumentEmbedder() + + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + result = embedder.run(docs) + docs_with_embeddings = result["documents"] + + assert isinstance(docs_with_embeddings, list) + assert len(docs_with_embeddings) == len(docs) + for doc in docs_with_embeddings: + assert isinstance(doc.embedding, list) + assert isinstance(doc.embedding[0], float) + + def test_run_wrong_input_format(self): + embedder = CohereDocumentEmbedder(api_key="test-api-key") + + with pytest.raises(TypeError, match="CohereDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents="text") + with pytest.raises(TypeError, match="CohereDocumentEmbedder expects a list of Documents as input"): + embedder.run(documents=[1, 2, 3]) + + assert embedder.run(documents=[]) == {"documents": [], "metadata": {}} diff --git a/integrations/cohere/tests/test_text_embedder.py b/integrations/cohere/tests/test_text_embedder.py new file mode 100644 index 000000000..d2aed79c1 --- /dev/null +++ b/integrations/cohere/tests/test_text_embedder.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os + +import pytest +from cohere import COHERE_API_URL + +from cohere_haystack.embedders.text_embedder import CohereTextEmbedder + +pytestmark = pytest.mark.embedders + + +class TestCohereTextEmbedder: + def test_init_default(self): + """ + Test default initialization parameters for CohereTextEmbedder. + """ + embedder = CohereTextEmbedder(api_key="test-api-key") + + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-english-v2.0" + assert embedder.api_base_url == COHERE_API_URL + assert embedder.truncate == "END" + assert embedder.use_async_client is False + assert embedder.max_retries == 3 + assert embedder.timeout == 120 + + def test_init_with_parameters(self): + """ + Test custom initialization parameters for CohereTextEmbedder. + """ + embedder = CohereTextEmbedder( + api_key="test-api-key", + model_name="embed-multilingual-v2.0", + api_base_url="https://custom-api-base-url.com", + truncate="START", + use_async_client=True, + max_retries=5, + timeout=60, + ) + assert embedder.api_key == "test-api-key" + assert embedder.model_name == "embed-multilingual-v2.0" + assert embedder.api_base_url == "https://custom-api-base-url.com" + assert embedder.truncate == "START" + assert embedder.use_async_client is True + assert embedder.max_retries == 5 + assert embedder.timeout == 60 + + def test_to_dict(self): + """ + Test serialization of this component to a dictionary, using default initialization parameters. + """ + embedder_component = CohereTextEmbedder(api_key="test-api-key") + component_dict = embedder_component.to_dict() + assert component_dict == { + "type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder", + "init_parameters": { + "model_name": "embed-english-v2.0", + "api_base_url": COHERE_API_URL, + "truncate": "END", + "use_async_client": False, + "max_retries": 3, + "timeout": 120, + }, + } + + def test_to_dict_with_custom_init_parameters(self): + """ + Test serialization of this component to a dictionary, using custom initialization parameters. + """ + embedder_component = CohereTextEmbedder( + api_key="test-api-key", + model_name="embed-multilingual-v2.0", + api_base_url="https://custom-api-base-url.com", + truncate="START", + use_async_client=True, + max_retries=5, + timeout=60, + ) + component_dict = embedder_component.to_dict() + assert component_dict == { + "type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder", + "init_parameters": { + "model_name": "embed-multilingual-v2.0", + "api_base_url": "https://custom-api-base-url.com", + "truncate": "START", + "use_async_client": True, + "max_retries": 5, + "timeout": 60, + }, + } + + def test_run_wrong_input_format(self): + """ + Test for checking incorrect input when creating embedding. + """ + embedder = CohereTextEmbedder(api_key="test-api-key") + list_integers_input = ["text_snippet_1", "text_snippet_2"] + + with pytest.raises(TypeError, match="CohereTextEmbedder expects a string as input"): + embedder.run(text=list_integers_input) + + @pytest.mark.skipif( + not os.environ.get("COHERE_API_KEY", None), + reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + ) + @pytest.mark.integration + def test_run(self): + embedder = CohereTextEmbedder() + text = "The food was delicious" + result = embedder.run(text) + assert all(isinstance(x, float) for x in result["embedding"])