Skip to content

Commit

Permalink
chore!: Rename model_name to model in the Cohere integration (#222)
Browse files Browse the repository at this point in the history
* rename model_name to model in cohere chat generator

* fix chat generator tests

* rename model_name to model in cohere generator

* fix generator tests

* rename model_name to model in cohere document embedder

* fix document embedder tests

* rename model_name to model in cohere text embedder

* fix text embedder tests

* black
  • Loading branch information
ZanSara authored Jan 17, 2024
1 parent 95effa1 commit bc243de
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 63 deletions.
16 changes: 8 additions & 8 deletions integrations/cohere/src/cohere_haystack/chat/chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class CohereChatGenerator:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "command",
model: str = "command",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -37,7 +37,7 @@ def __init__(
Initialize the CohereChatGenerator instance.
:param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var.
:param model_name: The name of the model to use. Available models are: [command, command-light, command-nightly,
:param model: The name of the model to use. Available models are: [command, command-light, command-nightly,
command-nightly-light]. Defaults to "command".
:param streaming_callback: A callback function to be called with the streaming response. Defaults to None.
:param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai".
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(
if generation_kwargs is None:
generation_kwargs = {}
self.api_key = api_key
self.model_name = model_name
self.model = model
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
self.generation_kwargs = generation_kwargs
Expand All @@ -93,7 +93,7 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -103,7 +103,7 @@ def to_dict(self) -> Dict[str, Any]:
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
generation_kwargs=self.generation_kwargs,
Expand Down Expand Up @@ -147,7 +147,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
chat_history = [self._message_to_dict(m) for m in messages[:-1]]
response = self.client.chat(
message=messages[-1].content,
model=self.model_name,
model=self.model,
stream=self.streaming_callback is not None,
chat_history=chat_history,
**generation_kwargs,
Expand All @@ -160,7 +160,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
chat_message = ChatMessage.from_assistant(content=response.texts)
chat_message.meta.update(
{
"model": self.model_name,
"model": self.model,
"usage": response.token_count,
"index": 0,
"finish_reason": response.finish_reason,
Expand Down Expand Up @@ -193,7 +193,7 @@ def _build_message(self, cohere_response):
message = ChatMessage.from_assistant(content=content)
message.meta.update(
{
"model": self.model_name,
"model": self.model,
"usage": cohere_response.token_count,
"index": 0,
"finish_reason": None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class CohereDocumentEmbedder:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "embed-english-v2.0",
model: str = "embed-english-v2.0",
input_type: str = "search_document",
api_base_url: str = COHERE_API_URL,
truncate: str = "END",
Expand All @@ -53,7 +53,7 @@ def __init__(
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment
variable COHERE_API_KEY (recommended).
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
:param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
`"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`,
`"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`,
`"embed-multilingual-v2.0"`. This list of all supported models can be found in the
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(
raise ValueError(msg)

self.api_key = api_key
self.model_name = model_name
self.model = model
self.input_type = input_type
self.api_base_url = api_base_url
self.truncate = truncate
Expand All @@ -106,7 +106,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
input_type=self.input_type,
api_base_url=self.api_base_url,
truncate=self.truncate,
Expand Down Expand Up @@ -160,7 +160,7 @@ def run(self, documents: List[Document]):
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
all_embeddings, metadata = asyncio.run(
get_async_response(cohere_client, texts_to_embed, self.model_name, self.input_type, self.truncate)
get_async_response(cohere_client, texts_to_embed, self.model, self.input_type, self.truncate)
)
else:
cohere_client = Client(
Expand All @@ -169,7 +169,7 @@ def run(self, documents: List[Document]):
all_embeddings, metadata = get_response(
cohere_client,
texts_to_embed,
self.model_name,
self.model,
self.input_type,
self.truncate,
self.batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class CohereTextEmbedder:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "embed-english-v2.0",
model: str = "embed-english-v2.0",
input_type: str = "search_query",
api_base_url: str = COHERE_API_URL,
truncate: str = "END",
Expand All @@ -47,7 +47,7 @@ def __init__(
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment
variable COHERE_API_KEY (recommended).
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
:param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
`"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`,
`"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`,
`"embed-multilingual-v2.0"`. This list of all supported models can be found in the
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
raise ValueError(msg)

self.api_key = api_key
self.model_name = model_name
self.model = model
self.input_type = input_type
self.api_base_url = api_base_url
self.truncate = truncate
Expand All @@ -91,7 +91,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
input_type=self.input_type,
api_base_url=self.api_base_url,
truncate=self.truncate,
Expand All @@ -117,12 +117,12 @@ def run(self, text: str):
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
embedding, metadata = asyncio.run(
get_async_response(cohere_client, [text], self.model_name, self.input_type, self.truncate)
get_async_response(cohere_client, [text], self.model, self.input_type, self.truncate)
)
else:
cohere_client = Client(
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
embedding, metadata = get_response(cohere_client, [text], self.model_name, self.input_type, self.truncate)
embedding, metadata = get_response(cohere_client, [text], self.model, self.input_type, self.truncate)

return {"embedding": embedding[0], "meta": metadata}
10 changes: 5 additions & 5 deletions integrations/cohere/src/cohere_haystack/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CohereGenerator:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "command",
model: str = "command",
streaming_callback: Optional[Callable] = None,
api_base_url: Optional[str] = None,
**kwargs,
Expand All @@ -41,7 +41,7 @@ def __init__(
Instantiates a `CohereGenerator` component.
:param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var.
:param model_name: The name of the model to use. Available models are: [command, command-light, command-nightly,
:param model: The name of the model to use. Available models are: [command, command-light, command-nightly,
command-nightly-light]. Defaults to "command".
:param streaming_callback: A callback function to be called with the streaming response. Defaults to None.
:param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai".
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(
api_base_url = COHERE_API_URL

self.api_key = api_key
self.model_name = model_name
self.model = model
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
self.model_parameters = kwargs
Expand All @@ -107,7 +107,7 @@ def to_dict(self) -> Dict[str, Any]:

return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
**self.model_parameters,
Expand Down Expand Up @@ -142,7 +142,7 @@ def run(self, prompt: str):
:param prompt: The prompt to be sent to the generative model.
"""
response = self.client.generate(
model=self.model_name, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters
model=self.model, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters
)
if self.streaming_callback:
metadata_dict: Dict[str, Any] = {}
Expand Down
26 changes: 12 additions & 14 deletions integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TestCohereChatGenerator:
def test_init_default(self):
component = CohereChatGenerator(api_key="test-api-key")
assert component.api_key == "test-api-key"
assert component.model_name == "command"
assert component.model == "command"
assert component.streaming_callback is None
assert component.api_base_url == cohere.COHERE_API_URL
assert not component.generation_kwargs
Expand All @@ -72,13 +72,13 @@ def test_init_fail_wo_api_key(self, monkeypatch):
def test_init_with_parameters(self):
component = CohereChatGenerator(
api_key="test-api-key",
model_name="command-nightly",
model="command-nightly",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
assert component.api_key == "test-api-key"
assert component.model_name == "command-nightly"
assert component.model == "command-nightly"
assert component.streaming_callback is default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
Expand All @@ -90,7 +90,7 @@ def test_to_dict_default(self):
assert data == {
"type": "cohere_haystack.chat.chat_generator.CohereChatGenerator",
"init_parameters": {
"model_name": "command",
"model": "command",
"streaming_callback": None,
"api_base_url": "https://api.cohere.ai",
"generation_kwargs": {},
Expand All @@ -101,7 +101,7 @@ def test_to_dict_default(self):
def test_to_dict_with_parameters(self):
component = CohereChatGenerator(
api_key="test-api-key",
model_name="command-nightly",
model="command-nightly",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -110,7 +110,7 @@ def test_to_dict_with_parameters(self):
assert data == {
"type": "cohere_haystack.chat.chat_generator.CohereChatGenerator",
"init_parameters": {
"model_name": "command-nightly",
"model": "command-nightly",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
"api_base_url": "test-base-url",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -121,7 +121,7 @@ def test_to_dict_with_parameters(self):
def test_to_dict_with_lambda_streaming_callback(self):
component = CohereChatGenerator(
api_key="test-api-key",
model_name="command",
model="command",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -130,7 +130,7 @@ def test_to_dict_with_lambda_streaming_callback(self):
assert data == {
"type": "cohere_haystack.chat.chat_generator.CohereChatGenerator",
"init_parameters": {
"model_name": "command",
"model": "command",
"api_base_url": "test-base-url",
"streaming_callback": "tests.test_cohere_chat_generator.<lambda>",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -143,14 +143,14 @@ def test_from_dict(self, monkeypatch):
data = {
"type": "cohere_haystack.chat.chat_generator.CohereChatGenerator",
"init_parameters": {
"model_name": "command",
"model": "command",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
component = CohereChatGenerator.from_dict(data)
assert component.model_name == "command"
assert component.model == "command"
assert component.streaming_callback is default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
Expand All @@ -161,7 +161,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
data = {
"type": "cohere_haystack.chat.chat_generator.CohereChatGenerator",
"init_parameters": {
"model_name": "command",
"model": "command",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
Expand Down Expand Up @@ -260,9 +260,7 @@ def test_live_run(self):
)
@pytest.mark.integration
def test_live_run_wrong_model(self, chat_messages):
component = CohereChatGenerator(
model_name="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")
)
component = CohereChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY"))
with pytest.raises(cohere.CohereAPIError, match="finetuned model something-obviously-wrong is not valid"):
component.run(chat_messages)

Expand Down
Loading

0 comments on commit bc243de

Please sign in to comment.