Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update WeaviateDocumentStore authentication to use new Secret class #425

Merged
merged 9 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integrations/weaviate/pydoc/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ loaders:
search_path: [../src]
modules:
[
"haystack_integrations.document_stores.weaviate.auth",
"haystack_integrations.document_stores.weaviate.document_store",
"haystack_integrations.components.retrievers.weaviate.bm25_retriever",
"haystack_integrations.components.retrievers.weaviate.embedding_retriever",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from .auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword, AuthCredentials
from .document_store import WeaviateDocumentStore

__all__ = ["WeaviateDocumentStore"]
__all__ = [
"WeaviateDocumentStore",
"AuthApiKey",
"AuthBearerToken",
"AuthClientCredentials",
"AuthClientPassword",
"AuthCredentials",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, fields
from enum import Enum
from typing import Any, Dict, Type

from haystack.core.errors import DeserializationError
from haystack.utils.auth import Secret, deserialize_secrets_inplace

from weaviate.auth import AuthApiKey as WeaviateAuthApiKey
from weaviate.auth import AuthBearerToken as WeaviateAuthBearerToken
from weaviate.auth import AuthClientCredentials as WeaviateAuthClientCredentials
from weaviate.auth import AuthClientPassword as WeaviateAuthClientPassword


class SupportedAuthTypes(Enum):
"""
Supported auth credentials for WeaviateDocumentStore.
"""

API_KEY = "api_key"
BEARER = "bearer"
CLIENT_CREDENTIALS = "client_credentials"
CLIENT_PASSWORD = "client_password"

def __str__(self):
return self.value

@staticmethod
def from_class(auth_class) -> "SupportedAuthTypes":
auth_types = {
AuthApiKey: SupportedAuthTypes.API_KEY,
AuthBearerToken: SupportedAuthTypes.BEARER,
AuthClientCredentials: SupportedAuthTypes.CLIENT_CREDENTIALS,
AuthClientPassword: SupportedAuthTypes.CLIENT_PASSWORD,
}
return auth_types[auth_class]


@dataclass(frozen=True)
class AuthCredentials(ABC):
"""
Base class for all auth credentials supported by WeaviateDocumentStore.
Can be used to deserialize from dict any of the supported auth credentials.
"""

def to_dict(self) -> Dict[str, Any]:
"""
Converts the object to a dictionary representation for serialization.
"""
_fields = {}
for _field in fields(self):
if _field.type is Secret:
_fields[_field.name] = getattr(self, _field.name).to_dict()
else:
_fields[_field.name] = getattr(self, _field.name)

return {"type": str(SupportedAuthTypes.from_class(self.__class__)), "init_parameters": _fields}

@staticmethod
def from_dict(data: Dict[str, Any]) -> "AuthCredentials":
"""
Converts a dictionary representation to an auth credentials object.
"""
if "type" not in data:
msg = "Missing 'type' in serialization data"
raise DeserializationError(msg)

auth_classes: Dict[str, Type[AuthCredentials]] = {
str(SupportedAuthTypes.API_KEY): AuthApiKey,
str(SupportedAuthTypes.BEARER): AuthBearerToken,
str(SupportedAuthTypes.CLIENT_CREDENTIALS): AuthClientCredentials,
str(SupportedAuthTypes.CLIENT_PASSWORD): AuthClientPassword,
}

return auth_classes[data["type"]]._from_dict(data)

@classmethod
@abstractmethod
def _from_dict(cls, data: Dict[str, Any]):
"""
Internal method to convert a dictionary representation to an auth credentials object.
All subclasses must implement this method.
"""

@abstractmethod
def resolve_value(self):
"""
Resolves all the secrets in the auth credentials object and returns the corresponding Weaviate object.
All subclasses must implement this method.
"""


@dataclass(frozen=True)
class AuthApiKey(AuthCredentials):
"""
AuthCredentials for API key authentication.
By default it will load `api_key` from the environment variable `WEAVIATE_API_KEY`.
"""

api_key: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_API_KEY"]))

@classmethod
def _from_dict(cls, data: Dict[str, Any]) -> "AuthApiKey":
deserialize_secrets_inplace(data["init_parameters"], ["api_key"])
return cls(**data["init_parameters"])

def resolve_value(self) -> WeaviateAuthApiKey:
return WeaviateAuthApiKey(api_key=self.api_key.resolve_value())


@dataclass(frozen=True)
class AuthBearerToken(AuthCredentials):
"""
AuthCredentials for Bearer token authentication.
By default it will load `access_token` from the environment variable `WEAVIATE_ACCESS_TOKEN`,
and `refresh_token` from the environment variable
`WEAVIATE_REFRESH_TOKEN`.
`WEAVIATE_REFRESH_TOKEN` environment variable is optional.
"""

access_token: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_ACCESS_TOKEN"]))
expires_in: int = field(default=60)
refresh_token: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_REFRESH_TOKEN"], strict=False))

@classmethod
def _from_dict(cls, data: Dict[str, Any]) -> "AuthBearerToken":
deserialize_secrets_inplace(data["init_parameters"], ["access_token", "refresh_token"])
return cls(**data["init_parameters"])

def resolve_value(self) -> WeaviateAuthBearerToken:
access_token = self.access_token.resolve_value()
refresh_token = self.refresh_token.resolve_value()

return WeaviateAuthBearerToken(
access_token=access_token,
expires_in=self.expires_in,
refresh_token=refresh_token,
)


@dataclass(frozen=True)
class AuthClientCredentials(AuthCredentials):
"""
AuthCredentials for client credentials authentication.
By default it will load `client_secret` from the environment variable `WEAVIATE_CLIENT_SECRET`, and
`scope` from the environment variable `WEAVIATE_SCOPE`.
`WEAVIATE_SCOPE` environment variable is optional, if set it can either be a string or a list of space
separated strings. e.g "scope1" or "scope1 scope2".
"""

client_secret: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_CLIENT_SECRET"]))
scope: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_SCOPE"], strict=False))

@classmethod
def _from_dict(cls, data: Dict[str, Any]) -> "AuthClientCredentials":
deserialize_secrets_inplace(data["init_parameters"], ["client_secret", "scope"])
return cls(**data["init_parameters"])

def resolve_value(self) -> WeaviateAuthClientCredentials:
return WeaviateAuthClientCredentials(
client_secret=self.client_secret.resolve_value(),
scope=self.scope.resolve_value(),
)


@dataclass(frozen=True)
class AuthClientPassword(AuthCredentials):
"""
AuthCredentials for username and password authentication.
By default it will load `username` from the environment variable `WEAVIATE_USERNAME`,
`password` from the environment variable `WEAVIATE_PASSWORD`, and
`scope` from the environment variable `WEAVIATE_SCOPE`.
`WEAVIATE_SCOPE` environment variable is optional, if set it can either be a string or a list of space
separated strings. e.g "scope1" or "scope1 scope2".
"""

username: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_USERNAME"]))
password: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_PASSWORD"]))
scope: Secret = field(default_factory=lambda: Secret.from_env_var(["WEAVIATE_SCOPE"], strict=False))

@classmethod
def _from_dict(cls, data: Dict[str, Any]) -> "AuthClientPassword":
deserialize_secrets_inplace(data["init_parameters"], ["username", "password", "scope"])
return cls(**data["init_parameters"])

def resolve_value(self) -> WeaviateAuthClientPassword:
return WeaviateAuthClientPassword(
username=self.username.resolve_value(),
password=self.password.resolve_value(),
scope=self.scope.resolve_value(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,16 @@
from haystack.document_stores.types.policy import DuplicatePolicy

import weaviate
from weaviate.auth import AuthCredentials
from weaviate.config import Config, ConnectionConfig
from weaviate.embedded import EmbeddedOptions
from weaviate.util import generate_uuid5

from ._filters import convert_filters
from .auth import AuthCredentials

Number = Union[int, float]
TimeoutType = Union[Tuple[Number, Number], Number]

# This simplifies a bit how we handle deserialization of the auth credentials.
# Otherwise we would need to use importlib to dynamically import the correct class.
_AUTH_CLASSES = {
"weaviate.auth.AuthClientCredentials": weaviate.auth.AuthClientCredentials,
"weaviate.auth.AuthClientPassword": weaviate.auth.AuthClientPassword,
"weaviate.auth.AuthBearerToken": weaviate.auth.AuthBearerToken,
"weaviate.auth.AuthApiKey": weaviate.auth.AuthApiKey,
}

# This is the default collection properties for Weaviate.
# It's a list of properties that will be created on the collection.
Expand Down Expand Up @@ -92,10 +84,10 @@ def __init__(
for more information on collections and their properties.
:param auth_client_secret: Authentication credentials, defaults to None.
Can be one of the following types depending on the authentication mode:
- `weaviate.auth.AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens
- `weaviate.auth.AuthClientPassword` to use username and password for oidc Resource Owner Password flow
- `weaviate.auth.AuthClientCredentials` to use a client secret for oidc client credential flow
- `weaviate.auth.AuthApiKey` to use an API key
- `AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens
- `AuthClientPassword` to use username and password for oidc Resource Owner Password flow
- `AuthClientCredentials` to use a client secret for oidc client credential flow
- `AuthApiKey` to use an API key
:param timeout_config: Timeout configuration for all requests to the Weaviate server, defaults to (10, 60).
It can be a real number or, a tuple of two real numbers: (connect timeout, read timeout).
If only one real number is passed then both connect and read timeout will be set to
Expand Down Expand Up @@ -124,7 +116,7 @@ def __init__(
"""
self._client = weaviate.Client(
url=url,
auth_client_secret=auth_client_secret,
auth_client_secret=auth_client_secret.resolve_value() if auth_client_secret else None,
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
timeout_config=timeout_config,
proxies=proxies,
trust_env=trust_env,
Expand Down Expand Up @@ -164,19 +156,14 @@ def __init__(
self._additional_config = additional_config

def to_dict(self) -> Dict[str, Any]:
auth_client_secret = None
if self._auth_client_secret:
# There are different types of AuthCredentials, so even thought it's a dataclass
# and we could just use asdict, we need to save the type too.
auth_client_secret = default_to_dict(self._auth_client_secret, **asdict(self._auth_client_secret))
embedded_options = asdict(self._embedded_options) if self._embedded_options else None
additional_config = asdict(self._additional_config) if self._additional_config else None

return default_to_dict(
self,
url=self._url,
collection_settings=self._collection_settings,
auth_client_secret=auth_client_secret,
auth_client_secret=self._auth_client_secret.to_dict() if self._auth_client_secret else None,
timeout_config=self._timeout_config,
proxies=self._proxies,
trust_env=self._trust_env,
Expand All @@ -193,8 +180,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "WeaviateDocumentStore":
tuple(timeout_config) if isinstance(timeout_config, list) else timeout_config
)
if (auth_client_secret := data["init_parameters"].get("auth_client_secret")) is not None:
auth_class = _AUTH_CLASSES[auth_client_secret["type"]]
data["init_parameters"]["auth_client_secret"] = default_from_dict(auth_class, auth_client_secret)
data["init_parameters"]["auth_client_secret"] = AuthCredentials.from_dict(auth_client_secret)
if (embedded_options := data["init_parameters"].get("embedded_options")) is not None:
data["init_parameters"]["embedded_options"] = EmbeddedOptions(**embedded_options)
if (additional_config := data["init_parameters"].get("additional_config")) is not None:
Expand Down
Loading
Loading