Skip to content

Commit

Permalink
Update WeaviateDocumentStore authentication to use new Secret class (#…
Browse files Browse the repository at this point in the history
…425)

* Update WeaviateDocumentStore authentication to use new Secret class

* Fix linting

* Update docs config

* Export auth classes

* Change expires_in to non secret

* Use enum for serialization types

* Freeze dataclasses

* Fix linting

* Fix failing tests
  • Loading branch information
silvanocerza authored Feb 15, 2024
1 parent 14c3de8 commit 5ca05dc
Show file tree
Hide file tree
Showing 6 changed files with 415 additions and 36 deletions.
1 change: 1 addition & 0 deletions integrations/weaviate/pydoc/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ loaders:
search_path: [../src]
modules:
[
"haystack_integrations.document_stores.weaviate.auth",
"haystack_integrations.document_stores.weaviate.document_store",
"haystack_integrations.components.retrievers.weaviate.bm25_retriever",
"haystack_integrations.components.retrievers.weaviate.embedding_retriever",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from .auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword, AuthCredentials
from .document_store import WeaviateDocumentStore

__all__ = ["WeaviateDocumentStore"]
__all__ = [
"WeaviateDocumentStore",
"AuthApiKey",
"AuthBearerToken",
"AuthClientCredentials",
"AuthClientPassword",
"AuthCredentials",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, fields
from enum import Enum
from typing import Any, Dict, Type

from haystack.core.errors import DeserializationError
from haystack.utils.auth import Secret, deserialize_secrets_inplace

from weaviate.auth import AuthApiKey as WeaviateAuthApiKey
from weaviate.auth import AuthBearerToken as WeaviateAuthBearerToken
from weaviate.auth import AuthClientCredentials as WeaviateAuthClientCredentials
from weaviate.auth import AuthClientPassword as WeaviateAuthClientPassword


class SupportedAuthTypes(Enum):
"""
Supported auth credentials for WeaviateDocumentStore.
"""

API_KEY = "api_key"
BEARER = "bearer"
CLIENT_CREDENTIALS = "client_credentials"
CLIENT_PASSWORD = "client_password"

def __str__(self):
return self.value

@staticmethod
def from_class(auth_class) -> "SupportedAuthTypes":
auth_types = {
AuthApiKey: SupportedAuthTypes.API_KEY,
AuthBearerToken: SupportedAuthTypes.BEARER,
AuthClientCredentials: SupportedAuthTypes.CLIENT_CREDENTIALS,
AuthClientPassword: SupportedAuthTypes.CLIENT_PASSWORD,
}
return auth_types[auth_class]


@dataclass(frozen=True)
class AuthCredentials(ABC):
"""
Base class for all auth credentials supported by WeaviateDocumentStore.
Can be used to deserialize from dict any of the supported auth credentials.
"""

def to_dict(self) -> Dict[str, Any]:
"""
Converts the object to a dictionary representation for serialization.
"""
_fields = {}
for _field in fields(self):
if _field.type is Secret:
_fields[_field.name] = getattr(self, _field.name).to_dict()
else:
_fields[_field.name] = getattr(self, _field.name)

return {"type": str(SupportedAuthTypes.from_class(self.__class__)), "init_parameters": _fields}

@staticmethod
def from_dict(data: Dict[str, Any]) -> "AuthCredentials":
"""
Converts a dictionary representation to an auth credentials object.
"""
if "type" not in data:
msg = "Missing 'type' in serialization data"
raise DeserializationError(msg)

auth_classes: Dict[str, Type[AuthCredentials]] = {
str(SupportedAuthTypes.API_KEY): AuthApiKey,
str(SupportedAuthTypes.BEARER): AuthBearerToken,
str(SupportedAuthTypes.CLIENT_CREDENTIALS): AuthClientCredentials,
str(SupportedAuthTypes.CLIENT_PASSWORD): AuthClientPassword,
}

return auth_classes[data["type"]]._from_dict(data)

@classmethod
@abstractmethod
def _from_dict(cls, data: Dict[str, Any]):
"""
Internal method to convert a dictionary representation to an auth credentials object.
All subclasses must implement this method.
"""

@abstractmethod
def resolve_value(self):
"""
Resolves all the secrets in the auth credentials object and returns the corresponding Weaviate object.
All subclasses must implement this method.
"""


@dataclass(frozen=True)
class AuthApiKey(AuthCredentials):
"""
AuthCredentials for API key authentication.
By default it will load `api_key` from the environment variable `WEAVIATE_API_KEY`.
"""

api_key: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_API_KEY"]))

@classmethod
def _from_dict(cls, data: Dict[str, Any]) -> "AuthApiKey":
deserialize_secrets_inplace(data["init_parameters"], ["api_key"])
return cls(**data["init_parameters"])

def resolve_value(self) -> WeaviateAuthApiKey:
return WeaviateAuthApiKey(api_key=self.api_key.resolve_value())


@dataclass(frozen=True)
class AuthBearerToken(AuthCredentials):
"""
AuthCredentials for Bearer token authentication.
By default it will load `access_token` from the environment variable `WEAVIATE_ACCESS_TOKEN`,
and `refresh_token` from the environment variable
`WEAVIATE_REFRESH_TOKEN`.
`WEAVIATE_REFRESH_TOKEN` environment variable is optional.
"""

access_token: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_ACCESS_TOKEN"]))
expires_in: int = field(default=60)
refresh_token: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_REFRESH_TOKEN"], strict=False))

@classmethod
def _from_dict(cls, data: Dict[str, Any]) -> "AuthBearerToken":
deserialize_secrets_inplace(data["init_parameters"], ["access_token", "refresh_token"])
return cls(**data["init_parameters"])

def resolve_value(self) -> WeaviateAuthBearerToken:
access_token = self.access_token.resolve_value()
refresh_token = self.refresh_token.resolve_value()

return WeaviateAuthBearerToken(
access_token=access_token,
expires_in=self.expires_in,
refresh_token=refresh_token,
)


@dataclass(frozen=True)
class AuthClientCredentials(AuthCredentials):
"""
AuthCredentials for client credentials authentication.
By default it will load `client_secret` from the environment variable `WEAVIATE_CLIENT_SECRET`, and
`scope` from the environment variable `WEAVIATE_SCOPE`.
`WEAVIATE_SCOPE` environment variable is optional, if set it can either be a string or a list of space
separated strings. e.g "scope1" or "scope1 scope2".
"""

client_secret: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_CLIENT_SECRET"]))
scope: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_SCOPE"], strict=False))

@classmethod
def _from_dict(cls, data: Dict[str, Any]) -> "AuthClientCredentials":
deserialize_secrets_inplace(data["init_parameters"], ["client_secret", "scope"])
return cls(**data["init_parameters"])

def resolve_value(self) -> WeaviateAuthClientCredentials:
return WeaviateAuthClientCredentials(
client_secret=self.client_secret.resolve_value(),
scope=self.scope.resolve_value(),
)


@dataclass(frozen=True)
class AuthClientPassword(AuthCredentials):
"""
AuthCredentials for username and password authentication.
By default it will load `username` from the environment variable `WEAVIATE_USERNAME`,
`password` from the environment variable `WEAVIATE_PASSWORD`, and
`scope` from the environment variable `WEAVIATE_SCOPE`.
`WEAVIATE_SCOPE` environment variable is optional, if set it can either be a string or a list of space
separated strings. e.g "scope1" or "scope1 scope2".
"""

username: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_USERNAME"]))
password: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_PASSWORD"]))
scope: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_SCOPE"], strict=False))

@classmethod
def _from_dict(cls, data: Dict[str, Any]) -> "AuthClientPassword":
deserialize_secrets_inplace(data["init_parameters"], ["username", "password", "scope"])
return cls(**data["init_parameters"])

def resolve_value(self) -> WeaviateAuthClientPassword:
return WeaviateAuthClientPassword(
username=self.username.resolve_value(),
password=self.password.resolve_value(),
scope=self.scope.resolve_value(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,16 @@
from haystack.document_stores.types.policy import DuplicatePolicy

import weaviate
from weaviate.auth import AuthCredentials
from weaviate.config import Config, ConnectionConfig
from weaviate.embedded import EmbeddedOptions
from weaviate.util import generate_uuid5

from ._filters import convert_filters
from .auth import AuthCredentials

Number = Union[int, float]
TimeoutType = Union[Tuple[Number, Number], Number]

# This simplifies a bit how we handle deserialization of the auth credentials.
# Otherwise we would need to use importlib to dynamically import the correct class.
_AUTH_CLASSES = {
"weaviate.auth.AuthClientCredentials": weaviate.auth.AuthClientCredentials,
"weaviate.auth.AuthClientPassword": weaviate.auth.AuthClientPassword,
"weaviate.auth.AuthBearerToken": weaviate.auth.AuthBearerToken,
"weaviate.auth.AuthApiKey": weaviate.auth.AuthApiKey,
}

# This is the default collection properties for Weaviate.
# It's a list of properties that will be created on the collection.
Expand Down Expand Up @@ -92,10 +84,10 @@ def __init__(
for more information on collections and their properties.
:param auth_client_secret: Authentication credentials, defaults to None.
Can be one of the following types depending on the authentication mode:
- `weaviate.auth.AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens
- `weaviate.auth.AuthClientPassword` to use username and password for oidc Resource Owner Password flow
- `weaviate.auth.AuthClientCredentials` to use a client secret for oidc client credential flow
- `weaviate.auth.AuthApiKey` to use an API key
- `AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens
- `AuthClientPassword` to use username and password for oidc Resource Owner Password flow
- `AuthClientCredentials` to use a client secret for oidc client credential flow
- `AuthApiKey` to use an API key
:param timeout_config: Timeout configuration for all requests to the Weaviate server, defaults to (10, 60).
It can be a real number or, a tuple of two real numbers: (connect timeout, read timeout).
If only one real number is passed then both connect and read timeout will be set to
Expand Down Expand Up @@ -124,7 +116,7 @@ def __init__(
"""
self._client = weaviate.Client(
url=url,
auth_client_secret=auth_client_secret,
auth_client_secret=auth_client_secret.resolve_value() if auth_client_secret else None,
timeout_config=timeout_config,
proxies=proxies,
trust_env=trust_env,
Expand Down Expand Up @@ -164,19 +156,14 @@ def __init__(
self._additional_config = additional_config

def to_dict(self) -> Dict[str, Any]:
auth_client_secret = None
if self._auth_client_secret:
# There are different types of AuthCredentials, so even thought it's a dataclass
# and we could just use asdict, we need to save the type too.
auth_client_secret = default_to_dict(self._auth_client_secret, **asdict(self._auth_client_secret))
embedded_options = asdict(self._embedded_options) if self._embedded_options else None
additional_config = asdict(self._additional_config) if self._additional_config else None

return default_to_dict(
self,
url=self._url,
collection_settings=self._collection_settings,
auth_client_secret=auth_client_secret,
auth_client_secret=self._auth_client_secret.to_dict() if self._auth_client_secret else None,
timeout_config=self._timeout_config,
proxies=self._proxies,
trust_env=self._trust_env,
Expand All @@ -193,8 +180,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "WeaviateDocumentStore":
tuple(timeout_config) if isinstance(timeout_config, list) else timeout_config
)
if (auth_client_secret := data["init_parameters"].get("auth_client_secret")) is not None:
auth_class = _AUTH_CLASSES[auth_client_secret["type"]]
data["init_parameters"]["auth_client_secret"] = default_from_dict(auth_class, auth_client_secret)
data["init_parameters"]["auth_client_secret"] = AuthCredentials.from_dict(auth_client_secret)
if (embedded_options := data["init_parameters"].get("embedded_options")) is not None:
data["init_parameters"]["embedded_options"] = EmbeddedOptions(**embedded_options)
if (additional_config := data["init_parameters"].get("additional_config")) is not None:
Expand Down
Loading

0 comments on commit 5ca05dc

Please sign in to comment.