diff --git a/integrations/weaviate/docker-compose.yml b/integrations/weaviate/docker-compose.yml index c61b0ed57..f7f033eee 100644 --- a/integrations/weaviate/docker-compose.yml +++ b/integrations/weaviate/docker-compose.yml @@ -8,7 +8,7 @@ services: - '8080' - --scheme - http - image: semitechnologies/weaviate:1.23.2 + image: semitechnologies/weaviate:1.24.1 ports: - 8080:8080 - 50051:50051 diff --git a/integrations/weaviate/pyproject.toml b/integrations/weaviate/pyproject.toml index 421c2ce18..54d9ec21b 100644 --- a/integrations/weaviate/pyproject.toml +++ b/integrations/weaviate/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ ] dependencies = [ "haystack-ai", - "weaviate-client==3.*", + "weaviate-client", "haystack-pydoc-tools", "python-dateutil", ] diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py index a192c6947..a2201f0a5 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/_filters.py @@ -4,8 +4,11 @@ from haystack.errors import FilterError from pandas import DataFrame +import weaviate +from weaviate.collections.classes.filters import Filter, FilterReturn -def convert_filters(filters: Dict[str, Any]) -> Dict[str, Any]: + +def convert_filters(filters: Dict[str, Any]) -> FilterReturn: """ Convert filters from Haystack format to Weaviate format. """ @@ -14,7 +17,7 @@ def convert_filters(filters: Dict[str, Any]) -> Dict[str, Any]: raise FilterError(msg) if "field" in filters: - return {"operator": "And", "operands": [_parse_comparison_condition(filters)]} + return Filter.all_of([_parse_comparison_condition(filters)]) return _parse_logical_condition(filters) @@ -29,7 +32,7 @@ def convert_filters(filters: Dict[str, Any]) -> Dict[str, Any]: "not in": "in", "AND": "OR", "OR": "AND", - "NOT": "AND", + "NOT": "OR", } @@ -51,7 +54,13 @@ def _invert_condition(filters: Dict[str, Any]) -> Dict[str, Any]: return inverted_condition -def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: +LOGICAL_OPERATORS = { + "AND": Filter.all_of, + "OR": Filter.any_of, +} + + +def _parse_logical_condition(condition: Dict[str, Any]) -> FilterReturn: if "operator" not in condition: msg = f"'operator' key missing in {condition}" raise FilterError(msg) @@ -67,7 +76,7 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: operands.append(_parse_logical_condition(c)) else: operands.append(_parse_comparison_condition(c)) - return {"operator": operator.lower().capitalize(), "operands": operands} + return LOGICAL_OPERATORS[operator](operands) elif operator == "NOT": inverted_conditions = _invert_condition(condition) return _parse_logical_condition(inverted_conditions) @@ -76,28 +85,6 @@ def _parse_logical_condition(condition: Dict[str, Any]) -> Dict[str, Any]: raise FilterError(msg) -def _infer_value_type(value: Any) -> str: - if value is None: - return "valueNull" - - if isinstance(value, bool): - return "valueBoolean" - if isinstance(value, int): - return "valueInt" - if isinstance(value, float): - return "valueNumber" - - if isinstance(value, str): - try: - parser.isoparse(value) - return "valueDate" - except ValueError: - return "valueText" - - msg = f"Unknown value type {type(value)}" - raise FilterError(msg) - - def _handle_date(value: Any) -> str: if isinstance(value, str): try: @@ -107,25 +94,22 @@ def _handle_date(value: Any) -> str: return value -def _equal(field: str, value: Any) -> Dict[str, Any]: +def _equal(field: str, value: Any) -> FilterReturn: if value is None: - return {"path": field, "operator": "IsNull", "valueBoolean": True} - return {"path": field, "operator": "Equal", _infer_value_type(value): _handle_date(value)} + return weaviate.classes.query.Filter.by_property(field).is_none(True) + return weaviate.classes.query.Filter.by_property(field).equal(_handle_date(value)) -def _not_equal(field: str, value: Any) -> Dict[str, Any]: +def _not_equal(field: str, value: Any) -> FilterReturn: if value is None: - return {"path": field, "operator": "IsNull", "valueBoolean": False} - return { - "operator": "Or", - "operands": [ - {"path": field, "operator": "NotEqual", _infer_value_type(value): _handle_date(value)}, - {"path": field, "operator": "IsNull", "valueBoolean": True}, - ], - } + return weaviate.classes.query.Filter.by_property(field).is_none(False) + return weaviate.classes.query.Filter.by_property(field).not_equal( + _handle_date(value) + ) | weaviate.classes.query.Filter.by_property(field).is_none(True) -def _greater_than(field: str, value: Any) -> Dict[str, Any]: + +def _greater_than(field: str, value: Any) -> FilterReturn: if value is None: # When the value is None and '>' is used we create a filter that would return a Document # if it has a field set and not set at the same time. @@ -144,10 +128,10 @@ def _greater_than(field: str, value: Any) -> Dict[str, Any]: if type(value) in [list, DataFrame]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) - return {"path": field, "operator": "GreaterThan", _infer_value_type(value): _handle_date(value)} + return weaviate.classes.query.Filter.by_property(field).greater_than(_handle_date(value)) -def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: +def _greater_than_equal(field: str, value: Any) -> FilterReturn: if value is None: # When the value is None and '>=' is used we create a filter that would return a Document # if it has a field set and not set at the same time. @@ -166,10 +150,10 @@ def _greater_than_equal(field: str, value: Any) -> Dict[str, Any]: if type(value) in [list, DataFrame]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) - return {"path": field, "operator": "GreaterThanEqual", _infer_value_type(value): _handle_date(value)} + return weaviate.classes.query.Filter.by_property(field).greater_or_equal(_handle_date(value)) -def _less_than(field: str, value: Any) -> Dict[str, Any]: +def _less_than(field: str, value: Any) -> FilterReturn: if value is None: # When the value is None and '<' is used we create a filter that would return a Document # if it has a field set and not set at the same time. @@ -188,10 +172,10 @@ def _less_than(field: str, value: Any) -> Dict[str, Any]: if type(value) in [list, DataFrame]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) - return {"path": field, "operator": "LessThan", _infer_value_type(value): _handle_date(value)} + return weaviate.classes.query.Filter.by_property(field).less_than(_handle_date(value)) -def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: +def _less_than_equal(field: str, value: Any) -> FilterReturn: if value is None: # When the value is None and '<=' is used we create a filter that would return a Document # if it has a field set and not set at the same time. @@ -210,22 +194,23 @@ def _less_than_equal(field: str, value: Any) -> Dict[str, Any]: if type(value) in [list, DataFrame]: msg = f"Filter value can't be of type {type(value)} using operators '>', '>=', '<', '<='" raise FilterError(msg) - return {"path": field, "operator": "LessThanEqual", _infer_value_type(value): _handle_date(value)} + return weaviate.classes.query.Filter.by_property(field).less_or_equal(_handle_date(value)) -def _in(field: str, value: Any) -> Dict[str, Any]: +def _in(field: str, value: Any) -> FilterReturn: if not isinstance(value, list): msg = f"{field}'s value must be a list when using 'in' or 'not in' comparators" raise FilterError(msg) - return {"operator": "And", "operands": [_equal(field, v) for v in value]} + return weaviate.classes.query.Filter.by_property(field).contains_any(value) -def _not_in(field: str, value: Any) -> Dict[str, Any]: +def _not_in(field: str, value: Any) -> FilterReturn: if not isinstance(value, list): msg = f"{field}'s value must be a list when using 'in' or 'not in' comparators" raise FilterError(msg) - return {"operator": "And", "operands": [_not_equal(field, v) for v in value]} + operands = [weaviate.classes.query.Filter.by_property(field).not_equal(v) for v in value] + return Filter.all_of(operands) COMPARISON_OPERATORS = { @@ -240,7 +225,7 @@ def _not_in(field: str, value: Any) -> Dict[str, Any]: } -def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: +def _parse_comparison_condition(condition: Dict[str, Any]) -> FilterReturn: field: str = condition["field"] if field.startswith("meta."): @@ -265,15 +250,11 @@ def _parse_comparison_condition(condition: Dict[str, Any]) -> Dict[str, Any]: return COMPARISON_OPERATORS[operator](field, value) -def _match_no_document(field: str) -> Dict[str, Any]: +def _match_no_document(field: str) -> FilterReturn: """ Returns a filters that will match no Document, this is used to keep the behavior consistent between different Document Stores. """ - return { - "operator": "And", - "operands": [ - {"path": field, "operator": "IsNull", "valueBoolean": False}, - {"path": field, "operator": "IsNull", "valueBoolean": True}, - ], - } + + operands = [weaviate.classes.query.Filter.by_property(field).is_none(val) for val in [False, True]] + return Filter.all_of(operands) diff --git a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py index 071fe336b..34fefa0a5 100644 --- a/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py +++ b/integrations/weaviate/src/haystack_integrations/document_stores/weaviate/document_store.py @@ -2,6 +2,8 @@ # # SPDX-License-Identifier: Apache-2.0 import base64 +import datetime +import json from dataclasses import asdict from typing import Any, Dict, List, Optional, Tuple, Union @@ -11,7 +13,8 @@ from haystack.document_stores.types.policy import DuplicatePolicy import weaviate -from weaviate.config import Config, ConnectionConfig +from weaviate.collections.classes.data import DataObject +from weaviate.config import AdditionalConfig from weaviate.embedded import EmbeddedOptions from weaviate.util import generate_uuid5 @@ -42,6 +45,16 @@ {"name": "score", "dataType": ["number"]}, ] +# This is the default limit used when querying documents with WeaviateDocumentStore. +# +# We picked this as QUERY_MAXIMUM_RESULTS defaults to 10000, trying to get that many +# documents at once will fail, even if the query is paginated. +# This value will ensure we get the most documents possible without hitting that limit, it would +# still fail if the user lowers the QUERY_MAXIMUM_RESULTS environment variable for their Weaviate instance. +# +# See WeaviateDocumentStore._query_with_filters() for more information. +DEFAULT_QUERY_LIMIT = 9999 + class WeaviateDocumentStore: """ @@ -54,13 +67,11 @@ def __init__( url: Optional[str] = None, collection_settings: Optional[Dict[str, Any]] = None, auth_client_secret: Optional[AuthCredentials] = None, - timeout_config: TimeoutType = (10, 60), - proxies: Optional[Union[Dict, str]] = None, - trust_env: bool = False, additional_headers: Optional[Dict] = None, - startup_period: Optional[int] = 5, embedded_options: Optional[EmbeddedOptions] = None, - additional_config: Optional[Config] = None, + additional_config: Optional[AdditionalConfig] = None, + grpc_port: int = 50051, + grpc_secure: bool = False, ): """ Create a new instance of WeaviateDocumentStore and connects to the Weaviate instance. @@ -88,46 +99,35 @@ def __init__( - `AuthClientPassword` to use username and password for oidc Resource Owner Password flow - `AuthClientCredentials` to use a client secret for oidc client credential flow - `AuthApiKey` to use an API key - :param timeout_config: Timeout configuration for all requests to the Weaviate server, defaults to (10, 60). - It can be a real number or, a tuple of two real numbers: (connect timeout, read timeout). - If only one real number is passed then both connect and read timeout will be set to - that value, by default (2, 20). - :param proxies: Proxy configuration, defaults to None. - Can be passed as a dict using the - ``requests` format`_, - or a string. If a string is passed it will be used for both HTTP and HTTPS requests. - :param trust_env: Whether to read proxies from the ENV variables, defaults to False. - Proxies will be read from the following ENV variables: - * `HTTP_PROXY` - * `http_proxy` - * `HTTPS_PROXY` - * `https_proxy` - If `proxies` is not None, `trust_env` is ignored. :param additional_headers: Additional headers to include in the requests, defaults to None. Can be used to set OpenAI/HuggingFace keys. OpenAI/HuggingFace key looks like this: ``` {"X-OpenAI-Api-Key": ""}, {"X-HuggingFace-Api-Key": ""} ``` - :param startup_period: How many seconds the client will wait for Weaviate to start before - raising a RequestsConnectionError, defaults to 5. :param embedded_options: If set create an embedded Weaviate cluster inside the client, defaults to None. For a full list of options see `weaviate.embedded.EmbeddedOptions`. :param additional_config: Additional and advanced configuration options for weaviate, defaults to None. + :param grpc_port: The port to use for the gRPC connection, defaults to 50051. + :param grpc_secure: Whether to use a secure channel for the underlying gRPC API. """ - self._client = weaviate.Client( - url=url, + # proxies, timeout_config, trust_env are part of additional_config now + # startup_period has been removed + self._client = weaviate.WeaviateClient( + connection_params=( + weaviate.connect.base.ConnectionParams.from_url(url=url, grpc_port=grpc_port, grpc_secure=grpc_secure) + if url + else None + ), auth_client_secret=auth_client_secret.resolve_value() if auth_client_secret else None, - timeout_config=timeout_config, - proxies=proxies, - trust_env=trust_env, + additional_config=additional_config, additional_headers=additional_headers, - startup_period=startup_period, embedded_options=embedded_options, - additional_config=additional_config, + skip_init_checks=False, ) + self._client.connect() # Test connection, it will raise an exception if it fails. - self._client.schema.get() + self._client.collections._get_all(simple=True) if collection_settings is None: collection_settings = { @@ -141,64 +141,53 @@ def __init__( # Set the properties if they're not set collection_settings["properties"] = collection_settings.get("properties", DOCUMENT_COLLECTION_PROPERTIES) - if not self._client.schema.exists(collection_settings["class"]): - self._client.schema.create_class(collection_settings) + if not self._client.collections.exists(collection_settings["class"]): + self._client.collections.create_from_dict(collection_settings) self._url = url self._collection_settings = collection_settings self._auth_client_secret = auth_client_secret - self._timeout_config = timeout_config - self._proxies = proxies - self._trust_env = trust_env self._additional_headers = additional_headers - self._startup_period = startup_period self._embedded_options = embedded_options self._additional_config = additional_config + self._collection = self._client.collections.get(collection_settings["class"]) def to_dict(self) -> Dict[str, Any]: embedded_options = asdict(self._embedded_options) if self._embedded_options else None - additional_config = asdict(self._additional_config) if self._additional_config else None + additional_config = ( + json.loads(self._additional_config.model_dump_json(by_alias=True)) if self._additional_config else None + ) return default_to_dict( self, url=self._url, collection_settings=self._collection_settings, auth_client_secret=self._auth_client_secret.to_dict() if self._auth_client_secret else None, - timeout_config=self._timeout_config, - proxies=self._proxies, - trust_env=self._trust_env, additional_headers=self._additional_headers, - startup_period=self._startup_period, embedded_options=embedded_options, additional_config=additional_config, ) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "WeaviateDocumentStore": - if (timeout_config := data["init_parameters"].get("timeout_config")) is not None: - data["init_parameters"]["timeout_config"] = ( - tuple(timeout_config) if isinstance(timeout_config, list) else timeout_config - ) if (auth_client_secret := data["init_parameters"].get("auth_client_secret")) is not None: data["init_parameters"]["auth_client_secret"] = AuthCredentials.from_dict(auth_client_secret) if (embedded_options := data["init_parameters"].get("embedded_options")) is not None: data["init_parameters"]["embedded_options"] = EmbeddedOptions(**embedded_options) if (additional_config := data["init_parameters"].get("additional_config")) is not None: - additional_config["connection_config"] = ConnectionConfig(**additional_config["connection_config"]) - data["init_parameters"]["additional_config"] = Config(**additional_config) + data["init_parameters"]["additional_config"] = AdditionalConfig(**additional_config) return default_from_dict( cls, data, ) def count_documents(self) -> int: - collection_name = self._collection_settings["class"] - res = self._client.query.aggregate(collection_name).with_meta_count().do() - return res.get("data", {}).get("Aggregate", {}).get(collection_name, [{}])[0].get("meta", {}).get("count", 0) + total = self._collection.aggregate.over_all(total_count=True).total_count + return total if total else 0 def _to_data_object(self, document: Document) -> Dict[str, Any]: """ - Convert a Document to a Weviate data object ready to be saved. + Convert a Document to a Weaviate data object ready to be saved. """ data = document.to_dict() # Weaviate forces a UUID as an id. @@ -216,95 +205,82 @@ def _to_data_object(self, document: Document) -> Dict[str, Any]: return data - def _to_document(self, data: Dict[str, Any]) -> Document: + def _to_document(self, data: DataObject[Dict[str, Any], None]) -> Document: """ Convert a data object read from Weaviate into a Document. """ - data["id"] = data.pop("_original_id") - data["embedding"] = data["_additional"].pop("vector") if data["_additional"].get("vector") else None + document_data = data.properties + document_data["id"] = document_data.pop("_original_id") + if isinstance(data.vector, List): + document_data["embedding"] = data.vector + elif isinstance(data.vector, Dict): + document_data["embedding"] = data.vector.get("default") + else: + document_data["embedding"] = None - if (blob_data := data.get("blob_data")) is not None: - data["blob"] = { + if (blob_data := document_data.get("blob_data")) is not None: + document_data["blob"] = { "data": base64.b64decode(blob_data), - "mime_type": data.get("blob_mime_type"), + "mime_type": document_data.get("blob_mime_type"), } - # We always delete these fields as they're not part of the Document dataclass - data.pop("blob_data") - data.pop("blob_mime_type") - - # We don't need these fields anymore, this usually only contains the uuid - # used by Weaviate to identify the object and the embedding vector that we already extracted. - del data["_additional"] - - return Document.from_dict(data) - - def _query_paginated(self, properties: List[str], cursor=None): - collection_name = self._collection_settings["class"] - query = ( - self._client.query.get( - collection_name, - properties, - ) - .with_additional(["id vector"]) - .with_limit(100) - ) - - if cursor: - # Fetch the next set of results - result = query.with_after(cursor).do() - else: - # Fetch the first set of results - result = query.do() - - if "errors" in result: - errors = [e["message"] for e in result.get("errors", {})] - msg = "\n".join(errors) - msg = f"Failed to query documents in Weaviate. Errors:\n{msg}" - raise DocumentStoreError(msg) - - return result["data"]["Get"][collection_name] - - def _query_with_filters(self, properties: List[str], filters: Dict[str, Any]) -> List[Dict[str, Any]]: - collection_name = self._collection_settings["class"] - query = ( - self._client.query.get( - collection_name, - properties, - ) - .with_additional(["id vector"]) - .with_where(convert_filters(filters)) - ) - - result = query.do() - if "errors" in result: - errors = [e["message"] for e in result.get("errors", {})] - msg = "\n".join(errors) - msg = f"Failed to query documents in Weaviate. Errors:\n{msg}" - raise DocumentStoreError(msg) + # We always delete these fields as they're not part of the Document dataclass + document_data.pop("blob_data", None) + document_data.pop("blob_mime_type", None) + + for key, value in document_data.items(): + if isinstance(value, datetime.datetime): + document_data[key] = value.strftime("%Y-%m-%dT%H:%M:%SZ") + + return Document.from_dict(document_data) + + def _query(self) -> List[Dict[str, Any]]: + properties = [p.name for p in self._collection.config.get().properties] + try: + result = self._collection.iterator(include_vector=True, return_properties=properties) + except weaviate.exceptions.WeaviateQueryError as e: + msg = f"Failed to query documents in Weaviate. Error: {e.message}" + raise DocumentStoreError(msg) from e + return result - return result["data"]["Get"][collection_name] + def _query_with_filters(self, filters: Dict[str, Any]) -> List[Dict[str, Any]]: + properties = [p.name for p in self._collection.config.get().properties] + # When querying with filters we need to paginate using limit and offset as using + # a cursor with after is not possible. See the official docs: + # https://weaviate.io/developers/weaviate/api/graphql/additional-operators#cursor-with-after + # + # Nonetheless there's also another issue, paginating with limit and offset is not efficient + # and it's still restricted by the QUERY_MAXIMUM_RESULTS environment variable. + # If the sum of limit and offest is greater than QUERY_MAXIMUM_RESULTS an error is raised. + # See the official docs for more: + # https://weaviate.io/developers/weaviate/api/graphql/additional-operators#performance-considerations + offset = 0 + partial_result = None + result = [] + # Keep querying until we get all documents matching the filters + while partial_result is None or len(partial_result.objects) == DEFAULT_QUERY_LIMIT: + try: + partial_result = self._collection.query.fetch_objects( + filters=convert_filters(filters), + include_vector=True, + limit=DEFAULT_QUERY_LIMIT, + offset=offset, + return_properties=properties, + ) + except weaviate.exceptions.WeaviateQueryError as e: + msg = f"Failed to query documents in Weaviate. Error: {e.message}" + raise DocumentStoreError(msg) from e + result.extend(partial_result.objects) + offset += DEFAULT_QUERY_LIMIT + return result def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: - properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) - properties = [prop["name"] for prop in properties] - - if filters: - result = self._query_with_filters(properties, filters) - return [self._to_document(doc) for doc in result] - result = [] - - cursor = None - while batch := self._query_paginated(properties, cursor): - # Take the cursor before we convert the batch to Documents as we manipulate - # the batch dictionary and might lose that information. - cursor = batch[-1]["_additional"]["id"] - - for doc in batch: - result.append(self._to_document(doc)) - # Move the cursor to the last returned uuid - return result + if filters: + result = self._query_with_filters(filters) + else: + result = self._query() + return [self._to_document(doc) for doc in result] def _batch_write(self, documents: List[Document]) -> int: """ @@ -312,33 +288,35 @@ def _batch_write(self, documents: List[Document]) -> int: Documents with the same id will be overwritten. Raises in case of errors. """ - statuses = [] - for doc in documents: - if not isinstance(doc, Document): - msg = f"Expected a Document, got '{type(doc)}' instead." - raise ValueError(msg) - if self._client.batch.num_objects() == self._client.batch.recommended_num_objects: - # Batch is full, let's create the objects - statuses.extend(self._client.batch.create_objects()) - self._client.batch.add_data_object( - uuid=generate_uuid5(doc.id), - data_object=self._to_data_object(doc), - class_name=self._collection_settings["class"], - vector=doc.embedding, + + with self._client.batch.dynamic() as batch: + for doc in documents: + if not isinstance(doc, Document): + msg = f"Expected a Document, got '{type(doc)}' instead." + raise ValueError(msg) + + batch.add_object( + properties=self._to_data_object(doc), + collection=self._collection.name, + uuid=generate_uuid5(doc.id), + vector=doc.embedding, + ) + if failed_objects := self._client.batch.failed_objects: + # We fallback to use the UUID if the _original_id is not present, this is just to be + mapped_objects = {} + for obj in failed_objects: + properties = obj.object_.properties or {} + # We get the object uuid just in case the _original_id is not present. + # That's extremely unlikely to happen but let's stay on the safe side. + id_ = properties.get("_original_id", obj.object_.uuid) + mapped_objects[id_] = obj.message + + msg = "\n".join( + [ + f"Failed to write object with id '{id_}'. Error: '{message}'" + for id_, message in mapped_objects.items() + ] ) - # Write remaining documents - statuses.extend(self._client.batch.create_objects()) - - errors = [] - # Gather errors and number of written documents - for status in statuses: - result_status = status.get("result", {}).get("status") - if result_status == "FAILED": - errors.extend([e["message"] for e in status["result"]["errors"]["error"]]) - - if errors: - msg = "\n".join(errors) - msg = f"Failed to write documents in Weaviate. Errors:\n{msg}" raise DocumentStoreError(msg) # If the document already exists we get no status message back from Weaviate. @@ -359,22 +337,19 @@ def _write(self, documents: List[Document], policy: DuplicatePolicy) -> int: msg = f"Expected a Document, got '{type(doc)}' instead." raise ValueError(msg) - if policy == DuplicatePolicy.SKIP and self._client.data_object.exists( - uuid=generate_uuid5(doc.id), - class_name=self._collection_settings["class"], - ): + if policy == DuplicatePolicy.SKIP and self._collection.data.exists(uuid=generate_uuid5(doc.id)): # This Document already exists, we skip it continue try: - self._client.data_object.create( + self._collection.data.insert( uuid=generate_uuid5(doc.id), - data_object=self._to_data_object(doc), - class_name=self._collection_settings["class"], + properties=self._to_data_object(doc), vector=doc.embedding, ) + written += 1 - except weaviate.exceptions.ObjectAlreadyExistsException: + except weaviate.exceptions.UnexpectedStatusCodeError: if policy == DuplicatePolicy.FAIL: duplicate_errors_ids.append(doc.id) if duplicate_errors_ids: @@ -397,37 +372,21 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D return self._write(documents, policy) def delete_documents(self, document_ids: List[str]) -> None: - self._client.batch.delete_objects( - class_name=self._collection_settings["class"], - where={ - "path": ["id"], - "operator": "ContainsAny", - "valueTextArray": [generate_uuid5(doc_id) for doc_id in document_ids], - }, - ) + weaviate_ids = [generate_uuid5(doc_id) for doc_id in document_ids] + self._collection.data.delete_many(where=weaviate.classes.query.Filter.by_id().contains_any(weaviate_ids)) def _bm25_retrieval( self, query: str, filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None ) -> List[Document]: - collection_name = self._collection_settings["class"] - properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) - properties = [prop["name"] for prop in properties] - - query_builder = ( - self._client.query.get(collection_name, properties=properties) - .with_bm25(query=query, properties=["content"]) - .with_additional(["vector"]) + result = self._collection.query.bm25( + query=query, + filters=convert_filters(filters) if filters else None, + limit=top_k, + include_vector=True, + query_properties=["content"], ) - if filters: - query_builder = query_builder.with_where(convert_filters(filters)) - - if top_k: - query_builder = query_builder.with_limit(top_k) - - result = query_builder.do() - - return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]] + return [self._to_document(doc) for doc in result.objects] def _embedding_retrieval( self, @@ -441,30 +400,13 @@ def _embedding_retrieval( msg = "Can't use 'distance' and 'certainty' parameters together" raise ValueError(msg) - collection_name = self._collection_settings["class"] - properties = self._client.schema.get(self._collection_settings["class"]).get("properties", []) - properties = [prop["name"] for prop in properties] - - near_vector: Dict[str, Union[float, List[float]]] = { - "vector": query_embedding, - } - if distance is not None: - near_vector["distance"] = distance - - if certainty is not None: - near_vector["certainty"] = certainty - - query_builder = ( - self._client.query.get(collection_name, properties=properties) - .with_near_vector(near_vector) - .with_additional(["vector"]) + result = self._collection.query.near_vector( + near_vector=query_embedding, + distance=distance, + certainty=certainty, + include_vector=True, + filters=convert_filters(filters) if filters else None, + limit=top_k, ) - if filters: - query_builder = query_builder.with_where(convert_filters(filters)) - - if top_k: - query_builder = query_builder.with_limit(top_k) - - result = query_builder.do() - return [self._to_document(doc) for doc in result["data"]["Get"][collection_name]] + return [self._to_document(doc) for doc in result.objects] diff --git a/integrations/weaviate/tests/test_bm25_retriever.py b/integrations/weaviate/tests/test_bm25_retriever.py index 83f90735b..23b7c8f92 100644 --- a/integrations/weaviate/tests/test_bm25_retriever.py +++ b/integrations/weaviate/tests/test_bm25_retriever.py @@ -38,11 +38,7 @@ def test_to_dict(_mock_weaviate): ], }, "auth_client_secret": None, - "timeout_config": (10, 60), - "proxies": None, - "trust_env": False, "additional_headers": None, - "startup_period": 5, "embedded_options": None, "additional_config": None, }, @@ -76,11 +72,7 @@ def test_from_dict(_mock_weaviate): ], }, "auth_client_secret": None, - "timeout_config": (10, 60), - "proxies": None, - "trust_env": False, "additional_headers": None, - "startup_period": 5, "embedded_options": None, "additional_config": None, }, diff --git a/integrations/weaviate/tests/test_document_store.py b/integrations/weaviate/tests/test_document_store.py index a2b32d578..4c1659a86 100644 --- a/integrations/weaviate/tests/test_document_store.py +++ b/integrations/weaviate/tests/test_document_store.py @@ -7,6 +7,7 @@ from dateutil import parser from haystack.dataclasses.byte_stream import ByteStream from haystack.dataclasses.document import Document +from haystack.document_stores.errors import DocumentStoreError from haystack.testing.document_store import ( TEST_EMBEDDING_1, TEST_EMBEDDING_2, @@ -24,8 +25,10 @@ from numpy import array_equal as np_array_equal from numpy import float32 as np_float32 from pandas import DataFrame -from weaviate.auth import AuthApiKey as WeaviateAuthApiKey -from weaviate.config import Config +from weaviate.collections.classes.data import DataObject + +# from weaviate.auth import AuthApiKey as WeaviateAuthApiKey +from weaviate.config import AdditionalConfig, ConnectionConfig, Proxies, Timeout from weaviate.embedded import ( DEFAULT_BINARY_PATH, DEFAULT_GRPC_PORT, @@ -53,7 +56,7 @@ def document_store(self, request) -> WeaviateDocumentStore: collection_settings=collection_settings, ) yield store - store._client.schema.delete_class(collection_settings["class"]) + store._client.collections.delete(collection_settings["class"]) @pytest.fixture def filterable_docs(self) -> List[Document]: @@ -145,49 +148,48 @@ def assert_documents_are_equal(self, received: List[Document], expected: List[Do for key in meta_keys: assert received_meta.get(key) == expected_meta.get(key) - @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.Client") + @patch("haystack_integrations.document_stores.weaviate.document_store.weaviate.WeaviateClient") def test_init(self, mock_weaviate_client_class, monkeypatch): mock_client = MagicMock() - mock_client.schema.exists.return_value = False + mock_client.collections.exists.return_value = False mock_weaviate_client_class.return_value = mock_client monkeypatch.setenv("WEAVIATE_API_KEY", "my_api_key") WeaviateDocumentStore( - url="http://localhost:8080", collection_settings={"class": "My_collection"}, auth_client_secret=AuthApiKey(), - proxies={"http": "http://proxy:1234"}, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, embedded_options=EmbeddedOptions( persistence_data_path=DEFAULT_PERSISTENCE_DATA_PATH, binary_path=DEFAULT_BINARY_PATH, - version="1.23.0", + version="1.23.7", hostname="127.0.0.1", ), - additional_config=Config(grpc_port_experimental=12345), + additional_config=AdditionalConfig( + proxies={"http": "http://proxy:1234"}, trust_env=False, timeout=(10, 60) + ), ) # Verify client is created with correct parameters + mock_weaviate_client_class.assert_called_once_with( - url="http://localhost:8080", - auth_client_secret=WeaviateAuthApiKey("my_api_key"), - timeout_config=(10, 60), - proxies={"http": "http://proxy:1234"}, - trust_env=False, + auth_client_secret=AuthApiKey().resolve_value(), + connection_params=None, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, - startup_period=5, embedded_options=EmbeddedOptions( persistence_data_path=DEFAULT_PERSISTENCE_DATA_PATH, binary_path=DEFAULT_BINARY_PATH, - version="1.23.0", + version="1.23.7", hostname="127.0.0.1", ), - additional_config=Config(grpc_port_experimental=12345), + skip_init_checks=False, + additional_config=AdditionalConfig( + proxies={"http": "http://proxy:1234"}, trust_env=False, timeout=(10, 60) + ), ) # Verify collection is created - mock_client.schema.get.assert_called_once() - mock_client.schema.exists.assert_called_once_with("My_collection") - mock_client.schema.create_class.assert_called_once_with( + mock_client.collections.exists.assert_called_once_with("My_collection") + mock_client.collections.create_from_dict.assert_called_once_with( {"class": "My_collection", "properties": DOCUMENT_COLLECTION_PROPERTIES} ) @@ -197,7 +199,6 @@ def test_to_dict(self, _mock_weaviate, monkeypatch): document_store = WeaviateDocumentStore( url="http://localhost:8080", auth_client_secret=AuthApiKey(), - proxies={"http": "http://proxy:1234"}, additional_headers={"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, embedded_options=EmbeddedOptions( persistence_data_path=DEFAULT_PERSISTENCE_DATA_PATH, @@ -205,7 +206,12 @@ def test_to_dict(self, _mock_weaviate, monkeypatch): version="1.23.0", hostname="127.0.0.1", ), - additional_config=Config(grpc_port_experimental=12345), + additional_config=AdditionalConfig( + connection=ConnectionConfig(), + timeout=(30, 90), + trust_env=False, + proxies={"http": "http://proxy:1234"}, + ), ) assert document_store.to_dict() == { "type": "haystack_integrations.document_stores.weaviate.document_store.WeaviateDocumentStore", @@ -229,11 +235,7 @@ def test_to_dict(self, _mock_weaviate, monkeypatch): "api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"} }, }, - "timeout_config": (10, 60), - "proxies": {"http": "http://proxy:1234"}, - "trust_env": False, "additional_headers": {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, - "startup_period": 5, "embedded_options": { "persistence_data_path": DEFAULT_PERSISTENCE_DATA_PATH, "binary_path": DEFAULT_BINARY_PATH, @@ -244,11 +246,14 @@ def test_to_dict(self, _mock_weaviate, monkeypatch): "grpc_port": DEFAULT_GRPC_PORT, }, "additional_config": { - "grpc_port_experimental": 12345, - "connection_config": { + "connection": { "session_pool_connections": 20, - "session_pool_maxsize": 20, + "session_pool_maxsize": 100, + "session_pool_max_retries": 3, }, + "proxies": {"http": "http://proxy:1234", "https": None, "grpc": None}, + "timeout": [30, 90], + "trust_env": False, }, }, } @@ -268,11 +273,7 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): "api_key": {"env_vars": ["WEAVIATE_API_KEY"], "strict": True, "type": "env_var"} }, }, - "timeout_config": [10, 60], - "proxies": {"http": "http://proxy:1234"}, - "trust_env": False, "additional_headers": {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"}, - "startup_period": 5, "embedded_options": { "persistence_data_path": DEFAULT_PERSISTENCE_DATA_PATH, "binary_path": DEFAULT_BINARY_PATH, @@ -283,11 +284,13 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): "grpc_port": DEFAULT_GRPC_PORT, }, "additional_config": { - "grpc_port_experimental": 12345, - "connection_config": { + "connection": { "session_pool_connections": 20, "session_pool_maxsize": 20, }, + "proxies": {"http": "http://proxy:1234"}, + "timeout": [10, 60], + "trust_env": False, }, }, } @@ -307,11 +310,10 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): ], } assert document_store._auth_client_secret == AuthApiKey() - assert document_store._timeout_config == (10, 60) - assert document_store._proxies == {"http": "http://proxy:1234"} - assert not document_store._trust_env + assert document_store._additional_config.timeout == Timeout(query=10, insert=60) + assert document_store._additional_config.proxies == Proxies(http="http://proxy:1234", https=None, grpc=None) + assert not document_store._additional_config.trust_env assert document_store._additional_headers == {"X-HuggingFace-Api-Key": "MY_HUGGINGFACE_KEY"} - assert document_store._startup_period == 5 assert document_store._embedded_options.persistence_data_path == DEFAULT_PERSISTENCE_DATA_PATH assert document_store._embedded_options.binary_path == DEFAULT_BINARY_PATH assert document_store._embedded_options.version == "1.23.0" @@ -319,9 +321,8 @@ def test_from_dict(self, _mock_weaviate, monkeypatch): assert document_store._embedded_options.hostname == "127.0.0.1" assert document_store._embedded_options.additional_env_vars is None assert document_store._embedded_options.grpc_port == DEFAULT_GRPC_PORT - assert document_store._additional_config.grpc_port_experimental == 12345 - assert document_store._additional_config.connection_config.session_pool_connections == 20 - assert document_store._additional_config.connection_config.session_pool_maxsize == 20 + assert document_store._additional_config.connection.session_pool_connections == 20 + assert document_store._additional_config.connection.session_pool_maxsize == 20 def test_to_data_object(self, document_store, test_files_path): doc = Document(content="test doc") @@ -353,18 +354,18 @@ def test_to_data_object(self, document_store, test_files_path): def test_to_document(self, document_store, test_files_path): image = ByteStream.from_file_path(test_files_path / "robot1.jpg", mime_type="image/jpeg") - data = { - "_additional": { - "vector": [1, 2, 3], + data = DataObject( + properties={ + "_original_id": "123", + "content": "some content", + "blob_data": base64.b64encode(image.data).decode(), + "blob_mime_type": "image/jpeg", + "dataframe": None, + "score": None, + "key": "value", }, - "_original_id": "123", - "content": "some content", - "blob_data": base64.b64encode(image.data).decode(), - "blob_mime_type": "image/jpeg", - "dataframe": None, - "score": None, - "meta": {"key": "value"}, - } + vector={"default": [1, 2, 3]}, + ) doc = document_store._to_document(data) assert doc.id == "123" @@ -626,3 +627,22 @@ def test_embedding_retrieval_with_certainty(self, document_store): def test_embedding_retrieval_with_distance_and_certainty(self, document_store): with pytest.raises(ValueError): document_store._embedding_retrieval(query_embedding=[], distance=0.1, certainty=0.1) + + def test_filter_documents_below_default_limit(self, document_store): + docs = [] + for index in range(9998): + docs.append(Document(content="This is some content", meta={"index": index})) + document_store.write_documents(docs) + result = document_store.filter_documents( + {"field": "content", "operator": "==", "value": "This is some content"} + ) + + assert len(result) == 9998 + + def test_filter_documents_over_default_limit(self, document_store): + docs = [] + for index in range(10000): + docs.append(Document(content="This is some content", meta={"index": index})) + document_store.write_documents(docs) + with pytest.raises(DocumentStoreError): + document_store.filter_documents({"field": "content", "operator": "==", "value": "This is some content"}) diff --git a/integrations/weaviate/tests/test_embedding_retriever.py b/integrations/weaviate/tests/test_embedding_retriever.py index 7f07d8a24..a406c40db 100644 --- a/integrations/weaviate/tests/test_embedding_retriever.py +++ b/integrations/weaviate/tests/test_embedding_retriever.py @@ -49,11 +49,7 @@ def test_to_dict(_mock_weaviate): ], }, "auth_client_secret": None, - "timeout_config": (10, 60), - "proxies": None, - "trust_env": False, "additional_headers": None, - "startup_period": 5, "embedded_options": None, "additional_config": None, }, @@ -89,11 +85,7 @@ def test_from_dict(_mock_weaviate): ], }, "auth_client_secret": None, - "timeout_config": (10, 60), - "proxies": None, - "trust_env": False, "additional_headers": None, - "startup_period": 5, "embedded_options": None, "additional_config": None, }, diff --git a/integrations/weaviate/tests/test_filters.py b/integrations/weaviate/tests/test_filters.py index cf38d84be..c32d69e2f 100644 --- a/integrations/weaviate/tests/test_filters.py +++ b/integrations/weaviate/tests/test_filters.py @@ -19,7 +19,7 @@ def test_invert_conditions(): inverted = _invert_condition(filters) assert inverted == { - "operator": "AND", + "operator": "OR", "conditions": [ {"field": "meta.number", "operator": "!=", "value": 100}, {"field": "meta.name", "operator": "!=", "value": "name_0"},