diff --git a/integrations/weaviate/pydoc/config.yml b/integrations/weaviate/pydoc/config.yml index 84334c2e6..82ddbb057 100644 --- a/integrations/weaviate/pydoc/config.yml +++ b/integrations/weaviate/pydoc/config.yml @@ -3,6 +3,7 @@ loaders: search_path: [../src] modules: [ + "haystack_integrations.document_stores.weaviate.auth", "haystack_integrations.document_stores.weaviate.document_store", "haystack_integrations.components.retrievers.weaviate.bm25_retriever", "haystack_integrations.components.retrievers.weaviate.embedding_retriever", diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py index c6f9b8776..87c7b6b01 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/__init__.py @@ -1,6 +1,14 @@ # SPDX-FileCopyrightText: 2023-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 +from .auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword, AuthCredentials from .document_store import WeaviateDocumentStore -__all__ = ["WeaviateDocumentStore"] +__all__ = [ + "WeaviateDocumentStore", + "AuthApiKey", + "AuthBearerToken", + "AuthClientCredentials", + "AuthClientPassword", + "AuthCredentials", +] diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py new file mode 100644 index 000000000..4c3898130 --- /dev/null +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/auth.py @@ -0,0 +1,191 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, field, fields +from enum import Enum +from typing import Any, Dict, Type + +from haystack.core.errors import DeserializationError +from haystack.utils.auth import Secret, deserialize_secrets_inplace + +from weaviate.auth import AuthApiKey as WeaviateAuthApiKey +from weaviate.auth import AuthBearerToken as WeaviateAuthBearerToken +from weaviate.auth import AuthClientCredentials as WeaviateAuthClientCredentials +from weaviate.auth import AuthClientPassword as WeaviateAuthClientPassword + + +class SupportedAuthTypes(Enum): + """ + Supported auth credentials for WeaviateDocumentStore. + """ + + API_KEY = "api_key" + BEARER = "bearer" + CLIENT_CREDENTIALS = "client_credentials" + CLIENT_PASSWORD = "client_password" + + def __str__(self): + return self.value + + @staticmethod + def from_class(auth_class) -> "SupportedAuthTypes": + auth_types = { + AuthApiKey: SupportedAuthTypes.API_KEY, + AuthBearerToken: SupportedAuthTypes.BEARER, + AuthClientCredentials: SupportedAuthTypes.CLIENT_CREDENTIALS, + AuthClientPassword: SupportedAuthTypes.CLIENT_PASSWORD, + } + return auth_types[auth_class] + + +@dataclass(frozen=True) +class AuthCredentials(ABC): + """ + Base class for all auth credentials supported by WeaviateDocumentStore. + Can be used to deserialize from dict any of the supported auth credentials. + """ + + def to_dict(self) -> Dict[str, Any]: + """ + Converts the object to a dictionary representation for serialization. + """ + _fields = {} + for _field in fields(self): + if _field.type is Secret: + _fields[_field.name] = getattr(self, _field.name).to_dict() + else: + _fields[_field.name] = getattr(self, _field.name) + + return {"type": str(SupportedAuthTypes.from_class(self.__class__)), "init_parameters": _fields} + + @staticmethod + def from_dict(data: Dict[str, Any]) -> "AuthCredentials": + """ + Converts a dictionary representation to an auth credentials object. + """ + if "type" not in data: + msg = "Missing 'type' in serialization data" + raise DeserializationError(msg) + + auth_classes: Dict[str, Type[AuthCredentials]] = { + str(SupportedAuthTypes.API_KEY): AuthApiKey, + str(SupportedAuthTypes.BEARER): AuthBearerToken, + str(SupportedAuthTypes.CLIENT_CREDENTIALS): AuthClientCredentials, + str(SupportedAuthTypes.CLIENT_PASSWORD): AuthClientPassword, + } + + return auth_classes[data["type"]]._from_dict(data) + + @classmethod + @abstractmethod + def _from_dict(cls, data: Dict[str, Any]): + """ + Internal method to convert a dictionary representation to an auth credentials object. + All subclasses must implement this method. + """ + + @abstractmethod + def resolve_value(self): + """ + Resolves all the secrets in the auth credentials object and returns the corresponding Weaviate object. + All subclasses must implement this method. + """ + + +@dataclass(frozen=True) +class AuthApiKey(AuthCredentials): + """ + AuthCredentials for API key authentication. + By default it will load `api_key` from the environment variable `WEAVIATE_API_KEY`. + """ + + api_key: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_API_KEY"])) + + @classmethod + def _from_dict(cls, data: Dict[str, Any]) -> "AuthApiKey": + deserialize_secrets_inplace(data["init_parameters"], ["api_key"]) + return cls(**data["init_parameters"]) + + def resolve_value(self) -> WeaviateAuthApiKey: + return WeaviateAuthApiKey(api_key=self.api_key.resolve_value()) + + +@dataclass(frozen=True) +class AuthBearerToken(AuthCredentials): + """ + AuthCredentials for Bearer token authentication. + By default it will load `access_token` from the environment variable `WEAVIATE_ACCESS_TOKEN`, + and `refresh_token` from the environment variable + `WEAVIATE_REFRESH_TOKEN`. + `WEAVIATE_REFRESH_TOKEN` environment variable is optional. + """ + + access_token: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_ACCESS_TOKEN"])) + expires_in: int = field(default=60) + refresh_token: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_REFRESH_TOKEN"], strict=False)) + + @classmethod + def _from_dict(cls, data: Dict[str, Any]) -> "AuthBearerToken": + deserialize_secrets_inplace(data["init_parameters"], ["access_token", "refresh_token"]) + return cls(**data["init_parameters"]) + + def resolve_value(self) -> WeaviateAuthBearerToken: + access_token = self.access_token.resolve_value() + refresh_token = self.refresh_token.resolve_value() + + return WeaviateAuthBearerToken( + access_token=access_token, + expires_in=self.expires_in, + refresh_token=refresh_token, + ) + + +@dataclass(frozen=True) +class AuthClientCredentials(AuthCredentials): + """ + AuthCredentials for client credentials authentication. + By default it will load `client_secret` from the environment variable `WEAVIATE_CLIENT_SECRET`, and + `scope` from the environment variable `WEAVIATE_SCOPE`. + `WEAVIATE_SCOPE` environment variable is optional, if set it can either be a string or a list of space + separated strings. e.g "scope1" or "scope1 scope2". + """ + + client_secret: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_CLIENT_SECRET"])) + scope: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_SCOPE"], strict=False)) + + @classmethod + def _from_dict(cls, data: Dict[str, Any]) -> "AuthClientCredentials": + deserialize_secrets_inplace(data["init_parameters"], ["client_secret", "scope"]) + return cls(**data["init_parameters"]) + + def resolve_value(self) -> WeaviateAuthClientCredentials: + return WeaviateAuthClientCredentials( + client_secret=self.client_secret.resolve_value(), + scope=self.scope.resolve_value(), + ) + + +@dataclass(frozen=True) +class AuthClientPassword(AuthCredentials): + """ + AuthCredentials for username and password authentication. + By default it will load `username` from the environment variable `WEAVIATE_USERNAME`, + `password` from the environment variable `WEAVIATE_PASSWORD`, and + `scope` from the environment variable `WEAVIATE_SCOPE`. + `WEAVIATE_SCOPE` environment variable is optional, if set it can either be a string or a list of space + separated strings. e.g "scope1" or "scope1 scope2". + """ + + username: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_USERNAME"])) + password: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_PASSWORD"])) + scope: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_SCOPE"], strict=False)) + + @classmethod + def _from_dict(cls, data: Dict[str, Any]) -> "AuthClientPassword": + deserialize_secrets_inplace(data["init_parameters"], ["username", "password", "scope"]) + return cls(**data["init_parameters"]) + + def resolve_value(self) -> WeaviateAuthClientPassword: + return WeaviateAuthClientPassword( + username=self.username.resolve_value(), + password=self.password.resolve_value(), + scope=self.scope.resolve_value(), + ) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 38f0b38cd..071fe336b 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -11,24 +11,16 @@ from haystack.document_stores.types.policy import DuplicatePolicy import weaviate -from weaviate.auth import AuthCredentials from weaviate.config import Config, ConnectionConfig from weaviate.embedded import EmbeddedOptions from weaviate.util import generate_uuid5 from ._filters import convert_filters +from .auth import AuthCredentials Number = Union[int, float] TimeoutType = Union[Tuple[Number, Number], Number] -# This simplifies a bit how we handle deserialization of the auth credentials. -# Otherwise we would need to use importlib to dynamically import the correct class. -_AUTH_CLASSES = { - "weaviate.auth.AuthClientCredentials": weaviate.auth.AuthClientCredentials, - "weaviate.auth.AuthClientPassword": weaviate.auth.AuthClientPassword, - "weaviate.auth.AuthBearerToken": weaviate.auth.AuthBearerToken, - "weaviate.auth.AuthApiKey": weaviate.auth.AuthApiKey, -} # This is the default collection properties for Weaviate. # It's a list of properties that will be created on the collection. @@ -92,10 +84,10 @@ def __init__( for more information on collections and their properties. :param auth_client_secret: Authentication credentials, defaults to None. Can be one of the following types depending on the authentication mode: - - `weaviate.auth.AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens - - `weaviate.auth.AuthClientPassword` to use username and password for oidc Resource Owner Password flow - - `weaviate.auth.AuthClientCredentials` to use a client secret for oidc client credential flow - - `weaviate.auth.AuthApiKey` to use an API key + - `AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens + - `AuthClientPassword` to use username and password for oidc Resource Owner Password flow + - `AuthClientCredentials` to use a client secret for oidc client credential flow + - `AuthApiKey` to use an API key :param timeout_config: Timeout configuration for all requests to the Weaviate server, defaults to (10, 60). It can be a real number or, a tuple of two real numbers: (connect timeout, read timeout). If only one real number is passed then both connect and read timeout will be set to @@ -124,7 +116,7 @@ def __init__( """ self._client = weaviate.Client( url=url, - auth_client_secret=auth_client_secret, + auth_client_secret=auth_client_secret.resolve_value() if auth_client_secret else None, timeout_config=timeout_config, proxies=proxies, trust_env=trust_env, @@ -164,11 +156,6 @@ def __init__( self._additional_config = additional_config def to_dict(self) -> Dict[str, Any]: - auth_client_secret = None - if self._auth_client_secret: - # There are different types of AuthCredentials, so even thought it's a dataclass - # and we could just use asdict, we need to save the type too. - auth_client_secret = default_to_dict(self._auth_client_secret, **asdict(self._auth_client_secret)) embedded_options = asdict(self._embedded_options) if self._embedded_options else None additional_config = asdict(self._additional_config) if self._additional_config else None @@ -176,7 +163,7 @@ def to_dict(self) -> Dict[str, Any]: self, url=self._url, collection_settings=self._collection_settings, - auth_client_secret=auth_client_secret, + auth_client_secret=self._auth_client_secret.to_dict() if self._auth_client_secret else None, timeout_config=self._timeout_config, proxies=self._proxies, trust_env=self._trust_env, @@ -193,8 +180,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "WeaviateDocumentStore": tuple(timeout_config) if isinstance(timeout_config, list) else timeout_config ) if (auth_client_secret := data["init_parameters"].get("auth_client_secret")) is not None: - auth_class = _AUTH_CLASSES[auth_client_secret["type"]] - data["init_parameters"]["auth_client_secret"] = default_from_dict(auth_class, auth_client_secret) + data["init_parameters"]["auth_client_secret"] = AuthCredentials.from_dict(auth_client_secret) if (embedded_options := data["init_parameters"].get("embedded_options")) is not None: data["init_parameters"]["embedded_options"] = EmbeddedOptions(**embedded_options) if (additional_config := data["init_parameters"].get("additional_config")) is not None: diff --git a/integrations/weaviate/tests/test_auth.py b/integrations/weaviate/tests/test_auth.py new file mode 100644 index 000000000..aad1b9e65 --- /dev/null +++ b/integrations/weaviate/tests/test_auth.py @@ -0,0 +1,186 @@ +from haystack_integrations.document_stores.weaviate.auth import ( + AuthApiKey, + AuthBearerToken, + AuthClientCredentials, + AuthClientPassword, + AuthCredentials, +) +from weaviate.auth import AuthApiKey as WeaviateAuthApiKey +from weaviate.auth import AuthBearerToken as WeaviateAuthBearerToken +from weaviate.auth import AuthClientCredentials as WeaviateAuthClientCredentials +from weaviate.auth import AuthClientPassword as WeaviateAuthClientPassword + + +class TestAuthApiKey: + def test_init(self): + credentials = AuthApiKey() + assert credentials.api_key._env_vars == ["WEAVIATE_API_KEY"] + assert credentials.api_key._strict + + def test_to_dict(self): + credentials = AuthApiKey() + assert credentials.to_dict() == { + "type": "api_key", + "init_parameters": {"api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"}}, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("WEAVIATE_API_KEY", "fake_key") + credentials = AuthCredentials.from_dict( + { + "type": "api_key", + "init_parameters": {"api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"}}, + } + ) + assert isinstance(credentials, AuthApiKey) + assert credentials.api_key._env_vars == ["WEAVIATE_API_KEY"] + assert credentials.api_key._strict + + def test_resolve_value(self, monkeypatch): + monkeypatch.setenv("WEAVIATE_API_KEY", "fake_key") + credentials = AuthApiKey() + resolved = credentials.resolve_value() + assert isinstance(resolved, WeaviateAuthApiKey) + assert resolved.api_key == "fake_key" + + +class TestAuthBearerToken: + def test_init(self): + credentials = AuthBearerToken() + assert credentials.access_token._env_vars == ["WEAVIATE_ACCESS_TOKEN"] + assert credentials.access_token._strict + assert credentials.expires_in == 60 + assert credentials.refresh_token._env_vars == ["WEAVIATE_REFRESH_TOKEN"] + assert not credentials.refresh_token._strict + + def test_to_dict(self): + credentials = AuthBearerToken() + assert credentials.to_dict() == { + "type": "bearer", + "init_parameters": { + "access_token": {"env_vars": ["WEAVIATE_ACCESS_TOKEN"], "strict": True, "type": "env_var"}, + "expires_in": 60, + "refresh_token": {"env_vars": ["WEAVIATE_REFRESH_TOKEN"], "strict": False, "type": "env_var"}, + }, + } + + def test_from_dict(self): + credentials = AuthCredentials.from_dict( + { + "type": "bearer", + "init_parameters": { + "access_token": {"env_vars": ["WEAVIATE_ACCESS_TOKEN"], "strict": True, "type": "env_var"}, + "expires_in": 10, + "refresh_token": {"env_vars": ["WEAVIATE_REFRESH_TOKEN"], "strict": False, "type": "env_var"}, + }, + } + ) + assert credentials.access_token._env_vars == ["WEAVIATE_ACCESS_TOKEN"] + assert credentials.access_token._strict + assert credentials.expires_in == 10 + assert credentials.refresh_token._env_vars == ["WEAVIATE_REFRESH_TOKEN"] + assert not credentials.refresh_token._strict + + def test_resolve_value(self, monkeypatch): + monkeypatch.setenv("WEAVIATE_ACCESS_TOKEN", "fake_key") + monkeypatch.setenv("WEAVIATE_REFRESH_TOKEN", "fake_refresh_token") + credentials = AuthBearerToken(expires_in=10) + resolved = credentials.resolve_value() + assert isinstance(resolved, WeaviateAuthBearerToken) + assert resolved.access_token == "fake_key" + assert resolved.expires_in == 10 + assert resolved.refresh_token == "fake_refresh_token" + + +class TestAuthClientCredentials: + def test_init(self): + credentials = AuthClientCredentials() + assert credentials.client_secret._env_vars == ["WEAVIATE_CLIENT_SECRET"] + assert credentials.client_secret._strict + assert credentials.scope._env_vars == ["WEAVIATE_SCOPE"] + assert not credentials.scope._strict + + def test_to_dict(self): + credentials = AuthClientCredentials() + assert credentials.to_dict() == { + "type": "client_credentials", + "init_parameters": { + "client_secret": {"env_vars": ["WEAVIATE_CLIENT_SECRET"], "strict": True, "type": "env_var"}, + "scope": {"env_vars": ["WEAVIATE_SCOPE"], "strict": False, "type": "env_var"}, + }, + } + + def test_from_dict(self): + credentials = AuthCredentials.from_dict( + { + "type": "client_credentials", + "init_parameters": { + "client_secret": {"env_vars": ["WEAVIATE_CLIENT_SECRET"], "strict": True, "type": "env_var"}, + "scope": {"env_vars": ["WEAVIATE_SCOPE"], "strict": False, "type": "env_var"}, + }, + } + ) + assert credentials.client_secret._env_vars == ["WEAVIATE_CLIENT_SECRET"] + assert credentials.client_secret._strict + assert credentials.scope._env_vars == ["WEAVIATE_SCOPE"] + assert not credentials.scope._strict + + def test_resolve_value(self, monkeypatch): + monkeypatch.setenv("WEAVIATE_CLIENT_SECRET", "fake_secret") + monkeypatch.setenv("WEAVIATE_SCOPE", "fake_scope another_fake_scope") + credentials = AuthClientCredentials() + resolved = credentials.resolve_value() + assert isinstance(resolved, WeaviateAuthClientCredentials) + assert resolved.client_secret == "fake_secret" + assert resolved.scope_list == ["fake_scope", "another_fake_scope"] + + +class TestAuthClientPassword: + def test_init(self): + credentials = AuthClientPassword() + assert credentials.username._env_vars == ["WEAVIATE_USERNAME"] + assert credentials.username._strict + assert credentials.password._env_vars == ["WEAVIATE_PASSWORD"] + assert credentials.password._strict + assert credentials.scope._env_vars == ["WEAVIATE_SCOPE"] + assert not credentials.scope._strict + + def test_to_dict(self): + credentials = AuthClientPassword() + assert credentials.to_dict() == { + "type": "client_password", + "init_parameters": { + "username": {"env_vars": ["WEAVIATE_USERNAME"], "strict": True, "type": "env_var"}, + "password": {"env_vars": ["WEAVIATE_PASSWORD"], "strict": True, "type": "env_var"}, + "scope": {"env_vars": ["WEAVIATE_SCOPE"], "strict": False, "type": "env_var"}, + }, + } + + def test_from_dict(self): + credentials = AuthCredentials.from_dict( + { + "type": "client_password", + "init_parameters": { + "username": {"env_vars": ["WEAVIATE_USERNAME"], "strict": True, "type": "env_var"}, + "password": {"env_vars": ["WEAVIATE_PASSWORD"], "strict": True, "type": "env_var"}, + "scope": {"env_vars": ["WEAVIATE_SCOPE"], "strict": False, "type": "env_var"}, + }, + } + ) + assert credentials.username._env_vars == ["WEAVIATE_USERNAME"] + assert credentials.username._strict + assert credentials.password._env_vars == ["WEAVIATE_PASSWORD"] + assert credentials.password._strict + assert credentials.scope._env_vars == ["WEAVIATE_SCOPE"] + assert not credentials.scope._strict + + def test_resolve_value(self, monkeypatch): + monkeypatch.setenv("WEAVIATE_USERNAME", "fake_username") + monkeypatch.setenv("WEAVIATE_PASSWORD", "fake_password") + monkeypatch.setenv("WEAVIATE_SCOPE", "fake_scope another_fake_scope") + credentials = AuthClientPassword() + resolved = credentials.resolve_value() + assert isinstance(resolved, WeaviateAuthClientPassword) + assert resolved.username == "fake_username" + assert resolved.password == "fake_password" + assert resolved.scope_list == ["fake_scope", "another_fake_scope"] diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 359af3670..a2b32d578 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -15,6 +15,7 @@ FilterDocumentsTest, WriteDocumentsTest, ) +from haystack_integrations.document_stores.weaviate.auth import AuthApiKey from haystack_integrations.document_stores.weaviate.document_store import ( DOCUMENT_COLLECTION_PROPERTIES, WeaviateDocumentStore, @@ -23,7 +24,7 @@ from numpy import array_equal as np_array_equal from numpy import float32 as np_float32 from pandas import DataFrame -from weaviate.auth import AuthApiKey +from weaviate.auth import AuthApiKey as WeaviateAuthApiKey from weaviate.config import Config from weaviate.embedded import ( DEFAULT_BINARY_PATH, @@ -145,15 +146,15 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do assert received_meta.get(key) == expected_meta.get(key) @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.Client") - def test_init(self, mock_weaviate_client_class): + def test_init(self, mock_weaviate_client_class, monkeypatch): mock_client = MagicMock() mock_client.schema.exists.return_value = False mock_weaviate_client_class.return_value = mock_client - + monkeypatch.setenv("WEAVIATE_API_KEY", "my_api_key") WeaviateDocumentStore( url="http://localhost:8080", collection_settings={"class": "My_collection"}, - auth_client_secret=AuthApiKey("my_api_key"), + auth_client_secret=AuthApiKey(), proxies={"http": "http://proxy:1234"}, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, embedded_options=EmbeddedOptions( @@ -168,7 +169,7 @@ def test_init(self, mock_weaviate_client_class): # Verify client is created with correct parameters mock_weaviate_client_class.assert_called_once_with( url="http://localhost:8080", - auth_client_secret=AuthApiKey("my_api_key"), + auth_client_secret=WeaviateAuthApiKey("my_api_key"), timeout_config=(10, 60), proxies={"http": "http://proxy:1234"}, trust_env=False, @@ -191,10 +192,11 @@ def test_init(self, mock_weaviate_client_class): ) @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") - def test_to_dict(self, _mock_weaviate): + def test_to_dict(self, _mock_weaviate, monkeypatch): + monkeypatch.setenv("WEAVIATE_API_KEY", "my_api_key") document_store = WeaviateDocumentStore( url="http://localhost:8080", - auth_client_secret=AuthApiKey("my_api_key"), + auth_client_secret=AuthApiKey(), proxies={"http": "http://proxy:1234"}, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, embedded_options=EmbeddedOptions( @@ -222,8 +224,10 @@ def test_to_dict(self, _mock_weaviate): ], }, "auth_client_secret": { - "type": "weaviate.auth.AuthApiKey", - "init_parameters": {"api_key": "my_api_key"}, + "type": "api_key", + "init_parameters": { + "api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"} + }, }, "timeout_config": (10, 60), "proxies": {"http": "http://proxy:1234"}, @@ -250,7 +254,8 @@ def test_to_dict(self, _mock_weaviate): } @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") - def test_from_dict(self, _mock_weaviate): + def test_from_dict(self, _mock_weaviate, monkeypatch): + monkeypatch.setenv("WEAVIATE_API_KEY", "my_api_key") document_store = WeaviateDocumentStore.from_dict( { "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", @@ -258,8 +263,10 @@ def test_from_dict(self, _mock_weaviate): "url": "http://localhost:8080", "collection_settings": None, "auth_client_secret": { - "type": "weaviate.auth.AuthApiKey", - "init_parameters": {"api_key": "my_api_key"}, + "type": "api_key", + "init_parameters": { + "api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"} + }, }, "timeout_config": [10, 60], "proxies": {"http": "http://proxy:1234"}, @@ -299,7 +306,7 @@ def test_from_dict(self, _mock_weaviate): {"name": "score", "dataType": ["number"]}, ], } - assert document_store._auth_client_secret == AuthApiKey("my_api_key") + assert document_store._auth_client_secret == AuthApiKey() assert document_store._timeout_config == (10, 60) assert document_store._proxies == {"http": "http://proxy:1234"} assert not document_store._trust_env