Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cohere] Add text and document embedders #80

Merged
merged 11 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions integrations/cohere/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,56 @@

**Table of Contents**

- [Installation](#installation)
- [License](#license)
- [cohere-haystack](#cohere-haystack)
- [Installation](#installation)
- [Contributing](#contributing)
- [License](#license)

## Installation

```console
pip install cohere-haystack
```

## Contributing

`hatch` is the best way to interact with this project, to install it:
```sh
pip install hatch
```

With `hatch` installed, to run all the tests:
```
hatch run test
```
> Note: integration tests will be skipped unless the env var COHERE_API_KEY is set. The api key needs to be valid
> in order to pass the tests.

To only run unit tests:
```
hatch run test -m"not integration"
```

To only run embedders tests:
```
hatch run test -m"embedders"
```

To only run generators tests:
```
hatch run test -m"generators"
```

Markers can be combined, for example you can run only integration tests for embedders with:
```
hatch run test -m"integrations and embedders"
```

To run the linters `ruff` and `mypy`:
```
hatch run lint:all
```

## License

`cohere-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license.
17 changes: 12 additions & 5 deletions integrations/cohere/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ select = [
"E",
"EM",
"F",
"FBT",
"I",
"ICN",
"ISC",
Expand All @@ -118,8 +117,6 @@ select = [
ignore = [
# Allow non-abstract empty methods in abstract base classes
"B027",
# Allow boolean positional values in function calls, like `dict.get(... True)`
"FBT003",
# Ignore checks for possible passwords
"S105", "S106", "S107",
# Ignore complexity
Expand Down Expand Up @@ -163,6 +160,16 @@ exclude_lines = [
module = [
"cohere.*",
"haystack.*",
"pytest.*"
"pytest.*",
"numpy.*",
]
ignore_missing_imports = true
ignore_missing_imports = true

[tool.pytest.ini_options]
addopts = "--strict-markers"
markers = [
"integration: integration tests",
"embedders: embedders tests",
"generators: generators tests",
]
log_cli = true
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
from typing import Any, Dict, List, Optional

from cohere import COHERE_API_URL, AsyncClient, Client
from haystack import Document, component, default_to_dict

from cohere_haystack.embedders.utils import get_async_response, get_response


@component
class CohereDocumentEmbedder:
"""
A component for computing Document embeddings using Cohere models.
The embedding of each Document is stored in the `embedding` field of the Document.
"""

def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "embed-english-v2.0",
api_base_url: str = COHERE_API_URL,
truncate: str = "END",
use_async_client: bool = False,
max_retries: int = 3,
timeout: int = 120,
batch_size: int = 32,
progress_bar: bool = True,
metadata_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
"""
Create a CohereDocumentEmbedder component.

:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment
variable COHERE_API_KEY (recommended).
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are
`"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`,
`"embed-multilingual-v2.0"`/ `"multilingual-22-12"`.
:param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`.
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to
`"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both
cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
:param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use
AsyncClient for applications with many concurrent calls.
:param max_retries: maximal number of retries for requests, defaults to `3`.
:param timeout: request timeout in seconds, defaults to `120`.
:param batch_size: Number of Documents to encode at once.
:param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments
to keep the logs clean.
:param metadata_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""

if api_key is None:
try:
api_key = os.environ["COHERE_API_KEY"]
except KeyError as error_msg:
msg = (
"CohereDocumentEmbedder expects an Cohere API key. Please provide one by setting the environment "
"variable COHERE_API_KEY (recommended) or by passing it explicitly."
)
raise ValueError(msg) from error_msg

self.api_key = api_key
self.model_name = model_name
self.api_base_url = api_base_url
self.truncate = truncate
self.use_async_client = use_async_client
self.max_retries = max_retries
self.timeout = timeout
self.batch_size = batch_size
self.progress_bar = progress_bar
self.metadata_fields_to_embed = metadata_fields_to_embed or []
self.embedding_separator = embedding_separator

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary omitting the api_key field.
"""
return default_to_dict(
self,
model_name=self.model_name,
api_base_url=self.api_base_url,
truncate=self.truncate,
use_async_client=self.use_async_client,
max_retries=self.max_retries,
timeout=self.timeout,
batch_size=self.batch_size,
progress_bar=self.progress_bar,
metadata_fields_to_embed=self.metadata_fields_to_embed,
embedding_separator=self.embedding_separator,
)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
"""
texts_to_embed: List[str] = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.metadata_fields_to_embed if doc.meta.get(key) is not None
]

text_to_embed = self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) # noqa: RUF005
texts_to_embed.append(text_to_embed)
return texts_to_embed

@component.output_types(documents=List[Document], metadata=Dict[str, Any])
def run(self, documents: List[Document]):
"""
Embed a list of Documents.
The embedding of each Document is stored in the `embedding` field of the Document.

:param documents: A list of Documents to embed.
"""

if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
msg = (
"CohereDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a string, please use the CohereTextEmbedder."
)
raise TypeError(msg)

if not documents:
# return early if we were passed an empty list
return {"documents": [], "metadata": {}}

texts_to_embed = self._prepare_texts_to_embed(documents)

if self.use_async_client:
cohere_client = AsyncClient(
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
all_embeddings, metadata = asyncio.run(
get_async_response(cohere_client, texts_to_embed, self.model_name, self.truncate)
)
else:
cohere_client = Client(
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
all_embeddings, metadata = get_response(
cohere_client, texts_to_embed, self.model_name, self.truncate, self.batch_size, self.progress_bar
)

for doc, embeddings in zip(documents, all_embeddings):
doc.embedding = embeddings

return {"documents": documents, "metadata": metadata}
104 changes: 104 additions & 0 deletions integrations/cohere/src/cohere_haystack/embedders/text_embedder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
from typing import Any, Dict, List, Optional

from cohere import COHERE_API_URL, AsyncClient, Client
from haystack import component, default_to_dict

from cohere_haystack.embedders.utils import get_async_response, get_response


@component
class CohereTextEmbedder:
"""
A component for embedding strings using Cohere models.
"""

def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "embed-english-v2.0",
api_base_url: str = COHERE_API_URL,
truncate: str = "END",
use_async_client: bool = False,
max_retries: int = 3,
timeout: int = 120,
):
"""
Create a CohereTextEmbedder component.

:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment
variable COHERE_API_KEY (recommended).
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are
`"embed-english-v2.0"`/ `"large"`, `"embed-english-light-v2.0"`/ `"small"`,
`"embed-multilingual-v2.0"`/ `"multilingual-22-12"`.
:param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`.
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to
`"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both
cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
:param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use
AsyncClient for applications with many concurrent calls.
:param max_retries: Maximum number of retries for requests, defaults to `3`.
:param timeout: Request timeout in seconds, defaults to `120`.
"""

if api_key is None:
try:
api_key = os.environ["COHERE_API_KEY"]
except KeyError as error_msg:
msg = (
"CohereTextEmbedder expects an Cohere API key. Please provide one by setting the environment "
"variable COHERE_API_KEY (recommended) or by passing it explicitly."
)
raise ValueError(msg) from error_msg

self.api_key = api_key
self.model_name = model_name
self.api_base_url = api_base_url
self.truncate = truncate
self.use_async_client = use_async_client
self.max_retries = max_retries
self.timeout = timeout

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary omitting the api_key field.
"""
return default_to_dict(
self,
model_name=self.model_name,
api_base_url=self.api_base_url,
truncate=self.truncate,
use_async_client=self.use_async_client,
max_retries=self.max_retries,
timeout=self.timeout,
)

@component.output_types(embedding=List[float], metadata=Dict[str, Any])
def run(self, text: str):
"""Embed a string."""
if not isinstance(text, str):
msg = (
"CohereTextEmbedder expects a string as input."
"In case you want to embed a list of Documents, please use the CohereDocumentEmbedder."
)
raise TypeError(msg)

# Establish connection to API

if self.use_async_client:
cohere_client = AsyncClient(
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
embedding, metadata = asyncio.run(get_async_response(cohere_client, [text], self.model_name, self.truncate))
else:
cohere_client = Client(
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
embedding, metadata = get_response(cohere_client, [text], self.model_name, self.truncate)

return {"embedding": embedding[0], "metadata": metadata}
Loading