Skip to content

Commit

Permalink
docs: review cohere integration (#500)
Browse files Browse the repository at this point in the history
  • Loading branch information
wochinge authored Feb 29, 2024
1 parent 65715d6 commit e269dd8
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
class CohereDocumentEmbedder:
"""
A component for computing Document embeddings using Cohere models.
The embedding of each Document is stored in the `embedding` field of the Document.
Usage Example:
Usage example:
```python
from haystack import Document
from cohere_haystack.embedders.document_embedder import CohereDocumentEmbedder
Expand Down Expand Up @@ -49,32 +50,30 @@ def __init__(
embedding_separator: str = "\n",
):
"""
Create a CohereDocumentEmbedder component.
:param api_key: The Cohere API key.
:param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
:param api_key: the Cohere API key.
:param model: the name of the model to use. Supported Models are:
`"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`,
`"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`,
`"embed-multilingual-v2.0"`. This list of all supported models can be found in the
[model documentation](https://docs.cohere.com/docs/models#representation).
:param input_type: Specifies the type of input you're giving to the model. Supported values are
"search_document", "search_query", "classification" and "clustering". Defaults to "search_document". Not
required for older versions of the embedding models (meaning anything lower than v3), but is required for more
recent versions (meaning anything bigger than v2).
:param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`.
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to
`"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both
:param input_type: specifies the type of input you're giving to the model. Supported values are
"search_document", "search_query", "classification" and "clustering". Not
required for older versions of the embedding models (meaning anything lower than v3), but is required for more
recent versions (meaning anything bigger than v2).
:param api_base_url: the Cohere API Base url.
:param truncate: truncate embeddings that are too long from start or end, ("NONE"|"START"|"END").
Passing "START" will discard the start of the input. "END" will discard the end of the input. In both
cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
:param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use
If "NONE" is selected, when the input exceeds the maximum input token length an error will be returned.
:param use_async_client: flag to select the AsyncClient. It is recommended to use
AsyncClient for applications with many concurrent calls.
:param max_retries: maximal number of retries for requests, defaults to `3`.
:param timeout: request timeout in seconds, defaults to `120`.
:param batch_size: Number of Documents to encode at once.
:param progress_bar: Whether to show a progress bar or not. Can be helpful to disable in production deployments
:param max_retries: maximal number of retries for requests.
:param timeout: request timeout in seconds.
:param batch_size: number of Documents to encode at once.
:param progress_bar: whether to show a progress bar or not. Can be helpful to disable in production deployments
to keep the logs clean.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
:param meta_fields_to_embed: list of meta fields that should be embedded along with the Document text.
:param embedding_separator: separator used to concatenate the meta fields to the Document text.
"""

self.api_key = api_key
Expand All @@ -92,7 +91,10 @@ def __init__(

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary omitting the api_key field.
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
Expand All @@ -113,9 +115,12 @@ def to_dict(self) -> Dict[str, Any]:
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CohereDocumentEmbedder":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
init_params = data.get("init_parameters", {})
deserialize_secrets_inplace(init_params, ["api_key"])
Expand All @@ -137,13 +142,14 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:

@component.output_types(documents=List[Document], meta=Dict[str, Any])
def run(self, documents: List[Document]):
"""
Embed a list of Documents.
The embedding of each Document is stored in the `embedding` field of the Document.
"""Embed a list of `Documents`.
:param documents: A list of Documents to embed.
:param documents: documents to embed.
:returns: A dictionary with the following keys:
- `documents`: documents with the `embedding` field set.
- `meta`: metadata about the embedding process.
:raises TypeError: if the input is not a list of `Documents`.
"""

if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
msg = (
"CohereDocumentEmbedder expects a list of Documents as input."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ class CohereTextEmbedder:
"""
A component for embedding strings using Cohere models.
Usage Example:
Usage example:
```python
from cohere_haystack.embedders.text_embedder import CohereTextEmbedder
from haystack_integrations.components.embedders.cohere import CohereDocumentEmbedder
text_to_embed = "I love pizza!"
Expand All @@ -43,27 +43,25 @@ def __init__(
timeout: int = 120,
):
"""
Create a CohereTextEmbedder component.
:param api_key: The Cohere API key.
:param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
:param api_key: the Cohere API key.
:param model: the name of the model to use. Supported Models are:
`"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`,
`"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`,
`"embed-multilingual-v2.0"`. This list of all supported models can be found in the
[model documentation](https://docs.cohere.com/docs/models#representation).
:param input_type: Specifies the type of input you're giving to the model. Supported values are
"search_document", "search_query", "classification" and "clustering". Defaults to "search_document". Not
required for older versions of the embedding models (meaning anything lower than v3), but is required for more
recent versions (meaning anything bigger than v2).
:param api_base_url: The Cohere API Base url, defaults to `https://api.cohere.ai/v1/embed`.
:param truncate: Truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to
`"END"`. Passing START will discard the start of the input. END will discard the end of the input. In both
:param input_type: specifies the type of input you're giving to the model. Supported values are
"search_document", "search_query", "classification" and "clustering". Not
required for older versions of the embedding models (meaning anything lower than v3), but is required for more
recent versions (meaning anything bigger than v2).
:param api_base_url: the Cohere API Base url.
:param truncate: truncate embeddings that are too long from start or end, ("NONE"|"START"|"END"), defaults to
`"END"`. Passing "START" will discard the start of the input. "END" will discard the end of the input. In both
cases, input is discarded until the remaining input is exactly the maximum input token length for the model.
If NONE is selected, when the input exceeds the maximum input token length an error will be returned.
:param use_async_client: Flag to select the AsyncClient, defaults to `False`. It is recommended to use
If "NONE" is selected, when the input exceeds the maximum input token length an error will be returned.
:param use_async_client: flag to select the AsyncClient. It is recommended to use
AsyncClient for applications with many concurrent calls.
:param max_retries: Maximum number of retries for requests, defaults to `3`.
:param timeout: Request timeout in seconds, defaults to `120`.
:param max_retries: maximum number of retries for requests.
:param timeout: request timeout in seconds.
"""

self.api_key = api_key
Expand All @@ -77,7 +75,10 @@ def __init__(

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary omitting the api_key field.
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
Expand All @@ -94,17 +95,27 @@ def to_dict(self) -> Dict[str, Any]:
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CohereTextEmbedder":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
init_params = data.get("init_parameters", {})
deserialize_secrets_inplace(init_params, ["api_key"])
return default_from_dict(cls, data)

@component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str):
"""Embed a string."""
"""Embed text.
:param text: the text to embed.
:returns: A dictionary with the following keys:
- "embedding": the embedding of the text.
- "meta": metadata about the request.
:raises TypeError: If the input is not a string.
"""
if not isinstance(text, str):
msg = (
"CohereTextEmbedder expects a string as input."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,19 @@


async def get_async_response(cohere_async_client: AsyncClient, texts: List[str], model_name, input_type, truncate):
"""Embeds a list of texts asynchronously using the Cohere API.
:param cohere_async_client: the Cohere `AsyncClient`
:param texts: the texts to embed
:param model_name: the name of the model to use
:param input_type: one of "classification", "clustering", "search_document", "search_query".
The type of input text provided to embed.
:param truncate: one of "NONE", "START", "END". How the API handles text longer than the maximum token length.
:returns: A tuple of the embeddings and metadata.
:raises ValueError: If an error occurs while querying the Cohere API.
"""
all_embeddings: List[List[float]] = []
metadata: Dict[str, Any] = {}
try:
Expand All @@ -30,9 +43,22 @@ async def get_async_response(cohere_async_client: AsyncClient, texts: List[str],
def get_response(
cohere_client: Client, texts: List[str], model_name, input_type, truncate, batch_size=32, progress_bar=False
) -> Tuple[List[List[float]], Dict[str, Any]]:
"""Embeds a list of texts using the Cohere API.
:param cohere_client: the Cohere `Client`
:param texts: the texts to embed
:param model_name: the name of the model to use
:param input_type: one of "classification", "clustering", "search_document", "search_query".
The type of input text provided to embed.
:param truncate: one of "NONE", "START", "END". How the API handles text longer than the maximum token length.
:param batch_size: the batch size to use
:param progress_bar: if `True`, show a progress bar
:returns: A tuple of the embeddings and metadata.
:raises ValueError: If an error occurs while querying the Cohere API.
"""
We support batching with the sync client.
"""

all_embeddings: List[List[float]] = []
metadata: Dict[str, Any] = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,27 @@

@component
class CohereChatGenerator:
"""Enables text generation using Cohere's chat endpoint. This component is designed to inference
Cohere's chat models.
"""
Enables text generation using Cohere's chat endpoint.
This component is designed to inference Cohere's chat models.
Users can pass any text generation parameters valid for the `cohere.Client,chat` method
directly to this component via the `**generation_kwargs` parameter in __init__ or the `**generation_kwargs`
parameter in `run` method.
Invocations are made using 'cohere' package.
See [Cohere API](https://docs.cohere.com/reference/chat) for more details.
Example usage:
```python
from haystack_integrations.components.generators.cohere import CohereChatGenerator
component = CohereChatGenerator(api_key=Secret.from_token("test-api-key"))
response = component.run(chat_messages)
assert response["replies"]
```
"""

def __init__(
Expand All @@ -37,12 +49,12 @@ def __init__(
"""
Initialize the CohereChatGenerator instance.
:param api_key: The API key for the Cohere API.
:param api_key: the API key for the Cohere API.
:param model: The name of the model to use. Available models are: [command, command-light, command-nightly,
command-nightly-light]. Defaults to "command".
:param streaming_callback: A callback function to be called with the streaming response. Defaults to None.
:param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai".
:param generation_kwargs: Additional model parameters. These will be used during generation. Refer to
command-nightly-light].
:param streaming_callback: a callback function to be called with the streaming response.
:param api_base_url: the base URL of the Cohere API.
:param generation_kwargs: additional model parameters. These will be used during generation. Refer to
https://docs.cohere.com/reference/chat for more details.
Some of the parameters are:
- 'chat_history': A list of previous messages between the user and the model, meant to give the model
Expand Down Expand Up @@ -89,8 +101,10 @@ def _get_telemetry_data(self) -> Dict[str, Any]:

def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:return: The serialized component as a dictionary.
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
Expand All @@ -105,9 +119,12 @@ def to_dict(self) -> Dict[str, Any]:
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CohereChatGenerator":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
init_params = data.get("init_parameters", {})
deserialize_secrets_inplace(init_params, ["api_key"])
Expand All @@ -126,12 +143,13 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
"""
Invoke the text generation inference based on the provided messages and generation parameters.
:param messages: A list of ChatMessage instances representing the input messages.
:param generation_kwargs: Additional keyword arguments for text generation. These parameters will
potentially override the parameters passed in the __init__ method.
For more details on the parameters supported by the Cohere API, refer to the
Cohere [documentation](https://docs.cohere.com/reference/chat).
:return: A list containing the generated responses as ChatMessage instances.
:param messages: list of `ChatMessage` instances representing the input messages.
:param generation_kwargs: additional keyword arguments for text generation. These parameters will
potentially override the parameters passed in the __init__ method.
For more details on the parameters supported by the Cohere API, refer to the
Cohere [documentation](https://docs.cohere.com/reference/chat).
:returns: A dictionary with the following keys:
- "replies": a list of `ChatMessage` instances representing the generated responses.
"""
# update generation kwargs by merging with the generation kwargs passed to the run method
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
Expand Down
Loading

0 comments on commit e269dd8

Please sign in to comment.