From ceecf38d38d42a0c04da507e4153260c9e516a8e Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Tue, 16 Jan 2024 11:48:24 +0100 Subject: [PATCH] Add collection_name parameter --- .../weaviate/document_store.py | 11 ++++ .../weaviate/tests/test_document_store.py | 50 ++++++++++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 9317fb9de..4c15d707e 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -35,6 +35,7 @@ def __init__( self, *, url: Optional[str] = None, + collection_name: str = "default", auth_client_secret: Optional[AuthCredentials] = None, timeout_config: TimeoutType = (10, 60), proxies: Optional[Union[Dict, str]] = None, @@ -79,6 +80,8 @@ def __init__( :param embedded_options: If set create an embedded Weaviate cluster inside the client, defaults to None. For a full list of options see `weaviate.embedded.EmbeddedOptions`. :param additional_config: Additional and advanced configuration options for weaviate, defaults to None. + :param collection_name: The name of the collection to use, defaults to "default". + If the collection does not exist it will be created. """ self._client = weaviate.Client( url=url, @@ -92,7 +95,14 @@ def __init__( additional_config=additional_config, ) + # Test connection, it will raise an exception if it fails. + self._client.schema.get() + + if not self._client.schema.exists(collection_name): + self._client.schema.create_class({"class": collection_name}) + self._url = url + self._collection_name = collection_name self._auth_client_secret = auth_client_secret self._timeout_config = timeout_config self._proxies = proxies @@ -114,6 +124,7 @@ def to_dict(self) -> Dict[str, Any]: return default_to_dict( self, url=self._url, + collection_name=self._collection_name, auth_client_secret=auth_client_secret, timeout_config=self._timeout_config, proxies=self._proxies, diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index 3c72934b1..1944d64fb 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import patch, MagicMock from haystack_integrations.document_stores.weaviate.document_store import WeaviateDocumentStore from weaviate.auth import AuthApiKey @@ -13,10 +13,55 @@ class TestWeaviateDocumentStore: + @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.Client") + def test_init(self, mock_weaviate_client_class): + mock_client = MagicMock() + mock_client.schema.exists.return_value = False + mock_weaviate_client_class.return_value = mock_client + + WeaviateDocumentStore( + url="http://localhost:8080", + collection_name="my_collection", + auth_client_secret=AuthApiKey("my_api_key"), + proxies={"http": "http://proxy:1234"}, + additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, + embedded_options=EmbeddedOptions( + persistence_data_path=DEFAULT_PERSISTENCE_DATA_PATH, + binary_path=DEFAULT_BINARY_PATH, + version="1.23.0", + hostname="127.0.0.1", + ), + additional_config=Config(grpc_port_experimental=12345), + ) + + # Verify client is created with correct parameters + mock_weaviate_client_class.assert_called_once_with( + url="http://localhost:8080", + auth_client_secret=AuthApiKey("my_api_key"), + timeout_config=(10, 60), + proxies={"http": "http://proxy:1234"}, + trust_env=False, + additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, + startup_period=5, + embedded_options=EmbeddedOptions( + persistence_data_path=DEFAULT_PERSISTENCE_DATA_PATH, + binary_path=DEFAULT_BINARY_PATH, + version="1.23.0", + hostname="127.0.0.1", + ), + additional_config=Config(grpc_port_experimental=12345), + ) + + # Verify collection is created + mock_client.schema.get.assert_called_once() + mock_client.schema.exists.assert_called_once_with("my_collection") + mock_client.schema.create_class.assert_called_once_with({"class": "my_collection"}) + @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate") def test_to_dict(self, _mock_weaviate): document_store = WeaviateDocumentStore( url="http://localhost:8080", + collection_name="my_collection", auth_client_secret=AuthApiKey("my_api_key"), proxies={"http": "http://proxy:1234"}, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, @@ -32,6 +77,7 @@ def test_to_dict(self, _mock_weaviate): "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", "init_parameters": { "url": "http://localhost:8080", + "collection_name": "my_collection", "auth_client_secret": { "type": "weaviate.auth.AuthApiKey", "init_parameters": {"api_key": "my_api_key"}, @@ -67,6 +113,7 @@ def test_from_dict(self, _mock_weaviate): "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", "init_parameters": { "url": "http://localhost:8080", + "collection_name": "my_collection", "auth_client_secret": { "type": "weaviate.auth.AuthApiKey", "init_parameters": {"api_key": "my_api_key"}, @@ -97,6 +144,7 @@ def test_from_dict(self, _mock_weaviate): ) assert document_store._url == "http://localhost:8080" + assert document_store._collection_name == "my_collection" assert document_store._auth_client_secret == AuthApiKey("my_api_key") assert document_store._timeout_config == (10, 60) assert document_store._proxies == {"http": "http://proxy:1234"}