From 22b7704542e8a124fc42baa732eb78402fb4cd8e Mon Sep 17 00:00:00 2001 From: awinml <97467100+awinml@users.noreply.github.com> Date: Fri, 8 Dec 2023 17:37:11 +0530 Subject: [PATCH] Add v3 embed models support --- .../embedders/document_embedder.py | 40 ++++++++++++++++--- .../embedders/text_embedder.py | 35 +++++++++++++--- .../src/cohere_haystack/embedders/utils.py | 10 +++-- .../cohere/tests/test_document_embedder.py | 6 +++ .../cohere/tests/test_text_embedder.py | 6 +++ 5 files changed, 83 insertions(+), 14 deletions(-) diff --git a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py index 681471947..5ad027a29 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/document_embedder.py @@ -16,12 +16,28 @@ class CohereDocumentEmbedder: """ A component for computing Document embeddings using Cohere models. The embedding of each Document is stored in the `embedding` field of the Document. + + Usage Example: + ```python + from haystack import Document + from cohere_haystack.embedders.document_embedder import CohereDocumentEmbedder + + doc = Document(content="I love pizza!") + + document_embedder = CohereDocumentEmbedder() + + result = document_embedder.run([doc]) + print(result['documents'][0].embedding) + + # [-0.453125, 1.2236328, 2.0058594, ...] + ``` """ def __init__( self, api_key: Optional[str] = None, model_name: str = "embed-english-v2.0", + input_type: Optional[str] = "search_document", api_base_url: str = COHERE_API_URL, truncate: str = "END", use_async_client: bool = False, @@ -37,9 +53,15 @@ def __init__( :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended). - :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are - `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, - `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. + :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are: + `"embed-english-v3.0"`, `"embed-english-light-3.0"`, `"embed-multilingual-v3.0"`, + `"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`, + `"embed-multilingual-v2.0"`. This list of all supported models can be found on the + [model documentation](https://docs.cohere.com/docs/models#representation). + :param input_type: Specifies the type of input you're giving to the model. Supported values are + "search_document", "search_query", "classification" and "clustering". Defaults to "search_document". Not + required for older versions of the embedding models (i.e. anything lower than v3), but is required for more + recent versions (i.e. anything bigger than v2). :param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both @@ -68,6 +90,7 @@ def __init__( self.api_key = api_key self.model_name = model_name + self.input_type = input_type self.api_base_url = api_base_url self.truncate = truncate self.use_async_client = use_async_client @@ -85,6 +108,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, model_name=self.model_name, + input_type=self.input_type, api_base_url=self.api_base_url, truncate=self.truncate, use_async_client=self.use_async_client, @@ -137,14 +161,20 @@ def run(self, documents: List[Document]): self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout ) all_embeddings, metadata = asyncio.run( - get_async_response(cohere_client, texts_to_embed, self.model_name, self.truncate) + get_async_response(cohere_client, texts_to_embed, self.model_name, self.input_type, self.truncate) ) else: cohere_client = Client( self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout ) all_embeddings, metadata = get_response( - cohere_client, texts_to_embed, self.model_name, self.truncate, self.batch_size, self.progress_bar + cohere_client, + texts_to_embed, + self.model_name, + self.input_type, + self.truncate, + self.batch_size, + self.progress_bar, ) for doc, embeddings in zip(documents, all_embeddings): diff --git a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py index 936926b99..28a1d2f7d 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py @@ -15,12 +15,27 @@ class CohereTextEmbedder: """ A component for embedding strings using Cohere models. + + Usage Example: + ```python + from cohere_haystack.embedders.text_embedder import CohereTextEmbedder + + text_to_embed = "I love pizza!" + + text_embedder = CohereTextEmbedder() + + print(text_embedder.run(text_to_embed)) + + # {'embedding': [-0.453125, 1.2236328, 2.0058594, ...] + # 'metadata': {'api_version': {'version': '1'}, 'billed_units': {'input_tokens': 4}}} + ``` """ def __init__( self, api_key: Optional[str] = None, model_name: str = "embed-english-v2.0", + input_type: Optional[str] = "search_document", api_base_url: str = COHERE_API_URL, truncate: str = "END", use_async_client: bool = False, @@ -32,9 +47,15 @@ def __init__( :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended). - :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are - `"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`, - `"embed-multilingual-v2.0"`/ `"multilingual-22-12"`. + :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are: + `"embed-english-v3.0"`, `"embed-english-light-3.0"`, `"embed-multilingual-v3.0"`, + `"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`, + `"embed-multilingual-v2.0"`. This list of all supported models can be found on the + [model documentation](https://docs.cohere.com/docs/models#representation). + :param input_type: Specifies the type of input you're giving to the model. Supported values are + "search_document", "search_query", "classification" and "clustering". Defaults to "search_document". Not + required for older versions of the embedding models (i.e. anything lower than v3), but is required for more + recent versions (i.e. anything bigger than v2). :param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`. :param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to `"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both @@ -58,6 +79,7 @@ def __init__( self.api_key = api_key self.model_name = model_name + self.input_type = input_type self.api_base_url = api_base_url self.truncate = truncate self.use_async_client = use_async_client @@ -71,6 +93,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, model_name=self.model_name, + input_type=self.input_type, api_base_url=self.api_base_url, truncate=self.truncate, use_async_client=self.use_async_client, @@ -94,11 +117,13 @@ def run(self, text: str): cohere_client = AsyncClient( self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout ) - embedding, metadata = asyncio.run(get_async_response(cohere_client, [text], self.model_name, self.truncate)) + embedding, metadata = asyncio.run( + get_async_response(cohere_client, [text], self.model_name, self.input_type, self.truncate) + ) else: cohere_client = Client( self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout ) - embedding, metadata = get_response(cohere_client, [text], self.model_name, self.truncate) + embedding, metadata = get_response(cohere_client, [text], self.model_name, self.input_type, self.truncate) return {"embedding": embedding[0], "metadata": metadata} diff --git a/integrations/cohere/src/cohere_haystack/embedders/utils.py b/integrations/cohere/src/cohere_haystack/embedders/utils.py index a3511008b..165d34acd 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/utils.py +++ b/integrations/cohere/src/cohere_haystack/embedders/utils.py @@ -9,11 +9,13 @@ API_BASE_URL = "https://api.cohere.ai/v1/embed" -async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, truncate): +async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, input_type, truncate): all_embeddings: List[List[float]] = [] metadata: Dict[str, Any] = {} try: - response = await cohere_async_client.embed(texts=texts, model=model_name, truncate=truncate) + response = await cohere_async_client.embed( + texts=texts, model=model_name, input_type=input_type, truncate=truncate + ) if response.meta is not None: metadata = response.meta for emb in response.embeddings: @@ -27,7 +29,7 @@ async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], def get_response( - cohere_client: Client, texts: List[str], model_name, truncate, batch_size=32, progress_bar=False + cohere_client: Client, texts: List[str], model_name, input_type, truncate, batch_size=32, progress_bar=False ) -> Tuple[List[List[float]], Dict[str, Any]]: """ We support batching with the sync client. @@ -42,7 +44,7 @@ def get_response( desc="Calculating embeddings", ): batch = texts[i : i + batch_size] - response = cohere_client.embed(batch, model=model_name, truncate=truncate) + response = cohere_client.embed(batch, model=model_name, input_type=input_type, truncate=truncate) for emb in response.embeddings: all_embeddings.append(emb) embeddings = [list(map(float, emb)) for emb in response.embeddings] diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py index d6309704c..5b0ad5c3f 100644 --- a/integrations/cohere/tests/test_document_embedder.py +++ b/integrations/cohere/tests/test_document_embedder.py @@ -17,6 +17,7 @@ def test_init_default(self): embedder = CohereDocumentEmbedder(api_key="test-api-key") assert embedder.api_key == "test-api-key" assert embedder.model_name == "embed-english-v2.0" + assert embedder.input_type == "search_document" assert embedder.api_base_url == COHERE_API_URL assert embedder.truncate == "END" assert embedder.use_async_client is False @@ -31,6 +32,7 @@ def test_init_with_parameters(self): embedder = CohereDocumentEmbedder( api_key="test-api-key", model_name="embed-multilingual-v2.0", + input_type="search_query", api_base_url="https://custom-api-base-url.com", truncate="START", use_async_client=True, @@ -43,6 +45,7 @@ def test_init_with_parameters(self): ) assert embedder.api_key == "test-api-key" assert embedder.model_name == "embed-multilingual-v2.0" + assert embedder.input_type == "search_query" assert embedder.api_base_url == "https://custom-api-base-url.com" assert embedder.truncate == "START" assert embedder.use_async_client is True @@ -60,6 +63,7 @@ def test_to_dict(self): "type": "cohere_haystack.embedders.document_embedder.CohereDocumentEmbedder", "init_parameters": { "model_name": "embed-english-v2.0", + "input_type": "search_document", "api_base_url": COHERE_API_URL, "truncate": "END", "use_async_client": False, @@ -76,6 +80,7 @@ def test_to_dict_with_custom_init_parameters(self): embedder_component = CohereDocumentEmbedder( api_key="test-api-key", model_name="embed-multilingual-v2.0", + input_type="search_query", api_base_url="https://custom-api-base-url.com", truncate="START", use_async_client=True, @@ -91,6 +96,7 @@ def test_to_dict_with_custom_init_parameters(self): "type": "cohere_haystack.embedders.document_embedder.CohereDocumentEmbedder", "init_parameters": { "model_name": "embed-multilingual-v2.0", + "input_type": "search_query", "api_base_url": "https://custom-api-base-url.com", "truncate": "START", "use_async_client": True, diff --git a/integrations/cohere/tests/test_text_embedder.py b/integrations/cohere/tests/test_text_embedder.py index d2aed79c1..9ec673c98 100644 --- a/integrations/cohere/tests/test_text_embedder.py +++ b/integrations/cohere/tests/test_text_embedder.py @@ -20,6 +20,7 @@ def test_init_default(self): assert embedder.api_key == "test-api-key" assert embedder.model_name == "embed-english-v2.0" + assert embedder.input_type == "search_document" assert embedder.api_base_url == COHERE_API_URL assert embedder.truncate == "END" assert embedder.use_async_client is False @@ -33,6 +34,7 @@ def test_init_with_parameters(self): embedder = CohereTextEmbedder( api_key="test-api-key", model_name="embed-multilingual-v2.0", + input_type="search_query", api_base_url="https://custom-api-base-url.com", truncate="START", use_async_client=True, @@ -41,6 +43,7 @@ def test_init_with_parameters(self): ) assert embedder.api_key == "test-api-key" assert embedder.model_name == "embed-multilingual-v2.0" + assert embedder.input_type == "search_query" assert embedder.api_base_url == "https://custom-api-base-url.com" assert embedder.truncate == "START" assert embedder.use_async_client is True @@ -57,6 +60,7 @@ def test_to_dict(self): "type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder", "init_parameters": { "model_name": "embed-english-v2.0", + "input_type": "search_document", "api_base_url": COHERE_API_URL, "truncate": "END", "use_async_client": False, @@ -72,6 +76,7 @@ def test_to_dict_with_custom_init_parameters(self): embedder_component = CohereTextEmbedder( api_key="test-api-key", model_name="embed-multilingual-v2.0", + input_type="search_query", api_base_url="https://custom-api-base-url.com", truncate="START", use_async_client=True, @@ -83,6 +88,7 @@ def test_to_dict_with_custom_init_parameters(self): "type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder", "init_parameters": { "model_name": "embed-multilingual-v2.0", + "input_type": "search_query", "api_base_url": "https://custom-api-base-url.com", "truncate": "START", "use_async_client": True,