From b541bf1c8cd3b3dfec8832c7b340bf24deb00943 Mon Sep 17 00:00:00 2001 From: Anush008 Date: Wed, 13 Dec 2023 07:06:49 +0530 Subject: [PATCH 1/7] feat: qdrant-haystack --- integrations/qdrant/LICENSE.txt | 73 +++ integrations/qdrant/README.md | 21 + integrations/qdrant/pyproject.toml | 160 ++++++ .../qdrant/src/qdrant_haystack/__about__.py | 4 + .../qdrant/src/qdrant_haystack/__init__.py | 7 + .../qdrant/src/qdrant_haystack/converters.py | 56 +++ .../src/qdrant_haystack/document_store.py | 458 +++++++++++++++++ .../qdrant/src/qdrant_haystack/filters.py | 211 ++++++++ .../qdrant/src/qdrant_haystack/retriever.py | 98 ++++ .../qdrant/src/qdrant_haystack/utils.py | 0 integrations/qdrant/tests/__init__.py | 3 + integrations/qdrant/tests/test_converters.py | 53 ++ .../qdrant/tests/test_dict_convertors.py | 102 ++++ .../qdrant/tests/test_document_store.py | 42 ++ integrations/qdrant/tests/test_filters.py | 115 +++++ .../qdrant/tests/test_legacy_filters.py | 459 ++++++++++++++++++ integrations/qdrant/tests/test_retriever.py | 113 +++++ 17 files changed, 1975 insertions(+) create mode 100644 integrations/qdrant/LICENSE.txt create mode 100644 integrations/qdrant/README.md create mode 100644 integrations/qdrant/pyproject.toml create mode 100644 integrations/qdrant/src/qdrant_haystack/__about__.py create mode 100644 integrations/qdrant/src/qdrant_haystack/__init__.py create mode 100644 integrations/qdrant/src/qdrant_haystack/converters.py create mode 100644 integrations/qdrant/src/qdrant_haystack/document_store.py create mode 100644 integrations/qdrant/src/qdrant_haystack/filters.py create mode 100644 integrations/qdrant/src/qdrant_haystack/retriever.py create mode 100644 integrations/qdrant/src/qdrant_haystack/utils.py create mode 100644 integrations/qdrant/tests/__init__.py create mode 100644 integrations/qdrant/tests/test_converters.py create mode 100644 integrations/qdrant/tests/test_dict_convertors.py create mode 100644 integrations/qdrant/tests/test_document_store.py create mode 100644 integrations/qdrant/tests/test_filters.py create mode 100644 integrations/qdrant/tests/test_legacy_filters.py create mode 100644 integrations/qdrant/tests/test_retriever.py diff --git a/integrations/qdrant/LICENSE.txt b/integrations/qdrant/LICENSE.txt new file mode 100644 index 000000000..137069b82 --- /dev/null +++ b/integrations/qdrant/LICENSE.txt @@ -0,0 +1,73 @@ +Apache License +Version 2.0, January 2004 +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + + (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + + You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +APPENDIX: How to apply the Apache License to your work. + +To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. + +Copyright [yyyy] [name of copyright owner] + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/integrations/qdrant/README.md b/integrations/qdrant/README.md new file mode 100644 index 000000000..de124a2a2 --- /dev/null +++ b/integrations/qdrant/README.md @@ -0,0 +1,21 @@ +# qdrant-haystack + +[![PyPI - Version](https://img.shields.io/pypi/v/qdrant-haystack.svg)](https://pypi.org/project/qdrant-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/qdrant-haystack.svg)](https://pypi.org/project/qdrant-haystack) + +----- + +**Table of Contents** + +- [Installation](#installation) +- [License](#license) + +## Installation + +```console +pip install qdrant-haystack +``` + +## License + +`qdrant-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/qdrant/pyproject.toml b/integrations/qdrant/pyproject.toml new file mode 100644 index 000000000..b61c27243 --- /dev/null +++ b/integrations/qdrant/pyproject.toml @@ -0,0 +1,160 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "qdrant-haystack" +dynamic = ["version"] +description = 'An integration of Qdrant ANN vector database backend with Haystack' +readme = "README.md" +requires-python = ">=3.7" +license = "Apache-2.0" +keywords = [] +authors = [ + { name = "deepset GmbH", email = "info@deepset.ai" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "haystack-ai", + "qdrant-client", +] + +[project.urls] +Documentation = "https://github.com/unknown/qdrant-haystack#readme" +Issues = "https://github.com/unknown/qdrant-haystack/issues" +Source = "https://github.com/unknown/qdrant-haystack" + +[tool.hatch.version] +path = "src/qdrant_haystack/__about__.py" + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", +] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] + +[[tool.hatch.envs.all.matrix]] +python = ["3.7", "3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = [ + "black>=23.1.0", + "mypy>=1.0.0", + "ruff>=0.0.243", +] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive {args:src/qdrant_haystack tests}" +style = [ + "ruff {args:.}", + "black --check --diff {args:.}", +] +fmt = [ + "black {args:.}", + "ruff --fix {args:.}", + "style", +] +all = [ + "style", + "typing", +] + +[tool.black] +target-version = ["py37"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py37" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", "S106", "S107", + # Ignore complexity + "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.isort] +known-first-party = ["qdrant_haystack"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +source_pkgs = ["qdrant_haystack", "tests"] +branch = true +parallel = true +omit = [ + "src/qdrant_haystack/__about__.py", +] + +[tool.coverage.paths] +qdrant_haystack = ["src/qdrant_haystack", "*/qdrant-haystack/src/qdrant_haystack"] +tests = ["tests", "*/qdrant-haystack/tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] diff --git a/integrations/qdrant/src/qdrant_haystack/__about__.py b/integrations/qdrant/src/qdrant_haystack/__about__.py new file mode 100644 index 000000000..0e4fa27cf --- /dev/null +++ b/integrations/qdrant/src/qdrant_haystack/__about__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +__version__ = "0.0.1" diff --git a/integrations/qdrant/src/qdrant_haystack/__init__.py b/integrations/qdrant/src/qdrant_haystack/__init__.py new file mode 100644 index 000000000..765ced0ef --- /dev/null +++ b/integrations/qdrant/src/qdrant_haystack/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from qdrant_haystack.document_store import QdrantDocumentStore + +__all__ = ("QdrantDocumentStore",) diff --git a/integrations/qdrant/src/qdrant_haystack/converters.py b/integrations/qdrant/src/qdrant_haystack/converters.py new file mode 100644 index 000000000..3fb6dabd6 --- /dev/null +++ b/integrations/qdrant/src/qdrant_haystack/converters.py @@ -0,0 +1,56 @@ +import uuid +from typing import List, Union + +from haystack.dataclasses import Document +from qdrant_client.http import models as rest + + +class HaystackToQdrant: + """A converter from Haystack to Qdrant types.""" + + UUID_NAMESPACE = uuid.UUID("3896d314-1e95-4a3a-b45a-945f9f0b541d") + + def documents_to_batch( + self, + documents: List[Document], + *, + embedding_field: str, + ) -> List[rest.PointStruct]: + points = [] + for document in documents: + payload = document.to_dict(flatten=False) + vector = payload.pop(embedding_field) or {} + _id = self.convert_id(payload.get("id")) + + point = rest.PointStruct( + payload=payload, + vector=vector, + id=_id, + ) + points.append(point) + return points + + def convert_id(self, _id: str) -> str: + """ + Converts any string into a UUID-like format in a deterministic way. + + Qdrant does not accept any string as an id, so an internal id has to be + generated for each point. This is a deterministic way of doing so. + """ + return uuid.uuid5(self.UUID_NAMESPACE, _id).hex + + +QdrantPoint = Union[rest.ScoredPoint, rest.Record] + + +class QdrantToHaystack: + def __init__(self, content_field: str, name_field: str, embedding_field: str): + self.content_field = content_field + self.name_field = name_field + self.embedding_field = embedding_field + + def point_to_document(self, point: QdrantPoint) -> Document: + payload = {**point.payload} + payload["embedding"] = point.vector if hasattr(point, "vector") else None + payload["score"] = point.score if hasattr(point, "score") else None + return Document.from_dict(payload) diff --git a/integrations/qdrant/src/qdrant_haystack/document_store.py b/integrations/qdrant/src/qdrant_haystack/document_store.py new file mode 100644 index 000000000..c4b709332 --- /dev/null +++ b/integrations/qdrant/src/qdrant_haystack/document_store.py @@ -0,0 +1,458 @@ +import inspect +import logging +from itertools import islice +from typing import Any, ClassVar, Dict, Generator, List, Optional, Set, Union + +import numpy as np +import qdrant_client +from grpc import RpcError +from haystack import default_from_dict, default_to_dict +from haystack.dataclasses import Document +from haystack.document_stores import DuplicatePolicy +from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError +from haystack.utils.filters import convert +from qdrant_client import grpc +from qdrant_client.http import models as rest +from qdrant_client.http.exceptions import UnexpectedResponse +from tqdm import tqdm + +from qdrant_haystack.converters import HaystackToQdrant, QdrantToHaystack +from qdrant_haystack.filters import QdrantFilterConverter + +logger = logging.getLogger(__name__) + + +class QdrantStoreError(DocumentStoreError): + pass + + +FilterType = Dict[str, Union[Dict[str, Any], List[Any], str, int, float, bool]] + + +def get_batches_from_generator(iterable, n): + """ + Batch elements of an iterable into fixed-length chunks or blocks. + """ + it = iter(iterable) + x = tuple(islice(it, n)) + while x: + yield x + x = tuple(islice(it, n)) + + +class QdrantDocumentStore: + SIMILARITY: ClassVar[Dict[str, str]] = { + "cosine": rest.Distance.COSINE, + "dot_product": rest.Distance.DOT, + "l2": rest.Distance.EUCLID, + } + + def __init__( + self, + location: Optional[str] = None, + url: Optional[str] = None, + port: int = 6333, + grpc_port: int = 6334, + prefer_grpc: bool = False, # noqa: FBT001, FBT002 + https: Optional[bool] = None, + api_key: Optional[str] = None, + prefix: Optional[str] = None, + timeout: Optional[float] = None, + host: Optional[str] = None, + path: Optional[str] = None, + index: str = "Document", + embedding_dim: int = 768, + content_field: str = "content", + name_field: str = "name", + embedding_field: str = "embedding", + similarity: str = "cosine", + return_embedding: bool = False, # noqa: FBT001, FBT002 + progress_bar: bool = True, # noqa: FBT001, FBT002 + duplicate_documents: str = "overwrite", + recreate_index: bool = False, # noqa: FBT001, FBT002 + shard_number: Optional[int] = None, + replication_factor: Optional[int] = None, + write_consistency_factor: Optional[int] = None, + on_disk_payload: Optional[bool] = None, + hnsw_config: Optional[dict] = None, + optimizers_config: Optional[dict] = None, + wal_config: Optional[dict] = None, + quantization_config: Optional[dict] = None, + init_from: Optional[dict] = None, + wait_result_from_api: bool = True, # noqa: FBT001, FBT002 + metadata: Optional[dict] = None, + write_batch_size: int = 100, + scroll_size: int = 10_000, + ): + super().__init__() + + metadata = metadata or {} + self.client = qdrant_client.QdrantClient( + location=location, + url=url, + port=port, + grpc_port=grpc_port, + prefer_grpc=prefer_grpc, + https=https, + api_key=api_key, + prefix=prefix, + timeout=timeout, + host=host, + path=path, + metadata=metadata, + ) + + # Store the Qdrant client specific attributes + self.location = location + self.url = url + self.port = port + self.grpc_port = grpc_port + self.prefer_grpc = prefer_grpc + self.https = https + self.api_key = api_key + self.prefix = prefix + self.timeout = timeout + self.host = host + self.path = path + self.metadata = metadata + + # Store the Qdrant collection specific attributes + self.shard_number = shard_number + self.replication_factor = replication_factor + self.write_consistency_factor = write_consistency_factor + self.on_disk_payload = on_disk_payload + self.hnsw_config = hnsw_config + self.optimizers_config = optimizers_config + self.wal_config = wal_config + self.quantization_config = quantization_config + self.init_from = init_from + self.wait_result_from_api = wait_result_from_api + self.recreate_index = recreate_index + + # Make sure the collection is properly set up + self._set_up_collection(index, embedding_dim, recreate_index, similarity) + + self.embedding_dim = embedding_dim + self.content_field = content_field + self.name_field = name_field + self.embedding_field = embedding_field + self.similarity = similarity + self.index = index + self.return_embedding = return_embedding + self.progress_bar = progress_bar + self.duplicate_documents = duplicate_documents + self.qdrant_filter_converter = QdrantFilterConverter() + self.haystack_to_qdrant_converter = HaystackToQdrant() + self.qdrant_to_haystack = QdrantToHaystack( + content_field, + name_field, + embedding_field, + ) + self.write_batch_size = write_batch_size + self.scroll_size = scroll_size + + def count_documents(self) -> int: + try: + response = self.client.count( + collection_name=self.index, + ) + return response.count + except (UnexpectedResponse, ValueError): + # Qdrant local raises ValueError if the collection is not found, but + # with the remote server UnexpectedResponse is raised. Until that's unified, + # we need to catch both. + return 0 + + def filter_documents( + self, + filters: Optional[Dict[str, Any]] = None, + ) -> List[Document]: + if filters and not isinstance(filters, dict): + msg = "Filter must be a dictionary" + raise ValueError(msg) + + if filters and "operator" not in filters: + filters = convert(filters) + return list( + self.get_documents_generator( + filters, + ) + ) + + def write_documents( + self, + documents: List[Document], + policy: DuplicatePolicy = DuplicatePolicy.FAIL, + ): + for doc in documents: + if not isinstance(doc, Document): + msg = f"DocumentStore.write_documents() expects a list of Documents but got an element of {type(doc)}." + raise ValueError(msg) + self._set_up_collection(self.index, self.embedding_dim, False, self.similarity) + + if len(documents) == 0: + logger.warning("Calling QdrantDocumentStore.write_documents() with empty list") + return + + document_objects = self._handle_duplicate_documents( + documents=documents, + index=self.index, + policy=policy, + ) + + batched_documents = get_batches_from_generator(document_objects, self.write_batch_size) + with tqdm(total=len(document_objects), disable=not self.progress_bar) as progress_bar: + for document_batch in batched_documents: + batch = self.haystack_to_qdrant_converter.documents_to_batch( + document_batch, + embedding_field=self.embedding_field, + ) + + self.client.upsert( + collection_name=self.index, + points=batch, + wait=self.wait_result_from_api, + ) + + progress_bar.update(self.write_batch_size) + return len(document_objects) + + def delete_documents(self, ids: List[str]): + ids = [self.haystack_to_qdrant_converter.convert_id(_id) for _id in ids] + try: + self.client.delete( + collection_name=self.index, + points_selector=ids, + wait=self.wait_result_from_api, + ) + except KeyError: + logger.warning( + "Called QdrantDocumentStore.delete_documents() on a non-existing ID", + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "QdrantDocumentStore": + return default_from_dict(cls, data) + + def to_dict(self) -> Dict[str, Any]: + params = inspect.signature(self.__init__).parameters # type: ignore + # All the __init__ params must be set as attributes + # Set as init_parms without default values + init_params = {k: getattr(self, k) for k in params} + return default_to_dict( + self, + **init_params, + ) + + def get_documents_generator( + self, + filters: Optional[Dict[str, Any]] = None, + ) -> Generator[Document, None, None]: + index = self.index + qdrant_filters = self.qdrant_filter_converter.convert(filters) + + next_offset = None + stop_scrolling = False + while not stop_scrolling: + records, next_offset = self.client.scroll( + collection_name=index, + scroll_filter=qdrant_filters, + limit=self.scroll_size, + offset=next_offset, + with_payload=True, + with_vectors=True, + ) + stop_scrolling = next_offset is None or ( + isinstance(next_offset, grpc.PointId) and next_offset.num == 0 and next_offset.uuid == "" + ) + + for record in records: + yield self.qdrant_to_haystack.point_to_document(record) + + def get_documents_by_id( + self, + ids: List[str], + index: Optional[str] = None, + ) -> List[Document]: + index = index or self.index + + documents: List[Document] = [] + + ids = [self.haystack_to_qdrant_converter.convert_id(_id) for _id in ids] + records = self.client.retrieve( + collection_name=index, + ids=ids, + with_payload=True, + with_vectors=True, + ) + + for record in records: + documents.append(self.qdrant_to_haystack.point_to_document(record)) + return documents + + def query_by_embedding( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + scale_score: bool = True, # noqa: FBT001, FBT002 + return_embedding: bool = False, # noqa: FBT001, FBT002 + ) -> List[Document]: + qdrant_filters = self.qdrant_filter_converter.convert(filters) + + points = self.client.search( + collection_name=self.index, + query_vector=query_embedding, + query_filter=qdrant_filters, + limit=top_k, + with_vectors=return_embedding, + ) + + results = [self.qdrant_to_haystack.point_to_document(point) for point in points] + if scale_score: + for document in results: + score = document.score + if self.similarity == "cosine": + score = (score + 1) / 2 + else: + score = float(1 / (1 + np.exp(-score / 100))) + document.score = score + return results + + def _get_distance(self, similarity: str) -> rest.Distance: + try: + return self.SIMILARITY[similarity] + except KeyError as ke: + msg = ( + f"Provided similarity '{similarity}' is not supported by Qdrant " + f"document store. Please choose one of the options: " + f"{', '.join(self.SIMILARITY.keys())}" + ) + raise QdrantStoreError(msg) from ke + + def _set_up_collection( + self, + collection_name: str, + embedding_dim: int, + recreate_collection: bool, # noqa: FBT001 + similarity: str, + ): + distance = self._get_distance(similarity) + + if recreate_collection: + # There is no need to verify the current configuration of that + # collection. It might be just recreated again. + self._recreate_collection(collection_name, distance, embedding_dim) + return + + try: + # Check if the collection already exists and validate its + # current configuration with the parameters. + collection_info = self.client.get_collection(collection_name) + except (UnexpectedResponse, RpcError, ValueError): + # That indicates the collection does not exist, so it can be + # safely created with any configuration. + # + # Qdrant local raises ValueError if the collection is not found, but + # with the remote server UnexpectedResponse / RpcError is raised. + # Until that's unified, we need to catch both. + self._recreate_collection(collection_name, distance, embedding_dim) + return + + current_distance = collection_info.config.params.vectors.distance + current_vector_size = collection_info.config.params.vectors.size + + if current_distance != distance: + msg = ( + f"Collection '{collection_name}' already exists in Qdrant, " + f"but it is configured with a similarity '{current_distance.name}'. " + f"If you want to use that collection, but with a different " + f"similarity, please set `recreate_collection=True` argument." + ) + raise ValueError(msg) + + if current_vector_size != embedding_dim: + msg = ( + f"Collection '{collection_name}' already exists in Qdrant, " + f"but it is configured with a vector size '{current_vector_size}'. " + f"If you want to use that collection, but with a different " + f"vector size, please set `recreate_collection=True` argument." + ) + raise ValueError(msg) + + def _recreate_collection(self, collection_name: str, distance, embedding_dim: int): + self.client.recreate_collection( + collection_name=collection_name, + vectors_config=rest.VectorParams( + size=embedding_dim, + distance=distance, + ), + shard_number=self.shard_number, + replication_factor=self.replication_factor, + write_consistency_factor=self.write_consistency_factor, + on_disk_payload=self.on_disk_payload, + hnsw_config=self.hnsw_config, + optimizers_config=self.optimizers_config, + wal_config=self.wal_config, + quantization_config=self.quantization_config, + init_from=self.init_from, + ) + + def _handle_duplicate_documents( + self, + documents: List[Document], + index: Optional[str] = None, + policy: DuplicatePolicy = None, + ): + """ + Checks whether any of the passed documents is already existing in the chosen index and returns a list of + documents that are not in the index yet. + + :param documents: A list of Haystack Document objects. + :param index: name of the index + :param duplicate_documents: Handle duplicates document based on parameter options. + Parameter options : ( 'skip','overwrite','fail') + skip (default option): Ignore the duplicates documents + overwrite: Update any existing documents with the same ID when adding documents. + fail: an error is raised if the document ID of the document being added already + exists. + :return: A list of Haystack Document objects. + """ + + index = index or self.index + if policy in (DuplicatePolicy.SKIP, DuplicatePolicy.FAIL): + documents = self._drop_duplicate_documents(documents, index) + documents_found = self.get_documents_by_id(ids=[doc.id for doc in documents], index=index) + ids_exist_in_db: List[str] = [doc.id for doc in documents_found] + + if len(ids_exist_in_db) > 0 and policy == DuplicatePolicy.FAIL: + msg = f"Document with ids '{', '.join(ids_exist_in_db)} already exists in index = '{index}'." + raise DuplicateDocumentError(msg) + + documents = list(filter(lambda doc: doc.id not in ids_exist_in_db, documents)) + + return documents + + def _drop_duplicate_documents(self, documents: List[Document], index: Optional[str] = None) -> List[Document]: + """ + Drop duplicates documents based on same hash ID + + :param documents: A list of Haystack Document objects. + :param index: name of the index + :return: A list of Haystack Document objects. + """ + _hash_ids: Set = set() + _documents: List[Document] = [] + + for document in documents: + if document.id in _hash_ids: + logger.info( + "Duplicate Documents: Document with id '%s' already exists in index '%s'", + document.id, + index or self.index, + ) + continue + _documents.append(document) + _hash_ids.add(document.id) + + return _documents diff --git a/integrations/qdrant/src/qdrant_haystack/filters.py b/integrations/qdrant/src/qdrant_haystack/filters.py new file mode 100644 index 000000000..cc6b2b6a5 --- /dev/null +++ b/integrations/qdrant/src/qdrant_haystack/filters.py @@ -0,0 +1,211 @@ +from abc import ABC, abstractmethod +from typing import Any, List, Optional, Union + +from haystack.utils.filters import COMPARISON_OPERATORS, LOGICAL_OPERATORS, FilterError +from qdrant_client.http import models + +from qdrant_haystack.converters import HaystackToQdrant + +COMPARISON_OPERATORS = COMPARISON_OPERATORS.keys() +LOGICAL_OPERATORS = LOGICAL_OPERATORS.keys() + + +class BaseFilterConverter(ABC): + """Converts Haystack filters to a format accepted by an external tool.""" + + @abstractmethod + def convert(self, filter_term: Optional[Union[List[dict], dict]]) -> Optional[Any]: + raise NotImplementedError + + +class QdrantFilterConverter(BaseFilterConverter): + """Converts Haystack filters to the format used by Qdrant.""" + + def __init__(self): + self.haystack_to_qdrant_converter = HaystackToQdrant() + + def convert( + self, + filter_term: Optional[Union[List[dict], dict]] = None, + ) -> Optional[models.Filter]: + if not filter_term: + return None + + must_clauses, should_clauses, must_not_clauses = [], [], [] + + if isinstance(filter_term, dict): + filter_term = [filter_term] + + for item in filter_term: + operator = item.get("operator") + if operator is None: + msg = "Operator not found in filters" + raise FilterError(msg) + + if operator in LOGICAL_OPERATORS and "conditions" not in item: + msg = f"'conditions' not found for '{operator}'" + raise FilterError(msg) + + if operator == "AND": + must_clauses.append(self.convert(item.get("conditions", []))) + elif operator == "OR": + should_clauses.append(self.convert(item.get("conditions", []))) + elif operator == "NOT": + must_not_clauses.append(self.convert(item.get("conditions", []))) + elif operator in COMPARISON_OPERATORS: + field = item.get("field") + value = item.get("value") + if field is None or value is None: + msg = f"'field' or 'value' not found for '{operator}'" + raise FilterError(msg) + + must_clauses.extend( + self._parse_comparison_operation(comparison_operation=operator, key=field, value=value) + ) + else: + msg = f"Unknown operator {operator} used in filters" + raise FilterError(msg) + + payload_filter = models.Filter( + must=must_clauses or None, + should=should_clauses or None, + must_not=must_not_clauses or None, + ) + + filter_result = self._squeeze_filter(payload_filter) + + return filter_result + + def _parse_comparison_operation( + self, comparison_operation: str, key: str, value: Union[dict, List, str, float] + ) -> List[models.Condition]: + conditions: List[models.Condition] = [] + + condition_builder_mapping = { + "==": self._build_eq_condition, + "in": self._build_in_condition, + "!=": self._build_ne_condition, + "not in": self._build_nin_condition, + ">": self._build_gt_condition, + ">=": self._build_gte_condition, + "<": self._build_lt_condition, + "<=": self._build_lte_condition, + } + + condition_builder = condition_builder_mapping.get(comparison_operation) + + if condition_builder is None: + msg = f"Unknown operator {comparison_operation} used in filters" + raise ValueError(msg) + + conditions.append(condition_builder(key, value)) + + return conditions + + def _build_eq_condition(self, key: str, value: models.ValueVariants) -> models.Condition: + if isinstance(value, str) and " " in value: + models.FieldCondition(key=key, match=models.MatchText(text=value)) + return models.FieldCondition(key=key, match=models.MatchValue(value=value)) + + def _build_in_condition(self, key: str, value: List[models.ValueVariants]) -> models.Condition: + if not isinstance(value, list): + msg = f"Value {value} is not a list" + raise FilterError(msg) + return models.Filter( + should=[ + models.FieldCondition(key=key, match=models.MatchText(text=item)) + if isinstance(item, str) and " " not in item + else models.FieldCondition(key=key, match=models.MatchValue(value=item)) + for item in value + ] + ) + + def _build_ne_condition(self, key: str, value: models.ValueVariants) -> models.Condition: + return models.Filter( + must_not=[ + models.FieldCondition(key=key, match=models.MatchText(text=value)) + if isinstance(value, str) and " " not in value + else models.FieldCondition(key=key, match=models.MatchValue(value=value)) + ] + ) + + def _build_nin_condition(self, key: str, value: List[models.ValueVariants]) -> models.Condition: + if not isinstance(value, list): + msg = f"Value {value} is not a list" + raise FilterError(msg) + return models.Filter( + must_not=[ + models.FieldCondition(key=key, match=models.MatchText(text=item)) + if isinstance(item, str) and " " not in item + else models.FieldCondition(key=key, match=models.MatchValue(value=item)) + for item in value + ] + ) + + def _build_lt_condition(self, key: str, value: float) -> models.Condition: + if not isinstance(value, (int, float)): + msg = f"Value {value} is not an int or float" + raise FilterError(msg) + return models.FieldCondition(key=key, range=models.Range(lt=value)) + + def _build_lte_condition(self, key: str, value: float) -> models.Condition: + if not isinstance(value, (int, float)): + msg = f"Value {value} is not an int or float" + raise FilterError(msg) + return models.FieldCondition(key=key, range=models.Range(lte=value)) + + def _build_gt_condition(self, key: str, value: float) -> models.Condition: + if not isinstance(value, (int, float)): + msg = f"Value {value} is not an int or float" + raise FilterError(msg) + return models.FieldCondition(key=key, range=models.Range(gt=value)) + + def _build_gte_condition(self, key: str, value: float) -> models.Condition: + if not isinstance(value, (int, float)): + msg = f"Value {value} is not an int or float" + raise FilterError(msg) + return models.FieldCondition(key=key, range=models.Range(gte=value)) + + def _build_has_id_condition(self, id_values: List[models.ExtendedPointId]) -> models.HasIdCondition: + return models.HasIdCondition( + has_id=[ + # Ids are converted into their internal representation + self.haystack_to_qdrant_converter.convert_id(item) + for item in id_values + ] + ) + + def _squeeze_filter(self, payload_filter: models.Filter) -> models.Filter: + """ + Simplify given payload filter, if the nested structure might be unnested. + That happens if there is a single clause in that filter. + :param payload_filter: + :return: + """ + filter_parts = { + "must": payload_filter.must, + "should": payload_filter.should, + "must_not": payload_filter.must_not, + } + + total_clauses = sum(len(x) for x in filter_parts.values() if x is not None) + if total_clauses == 0 or total_clauses > 1: + return payload_filter + + # Payload filter has just a single clause provided (either must, should + # or must_not). If that single clause is also of a models.Filter type, + # then it might be returned instead. + for part_name, filter_part in filter_parts.items(): + if not filter_part: + continue + + subfilter = filter_part[0] + if not isinstance(subfilter, models.Filter): + # The inner statement is a simple condition like models.FieldCondition + # so it cannot be simplified. + continue + + if subfilter.must: + return models.Filter(**{part_name: subfilter.must}) + + return payload_filter diff --git a/integrations/qdrant/src/qdrant_haystack/retriever.py b/integrations/qdrant/src/qdrant_haystack/retriever.py new file mode 100644 index 000000000..054ba96ac --- /dev/null +++ b/integrations/qdrant/src/qdrant_haystack/retriever.py @@ -0,0 +1,98 @@ +from typing import Any, Dict, List, Optional + +from haystack import Document, component, default_from_dict, default_to_dict + +from qdrant_haystack import QdrantDocumentStore + + +@component +class QdrantRetriever: + """ + A component for retrieving documents from an QdrantDocumentStore. + """ + + def __init__( + self, + document_store: QdrantDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + scale_score: bool = True, # noqa: FBT001, FBT002 + return_embedding: bool = False, # noqa: FBT001, FBT002 + ): + """ + Create a QdrantRetriever component. + + :param document_store: An instance of QdrantDocumentStore. + :param filters: A dictionary with filters to narrow down the search space. Default is None. + :param top_k: The maximum number of documents to retrieve. Default is 10. + :param scale_score: Whether to scale the scores of the retrieved documents or not. Default is True. + :param return_embedding: Whether to return the embedding of the retrieved Documents. Default is False. + + :raises ValueError: If 'document_store' is not an instance of QdrantDocumentStore. + """ + + if not isinstance(document_store, QdrantDocumentStore): + msg = "document_store must be an instance of QdrantDocumentStore" + raise ValueError(msg) + + self._document_store = document_store + + self._filters = filters + self._top_k = top_k + self._scale_score = scale_score + self._return_embedding = return_embedding + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + d = default_to_dict( + self, + document_store=self._document_store, + filters=self._filters, + top_k=self._top_k, + scale_score=self._scale_score, + return_embedding=self._return_embedding, + ) + d["init_parameters"]["document_store"] = self._document_store.to_dict() + + return d + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "QdrantRetriever": + """ + Deserialize this component from a dictionary. + """ + document_store = QdrantDocumentStore.from_dict(data["init_parameters"]["document_store"]) + data["init_parameters"]["document_store"] = document_store + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run( + self, + query_embedding: List[float], + filters: Optional[Dict[str, Any]] = None, + top_k: Optional[int] = None, + scale_score: Optional[bool] = None, + return_embedding: Optional[bool] = None, + ): + """ + Run the Embedding Retriever on the given input data. + + :param query_embedding: Embedding of the query. + :param filters: A dictionary with filters to narrow down the search space. + :param top_k: The maximum number of documents to return. + :param scale_score: Whether to scale the scores of the retrieved documents or not. + :param return_embedding: Whether to return the embedding of the retrieved Documents. + :return: The retrieved documents. + + """ + docs = self._document_store.query_by_embedding( + query_embedding=query_embedding, + filters=filters or self._filters, + top_k=top_k or self._top_k, + scale_score=scale_score or self._scale_score, + return_embedding=return_embedding or self._return_embedding, + ) + + return {"documents": docs} diff --git a/integrations/qdrant/src/qdrant_haystack/utils.py b/integrations/qdrant/src/qdrant_haystack/utils.py new file mode 100644 index 000000000..e69de29bb diff --git a/integrations/qdrant/tests/__init__.py b/integrations/qdrant/tests/__init__.py new file mode 100644 index 000000000..e873bc332 --- /dev/null +++ b/integrations/qdrant/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/qdrant/tests/test_converters.py b/integrations/qdrant/tests/test_converters.py new file mode 100644 index 000000000..dc4866293 --- /dev/null +++ b/integrations/qdrant/tests/test_converters.py @@ -0,0 +1,53 @@ +import numpy as np +import pytest +from qdrant_client.http import models as rest + +from qdrant_haystack.converters import HaystackToQdrant, QdrantToHaystack + +CONTENT_FIELD = "content" +NAME_FIELD = "name" +EMBEDDING_FIELD = "vector" + + +@pytest.fixture +def haystack_to_qdrant() -> HaystackToQdrant: + return HaystackToQdrant() + + +@pytest.fixture +def qdrant_to_haystack() -> QdrantToHaystack: + return QdrantToHaystack( + content_field=CONTENT_FIELD, + name_field=NAME_FIELD, + embedding_field=EMBEDDING_FIELD, + ) + + +def test_convert_id_is_deterministic(haystack_to_qdrant: HaystackToQdrant): + first_id = haystack_to_qdrant.convert_id("test-id") + second_id = haystack_to_qdrant.convert_id("test-id") + assert first_id == second_id + + +def test_point_to_document_reverts_proper_structure_from_record( + qdrant_to_haystack: QdrantToHaystack, +): + point = rest.Record( + id="c7c62e8e-02b9-4ec6-9f88-46bd97b628b7", + payload={ + "id": "my-id", + "id_hash_keys": ["content"], + "content": "Lorem ipsum", + "content_type": "text", + "meta": { + "test_field": 1, + }, + }, + vector=[1.0, 0.0, 0.0, 0.0], + ) + document = qdrant_to_haystack.point_to_document(point) + assert "my-id" == document.id + assert "Lorem ipsum" == document.content + assert "text" == document.content_type + assert {"test_field": 1} == document.meta + assert 0.0 == np.sum(np.array([1.0, 0.0, 0.0, 0.0]) - document.embedding) diff --git a/integrations/qdrant/tests/test_dict_convertors.py b/integrations/qdrant/tests/test_dict_convertors.py new file mode 100644 index 000000000..1a211655c --- /dev/null +++ b/integrations/qdrant/tests/test_dict_convertors.py @@ -0,0 +1,102 @@ +from qdrant_haystack import QdrantDocumentStore + + +def test_to_dict(): + document_store = QdrantDocumentStore(location=":memory:", index="test") + + expected = { + "type": "qdrant_haystack.document_store.QdrantDocumentStore", + "init_parameters": { + "location": ":memory:", + "url": None, + "port": 6333, + "grpc_port": 6334, + "prefer_grpc": False, + "https": None, + "api_key": None, + "prefix": None, + "timeout": None, + "host": None, + "path": None, + "index": "test", + "embedding_dim": 768, + "content_field": "content", + "name_field": "name", + "embedding_field": "embedding", + "similarity": "cosine", + "return_embedding": False, + "progress_bar": True, + "duplicate_documents": "overwrite", + "recreate_index": False, + "shard_number": None, + "replication_factor": None, + "write_consistency_factor": None, + "on_disk_payload": None, + "hnsw_config": None, + "optimizers_config": None, + "wal_config": None, + "quantization_config": None, + "init_from": None, + "wait_result_from_api": True, + "metadata": {}, + "write_batch_size": 100, + "scroll_size": 10000, + }, + } + + assert document_store.to_dict() == expected + + +def test_from_dict(): + document_store = QdrantDocumentStore.from_dict( + { + "type": "qdrant_haystack.document_store.QdrantDocumentStore", + "init_parameters": { + "location": ":memory:", + "index": "test", + "embedding_dim": 768, + "content_field": "content", + "name_field": "name", + "embedding_field": "embedding", + "similarity": "cosine", + "return_embedding": False, + "progress_bar": True, + "duplicate_documents": "overwrite", + "recreate_index": True, + "shard_number": None, + "quantization_config": None, + "init_from": None, + "wait_result_from_api": True, + "metadata": {}, + "write_batch_size": 1000, + "scroll_size": 10000, + }, + } + ) + + assert all( + [ + document_store.index == "test", + document_store.content_field == "content", + document_store.name_field == "name", + document_store.embedding_field == "embedding", + document_store.similarity == "cosine", + document_store.return_embedding is False, + document_store.progress_bar, + document_store.duplicate_documents == "overwrite", + document_store.recreate_index is True, + document_store.shard_number is None, + document_store.replication_factor is None, + document_store.write_consistency_factor is None, + document_store.on_disk_payload is None, + document_store.hnsw_config is None, + document_store.optimizers_config is None, + document_store.wal_config is None, + document_store.quantization_config is None, + document_store.init_from is None, + document_store.wait_result_from_api, + document_store.metadata == {}, + document_store.write_batch_size == 1000, + document_store.scroll_size == 10000, + ] + ) diff --git a/integrations/qdrant/tests/test_document_store.py b/integrations/qdrant/tests/test_document_store.py new file mode 100644 index 000000000..bbc16b9df --- /dev/null +++ b/integrations/qdrant/tests/test_document_store.py @@ -0,0 +1,42 @@ +from typing import List + +import pytest +from haystack import Document +from haystack.document_stores import DuplicatePolicy +from haystack.document_stores.errors import DuplicateDocumentError +from haystack.testing.document_store import ( + CountDocumentsTest, + DeleteDocumentsTest, + WriteDocumentsTest, +) + +from qdrant_haystack import QdrantDocumentStore + + +class TestQdrantStoreBaseTests(CountDocumentsTest, WriteDocumentsTest, DeleteDocumentsTest): + @pytest.fixture + def document_store(self) -> QdrantDocumentStore: + return QdrantDocumentStore( + ":memory:", + recreate_index=True, + return_embedding=True, + wait_result_from_api=True, + ) + + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + Assert that two lists of Documents are equal. + This is used in every test. + """ + + # Check that the lengths of the lists are the same + assert len(received) == len(expected) + + # Check that the sets are equal, meaning the content and IDs match regardless of order + assert {doc.id for doc in received} == {doc.id for doc in expected} + + def test_write_documents(self, document_store: QdrantDocumentStore): + docs = [Document(id="1")] + assert document_store.write_documents(docs) == 1 + with pytest.raises(DuplicateDocumentError): + document_store.write_documents(docs, DuplicatePolicy.FAIL) diff --git a/integrations/qdrant/tests/test_filters.py b/integrations/qdrant/tests/test_filters.py new file mode 100644 index 000000000..a25f4a672 --- /dev/null +++ b/integrations/qdrant/tests/test_filters.py @@ -0,0 +1,115 @@ +from typing import List + +import pytest +from haystack import Document +from haystack.testing.document_store import FilterDocumentsTest +from haystack.utils.filters import FilterError + +from qdrant_haystack import QdrantDocumentStore + + +class TestQdrantStoreBaseTests(FilterDocumentsTest): + @pytest.fixture + def document_store(self) -> QdrantDocumentStore: + return QdrantDocumentStore( + ":memory:", + recreate_index=True, + return_embedding=True, + wait_result_from_api=True, + ) + + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + Assert that two lists of Documents are equal. + This is used in every test. + """ + + # Check that the lengths of the lists are the same + assert len(received) == len(expected) + + # Check that the sets are equal, meaning the content and IDs match regardless of order + assert {doc.id for doc in received} == {doc.id for doc in expected} + + def test_not_operator(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents( + filters={ + "operator": "NOT", + "conditions": [ + {"field": "meta.number", "operator": "==", "value": 100}, + {"field": "meta.name", "operator": "==", "value": "name_0"}, + ], + } + ) + self.assert_documents_are_equal( + result, + [d for d in filterable_docs if (d.meta.get("number") != 100 and d.meta.get("name") != "name_0")], + ) + + # ======== OVERRIDES FOR NONE VALUED FILTERS ======== + + def test_comparison_equal_with_none(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + result = document_store.filter_documents(filters={"field": "meta.number", "operator": "==", "value": None}) + self.assert_documents_are_equal(result, [d for d in filterable_docs if d.meta.get("number") is None]) + + def test_comparison_not_equal_with_none(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + result = document_store.filter_documents(filters={"field": "meta.number", "operator": "!=", "value": None}) + self.assert_documents_are_equal(result, [d for d in filterable_docs if d.meta.get("number") is not None]) + + def test_comparison_greater_than_with_none(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + result = document_store.filter_documents(filters={"field": "meta.number", "operator": ">", "value": None}) + self.assert_documents_are_equal(result, []) + + def test_comparison_greater_than_equal_with_none(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + result = document_store.filter_documents(filters={"field": "meta.number", "operator": ">=", "value": None}) + self.assert_documents_are_equal(result, []) + + def test_comparison_less_than_with_none(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + result = document_store.filter_documents(filters={"field": "meta.number", "operator": "<", "value": None}) + self.assert_documents_are_equal(result, []) + + def test_comparison_less_than_equal_with_none(self, document_store, filterable_docs): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + result = document_store.filter_documents(filters={"field": "meta.number", "operator": "<=", "value": None}) + self.assert_documents_are_equal(result, []) + + # ======== ========================== ======== + + @pytest.mark.skip(reason="Qdrant doesn't support comparision with dataframe") + def test_comparison_equal_with_dataframe(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Qdrant doesn't support comparision with dataframe") + def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates") + def test_comparison_greater_than_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates") + def test_comparison_greater_than_equal_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates") + def test_comparison_less_than_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Qdrant doesn't support comparision with Dates") + def test_comparison_less_than_equal_with_iso_date(self, document_store, filterable_docs): + ... + + @pytest.mark.skip(reason="Cannot distinguish errors yet") + def test_missing_top_level_operator_key(self, document_store, filterable_docs): + ... diff --git a/integrations/qdrant/tests/test_legacy_filters.py b/integrations/qdrant/tests/test_legacy_filters.py new file mode 100644 index 000000000..6cb78c653 --- /dev/null +++ b/integrations/qdrant/tests/test_legacy_filters.py @@ -0,0 +1,459 @@ +from typing import List + +import pytest +from haystack import Document +from haystack.document_stores import DocumentStore +from haystack.testing.document_store import LegacyFilterDocumentsTest +from haystack.utils.filters import FilterError + +from qdrant_haystack import QdrantDocumentStore + +# The tests below are from haystack.testing.document_store.LegacyFilterDocumentsTest +# Updated to include `meta` prefix for filter keys wherever necessary +# And skip tests that are not supported in Qdrant(Dataframes, embeddings) + + +class TestQdrantLegacyFilterDocuments(LegacyFilterDocumentsTest): + """ + Utility class to test a Document Store `filter_documents` method using different types of legacy filters + """ + + @pytest.fixture + def document_store(self) -> QdrantDocumentStore: + return QdrantDocumentStore( + ":memory:", + recreate_index=True, + return_embedding=True, + wait_result_from_api=True, + ) + + def assert_documents_are_equal(self, received: List[Document], expected: List[Document]): + """ + Assert that two lists of Documents are equal. + This is used in every test. + """ + + # Check that the lengths of the lists are the same + assert len(received) == len(expected) + + # Check that the sets are equal, meaning the content and IDs match regardless of order + assert {doc.id for doc in received} == {doc.id for doc in expected} + + def test_filter_simple_metadata_value(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": "100"}) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_filter_document_dataframe(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + def test_eq_filter_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": {"$eq": "100"}}) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) + + def test_eq_filter_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": "100"}) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_eq_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_eq_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + # LegacyFilterDocumentsNotEqualTest + + def test_ne_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": {"$ne": "100"}}) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") != "100"]) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_ne_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_ne_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + # LegacyFilterDocumentsInTest + + def test_filter_simple_list_single_element(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": ["100"]}) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) + + def test_filter_simple_list_one_value(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": ["100"]}) + self.assert_documents_are_equal(result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100"]]) + + def test_filter_simple_list(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": ["100", "123"]}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]], + ) + + def test_incorrect_filter_value(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": ["nope"]}) + self.assert_documents_are_equal(result, []) + + def test_in_filter_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": {"$in": ["100", "123", "n.a."]}}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]], + ) + + def test_in_filter_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": ["100", "123", "n.a."]}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]], + ) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_in_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_in_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + # LegacyFilterDocumentsNotInTest + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_nin_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_nin_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + def test_nin_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.page": {"$nin": ["100", "123", "n.a."]}}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if doc.meta.get("page") not in ["100", "123"]], + ) + + # LegacyFilterDocumentsGreaterThanTest + + def test_gt_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.number": {"$gt": 0.0}}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] > 0], + ) + + def test_gt_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"meta.page": {"$gt": "100"}}) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_gt_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_gt_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + # LegacyFilterDocumentsGreaterThanEqualTest + + def test_gte_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.number": {"$gte": -2}}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] >= -2], + ) + + def test_gte_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"meta.page": {"$gte": "100"}}) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_gte_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_gte_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + # LegacyFilterDocumentsLessThanTest + + def test_lt_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.number": {"$lt": 0.0}}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if doc.meta.get("number") is not None and doc.meta["number"] < 0], + ) + + def test_lt_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"meta.page": {"$lt": "100"}}) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_lt_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_lt_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + # LegacyFilterDocumentsLessThanEqualTest + + def test_lte_filter(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.number": {"$lte": 2.0}}) + self.assert_documents_are_equal( + result, + [doc for doc in filterable_docs if doc.meta.get("number") is not None and doc.meta["number"] <= 2.0], + ) + + def test_lte_filter_non_numeric(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + with pytest.raises(FilterError): + document_store.filter_documents(filters={"meta.page": {"$lte": "100"}}) + + @pytest.mark.skip(reason="Dataframe filtering is not supported in Qdrant") + def test_lte_filter_table(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + @pytest.mark.skip(reason="Embedding filtering is not supported in Qdrant") + def test_lte_filter_embedding(self, document_store: DocumentStore, filterable_docs: List[Document]): + ... + + # LegacyFilterDocumentsSimpleLogicalTest + + def test_filter_simple_or(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + filters = { + "$or": { + "meta.name": {"$in": ["name_0", "name_1"]}, + "meta.number": {"$lt": 1.0}, + } + } + result = document_store.filter_documents(filters=filters) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if (doc.meta.get("number") is not None and doc.meta["number"] < 1) + or doc.meta.get("name") in ["name_0", "name_1"] + ], + ) + + def test_filter_simple_implicit_and_with_multi_key_dict( + self, document_store: DocumentStore, filterable_docs: List[Document] + ): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.number": {"$lte": 2.0, "$gte": 0.0}}) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if "number" in doc.meta and doc.meta["number"] >= 0.0 and doc.meta["number"] <= 2.0 + ], + ) + + def test_filter_simple_explicit_and_with_list(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.number": {"$and": [{"$lte": 2}, {"$gte": 0}]}}) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if "number" in doc.meta and doc.meta["number"] <= 2.0 and doc.meta["number"] >= 0.0 + ], + ) + + def test_filter_simple_implicit_and(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + result = document_store.filter_documents(filters={"meta.number": {"$lte": 2.0, "$gte": 0}}) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if "number" in doc.meta and doc.meta["number"] <= 2.0 and doc.meta["number"] >= 0.0 + ], + ) + + # LegacyFilterDocumentsNestedLogicalTest( + + def test_filter_nested_implicit_and(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + filters_simplified = { + "meta.number": {"$lte": 2, "$gte": 0}, + "meta.name": ["name_0", "name_1"], + } + result = document_store.filter_documents(filters=filters_simplified) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + "number" in doc.meta + and doc.meta["number"] <= 2 + and doc.meta["number"] >= 0 + and doc.meta.get("name") in ["name_0", "name_1"] + ) + ], + ) + + def test_filter_nested_or(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + filters = { + "$or": { + "meta.name": {"$in": ["name_0", "name_1"]}, + "meta.number": {"$lt": 1.0}, + } + } + result = document_store.filter_documents(filters=filters) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + doc.meta.get("name") in ["name_0", "name_1"] + or (doc.meta.get("number") is not None and doc.meta["number"] < 1) + ) + ], + ) + + def test_filter_nested_and_or_explicit(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + filters_simplified = { + "$and": { + "meta.page": {"$eq": "123"}, + "$or": { + "meta.name": {"$in": ["name_0", "name_1"]}, + "meta.number": {"$lt": 1.0}, + }, + } + } + result = document_store.filter_documents(filters=filters_simplified) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + doc.meta.get("page") in ["123"] + and ( + doc.meta.get("name") in ["name_0", "name_1"] + or ("number" in doc.meta and doc.meta["number"] < 1) + ) + ) + ], + ) + + def test_filter_nested_and_or_implicit(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + filters_simplified = { + "meta.page": {"$eq": "123"}, + "$or": { + "meta.name": {"$in": ["name_0", "name_1"]}, + "meta.number": {"$lt": 1.0}, + }, + } + result = document_store.filter_documents(filters=filters_simplified) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + doc.meta.get("page") in ["123"] + and ( + doc.meta.get("name") in ["name_0", "name_1"] + or ("number" in doc.meta and doc.meta["number"] < 1) + ) + ) + ], + ) + + def test_filter_nested_or_and(self, document_store: DocumentStore, filterable_docs: List[Document]): + document_store.write_documents(filterable_docs) + filters_simplified = { + "$or": { + "meta.number": {"$lt": 1}, + "$and": { + "meta.name": {"$in": ["name_0", "name_1"]}, + "$not": {"meta.chapter": {"$eq": "intro"}}, + }, + } + } + result = document_store.filter_documents(filters=filters_simplified) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + (doc.meta.get("number") is not None and doc.meta["number"] < 1) + or (doc.meta.get("name") in ["name_0", "name_1"] and (doc.meta.get("chapter") != "intro")) + ) + ], + ) + + def test_filter_nested_multiple_identical_operators_same_level( + self, document_store: DocumentStore, filterable_docs: List[Document] + ): + document_store.write_documents(filterable_docs) + filters = { + "$or": [ + { + "$and": { + "meta.name": {"$in": ["name_0", "name_1"]}, + "meta.page": "100", + } + }, + { + "$and": { + "meta.chapter": {"$in": ["intro", "abstract"]}, + "meta.page": "123", + } + }, + ] + } + result = document_store.filter_documents(filters=filters) + self.assert_documents_are_equal( + result, + [ + doc + for doc in filterable_docs + if ( + (doc.meta.get("name") in ["name_0", "name_1"] and doc.meta.get("page") == "100") + or (doc.meta.get("chapter") in ["intro", "abstract"] and doc.meta.get("page") == "123") + ) + ], + ) + + def test_no_filter_not_empty(self, document_store: DocumentStore): + docs = [Document(content="test doc")] + document_store.write_documents(docs) + self.assert_documents_are_equal(document_store.filter_documents(), docs) + self.assert_documents_are_equal(document_store.filter_documents(filters={}), docs) diff --git a/integrations/qdrant/tests/test_retriever.py b/integrations/qdrant/tests/test_retriever.py new file mode 100644 index 000000000..22eabfaad --- /dev/null +++ b/integrations/qdrant/tests/test_retriever.py @@ -0,0 +1,113 @@ +from typing import List + +from haystack.dataclasses import Document +from haystack.testing.document_store import ( + FilterableDocsFixtureMixin, + _random_embeddings, +) + +from qdrant_haystack import QdrantDocumentStore +from qdrant_haystack.retriever import QdrantRetriever + + +class TestQdrantRetriever(FilterableDocsFixtureMixin): + def test_init_default(self): + document_store = QdrantDocumentStore(location=":memory:", index="test") + retriever = QdrantRetriever(document_store=document_store) + assert retriever._document_store == document_store + assert retriever._filters is None + assert retriever._top_k == 10 + assert retriever._return_embedding is False + + def test_to_dict(self): + document_store = QdrantDocumentStore(location=":memory:", index="test") + retriever = QdrantRetriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "qdrant_haystack.retriever.QdrantRetriever", + "init_parameters": { + "document_store": { + "type": "qdrant_haystack.document_store.QdrantDocumentStore", + "init_parameters": { + "location": ":memory:", + "url": None, + "port": 6333, + "grpc_port": 6334, + "prefer_grpc": False, + "https": None, + "api_key": None, + "prefix": None, + "timeout": None, + "host": None, + "path": None, + "index": "test", + "embedding_dim": 768, + "content_field": "content", + "name_field": "name", + "embedding_field": "embedding", + "similarity": "cosine", + "return_embedding": False, + "progress_bar": True, + "duplicate_documents": "overwrite", + "recreate_index": False, + "shard_number": None, + "replication_factor": None, + "write_consistency_factor": None, + "on_disk_payload": None, + "hnsw_config": None, + "optimizers_config": None, + "wal_config": None, + "quantization_config": None, + "init_from": None, + "wait_result_from_api": True, + "metadata": {}, + "write_batch_size": 100, + "scroll_size": 10000, + }, + }, + "filters": None, + "top_k": 10, + "scale_score": True, + "return_embedding": False, + }, + } + + def test_from_dict(self): + data = { + "type": "qdrant_haystack.retriever.QdrantRetriever", + "init_parameters": { + "document_store": { + "init_parameters": {"location": ":memory:", "index": "test"}, + "type": "qdrant_haystack.document_store.QdrantDocumentStore", + }, + "filters": None, + "top_k": 5, + "scale_score": False, + "return_embedding": True, + }, + } + retriever = QdrantRetriever.from_dict(data) + assert isinstance(retriever._document_store, QdrantDocumentStore) + assert retriever._document_store.index == "test" + assert retriever._filters is None + assert retriever._top_k == 5 + assert retriever._scale_score is False + assert retriever._return_embedding is True + + def test_run(self, filterable_docs: List[Document]): + document_store = QdrantDocumentStore(location=":memory:", index="Boi") + + document_store.write_documents(filterable_docs) + + retriever = QdrantRetriever(document_store=document_store) + + results: List[Document] = retriever.run(query_embedding=_random_embeddings(768)) + + assert len(results["documents"]) == 10 # type: ignore + + results = retriever.run(query_embedding=_random_embeddings(768), top_k=5, return_embedding=False) + + assert len(results["documents"]) == 5 # type: ignore + + for document in results["documents"]: # type: ignore + assert document.embedding is None From 47d1fee6cebfafacb85f043f6aae756f16508fb0 Mon Sep 17 00:00:00 2001 From: Anush Date: Wed, 13 Dec 2023 07:10:05 +0530 Subject: [PATCH 2/7] ci: Create qdrant.yml --- .github/workflows/qdrant.yml | 56 ++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 .github/workflows/qdrant.yml diff --git a/.github/workflows/qdrant.yml b/.github/workflows/qdrant.yml new file mode 100644 index 000000000..2bbf4f63a --- /dev/null +++ b/.github/workflows/qdrant.yml @@ -0,0 +1,56 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / qdrant + +on: + schedule: + - cron: "0 0 * * *" + pull_request: + paths: + - 'integrations/qdrant/**' + - '.github/workflows/qdrant.yml' + +defaults: + run: + working-directory: integrations/qdrant + +concurrency: + group: qdrant-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ['3.9', '3.10'] + + steps: + - name: Support longpaths + if: matrix.os == 'windows-latest' + working-directory: . + run: git config --system core.longpaths true + + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all + + - name: Run tests + run: hatch run cov From e0aeb5ab98e919a612a67f42000e9717758c0cda Mon Sep 17 00:00:00 2001 From: Anush Date: Wed, 13 Dec 2023 07:26:21 +0530 Subject: [PATCH 3/7] docs: Update README.md, mypy overrides (#1) * docs: Update README.md * chore: mypy overrides --- integrations/qdrant/README.md | 7 +++++++ integrations/qdrant/pyproject.toml | 16 +++++++++++++--- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/integrations/qdrant/README.md b/integrations/qdrant/README.md index de124a2a2..4fbe69f6e 100644 --- a/integrations/qdrant/README.md +++ b/integrations/qdrant/README.md @@ -16,6 +16,13 @@ pip install qdrant-haystack ``` +## Testing +The test suites use Qdran't in-memory instance. No additional steps required. + +```console +hatch run test +``` + ## License `qdrant-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/integrations/qdrant/pyproject.toml b/integrations/qdrant/pyproject.toml index b61c27243..bf4387f2f 100644 --- a/integrations/qdrant/pyproject.toml +++ b/integrations/qdrant/pyproject.toml @@ -30,9 +30,9 @@ dependencies = [ ] [project.urls] -Documentation = "https://github.com/unknown/qdrant-haystack#readme" -Issues = "https://github.com/unknown/qdrant-haystack/issues" -Source = "https://github.com/unknown/qdrant-haystack" +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/blob/main/integrations/qdrant/README.md" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations" [tool.hatch.version] path = "src/qdrant_haystack/__about__.py" @@ -158,3 +158,13 @@ exclude_lines = [ "if __name__ == .__main__.:", "if TYPE_CHECKING:", ] + +[[tool.mypy.overrides]] +module = [ + "haystack.*", + "pytest.*", + "qdrant_client.*", + "numpy", + "grpc" +] +ignore_missing_imports = true \ No newline at end of file From e1d96c3e1f9644fc7862429bc9f8fdc8729c8bed Mon Sep 17 00:00:00 2001 From: Anush Date: Wed, 13 Dec 2023 07:30:19 +0530 Subject: [PATCH 4/7] docs: README.md typo fix --- integrations/qdrant/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/qdrant/README.md b/integrations/qdrant/README.md index 4fbe69f6e..18bf5b7a8 100644 --- a/integrations/qdrant/README.md +++ b/integrations/qdrant/README.md @@ -17,7 +17,7 @@ pip install qdrant-haystack ``` ## Testing -The test suites use Qdran't in-memory instance. No additional steps required. +The test suites use Qdrant's in-memory instance. No additional steps required. ```console hatch run test From da6f2c1711b7eb0882361a6913dfb4f934ddc231 Mon Sep 17 00:00:00 2001 From: Anush Date: Wed, 13 Dec 2023 07:50:52 +0530 Subject: [PATCH 5/7] chore: update pyproject.toml (#2) * chore: pin pyproject.toml version * Update pyproject.toml * Update pyproject.toml --- integrations/qdrant/pyproject.toml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/integrations/qdrant/pyproject.toml b/integrations/qdrant/pyproject.toml index bf4387f2f..f8209e0c9 100644 --- a/integrations/qdrant/pyproject.toml +++ b/integrations/qdrant/pyproject.toml @@ -11,7 +11,8 @@ requires-python = ">=3.7" license = "Apache-2.0" keywords = [] authors = [ - { name = "deepset GmbH", email = "info@deepset.ai" }, + { name = "Kacper Ɓukawski", email = "kacper.lukawski@qdrant.com" }, + { name = "Anush Shetty", email = "anush.shetty@qdrant.com" }, ] classifiers = [ "Development Status :: 4 - Beta", @@ -30,9 +31,9 @@ dependencies = [ ] [project.urls] +Source = "https://github.com/deepset-ai/haystack-core-integrations" Documentation = "https://github.com/deepset-ai/haystack-core-integrations/blob/main/integrations/qdrant/README.md" Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" -Source = "https://github.com/deepset-ai/haystack-core-integrations" [tool.hatch.version] path = "src/qdrant_haystack/__about__.py" @@ -167,4 +168,4 @@ module = [ "numpy", "grpc" ] -ignore_missing_imports = true \ No newline at end of file +ignore_missing_imports = true From 538bf9b4eee90b988d93d3910e41409d5c0e3872 Mon Sep 17 00:00:00 2001 From: Anush Date: Fri, 15 Dec 2023 18:42:18 +0530 Subject: [PATCH 6/7] Delete integrations/qdrant/src/qdrant_haystack/utils.py --- integrations/qdrant/src/qdrant_haystack/utils.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 integrations/qdrant/src/qdrant_haystack/utils.py diff --git a/integrations/qdrant/src/qdrant_haystack/utils.py b/integrations/qdrant/src/qdrant_haystack/utils.py deleted file mode 100644 index e69de29bb..000000000 From 726accbab63c78e13c2dcc492a1d4de9e3cb7816 Mon Sep 17 00:00:00 2001 From: Julian Risch Date: Fri, 15 Dec 2023 14:22:37 +0100 Subject: [PATCH 7/7] Rename test_dict_convertors.py to test_dict_converters.py --- .../tests/{test_dict_convertors.py => test_dict_converters.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename integrations/qdrant/tests/{test_dict_convertors.py => test_dict_converters.py} (100%) diff --git a/integrations/qdrant/tests/test_dict_convertors.py b/integrations/qdrant/tests/test_dict_converters.py similarity index 100% rename from integrations/qdrant/tests/test_dict_convertors.py rename to integrations/qdrant/tests/test_dict_converters.py