Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore!: Rename model_name to model in the Cohere integration #222

Merged
merged 9 commits into from
Jan 17, 2024
16 changes: 8 additions & 8 deletions integrations/cohere/src/cohere_haystack/chat/chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class CohereChatGenerator:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "command",
model: str = "command",
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
api_base_url: Optional[str] = None,
generation_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -37,7 +37,7 @@ def __init__(
Initialize the CohereChatGenerator instance.

:param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var.
:param model_name: The name of the model to use. Available models are: [command, command-light, command-nightly,
:param model: The name of the model to use. Available models are: [command, command-light, command-nightly,
command-nightly-light]. Defaults to "command".
:param streaming_callback: A callback function to be called with the streaming response. Defaults to None.
:param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai".
Expand Down Expand Up @@ -82,7 +82,7 @@ def __init__(
if generation_kwargs is None:
generation_kwargs = {}
self.api_key = api_key
self.model_name = model_name
self.model = model
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
self.generation_kwargs = generation_kwargs
Expand All @@ -93,7 +93,7 @@ def _get_telemetry_data(self) -> Dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model_name}
return {"model": self.model}

def to_dict(self) -> Dict[str, Any]:
"""
Expand All @@ -103,7 +103,7 @@ def to_dict(self) -> Dict[str, Any]:
callback_name = serialize_callback_handler(self.streaming_callback) if self.streaming_callback else None
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
generation_kwargs=self.generation_kwargs,
Expand Down Expand Up @@ -147,7 +147,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
chat_history = [self._message_to_dict(m) for m in messages[:-1]]
response = self.client.chat(
message=messages[-1].content,
model=self.model_name,
model=self.model,
stream=self.streaming_callback is not None,
chat_history=chat_history,
**generation_kwargs,
Expand All @@ -160,7 +160,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
chat_message = ChatMessage.from_assistant(content=response.texts)
chat_message.meta.update(
{
"model": self.model_name,
"model": self.model,
"usage": response.token_count,
"index": 0,
"finish_reason": response.finish_reason,
Expand Down Expand Up @@ -193,7 +193,7 @@ def _build_message(self, cohere_response):
message = ChatMessage.from_assistant(content=content)
message.meta.update(
{
"model": self.model_name,
"model": self.model,
"usage": cohere_response.token_count,
"index": 0,
"finish_reason": None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class CohereDocumentEmbedder:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "embed-english-v2.0",
model: str = "embed-english-v2.0",
input_type: str = "search_document",
api_base_url: str = COHERE_API_URL,
truncate: str = "END",
Expand All @@ -53,7 +53,7 @@ def __init__(

:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment
variable COHERE_API_KEY (recommended).
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
:param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
`"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`,
`"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`,
`"embed-multilingual-v2.0"`. This list of all supported models can be found in the
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(
raise ValueError(msg)

self.api_key = api_key
self.model_name = model_name
self.model = model
self.input_type = input_type
self.api_base_url = api_base_url
self.truncate = truncate
Expand All @@ -106,7 +106,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
input_type=self.input_type,
api_base_url=self.api_base_url,
truncate=self.truncate,
Expand Down Expand Up @@ -160,7 +160,7 @@ def run(self, documents: List[Document]):
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
all_embeddings, metadata = asyncio.run(
get_async_response(cohere_client, texts_to_embed, self.model_name, self.input_type, self.truncate)
get_async_response(cohere_client, texts_to_embed, self.model, self.input_type, self.truncate)
)
else:
cohere_client = Client(
Expand All @@ -169,7 +169,7 @@ def run(self, documents: List[Document]):
all_embeddings, metadata = get_response(
cohere_client,
texts_to_embed,
self.model_name,
self.model,
self.input_type,
self.truncate,
self.batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class CohereTextEmbedder:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "embed-english-v2.0",
model: str = "embed-english-v2.0",
input_type: str = "search_query",
api_base_url: str = COHERE_API_URL,
truncate: str = "END",
Expand All @@ -47,7 +47,7 @@ def __init__(

:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment
variable COHERE_API_KEY (recommended).
:param model_name: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
:param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
`"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`,
`"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`,
`"embed-multilingual-v2.0"`. This list of all supported models can be found in the
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
raise ValueError(msg)

self.api_key = api_key
self.model_name = model_name
self.model = model
self.input_type = input_type
self.api_base_url = api_base_url
self.truncate = truncate
Expand All @@ -91,7 +91,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
input_type=self.input_type,
api_base_url=self.api_base_url,
truncate=self.truncate,
Expand All @@ -117,12 +117,12 @@ def run(self, text: str):
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
embedding, metadata = asyncio.run(
get_async_response(cohere_client, [text], self.model_name, self.input_type, self.truncate)
get_async_response(cohere_client, [text], self.model, self.input_type, self.truncate)
)
else:
cohere_client = Client(
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
)
embedding, metadata = get_response(cohere_client, [text], self.model_name, self.input_type, self.truncate)
embedding, metadata = get_response(cohere_client, [text], self.model, self.input_type, self.truncate)

return {"embedding": embedding[0], "meta": metadata}
10 changes: 5 additions & 5 deletions integrations/cohere/src/cohere_haystack/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class CohereGenerator:
def __init__(
self,
api_key: Optional[str] = None,
model_name: str = "command",
model: str = "command",
streaming_callback: Optional[Callable] = None,
api_base_url: Optional[str] = None,
**kwargs,
Expand All @@ -41,7 +41,7 @@ def __init__(
Instantiates a `CohereGenerator` component.

:param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var.
:param model_name: The name of the model to use. Available models are: [command, command-light, command-nightly,
:param model: The name of the model to use. Available models are: [command, command-light, command-nightly,
command-nightly-light]. Defaults to "command".
:param streaming_callback: A callback function to be called with the streaming response. Defaults to None.
:param api_base_url: The base URL of the Cohere API. Defaults to "https://api.cohere.ai".
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(
api_base_url = COHERE_API_URL

self.api_key = api_key
self.model_name = model_name
self.model = model
self.streaming_callback = streaming_callback
self.api_base_url = api_base_url
self.model_parameters = kwargs
Expand All @@ -107,7 +107,7 @@ def to_dict(self) -> Dict[str, Any]:

return default_to_dict(
self,
model_name=self.model_name,
model=self.model,
streaming_callback=callback_name,
api_base_url=self.api_base_url,
**self.model_parameters,
Expand Down Expand Up @@ -142,7 +142,7 @@ def run(self, prompt: str):
:param prompt: The prompt to be sent to the generative model.
"""
response = self.client.generate(
model=self.model_name, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters
model=self.model, prompt=prompt, stream=self.streaming_callback is not None, **self.model_parameters
)
if self.streaming_callback:
metadata_dict: Dict[str, Any] = {}
Expand Down
26 changes: 12 additions & 14 deletions integrations/cohere/tests/test_cohere_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class TestCohereChatGenerator:
def test_init_default(self):
component = CohereChatGenerator(api_key="test-api-key")
assert component.api_key == "test-api-key"
assert component.model_name == "command"
assert component.model == "command"
assert component.streaming_callback is None
assert component.api_base_url == cohere.COHERE_API_URL
assert not component.generation_kwargs
Expand All @@ -72,13 +72,13 @@ def test_init_fail_wo_api_key(self, monkeypatch):
def test_init_with_parameters(self):
component = CohereChatGenerator(
api_key="test-api-key",
model_name="command-nightly",
model="command-nightly",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
)
assert component.api_key == "test-api-key"
assert component.model_name == "command-nightly"
assert component.model == "command-nightly"
assert component.streaming_callback is default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
Expand All @@ -90,7 +90,7 @@ def test_to_dict_default(self):
assert data == {
"type": "cohere_haystack.chat.chat_generator.CohereChatGenerator",
"init_parameters": {
"model_name": "command",
"model": "command",
"streaming_callback": None,
"api_base_url": "https://api.cohere.ai",
"generation_kwargs": {},
Expand All @@ -101,7 +101,7 @@ def test_to_dict_default(self):
def test_to_dict_with_parameters(self):
component = CohereChatGenerator(
api_key="test-api-key",
model_name="command-nightly",
model="command-nightly",
streaming_callback=default_streaming_callback,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -110,7 +110,7 @@ def test_to_dict_with_parameters(self):
assert data == {
"type": "cohere_haystack.chat.chat_generator.CohereChatGenerator",
"init_parameters": {
"model_name": "command-nightly",
"model": "command-nightly",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
"api_base_url": "test-base-url",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -121,7 +121,7 @@ def test_to_dict_with_parameters(self):
def test_to_dict_with_lambda_streaming_callback(self):
component = CohereChatGenerator(
api_key="test-api-key",
model_name="command",
model="command",
streaming_callback=lambda x: x,
api_base_url="test-base-url",
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -130,7 +130,7 @@ def test_to_dict_with_lambda_streaming_callback(self):
assert data == {
"type": "cohere_haystack.chat.chat_generator.CohereChatGenerator",
"init_parameters": {
"model_name": "command",
"model": "command",
"api_base_url": "test-base-url",
"streaming_callback": "tests.test_cohere_chat_generator.<lambda>",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
Expand All @@ -143,14 +143,14 @@ def test_from_dict(self, monkeypatch):
data = {
"type": "cohere_haystack.chat.chat_generator.CohereChatGenerator",
"init_parameters": {
"model_name": "command",
"model": "command",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
},
}
component = CohereChatGenerator.from_dict(data)
assert component.model_name == "command"
assert component.model == "command"
assert component.streaming_callback is default_streaming_callback
assert component.api_base_url == "test-base-url"
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
Expand All @@ -161,7 +161,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
data = {
"type": "cohere_haystack.chat.chat_generator.CohereChatGenerator",
"init_parameters": {
"model_name": "command",
"model": "command",
"api_base_url": "test-base-url",
"streaming_callback": "haystack.components.generators.utils.default_streaming_callback",
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
Expand Down Expand Up @@ -260,9 +260,7 @@ def test_live_run(self):
)
@pytest.mark.integration
def test_live_run_wrong_model(self, chat_messages):
component = CohereChatGenerator(
model_name="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")
)
component = CohereChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY"))
with pytest.raises(cohere.CohereAPIError, match="finetuned model something-obviously-wrong is not valid"):
component.run(chat_messages)

Expand Down
Loading