Skip to content

Commit

Permalink
Jina: update secrets management (#411)
Browse files Browse the repository at this point in the history
* jina update secrets management

* fix wrong message
  • Loading branch information
anakin87 authored Feb 14, 2024
1 parent e7e3ece commit 2f451e9
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 38 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Any, Dict, List, Optional, Tuple

import requests
from haystack import Document, component, default_to_dict
from haystack import Document, component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace
from tqdm import tqdm

JINA_API_URL: str = "https://api.jina.ai/v1/embeddings"
Expand Down Expand Up @@ -35,7 +35,7 @@ class JinaDocumentEmbedder:

def __init__(
self,
api_key: Optional[str] = None,
api_key: Secret = Secret.from_env_var("JINA_API_KEY"), # noqa: B008
model: str = "jina-embeddings-v2-base-en",
prefix: str = "",
suffix: str = "",
Expand All @@ -46,8 +46,7 @@ def __init__(
):
"""
Create a JinaDocumentEmbedder component.
:param api_key: The Jina API key. It can be explicitly provided or automatically read from the
environment variable JINA_API_KEY (recommended).
:param api_key: The Jina API key.
:param model: The name of the Jina model to use. Check the list of available models on `https://jina.ai/embeddings/`
:param prefix: A string to add to the beginning of each text.
:param suffix: A string to add to the end of each text.
Expand All @@ -57,16 +56,15 @@ def __init__(
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document text.
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
"""

api_key = api_key or os.environ.get("JINA_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
resolved_api_key = api_key.resolve_value()
if resolved_api_key is None:
msg = (
"JinaDocumentEmbedder expects an API key. "
"Set the JINA_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.api_key = api_key
self.model_name = model
self.prefix = prefix
self.suffix = suffix
Expand All @@ -77,7 +75,7 @@ def __init__(
self._session = requests.Session()
self._session.headers.update(
{
"Authorization": f"Bearer {api_key}",
"Authorization": f"Bearer {resolved_api_key}",
"Accept-Encoding": "identity",
"Content-type": "application/json",
}
Expand All @@ -96,6 +94,7 @@ def to_dict(self) -> Dict[str, Any]:
"""
return default_to_dict(
self,
api_key=self.api_key.to_dict(),
model=self.model_name,
prefix=self.prefix,
suffix=self.suffix,
Expand All @@ -105,6 +104,11 @@ def to_dict(self) -> Dict[str, Any]:
embedding_separator=self.embedding_separator,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "JinaDocumentEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)

def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

import requests
from haystack import component, default_to_dict
from haystack import component, default_from_dict, default_to_dict
from haystack.utils import Secret, deserialize_secrets_inplace

JINA_API_URL: str = "https://api.jina.ai/v1/embeddings"

Expand Down Expand Up @@ -33,7 +33,7 @@ class JinaTextEmbedder:

def __init__(
self,
api_key: Optional[str] = None,
api_key: Secret = Secret.from_env_var("JINA_API_KEY"), # noqa: B008
model: str = "jina-embeddings-v2-base-en",
prefix: str = "",
suffix: str = "",
Expand All @@ -48,22 +48,22 @@ def __init__(
:param suffix: A string to add to the end of each text.
"""

api_key = api_key or os.environ.get("JINA_API_KEY")
# we check whether api_key is None or an empty string
if not api_key:
resolved_api_key = api_key.resolve_value()
if resolved_api_key is None:
msg = (
"JinaTextEmbedder expects an API key. "
"Set the JINA_API_KEY environment variable (recommended) or pass it explicitly."
)
raise ValueError(msg)

self.api_key = api_key
self.model_name = model
self.prefix = prefix
self.suffix = suffix
self._session = requests.Session()
self._session.headers.update(
{
"Authorization": f"Bearer {api_key}",
"Authorization": f"Bearer {resolved_api_key}",
"Accept-Encoding": "identity",
"Content-type": "application/json",
}
Expand All @@ -81,7 +81,14 @@ def to_dict(self) -> Dict[str, Any]:
to the constructor.
"""

return default_to_dict(self, model=self.model_name, prefix=self.prefix, suffix=self.suffix)
return default_to_dict(
self, api_key=self.api_key.to_dict(), model=self.model_name, prefix=self.prefix, suffix=self.suffix
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "JinaTextEmbedder":
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)

@component.output_types(embedding=List[float], meta=Dict[str, Any])
def run(self, text: str):
Expand Down
33 changes: 21 additions & 12 deletions integrations/jina/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import requests
from haystack import Document
from haystack.utils import Secret
from haystack_integrations.components.embedders.jina import JinaDocumentEmbedder


Expand All @@ -28,6 +29,7 @@ def test_init_default(self, monkeypatch):
monkeypatch.setenv("JINA_API_KEY", "fake-api-key")
embedder = JinaDocumentEmbedder()

assert embedder.api_key == Secret.from_env_var("JINA_API_KEY")
assert embedder.model_name == "jina-embeddings-v2-base-en"
assert embedder.prefix == ""
assert embedder.suffix == ""
Expand All @@ -38,7 +40,7 @@ def test_init_default(self, monkeypatch):

def test_init_with_parameters(self):
embedder = JinaDocumentEmbedder(
api_key="fake-api-key",
api_key=Secret.from_token("fake-api-key"),
model="model",
prefix="prefix",
suffix="suffix",
Expand All @@ -47,6 +49,8 @@ def test_init_with_parameters(self):
meta_fields_to_embed=["test_field"],
embedding_separator=" | ",
)

assert embedder.api_key == Secret.from_token("fake-api-key")
assert embedder.model_name == "model"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
Expand All @@ -60,12 +64,14 @@ def test_init_fail_wo_api_key(self, monkeypatch):
with pytest.raises(ValueError):
JinaDocumentEmbedder()

def test_to_dict(self):
component = JinaDocumentEmbedder(api_key="fake-api-key")
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("JINA_API_KEY", "fake-api-key")
component = JinaDocumentEmbedder()
data = component.to_dict()
assert data == {
"type": "haystack_integrations.components.embedders.jina.document_embedder.JinaDocumentEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"},
"model": "jina-embeddings-v2-base-en",
"prefix": "",
"suffix": "",
Expand All @@ -76,9 +82,9 @@ def test_to_dict(self):
},
}

def test_to_dict_with_custom_init_parameters(self):
def test_to_dict_with_custom_init_parameters(self, monkeypatch):
monkeypatch.setenv("JINA_API_KEY", "fake-api-key")
component = JinaDocumentEmbedder(
api_key="fake-api-key",
model="model",
prefix="prefix",
suffix="suffix",
Expand All @@ -91,6 +97,7 @@ def test_to_dict_with_custom_init_parameters(self):
assert data == {
"type": "haystack_integrations.components.embedders.jina.document_embedder.JinaDocumentEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"},
"model": "model",
"prefix": "prefix",
"suffix": "suffix",
Expand All @@ -107,7 +114,7 @@ def test_prepare_texts_to_embed_w_metadata(self):
]

embedder = JinaDocumentEmbedder(
api_key="fake-api-key", meta_fields_to_embed=["meta_field"], embedding_separator=" | "
api_key=Secret.from_token("fake-api-key"), meta_fields_to_embed=["meta_field"], embedding_separator=" | "
)

prepared_texts = embedder._prepare_texts_to_embed(documents)
Expand All @@ -124,7 +131,9 @@ def test_prepare_texts_to_embed_w_metadata(self):
def test_prepare_texts_to_embed_w_suffix(self):
documents = [Document(content=f"document number {i}") for i in range(5)]

embedder = JinaDocumentEmbedder(api_key="fake-api-key", prefix="my_prefix ", suffix=" my_suffix")
embedder = JinaDocumentEmbedder(
api_key=Secret.from_token("fake-api-key"), prefix="my_prefix ", suffix=" my_suffix"
)

prepared_texts = embedder._prepare_texts_to_embed(documents)

Expand All @@ -140,7 +149,7 @@ def test_embed_batch(self):
texts = ["text 1", "text 2", "text 3", "text 4", "text 5"]

with patch("requests.sessions.Session.post", side_effect=mock_session_post_response):
embedder = JinaDocumentEmbedder(api_key="fake-api-key", model="model")
embedder = JinaDocumentEmbedder(api_key=Secret.from_token("fake-api-key"), model="model")

embeddings, metadata = embedder._embed_batch(texts_to_embed=texts, batch_size=2)

Expand All @@ -162,7 +171,7 @@ def test_run(self):
model = "jina-embeddings-v2-base-en"
with patch("requests.sessions.Session.post", side_effect=mock_session_post_response):
embedder = JinaDocumentEmbedder(
api_key="fake-api-key",
api_key=Secret.from_token("fake-api-key"),
model=model,
prefix="prefix ",
suffix=" suffix",
Expand Down Expand Up @@ -192,7 +201,7 @@ def test_run_custom_batch_size(self):
model = "jina-embeddings-v2-base-en"
with patch("requests.sessions.Session.post", side_effect=mock_session_post_response):
embedder = JinaDocumentEmbedder(
api_key="fake-api-key",
api_key=Secret.from_token("fake-api-key"),
model=model,
prefix="prefix ",
suffix=" suffix",
Expand All @@ -217,7 +226,7 @@ def test_run_custom_batch_size(self):
assert metadata == {"model": model, "usage": {"prompt_tokens": 2 * 4, "total_tokens": 2 * 4}}

def test_run_wrong_input_format(self):
embedder = JinaDocumentEmbedder(api_key="fake-api-key")
embedder = JinaDocumentEmbedder(api_key=Secret.from_token("fake-api-key"))

string_input = "text"
list_integers_input = [1, 2, 3]
Expand All @@ -229,7 +238,7 @@ def test_run_wrong_input_format(self):
embedder.run(documents=list_integers_input)

def test_run_on_empty_list(self):
embedder = JinaDocumentEmbedder(api_key="fake-api-key")
embedder = JinaDocumentEmbedder(api_key=Secret.from_token("fake-api-key"))

empty_list_input = []
result = embedder.run(documents=empty_list_input)
Expand Down
22 changes: 15 additions & 7 deletions integrations/jina/tests/test_text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest
import requests
from haystack.utils import Secret
from haystack_integrations.components.embedders.jina import JinaTextEmbedder


Expand All @@ -14,17 +15,19 @@ def test_init_default(self, monkeypatch):
monkeypatch.setenv("JINA_API_KEY", "fake-api-key")
embedder = JinaTextEmbedder()

assert embedder.api_key == Secret.from_env_var("JINA_API_KEY")
assert embedder.model_name == "jina-embeddings-v2-base-en"
assert embedder.prefix == ""
assert embedder.suffix == ""

def test_init_with_parameters(self):
embedder = JinaTextEmbedder(
api_key="fake-api-key",
api_key=Secret.from_token("fake-api-key"),
model="model",
prefix="prefix",
suffix="suffix",
)
assert embedder.api_key == Secret.from_token("fake-api-key")
assert embedder.model_name == "model"
assert embedder.prefix == "prefix"
assert embedder.suffix == "suffix"
Expand All @@ -34,21 +37,23 @@ def test_init_fail_wo_api_key(self, monkeypatch):
with pytest.raises(ValueError):
JinaTextEmbedder()

def test_to_dict(self):
component = JinaTextEmbedder(api_key="fake-api-key")
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("JINA_API_KEY", "fake-api-key")
component = JinaTextEmbedder()
data = component.to_dict()
assert data == {
"type": "haystack_integrations.components.embedders.jina.text_embedder.JinaTextEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"},
"model": "jina-embeddings-v2-base-en",
"prefix": "",
"suffix": "",
},
}

def test_to_dict_with_custom_init_parameters(self):
def test_to_dict_with_custom_init_parameters(self, monkeypatch):
monkeypatch.setenv("JINA_API_KEY", "fake-api-key")
component = JinaTextEmbedder(
api_key="fake-api-key",
model="model",
prefix="prefix",
suffix="suffix",
Expand All @@ -57,6 +62,7 @@ def test_to_dict_with_custom_init_parameters(self):
assert data == {
"type": "haystack_integrations.components.embedders.jina.text_embedder.JinaTextEmbedder",
"init_parameters": {
"api_key": {"env_vars": ["JINA_API_KEY"], "strict": True, "type": "env_var"},
"model": "model",
"prefix": "prefix",
"suffix": "suffix",
Expand All @@ -80,7 +86,9 @@ def test_run(self):

mock_post.return_value = mock_response

embedder = JinaTextEmbedder(api_key="fake-api-key", model=model, prefix="prefix ", suffix=" suffix")
embedder = JinaTextEmbedder(
api_key=Secret.from_token("fake-api-key"), model=model, prefix="prefix ", suffix=" suffix"
)
result = embedder.run(text="The food was delicious")

assert len(result["embedding"]) == 3
Expand All @@ -91,7 +99,7 @@ def test_run(self):
}

def test_run_wrong_input_format(self):
embedder = JinaTextEmbedder(api_key="fake-api-key")
embedder = JinaTextEmbedder(api_key=Secret.from_token("fake-api-key"))

list_integers_input = [1, 2, 3]

Expand Down

0 comments on commit 2f451e9

Please sign in to comment.