Skip to content

Commit

Permalink
rename model_name to model in cohere text embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
ZanSara committed Jan 16, 2024
1 parent 346fa1d commit a0ce770
Showing 1 changed file with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class CohereTextEmbedder:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "embed-english-v2.0",
model: str = "embed-english-v2.0",
input_type: str = "search_query",
api_base_url: str = COHERE_API_URL,
truncate: str = "END",
Expand All @@ -47,7 +47,7 @@ def __init__(
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment
variable COHERE_API_KEY (recommended).
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
:param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
`"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`,
`"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`,
`"embed-multilingual-v2.0"`. This list of all supported models can be found in the
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
raise ValueError(msg)

self.api_key = api_key
self.model_name = model_name
self.model = model
self.input_type = input_type
self.api_base_url = api_base_url
self.truncate = truncate
Expand All @@ -91,7 +91,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
input_type=self.input_type,
api_base_url=self.api_base_url,
truncate=self.truncate,
Expand All @@ -117,12 +117,12 @@ def run(self, text: str):
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
embedding, metadata = asyncio.run(
get_async_response(cohere_client, [text], self.model_name, self.input_type, self.truncate)
get_async_response(cohere_client, [text], self.model, self.input_type, self.truncate)
)
else:
cohere_client = Client(
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
embedding, metadata = get_response(cohere_client, [text], self.model_name, self.input_type, self.truncate)
embedding, metadata = get_response(cohere_client, [text], self.model, self.input_type, self.truncate)

return {"embedding": embedding[0], "meta": metadata}

0 comments on commit a0ce770

Please sign in to comment.