From 2d7abb9e7960afef758fb5d20de4f9335a1333d7 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Wed, 10 Jan 2024 18:29:31 +0100 Subject: [PATCH] Implement WeaviateDocumentStore initialization and serialization --- .../src/weaviate_haystack/document_store.py | 156 +++++++++++++++++- .../weaviate/tests/test_document_store.py | 118 ++++++++++++- 2 files changed, 271 insertions(+), 3 deletions(-) diff --git a/integrations/weaviate/src/weaviate_haystack/document_store.py b/integrations/weaviate/src/weaviate_haystack/document_store.py index 267c3a6d0..ba48241fe 100644 --- a/integrations/weaviate/src/weaviate_haystack/document_store.py +++ b/integrations/weaviate/src/weaviate_haystack/document_store.py @@ -1,2 +1,156 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Union, Dict, Tuple, Optional, List, Any +from dataclasses import asdict + +import weaviate +from weaviate.auth import AuthCredentials +from weaviate.embedded import EmbeddedOptions +from weaviate.config import Config, ConnectionConfig + +from haystack.core.serialization import default_to_dict, default_from_dict +from haystack.dataclasses.document import Document +from haystack.document_stores.protocol import DuplicatePolicy + +Number = Union[int, float] +TimeoutType = Union[Tuple[Number, Number], Number] + +# This simplifies a bit how we handle deserialization of the auth credentials. +# Otherwise we would need to use importlib to dynamically import the correct class. +_AUTH_CLASSES = { + "weaviate.auth.AuthClientCredentials": weaviate.auth.AuthClientCredentials, + "weaviate.auth.AuthClientPassword": weaviate.auth.AuthClientPassword, + "weaviate.auth.AuthBearerToken": weaviate.auth.AuthBearerToken, + "weaviate.auth.AuthApiKey": weaviate.auth.AuthApiKey, +} + + class WeaviateDocumentStore: - pass + """ + WeaviateDocumentStore is a Document Store for Weaviate. + """ + + def __init__( + self, + *, + url: Optional[str] = None, + auth_client_secret: Optional[AuthCredentials] = None, + timeout_config: TimeoutType = (10, 60), + proxies: Optional[Union[Dict, str]] = None, + trust_env: bool = False, + additional_headers: Optional[Dict] = None, + startup_period: Optional[int] = 5, + embedded_options: Optional[EmbeddedOptions] = None, + additional_config: Optional[Config] = None, + ): + """ + Create a new instance of WeaviateDocumentStore and connects to the Weaviate instance. + + :param url: The URL to the weaviate instance, defaults to None. + :param auth_client_secret: Authentication credentials, defaults to None. + Can be one of the following types depending on the authentication mode: + - `weaviate.auth.AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens + - `weaviate.auth.AuthClientPassword` to use username and password for oidc Resource Owner Password flow + - `weaviate.auth.AuthClientCredentials` to use a client secret for oidc client credential flow + - `weaviate.auth.AuthApiKey` to use an API key + :param timeout_config: Timeout configuration for all requests to the Weaviate server, defaults to (10, 60). + It can be a real number or, a tuple of two real numbers: (connect timeout, read timeout). + If only one real number is passed then both connect and read timeout will be set to + that value, by default (2, 20). + :param proxies: Proxy configuration, defaults to None. + Can be passed as a dict using the + ``requests` format`_, + or a string. If a string is passed it will be used for both HTTP and HTTPS requests. + :param trust_env: Whether to read proxies from the ENV variables, defaults to False. + Proxies will be read from the following ENV variables: + * `HTTP_PROXY` + * `http_proxy` + * `HTTPS_PROXY` + * `https_proxy` + If `proxies` is not None, `trust_env` is ignored. + :param additional_headers: Additional headers to include in the requests, defaults to None. + Can be used to set OpenAI/HuggingFace keys. OpenAI/HuggingFace key looks like this: + ``` + {"X-OpenAI-Api-Key": ""}, {"X-HuggingFace-Api-Key": ""} + ``` + :param startup_period: How many seconds the client will wait for Weaviate to start before + raising a RequestsConnectionError, defaults to 5. + :param embedded_options: If set create an embedded Weaviate cluster inside the client, defaults to None. + For a full list of options see `weaviate.embedded.EmbeddedOptions`. + :param additional_config: Additional and advanced configuration options for weaviate, defaults to None. + """ + self._client = weaviate.Client( + url=url, + auth_client_secret=auth_client_secret, + timeout_config=timeout_config, + proxies=proxies, + trust_env=trust_env, + additional_headers=additional_headers, + startup_period=startup_period, + embedded_options=embedded_options, + additional_config=additional_config, + ) + + self._url = url + self._auth_client_secret = auth_client_secret + self._timeout_config = timeout_config + self._proxies = proxies + self._trust_env = trust_env + self._additional_headers = additional_headers + self._startup_period = startup_period + self._embedded_options = embedded_options + self._additional_config = additional_config + + def to_dict(self) -> Dict[str, Any]: + auth_client_secret = None + if self._auth_client_secret: + # There are different types of AuthCredentials, so even thought it's a dataclass + # and we could just use asdict, we need to save the type too. + auth_client_secret = default_to_dict(self._auth_client_secret, **asdict(self._auth_client_secret)) + embedded_options = asdict(self._embedded_options) if self._embedded_options else None + additional_config = asdict(self._additional_config) if self._additional_config else None + + return default_to_dict( + self, + url=self._url, + auth_client_secret=auth_client_secret, + timeout_config=self._timeout_config, + proxies=self._proxies, + trust_env=self._trust_env, + additional_headers=self._additional_headers, + startup_period=self._startup_period, + embedded_options=embedded_options, + additional_config=additional_config, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "WeaviateDocumentStore": + if (timeout_config := data["init_parameters"].get("timeout_config")) is not None: + data["init_parameters"]["timeout_config"] = ( + tuple(timeout_config) if isinstance(timeout_config, list) else timeout_config + ) + if (auth_client_secret := data["init_parameters"].get("auth_client_secret")) is not None: + auth_class = _AUTH_CLASSES[auth_client_secret["type"]] + data["init_parameters"]["auth_client_secret"] = default_from_dict(auth_class, auth_client_secret) + if (embedded_options := data["init_parameters"].get("embedded_options")) is not None: + data["init_parameters"]["embedded_options"] = EmbeddedOptions(**embedded_options) + if (additional_config := data["init_parameters"].get("additional_config")) is not None: + additional_config["connection_config"] = ConnectionConfig(**additional_config["connection_config"]) + data["init_parameters"]["additional_config"] = Config(**additional_config) + return default_from_dict( + cls, + data, + ) + + def count_documents(self) -> int: + ... + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + ... + + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int: + ... + + def delete_documents(self, document_ids: List[str]) -> None: + ... diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 993c50b78..90ac4e848 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -1,2 +1,116 @@ -def test(): - assert True +from unittest.mock import patch + +from weaviate.auth import AuthApiKey +from weaviate.embedded import ( + EmbeddedOptions, + DEFAULT_BINARY_PATH, + DEFAULT_PERSISTENCE_DATA_PATH, + DEFAULT_PORT, + DEFAULT_GRPC_PORT, +) +from weaviate.config import Config + +from weaviate_haystack.document_store import WeaviateDocumentStore + + +class TestWeaviateDocumentStore: + @patch("weaviate_haystack.document_store.weaviate") + def test_to_dict(self, _mock_weaviate): + document_store = WeaviateDocumentStore( + url="http://localhost:8080", + auth_client_secret=AuthApiKey("my_api_key"), + proxies={"http": "http://proxy:1234"}, + additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, + embedded_options=EmbeddedOptions( + persistence_data_path=DEFAULT_PERSISTENCE_DATA_PATH, + binary_path=DEFAULT_BINARY_PATH, + version="1.23.0", + hostname="127.0.0.1", + ), + additional_config=Config(grpc_port_experimental=12345), + ) + assert document_store.to_dict() == { + "type": "weaviate_haystack.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": "http://localhost:8080", + "auth_client_secret": { + "type": "weaviate.auth.AuthApiKey", + "init_parameters": {"api_key": "my_api_key"}, + }, + "timeout_config": (10, 60), + "proxies": {"http": "http://proxy:1234"}, + "trust_env": False, + "additional_headers": {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, + "startup_period": 5, + "embedded_options": { + "persistence_data_path": DEFAULT_PERSISTENCE_DATA_PATH, + "binary_path": DEFAULT_BINARY_PATH, + "version": "1.23.0", + "port": DEFAULT_PORT, + "hostname": "127.0.0.1", + "additional_env_vars": None, + "grpc_port": DEFAULT_GRPC_PORT, + }, + "additional_config": { + "grpc_port_experimental": 12345, + "connection_config": { + "session_pool_connections": 20, + "session_pool_maxsize": 20, + }, + }, + }, + } + + @patch("weaviate_haystack.document_store.weaviate") + def test_from_dict(self, _mock_weaviate): + document_store = WeaviateDocumentStore.from_dict( + { + "type": "weaviate_haystack.document_store.WeaviateDocumentStore", + "init_parameters": { + "url": "http://localhost:8080", + "auth_client_secret": { + "type": "weaviate.auth.AuthApiKey", + "init_parameters": {"api_key": "my_api_key"}, + }, + "timeout_config": [10, 60], + "proxies": {"http": "http://proxy:1234"}, + "trust_env": False, + "additional_headers": {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, + "startup_period": 5, + "embedded_options": { + "persistence_data_path": DEFAULT_PERSISTENCE_DATA_PATH, + "binary_path": DEFAULT_BINARY_PATH, + "version": "1.23.0", + "port": DEFAULT_PORT, + "hostname": "127.0.0.1", + "additional_env_vars": None, + "grpc_port": DEFAULT_GRPC_PORT, + }, + "additional_config": { + "grpc_port_experimental": 12345, + "connection_config": { + "session_pool_connections": 20, + "session_pool_maxsize": 20, + }, + }, + }, + } + ) + + assert document_store._url == "http://localhost:8080" + assert document_store._auth_client_secret == AuthApiKey("my_api_key") + assert document_store._timeout_config == (10, 60) + assert document_store._proxies == {"http": "http://proxy:1234"} + assert not document_store._trust_env + assert document_store._additional_headers == {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"} + assert document_store._startup_period == 5 + assert document_store._embedded_options.persistence_data_path == DEFAULT_PERSISTENCE_DATA_PATH + assert document_store._embedded_options.binary_path == DEFAULT_BINARY_PATH + assert document_store._embedded_options.version == "1.23.0" + assert document_store._embedded_options.port == DEFAULT_PORT + assert document_store._embedded_options.hostname == "127.0.0.1" + assert document_store._embedded_options.additional_env_vars == None + assert document_store._embedded_options.grpc_port == DEFAULT_GRPC_PORT + assert document_store._additional_config.grpc_port_experimental == 12345 + assert document_store._additional_config.connection_config.session_pool_connections == 20 + assert document_store._additional_config.connection_config.session_pool_maxsize == 20