From bca722616d44e223c1dc37ff1cace92283d89e00 Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Fri, 9 Feb 2024 17:23:59 +0100 Subject: [PATCH] refactor!: Use `Secret` for API keys in Cohere components (#386) --- integrations/cohere/pyproject.toml | 77 +++++++----------- .../embedders/cohere/document_embedder.py | 43 ++++++---- .../embedders/cohere/text_embedder.py | 45 +++++++---- .../generators/cohere/chat/chat_generator.py | 19 ++--- .../components/generators/cohere/generator.py | 19 ++--- .../tests/test_cohere_chat_generator.py | 79 +++++++++++-------- .../cohere/tests/test_cohere_generators.py | 47 ++++++----- .../cohere/tests/test_document_embedder.py | 21 ++--- .../cohere/tests/test_text_embedder.py | 21 ++--- 9 files changed, 191 insertions(+), 180 deletions(-) diff --git a/integrations/cohere/pyproject.toml b/integrations/cohere/pyproject.toml index 332471674..4b612aca5 100644 --- a/integrations/cohere/pyproject.toml +++ b/integrations/cohere/pyproject.toml @@ -10,9 +10,7 @@ readme = "README.md" requires-python = ">=3.7" license = "Apache-2.0" keywords = [] -authors = [ - { name = "deepset GmbH", email = "info@deepset.ai" }, -] +authors = [{ name = "deepset GmbH", email = "info@deepset.ai" }] classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python", @@ -24,10 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = [ - "haystack-ai", - "cohere", -] +dependencies = ["haystack-ai>=2.0.0b6", "cohere"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/cohere#readme" @@ -46,51 +41,25 @@ root = "../.." git_describe_command = 'git describe --tags --match="integrations/cohere-v[0-9]*"' [tool.hatch.envs.default] -dependencies = [ - "coverage[toml]>=6.5", - "pytest", - "haystack-pydoc-tools", -] +dependencies = ["coverage[toml]>=6.5", "pytest", "haystack-pydoc-tools"] [tool.hatch.envs.default.scripts] test = "pytest {args:tests}" test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] -docs = [ - "pydoc-markdown pydoc/config.yml" -] +cov-report = ["- coverage combine", "coverage report"] +cov = ["test-cov", "cov-report"] +docs = ["pydoc-markdown pydoc/config.yml"] [[tool.hatch.envs.all.matrix]] python = ["3.7", "3.8", "3.9", "3.10", "3.11"] [tool.hatch.envs.lint] detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] +dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"] [tool.hatch.envs.lint.scripts] typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +style = ["ruff {args:.}", "black --check --diff {args:.}"] +fmt = ["black {args:.}", "ruff --fix {args:.}", "style"] +all = ["style", "typing"] [tool.black] target-version = ["py37"] @@ -130,9 +99,18 @@ ignore = [ # Allow non-abstract empty methods in abstract base classes "B027", # Ignore checks for possible passwords - "S105", "S106", "S107", + "S105", + "S106", + "S107", # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", + "C901", + "PLR0911", + "PLR0912", + "PLR0913", + "PLR0915", + # Misc + "B008", + "S101", ] unfixable = [ # Don't touch unused imports @@ -155,15 +133,14 @@ branch = true parallel = true [tool.coverage.paths] -cohere_haystack = ["src/haystack_integrations", "*/cohere/src/haystack_integrations"] +cohere_haystack = [ + "src/haystack_integrations", + "*/cohere/src/haystack_integrations", +] tests = ["tests", "*/cohere/tests"] [tool.coverage.report] -exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", -] +exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] [[tool.mypy.overrides]] module = [ @@ -184,4 +161,4 @@ markers = [ "generators: generators tests", "chat_generators: chat_generators tests", ] -log_cli = true \ No newline at end of file +log_cli = true diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py index 69308ad19..b09258128 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/document_embedder.py @@ -2,10 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 import asyncio -import os from typing import Any, Dict, List, Optional -from haystack import Document, component, default_to_dict +from haystack import Document, component, default_from_dict, default_to_dict +from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response from cohere import COHERE_API_URL, AsyncClient, Client @@ -35,7 +35,7 @@ class CohereDocumentEmbedder: def __init__( self, - api_key: Optional[str] = None, + api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]), model: str = "embed-english-v2.0", input_type: str = "search_document", api_base_url: str = COHERE_API_URL, @@ -51,8 +51,7 @@ def __init__( """ Create a CohereDocumentEmbedder component. - :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment - variable COHERE_API_KEY (recommended). + :param api_key: The Cohere API key. :param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are: `"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`, `"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`, @@ -78,15 +77,6 @@ def __init__( :param embedding_separator: Separator used to concatenate the meta fields to the Document text. """ - api_key = api_key or os.environ.get("COHERE_API_KEY") - # we check whether api_key is None or an empty string - if not api_key: - msg = ( - "CohereDocumentEmbedder expects an API key. " - "Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly." - ) - raise ValueError(msg) - self.api_key = api_key self.model = model self.input_type = input_type @@ -106,6 +96,7 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, + api_key=self.api_key.to_dict(), model=self.model, input_type=self.input_type, api_base_url=self.api_base_url, @@ -119,6 +110,17 @@ def to_dict(self) -> Dict[str, Any]: embedding_separator=self.embedding_separator, ) + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CohereDocumentEmbedder": + """ + Deserialize this component from a dictionary. + :param data: The dictionary representation of this component. + :return: The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + deserialize_secrets_inplace(init_params, ["api_key"]) + return default_from_dict(cls, data) + def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]: """ Prepare the texts to embed by concatenating the Document text with the metadata fields to embed. @@ -155,16 +157,25 @@ def run(self, documents: List[Document]): texts_to_embed = self._prepare_texts_to_embed(documents) + api_key = self.api_key.resolve_value() + assert api_key is not None + if self.use_async_client: cohere_client = AsyncClient( - self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + api_key, + api_url=self.api_base_url, + max_retries=self.max_retries, + timeout=self.timeout, ) all_embeddings, metadata = asyncio.run( get_async_response(cohere_client, texts_to_embed, self.model, self.input_type, self.truncate) ) else: cohere_client = Client( - self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + api_key, + api_url=self.api_base_url, + max_retries=self.max_retries, + timeout=self.timeout, ) all_embeddings, metadata = get_response( cohere_client, diff --git a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py index 2fa922004..448a49dec 100644 --- a/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py +++ b/integrations/cohere/src/haystack_integrations/components/embedders/cohere/text_embedder.py @@ -2,10 +2,10 @@ # # SPDX-License-Identifier: Apache-2.0 import asyncio -import os -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List -from haystack import component, default_to_dict +from haystack import component, default_from_dict, default_to_dict +from haystack.utils import Secret, deserialize_secrets_inplace from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response from cohere import COHERE_API_URL, AsyncClient, Client @@ -33,7 +33,7 @@ class CohereTextEmbedder: def __init__( self, - api_key: Optional[str] = None, + api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]), model: str = "embed-english-v2.0", input_type: str = "search_query", api_base_url: str = COHERE_API_URL, @@ -45,8 +45,7 @@ def __init__( """ Create a CohereTextEmbedder component. - :param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment - variable COHERE_API_KEY (recommended). + :param api_key: The Cohere API key. :param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are: `"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`, `"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`, @@ -67,15 +66,6 @@ def __init__( :param timeout: Request timeout in seconds, defaults to `120`. """ - api_key = api_key or os.environ.get("COHERE_API_KEY") - # we check whether api_key is None or an empty string - if not api_key: - msg = ( - "CohereTextEmbedder expects an API key. " - "Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly." - ) - raise ValueError(msg) - self.api_key = api_key self.model = model self.input_type = input_type @@ -91,6 +81,7 @@ def to_dict(self) -> Dict[str, Any]: """ return default_to_dict( self, + api_key=self.api_key.to_dict(), model=self.model, input_type=self.input_type, api_base_url=self.api_base_url, @@ -100,6 +91,17 @@ def to_dict(self) -> Dict[str, Any]: timeout=self.timeout, ) + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "CohereTextEmbedder": + """ + Deserialize this component from a dictionary. + :param data: The dictionary representation of this component. + :return: The deserialized component instance. + """ + init_params = data.get("init_parameters", {}) + deserialize_secrets_inplace(init_params, ["api_key"]) + return default_from_dict(cls, data) + @component.output_types(embedding=List[float], meta=Dict[str, Any]) def run(self, text: str): """Embed a string.""" @@ -112,16 +114,25 @@ def run(self, text: str): # Establish connection to API + api_key = self.api_key.resolve_value() + assert api_key is not None + if self.use_async_client: cohere_client = AsyncClient( - self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + api_key, + api_url=self.api_base_url, + max_retries=self.max_retries, + timeout=self.timeout, ) embedding, metadata = asyncio.run( get_async_response(cohere_client, [text], self.model, self.input_type, self.truncate) ) else: cohere_client = Client( - self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout + api_key, + api_url=self.api_base_url, + max_retries=self.max_retries, + timeout=self.timeout, ) embedding, metadata = get_response(cohere_client, [text], self.model, self.input_type, self.truncate) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py index c632bed83..600cf3cf9 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py @@ -1,11 +1,11 @@ import logging -import os from typing import Any, Callable, Dict, List, Optional from haystack import component, default_from_dict, default_to_dict from haystack.components.generators.utils import deserialize_callback_handler, serialize_callback_handler from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk from haystack.lazy_imports import LazyImport +from haystack.utils import Secret, deserialize_secrets_inplace with LazyImport(message="Run 'pip install cohere'") as cohere_import: import cohere @@ -27,7 +27,7 @@ class CohereChatGenerator: def __init__( self, - api_key: Optional[str] = None, + api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]), model: str = "command", streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, api_base_url: Optional[str] = None, @@ -37,7 +37,7 @@ def __init__( """ Initialize the CohereChatGenerator instance. - :param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var. + :param api_key: The API key for the Cohere API. :param model: The name of the model to use. Available models are: [command, command-light, command-nightly, command-nightly-light]. Defaults to "command". :param streaming_callback: A callback function to be called with the streaming response. Defaults to None. @@ -69,15 +69,6 @@ def __init__( """ cohere_import.check() - api_key = api_key or os.environ.get("COHERE_API_KEY") - # we check whether api_key is None or an empty string - if not api_key: - msg = ( - "CohereChatGenerator expects an API key. " - "Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly." - ) - raise ValueError(msg) - if not api_base_url: api_base_url = cohere.COHERE_API_URL if generation_kwargs is None: @@ -88,7 +79,7 @@ def __init__( self.api_base_url = api_base_url self.generation_kwargs = generation_kwargs self.model_parameters = kwargs - self.client = cohere.Client(api_key=self.api_key, api_url=self.api_base_url) + self.client = cohere.Client(api_key=self.api_key.resolve_value(), api_url=self.api_base_url) def _get_telemetry_data(self) -> Dict[str, Any]: """ @@ -107,6 +98,7 @@ def to_dict(self) -> Dict[str, Any]: model=self.model, streaming_callback=callback_name, api_base_url=self.api_base_url, + api_key=self.api_key.to_dict(), generation_kwargs=self.generation_kwargs, ) @@ -118,6 +110,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereChatGenerator": :return: The deserialized component instance. """ init_params = data.get("init_parameters", {}) + deserialize_secrets_inplace(init_params, ["api_key"]) serialized_callback_handler = init_params.get("streaming_callback") if serialized_callback_handler: data["init_parameters"]["streaming_callback"] = deserialize_callback_handler(serialized_callback_handler) diff --git a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py index 92fed51aa..4927839d2 100644 --- a/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py +++ b/integrations/cohere/src/haystack_integrations/components/generators/cohere/generator.py @@ -2,12 +2,12 @@ # # SPDX-License-Identifier: Apache-2.0 import logging -import os import sys from typing import Any, Callable, Dict, List, Optional, cast from haystack import DeserializationError, component, default_from_dict, default_to_dict from haystack.dataclasses import StreamingChunk +from haystack.utils import Secret, deserialize_secrets_inplace from cohere import COHERE_API_URL, Client from cohere.responses import Generations @@ -33,7 +33,7 @@ class CohereGenerator: def __init__( self, - api_key: Optional[str] = None, + api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]), model: str = "command", streaming_callback: Optional[Callable] = None, api_base_url: Optional[str] = None, @@ -42,7 +42,7 @@ def __init__( """ Instantiates a `CohereGenerator` component. - :param api_key: The API key for the Cohere API. If not set, it will be read from the COHERE_API_KEY env var. + :param api_key: The API key for the Cohere API. :param model: The name of the model to use. Available models are: [command, command-light, command-nightly, command-nightly-light]. Defaults to "command". :param streaming_callback: A callback function to be called with the streaming response. Defaults to None. @@ -75,15 +75,6 @@ def __init__( - 'logit_bias': Used to prevent the model from generating unwanted tokens or to incentivize it to include desired tokens. The format is {token_id: bias} where bias is a float between -10 and 10. """ - api_key = api_key or os.environ.get("COHERE_API_KEY") - # we check whether api_key is None or an empty string - if not api_key: - msg = ( - "CohereGenerator expects an API key. " - "Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly." - ) - raise ValueError(msg) - if not api_base_url: api_base_url = COHERE_API_URL @@ -92,7 +83,7 @@ def __init__( self.streaming_callback = streaming_callback self.api_base_url = api_base_url self.model_parameters = kwargs - self.client = Client(api_key=self.api_key, api_url=self.api_base_url) + self.client = Client(api_key=self.api_key.resolve_value(), api_url=self.api_base_url) def to_dict(self) -> Dict[str, Any]: """ @@ -112,6 +103,7 @@ def to_dict(self) -> Dict[str, Any]: model=self.model, streaming_callback=callback_name, api_base_url=self.api_base_url, + api_key=self.api_key.to_dict(), **self.model_parameters, ) @@ -121,6 +113,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "CohereGenerator": Deserialize this component from a dictionary. """ init_params = data.get("init_parameters", {}) + deserialize_secrets_inplace(init_params, ["api_key"]) streaming_callback = None if "streaming_callback" in init_params and init_params["streaming_callback"] is not None: parts = init_params["streaming_callback"].split(".") diff --git a/integrations/cohere/tests/test_cohere_chat_generator.py b/integrations/cohere/tests/test_cohere_chat_generator.py index 556535e10..f1d15db08 100644 --- a/integrations/cohere/tests/test_cohere_chat_generator.py +++ b/integrations/cohere/tests/test_cohere_chat_generator.py @@ -5,6 +5,7 @@ import pytest from haystack.components.generators.utils import print_streaming_chunk from haystack.dataclasses import ChatMessage, ChatRole, StreamingChunk +from haystack.utils import Secret from haystack_integrations.components.generators.cohere import CohereChatGenerator pytestmark = pytest.mark.chat_generators @@ -53,9 +54,11 @@ def chat_messages(): class TestCohereChatGenerator: @pytest.mark.unit - def test_init_default(self): - component = CohereChatGenerator(api_key="test-api-key") - assert component.api_key == "test-api-key" + def test_init_default(self, monkeypatch): + monkeypatch.setenv("COHERE_API_KEY", "test-api-key") + + component = CohereChatGenerator() + assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) assert component.model == "command" assert component.streaming_callback is None assert component.api_base_url == cohere.COHERE_API_URL @@ -64,42 +67,47 @@ def test_init_default(self): @pytest.mark.unit def test_init_fail_wo_api_key(self, monkeypatch): monkeypatch.delenv("COHERE_API_KEY", raising=False) + monkeypatch.delenv("CO_API_KEY", raising=False) with pytest.raises(ValueError): CohereChatGenerator() @pytest.mark.unit def test_init_with_parameters(self): component = CohereChatGenerator( - api_key="test-api-key", + api_key=Secret.from_token("test-api-key"), model="command-nightly", streaming_callback=print_streaming_chunk, api_base_url="test-base-url", generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"}, ) - assert component.api_key == "test-api-key" + assert component.api_key == Secret.from_token("test-api-key") assert component.model == "command-nightly" assert component.streaming_callback is print_streaming_chunk assert component.api_base_url == "test-base-url" assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"} @pytest.mark.unit - def test_to_dict_default(self): - component = CohereChatGenerator(api_key="test-api-key") + def test_to_dict_default(self, monkeypatch): + monkeypatch.setenv("COHERE_API_KEY", "test-api-key") + component = CohereChatGenerator() data = component.to_dict() assert data == { "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { "model": "command", "streaming_callback": None, + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "api_base_url": "https://api.cohere.ai", "generation_kwargs": {}, }, } @pytest.mark.unit - def test_to_dict_with_parameters(self): + def test_to_dict_with_parameters(self, monkeypatch): + monkeypatch.setenv("COHERE_API_KEY", "test-api-key") + monkeypatch.setenv("CO_API_KEY", "fake-api-key") component = CohereChatGenerator( - api_key="test-api-key", + api_key=Secret.from_env_var("ENV_VAR", strict=False), model="command-nightly", streaming_callback=print_streaming_chunk, api_base_url="test-base-url", @@ -110,6 +118,7 @@ def test_to_dict_with_parameters(self): "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { "model": "command-nightly", + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "api_base_url": "test-base-url", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, @@ -117,9 +126,9 @@ def test_to_dict_with_parameters(self): } @pytest.mark.unit - def test_to_dict_with_lambda_streaming_callback(self): + def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): + monkeypatch.setenv("COHERE_API_KEY", "test-api-key") component = CohereChatGenerator( - api_key="test-api-key", model="command", streaming_callback=lambda x: x, api_base_url="test-base-url", @@ -131,6 +140,7 @@ def test_to_dict_with_lambda_streaming_callback(self): "init_parameters": { "model": "command", "api_base_url": "test-base-url", + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "streaming_callback": "tests.test_cohere_chat_generator.", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, @@ -139,11 +149,13 @@ def test_to_dict_with_lambda_streaming_callback(self): @pytest.mark.unit def test_from_dict(self, monkeypatch): monkeypatch.setenv("COHERE_API_KEY", "fake-api-key") + monkeypatch.setenv("CO_API_KEY", "fake-api-key") data = { "type": "haystack_integrations.components.generators.cohere.chat.chat_generator.CohereChatGenerator", "init_parameters": { "model": "command", "api_base_url": "test-base-url", + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, @@ -162,6 +174,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): "init_parameters": { "model": "command", "api_base_url": "test-base-url", + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", "generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"}, }, @@ -171,7 +184,7 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch): @pytest.mark.unit def test_run(self, chat_messages, mock_chat_response): # noqa: ARG002 - component = CohereChatGenerator(api_key="test-api-key") + component = CohereChatGenerator(api_key=Secret.from_token("test-api-key")) response = component.run(chat_messages) # check that the component returns the correct ChatMessage response @@ -183,14 +196,14 @@ def test_run(self, chat_messages, mock_chat_response): # noqa: ARG002 @pytest.mark.unit def test_message_to_dict(self, chat_messages): - obj = CohereChatGenerator(api_key="api-key") + obj = CohereChatGenerator(api_key=Secret.from_token("test-api-key")) dictionary = [obj._message_to_dict(message) for message in chat_messages] assert dictionary == [{"user_name": "Chatbot", "text": "What's the capital of France"}] @pytest.mark.unit def test_run_with_params(self, chat_messages, mock_chat_response): component = CohereChatGenerator( - api_key="test-api-key", generation_kwargs={"max_tokens": 10, "temperature": 0.5} + api_key=Secret.from_token("test-api-key"), generation_kwargs={"max_tokens": 10, "temperature": 0.5} ) response = component.run(chat_messages) @@ -216,7 +229,9 @@ def streaming_callback_fn(chunk: StreamingChunk): streaming_call_count += 1 assert isinstance(chunk, StreamingChunk) - generator = CohereChatGenerator(api_key="test-api-key", streaming_callback=streaming_callback_fn) + generator = CohereChatGenerator( + api_key=Secret.from_token("test-api-key"), streaming_callback=streaming_callback_fn + ) # Create a fake streamed response # self needed here, don't remove @@ -239,27 +254,25 @@ def mock_iter(self): # noqa: ARG001 assert [isinstance(reply, ChatMessage) for reply in response["replies"]] @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None), - reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration def test_live_run(self): chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", meta={})] - component = CohereChatGenerator( - api_key=os.environ.get("COHERE_API_KEY"), generation_kwargs={"temperature": 0.8} - ) + component = CohereChatGenerator(generation_kwargs={"temperature": 0.8}) results = component.run(chat_messages) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] assert "Paris" in message.content @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None), - reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration def test_live_run_wrong_model(self, chat_messages): - component = CohereChatGenerator(model="something-obviously-wrong", api_key=os.environ.get("COHERE_API_KEY")) + component = CohereChatGenerator(model="something-obviously-wrong") with pytest.raises( cohere.CohereAPIError, match="model not found, make sure the correct model ID was used and that you have access to the model.", @@ -267,8 +280,8 @@ def test_live_run_wrong_model(self, chat_messages): component.run(chat_messages) @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None), - reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration def test_live_run_streaming(self): @@ -282,7 +295,7 @@ def __call__(self, chunk: StreamingChunk) -> None: self.responses += chunk.content if chunk.content else "" callback = Callback() - component = CohereChatGenerator(os.environ.get("COHERE_API_KEY"), streaming_callback=callback) + component = CohereChatGenerator(streaming_callback=callback) results = component.run( [ChatMessage(content="What's the capital of France? answer in a word", role=ChatRole.USER, name=None)] ) @@ -297,15 +310,13 @@ def __call__(self, chunk: StreamingChunk) -> None: assert "Paris" in callback.responses @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None), - reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration def test_live_run_with_connector(self): chat_messages = [ChatMessage(content="What's the capital of France", role=ChatRole.USER, name="", meta={})] - component = CohereChatGenerator( - api_key=os.environ.get("COHERE_API_KEY"), generation_kwargs={"temperature": 0.8} - ) + component = CohereChatGenerator(generation_kwargs={"temperature": 0.8}) results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) assert len(results["replies"]) == 1 message: ChatMessage = results["replies"][0] @@ -314,8 +325,8 @@ def test_live_run_with_connector(self): assert message.meta["citations"] is not None @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None), - reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration def test_live_run_streaming_with_connector(self): @@ -330,7 +341,7 @@ def __call__(self, chunk: StreamingChunk) -> None: callback = Callback() chat_messages = [ChatMessage(content="What's the capital of France? answer in a word", role=None, name=None)] - component = CohereChatGenerator(os.environ.get("COHERE_API_KEY"), streaming_callback=callback) + component = CohereChatGenerator(streaming_callback=callback) results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]}) assert len(results["replies"]) == 1 diff --git a/integrations/cohere/tests/test_cohere_generators.py b/integrations/cohere/tests/test_cohere_generators.py index 5b12374a7..dcfa5f27a 100644 --- a/integrations/cohere/tests/test_cohere_generators.py +++ b/integrations/cohere/tests/test_cohere_generators.py @@ -6,6 +6,7 @@ import pytest from cohere import COHERE_API_URL from haystack.components.generators.utils import print_streaming_chunk +from haystack.utils import Secret from haystack_integrations.components.generators.cohere import CohereGenerator pytestmark = pytest.mark.generators @@ -13,8 +14,8 @@ class TestCohereGenerator: def test_init_default(self): - component = CohereGenerator(api_key="test-api-key") - assert component.api_key == "test-api-key" + component = CohereGenerator() + assert component.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) assert component.model == "command" assert component.streaming_callback is None assert component.api_base_url == COHERE_API_URL @@ -23,34 +24,38 @@ def test_init_default(self): def test_init_with_parameters(self): callback = lambda x: x # noqa: E731 component = CohereGenerator( - api_key="test-api-key", + api_key=Secret.from_token("test-api-key"), model="command-light", max_tokens=10, some_test_param="test-params", streaming_callback=callback, api_base_url="test-base-url", ) - assert component.api_key == "test-api-key" + assert component.api_key == Secret.from_token("test-api-key") assert component.model == "command-light" assert component.streaming_callback == callback assert component.api_base_url == "test-base-url" assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} - def test_to_dict_default(self): - component = CohereGenerator(api_key="test-api-key") + def test_to_dict_default(self, monkeypatch): + monkeypatch.setenv("COHERE_API_KEY", "test-api-key") + component = CohereGenerator() data = component.to_dict() assert data == { "type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator", "init_parameters": { "model": "command", + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "streaming_callback": None, "api_base_url": COHERE_API_URL, }, } - def test_to_dict_with_parameters(self): + def test_to_dict_with_parameters(self, monkeypatch): + monkeypatch.setenv("COHERE_API_KEY", "test-api-key") + monkeypatch.setenv("CO_API_KEY", "fake-api-key") component = CohereGenerator( - api_key="test-api-key", + api_key=Secret.from_env_var("ENV_VAR", strict=False), model="command-light", max_tokens=10, some_test_param="test-params", @@ -65,13 +70,14 @@ def test_to_dict_with_parameters(self): "max_tokens": 10, "some_test_param": "test-params", "api_base_url": "test-base-url", + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", }, } - def test_to_dict_with_lambda_streaming_callback(self): + def test_to_dict_with_lambda_streaming_callback(self, monkeypatch): + monkeypatch.setenv("COHERE_API_KEY", "test-api-key") component = CohereGenerator( - api_key="test-api-key", model="command", max_tokens=10, some_test_param="test-params", @@ -84,6 +90,7 @@ def test_to_dict_with_lambda_streaming_callback(self): "init_parameters": { "model": "command", "streaming_callback": "tests.test_cohere_generators.", + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "api_base_url": "test-base-url", "max_tokens": 10, "some_test_param": "test-params", @@ -91,26 +98,28 @@ def test_to_dict_with_lambda_streaming_callback(self): } def test_from_dict(self, monkeypatch): - monkeypatch.setenv("COHERE_API_KEY", "test-key") + monkeypatch.setenv("COHERE_API_KEY", "fake-api-key") + monkeypatch.setenv("CO_API_KEY", "fake-api-key") data = { "type": "haystack_integrations.components.generators.cohere.generator.CohereGenerator", "init_parameters": { "model": "command", "max_tokens": 10, + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "some_test_param": "test-params", "api_base_url": "test-base-url", "streaming_callback": "haystack.components.generators.utils.print_streaming_chunk", }, } component: CohereGenerator = CohereGenerator.from_dict(data) - assert component.api_key == "test-key" + assert component.api_key == Secret.from_env_var("ENV_VAR", strict=False) assert component.model == "command" assert component.streaming_callback == print_streaming_chunk assert component.api_base_url == "test-base-url" assert component.model_parameters == {"max_tokens": 10, "some_test_param": "test-params"} def test_check_truncated_answers(self, caplog): - component = CohereGenerator(api_key="test-api-key") + component = CohereGenerator(api_key=Secret.from_token("test-api-key")) meta = [{"finish_reason": "MAX_TOKENS"}] component._check_truncated_answers(meta) assert caplog.records[0].message == ( @@ -119,8 +128,8 @@ def test_check_truncated_answers(self, caplog): ) @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None), - reason="Export an env var called CO_API_KEY containing the Cohere API key to run this test.", + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration def test_cohere_generator_run(self): @@ -132,8 +141,8 @@ def test_cohere_generator_run(self): assert results["meta"][0]["finish_reason"] == "COMPLETE" @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None), - reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration def test_cohere_generator_run_wrong_model(self): @@ -147,8 +156,8 @@ def test_cohere_generator_run_wrong_model(self): component.run(prompt="What's the capital of France?") @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None), - reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration def test_cohere_generator_run_streaming(self): diff --git a/integrations/cohere/tests/test_document_embedder.py b/integrations/cohere/tests/test_document_embedder.py index efe8eb36a..ee15a6b30 100644 --- a/integrations/cohere/tests/test_document_embedder.py +++ b/integrations/cohere/tests/test_document_embedder.py @@ -6,6 +6,7 @@ import pytest from cohere import COHERE_API_URL from haystack import Document +from haystack.utils import Secret from haystack_integrations.components.embedders.cohere import CohereDocumentEmbedder pytestmark = pytest.mark.embedders @@ -13,8 +14,8 @@ class TestCohereDocumentEmbedder: def test_init_default(self): - embedder = CohereDocumentEmbedder(api_key="test-api-key") - assert embedder.api_key == "test-api-key" + embedder = CohereDocumentEmbedder() + assert embedder.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) assert embedder.model == "embed-english-v2.0" assert embedder.input_type == "search_document" assert embedder.api_base_url == COHERE_API_URL @@ -29,7 +30,7 @@ def test_init_default(self): def test_init_with_parameters(self): embedder = CohereDocumentEmbedder( - api_key="test-api-key", + api_key=Secret.from_token("test-api-key"), model="embed-multilingual-v2.0", input_type="search_query", api_base_url="https://custom-api-base-url.com", @@ -42,7 +43,7 @@ def test_init_with_parameters(self): meta_fields_to_embed=["test_field"], embedding_separator="-", ) - assert embedder.api_key == "test-api-key" + assert embedder.api_key == Secret.from_token("test-api-key") assert embedder.model == "embed-multilingual-v2.0" assert embedder.input_type == "search_query" assert embedder.api_base_url == "https://custom-api-base-url.com" @@ -56,11 +57,12 @@ def test_init_with_parameters(self): assert embedder.embedding_separator == "-" def test_to_dict(self): - embedder_component = CohereDocumentEmbedder(api_key="test-api-key") + embedder_component = CohereDocumentEmbedder() component_dict = embedder_component.to_dict() assert component_dict == { "type": "haystack_integrations.components.embedders.cohere.document_embedder.CohereDocumentEmbedder", "init_parameters": { + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "model": "embed-english-v2.0", "input_type": "search_document", "api_base_url": COHERE_API_URL, @@ -77,7 +79,7 @@ def test_to_dict(self): def test_to_dict_with_custom_init_parameters(self): embedder_component = CohereDocumentEmbedder( - api_key="test-api-key", + api_key=Secret.from_env_var("ENV_VAR", strict=False), model="embed-multilingual-v2.0", input_type="search_query", api_base_url="https://custom-api-base-url.com", @@ -94,6 +96,7 @@ def test_to_dict_with_custom_init_parameters(self): assert component_dict == { "type": "haystack_integrations.components.embedders.cohere.document_embedder.CohereDocumentEmbedder", "init_parameters": { + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "model": "embed-multilingual-v2.0", "input_type": "search_query", "api_base_url": "https://custom-api-base-url.com", @@ -109,8 +112,8 @@ def test_to_dict_with_custom_init_parameters(self): } @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None), - reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration def test_run(self): @@ -131,7 +134,7 @@ def test_run(self): assert isinstance(doc.embedding[0], float) def test_run_wrong_input_format(self): - embedder = CohereDocumentEmbedder(api_key="test-api-key") + embedder = CohereDocumentEmbedder(api_key=Secret.from_token("test-api-key")) with pytest.raises(TypeError, match="CohereDocumentEmbedder expects a list of Documents as input"): embedder.run(documents="text") diff --git a/integrations/cohere/tests/test_text_embedder.py b/integrations/cohere/tests/test_text_embedder.py index 657d8df83..d7f3147ca 100644 --- a/integrations/cohere/tests/test_text_embedder.py +++ b/integrations/cohere/tests/test_text_embedder.py @@ -5,6 +5,7 @@ import pytest from cohere import COHERE_API_URL +from haystack.utils import Secret from haystack_integrations.components.embedders.cohere import CohereTextEmbedder pytestmark = pytest.mark.embedders @@ -15,9 +16,9 @@ def test_init_default(self): """ Test default initialization parameters for CohereTextEmbedder. """ - embedder = CohereTextEmbedder(api_key="test-api-key") + embedder = CohereTextEmbedder() - assert embedder.api_key == "test-api-key" + assert embedder.api_key == Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]) assert embedder.model == "embed-english-v2.0" assert embedder.input_type == "search_query" assert embedder.api_base_url == COHERE_API_URL @@ -31,7 +32,7 @@ def test_init_with_parameters(self): Test custom initialization parameters for CohereTextEmbedder. """ embedder = CohereTextEmbedder( - api_key="test-api-key", + api_key=Secret.from_token("test-api-key"), model="embed-multilingual-v2.0", input_type="classification", api_base_url="https://custom-api-base-url.com", @@ -40,7 +41,7 @@ def test_init_with_parameters(self): max_retries=5, timeout=60, ) - assert embedder.api_key == "test-api-key" + assert embedder.api_key == Secret.from_token("test-api-key") assert embedder.model == "embed-multilingual-v2.0" assert embedder.input_type == "classification" assert embedder.api_base_url == "https://custom-api-base-url.com" @@ -53,11 +54,12 @@ def test_to_dict(self): """ Test serialization of this component to a dictionary, using default initialization parameters. """ - embedder_component = CohereTextEmbedder(api_key="test-api-key") + embedder_component = CohereTextEmbedder() component_dict = embedder_component.to_dict() assert component_dict == { "type": "haystack_integrations.components.embedders.cohere.text_embedder.CohereTextEmbedder", "init_parameters": { + "api_key": {"env_vars": ["COHERE_API_KEY", "CO_API_KEY"], "strict": True, "type": "env_var"}, "model": "embed-english-v2.0", "input_type": "search_query", "api_base_url": COHERE_API_URL, @@ -73,7 +75,7 @@ def test_to_dict_with_custom_init_parameters(self): Test serialization of this component to a dictionary, using custom initialization parameters. """ embedder_component = CohereTextEmbedder( - api_key="test-api-key", + api_key=Secret.from_env_var("ENV_VAR", strict=False), model="embed-multilingual-v2.0", input_type="classification", api_base_url="https://custom-api-base-url.com", @@ -86,6 +88,7 @@ def test_to_dict_with_custom_init_parameters(self): assert component_dict == { "type": "haystack_integrations.components.embedders.cohere.text_embedder.CohereTextEmbedder", "init_parameters": { + "api_key": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}, "model": "embed-multilingual-v2.0", "input_type": "classification", "api_base_url": "https://custom-api-base-url.com", @@ -100,15 +103,15 @@ def test_run_wrong_input_format(self): """ Test for checking incorrect input when creating embedding. """ - embedder = CohereTextEmbedder(api_key="test-api-key") + embedder = CohereTextEmbedder(api_key=Secret.from_token("test-api-key")) list_integers_input = ["text_snippet_1", "text_snippet_2"] with pytest.raises(TypeError, match="CohereTextEmbedder expects a string as input"): embedder.run(text=list_integers_input) @pytest.mark.skipif( - not os.environ.get("COHERE_API_KEY", None), - reason="Export an env var called COHERE_API_KEY containing the Cohere API key to run this test.", + not os.environ.get("COHERE_API_KEY", None) and not os.environ.get("CO_API_KEY", None), + reason="Export an env var called COHERE_API_KEY/CO_API_KEY containing the Cohere API key to run this test.", ) @pytest.mark.integration def test_run(self):