Skip to content

Commit

Permalink
Add v3 embed models support
Browse files Browse the repository at this point in the history
  • Loading branch information
awinml committed Dec 8, 2023
1 parent 4076949 commit 22b7704
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,28 @@ class CohereDocumentEmbedder:
"""
A component for computing Document embeddings using Cohere models.
The embedding of each Document is stored in the `embedding` field of the Document.
Usage Example:
```python
from haystack import Document
from cohere_haystack.embedders.document_embedder import CohereDocumentEmbedder
doc = Document(content="I love pizza!")
document_embedder = CohereDocumentEmbedder()
result = document_embedder.run([doc])
print(result['documents'][0].embedding)
# [-0.453125, 1.2236328, 2.0058594, ...]
```
"""

def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "embed-english-v2.0",
input_type: Optional[str] = "search_document",
api_base_url: str = COHERE_API_URL,
truncate: str = "END",
use_async_client: bool = False,
Expand All @@ -37,9 +53,15 @@ def __init__(
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment
variable COHERE_API_KEY (recommended).
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are
`"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`,
`"embed-multilingual-v2.0"`/ `"multilingual-22-12"`.
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
`"embed-english-v3.0"`, `"embed-english-light-3.0"`, `"embed-multilingual-v3.0"`,
`"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`,
`"embed-multilingual-v2.0"`. This list of all supported models can be found on the
[model documentation](https://docs.cohere.com/docs/models#representation).
:param input_type: Specifies the type of input you're giving to the model. Supported values are
"search_document", "search_query", "classification" and "clustering". Defaults to "search_document". Not
required for older versions of the embedding models (i.e. anything lower than v3), but is required for more
recent versions (i.e. anything bigger than v2).
:param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`.
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to
`"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both
Expand Down Expand Up @@ -68,6 +90,7 @@ def __init__(

self.api_key = api_key
self.model_name = model_name
self.input_type = input_type
self.api_base_url = api_base_url
self.truncate = truncate
self.use_async_client = use_async_client
Expand All @@ -85,6 +108,7 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model_name=self.model_name,
input_type=self.input_type,
api_base_url=self.api_base_url,
truncate=self.truncate,
use_async_client=self.use_async_client,
Expand Down Expand Up @@ -137,14 +161,20 @@ def run(self, documents: List[Document]):
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
all_embeddings, metadata = asyncio.run(
get_async_response(cohere_client, texts_to_embed, self.model_name, self.truncate)
get_async_response(cohere_client, texts_to_embed, self.model_name, self.input_type, self.truncate)
)
else:
cohere_client = Client(
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
all_embeddings, metadata = get_response(
cohere_client, texts_to_embed, self.model_name, self.truncate, self.batch_size, self.progress_bar
cohere_client,
texts_to_embed,
self.model_name,
self.input_type,
self.truncate,
self.batch_size,
self.progress_bar,
)

for doc, embeddings in zip(documents, all_embeddings):
Expand Down
35 changes: 30 additions & 5 deletions integrations/cohere/src/cohere_haystack/embedders/text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,27 @@
class CohereTextEmbedder:
"""
A component for embedding strings using Cohere models.
Usage Example:
```python
from cohere_haystack.embedders.text_embedder import CohereTextEmbedder
text_to_embed = "I love pizza!"
text_embedder = CohereTextEmbedder()
print(text_embedder.run(text_to_embed))
# {'embedding': [-0.453125, 1.2236328, 2.0058594, ...]
# 'metadata': {'api_version': {'version': '1'}, 'billed_units': {'input_tokens': 4}}}
```
"""

def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "embed-english-v2.0",
input_type: Optional[str] = "search_document",
api_base_url: str = COHERE_API_URL,
truncate: str = "END",
use_async_client: bool = False,
Expand All @@ -32,9 +47,15 @@ def __init__(
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment
variable COHERE_API_KEY (recommended).
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are
`"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`,
`"embed-multilingual-v2.0"`/ `"multilingual-22-12"`.
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
`"embed-english-v3.0"`, `"embed-english-light-3.0"`, `"embed-multilingual-v3.0"`,
`"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`,
`"embed-multilingual-v2.0"`. This list of all supported models can be found on the
[model documentation](https://docs.cohere.com/docs/models#representation).
:param input_type: Specifies the type of input you're giving to the model. Supported values are
"search_document", "search_query", "classification" and "clustering". Defaults to "search_document". Not
required for older versions of the embedding models (i.e. anything lower than v3), but is required for more
recent versions (i.e. anything bigger than v2).
:param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`.
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to
`"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both
Expand All @@ -58,6 +79,7 @@ def __init__(

self.api_key = api_key
self.model_name = model_name
self.input_type = input_type
self.api_base_url = api_base_url
self.truncate = truncate
self.use_async_client = use_async_client
Expand All @@ -71,6 +93,7 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
model_name=self.model_name,
input_type=self.input_type,
api_base_url=self.api_base_url,
truncate=self.truncate,
use_async_client=self.use_async_client,
Expand All @@ -94,11 +117,13 @@ def run(self, text: str):
cohere_client = AsyncClient(
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
embedding, metadata = asyncio.run(get_async_response(cohere_client, [text], self.model_name, self.truncate))
embedding, metadata = asyncio.run(
get_async_response(cohere_client, [text], self.model_name, self.input_type, self.truncate)
)
else:
cohere_client = Client(
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
embedding, metadata = get_response(cohere_client, [text], self.model_name, self.truncate)
embedding, metadata = get_response(cohere_client, [text], self.model_name, self.input_type, self.truncate)

return {"embedding": embedding[0], "metadata": metadata}
10 changes: 6 additions & 4 deletions integrations/cohere/src/cohere_haystack/embedders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
API_BASE_URL = "https://api.cohere.ai/v1/embed"


async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, truncate):
async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, input_type, truncate):
all_embeddings: List[List[float]] = []
metadata: Dict[str, Any] = {}
try:
response = await cohere_async_client.embed(texts=texts, model=model_name, truncate=truncate)
response = await cohere_async_client.embed(
texts=texts, model=model_name, input_type=input_type, truncate=truncate
)
if response.meta is not None:
metadata = response.meta
for emb in response.embeddings:
Expand All @@ -27,7 +29,7 @@ async def get_async_response(cohere_async_client: AsyncClient, texts: List[str],


def get_response(
cohere_client: Client, texts: List[str], model_name, truncate, batch_size=32, progress_bar=False
cohere_client: Client, texts: List[str], model_name, input_type, truncate, batch_size=32, progress_bar=False
) -> Tuple[List[List[float]], Dict[str, Any]]:
"""
We support batching with the sync client.
Expand All @@ -42,7 +44,7 @@ def get_response(
desc="Calculating embeddings",
):
batch = texts[i : i + batch_size]
response = cohere_client.embed(batch, model=model_name, truncate=truncate)
response = cohere_client.embed(batch, model=model_name, input_type=input_type, truncate=truncate)
for emb in response.embeddings:
all_embeddings.append(emb)
embeddings = [list(map(float, emb)) for emb in response.embeddings]
Expand Down
6 changes: 6 additions & 0 deletions integrations/cohere/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def test_init_default(self):
embedder = CohereDocumentEmbedder(api_key="test-api-key")
assert embedder.api_key == "test-api-key"
assert embedder.model_name == "embed-english-v2.0"
assert embedder.input_type == "search_document"
assert embedder.api_base_url == COHERE_API_URL
assert embedder.truncate == "END"
assert embedder.use_async_client is False
Expand All @@ -31,6 +32,7 @@ def test_init_with_parameters(self):
embedder = CohereDocumentEmbedder(
api_key="test-api-key",
model_name="embed-multilingual-v2.0",
input_type="search_query",
api_base_url="https://custom-api-base-url.com",
truncate="START",
use_async_client=True,
Expand All @@ -43,6 +45,7 @@ def test_init_with_parameters(self):
)
assert embedder.api_key == "test-api-key"
assert embedder.model_name == "embed-multilingual-v2.0"
assert embedder.input_type == "search_query"
assert embedder.api_base_url == "https://custom-api-base-url.com"
assert embedder.truncate == "START"
assert embedder.use_async_client is True
Expand All @@ -60,6 +63,7 @@ def test_to_dict(self):
"type": "cohere_haystack.embedders.document_embedder.CohereDocumentEmbedder",
"init_parameters": {
"model_name": "embed-english-v2.0",
"input_type": "search_document",
"api_base_url": COHERE_API_URL,
"truncate": "END",
"use_async_client": False,
Expand All @@ -76,6 +80,7 @@ def test_to_dict_with_custom_init_parameters(self):
embedder_component = CohereDocumentEmbedder(
api_key="test-api-key",
model_name="embed-multilingual-v2.0",
input_type="search_query",
api_base_url="https://custom-api-base-url.com",
truncate="START",
use_async_client=True,
Expand All @@ -91,6 +96,7 @@ def test_to_dict_with_custom_init_parameters(self):
"type": "cohere_haystack.embedders.document_embedder.CohereDocumentEmbedder",
"init_parameters": {
"model_name": "embed-multilingual-v2.0",
"input_type": "search_query",
"api_base_url": "https://custom-api-base-url.com",
"truncate": "START",
"use_async_client": True,
Expand Down
6 changes: 6 additions & 0 deletions integrations/cohere/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_init_default(self):

assert embedder.api_key == "test-api-key"
assert embedder.model_name == "embed-english-v2.0"
assert embedder.input_type == "search_document"
assert embedder.api_base_url == COHERE_API_URL
assert embedder.truncate == "END"
assert embedder.use_async_client is False
Expand All @@ -33,6 +34,7 @@ def test_init_with_parameters(self):
embedder = CohereTextEmbedder(
api_key="test-api-key",
model_name="embed-multilingual-v2.0",
input_type="search_query",
api_base_url="https://custom-api-base-url.com",
truncate="START",
use_async_client=True,
Expand All @@ -41,6 +43,7 @@ def test_init_with_parameters(self):
)
assert embedder.api_key == "test-api-key"
assert embedder.model_name == "embed-multilingual-v2.0"
assert embedder.input_type == "search_query"
assert embedder.api_base_url == "https://custom-api-base-url.com"
assert embedder.truncate == "START"
assert embedder.use_async_client is True
Expand All @@ -57,6 +60,7 @@ def test_to_dict(self):
"type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder",
"init_parameters": {
"model_name": "embed-english-v2.0",
"input_type": "search_document",
"api_base_url": COHERE_API_URL,
"truncate": "END",
"use_async_client": False,
Expand All @@ -72,6 +76,7 @@ def test_to_dict_with_custom_init_parameters(self):
embedder_component = CohereTextEmbedder(
api_key="test-api-key",
model_name="embed-multilingual-v2.0",
input_type="search_query",
api_base_url="https://custom-api-base-url.com",
truncate="START",
use_async_client=True,
Expand All @@ -83,6 +88,7 @@ def test_to_dict_with_custom_init_parameters(self):
"type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder",
"init_parameters": {
"model_name": "embed-multilingual-v2.0",
"input_type": "search_query",
"api_base_url": "https://custom-api-base-url.com",
"truncate": "START",
"use_async_client": True,
Expand Down

0 comments on commit 22b7704

Please sign in to comment.