diff --git a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py index 4ba8acd47..bfef97dc3 100644 --- a/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py +++ b/integrations/cohere/src/cohere_haystack/embedders/text_embedder.py @@ -34,7 +34,7 @@ class CohereTextEmbedder: def __init__( self, api_key: Optional[str] = None, - model_name: str = "embed-english-v2.0", + model: str = "embed-english-v2.0", input_type: str = "search_query", api_base_url: str = COHERE_API_URL, truncate: str = "END", @@ -47,7 +47,7 @@ def __init__( :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment variable COHERE_API_KEY (recommended). - :param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are: + :param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are: `"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`, `"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`, `"embed-multilingual-v2.0"`. This list of all supported models can be found in the @@ -77,7 +77,7 @@ def __init__( raise ValueError(msg) self.api_key = api_key - self.model_name = model_name + self.model = model self.input_type = input_type self.api_base_url = api_base_url self.truncate = truncate @@ -91,7 +91,7 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, - model_name=self.model_name, + model=self.model, input_type=self.input_type, api_base_url=self.api_base_url, truncate=self.truncate, @@ -117,12 +117,12 @@ def run(self, text: str): self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout ) embedding, metadata = asyncio.run( - get_async_response(cohere_client, [text], self.model_name, self.input_type, self.truncate) + get_async_response(cohere_client, [text], self.model, self.input_type, self.truncate) ) else: cohere_client = Client( self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout ) - embedding, metadata = get_response(cohere_client, [text], self.model_name, self.input_type, self.truncate) + embedding, metadata = get_response(cohere_client, [text], self.model, self.input_type, self.truncate) return {"embedding": embedding[0], "meta": metadata}