Skip to content

Commit

Permalink
Add collection_name parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanocerza committed Jan 16, 2024
1 parent b930314 commit ceecf38
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
self,
*,
url: Optional[str] = None,
collection_name: str = "default",
auth_client_secret: Optional[AuthCredentials] = None,
timeout_config: TimeoutType = (10, 60),
proxies: Optional[Union[Dict, str]] = None,
Expand Down Expand Up @@ -79,6 +80,8 @@ def __init__(
:param embedded_options: If set create an embedded Weaviate cluster inside the client, defaults to None.
For a full list of options see `weaviate.embedded.EmbeddedOptions`.
:param additional_config: Additional and advanced configuration options for weaviate, defaults to None.
:param collection_name: The name of the collection to use, defaults to "default".
If the collection does not exist it will be created.
"""
self._client = weaviate.Client(
url=url,
Expand All @@ -92,7 +95,14 @@ def __init__(
additional_config=additional_config,
)

# Test connection, it will raise an exception if it fails.
self._client.schema.get()

if not self._client.schema.exists(collection_name):
self._client.schema.create_class({"class": collection_name})

self._url = url
self._collection_name = collection_name
self._auth_client_secret = auth_client_secret
self._timeout_config = timeout_config
self._proxies = proxies
Expand All @@ -114,6 +124,7 @@ def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
url=self._url,
collection_name=self._collection_name,
auth_client_secret=auth_client_secret,
timeout_config=self._timeout_config,
proxies=self._proxies,
Expand Down
50 changes: 49 additions & 1 deletion integrations/weaviate/tests/test_document_store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import patch
from unittest.mock import patch, MagicMock

from haystack_integrations.document_stores.weaviate.document_store import WeaviateDocumentStore
from weaviate.auth import AuthApiKey
Expand All @@ -13,10 +13,55 @@


class TestWeaviateDocumentStore:
@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.Client")
def test_init(self, mock_weaviate_client_class):
mock_client = MagicMock()
mock_client.schema.exists.return_value = False
mock_weaviate_client_class.return_value = mock_client

WeaviateDocumentStore(
url="http://localhost:8080",
collection_name="my_collection",
auth_client_secret=AuthApiKey("my_api_key"),
proxies={"http": "http://proxy:1234"},
additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
embedded_options=EmbeddedOptions(
persistence_data_path=DEFAULT_PERSISTENCE_DATA_PATH,
binary_path=DEFAULT_BINARY_PATH,
version="1.23.0",
hostname="127.0.0.1",
),
additional_config=Config(grpc_port_experimental=12345),
)

# Verify client is created with correct parameters
mock_weaviate_client_class.assert_called_once_with(
url="http://localhost:8080",
auth_client_secret=AuthApiKey("my_api_key"),
timeout_config=(10, 60),
proxies={"http": "http://proxy:1234"},
trust_env=False,
additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
startup_period=5,
embedded_options=EmbeddedOptions(
persistence_data_path=DEFAULT_PERSISTENCE_DATA_PATH,
binary_path=DEFAULT_BINARY_PATH,
version="1.23.0",
hostname="127.0.0.1",
),
additional_config=Config(grpc_port_experimental=12345),
)

# Verify collection is created
mock_client.schema.get.assert_called_once()
mock_client.schema.exists.assert_called_once_with("my_collection")
mock_client.schema.create_class.assert_called_once_with({"class": "my_collection"})

@patch("haystack_integrations.document_stores.weaviate.document_store.weaviate")
def test_to_dict(self, _mock_weaviate):
document_store = WeaviateDocumentStore(
url="http://localhost:8080",
collection_name="my_collection",
auth_client_secret=AuthApiKey("my_api_key"),
proxies={"http": "http://proxy:1234"},
additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"},
Expand All @@ -32,6 +77,7 @@ def test_to_dict(self, _mock_weaviate):
"type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": "http://localhost:8080",
"collection_name": "my_collection",
"auth_client_secret": {
"type": "weaviate.auth.AuthApiKey",
"init_parameters": {"api_key": "my_api_key"},
Expand Down Expand Up @@ -67,6 +113,7 @@ def test_from_dict(self, _mock_weaviate):
"type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore",
"init_parameters": {
"url": "http://localhost:8080",
"collection_name": "my_collection",
"auth_client_secret": {
"type": "weaviate.auth.AuthApiKey",
"init_parameters": {"api_key": "my_api_key"},
Expand Down Expand Up @@ -97,6 +144,7 @@ def test_from_dict(self, _mock_weaviate):
)

assert document_store._url == "http://localhost:8080"
assert document_store._collection_name == "my_collection"
assert document_store._auth_client_secret == AuthApiKey("my_api_key")
assert document_store._timeout_config == (10, 60)
assert document_store._proxies == {"http": "http://proxy:1234"}
Expand Down

0 comments on commit ceecf38

Please sign in to comment.