Skip to content

Commit

Permalink
Implement WeaviateDocumentStore initialization and serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza committed Jan 10, 2024
1 parent 07fd392 commit 2d7abb9
Show file tree
Hide file tree
Showing 2 changed files with 271 additions and 3 deletions.
156 changes: 155 additions & 1 deletion integrations/weaviate/src/weaviate_haystack/document_store.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,156 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Union, Dict, Tuple, Optional, List, Any
from dataclasses import asdict

import weaviate
from weaviate.auth import AuthCredentials
from weaviate.embedded import EmbeddedOptions
from weaviate.config import Config, ConnectionConfig

from haystack.core.serialization import default_to_dict, default_from_dict
from haystack.dataclasses.document import Document
from haystack.document_stores.protocol import DuplicatePolicy

Number = Union[int, float]
TimeoutType = Union[Tuple[Number, Number], Number]

# This simplifies a bit how we handle deserialization of the auth credentials.
# Otherwise we would need to use importlib to dynamically import the correct class.
_AUTH_CLASSES = {
"weaviate.auth.AuthClientCredentials": weaviate.auth.AuthClientCredentials,
"weaviate.auth.AuthClientPassword": weaviate.auth.AuthClientPassword,
"weaviate.auth.AuthBearerToken": weaviate.auth.AuthBearerToken,
"weaviate.auth.AuthApiKey": weaviate.auth.AuthApiKey,
}


class WeaviateDocumentStore:
pass
"""
WeaviateDocumentStore is a Document Store for Weaviate.
"""

def __init__(
self,
*,
url: Optional[str] = None,
auth_client_secret: Optional[AuthCredentials] = None,
timeout_config: TimeoutType = (10, 60),
proxies: Optional[Union[Dict, str]] = None,
trust_env: bool = False,
additional_headers: Optional[Dict] = None,
startup_period: Optional[int] = 5,
embedded_options: Optional[EmbeddedOptions] = None,
additional_config: Optional[Config] = None,
):
"""
Create a new instance of WeaviateDocumentStore and connects to the Weaviate instance.
:param url: The URL to the weaviate instance, defaults to None.
:param auth_client_secret: Authentication credentials, defaults to None.
Can be one of the following types depending on the authentication mode:
- `weaviate.auth.AuthBearerToken` to use existing access and (optionally, but recommended) refresh tokens
- `weaviate.auth.AuthClientPassword` to use username and password for oidc Resource Owner Password flow
- `weaviate.auth.AuthClientCredentials` to use a client secret for oidc client credential flow
- `weaviate.auth.AuthApiKey` to use an API key
:param timeout_config: Timeout configuration for all requests to the Weaviate server, defaults to (10, 60).
It can be a real number or, a tuple of two real numbers: (connect timeout, read timeout).
If only one real number is passed then both connect and read timeout will be set to
that value, by default (2, 20).
:param proxies: Proxy configuration, defaults to None.
Can be passed as a dict using the
``requests` format<https://docs.python-requests.org/en/stable/user/advanced/#proxies>`_,
or a string. If a string is passed it will be used for both HTTP and HTTPS requests.
:param trust_env: Whether to read proxies from the ENV variables, defaults to False.
Proxies will be read from the following ENV variables:
* `HTTP_PROXY`
* `http_proxy`
* `HTTPS_PROXY`
* `https_proxy`
If `proxies` is not None, `trust_env` is ignored.
:param additional_headers: Additional headers to include in the requests, defaults to None.
Can be used to set OpenAI/HuggingFace keys. OpenAI/HuggingFace key looks like this:
```
{"X-OpenAI-Api-Key": "<THE-KEY>"}, {"X-HuggingFace-Api-Key": "<THE-KEY>"}
```
:param startup_period: How many seconds the client will wait for Weaviate to start before
raising a RequestsConnectionError, defaults to 5.
:param embedded_options: If set create an embedded Weaviate cluster inside the client, defaults to None.
For a full list of options see `weaviate.embedded.EmbeddedOptions`.
:param additional_config: Additional and advanced configuration options for weaviate, defaults to None.
"""
self._client = weaviate.Client(
url=url,
auth_client_secret=auth_client_secret,
timeout_config=timeout_config,
proxies=proxies,
trust_env=trust_env,
additional_headers=additional_headers,
startup_period=startup_period,
embedded_options=embedded_options,
additional_config=additional_config,
)

self._url = url
self._auth_client_secret = auth_client_secret
self._timeout_config = timeout_config
self._proxies = proxies
self._trust_env = trust_env
self._additional_headers = additional_headers
self._startup_period = startup_period
self._embedded_options = embedded_options
self._additional_config = additional_config

def to_dict(self) -> Dict[str, Any]:
auth_client_secret = None
if self._auth_client_secret:
# There are different types of AuthCredentials, so even thought it's a dataclass
# and we could just use asdict, we need to save the type too.
auth_client_secret = default_to_dict(self._auth_client_secret, **asdict(self._auth_client_secret))
embedded_options = asdict(self._embedded_options) if self._embedded_options else None
additional_config = asdict(self._additional_config) if self._additional_config else None

