Skip to content

Commit

Permalink
refactor!: Use Secret for API keys in Cohere components (#386)
Browse files Browse the repository at this point in the history
  • Loading branch information
shadeMe authored Feb 9, 2024
1 parent f8a1019 commit bca7226
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 180 deletions.
77 changes: 27 additions & 50 deletions integrations/cohere/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ readme = "README.md"
requires-python = ">=3.7"
license = "Apache-2.0"
keywords = []
authors = [
{ name = "deepset GmbH", email = "[email protected]" },
]
authors = [{ name = "deepset GmbH", email = "[email protected]" }]
classifiers = [
"Development Status :: 4 - Beta",
"Programming Language :: Python",
Expand All @@ -24,10 +22,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = [
"haystack-ai",
"cohere",
]
dependencies = ["haystack-ai>=2.0.0b6", "cohere"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/cohere#readme"
Expand All @@ -46,51 +41,25 @@ root = "../.."
git_describe_command = 'git describe --tags --match="integrations/cohere-v[0-9]*"'

[tool.hatch.envs.default]
dependencies = [
"coverage[toml]>=6.5",
"pytest",
"haystack-pydoc-tools",
]
dependencies = ["coverage[toml]>=6.5", "pytest", "haystack-pydoc-tools"]
[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
test-cov = "coverage run -m pytest {args:tests}"
cov-report = [
"- coverage combine",
"coverage report",
]
cov = [
"test-cov",
"cov-report",
]
docs = [
"pydoc-markdown pydoc/config.yml"
]
cov-report = ["- coverage combine", "coverage report"]
cov = ["test-cov", "cov-report"]
docs = ["pydoc-markdown pydoc/config.yml"]

[[tool.hatch.envs.all.matrix]]
python = ["3.7", "3.8", "3.9", "3.10", "3.11"]

[tool.hatch.envs.lint]
detached = true
dependencies = [
"black>=23.1.0",
"mypy>=1.0.0",
"ruff>=0.0.243",
]
dependencies = ["black>=23.1.0", "mypy>=1.0.0", "ruff>=0.0.243"]
[tool.hatch.envs.lint.scripts]
typing = "mypy --install-types --non-interactive --explicit-package-bases {args:src/ tests}"
style = [
"ruff {args:.}",
"black --check --diff {args:.}",
]
fmt = [
"black {args:.}",
"ruff --fix {args:.}",
"style",
]
all = [
"style",
"typing",
]
style = ["ruff {args:.}", "black --check --diff {args:.}"]
fmt = ["black {args:.}", "ruff --fix {args:.}", "style"]
all = ["style", "typing"]

[tool.black]
target-version = ["py37"]
Expand Down Expand Up @@ -130,9 +99,18 @@ ignore = [
# Allow non-abstract empty methods in abstract base classes
"B027",
# Ignore checks for possible passwords
"S105", "S106", "S107",
"S105",
"S106",
"S107",
# Ignore complexity
"C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915",
"C901",
"PLR0911",
"PLR0912",
"PLR0913",
"PLR0915",
# Misc
"B008",
"S101",
]
unfixable = [
# Don't touch unused imports
Expand All @@ -155,15 +133,14 @@ branch = true
parallel = true

[tool.coverage.paths]
cohere_haystack = ["src/haystack_integrations", "*/cohere/src/haystack_integrations"]
cohere_haystack = [
"src/haystack_integrations",
"*/cohere/src/haystack_integrations",
]
tests = ["tests", "*/cohere/tests"]

[tool.coverage.report]
exclude_lines = [
"no cov",
"if __name__ == .__main__.:",
"if TYPE_CHECKING:",
]
exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]

[[tool.mypy.overrides]]
module = [
Expand All @@ -184,4 +161,4 @@ markers = [
"generators: generators tests",
"chat_generators: chat_generators tests",
]
log_cli = true
log_cli = true
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
from typing import Any, Dict, List, Optional

from haystack import Document, component, default_to_dict
from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response

from cohere import COHERE_API_URL, AsyncClient, Client
Expand Down Expand Up @@ -35,7 +35,7 @@ class CohereDocumentEmbedder:

def __init__(
self,
api_key: Optional[str] = None,
api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]),
model: str = "embed-english-v2.0",
input_type: str = "search_document",
api_base_url: str = COHERE_API_URL,
Expand All @@ -51,8 +51,7 @@ def __init__(
"""
Create a CohereDocumentEmbedder component.
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment
variable COHERE_API_KEY (recommended).
:param api_key: The Cohere API key.
:param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
`"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`,
`"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`,
Expand All @@ -78,15 +77,6 @@ def __init__(
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""

api_key = api_key or os.environ.get("COHERE_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"CohereDocumentEmbedder expects an API key. "
"Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.api_key = api_key
self.model = model
self.input_type = input_type
Expand All @@ -106,6 +96,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
api_key=self.api_key.to_dict(),
model=self.model,
input_type=self.input_type,
api_base_url=self.api_base_url,
Expand All @@ -119,6 +110,17 @@ def to_dict(self) -> Dict[str, Any]:
embedding_separator=self.embedding_separator,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CohereDocumentEmbedder":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
"""
init_params = data.get("init_parameters", {})
deserialize_secrets_inplace(init_params, ["api_key"])
return default_from_dict(cls, data)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
Expand Down Expand Up @@ -155,16 +157,25 @@ def run(self, documents: List[Document]):

texts_to_embed = self._prepare_texts_to_embed(documents)

api_key = self.api_key.resolve_value()
assert api_key is not None

if self.use_async_client:
cohere_client = AsyncClient(
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
api_key,
api_url=self.api_base_url,
max_retries=self.max_retries,
timeout=self.timeout,
)
all_embeddings, metadata = asyncio.run(
get_async_response(cohere_client, texts_to_embed, self.model, self.input_type, self.truncate)
)
else:
cohere_client = Client(
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
api_key,
api_url=self.api_base_url,
max_retries=self.max_retries,
timeout=self.timeout,
)
all_embeddings, metadata = get_response(
cohere_client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

from haystack import component, default_to_dict
from haystack import component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from haystack_integrations.components.embedders.cohere.utils import get_async_response, get_response

from cohere import COHERE_API_URL, AsyncClient, Client
Expand Down Expand Up @@ -33,7 +33,7 @@ class CohereTextEmbedder:

def __init__(
self,
api_key: Optional[str] = None,
api_key: Secret = Secret.from_env_var(["COHERE_API_KEY", "CO_API_KEY"]),
model: str = "embed-english-v2.0",
input_type: str = "search_query",
api_base_url: str = COHERE_API_URL,
Expand All @@ -45,8 +45,7 @@ def __init__(
"""
Create a CohereTextEmbedder component.
:param api_key: The Cohere API key. It can be explicitly provided or automatically read from the environment
variable COHERE_API_KEY (recommended).
:param api_key: The Cohere API key.
:param model: The name of the model to use, defaults to `"embed-english-v2.0"`. Supported Models are:
`"embed-english-v3.0"`, `"embed-english-light-v3.0"`, `"embed-multilingual-v3.0"`,
`"embed-multilingual-light-v3.0"`, `"embed-english-v2.0"`, `"embed-english-light-v2.0"`,
Expand All @@ -67,15 +66,6 @@ def __init__(
:param timeout: Request timeout in seconds, defaults to `120`.
"""

api_key = api_key or os.environ.get("COHERE_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
msg = (
"CohereTextEmbedder expects an API key. "
"Set the COHERE_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.api_key = api_key
self.model = model
self.input_type = input_type
Expand All @@ -91,6 +81,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
api_key=self.api_key.to_dict(),
model=self.model,
input_type=self.input_type,
api_base_url=self.api_base_url,
Expand All @@ -100,6 +91,17 @@ def to_dict(self) -> Dict[str, Any]:
timeout=self.timeout,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CohereTextEmbedder":
"""
Deserialize this component from a dictionary.
:param data: The dictionary representation of this component.
:return: The deserialized component instance.
"""
init_params = data.get("init_parameters", {})
deserialize_secrets_inplace(init_params, ["api_key"])
return default_from_dict(cls, data)

@component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str):
"""Embed a string."""
Expand All @@ -112,16 +114,25 @@ def run(self, text: str):

# Establish connection to API

api_key = self.api_key.resolve_value()
assert api_key is not None

if self.use_async_client:
cohere_client = AsyncClient(
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
api_key,
api_url=self.api_base_url,
max_retries=self.max_retries,
timeout=self.timeout,
)
embedding, metadata = asyncio.run(
get_async_response(cohere_client, [text], self.model, self.input_type, self.truncate)
)
else:
cohere_client = Client(
self.api_key, api_url=self.api_base_url, max_retries=self.max_retries, timeout=self.timeout
api_key,
api_url=self.api_base_url,
max_retries=self.max_retries,
timeout=self.timeout,
)
embedding, metadata = get_response(cohere_client, [text], self.model, self.input_type, self.truncate)

Expand Down
Loading

0 comments on commit bca7226

Please sign in to comment.