From 8f79c026a90e9a85f6777528de4e71a8842c9b70 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 15 Feb 2024 15:57:50 +0100 Subject: [PATCH 1/9] Update WeaviateDocumentStore authentication to use new Secret class --- .../document_stores/weaviate/auth.py | 170 ++++++++++++++++ .../weaviate/document_store.py | 30 +-- integrations/weaviate/tests/test_auth.py | 189 ++++++++++++++++++ .../weaviate/tests/test_document_store.py | 33 +-- 4 files changed, 387 insertions(+), 35 deletions(-) create mode 100644 integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py create mode 100644 integrations/weaviate/tests/test_auth.py diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py new file mode 100644 index 000000000..f67269604 --- /dev/null +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py @@ -0,0 +1,170 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, field, fields +from typing import Any, Dict + +from haystack.core.errors import DeserializationError +from haystack.core.serialization import default_from_dict, default_to_dict +from haystack.utils.auth import Secret, deserialize_secrets_inplace + +from weaviate.auth import AuthApiKey as WeaviateAuthApiKey +from weaviate.auth import AuthBearerToken as WeaviateAuthBearerToken +from weaviate.auth import AuthClientCredentials as WeaviateAuthClientCredentials +from weaviate.auth import AuthClientPassword as WeaviateAuthClientPassword + + +@dataclass +class AuthCredentials(ABC): + """ + Base class for all auth credentials supported by WeaviateDocumentStore. + Can be used to deserialize from dict any of the supported auth credentials. + """ + + def to_dict(self) -> Dict[str, Any]: + """ + Converts the object to a dictionary representation for serialization. + """ + # We assume all fields are Secret instances + _fields = {f.name: getattr(self, f.name).to_dict() for f in fields(self)} + + return default_to_dict( + self, + **_fields, + ) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> "AuthCredentials": + """ + Converts a dictionary representation to an auth credentials object. + """ + if "type" not in data: + msg = "Missing 'type' in serialization data" + raise DeserializationError(msg) + return _AUTH_CLASSES[data["type"]]._from_dict(data) + + @classmethod + @abstractmethod + def _from_dict(cls, data: Dict[str, Any]): + """ + Internal method to convert a dictionary representation to an auth credentials object. + All subclasses must implement this method. + """ + + @abstractmethod + def resolve_value(self): + """ + Resolves all the secrets in the auth credentials object and returns the corresponding Weaviate object. + All subclasses must implement this method. + """ + + +@dataclass +class AuthApiKey(AuthCredentials): + """ + AuthCredentials for API key authentication. + By default it will load `api_key` from the environment variable `WEAVIATE_API_KEY`. + """ + + api_key: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_API_KEY"])) + + @classmethod + def _from_dict(cls, data: Dict[str, Any]) -> "AuthApiKey": + deserialize_secrets_inplace(data["init_parameters"], ["api_key"]) + return default_from_dict(cls, data) + + def resolve_value(self) -> WeaviateAuthApiKey: + return WeaviateAuthApiKey(api_key=self.api_key.resolve_value()) + + +@dataclass +class AuthBearerToken(AuthCredentials): + """ + AuthCredentials for Bearer token authentication. + By default it will load `access_token` from the environment variable `WEAVIATE_ACCESS_TOKEN`, + `expires_in` from the environment variable `WEAVIATE_EXPIRES_IN`, and `refresh_token` from the environment variable + `WEAVIATE_REFRESH_TOKEN`. + `WEAVIATE_EXPIRES_IN` environment variable is optional, if set must be an integer. + `WEAVIATE_REFRESH_TOKEN` environment variable is optional. + """ + + access_token: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_ACCESS_TOKEN"])) + expires_in: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_EXPIRES_IN"], strict=False)) + refresh_token: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_REFRESH_TOKEN"], strict=False)) + + @classmethod + def _from_dict(cls, data: Dict[str, Any]) -> "AuthBearerToken": + deserialize_secrets_inplace(data["init_parameters"], ["access_token", "expires_in", "refresh_token"]) + return default_from_dict(cls, data) + + def resolve_value(self) -> WeaviateAuthBearerToken: + access_token = self.access_token.resolve_value() + expires_in = self.expires_in.resolve_value() + refresh_token = self.refresh_token.resolve_value() + expires_in = int(expires_in) if expires_in else 60 + + return WeaviateAuthBearerToken( + access_token=access_token, + expires_in=expires_in, + refresh_token=refresh_token, + ) + + +@dataclass +class AuthClientCredentials(AuthCredentials): + """ + AuthCredentials for client credentials authentication. + By default it will load `client_secret` from the environment variable `WEAVIATE_CLIENT_SECRET`, and + `scope` from the environment variable `WEAVIATE_SCOPE`. + `WEAVIATE_SCOPE` environment variable is optional, if set it can either be a string or a list of space + separated strings. e.g "scope1" or "scope1 scope2". + """ + + client_secret: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_CLIENT_SECRET"])) + scope: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_SCOPE"], strict=False)) + + @classmethod + def _from_dict(cls, data: Dict[str, Any]) -> "AuthClientCredentials": + deserialize_secrets_inplace(data["init_parameters"], ["client_secret", "scope"]) + return default_from_dict(cls, data) + + def resolve_value(self) -> WeaviateAuthClientCredentials: + return WeaviateAuthClientCredentials( + client_secret=self.client_secret.resolve_value(), + scope=self.scope.resolve_value(), + ) + + +@dataclass +class AuthClientPassword(AuthCredentials): + """ + AuthCredentials for username and password authentication. + By default it will load `username` from the environment variable `WEAVIATE_USERNAME`, + `password` from the environment variable `WEAVIATE_PASSWORD`, and + `scope` from the environment variable `WEAVIATE_SCOPE`. + `WEAVIATE_SCOPE` environment variable is optional, if set it can either be a string or a list of space + separated strings. e.g "scope1" or "scope1 scope2". + """ + + username: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_USERNAME"])) + password: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_PASSWORD"])) + scope: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_SCOPE"], strict=False)) + + @classmethod + def _from_dict(cls, data: Dict[str, Any]) -> "AuthClientPassword": + deserialize_secrets_inplace(data["init_parameters"], ["username", "password", "scope"]) + return default_from_dict(cls, data) + + def resolve_value(self) -> WeaviateAuthClientPassword: + return WeaviateAuthClientPassword( + username=self.username.resolve_value(), + password=self.password.resolve_value(), + scope=self.scope.resolve_value(), + ) + + +# This simplifies a bit how we handle deserialization of the auth credentials. +_AUTH_CLASSES = { + "haystack_integrations.document_stores.weaviate.auth.AuthClientCredentials": AuthClientCredentials, + "haystack_integrations.document_stores.weaviate.auth.AuthClientPassword": AuthClientPassword, + "haystack_integrations.document_stores.weaviate.auth.AuthBearerToken": AuthBearerToken, + "haystack_integrations.document_stores.weaviate.auth.AuthApiKey": AuthApiKey, +} diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 38f0b38cd..071fe336b 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -11,24 +11,16 @@ from haystack.document_stores.types.policy import DuplicatePolicy import weaviate -from weaviate.auth import AuthCredentials from weaviate.config import Config, ConnectionConfig from weaviate.embedded import EmbeddedOptions from weaviate.util import generate_uuid5 from ._filters import convert_filters +from .auth import AuthCredentials Number = Union[int, float] TimeoutType = Union[Tuple[Number, Number], Number] -# This simplifies a bit how we handle deserialization of the auth credentials. -# Otherwise we would need to use importlib to dynamically import the correct class. -_AUTH_CLASSES = { - "weaviate.auth.AuthClientCredentials": weaviate.auth.AuthClientCredentials, - "weaviate.auth.AuthClientPassword": weaviate.auth.AuthClientPassword, - "weaviate.auth.AuthBearerToken": weaviate.auth.AuthBearerToken, - "weaviate.auth.AuthApiKey": weaviate.auth.AuthApiKey, -} # This is the default collection properties for Weaviate. # It's a list of properties that will be created on the collection. @@ -92,10 +84,10 @@ def __init__( for more information on collections and their properties. :param auth_client_secret: Authentication credentials, defaults to None. Can be one of the following types depending on the authentication mode: - - `weaviate.auth.AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens - - `weaviate.auth.AuthClientPassword` to use username and password for oidc Resource Owner Password flow - - `weaviate.auth.AuthClientCredentials` to use a client secret for oidc client credential flow - - `weaviate.auth.AuthApiKey` to use an API key + - `AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens + - `AuthClientPassword` to use username and password for oidc Resource Owner Password flow + - `AuthClientCredentials` to use a client secret for oidc client credential flow + - `AuthApiKey` to use an API key :param timeout_config: Timeout configuration for all requests to the Weaviate server, defaults to (10, 60). It can be a real number or, a tuple of two real numbers: (connect timeout, read timeout). If only one real number is passed then both connect and read timeout will be set to @@ -124,7 +116,7 @@ def __init__( """ self._client = weaviate.Client( url=url, - auth_client_secret=auth_client_secret, + auth_client_secret=auth_client_secret.resolve_value() if auth_client_secret else None, timeout_config=timeout_config, proxies=proxies, trust_env=trust_env, @@ -164,11 +156,6 @@ def __init__( self._additional_config = additional_config def to_dict(self) -> Dict[str, Any]: - auth_client_secret = None - if self._auth_client_secret: - # There are different types of AuthCredentials, so even thought it's a dataclass - # and we could just use asdict, we need to save the type too. - auth_client_secret = default_to_dict(self._auth_client_secret, **asdict(self._auth_client_secret)) embedded_options = asdict(self._embedded_options) if self._embedded_options else None additional_config = asdict(self._additional_config) if self._additional_config else None @@ -176,7 +163,7 @@ def to_dict(self) -> Dict[str, Any]: self, url=self._url, collection_settings=self._collection_settings, - auth_client_secret=auth_client_secret, + auth_client_secret=self._auth_client_secret.to_dict() if self._auth_client_secret else None, timeout_config=self._timeout_config, proxies=self._proxies, trust_env=self._trust_env, @@ -193,8 +180,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "WeaviateDocumentStore": tuple(timeout_config) if isinstance(timeout_config, list) else timeout_config ) if (auth_client_secret := data["init_parameters"].get("auth_client_secret")) is not None: - auth_class = _AUTH_CLASSES[auth_client_secret["type"]] - data["init_parameters"]["auth_client_secret"] = default_from_dict(auth_class, auth_client_secret) + data["init_parameters"]["auth_client_secret"] = AuthCredentials.from_dict(auth_client_secret) if (embedded_options := data["init_parameters"].get("embedded_options")) is not None: data["init_parameters"]["embedded_options"] = EmbeddedOptions(**embedded_options) if (additional_config := data["init_parameters"].get("additional_config")) is not None: diff --git a/integrations/weaviate/tests/test_auth.py b/integrations/weaviate/tests/test_auth.py new file mode 100644 index 000000000..14bdf1317 --- /dev/null +++ b/integrations/weaviate/tests/test_auth.py @@ -0,0 +1,189 @@ +from haystack_integrations.document_stores.weaviate.auth import ( + AuthApiKey, + AuthBearerToken, + AuthClientCredentials, + AuthClientPassword, + AuthCredentials, +) +from weaviate.auth import AuthApiKey as WeaviateAuthApiKey +from weaviate.auth import AuthBearerToken as WeaviateAuthBearerToken +from weaviate.auth import AuthClientCredentials as WeaviateAuthClientCredentials +from weaviate.auth import AuthClientPassword as WeaviateAuthClientPassword + + +class TestAuthApiKey: + def test_init(self): + credentials = AuthApiKey() + assert credentials.api_key._env_vars == ["WEAVIATE_API_KEY"] + assert credentials.api_key._strict + + def test_to_dict(self): + credentials = AuthApiKey() + assert credentials.to_dict() == { + "type": "haystack_integrations.document_stores.weaviate.auth.AuthApiKey", + "init_parameters": {"api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"}}, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("WEAVIATE_API_KEY", "fake_key") + credentials = AuthCredentials.from_dict( + { + "type": "haystack_integrations.document_stores.weaviate.auth.AuthApiKey", + "init_parameters": {"api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"}}, + } + ) + assert isinstance(credentials, AuthApiKey) + assert credentials.api_key._env_vars == ["WEAVIATE_API_KEY"] + assert credentials.api_key._strict + + def test_resolve_value(self, monkeypatch): + monkeypatch.setenv("WEAVIATE_API_KEY", "fake_key") + credentials = AuthApiKey() + resolved = credentials.resolve_value() + assert isinstance(resolved, WeaviateAuthApiKey) + assert resolved.api_key == "fake_key" + + +class TestAuthBearerToken: + def test_init(self): + credentials = AuthBearerToken() + assert credentials.access_token._env_vars == ["WEAVIATE_ACCESS_TOKEN"] + assert credentials.access_token._strict + assert credentials.expires_in._env_vars == ["WEAVIATE_EXPIRES_IN"] + assert not credentials.expires_in._strict + assert credentials.refresh_token._env_vars == ["WEAVIATE_REFRESH_TOKEN"] + assert not credentials.refresh_token._strict + + def test_to_dict(self): + credentials = AuthBearerToken() + assert credentials.to_dict() == { + "type": "haystack_integrations.document_stores.weaviate.auth.AuthBearerToken", + "init_parameters": { + "access_token": {"env_vars": ["WEAVIATE_ACCESS_TOKEN"], "strict": True, "type": "env_var"}, + "expires_in": {"env_vars": ["WEAVIATE_EXPIRES_IN"], "strict": False, "type": "env_var"}, + "refresh_token": {"env_vars": ["WEAVIATE_REFRESH_TOKEN"], "strict": False, "type": "env_var"}, + }, + } + + def test_from_dict(self): + credentials = AuthCredentials.from_dict( + { + "type": "haystack_integrations.document_stores.weaviate.auth.AuthBearerToken", + "init_parameters": { + "access_token": {"env_vars": ["WEAVIATE_ACCESS_TOKEN"], "strict": True, "type": "env_var"}, + "expires_in": {"env_vars": ["WEAVIATE_EXPIRES_IN"], "strict": False, "type": "env_var"}, + "refresh_token": {"env_vars": ["WEAVIATE_REFRESH_TOKEN"], "strict": False, "type": "env_var"}, + }, + } + ) + assert credentials.access_token._env_vars == ["WEAVIATE_ACCESS_TOKEN"] + assert credentials.access_token._strict + assert credentials.expires_in._env_vars == ["WEAVIATE_EXPIRES_IN"] + assert not credentials.expires_in._strict + assert credentials.refresh_token._env_vars == ["WEAVIATE_REFRESH_TOKEN"] + assert not credentials.refresh_token._strict + + def test_resolve_value(self, monkeypatch): + monkeypatch.setenv("WEAVIATE_ACCESS_TOKEN", "fake_key") + monkeypatch.setenv("WEAVIATE_EXPIRES_IN", "10") + monkeypatch.setenv("WEAVIATE_REFRESH_TOKEN", "fake_refresh_token") + credentials = AuthBearerToken() + resolved = credentials.resolve_value() + assert isinstance(resolved, WeaviateAuthBearerToken) + assert resolved.access_token == "fake_key" + assert resolved.expires_in == 10 + assert resolved.refresh_token == "fake_refresh_token" + + +class TestAuthClientCredentials: + def test_init(self): + credentials = AuthClientCredentials() + assert credentials.client_secret._env_vars == ["WEAVIATE_CLIENT_SECRET"] + assert credentials.client_secret._strict + assert credentials.scope._env_vars == ["WEAVIATE_SCOPE"] + assert not credentials.scope._strict + + def test_to_dict(self): + credentials = AuthClientCredentials() + assert credentials.to_dict() == { + "type": "haystack_integrations.document_stores.weaviate.auth.AuthClientCredentials", + "init_parameters": { + "client_secret": {"env_vars": ["WEAVIATE_CLIENT_SECRET"], "strict": True, "type": "env_var"}, + "scope": {"env_vars": ["WEAVIATE_SCOPE"], "strict": False, "type": "env_var"}, + }, + } + + def test_from_dict(self): + credentials = AuthCredentials.from_dict( + { + "type": "haystack_integrations.document_stores.weaviate.auth.AuthClientCredentials", + "init_parameters": { + "client_secret": {"env_vars": ["WEAVIATE_CLIENT_SECRET"], "strict": True, "type": "env_var"}, + "scope": {"env_vars": ["WEAVIATE_SCOPE"], "strict": False, "type": "env_var"}, + }, + } + ) + assert credentials.client_secret._env_vars == ["WEAVIATE_CLIENT_SECRET"] + assert credentials.client_secret._strict + assert credentials.scope._env_vars == ["WEAVIATE_SCOPE"] + assert not credentials.scope._strict + + def test_resolve_value(self, monkeypatch): + monkeypatch.setenv("WEAVIATE_CLIENT_SECRET", "fake_secret") + monkeypatch.setenv("WEAVIATE_SCOPE", "fake_scope another_fake_scope") + credentials = AuthClientCredentials() + resolved = credentials.resolve_value() + assert isinstance(resolved, WeaviateAuthClientCredentials) + assert resolved.client_secret == "fake_secret" + assert resolved.scope_list == ["fake_scope", "another_fake_scope"] + + +class TestAuthClientPassword: + def test_init(self): + credentials = AuthClientPassword() + assert credentials.username._env_vars == ["WEAVIATE_USERNAME"] + assert credentials.username._strict + assert credentials.password._env_vars == ["WEAVIATE_PASSWORD"] + assert credentials.password._strict + assert credentials.scope._env_vars == ["WEAVIATE_SCOPE"] + assert not credentials.scope._strict + + def test_to_dict(self): + credentials = AuthClientPassword() + assert credentials.to_dict() == { + "type": "haystack_integrations.document_stores.weaviate.auth.AuthClientPassword", + "init_parameters": { + "username": {"env_vars": ["WEAVIATE_USERNAME"], "strict": True, "type": "env_var"}, + "password": {"env_vars": ["WEAVIATE_PASSWORD"], "strict": True, "type": "env_var"}, + "scope": {"env_vars": ["WEAVIATE_SCOPE"], "strict": False, "type": "env_var"}, + }, + } + + def test_from_dict(self): + credentials = AuthCredentials.from_dict( + { + "type": "haystack_integrations.document_stores.weaviate.auth.AuthClientPassword", + "init_parameters": { + "username": {"env_vars": ["WEAVIATE_USERNAME"], "strict": True, "type": "env_var"}, + "password": {"env_vars": ["WEAVIATE_PASSWORD"], "strict": True, "type": "env_var"}, + "scope": {"env_vars": ["WEAVIATE_SCOPE"], "strict": False, "type": "env_var"}, + }, + } + ) + assert credentials.username._env_vars == ["WEAVIATE_USERNAME"] + assert credentials.username._strict + assert credentials.password._env_vars == ["WEAVIATE_PASSWORD"] + assert credentials.password._strict + assert credentials.scope._env_vars == ["WEAVIATE_SCOPE"] + assert not credentials.scope._strict + + def test_resolve_value(self, monkeypatch): + monkeypatch.setenv("WEAVIATE_USERNAME", "fake_username") + monkeypatch.setenv("WEAVIATE_PASSWORD", "fake_password") + monkeypatch.setenv("WEAVIATE_SCOPE", "fake_scope another_fake_scope") + credentials = AuthClientPassword() + resolved = credentials.resolve_value() + assert isinstance(resolved, WeaviateAuthClientPassword) + assert resolved.username == "fake_username" + assert resolved.password == "fake_password" + assert resolved.scope_list == ["fake_scope", "another_fake_scope"] diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 359af3670..f1512b23b 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -15,6 +15,7 @@ FilterDocumentsTest, WriteDocumentsTest, ) +from haystack_integrations.document_stores.weaviate.auth import AuthApiKey from haystack_integrations.document_stores.weaviate.document_store import ( DOCUMENT_COLLECTION_PROPERTIES, WeaviateDocumentStore, @@ -23,7 +24,7 @@ from numpy import array_equal as np_array_equal from numpy import float32 as np_float32 from pandas import DataFrame -from weaviate.auth import AuthApiKey +from weaviate.auth import AuthApiKey as WeaviateAuthApiKey from weaviate.config import Config from weaviate.embedded import ( DEFAULT_BINARY_PATH, @@ -145,15 +146,15 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do assert received_meta.get(key) == expected_meta.get(key) @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.Client") - def test_init(self, mock_weaviate_client_class): + def test_init(self, mock_weaviate_client_class, monkeypatch): mock_client = MagicMock() mock_client.schema.exists.return_value = False mock_weaviate_client_class.return_value = mock_client - + monkeypatch.setenv("WEAVIATE_API_KEY", "my_api_key") WeaviateDocumentStore( url="http://localhost:8080", collection_settings={"class": "My_collection"}, - auth_client_secret=AuthApiKey("my_api_key"), + auth_client_secret=AuthApiKey(), proxies={"http": "http://proxy:1234"}, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, embedded_options=EmbeddedOptions( @@ -168,7 +169,7 @@ def test_init(self, mock_weaviate_client_class): # Verify client is created with correct parameters mock_weaviate_client_class.assert_called_once_with( url="http://localhost:8080", - auth_client_secret=AuthApiKey("my_api_key"), + auth_client_secret=WeaviateAuthApiKey("my_api_key"), timeout_config=(10, 60), proxies={"http": "http://proxy:1234"}, trust_env=False, @@ -191,10 +192,11 @@ def test_init(self, mock_weaviate_client_class): ) @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") - def test_to_dict(self, _mock_weaviate): + def test_to_dict(self, _mock_weaviate, monkeypatch): + monkeypatch.setenv("WEAVIATE_API_KEY", "my_api_key") document_store = WeaviateDocumentStore( url="http://localhost:8080", - auth_client_secret=AuthApiKey("my_api_key"), + auth_client_secret=AuthApiKey(), proxies={"http": "http://proxy:1234"}, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, embedded_options=EmbeddedOptions( @@ -222,8 +224,10 @@ def test_to_dict(self, _mock_weaviate): ], }, "auth_client_secret": { - "type": "weaviate.auth.AuthApiKey", - "init_parameters": {"api_key": "my_api_key"}, + "type": "haystack_integrations.document_stores.weaviate.auth.AuthApiKey", + "init_parameters": { + "api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"} + }, }, "timeout_config": (10, 60), "proxies": {"http": "http://proxy:1234"}, @@ -250,7 +254,8 @@ def test_to_dict(self, _mock_weaviate): } @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") - def test_from_dict(self, _mock_weaviate): + def test_from_dict(self, _mock_weaviate, monkeypatch): + monkeypatch.setenv("WEAVIATE_API_KEY", "my_api_key") document_store = WeaviateDocumentStore.from_dict( { "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", @@ -258,8 +263,10 @@ def test_from_dict(self, _mock_weaviate): "url": "http://localhost:8080", "collection_settings": None, "auth_client_secret": { - "type": "weaviate.auth.AuthApiKey", - "init_parameters": {"api_key": "my_api_key"}, + "type": "haystack_integrations.document_stores.weaviate.auth.AuthApiKey", + "init_parameters": { + "api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"} + }, }, "timeout_config": [10, 60], "proxies": {"http": "http://proxy:1234"}, @@ -299,7 +306,7 @@ def test_from_dict(self, _mock_weaviate): {"name": "score", "dataType": ["number"]}, ], } - assert document_store._auth_client_secret == AuthApiKey("my_api_key") + assert document_store._auth_client_secret == AuthApiKey() assert document_store._timeout_config == (10, 60) assert document_store._proxies == {"http": "http://proxy:1234"} assert not document_store._trust_env From 37c5264398c57b36817094db75dc5915c32be852 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 15 Feb 2024 16:02:33 +0100 Subject: [PATCH 2/9] Fix linting --- .../haystack_integrations/document_stores/weaviate/auth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py index f67269604..a2f0799ee 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields -from typing import Any, Dict +from typing import Any, Dict, Type from haystack.core.errors import DeserializationError from haystack.core.serialization import default_from_dict, default_to_dict @@ -162,7 +162,7 @@ def resolve_value(self) -> WeaviateAuthClientPassword: # This simplifies a bit how we handle deserialization of the auth credentials. -_AUTH_CLASSES = { +_AUTH_CLASSES: Dict[str, Type[AuthCredentials]] = { "haystack_integrations.document_stores.weaviate.auth.AuthClientCredentials": AuthClientCredentials, "haystack_integrations.document_stores.weaviate.auth.AuthClientPassword": AuthClientPassword, "haystack_integrations.document_stores.weaviate.auth.AuthBearerToken": AuthBearerToken, From c4d0e8e58b718813d2f104cdb798f7031efbbaab Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 15 Feb 2024 16:11:53 +0100 Subject: [PATCH 3/9] Update docs config --- integrations/weaviate/pydoc/config.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/integrations/weaviate/pydoc/config.yml b/integrations/weaviate/pydoc/config.yml index 84334c2e6..82ddbb057 100644 --- a/integrations/weaviate/pydoc/config.yml +++ b/integrations/weaviate/pydoc/config.yml @@ -3,6 +3,7 @@ loaders: search_path: [../src] modules: [ + "haystack_integrations.document_stores.weaviate.auth", "haystack_integrations.document_stores.weaviate.document_store", "haystack_integrations.components.retrievers.weaviate.bm25_retriever", "haystack_integrations.components.retrievers.weaviate.embedding_retriever", From 71e1b1e887ea36c4381f7304ea74d279b33af5da Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 15 Feb 2024 16:29:01 +0100 Subject: [PATCH 4/9] Export auth classes --- .../document_stores/weaviate/__init__.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py index c6f9b8776..87c7b6b01 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py @@ -1,6 +1,14 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from .auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword, AuthCredentials from .document_store import WeaviateDocumentStore -__all__ = ["WeaviateDocumentStore"] +__all__ = [ + "WeaviateDocumentStore", + "AuthApiKey", + "AuthBearerToken", + "AuthClientCredentials", + "AuthClientPassword", + "AuthCredentials", +] From 3e34e8152a4e65d93152f7c55a39b939265aed24 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 15 Feb 2024 17:17:44 +0100 Subject: [PATCH 5/9] Change expires_in to non secret --- .../document_stores/weaviate/auth.py | 19 ++++++++++--------- integrations/weaviate/tests/test_auth.py | 13 +++++-------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py index a2f0799ee..3dcf08b36 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py @@ -23,8 +23,12 @@ def to_dict(self) -> Dict[str, Any]: """ Converts the object to a dictionary representation for serialization. """ - # We assume all fields are Secret instances - _fields = {f.name: getattr(self, f.name).to_dict() for f in fields(self)} + _fields = {} + for field in fields(self): + if field.type is Secret: + _fields[field.name] = getattr(self, field.name).to_dict() + else: + _fields[field.name] = getattr(self, field.name) return default_to_dict( self, @@ -80,30 +84,27 @@ class AuthBearerToken(AuthCredentials): """ AuthCredentials for Bearer token authentication. By default it will load `access_token` from the environment variable `WEAVIATE_ACCESS_TOKEN`, - `expires_in` from the environment variable `WEAVIATE_EXPIRES_IN`, and `refresh_token` from the environment variable + and `refresh_token` from the environment variable `WEAVIATE_REFRESH_TOKEN`. - `WEAVIATE_EXPIRES_IN` environment variable is optional, if set must be an integer. `WEAVIATE_REFRESH_TOKEN` environment variable is optional. """ access_token: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_ACCESS_TOKEN"])) - expires_in: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_EXPIRES_IN"], strict=False)) + expires_in: int = field(default=60) refresh_token: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_REFRESH_TOKEN"], strict=False)) @classmethod def _from_dict(cls, data: Dict[str, Any]) -> "AuthBearerToken": - deserialize_secrets_inplace(data["init_parameters"], ["access_token", "expires_in", "refresh_token"]) + deserialize_secrets_inplace(data["init_parameters"], ["access_token", "refresh_token"]) return default_from_dict(cls, data) def resolve_value(self) -> WeaviateAuthBearerToken: access_token = self.access_token.resolve_value() - expires_in = self.expires_in.resolve_value() refresh_token = self.refresh_token.resolve_value() - expires_in = int(expires_in) if expires_in else 60 return WeaviateAuthBearerToken( access_token=access_token, - expires_in=expires_in, + expires_in=self.expires_in, refresh_token=refresh_token, ) diff --git a/integrations/weaviate/tests/test_auth.py b/integrations/weaviate/tests/test_auth.py index 14bdf1317..896976975 100644 --- a/integrations/weaviate/tests/test_auth.py +++ b/integrations/weaviate/tests/test_auth.py @@ -49,8 +49,7 @@ def test_init(self): credentials = AuthBearerToken() assert credentials.access_token._env_vars == ["WEAVIATE_ACCESS_TOKEN"] assert credentials.access_token._strict - assert credentials.expires_in._env_vars == ["WEAVIATE_EXPIRES_IN"] - assert not credentials.expires_in._strict + assert credentials.expires_in == 60 assert credentials.refresh_token._env_vars == ["WEAVIATE_REFRESH_TOKEN"] assert not credentials.refresh_token._strict @@ -60,7 +59,7 @@ def test_to_dict(self): "type": "haystack_integrations.document_stores.weaviate.auth.AuthBearerToken", "init_parameters": { "access_token": {"env_vars": ["WEAVIATE_ACCESS_TOKEN"], "strict": True, "type": "env_var"}, - "expires_in": {"env_vars": ["WEAVIATE_EXPIRES_IN"], "strict": False, "type": "env_var"}, + "expires_in": 60, "refresh_token": {"env_vars": ["WEAVIATE_REFRESH_TOKEN"], "strict": False, "type": "env_var"}, }, } @@ -71,23 +70,21 @@ def test_from_dict(self): "type": "haystack_integrations.document_stores.weaviate.auth.AuthBearerToken", "init_parameters": { "access_token": {"env_vars": ["WEAVIATE_ACCESS_TOKEN"], "strict": True, "type": "env_var"}, - "expires_in": {"env_vars": ["WEAVIATE_EXPIRES_IN"], "strict": False, "type": "env_var"}, + "expires_in": 10, "refresh_token": {"env_vars": ["WEAVIATE_REFRESH_TOKEN"], "strict": False, "type": "env_var"}, }, } ) assert credentials.access_token._env_vars == ["WEAVIATE_ACCESS_TOKEN"] assert credentials.access_token._strict - assert credentials.expires_in._env_vars == ["WEAVIATE_EXPIRES_IN"] - assert not credentials.expires_in._strict + assert credentials.expires_in == 10 assert credentials.refresh_token._env_vars == ["WEAVIATE_REFRESH_TOKEN"] assert not credentials.refresh_token._strict def test_resolve_value(self, monkeypatch): monkeypatch.setenv("WEAVIATE_ACCESS_TOKEN", "fake_key") - monkeypatch.setenv("WEAVIATE_EXPIRES_IN", "10") monkeypatch.setenv("WEAVIATE_REFRESH_TOKEN", "fake_refresh_token") - credentials = AuthBearerToken() + credentials = AuthBearerToken(expires_in=10) resolved = credentials.resolve_value() assert isinstance(resolved, WeaviateAuthBearerToken) assert resolved.access_token == "fake_key" From 57d0b1554cca8ff41b1a384ee60eadb37f055771 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 15 Feb 2024 17:40:34 +0100 Subject: [PATCH 6/9] Use enum for serialization types --- .../document_stores/weaviate/auth.py | 57 +++++++++++++------ integrations/weaviate/tests/test_auth.py | 16 +++--- 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py index 3dcf08b36..9c37ecfaf 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields +from enum import Enum from typing import Any, Dict, Type from haystack.core.errors import DeserializationError @@ -12,6 +13,30 @@ from weaviate.auth import AuthClientPassword as WeaviateAuthClientPassword +class SupportedAuthTypes(Enum): + """ + Supported auth credentials for WeaviateDocumentStore. + """ + + API_KEY = "api_key" + BEARER = "bearer" + CLIENT_CREDENTIALS = "client_credentials" + CLIENT_PASSWORD = "client_password" + + def __str__(self): + return self.value + + @staticmethod + def from_class(auth_class) -> "SupportedAuthTypes": + auth_types = { + AuthApiKey: SupportedAuthTypes.API_KEY, + AuthBearerToken: SupportedAuthTypes.BEARER, + AuthClientCredentials: SupportedAuthTypes.CLIENT_CREDENTIALS, + AuthClientPassword: SupportedAuthTypes.CLIENT_PASSWORD, + } + return auth_types[auth_class] + + @dataclass class AuthCredentials(ABC): """ @@ -30,10 +55,7 @@ def to_dict(self) -> Dict[str, Any]: else: _fields[field.name] = getattr(self, field.name) - return default_to_dict( - self, - **_fields, - ) + return {"type": str(SupportedAuthTypes.from_class(self.__class__)), "init_parameters": _fields} @staticmethod def from_dict(data: Dict[str, Any]) -> "AuthCredentials": @@ -43,7 +65,15 @@ def from_dict(data: Dict[str, Any]) -> "AuthCredentials": if "type" not in data: msg = "Missing 'type' in serialization data" raise DeserializationError(msg) - return _AUTH_CLASSES[data["type"]]._from_dict(data) + + auth_classes: Dict[str, Type[AuthCredentials]] = { + str(SupportedAuthTypes.API_KEY): AuthApiKey, + str(SupportedAuthTypes.BEARER): AuthBearerToken, + str(SupportedAuthTypes.CLIENT_CREDENTIALS): AuthClientCredentials, + str(SupportedAuthTypes.CLIENT_PASSWORD): AuthClientPassword, + } + + return auth_classes[data["type"]]._from_dict(data) @classmethod @abstractmethod @@ -73,7 +103,7 @@ class AuthApiKey(AuthCredentials): @classmethod def _from_dict(cls, data: Dict[str, Any]) -> "AuthApiKey": deserialize_secrets_inplace(data["init_parameters"], ["api_key"]) - return default_from_dict(cls, data) + return cls(**data["init_parameters"]) def resolve_value(self) -> WeaviateAuthApiKey: return WeaviateAuthApiKey(api_key=self.api_key.resolve_value()) @@ -96,7 +126,7 @@ class AuthBearerToken(AuthCredentials): @classmethod def _from_dict(cls, data: Dict[str, Any]) -> "AuthBearerToken": deserialize_secrets_inplace(data["init_parameters"], ["access_token", "refresh_token"]) - return default_from_dict(cls, data) + return cls(**data["init_parameters"]) def resolve_value(self) -> WeaviateAuthBearerToken: access_token = self.access_token.resolve_value() @@ -125,7 +155,7 @@ class AuthClientCredentials(AuthCredentials): @classmethod def _from_dict(cls, data: Dict[str, Any]) -> "AuthClientCredentials": deserialize_secrets_inplace(data["init_parameters"], ["client_secret", "scope"]) - return default_from_dict(cls, data) + return cls(**data["init_parameters"]) def resolve_value(self) -> WeaviateAuthClientCredentials: return WeaviateAuthClientCredentials( @@ -152,7 +182,7 @@ class AuthClientPassword(AuthCredentials): @classmethod def _from_dict(cls, data: Dict[str, Any]) -> "AuthClientPassword": deserialize_secrets_inplace(data["init_parameters"], ["username", "password", "scope"]) - return default_from_dict(cls, data) + return cls(**data["init_parameters"]) def resolve_value(self) -> WeaviateAuthClientPassword: return WeaviateAuthClientPassword( @@ -160,12 +190,3 @@ def resolve_value(self) -> WeaviateAuthClientPassword: password=self.password.resolve_value(), scope=self.scope.resolve_value(), ) - - -# This simplifies a bit how we handle deserialization of the auth credentials. -_AUTH_CLASSES: Dict[str, Type[AuthCredentials]] = { - "haystack_integrations.document_stores.weaviate.auth.AuthClientCredentials": AuthClientCredentials, - "haystack_integrations.document_stores.weaviate.auth.AuthClientPassword": AuthClientPassword, - "haystack_integrations.document_stores.weaviate.auth.AuthBearerToken": AuthBearerToken, - "haystack_integrations.document_stores.weaviate.auth.AuthApiKey": AuthApiKey, -} diff --git a/integrations/weaviate/tests/test_auth.py b/integrations/weaviate/tests/test_auth.py index 896976975..aad1b9e65 100644 --- a/integrations/weaviate/tests/test_auth.py +++ b/integrations/weaviate/tests/test_auth.py @@ -20,7 +20,7 @@ def test_init(self): def test_to_dict(self): credentials = AuthApiKey() assert credentials.to_dict() == { - "type": "haystack_integrations.document_stores.weaviate.auth.AuthApiKey", + "type": "api_key", "init_parameters": {"api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"}}, } @@ -28,7 +28,7 @@ def test_from_dict(self, monkeypatch): monkeypatch.setenv("WEAVIATE_API_KEY", "fake_key") credentials = AuthCredentials.from_dict( { - "type": "haystack_integrations.document_stores.weaviate.auth.AuthApiKey", + "type": "api_key", "init_parameters": {"api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"}}, } ) @@ -56,7 +56,7 @@ def test_init(self): def test_to_dict(self): credentials = AuthBearerToken() assert credentials.to_dict() == { - "type": "haystack_integrations.document_stores.weaviate.auth.AuthBearerToken", + "type": "bearer", "init_parameters": { "access_token": {"env_vars": ["WEAVIATE_ACCESS_TOKEN"], "strict": True, "type": "env_var"}, "expires_in": 60, @@ -67,7 +67,7 @@ def test_to_dict(self): def test_from_dict(self): credentials = AuthCredentials.from_dict( { - "type": "haystack_integrations.document_stores.weaviate.auth.AuthBearerToken", + "type": "bearer", "init_parameters": { "access_token": {"env_vars": ["WEAVIATE_ACCESS_TOKEN"], "strict": True, "type": "env_var"}, "expires_in": 10, @@ -103,7 +103,7 @@ def test_init(self): def test_to_dict(self): credentials = AuthClientCredentials() assert credentials.to_dict() == { - "type": "haystack_integrations.document_stores.weaviate.auth.AuthClientCredentials", + "type": "client_credentials", "init_parameters": { "client_secret": {"env_vars": ["WEAVIATE_CLIENT_SECRET"], "strict": True, "type": "env_var"}, "scope": {"env_vars": ["WEAVIATE_SCOPE"], "strict": False, "type": "env_var"}, @@ -113,7 +113,7 @@ def test_to_dict(self): def test_from_dict(self): credentials = AuthCredentials.from_dict( { - "type": "haystack_integrations.document_stores.weaviate.auth.AuthClientCredentials", + "type": "client_credentials", "init_parameters": { "client_secret": {"env_vars": ["WEAVIATE_CLIENT_SECRET"], "strict": True, "type": "env_var"}, "scope": {"env_vars": ["WEAVIATE_SCOPE"], "strict": False, "type": "env_var"}, @@ -148,7 +148,7 @@ def test_init(self): def test_to_dict(self): credentials = AuthClientPassword() assert credentials.to_dict() == { - "type": "haystack_integrations.document_stores.weaviate.auth.AuthClientPassword", + "type": "client_password", "init_parameters": { "username": {"env_vars": ["WEAVIATE_USERNAME"], "strict": True, "type": "env_var"}, "password": {"env_vars": ["WEAVIATE_PASSWORD"], "strict": True, "type": "env_var"}, @@ -159,7 +159,7 @@ def test_to_dict(self): def test_from_dict(self): credentials = AuthCredentials.from_dict( { - "type": "haystack_integrations.document_stores.weaviate.auth.AuthClientPassword", + "type": "client_password", "init_parameters": { "username": {"env_vars": ["WEAVIATE_USERNAME"], "strict": True, "type": "env_var"}, "password": {"env_vars": ["WEAVIATE_PASSWORD"], "strict": True, "type": "env_var"}, From 5bf2f8e66d3b7fb2be9aa2f5386e037b15a0f5b1 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 15 Feb 2024 17:41:18 +0100 Subject: [PATCH 7/9] Freeze dataclasses --- .../document_stores/weaviate/auth.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py index 9c37ecfaf..093a32da2 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py @@ -37,7 +37,7 @@ def from_class(auth_class) -> "SupportedAuthTypes": return auth_types[auth_class] -@dataclass +@dataclass(frozen=True) class AuthCredentials(ABC): """ Base class for all auth credentials supported by WeaviateDocumentStore. @@ -91,7 +91,7 @@ def resolve_value(self): """ -@dataclass +@dataclass(frozen=True) class AuthApiKey(AuthCredentials): """ AuthCredentials for API key authentication. @@ -109,7 +109,7 @@ def resolve_value(self) -> WeaviateAuthApiKey: return WeaviateAuthApiKey(api_key=self.api_key.resolve_value()) -@dataclass +@dataclass(frozen=True) class AuthBearerToken(AuthCredentials): """ AuthCredentials for Bearer token authentication. @@ -139,7 +139,7 @@ def resolve_value(self) -> WeaviateAuthBearerToken: ) -@dataclass +@dataclass(frozen=True) class AuthClientCredentials(AuthCredentials): """ AuthCredentials for client credentials authentication. @@ -164,7 +164,7 @@ def resolve_value(self) -> WeaviateAuthClientCredentials: ) -@dataclass +@dataclass(frozen=True) class AuthClientPassword(AuthCredentials): """ AuthCredentials for username and password authentication. From 74406e5c8b14ec2f1a5c07f33370a3bb3467bffa Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 15 Feb 2024 17:42:55 +0100 Subject: [PATCH 8/9] Fix linting --- .../document_stores/weaviate/auth.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py index 093a32da2..4c3898130 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py @@ -4,7 +4,6 @@ from typing import Any, Dict, Type from haystack.core.errors import DeserializationError -from haystack.core.serialization import default_from_dict, default_to_dict from haystack.utils.auth import Secret, deserialize_secrets_inplace from weaviate.auth import AuthApiKey as WeaviateAuthApiKey @@ -49,11 +48,11 @@ def to_dict(self) -> Dict[str, Any]: Converts the object to a dictionary representation for serialization. """ _fields = {} - for field in fields(self): - if field.type is Secret: - _fields[field.name] = getattr(self, field.name).to_dict() + for _field in fields(self): + if _field.type is Secret: + _fields[_field.name] = getattr(self, _field.name).to_dict() else: - _fields[field.name] = getattr(self, field.name) + _fields[_field.name] = getattr(self, _field.name) return {"type": str(SupportedAuthTypes.from_class(self.__class__)), "init_parameters": _fields} From 39dd9d250df701e47b76132ba6ed886befb56987 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 15 Feb 2024 17:48:38 +0100 Subject: [PATCH 9/9] Fix failing tests --- integrations/weaviate/tests/test_document_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index f1512b23b..a2b32d578 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -224,7 +224,7 @@ def test_to_dict(self, _mock_weaviate, monkeypatch): ], }, "auth_client_secret": { - "type": "haystack_integrations.document_stores.weaviate.auth.AuthApiKey", + "type": "api_key", "init_parameters": { "api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"} }, @@ -263,7 +263,7 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): "url": "http://localhost:8080", "collection_settings": None, "auth_client_secret": { - "type": "haystack_integrations.document_stores.weaviate.auth.AuthApiKey", + "type": "api_key", "init_parameters": { "api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"} },