Skip to content

Commit

Permalink
Change default 'input_type' for CohereTextEmbedder (#99)
Browse files Browse the repository at this point in the history
* Change default 'input_type' for CohereTextEmbedder

* Update tests after updating the default value of input_type
  • Loading branch information
bilgeyucel authored Dec 13, 2023
1 parent 8454bda commit d45eca2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "embed-english-v2.0",
input_type: str = "search_document",
input_type: str = "search_query",
api_base_url: str = COHERE_API_URL,
truncate: str = "END",
use_async_client: bool = False,
Expand Down
12 changes: 6 additions & 6 deletions integrations/cohere/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_init_default(self):

assert embedder.api_key == "test-api-key"
assert embedder.model_name == "embed-english-v2.0"
assert embedder.input_type == "search_document"
assert embedder.input_type == "search_query"
assert embedder.api_base_url == COHERE_API_URL
assert embedder.truncate == "END"
assert embedder.use_async_client is False
Expand All @@ -34,7 +34,7 @@ def test_init_with_parameters(self):
embedder = CohereTextEmbedder(
api_key="test-api-key",
model_name="embed-multilingual-v2.0",
input_type="search_query",
input_type="classification",
api_base_url="https://custom-api-base-url.com",
truncate="START",
use_async_client=True,
Expand All @@ -43,7 +43,7 @@ def test_init_with_parameters(self):
)
assert embedder.api_key == "test-api-key"
assert embedder.model_name == "embed-multilingual-v2.0"
assert embedder.input_type == "search_query"
assert embedder.input_type == "classification"
assert embedder.api_base_url == "https://custom-api-base-url.com"
assert embedder.truncate == "START"
assert embedder.use_async_client is True
Expand All @@ -60,7 +60,7 @@ def test_to_dict(self):
"type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder",
"init_parameters": {
"model_name": "embed-english-v2.0",
"input_type": "search_document",
"input_type": "search_query",
"api_base_url": COHERE_API_URL,
"truncate": "END",
"use_async_client": False,
Expand All @@ -76,7 +76,7 @@ def test_to_dict_with_custom_init_parameters(self):
embedder_component = CohereTextEmbedder(
api_key="test-api-key",
model_name="embed-multilingual-v2.0",
input_type="search_query",
input_type="classification",
api_base_url="https://custom-api-base-url.com",
truncate="START",
use_async_client=True,
Expand All @@ -88,7 +88,7 @@ def test_to_dict_with_custom_init_parameters(self):
"type": "cohere_haystack.embedders.text_embedder.CohereTextEmbedder",
"init_parameters": {
"model_name": "embed-multilingual-v2.0",
"input_type": "search_query",
"input_type": "classification",
"api_base_url": "https://custom-api-base-url.com",
"truncate": "START",
"use_async_client": True,
Expand Down

0 comments on commit d45eca2

Please sign in to comment.