return default_to_dict(
self,
url=self._url,
auth_client_secret=auth_client_secret,
timeout_config=self._timeout_config,
proxies=self._proxies,
trust_env=self._trust_env,
additional_headers=self._additional_headers,
startup_period=self._startup_period,
embedded_options=embedded_options,
additional_config=additional_config,
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "WeaviateDocumentStore":
if (timeout_config := data["init_parameters"].get("timeout_config")) is not None:
data["init_parameters"]["timeout_config"] = (
tuple(timeout_config) if isinstance(timeout_config, list) else timeout_config
)
if (auth_client_secret := data["init_parameters"].get("auth_client_secret")) is not None:
auth_class = _AUTH_CLASSES[auth_client_secret["type"]]
data["init_parameters"]["auth_client_secret"] = default_from_dict(auth_class, auth_client_secret)
if (embedded_options := data["init_parameters"].get("embedded_options")) is not None:
data["init_parameters"]["embedded_options"] = EmbeddedOptions(**embedded_options)
if (additional_config := data["init_parameters"].get("additional_config")) is not None:
additional_config["connection_config"] = ConnectionConfig(**additional_config["connection_config"])
data["init_parameters"]["additional_config"] = Config(**additional_config)
return default_from_dict(
cls,
data,
)

def count_documents(self) -> int:
...

def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
...

def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.NONE) -> int:
...

def delete_documents(self, document_ids: List[str]) -> None:
...
118 changes: 116 additions & 2 deletions integrations/weaviate/tests/test_document_store.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,116 @@
def test():
assert True
from unittest.mock import patch

from weaviate.auth import AuthApiKey
from weaviate.embedded import (
EmbeddedOptions,
DEFAULT_BINARY_PATH,
DEFAULT_PERSISTENCE_DATA_PATH,
DEFAULT_PORT,
DEFAULT_GRPC_PORT,
)
from weaviate.config import Config

from weaviate_haystack.document_store import WeaviateDocumentStore


class TestWeaviateDocumentStore:
@patch("weaviate_haystack.document_store.weaviate")
def test_to_dict(self, _mock_weaviate):
document_store = WeaviateDocumentStore(
url="http://localhost:8080",
auth_client_secret=AuthApiKey("my_api_key"),
proxies={"http": "http://proxy:1234"},
additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
embedded_options=EmbeddedOptions(
persistence_data_path=DEFAULT_PERSISTENCE_DATA_PATH,
binary_path=DEFAULT_BINARY_PATH,
version="1.23.0",
hostname="127.0.0.1",
),
additional_config=Config(grpc_port_experimental=12345),
)
assert document_store.to_dict() == {
"type": "weaviate_haystack.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": "http://localhost:8080",
"auth_client_secret": {
"type": "weaviate.auth.AuthApiKey",
"init_parameters": {"api_key": "my_api_key"},
},
"timeout_config": (10, 60),
"proxies": {"http": "http://proxy:1234"},
"trust_env": False,
"additional_headers": {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
"startup_period": 5,
"embedded_options": {
"persistence_data_path": DEFAULT_PERSISTENCE_DATA_PATH,
"binary_path": DEFAULT_BINARY_PATH,
"version": "1.23.0",
"port": DEFAULT_PORT,
"hostname": "127.0.0.1",
"additional_env_vars": None,
"grpc_port": DEFAULT_GRPC_PORT,
},
"additional_config": {
"grpc_port_experimental": 12345,
"connection_config": {
"session_pool_connections": 20,
"session_pool_maxsize": 20,
},
},
},
}

@patch("weaviate_haystack.document_store.weaviate")
def test_from_dict(self, _mock_weaviate):
document_store = WeaviateDocumentStore.from_dict(
{
"type": "weaviate_haystack.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": "http://localhost:8080",
"auth_client_secret": {
"type": "weaviate.auth.AuthApiKey",
"init_parameters": {"api_key": "my_api_key"},
},
"timeout_config": [10, 60],
"proxies": {"http": "http://proxy:1234"},
"trust_env": False,
"additional_headers": {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
"startup_period": 5,
"embedded_options": {
"persistence_data_path": DEFAULT_PERSISTENCE_DATA_PATH,
"binary_path": DEFAULT_BINARY_PATH,
"version": "1.23.0",
"port": DEFAULT_PORT,
"hostname": "127.0.0.1",
"additional_env_vars": None,
"grpc_port": DEFAULT_GRPC_PORT,
},
"additional_config": {
"grpc_port_experimental": 12345,
"connection_config": {
"session_pool_connections": 20,
"session_pool_maxsize": 20,
},
},
},
}
)

assert document_store._url == "http://localhost:8080"
assert document_store._auth_client_secret == AuthApiKey("my_api_key")
assert document_store._timeout_config == (10, 60)
assert document_store._proxies == {"http": "http://proxy:1234"}
assert not document_store._trust_env
assert document_store._additional_headers == {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}
assert document_store._startup_period == 5
assert document_store._embedded_options.persistence_data_path == DEFAULT_PERSISTENCE_DATA_PATH
assert document_store._embedded_options.binary_path == DEFAULT_BINARY_PATH
assert document_store._embedded_options.version == "1.23.0"
assert document_store._embedded_options.port == DEFAULT_PORT
assert document_store._embedded_options.hostname == "127.0.0.1"
assert document_store._embedded_options.additional_env_vars == None
assert document_store._embedded_options.grpc_port == DEFAULT_GRPC_PORT
assert document_store._additional_config.grpc_port_experimental == 12345
assert document_store._additional_config.connection_config.session_pool_connections == 20
assert document_store._additional_config.connection_config.session_pool_maxsize == 20

0 comments on commit 2d7abb9

Please sign in to comment.