From 253bc7fc1c6e46ac7cd22f93a2608239465232c2 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Mon, 16 Oct 2023 15:22:02 +0200 Subject: [PATCH 01/21] Add elastichsearch-haystack package --- document_stores/elasticsearch/.gitignore | 163 +++++++++ document_stores/elasticsearch/LICENSE | 201 +++++++++++ document_stores/elasticsearch/README.md | 32 ++ .../elasticsearch/docker-compose.yml | 15 + document_stores/elasticsearch/pyproject.toml | 178 ++++++++++ .../src/elasticsearch_haystack/__about__.py | 4 + .../src/elasticsearch_haystack/__init__.py | 6 + .../elasticsearch_haystack/bm25_retriever.py | 51 +++ .../elasticsearch_haystack/document_store.py | 321 ++++++++++++++++++ .../src/elasticsearch_haystack/filters.py | 136 ++++++++ .../elasticsearch/tests/__init__.py | 3 + .../tests/test_bm25_retriever.py | 74 ++++ .../tests/test_document_store.py | 188 ++++++++++ .../elasticsearch/tests/test_filters.py | 169 +++++++++ 14 files changed, 1541 insertions(+) create mode 100644 document_stores/elasticsearch/.gitignore create mode 100644 document_stores/elasticsearch/LICENSE create mode 100644 document_stores/elasticsearch/README.md create mode 100644 document_stores/elasticsearch/docker-compose.yml create mode 100644 document_stores/elasticsearch/pyproject.toml create mode 100644 document_stores/elasticsearch/src/elasticsearch_haystack/__about__.py create mode 100644 document_stores/elasticsearch/src/elasticsearch_haystack/__init__.py create mode 100644 document_stores/elasticsearch/src/elasticsearch_haystack/bm25_retriever.py create mode 100644 document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py create mode 100644 document_stores/elasticsearch/src/elasticsearch_haystack/filters.py create mode 100644 document_stores/elasticsearch/tests/__init__.py create mode 100644 document_stores/elasticsearch/tests/test_bm25_retriever.py create mode 100644 document_stores/elasticsearch/tests/test_document_store.py create mode 100644 document_stores/elasticsearch/tests/test_filters.py diff --git a/document_stores/elasticsearch/.gitignore b/document_stores/elasticsearch/.gitignore new file mode 100644 index 000000000..d1c340c1f --- /dev/null +++ b/document_stores/elasticsearch/.gitignore @@ -0,0 +1,163 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# VS Code +.vscode diff --git a/document_stores/elasticsearch/LICENSE b/document_stores/elasticsearch/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/document_stores/elasticsearch/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/document_stores/elasticsearch/README.md b/document_stores/elasticsearch/README.md new file mode 100644 index 000000000..7e70ad6e4 --- /dev/null +++ b/document_stores/elasticsearch/README.md @@ -0,0 +1,32 @@ +[![test](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/document_stores_elasticsearch.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/document_stores_elasticsearch.yml) + +[![PyPI - Version](https://img.shields.io/pypi/v/elasticsearch-haystack.svg)](https://pypi.org/project/elasticsearch-haystack) +[![PyPI - Python Version](https://img.shields.io/pypi/pyversions/elasticsearch-haystack.svg)](https://pypi.org/project/elasticsearch-haystack) + +# Elasticsearch Document Store + +Document Store for Haystack 2.x, supports ElasticSearch 8. + +## Installation + +```console +pip install elasticsearch-haystack +``` + +## Testing + +To run tests first start a Docker container running ElasticSearch. We provide a utility `docker-compose.yml` for that: + +```console +docker-compose up +``` + +Then run tests: + +```console +hatch run test +``` + +## License + +`elasticsearch-haystack` is distributed under the terms of the [Apache-2.0](https://spdx.org/licenses/Apache-2.0.html) license. diff --git a/document_stores/elasticsearch/docker-compose.yml b/document_stores/elasticsearch/docker-compose.yml new file mode 100644 index 000000000..6d21941b7 --- /dev/null +++ b/document_stores/elasticsearch/docker-compose.yml @@ -0,0 +1,15 @@ +services: + elasticsearch: + image: "docker.elastic.co/elasticsearch/elasticsearch:8.10.0" + ports: + - 9200:9200 + restart: on-failure + environment: + - discovery.type=single-node + - xpack.security.enabled=false + - "ES_JAVA_OPTS=-Xms1024m -Xmx1024m" + healthcheck: + test: curl --fail http://localhost:9200/_cat/health || exit 1 + interval: 10s + timeout: 1s + retries: 10 \ No newline at end of file diff --git a/document_stores/elasticsearch/pyproject.toml b/document_stores/elasticsearch/pyproject.toml new file mode 100644 index 000000000..8da7e2f0d --- /dev/null +++ b/document_stores/elasticsearch/pyproject.toml @@ -0,0 +1,178 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "elasticsearch-haystack" +dynamic = ["version"] +description = 'Haystack 2.x Document Store for ElasticSearch' +readme = "README.md" +requires-python = ">=3.8" +license = "Apache-2.0" +keywords = [] +authors = [ + { name = "Silvano Cerza", email = "silvanocerza@gmail.com" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + # we distribute the preview version of Haystack 2.0 under the package "haystack-ai" + "haystack-ai", + "elasticsearch>=8,<9" +] + +[project.urls] +Documentation = "https://github.com/deepset-ai/elasticsearch-haystack#readme" +Issues = "https://github.com/deepset-ai/elasticsearch-haystack/issues" +Source = "https://github.com/deepset-ai/elasticsearch-haystack" + +[tool.hatch.version] +path = "src/elasticsearch_haystack/__about__.py" + +[tool.hatch.envs.default] +dependencies = [ + "coverage[toml]>=6.5", + "pytest", + "pytest-xdist", +] +[tool.hatch.envs.default.scripts] +test = "pytest {args:tests}" +test-cov = "coverage run -m pytest {args:tests}" +cov-report = [ + "- coverage combine", + "coverage report", +] +cov = [ + "test-cov", + "cov-report", +] + +[[tool.hatch.envs.all.matrix]] +python = ["3.8", "3.9", "3.10", "3.11"] + +[tool.hatch.envs.lint] +detached = true +dependencies = [ + "black>=23.1.0", + "mypy>=1.0.0", + "ruff>=0.0.243", +] +[tool.hatch.envs.lint.scripts] +typing = "mypy --install-types --non-interactive {args:src/elasticsearch_haystack tests}" +style = [ + "ruff {args:.}", + "black --check --diff {args:.}", +] +fmt = [ + "black {args:.}", + "ruff --fix {args:.}", + "style", +] +all = [ + "style", + "typing", +] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.black] +target-version = ["py37"] +line-length = 120 +skip-string-normalization = true + +[tool.ruff] +target-version = "py37" +line-length = 120 +select = [ + "A", + "ARG", + "B", + "C", + "DTZ", + "E", + "EM", + "F", + "FBT", + "I", + "ICN", + "ISC", + "N", + "PLC", + "PLE", + "PLR", + "PLW", + "Q", + "RUF", + "S", + "T", + "TID", + "UP", + "W", + "YTT", +] +ignore = [ + # Allow non-abstract empty methods in abstract base classes + "B027", + # Allow boolean positional values in function calls, like `dict.get(... True)` + "FBT003", + # Ignore checks for possible passwords + "S105", "S106", "S107", + # Ignore complexity + "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", +] +unfixable = [ + # Don't touch unused imports + "F401", +] + +[tool.ruff.isort] +known-first-party = ["elasticsearch_haystack"] + +[tool.ruff.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.per-file-ignores] +# Tests can use magic values, assertions, and relative imports +"tests/**/*" = ["PLR2004", "S101", "TID252"] + +[tool.coverage.run] +source_pkgs = ["elasticsearch_haystack", "tests"] +branch = true +parallel = true +omit = [ + "src/elasticsearch_haystack/__about__.py", +] + +[tool.coverage.paths] +elasticsearch_haystack = ["src/elasticsearch_haystack", "*/elasticsearch-haystack/src/elasticsearch_haystack"] +tests = ["tests", "*/elasticsearch-haystack/tests"] + +[tool.coverage.report] +exclude_lines = [ + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", +] + +[tool.pytest.ini_options] +minversion = "6.0" +markers = [ + "unit: unit tests", + "integration: integration tests" +] + +[[tool.mypy.overrides]] +module = [ + "haystack.*", + "pytest.*" +] +ignore_missing_imports = true diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/__about__.py b/document_stores/elasticsearch/src/elasticsearch_haystack/__about__.py new file mode 100644 index 000000000..f3717f266 --- /dev/null +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/__about__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: 2023-present Silvano Cerza +# +# SPDX-License-Identifier: Apache-2.0 +__version__ = "0.0.1" diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/__init__.py b/document_stores/elasticsearch/src/elasticsearch_haystack/__init__.py new file mode 100644 index 000000000..af32a762d --- /dev/null +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/__init__.py @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: 2023-present Silvano Cerza +# +# SPDX-License-Identifier: Apache-2.0 +from elasticsearch_haystack.document_store import ElasticsearchDocumentStore + +__all__ = ["ElasticsearchDocumentStore"] diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/bm25_retriever.py b/document_stores/elasticsearch/src/elasticsearch_haystack/bm25_retriever.py new file mode 100644 index 000000000..485ce2a15 --- /dev/null +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/bm25_retriever.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: 2023-present Silvano Cerza +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Dict, List, Optional + +from haystack.preview import component, default_from_dict, default_to_dict +from haystack.preview.dataclasses import Document + +from elasticsearch_haystack.document_store import ElasticsearchDocumentStore + + +@component +class ElasticsearchBM25Retriever: + def __init__( + self, + *, + document_store: ElasticsearchDocumentStore, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + scale_score: bool = True, + ): + if not isinstance(document_store, ElasticsearchDocumentStore): + raise + + self._document_store = document_store + self._filters = filters or {} + self._top_k = top_k + self._scale_score = scale_score + + def to_dict(self) -> Dict[str, Any]: + return default_to_dict( + self, + filters=self._filters, + top_k=self._top_k, + scale_score=self._scale_score, + document_store=self._document_store.to_dict(), + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ElasticsearchBM25Retriever": + data["init_parameters"]["document_store"] = ElasticsearchDocumentStore.from_dict( + data["init_parameters"]["document_store"] + ) + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query: str): + docs = self._document_store._bm25_retrieval( + query=query, filters=self._filters, top_k=self._top_k, scale_score=self._scale_score + ) + return {"documents": docs} diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py new file mode 100644 index 000000000..16659d465 --- /dev/null +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py @@ -0,0 +1,321 @@ +# SPDX-FileCopyrightText: 2023-present Silvano Cerza +# +# SPDX-License-Identifier: Apache-2.0 +import json +import logging +from typing import Any, Dict, List, Mapping, Optional, Union + +import numpy as np +from elastic_transport import NodeConfig +from elasticsearch import Elasticsearch, helpers +from haystack.preview import default_from_dict, default_to_dict +from haystack.preview.dataclasses import Document +from haystack.preview.document_stores.decorator import document_store +from haystack.preview.document_stores.errors import DuplicateDocumentError +from haystack.preview.document_stores.protocols import DuplicatePolicy +from pandas import DataFrame + +from elasticsearch_haystack.filters import _normalize_filters + +logger = logging.getLogger(__name__) + +Hosts = Union[str, List[Union[str, Mapping[str, Union[str, int]], NodeConfig]]] + + +@document_store +class ElasticsearchDocumentStore: + def __init__(self, *, hosts: Optional[Hosts] = None, index: str = "default", **kwargs): + """ + Creates a new ElasticsearchDocumentStore instance. + + For more information on connection parameters, see the official Elasticsearch documentation: + https://www.elastic.co/guide/en/elasticsearch/client/python-api/current/connecting.html + + For the full list of supported kwargs, see the official Elasticsearch reference: + https://elasticsearch-py.readthedocs.io/en/stable/api.html#module-elasticsearch + + :param hosts: List of hosts running the Elasticsearch client. Defaults to None + :param index: Name of index in Elasticsearch, if it doesn't exist it will be created. Defaults to "default" + :param **kwargs: Optional arguments that ``Elasticsearch`` takes. + """ + self._hosts = hosts + self._client = Elasticsearch(hosts, **kwargs) + self._index = index + self._kwargs = kwargs + + # Check client connection, this will raise if not connected + self._client.info() + + # Create the index if it doesn't exist + if not self._client.indices.exists(index=index): + self._client.indices.create(index=index) + + def to_dict(self) -> Dict[str, Any]: + # This is not the best solution to serialise this class but is the fastest to implement. + # Not all kwargs types can be serialised to text so this can fail. We must serialise each + # type explicitly to handle this properly. + return default_to_dict( + self, + hosts=self._hosts, + index=self._index, + **self._kwargs, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ElasticsearchDocumentStore": + return default_from_dict(cls, data) + + def count_documents(self) -> int: + """ + Returns how many documents are present in the document store. + """ + return self._client.count(index=self._index)["count"] + + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: + """ + Returns the documents that match the filters provided. + + Filters are defined as nested dictionaries. The keys of the dictionaries can be a logical operator (`"$and"`, + `"$or"`, `"$not"`), a comparison operator (`"$eq"`, `$ne`, `"$in"`, `$nin`, `"$gt"`, `"$gte"`, `"$lt"`, + `"$lte"`) or a metadata field name. + + Logical operator keys take a dictionary of metadata field names and/or logical operators as value. Metadata + field names take a dictionary of comparison operators as value. Comparison operator keys take a single value or + (in case of `"$in"`) a list of values as value. If no logical operator is provided, `"$and"` is used as default + operation. If no comparison operator is provided, `"$eq"` (or `"$in"` if the comparison value is a list) is used + as default operation. + + Example: + + ```python + filters = { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": {"$in": ["economy", "politics"]}, + "publisher": {"$eq": "nytimes"} + } + } + } + # or simpler using default operators + filters = { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": { + "genre": ["economy", "politics"], + "publisher": "nytimes" + } + } + ``` + + To use the same logical operator multiple times on the same level, logical operators can take a list of + dictionaries as value. + + Example: + + ```python + filters = { + "$or": [ + { + "$and": { + "Type": "News Paper", + "Date": { + "$lt": "2019-01-01" + } + } + }, + { + "$and": { + "Type": "Blog Post", + "Date": { + "$gte": "2019-01-01" + } + } + } + ] + } + ``` + + :param filters: the filters to apply to the document list. + :return: a list of Documents that match the given filters. + """ + query = {"bool": {"filter": _normalize_filters(filters)}} if filters else None + + res = self._client.search( + index=self._index, + query=query, + ) + + return [self._deserialize_document(hit) for hit in res["hits"]["hits"]] + + def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None: + """ + Writes (or overwrites) documents into the store. + + :param documents: a list of documents. + :param policy: documents with the same ID count as duplicates. When duplicates are met, + the store can: + - skip: keep the existing document and ignore the new one. + - overwrite: remove the old document and write the new one. + - fail: an error is raised + :raises DuplicateDocumentError: Exception trigger on duplicate document if `policy=DuplicatePolicy.FAIL` + :return: None + """ + if len(documents) > 0: + if not isinstance(documents[0], Document): + msg = "param 'documents' must contain a list of objects of type Document" + raise ValueError(msg) + + action = "index" if policy == DuplicatePolicy.OVERWRITE else "create" + _, errors = helpers.bulk( + client=self._client, + actions=( + {"_op_type": action, "_id": doc.id, "_source": self._serialize_document(doc)} for doc in documents + ), + refresh="wait_for", + index=self._index, + raise_on_error=False, + ) + if errors and policy == DuplicatePolicy.FAIL: + # TODO: Handle errors in a better way, we're assuming that all errors + # are related to duplicate documents but that could be very well be wrong. + + # mypy complains that `errors`` could be either `int` or a `list` of `dict`s. + # Since the type depends on the parameters passed to `helpers.bulk()`` we know + # for sure that it will be a `list`. + ids = ", ".join(e["create"]["_id"] for e in errors) # type: ignore[union-attr] + msg = f"IDs '{ids}' already exist in the document store." + raise DuplicateDocumentError(msg) + + def _deserialize_document(self, hit: Dict[str, Any]) -> Document: + """ + Creates a Document from the search hit provided. + This is mostly useful in self.filter_documents(). + """ + data = hit["_source"] + + if "highlight" in hit: + data["metadata"]["highlighted"] = hit["highlight"] + data["score"] = hit["_score"] + + if array := data["array"]: + data["array"] = np.asarray(array, dtype=np.float32) + if dataframe := data["dataframe"]: + data["dataframe"] = DataFrame.from_dict(json.loads(dataframe)) + if embedding := data["embedding"]: + data["embedding"] = np.asarray(embedding, dtype=np.float32) + + # We can't use Document.from_dict() as the data dictionary contains + # all the metadata fields + return Document( + id=data["id"], + text=data["text"], + array=data["array"], + dataframe=data["dataframe"], + blob=data["blob"], + mime_type=data["mime_type"], + metadata=data["metadata"], + id_hash_keys=data["id_hash_keys"], + score=data["score"], + embedding=data["embedding"], + ) + + def _serialize_document(self, doc: Document) -> Dict[str, Any]: + """ + Serializes Document to a dictionary handling conversion of Pandas' dataframe + and NumPy arrays if present. + """ + # We don't use doc.flatten() cause we want to keep the metadata field + # as it makes it easier to recreate the Document object when calling + # self.filter_document(). + # Otherwise we'd have to filter out the fields that are not part of the + # Document dataclass and keep them as metadata. This is faster and easier. + res = {**doc.to_dict(), **doc.metadata} + if res["array"] is not None: + res["array"] = res["array"].tolist() + if res["dataframe"] is not None: + # Convert dataframe to a json string + res["dataframe"] = res["dataframe"].to_json() + if res["embedding"] is not None: + res["embedding"] = res["embedding"].tolist() + return res + + def delete_documents(self, document_ids: List[str]) -> None: + """ + Deletes all documents with a matching document_ids from the document store. + + :param object_ids: the object_ids to delete + """ + + # + helpers.bulk( + client=self._client, + actions=({"_op_type": "delete", "_id": id_} for id_ in document_ids), + refresh="wait_for", + index=self._index, + raise_on_error=False, + ) + + def _bm25_retrieval( + self, + query: str, + *, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + scale_score: bool = True, + ) -> List[Document]: + """ + Elasticsearch by defaults uses BM25 search algorithm. + Even though this method is called `bm25_retrieval` it searches for `query` + using the search algorithm `_client` was configured with. + + This method is not mean to be part of the public interface of + `ElasticsearchDocumentStore` nor called directly. + `ElasticsearchBM25Retriever` uses this method directly and is the public interface for it. + + `query` must be a non empty string, otherwise a `ValueError` will be raised. + + :param query: String to search in saved Documents' text. + :param filters: Filters applied to the retrieved Documents, for more info + see `ElasticsearchDocumentStore.filter_documents`, defaults to None + :param top_k: Maximum number of Documents to return, defaults to 10 + :param scale_score: If `True` scales the Document`s scores between 0 and 1, defaults to True + :raises ValueError: If `query` is an empty string + :return: List of Document that match `query` + """ + + if not query: + msg = "query must be a non empty string" + raise ValueError(msg) + + body: Dict[str, Any] = { + "size": top_k, + "query": { + "bool": { + "must": [ + { + "multi_match": { + "query": query, + "type": "most_fields", + "operator": "AND", + } + } + ] + } + }, + } + + if filters: + body["query"]["bool"]["filter"] = _normalize_filters(filters) + + res = self._client.search(index=self._index, **body) + + docs = [] + for hit in res["hits"]["hits"]: + if scale_score: + hit["_score"] = float(1 / (1 + np.exp(-np.asarray(hit["_score"] / 8)))) + docs.append(self._deserialize_document(hit)) + return docs diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py b/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py new file mode 100644 index 000000000..31f111e07 --- /dev/null +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py @@ -0,0 +1,136 @@ +from typing import Any, Dict, List, Union + +import numpy as np +from haystack.preview.errors import FilterError +from pandas import DataFrame + + +def _normalize_filters(filters: Union[List[Dict], Dict], logical_condition="") -> Dict[str, Any]: + """ + Converts Haystack filters in ElasticSearch compatible filters. + """ + if not isinstance(filters, dict) and not isinstance(filters, list): + msg = "Filters must be either a dictionary or a list" + raise FilterError(msg) + conditions = [] + if isinstance(filters, dict): + filters = [filters] + for filter_ in filters: + for operator, value in filter_.items(): + if operator in ["$not", "$and", "$or"]: + # Logical operators + conditions.append(_normalize_filters(value, operator)) + else: + # Comparison operators + conditions.extend(_parse_comparison(operator, value)) + + if len(conditions) == 1: + return conditions[0] + + conditions = _normalize_ranges(conditions) + + if logical_condition == "$not": + return {"bool": {"must_not": conditions}} + elif logical_condition == "$or": + return {"bool": {"should": conditions}} + + # If no logical condition is specified we default to "$and" + return {"bool": {"must": conditions}} + + +def _parse_comparison(field: str, comparison: Union[Dict, List, str, float]) -> List: + result: List[Dict[str, Any]] = [] + if isinstance(comparison, dict): + for comparator, val in comparison.items(): + if comparator == "$eq": + if isinstance(val, list): + result.append( + { + "terms_set": { + field: { + "terms": val, + "minimum_should_match_script": { + "source": f"Math.max(params.num_terms, doc['{field}'].size())" + }, + } + } + } + ) + result.append({"term": {field: val}}) + elif comparator == "$ne": + if isinstance(val, list): + msg = f"{field}'s value can't be a list when using '{comparator}' comparator" + raise FilterError(msg) + result.append({"bool": {"must_not": {"term": {field: val}}}}) + elif comparator == "$in": + if not isinstance(val, list): + msg = f"{field}'s value must be a list when using '{comparator}' comparator" + raise FilterError(msg) + result.append({"terms": {field: val}}) + elif comparator == "$nin": + if not isinstance(val, list): + msg = f"{field}'s value must be a list when using '{comparator}' comparator" + raise FilterError(msg) + result.append({"bool": {"must_not": {"terms": {field: val}}}}) + elif comparator in ["$gt", "$gte", "$lt", "$lte"]: + if isinstance(val, list): + msg = f"{field}'s value can't be a list when using '{comparator}' comparator" + raise FilterError(msg) + result.append({"range": {field: {comparator[1:]: val}}}) + elif comparator in ["$not", "$or"]: + result.append(_normalize_filters(val, comparator)) + elif comparator == "$and" and isinstance(val, list): + # We're assuming there are no duplicate items in the list + flat_filters = {k: v for d in val for k, v in d.items()} + result.extend(_parse_comparison(field, flat_filters)) + elif comparator == "$and": + result.append(_normalize_filters({field: val}, comparator)) + else: + msg = f"Unknown comparator '{comparator}'" + raise FilterError(msg) + elif isinstance(comparison, list): + result.append({"terms": {field: comparison}}) + elif isinstance(comparison, np.ndarray): + result.append({"terms": {field: comparison.tolist()}}) + elif isinstance(comparison, DataFrame): + # We're saving dataframes as json strings so we compare them as such + result.append({"match": {field: comparison.to_json()}}) + elif isinstance(comparison, str): + # We can't use "term" for text fields as ElasticSearch changes the value of text. + # More info here: + # https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-term-query.html#query-dsl-term-query + result.append({"match": {field: comparison}}) + else: + result.append({"term": {field: comparison}}) + return result + + +def _normalize_ranges(conditions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Merges range conditions acting on a same field. + + Example usage: + + ```python + conditions = [ + {"range": {"date": {"lt": "2021-01-01"}}}, + {"range": {"date": {"gte": "2015-01-01"}}}, + ] + conditions = _normalize_ranges(conditions) + assert conditions == [ + {"range": {"date": {"lt": "2021-01-01", "gte": "2015-01-01"}}}, + ] + ``` + """ + range_conditions = [next(iter(c["range"].items())) for c in conditions if "range" in c] + if range_conditions: + conditions = [c for c in conditions if "range" not in c] + range_conditions_dict: Dict[str, Any] = {} + for field_name, comparison in range_conditions: + if field_name not in range_conditions_dict: + range_conditions_dict[field_name] = {} + range_conditions_dict[field_name].update(comparison) + + for field_name, comparisons in range_conditions_dict.items(): + conditions.append({"range": {field_name: comparisons}}) + return conditions diff --git a/document_stores/elasticsearch/tests/__init__.py b/document_stores/elasticsearch/tests/__init__.py new file mode 100644 index 000000000..ec55bfc66 --- /dev/null +++ b/document_stores/elasticsearch/tests/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2023-present Silvano Cerza +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/document_stores/elasticsearch/tests/test_bm25_retriever.py b/document_stores/elasticsearch/tests/test_bm25_retriever.py new file mode 100644 index 000000000..530e85e3a --- /dev/null +++ b/document_stores/elasticsearch/tests/test_bm25_retriever.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: 2023-present Silvano Cerza +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import Mock, patch + +from haystack.preview.dataclasses import Document + +from elasticsearch_haystack.bm25_retriever import ElasticsearchBM25Retriever +from elasticsearch_haystack.document_store import ElasticsearchDocumentStore + + +def test_init_default(): + mock_store = Mock(spec=ElasticsearchDocumentStore) + retriever = ElasticsearchBM25Retriever(document_store=mock_store) + assert retriever._document_store == mock_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._scale_score + + +@patch("elasticsearch_haystack.document_store.Elasticsearch") +def test_to_dict(_mock_elasticsearch_client): + document_store = ElasticsearchDocumentStore(hosts="some fake host") + retriever = ElasticsearchBM25Retriever(document_store=document_store) + res = retriever.to_dict() + assert res == { + "type": "ElasticsearchBM25Retriever", + "init_parameters": { + "document_store": { + "init_parameters": {"hosts": "some fake host", "index": "default"}, + "type": "ElasticsearchDocumentStore", + }, + "filters": {}, + "top_k": 10, + "scale_score": True, + }, + } + + +@patch("elasticsearch_haystack.document_store.Elasticsearch") +def test_from_dict(_mock_elasticsearch_client): + data = { + "type": "ElasticsearchBM25Retriever", + "init_parameters": { + "document_store": { + "init_parameters": {"hosts": "some fake host", "index": "default"}, + "type": "ElasticsearchDocumentStore", + }, + "filters": {}, + "top_k": 10, + "scale_score": True, + }, + } + retriever = ElasticsearchBM25Retriever.from_dict(data) + assert retriever._document_store + assert retriever._filters == {} + assert retriever._top_k == 10 + assert retriever._scale_score + + +def test_run(): + mock_store = Mock(spec=ElasticsearchDocumentStore) + mock_store._bm25_retrieval.return_value = [Document(text="Test doc")] + retriever = ElasticsearchBM25Retriever(document_store=mock_store) + res = retriever.run(query="some query") + mock_store._bm25_retrieval.assert_called_once_with( + query="some query", + filters={}, + top_k=10, + scale_score=True, + ) + assert len(res) == 1 + assert len(res["documents"]) == 1 + assert res["documents"][0].text == "Test doc" diff --git a/document_stores/elasticsearch/tests/test_document_store.py b/document_stores/elasticsearch/tests/test_document_store.py new file mode 100644 index 000000000..b40e04ba5 --- /dev/null +++ b/document_stores/elasticsearch/tests/test_document_store.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: 2023-present Silvano Cerza +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import patch + +import pytest +from haystack.preview.dataclasses.document import Document +from haystack.preview.document_stores.errors import DuplicateDocumentError +from haystack.preview.document_stores.protocols import DuplicatePolicy +from haystack.preview.testing.document_store import DocumentStoreBaseTests + +from elasticsearch_haystack.document_store import ElasticsearchDocumentStore + + +class TestDocumentStore(DocumentStoreBaseTests): + """ + Common test cases will be provided by `DocumentStoreBaseTests` but + you can add more to this class. + """ + + @pytest.fixture + def docstore(self, request): + """ + This is the most basic requirement for the child class: provide + an instance of this document store so the base class can use it. + """ + hosts = ["http://localhost:9200"] + # Use a different index for each test so we can run them in parallel + index = f"{request.node.name}" + + store = ElasticsearchDocumentStore(hosts=hosts, index=index) + yield store + store._client.options(ignore_status=[400, 404]).indices.delete(index=index) + + @patch("elasticsearch_haystack.document_store.Elasticsearch") + def test_to_dict(self, _mock_elasticsearch_client): + document_store = ElasticsearchDocumentStore(hosts="some hosts") + res = document_store.to_dict() + assert res == { + "type": "ElasticsearchDocumentStore", + "init_parameters": { + "hosts": "some hosts", + "index": "default", + }, + } + + @patch("elasticsearch_haystack.document_store.Elasticsearch") + def test_from_dict(self, _mock_elasticsearch_client): + data = { + "type": "ElasticsearchDocumentStore", + "init_parameters": { + "hosts": "some hosts", + "index": "default", + }, + } + document_store = ElasticsearchDocumentStore.from_dict(data) + assert document_store._hosts == "some hosts" + assert document_store._index == "default" + + def test_bm25_retrieval(self, docstore: ElasticsearchDocumentStore): + docstore.write_documents( + [ + Document(text="Haskell is a functional programming language"), + Document(text="Lisp is a functional programming language"), + Document(text="Exilir is a functional programming language"), + Document(text="F# is a functional programming language"), + Document(text="C# is a functional programming language"), + Document(text="C++ is an object oriented programming language"), + Document(text="Dart is an object oriented programming language"), + Document(text="Go is an object oriented programming language"), + Document(text="Python is a object oriented programming language"), + Document(text="Ruby is a object oriented programming language"), + Document(text="PHP is a object oriented programming language"), + ] + ) + + res = docstore._bm25_retrieval("functional", top_k=3) + assert len(res) == 3 + assert "functional" in res[0].text + assert "functional" in res[1].text + assert "functional" in res[2].text + + def test_write_duplicate_fail(self, docstore: ElasticsearchDocumentStore): + """ + Verify `DuplicateDocumentError` is raised when trying to write duplicate files. + + `DocumentStoreBaseTests` declares this test but we override it since we return + a different error message that it expects. + """ + doc = Document(text="test doc") + docstore.write_documents([doc]) + with pytest.raises(DuplicateDocumentError): + docstore.write_documents(documents=[doc], policy=DuplicatePolicy.FAIL) + assert docstore.filter_documents(filters={"id": doc.id}) == [doc] + + def test_delete_not_empty(self, docstore: ElasticsearchDocumentStore): + """ + Verifies delete properly deletes specified document. + + `DocumentStoreBaseTests` declares this test but we override it since we + want `delete_documents` to be idempotent. + """ + doc = Document(text="test doc") + docstore.write_documents([doc]) + + docstore.delete_documents([doc.id]) + + res = docstore.filter_documents(filters={"id": doc.id}) + assert res == [] + + def test_delete_empty(self, docstore: ElasticsearchDocumentStore): + """ + Verifies delete doesn't raises when trying to delete a non-existing document. + + `DocumentStoreBaseTests` declares this test but we override it since we + want `delete_documents` to be idempotent. + """ + docstore.delete_documents(["test"]) + + def test_delete_not_empty_nonexisting(self, docstore: ElasticsearchDocumentStore): + """ + Verifies delete properly deletes specified document in DocumentStore containing + multiple documents. + + `DocumentStoreBaseTests` declares this test but we override it since we + want `delete_documents` to be idempotent. + """ + doc = Document(text="test doc") + docstore.write_documents([doc]) + + docstore.delete_documents(["non_existing"]) + + assert docstore.filter_documents(filters={"id": doc.id}) == [doc] + + # The tests below are filters not supported by ElasticsearchDocumentStore + def test_in_filter_table(self): + pass + + def test_in_filter_embedding(self): + pass + + def test_ne_filter_table(self): + pass + + def test_ne_filter_embedding(self): + pass + + def test_nin_filter_table(self): + pass + + def test_nin_filter_embedding(self): + pass + + def test_gt_filter_non_numeric(self): + pass + + def test_gt_filter_table(self): + pass + + def test_gt_filter_embedding(self): + pass + + def test_gte_filter_non_numeric(self): + pass + + def test_gte_filter_table(self): + pass + + def test_gte_filter_embedding(self): + pass + + def test_lt_filter_non_numeric(self): + pass + + def test_lt_filter_table(self): + pass + + def test_lt_filter_embedding(self): + pass + + def test_lte_filter_non_numeric(self): + pass + + def test_lte_filter_table(self): + pass + + def test_lte_filter_embedding(self): + pass diff --git a/document_stores/elasticsearch/tests/test_filters.py b/document_stores/elasticsearch/tests/test_filters.py new file mode 100644 index 000000000..2e8a320d4 --- /dev/null +++ b/document_stores/elasticsearch/tests/test_filters.py @@ -0,0 +1,169 @@ +import pytest +from haystack.preview.errors import FilterError + +from elasticsearch_haystack.filters import _normalize_filters, _normalize_ranges + +filters_data = [ + ( + { + "$and": { + "type": {"$eq": "article"}, + "$or": {"genre": {"$in": ["economy", "politics"]}, "publisher": {"$eq": "nytimes"}}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + } + }, + { + "bool": { + "must": [ + {"term": {"type": "article"}}, + { + "bool": { + "should": [ + {"terms": {"genre": ["economy", "politics"]}}, + {"term": {"publisher": "nytimes"}}, + ] + } + }, + {"range": {"date": {"gte": "2015-01-01", "lt": "2021-01-01"}}}, + {"range": {"rating": {"gte": 3}}}, + ] + } + }, + ), + ( + { + "$or": [ + {"Type": "News Paper", "Date": {"$lt": "2019-01-01"}}, + {"Type": "Blog Post", "Date": {"$gte": "2019-01-01"}}, + ] + }, + { + "bool": { + "should": [ + {"match": {"Type": "News Paper"}}, + {"match": {"Type": "Blog Post"}}, + {"range": {"Date": {"gte": "2019-01-01", "lt": "2019-01-01"}}}, + ] + } + }, + ), + ( + { + "$and": { + "type": {"$eq": "article"}, + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": {"genre": {"$in": ["economy", "politics"]}, "publisher": {"$eq": "nytimes"}}, + } + }, + { + "bool": { + "must": [ + {"term": {"type": "article"}}, + { + "bool": { + "should": [ + {"terms": {"genre": ["economy", "politics"]}}, + {"term": {"publisher": "nytimes"}}, + ] + } + }, + {"range": {"date": {"gte": "2015-01-01", "lt": "2021-01-01"}}}, + {"range": {"rating": {"gte": 3}}}, + ] + } + }, + ), + ( + { + "type": "article", + "date": {"$gte": "2015-01-01", "$lt": "2021-01-01"}, + "rating": {"$gte": 3}, + "$or": {"genre": ["economy", "politics"], "publisher": "nytimes"}, + }, + { + "bool": { + "must": [ + {"match": {"type": "article"}}, + { + "bool": { + "should": [ + {"terms": {"genre": ["economy", "politics"]}}, + {"match": {"publisher": "nytimes"}}, + ] + } + }, + {"range": {"date": {"gte": "2015-01-01", "lt": "2021-01-01"}}}, + {"range": {"rating": {"gte": 3}}}, + ] + } + }, + ), + ({"text": "A Foo Document 1"}, {"match": {"text": "A Foo Document 1"}}), + ( + {"$or": {"name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}, "number": {"$lt": 1.0}}}, + { + "bool": { + "should": [ + {"bool": {"should": [{"match": {"$eq": "name_0"}}, {"match": {"$eq": "name_1"}}]}}, + {"range": {"number": {"lt": 1.0}}}, + ] + } + }, + ), + ( + {"$and": {"number": {"$and": {"$lte": 2, "$gte": 0}}, "name": {"$in": ["name_0", "name_1"]}}}, + { + "bool": { + "must": [ + {"bool": {"must": [{"range": {"number": {"lte": 2, "gte": 0}}}]}}, + {"terms": {"name": ["name_0", "name_1"]}}, + ] + } + }, + ), + ( + {"number": {"$lte": 2, "$gte": 0}, "name": ["name_0", "name_1"]}, + { + "bool": { + "must": [ + {"terms": {"name": ["name_0", "name_1"]}}, + {"range": {"number": {"lte": 2, "gte": 0}}}, + ] + } + }, + ), + ( + {"number": {"$and": [{"$lte": 2}, {"$gte": 0}]}}, + {"bool": {"must": [{"range": {"number": {"lte": 2, "gte": 0}}}]}}, + ), +] + + +@pytest.mark.parametrize("filters, expected", filters_data) +def test_normalize_filters(filters, expected): + result = _normalize_filters(filters) + assert result == expected + + +def test_normalize_filters_raises_with_malformed_filters(): + with pytest.raises(FilterError): + _normalize_filters("not a filter") + + with pytest.raises(FilterError): + _normalize_filters({"number": {"page": "100"}}) + + with pytest.raises(FilterError): + _normalize_filters({"number": {"page": {"chapter": "intro"}}}) + + +def test_normalize_ranges(): + conditions = [ + {"range": {"date": {"lt": "2021-01-01"}}}, + {"range": {"date": {"gte": "2015-01-01"}}}, + ] + conditions = _normalize_ranges(conditions) + assert conditions == [ + {"range": {"date": {"lt": "2021-01-01", "gte": "2015-01-01"}}}, + ] From 3ecff1cc0e630823fff6dafc0094d6e707386971 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Mon, 16 Oct 2023 15:22:24 +0200 Subject: [PATCH 02/21] Add workflow to run elastichsearch-haystack tests --- .../document_stores_elasticsearch.yml | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 .github/workflows/document_stores_elasticsearch.yml diff --git a/.github/workflows/document_stores_elasticsearch.yml b/.github/workflows/document_stores_elasticsearch.yml new file mode 100644 index 000000000..da425eb75 --- /dev/null +++ b/.github/workflows/document_stores_elasticsearch.yml @@ -0,0 +1,52 @@ +# This workflow comes from https://github.com/ofek/hatch-mypyc +# https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml +name: Test / ElasticsearchDocumentStore + +on: + push: + branches: + - main + pull_request: + paths: + - "document_stores/elasticsearch/**" + - ".github/workflows/document_stores_elasticsearch.yml" + +concurrency: + group: test-${{ github.head_ref }} + cancel-in-progress: true + +env: + PYTHONUNBUFFERED: "1" + FORCE_COLOR: "1" + +jobs: + run: + name: Python ${{ matrix.python-version }} on ${{ startsWith(matrix.os, 'macos-') && 'macOS' || startsWith(matrix.os, 'windows-') && 'Windows' || 'Linux' }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: ["3.8", "3.9", "3.10", "3.11"] + + steps: + - name: Support longpaths + if: matrix.os == 'windows-latest' + run: git config --system core.longpaths true + + - uses: actions/checkout@v3 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Hatch + run: pip install --upgrade hatch + + - name: Lint + if: matrix.python-version == '3.9' && runner.os == 'Linux' + run: hatch run lint:all + + - name: Run tests + run: hatch run cov From af64dd670e417a242ac61289318d5d511931219e Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Mon, 16 Oct 2023 15:28:41 +0200 Subject: [PATCH 03/21] Fix missing working directory --- .github/workflows/document_stores_elasticsearch.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/document_stores_elasticsearch.yml b/.github/workflows/document_stores_elasticsearch.yml index da425eb75..d5580e091 100644 --- a/.github/workflows/document_stores_elasticsearch.yml +++ b/.github/workflows/document_stores_elasticsearch.yml @@ -45,8 +45,10 @@ jobs: run: pip install --upgrade hatch - name: Lint + working-directory: document_stores/elasticsearch if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run lint:all - name: Run tests + working-directory: document_stores/elasticsearch run: hatch run cov From 3e6d25384101812971006a190671db1d882f2df4 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Mon, 16 Oct 2023 15:33:26 +0200 Subject: [PATCH 04/21] Fix linting --- .../src/elasticsearch_haystack/document_store.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py index 16659d465..65d91bce7 100644 --- a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py @@ -6,8 +6,10 @@ from typing import Any, Dict, List, Mapping, Optional, Union import numpy as np -from elastic_transport import NodeConfig -from elasticsearch import Elasticsearch, helpers + +# There are no import stubs for elastic_transport and elasticsearch so mypy fails +from elastic_transport import NodeConfig # type: ignore[import-not-found] +from elasticsearch import Elasticsearch, helpers # type: ignore[import-not-found] from haystack.preview import default_from_dict, default_to_dict from haystack.preview.dataclasses import Document from haystack.preview.document_stores.decorator import document_store @@ -173,7 +175,12 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D _, errors = helpers.bulk( client=self._client, actions=( - {"_op_type": action, "_id": doc.id, "_source": self._serialize_document(doc)} for doc in documents + { + "_op_type": action, + "_id": doc.id, + "_source": self._serialize_document(doc), + } + for doc in documents ), refresh="wait_for", index=self._index, From b3825f7468cf60d8074cce465d020b27ff347972 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Mon, 16 Oct 2023 15:35:09 +0200 Subject: [PATCH 05/21] Rename workflow --- .github/workflows/document_stores_elasticsearch.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/document_stores_elasticsearch.yml b/.github/workflows/document_stores_elasticsearch.yml index d5580e091..8b084331c 100644 --- a/.github/workflows/document_stores_elasticsearch.yml +++ b/.github/workflows/document_stores_elasticsearch.yml @@ -1,6 +1,6 @@ # This workflow comes from https://github.com/ofek/hatch-mypyc # https://github.com/ofek/hatch-mypyc/blob/5a198c0ba8660494d02716cfc9d79ce4adfb1442/.github/workflows/test.yml -name: Test / ElasticsearchDocumentStore +name: Test / document_stores / elasticsearch on: push: From aa18cc199dc5e95b33ffcb2e9f95515774168312 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Mon, 16 Oct 2023 16:11:09 +0200 Subject: [PATCH 06/21] Run ElasticSearch container in CI for testing --- .github/workflows/document_stores_elasticsearch.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/document_stores_elasticsearch.yml b/.github/workflows/document_stores_elasticsearch.yml index 8b084331c..06a93e014 100644 --- a/.github/workflows/document_stores_elasticsearch.yml +++ b/.github/workflows/document_stores_elasticsearch.yml @@ -49,6 +49,10 @@ jobs: if: matrix.python-version == '3.9' && runner.os == 'Linux' run: hatch run lint:all + - name: Run ElasticSearch container + working-directory: document_stores/elasticsearch + run: docker-compose up -d + - name: Run tests working-directory: document_stores/elasticsearch run: hatch run cov From a052341015a456aea70dba50b7620461edda243a Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Mon, 16 Oct 2023 16:15:12 +0200 Subject: [PATCH 07/21] Remove windows and macos --- .github/workflows/document_stores_elasticsearch.yml | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/.github/workflows/document_stores_elasticsearch.yml b/.github/workflows/document_stores_elasticsearch.yml index 06a93e014..f8ac8acff 100644 --- a/.github/workflows/document_stores_elasticsearch.yml +++ b/.github/workflows/document_stores_elasticsearch.yml @@ -26,14 +26,10 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, windows-latest, macos-latest] + os: [ubuntu-latest] python-version: ["3.8", "3.9", "3.10", "3.11"] steps: - - name: Support longpaths - if: matrix.os == 'windows-latest' - run: git config --system core.longpaths true - - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -46,7 +42,7 @@ jobs: - name: Lint working-directory: document_stores/elasticsearch - if: matrix.python-version == '3.9' && runner.os == 'Linux' + if: matrix.python-version == '3.9' run: hatch run lint:all - name: Run ElasticSearch container From 697de2549bcb71a00592ec8dc84f226bfc806ab6 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 9 Nov 2023 12:50:16 +0100 Subject: [PATCH 08/21] Rework filters conversion --- .../src/elasticsearch_haystack/filters.py | 30 ++--- .../elasticsearch/tests/test_filters.py | 114 +++++++++++------- 2 files changed, 87 insertions(+), 57 deletions(-) diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py b/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py index 31f111e07..f7d884e2a 100644 --- a/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py @@ -1,6 +1,5 @@ from typing import Any, Dict, List, Union -import numpy as np from haystack.preview.errors import FilterError from pandas import DataFrame @@ -24,10 +23,10 @@ def _normalize_filters(filters: Union[List[Dict], Dict], logical_condition="") - # Comparison operators conditions.extend(_parse_comparison(operator, value)) - if len(conditions) == 1: - return conditions[0] - - conditions = _normalize_ranges(conditions) + if len(conditions) > 1: + conditions = _normalize_ranges(conditions) + else: + conditions = conditions[0] if logical_condition == "$not": return {"bool": {"must_not": conditions}} @@ -42,6 +41,8 @@ def _parse_comparison(field: str, comparison: Union[Dict, List, str, float]) -> result: List[Dict[str, Any]] = [] if isinstance(comparison, dict): for comparator, val in comparison.items(): + if isinstance(val, DataFrame): + val = val.to_json() if comparator == "$eq": if isinstance(val, list): result.append( @@ -59,9 +60,11 @@ def _parse_comparison(field: str, comparison: Union[Dict, List, str, float]) -> result.append({"term": {field: val}}) elif comparator == "$ne": if isinstance(val, list): - msg = f"{field}'s value can't be a list when using '{comparator}' comparator" - raise FilterError(msg) - result.append({"bool": {"must_not": {"term": {field: val}}}}) + result.append({"bool": {"must_not": {"terms": {field: val}}}}) + else: + result.append( + {"bool": {"must_not": {"match": {field: {"query": val, "minimum_should_match": "100%"}}}}} + ) elif comparator == "$in": if not isinstance(val, list): msg = f"{field}'s value must be a list when using '{comparator}' comparator" @@ -73,8 +76,8 @@ def _parse_comparison(field: str, comparison: Union[Dict, List, str, float]) -> raise FilterError(msg) result.append({"bool": {"must_not": {"terms": {field: val}}}}) elif comparator in ["$gt", "$gte", "$lt", "$lte"]: - if isinstance(val, list): - msg = f"{field}'s value can't be a list when using '{comparator}' comparator" + if not isinstance(val, str) and not isinstance(val, int) and not isinstance(val, float): + msg = f"{field}'s value must be 'str', 'int', 'float' types when using '{comparator}' comparator" raise FilterError(msg) result.append({"range": {field: {comparator[1:]: val}}}) elif comparator in ["$not", "$or"]: @@ -90,16 +93,13 @@ def _parse_comparison(field: str, comparison: Union[Dict, List, str, float]) -> raise FilterError(msg) elif isinstance(comparison, list): result.append({"terms": {field: comparison}}) - elif isinstance(comparison, np.ndarray): - result.append({"terms": {field: comparison.tolist()}}) elif isinstance(comparison, DataFrame): - # We're saving dataframes as json strings so we compare them as such - result.append({"match": {field: comparison.to_json()}}) + result.append({"match": {field: {"query": comparison.to_json(), "minimum_should_match": "100%"}}}) elif isinstance(comparison, str): # We can't use "term" for text fields as ElasticSearch changes the value of text. # More info here: # https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-term-query.html#query-dsl-term-query - result.append({"match": {field: comparison}}) + result.append({"match": {field: {"query": comparison, "minimum_should_match": "100%"}}}) else: result.append({"term": {field: comparison}}) return result diff --git a/document_stores/elasticsearch/tests/test_filters.py b/document_stores/elasticsearch/tests/test_filters.py index 2e8a320d4..02e823232 100644 --- a/document_stores/elasticsearch/tests/test_filters.py +++ b/document_stores/elasticsearch/tests/test_filters.py @@ -15,19 +15,23 @@ }, { "bool": { - "must": [ - {"term": {"type": "article"}}, - { - "bool": { - "should": [ - {"terms": {"genre": ["economy", "politics"]}}, - {"term": {"publisher": "nytimes"}}, - ] - } - }, - {"range": {"date": {"gte": "2015-01-01", "lt": "2021-01-01"}}}, - {"range": {"rating": {"gte": 3}}}, - ] + "must": { + "bool": { + "must": [ + {"term": {"type": "article"}}, + { + "bool": { + "should": [ + {"terms": {"genre": ["economy", "politics"]}}, + {"term": {"publisher": "nytimes"}}, + ] + } + }, + {"range": {"date": {"gte": "2015-01-01", "lt": "2021-01-01"}}}, + {"range": {"rating": {"gte": 3}}}, + ] + } + } } }, ), @@ -40,11 +44,15 @@ }, { "bool": { - "should": [ - {"match": {"Type": "News Paper"}}, - {"match": {"Type": "Blog Post"}}, - {"range": {"Date": {"gte": "2019-01-01", "lt": "2019-01-01"}}}, - ] + "must": { + "bool": { + "should": [ + {"match": {"Type": {"query": "News Paper", "minimum_should_match": "100%"}}}, + {"match": {"Type": {"query": "Blog Post", "minimum_should_match": "100%"}}}, + {"range": {"Date": {"lt": "2019-01-01", "gte": "2019-01-01"}}}, + ] + } + } } }, ), @@ -59,19 +67,23 @@ }, { "bool": { - "must": [ - {"term": {"type": "article"}}, - { - "bool": { - "should": [ - {"terms": {"genre": ["economy", "politics"]}}, - {"term": {"publisher": "nytimes"}}, - ] - } - }, - {"range": {"date": {"gte": "2015-01-01", "lt": "2021-01-01"}}}, - {"range": {"rating": {"gte": 3}}}, - ] + "must": { + "bool": { + "must": [ + {"term": {"type": "article"}}, + { + "bool": { + "should": [ + {"terms": {"genre": ["economy", "politics"]}}, + {"term": {"publisher": "nytimes"}}, + ] + } + }, + {"range": {"date": {"gte": "2015-01-01", "lt": "2021-01-01"}}}, + {"range": {"rating": {"gte": 3}}}, + ] + } + } } }, ), @@ -85,12 +97,12 @@ { "bool": { "must": [ - {"match": {"type": "article"}}, + {"match": {"type": {"query": "article", "minimum_should_match": "100%"}}}, { "bool": { "should": [ {"terms": {"genre": ["economy", "politics"]}}, - {"match": {"publisher": "nytimes"}}, + {"match": {"publisher": {"query": "nytimes", "minimum_should_match": "100%"}}}, ] } }, @@ -100,15 +112,29 @@ } }, ), - ({"text": "A Foo Document 1"}, {"match": {"text": "A Foo Document 1"}}), + ( + {"text": "A Foo Document 1"}, + {"bool": {"must": {"match": {"text": {"query": "A Foo Document 1", "minimum_should_match": "100%"}}}}}, + ), ( {"$or": {"name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}, "number": {"$lt": 1.0}}}, { "bool": { - "should": [ - {"bool": {"should": [{"match": {"$eq": "name_0"}}, {"match": {"$eq": "name_1"}}]}}, - {"range": {"number": {"lt": 1.0}}}, - ] + "must": { + "bool": { + "should": [ + { + "bool": { + "should": [ + {"match": {"$eq": {"query": "name_0", "minimum_should_match": "100%"}}}, + {"match": {"$eq": {"query": "name_1", "minimum_should_match": "100%"}}}, + ] + } + }, + {"range": {"number": {"lt": 1.0}}}, + ] + } + } } }, ), @@ -116,10 +142,14 @@ {"$and": {"number": {"$and": {"$lte": 2, "$gte": 0}}, "name": {"$in": ["name_0", "name_1"]}}}, { "bool": { - "must": [ - {"bool": {"must": [{"range": {"number": {"lte": 2, "gte": 0}}}]}}, - {"terms": {"name": ["name_0", "name_1"]}}, - ] + "must": { + "bool": { + "must": [ + {"bool": {"must": [{"range": {"number": {"lte": 2, "gte": 0}}}]}}, + {"terms": {"name": ["name_0", "name_1"]}}, + ] + } + } } }, ), From 35a04af5073764a3be4fc5c9723b4e05f70bf36c Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 9 Nov 2023 12:54:23 +0100 Subject: [PATCH 09/21] Update Document serialization --- .../elasticsearch_haystack/document_store.py | 59 +++++-------------- 1 file changed, 16 insertions(+), 43 deletions(-) diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py index 65d91bce7..4c0a259c3 100644 --- a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py @@ -150,8 +150,21 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc index=self._index, query=query, ) + documents = [self._deserialize_document(hit) for hit in res["hits"]["hits"]] + total = res["hits"]["total"]["value"] - return [self._deserialize_document(hit) for hit in res["hits"]["hits"]] + from_ = len(documents) + while len(documents) < total: + res = self._client.search( + index=self._index, + query=query, + from_=from_, + ) + + documents.extend(self._deserialize_document(hit) for hit in res["hits"]["hits"]) + from_ = len(documents) + + return documents def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None: """ @@ -178,7 +191,7 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D { "_op_type": action, "_id": doc.id, - "_source": self._serialize_document(doc), + "_source": doc.to_dict(), } for doc in documents ), @@ -208,47 +221,7 @@ def _deserialize_document(self, hit: Dict[str, Any]) -> Document: data["metadata"]["highlighted"] = hit["highlight"] data["score"] = hit["_score"] - if array := data["array"]: - data["array"] = np.asarray(array, dtype=np.float32) - if dataframe := data["dataframe"]: - data["dataframe"] = DataFrame.from_dict(json.loads(dataframe)) - if embedding := data["embedding"]: - data["embedding"] = np.asarray(embedding, dtype=np.float32) - - # We can't use Document.from_dict() as the data dictionary contains - # all the metadata fields - return Document( - id=data["id"], - text=data["text"], - array=data["array"], - dataframe=data["dataframe"], - blob=data["blob"], - mime_type=data["mime_type"], - metadata=data["metadata"], - id_hash_keys=data["id_hash_keys"], - score=data["score"], - embedding=data["embedding"], - ) - - def _serialize_document(self, doc: Document) -> Dict[str, Any]: - """ - Serializes Document to a dictionary handling conversion of Pandas' dataframe - and NumPy arrays if present. - """ - # We don't use doc.flatten() cause we want to keep the metadata field - # as it makes it easier to recreate the Document object when calling - # self.filter_document(). - # Otherwise we'd have to filter out the fields that are not part of the - # Document dataclass and keep them as metadata. This is faster and easier. - res = {**doc.to_dict(), **doc.metadata} - if res["array"] is not None: - res["array"] = res["array"].tolist() - if res["dataframe"] is not None: - # Convert dataframe to a json string - res["dataframe"] = res["dataframe"].to_json() - if res["embedding"] is not None: - res["embedding"] = res["embedding"].tolist() - return res + return Document.from_dict(data) def delete_documents(self, document_ids: List[str]) -> None: """ From c3beb27a6120179bae80b95895905437fbb8de1b Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 9 Nov 2023 12:54:41 +0100 Subject: [PATCH 10/21] Handle pagination of search results in filter_documents() --- .../src/elasticsearch_haystack/document_store.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py index 4c0a259c3..aa300db21 100644 --- a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py @@ -146,13 +146,17 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc """ query = {"bool": {"filter": _normalize_filters(filters)}} if filters else None + # Get first page of the search results res = self._client.search( index=self._index, query=query, ) documents = [self._deserialize_document(hit) for hit in res["hits"]["hits"]] + + # This is the number of total documents that match the query total = res["hits"]["total"]["value"] + # Keep querying until we have all the documents from_ = len(documents) while len(documents) < total: res = self._client.search( @@ -161,7 +165,9 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc from_=from_, ) + # Add new documents to the list documents.extend(self._deserialize_document(hit) for hit in res["hits"]["hits"]) + # Update the cursor from_ = len(documents) return documents From 4346f7c3904a4d2f1b71c5f8f48397f1954deeb7 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 9 Nov 2023 12:55:03 +0100 Subject: [PATCH 11/21] Update tests --- .../tests/test_bm25_retriever.py | 4 +- .../tests/test_document_store.py | 697 ++++++++++++++++-- 2 files changed, 651 insertions(+), 50 deletions(-) diff --git a/document_stores/elasticsearch/tests/test_bm25_retriever.py b/document_stores/elasticsearch/tests/test_bm25_retriever.py index 530e85e3a..51dfe5140 100644 --- a/document_stores/elasticsearch/tests/test_bm25_retriever.py +++ b/document_stores/elasticsearch/tests/test_bm25_retriever.py @@ -60,7 +60,7 @@ def test_from_dict(_mock_elasticsearch_client): def test_run(): mock_store = Mock(spec=ElasticsearchDocumentStore) - mock_store._bm25_retrieval.return_value = [Document(text="Test doc")] + mock_store._bm25_retrieval.return_value = [Document(content="Test doc")] retriever = ElasticsearchBM25Retriever(document_store=mock_store) res = retriever.run(query="some query") mock_store._bm25_retrieval.assert_called_once_with( @@ -71,4 +71,4 @@ def test_run(): ) assert len(res) == 1 assert len(res["documents"]) == 1 - assert res["documents"][0].text == "Test doc" + assert res["documents"][0].content == "Test doc" diff --git a/document_stores/elasticsearch/tests/test_document_store.py b/document_stores/elasticsearch/tests/test_document_store.py index b40e04ba5..0af495017 100644 --- a/document_stores/elasticsearch/tests/test_document_store.py +++ b/document_stores/elasticsearch/tests/test_document_store.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2023-present Silvano Cerza # # SPDX-License-Identifier: Apache-2.0 +from typing import List from unittest.mock import patch import pytest @@ -8,6 +9,9 @@ from haystack.preview.document_stores.errors import DuplicateDocumentError from haystack.preview.document_stores.protocols import DuplicatePolicy from haystack.preview.testing.document_store import DocumentStoreBaseTests +from haystack.preview.errors import FilterError +import pandas as pd +import numpy as np from elasticsearch_haystack.document_store import ElasticsearchDocumentStore @@ -60,25 +64,25 @@ def test_from_dict(self, _mock_elasticsearch_client): def test_bm25_retrieval(self, docstore: ElasticsearchDocumentStore): docstore.write_documents( [ - Document(text="Haskell is a functional programming language"), - Document(text="Lisp is a functional programming language"), - Document(text="Exilir is a functional programming language"), - Document(text="F# is a functional programming language"), - Document(text="C# is a functional programming language"), - Document(text="C++ is an object oriented programming language"), - Document(text="Dart is an object oriented programming language"), - Document(text="Go is an object oriented programming language"), - Document(text="Python is a object oriented programming language"), - Document(text="Ruby is a object oriented programming language"), - Document(text="PHP is a object oriented programming language"), + Document(content="Haskell is a functional programming language"), + Document(content="Lisp is a functional programming language"), + Document(content="Exilir is a functional programming language"), + Document(content="F# is a functional programming language"), + Document(content="C# is a functional programming language"), + Document(content="C++ is an object oriented programming language"), + Document(content="Dart is an object oriented programming language"), + Document(content="Go is an object oriented programming language"), + Document(content="Python is a object oriented programming language"), + Document(content="Ruby is a object oriented programming language"), + Document(content="PHP is a object oriented programming language"), ] ) res = docstore._bm25_retrieval("functional", top_k=3) assert len(res) == 3 - assert "functional" in res[0].text - assert "functional" in res[1].text - assert "functional" in res[2].text + assert "functional" in res[0].content + assert "functional" in res[1].content + assert "functional" in res[2].content def test_write_duplicate_fail(self, docstore: ElasticsearchDocumentStore): """ @@ -87,7 +91,7 @@ def test_write_duplicate_fail(self, docstore: ElasticsearchDocumentStore): `DocumentStoreBaseTests` declares this test but we override it since we return a different error message that it expects. """ - doc = Document(text="test doc") + doc = Document(content="test doc") docstore.write_documents([doc]) with pytest.raises(DuplicateDocumentError): docstore.write_documents(documents=[doc], policy=DuplicatePolicy.FAIL) @@ -132,57 +136,654 @@ def test_delete_not_empty_nonexisting(self, docstore: ElasticsearchDocumentStore assert docstore.filter_documents(filters={"id": doc.id}) == [doc] - # The tests below are filters not supported by ElasticsearchDocumentStore - def test_in_filter_table(self): - pass + #### - def test_in_filter_embedding(self): - pass + def test_count_empty(self, docstore: ElasticsearchDocumentStore): + assert docstore.count_documents() == 0 - def test_ne_filter_table(self): - pass + def test_count_not_empty(self, docstore: ElasticsearchDocumentStore): + docstore.write_documents( + [Document(content="test doc 1"), Document(content="test doc 2"), Document(content="test doc 3")] + ) + assert docstore.count_documents() == 3 + + def test_no_filter_empty(self, docstore: ElasticsearchDocumentStore): + assert docstore.filter_documents() == [] + assert docstore.filter_documents(filters={}) == [] + + def test_no_filter_not_empty(self, docstore: ElasticsearchDocumentStore): + docs = [Document(content="test doc")] + docstore.write_documents(docs) + assert docstore.filter_documents() == docs + assert docstore.filter_documents(filters={}) == docs + + def test_filter_simple_metadata_value(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": "100"}) + assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) + + def test_filter_simple_list_single_element( + self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document] + ): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": ["100"]}) + assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) + + def test_filter_document_content(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"content": "A Foo Document 1"}) + assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.content == "A Foo Document 1"]) + + def test_filter_document_dataframe(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"dataframe": pd.DataFrame([1])}) + assert self.contains_same_docs( + result, + [doc for doc in filterable_docs if doc.dataframe is not None and doc.dataframe.equals(pd.DataFrame([1]))], + ) - def test_ne_filter_embedding(self): - pass + def test_filter_simple_list_one_value(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": ["100"]}) + assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100"]]) - def test_nin_filter_table(self): - pass + def test_filter_simple_list(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": ["100", "123"]}) + assert self.contains_same_docs( + result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]] + ) - def test_nin_filter_embedding(self): - pass + def test_incorrect_filter_name(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"non_existing_meta_field": ["whatever"]}) + assert len(result) == 0 + + def test_incorrect_filter_type(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + with pytest.raises(FilterError): + docstore.filter_documents(filters="something odd") # type: ignore + + def test_incorrect_filter_value(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": ["nope"]}) + assert len(result) == 0 + + def test_incorrect_filter_nesting(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + with pytest.raises(FilterError): + docstore.filter_documents(filters={"number": {"page": "100"}}) + + def test_deeper_incorrect_filter_nesting( + self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document] + ): + docstore.write_documents(filterable_docs) + with pytest.raises(FilterError): + docstore.filter_documents(filters={"number": {"page": {"chapter": "intro"}}}) + + def test_eq_filter_explicit(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": {"$eq": "100"}}) + assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) + + def test_eq_filter_implicit(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": "100"}) + assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) + + def test_eq_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"dataframe": pd.DataFrame([1])}) + assert self.contains_same_docs( + result, + [ + doc + for doc in filterable_docs + if isinstance(doc.dataframe, pd.DataFrame) and doc.dataframe.equals(pd.DataFrame([1])) + ], + ) - def test_gt_filter_non_numeric(self): - pass + def test_eq_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + embedding = [0.0] * 768 + result = docstore.filter_documents(filters={"embedding": embedding}) + assert self.contains_same_docs(result, [doc for doc in filterable_docs if embedding == doc.embedding]) + + def test_in_filter_explicit(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": {"$in": ["100", "123", "n.a."]}}) + assert self.contains_same_docs( + result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]] + ) - def test_gt_filter_table(self): - pass + def test_in_filter_implicit(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": ["100", "123", "n.a."]}) + assert self.contains_same_docs( + result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]] + ) - def test_gt_filter_embedding(self): + @pytest.mark.skip(reason="Not supported") + def test_in_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): pass - def test_gte_filter_non_numeric(self): + @pytest.mark.skip(reason="Not supported") + def test_in_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): pass - def test_gte_filter_table(self): - pass + def test_ne_filter(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": {"$ne": "100"}}) + assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.meta.get("page") != "100"]) + + def test_ne_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"dataframe": {"$ne": pd.DataFrame([1])}}) + assert self.contains_same_docs( + result, + [doc for doc in filterable_docs if doc.dataframe is None or not doc.dataframe.equals(pd.DataFrame([1]))], + ) - def test_gte_filter_embedding(self): - pass + def test_ne_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + embedding = [0.0] * 768 + result = docstore.filter_documents(filters={"embedding": {"$ne": embedding}}) + assert self.contains_same_docs( + result, + [doc for doc in filterable_docs if doc.embedding is None or not embedding == doc.embedding], + ) - def test_lt_filter_non_numeric(self): + @pytest.mark.skip(reason="Not supported") + def test_nin_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): pass - def test_lt_filter_table(self): + @pytest.mark.skip(reason="Not supported") + def test_nin_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): pass - def test_lt_filter_embedding(self): - pass + def test_nin_filter(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": {"$nin": ["100", "123", "n.a."]}}) + assert self.contains_same_docs( + result, [doc for doc in filterable_docs if doc.meta.get("page") not in ["100", "123"]] + ) - def test_lte_filter_non_numeric(self): - pass + def test_gt_filter(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"number": {"$gt": 0.0}}) + assert self.contains_same_docs( + result, [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] > 0] + ) - def test_lte_filter_table(self): - pass + def test_gt_filter_non_numeric(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": {"$gt": "100"}}) + assert self.contains_same_docs( + result, [d for d in filterable_docs if "page" in d.meta and d.meta["page"] > "100"] + ) - def test_lte_filter_embedding(self): - pass + def test_gt_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"dataframe": {"$gt": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) + assert result == [] + + def test_gt_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + embedding_zeros = np.zeros([768, 1]).astype(np.float32) + with pytest.raises(FilterError): + docstore.filter_documents(filters={"embedding": {"$gt": embedding_zeros}}) + + def test_gte_filter(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"number": {"$gte": -2}}) + assert self.contains_same_docs( + result, [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] >= -2] + ) + + def test_gte_filter_non_numeric(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": {"$gte": "100"}}) + assert self.contains_same_docs( + result, [d for d in filterable_docs if "page" in d.meta and d.meta["page"] >= "100"] + ) + + def test_gte_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"dataframe": {"$gte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) + assert result == [] + + def test_gte_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + embedding_zeros = np.zeros([768, 1]).astype(np.float32) + with pytest.raises(FilterError): + docstore.filter_documents(filters={"embedding": {"$gte": embedding_zeros}}) + + def test_lt_filter(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"number": {"$lt": 0.0}}) + assert self.contains_same_docs( + result, [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] < 0] + ) + + def test_lt_filter_non_numeric(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": {"$lt": "100"}}) + assert result == [] + + def test_lt_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"dataframe": {"$lt": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) + assert self.contains_same_docs(result, [d for d in filterable_docs if d.dataframe is not None]) + + def test_lt_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + embedding_ones = np.ones([768, 1]).astype(np.float32) + with pytest.raises(FilterError): + docstore.filter_documents(filters={"embedding": {"$lt": embedding_ones}}) + + def test_lte_filter(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"number": {"$lte": 2.0}}) + assert self.contains_same_docs( + result, [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] <= 2.0] + ) + + def test_lte_filter_non_numeric(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"page": {"$lte": "100"}}) + assert self.contains_same_docs( + result, [d for d in filterable_docs if "page" in d.meta and d.meta["page"] <= "100"] + ) + + def test_lte_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"dataframe": {"$lte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) + assert self.contains_same_docs(result, [d for d in filterable_docs if d.dataframe is not None]) + + def test_lte_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + embedding_ones = np.ones([768, 1]).astype(np.float32) + with pytest.raises(FilterError): + docstore.filter_documents(filters={"embedding": {"$lte": embedding_ones}}) + + def test_filter_simple_implicit_and_with_multi_key_dict( + self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document] + ): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"number": {"$lte": 2.0, "$gte": 0.0}}) + assert self.contains_same_docs( + result, + [ + doc + for doc in filterable_docs + if "number" in doc.meta and doc.meta["number"] >= 0.0 and doc.meta["number"] <= 2.0 + ], + ) + + def test_filter_simple_explicit_and_with_multikey_dict( + self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document] + ): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"number": {"$and": {"$gte": 0, "$lte": 2}}}) + assert self.contains_same_docs( + result, [doc for doc in filterable_docs if "number" in doc.meta and 0 <= doc.meta["number"] <= 2] + ) + + def test_filter_simple_explicit_and_with_list( + self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document] + ): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"number": {"$and": [{"$lte": 2}, {"$gte": 0}]}}) + assert self.contains_same_docs( + result, + [ + doc + for doc in filterable_docs + if "number" in doc.meta and doc.meta["number"] <= 2.0 and doc.meta["number"] >= 0.0 + ], + ) + + def test_filter_simple_implicit_and(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + result = docstore.filter_documents(filters={"number": {"$lte": 2.0, "$gte": 0}}) + assert self.contains_same_docs( + result, + [ + doc + for doc in filterable_docs + if "number" in doc.meta and doc.meta["number"] <= 2.0 and doc.meta["number"] >= 0.0 + ], + ) + + def test_filter_nested_explicit_and(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + filters = {"$and": {"number": {"$and": {"$lte": 2, "$gte": 0}}, "name": {"$in": ["name_0", "name_1"]}}} + result = docstore.filter_documents(filters=filters) + assert self.contains_same_docs( + result, + [ + doc + for doc in filterable_docs + if ( + "number" in doc.meta + and doc.meta["number"] >= 0 + and doc.meta["number"] <= 2 + and doc.meta["name"] in ["name_0", "name_1"] + ) + ], + ) + + def test_filter_nested_implicit_and(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + filters_simplified = {"number": {"$lte": 2, "$gte": 0}, "name": ["name_0", "name_1"]} + result = docstore.filter_documents(filters=filters_simplified) + assert self.contains_same_docs( + result, + [ + doc + for doc in filterable_docs + if ( + "number" in doc.meta + and doc.meta["number"] <= 2 + and doc.meta["number"] >= 0 + and doc.meta.get("name") in ["name_0", "name_1"] + ) + ], + ) + + def test_filter_simple_or(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + filters = {"$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}} + result = docstore.filter_documents(filters=filters) + assert self.contains_same_docs( + result, + [ + doc + for doc in filterable_docs + if (("number" in doc.meta and doc.meta["number"] < 1) or doc.meta.get("name") in ["name_0", "name_1"]) + ], + ) + + def test_filter_nested_or(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + filters = {"$or": {"name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}, "number": {"$lt": 1.0}}} + result = docstore.filter_documents(filters=filters) + assert self.contains_same_docs( + result, + [ + doc + for doc in filterable_docs + if (doc.meta.get("name") in ["name_0", "name_1"] or ("number" in doc.meta and doc.meta["number"] < 1)) + ], + ) + + def test_filter_nested_and_or_explicit(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + filters_simplified = { + "$and": {"page": {"$eq": "123"}, "$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}} + } + result = docstore.filter_documents(filters=filters_simplified) + assert self.contains_same_docs( + result, + [ + doc + for doc in filterable_docs + if ( + doc.meta.get("page") in ["123"] + and ( + doc.meta.get("name") in ["name_0", "name_1"] + or ("number" in doc.meta and doc.meta["number"] < 1) + ) + ) + ], + ) + + def test_filter_nested_and_or_implicit(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + filters_simplified = { + "page": {"$eq": "123"}, + "$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}, + } + result = docstore.filter_documents(filters=filters_simplified) + assert self.contains_same_docs( + result, + [ + doc + for doc in filterable_docs + if ( + doc.meta.get("page") in ["123"] + and ( + doc.meta.get("name") in ["name_0", "name_1"] + or ("number" in doc.meta and doc.meta["number"] < 1) + ) + ) + ], + ) + + def test_filter_nested_or_and(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + docstore.write_documents(filterable_docs) + filters_simplified = { + "$or": { + "number": {"$lt": 1}, + "$and": {"name": {"$in": ["name_0", "name_1"]}, "$not": {"chapter": {"$eq": "intro"}}}, + } + } + result = docstore.filter_documents(filters=filters_simplified) + assert self.contains_same_docs( + result, + [ + doc + for doc in filterable_docs + if ( + ("number" in doc.meta and doc.meta["number"] < 1) + or ( + doc.meta.get("name") in ["name_0", "name_1"] + and ("chapter" in doc.meta and doc.meta["chapter"] != "intro") + ) + ) + ], + ) + + def test_filter_nested_multiple_identical_operators_same_level( + self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document] + ): + docstore.write_documents(filterable_docs) + filters = { + "$or": [ + {"$and": {"name": {"$in": ["name_0", "name_1"]}, "page": "100"}}, + {"$and": {"chapter": {"$in": ["intro", "abstract"]}, "page": "123"}}, + ] + } + result = docstore.filter_documents(filters=filters) + assert self.contains_same_docs( + result, + [ + doc + for doc in filterable_docs + if ( + (doc.meta.get("name") in ["name_0", "name_1"] and doc.meta.get("page") == "100") + or (doc.meta.get("chapter") in ["intro", "abstract"] and doc.meta.get("page") == "123") + ) + ], + ) + + def test_write(self, docstore: ElasticsearchDocumentStore): + doc = Document(content="test doc") + docstore.write_documents([doc]) + assert docstore.filter_documents(filters={"id": doc.id}) == [doc] + + def test_write_duplicate_skip(self, docstore: ElasticsearchDocumentStore): + doc = Document(content="test doc") + docstore.write_documents([doc]) + docstore.write_documents(documents=[doc], policy=DuplicatePolicy.SKIP) + assert docstore.filter_documents(filters={"id": doc.id}) == [doc] + + def test_write_duplicate_overwrite(self, docstore: ElasticsearchDocumentStore): + doc1 = Document(content="test doc 1") + doc2 = Document(content="test doc 2") + object.__setattr__(doc2, "id", doc1.id) # Make two docs with different content but same ID + + docstore.write_documents([doc2]) + assert docstore.filter_documents(filters={"id": doc1.id}) == [doc2] + docstore.write_documents(documents=[doc1], policy=DuplicatePolicy.OVERWRITE) + assert docstore.filter_documents(filters={"id": doc1.id}) == [doc1] + + def test_write_not_docs(self, docstore: ElasticsearchDocumentStore): + with pytest.raises(ValueError): + docstore.write_documents(["not a document for sure"]) # type: ignore + + def test_write_not_list(self, docstore: ElasticsearchDocumentStore): + with pytest.raises(ValueError): + docstore.write_documents("not a list actually") # type: ignore + + # The tests below are filters not supported by ElasticsearchDocumentStore + # def test_in_filter_table(self): + # pass + + # def test_in_filter_embedding(self): + # pass + + # def test_ne_filter_table(self): + # pass + + # def test_ne_filter_embedding(self): + # pass + + # def test_nin_filter_table(self): + # pass + + # def test_nin_filter_embedding(self): + # pass + + # def test_gt_filter_non_numeric(self): + # pass + + # def test_gt_filter_table(self): + # pass + + # def test_gt_filter_embedding(self): + # pass + + # def test_gte_filter_non_numeric(self): + # pass + + # def test_gte_filter_table(self): + # pass + + # def test_gte_filter_embedding(self): + # pass + + # def test_lt_filter_non_numeric(self): + # pass + + # def test_lt_filter_table(self): + # pass + + # def test_lt_filter_embedding(self): + # pass + + # def test_lte_filter_non_numeric(self): + # pass + + # def test_lte_filter_table(self): + # pass + + # def test_lte_filter_embedding(self): + # pass + + # def test_nin_filter(self, docstore, filterable_docs): + # docstore.write_documents(filterable_docs) + # result = docstore.filter_documents(filters={"page": {"$nin": ["100", "123", "n.a."]}}) + # expected = [doc for doc in filterable_docs if doc.meta.get("page") not in ["100", "123"]] + # assert self.contains_same_docs(result, expected) + + # def test_filter_document_text(self, docstore, filterable_docs): + # docstore.write_documents(filterable_docs) + # result = docstore.filter_documents(filters={"text": "A Foo Document 1"}) + # assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.text == "A Foo Document 1"]) + + # def test_filter_document_dataframe(self, docstore, filterable_docs): + # docstore.write_documents(filterable_docs) + # result = docstore.filter_documents(filters={"dataframe": pd.DataFrame([1])}) + # expected = ( + # [doc for doc in filterable_docs if doc.dataframe is not None and doc.dataframe.equals(pd.DataFrame([1]))], + # ) + # assert self.contains_same_docs(result, expected) + + # def test_eq_filter_table(self, docstore, filterable_docs): + # docstore.write_documents(filterable_docs) + # result = docstore.filter_documents(filters={"dataframe": pd.DataFrame([1])}) + # assert self.contains_same_docs( + # result, + # [ + # doc + # for doc in filterable_docs + # if isinstance(doc.dataframe, pd.DataFrame) and doc.dataframe.equals(pd.DataFrame([1])) + # ], + # ) + + # def test_ne_filter(self, docstore, filterable_docs): + # docstore.write_documents(filterable_docs) + # result = docstore.filter_documents(filters={"page": {"$ne": "100"}}) + # assert self.contains_same_docs( + # result, + # [doc for doc in filterable_docs if doc.meta.get("page") != "100"], + # ) + + # def test_nin_filter_table(self, docstore, filterable_docs): + # docstore.write_documents(filterable_docs) + # result = docstore.filter_documents(filters={"dataframe": {"$nin": [pd.DataFrame([1]), pd.DataFrame([0])]}}) + # assert self.contains_same_docs( + # result, + # [ + # doc + # for doc in filterable_docs + # if not isinstance(doc.dataframe, pd.DataFrame) + # or (not doc.dataframe.equals(pd.DataFrame([1])) and not doc.dataframe.equals(pd.DataFrame([0]))) + # ], + # ) + + # def test_filter_nested_or(self, docstore, filterable_docs): + # docstore.write_documents(filterable_docs) + # filters = { + # "$or": { + # "name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}, + # "number": {"$lt": 1.0}, + # } + # } + # result = docstore.filter_documents(filters=filters) + # assert self.contains_same_docs( + # result, + # [ + # doc + # for doc in filterable_docs + # if (doc.meta.get("name") in ["name_0", "name_1"] or ("number" in doc.meta and doc.meta["number"] < 1)) + # ], + # ) + + # def test_filter_nested_or_and(self, docstore, filterable_docs): + # docstore.write_documents(filterable_docs) + # filters_simplified = { + # "$or": { + # "number": {"$lt": 1}, + # "$and": { + # "name": {"$in": ["name_0", "name_1"]}, + # "$not": {"chapter": {"$eq": "intro"}}, + # }, + # } + # } + # result = docstore.filter_documents(filters=filters_simplified) + # assert self.contains_same_docs( + # result, + # [ + # doc + # for doc in filterable_docs + # if ( + # ("number" in doc.meta and doc.meta["number"] < 1) + # or ( + # doc.meta.get("name") in ["name_0", "name_1"] + # and ("chapter" in doc.meta and doc.meta["chapter"] != "intro") + # ) + # ) + # ], + # ) From f4c2963ca18cf9da56002ef48c0cf85c40648311 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 9 Nov 2023 16:06:32 +0100 Subject: [PATCH 12/21] Handle filters conrner case --- .../src/elasticsearch_haystack/filters.py | 13 ++++++++++++- .../elasticsearch/tests/test_filters.py | 19 +++++++++++++++++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py b/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py index f7d884e2a..66adaacf7 100644 --- a/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py @@ -81,7 +81,18 @@ def _parse_comparison(field: str, comparison: Union[Dict, List, str, float]) -> raise FilterError(msg) result.append({"range": {field: {comparator[1:]: val}}}) elif comparator in ["$not", "$or"]: - result.append(_normalize_filters(val, comparator)) + if isinstance(val, list): + # This handles corner cases like this: + # `{"name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}}` + # If we don't handle it like this we'd lose the "name" field and the + # generated query would be wrong and return unexpected results. + comparisons = [_parse_comparison(field, v)[0] for v in val] + if comparator == "$not": + result.append({"bool": {"must_not": comparisons}}) + elif comparator == "$or": + result.append({"bool": {"should": comparisons}}) + else: + result.append(_normalize_filters(val, comparator)) elif comparator == "$and" and isinstance(val, list): # We're assuming there are no duplicate items in the list flat_filters = {k: v for d in val for k, v in d.items()} diff --git a/document_stores/elasticsearch/tests/test_filters.py b/document_stores/elasticsearch/tests/test_filters.py index 02e823232..efaa168b0 100644 --- a/document_stores/elasticsearch/tests/test_filters.py +++ b/document_stores/elasticsearch/tests/test_filters.py @@ -126,8 +126,8 @@ { "bool": { "should": [ - {"match": {"$eq": {"query": "name_0", "minimum_should_match": "100%"}}}, - {"match": {"$eq": {"query": "name_1", "minimum_should_match": "100%"}}}, + {"term": {"name": "name_0"}}, + {"term": {"name": "name_1"}}, ] } }, @@ -168,6 +168,21 @@ {"number": {"$and": [{"$lte": 2}, {"$gte": 0}]}}, {"bool": {"must": [{"range": {"number": {"lte": 2, "gte": 0}}}]}}, ), + ( + {"name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}}, + { + "bool": { + "must": { + "bool": { + "should": [ + {"term": {"name": "name_0"}}, + {"term": {"name": "name_1"}}, + ] + } + } + } + }, + ), ] From 23330a144dd28547a9ff62371c788ebc1dfdea63 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 9 Nov 2023 16:33:13 +0100 Subject: [PATCH 13/21] Cleanup unchanged tests --- .../tests/test_document_store.py | 581 ------------------ 1 file changed, 581 deletions(-) diff --git a/document_stores/elasticsearch/tests/test_document_store.py b/document_stores/elasticsearch/tests/test_document_store.py index 0af495017..99d2f7058 100644 --- a/document_stores/elasticsearch/tests/test_document_store.py +++ b/document_stores/elasticsearch/tests/test_document_store.py @@ -136,133 +136,6 @@ def test_delete_not_empty_nonexisting(self, docstore: ElasticsearchDocumentStore assert docstore.filter_documents(filters={"id": doc.id}) == [doc] - #### - - def test_count_empty(self, docstore: ElasticsearchDocumentStore): - assert docstore.count_documents() == 0 - - def test_count_not_empty(self, docstore: ElasticsearchDocumentStore): - docstore.write_documents( - [Document(content="test doc 1"), Document(content="test doc 2"), Document(content="test doc 3")] - ) - assert docstore.count_documents() == 3 - - def test_no_filter_empty(self, docstore: ElasticsearchDocumentStore): - assert docstore.filter_documents() == [] - assert docstore.filter_documents(filters={}) == [] - - def test_no_filter_not_empty(self, docstore: ElasticsearchDocumentStore): - docs = [Document(content="test doc")] - docstore.write_documents(docs) - assert docstore.filter_documents() == docs - assert docstore.filter_documents(filters={}) == docs - - def test_filter_simple_metadata_value(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"page": "100"}) - assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) - - def test_filter_simple_list_single_element( - self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document] - ): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"page": ["100"]}) - assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) - - def test_filter_document_content(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"content": "A Foo Document 1"}) - assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.content == "A Foo Document 1"]) - - def test_filter_document_dataframe(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"dataframe": pd.DataFrame([1])}) - assert self.contains_same_docs( - result, - [doc for doc in filterable_docs if doc.dataframe is not None and doc.dataframe.equals(pd.DataFrame([1]))], - ) - - def test_filter_simple_list_one_value(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"page": ["100"]}) - assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100"]]) - - def test_filter_simple_list(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"page": ["100", "123"]}) - assert self.contains_same_docs( - result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]] - ) - - def test_incorrect_filter_name(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"non_existing_meta_field": ["whatever"]}) - assert len(result) == 0 - - def test_incorrect_filter_type(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - with pytest.raises(FilterError): - docstore.filter_documents(filters="something odd") # type: ignore - - def test_incorrect_filter_value(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"page": ["nope"]}) - assert len(result) == 0 - - def test_incorrect_filter_nesting(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - with pytest.raises(FilterError): - docstore.filter_documents(filters={"number": {"page": "100"}}) - - def test_deeper_incorrect_filter_nesting( - self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document] - ): - docstore.write_documents(filterable_docs) - with pytest.raises(FilterError): - docstore.filter_documents(filters={"number": {"page": {"chapter": "intro"}}}) - - def test_eq_filter_explicit(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"page": {"$eq": "100"}}) - assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) - - def test_eq_filter_implicit(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"page": "100"}) - assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.meta.get("page") == "100"]) - - def test_eq_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"dataframe": pd.DataFrame([1])}) - assert self.contains_same_docs( - result, - [ - doc - for doc in filterable_docs - if isinstance(doc.dataframe, pd.DataFrame) and doc.dataframe.equals(pd.DataFrame([1])) - ], - ) - - def test_eq_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - embedding = [0.0] * 768 - result = docstore.filter_documents(filters={"embedding": embedding}) - assert self.contains_same_docs(result, [doc for doc in filterable_docs if embedding == doc.embedding]) - - def test_in_filter_explicit(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"page": {"$in": ["100", "123", "n.a."]}}) - assert self.contains_same_docs( - result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]] - ) - - def test_in_filter_implicit(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"page": ["100", "123", "n.a."]}) - assert self.contains_same_docs( - result, [doc for doc in filterable_docs if doc.meta.get("page") in ["100", "123"]] - ) - @pytest.mark.skip(reason="Not supported") def test_in_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): pass @@ -271,19 +144,6 @@ def test_in_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_ def test_in_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): pass - def test_ne_filter(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"page": {"$ne": "100"}}) - assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.meta.get("page") != "100"]) - - def test_ne_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"dataframe": {"$ne": pd.DataFrame([1])}}) - assert self.contains_same_docs( - result, - [doc for doc in filterable_docs if doc.dataframe is None or not doc.dataframe.equals(pd.DataFrame([1]))], - ) - def test_ne_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) embedding = [0.0] * 768 @@ -301,20 +161,6 @@ def test_nin_filter_table(self, docstore: ElasticsearchDocumentStore, filterable def test_nin_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): pass - def test_nin_filter(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"page": {"$nin": ["100", "123", "n.a."]}}) - assert self.contains_same_docs( - result, [doc for doc in filterable_docs if doc.meta.get("page") not in ["100", "123"]] - ) - - def test_gt_filter(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"number": {"$gt": 0.0}}) - assert self.contains_same_docs( - result, [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] > 0] - ) - def test_gt_filter_non_numeric(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": {"$gt": "100"}}) @@ -327,19 +173,6 @@ def test_gt_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_ result = docstore.filter_documents(filters={"dataframe": {"$gt": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) assert result == [] - def test_gt_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - embedding_zeros = np.zeros([768, 1]).astype(np.float32) - with pytest.raises(FilterError): - docstore.filter_documents(filters={"embedding": {"$gt": embedding_zeros}}) - - def test_gte_filter(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"number": {"$gte": -2}}) - assert self.contains_same_docs( - result, [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] >= -2] - ) - def test_gte_filter_non_numeric(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": {"$gte": "100"}}) @@ -352,19 +185,6 @@ def test_gte_filter_table(self, docstore: ElasticsearchDocumentStore, filterable result = docstore.filter_documents(filters={"dataframe": {"$gte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) assert result == [] - def test_gte_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - embedding_zeros = np.zeros([768, 1]).astype(np.float32) - with pytest.raises(FilterError): - docstore.filter_documents(filters={"embedding": {"$gte": embedding_zeros}}) - - def test_lt_filter(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"number": {"$lt": 0.0}}) - assert self.contains_same_docs( - result, [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] < 0] - ) - def test_lt_filter_non_numeric(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": {"$lt": "100"}}) @@ -375,19 +195,6 @@ def test_lt_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_ result = docstore.filter_documents(filters={"dataframe": {"$lt": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) assert self.contains_same_docs(result, [d for d in filterable_docs if d.dataframe is not None]) - def test_lt_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - embedding_ones = np.ones([768, 1]).astype(np.float32) - with pytest.raises(FilterError): - docstore.filter_documents(filters={"embedding": {"$lt": embedding_ones}}) - - def test_lte_filter(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"number": {"$lte": 2.0}}) - assert self.contains_same_docs( - result, [doc for doc in filterable_docs if "number" in doc.meta and doc.meta["number"] <= 2.0] - ) - def test_lte_filter_non_numeric(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": {"$lte": "100"}}) @@ -399,391 +206,3 @@ def test_lte_filter_table(self, docstore: ElasticsearchDocumentStore, filterable docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"dataframe": {"$lte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) assert self.contains_same_docs(result, [d for d in filterable_docs if d.dataframe is not None]) - - def test_lte_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - embedding_ones = np.ones([768, 1]).astype(np.float32) - with pytest.raises(FilterError): - docstore.filter_documents(filters={"embedding": {"$lte": embedding_ones}}) - - def test_filter_simple_implicit_and_with_multi_key_dict( - self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document] - ): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"number": {"$lte": 2.0, "$gte": 0.0}}) - assert self.contains_same_docs( - result, - [ - doc - for doc in filterable_docs - if "number" in doc.meta and doc.meta["number"] >= 0.0 and doc.meta["number"] <= 2.0 - ], - ) - - def test_filter_simple_explicit_and_with_multikey_dict( - self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document] - ): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"number": {"$and": {"$gte": 0, "$lte": 2}}}) - assert self.contains_same_docs( - result, [doc for doc in filterable_docs if "number" in doc.meta and 0 <= doc.meta["number"] <= 2] - ) - - def test_filter_simple_explicit_and_with_list( - self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document] - ): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"number": {"$and": [{"$lte": 2}, {"$gte": 0}]}}) - assert self.contains_same_docs( - result, - [ - doc - for doc in filterable_docs - if "number" in doc.meta and doc.meta["number"] <= 2.0 and doc.meta["number"] >= 0.0 - ], - ) - - def test_filter_simple_implicit_and(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - result = docstore.filter_documents(filters={"number": {"$lte": 2.0, "$gte": 0}}) - assert self.contains_same_docs( - result, - [ - doc - for doc in filterable_docs - if "number" in doc.meta and doc.meta["number"] <= 2.0 and doc.meta["number"] >= 0.0 - ], - ) - - def test_filter_nested_explicit_and(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - filters = {"$and": {"number": {"$and": {"$lte": 2, "$gte": 0}}, "name": {"$in": ["name_0", "name_1"]}}} - result = docstore.filter_documents(filters=filters) - assert self.contains_same_docs( - result, - [ - doc - for doc in filterable_docs - if ( - "number" in doc.meta - and doc.meta["number"] >= 0 - and doc.meta["number"] <= 2 - and doc.meta["name"] in ["name_0", "name_1"] - ) - ], - ) - - def test_filter_nested_implicit_and(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - filters_simplified = {"number": {"$lte": 2, "$gte": 0}, "name": ["name_0", "name_1"]} - result = docstore.filter_documents(filters=filters_simplified) - assert self.contains_same_docs( - result, - [ - doc - for doc in filterable_docs - if ( - "number" in doc.meta - and doc.meta["number"] <= 2 - and doc.meta["number"] >= 0 - and doc.meta.get("name") in ["name_0", "name_1"] - ) - ], - ) - - def test_filter_simple_or(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - filters = {"$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}} - result = docstore.filter_documents(filters=filters) - assert self.contains_same_docs( - result, - [ - doc - for doc in filterable_docs - if (("number" in doc.meta and doc.meta["number"] < 1) or doc.meta.get("name") in ["name_0", "name_1"]) - ], - ) - - def test_filter_nested_or(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - filters = {"$or": {"name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}, "number": {"$lt": 1.0}}} - result = docstore.filter_documents(filters=filters) - assert self.contains_same_docs( - result, - [ - doc - for doc in filterable_docs - if (doc.meta.get("name") in ["name_0", "name_1"] or ("number" in doc.meta and doc.meta["number"] < 1)) - ], - ) - - def test_filter_nested_and_or_explicit(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - filters_simplified = { - "$and": {"page": {"$eq": "123"}, "$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}} - } - result = docstore.filter_documents(filters=filters_simplified) - assert self.contains_same_docs( - result, - [ - doc - for doc in filterable_docs - if ( - doc.meta.get("page") in ["123"] - and ( - doc.meta.get("name") in ["name_0", "name_1"] - or ("number" in doc.meta and doc.meta["number"] < 1) - ) - ) - ], - ) - - def test_filter_nested_and_or_implicit(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - filters_simplified = { - "page": {"$eq": "123"}, - "$or": {"name": {"$in": ["name_0", "name_1"]}, "number": {"$lt": 1.0}}, - } - result = docstore.filter_documents(filters=filters_simplified) - assert self.contains_same_docs( - result, - [ - doc - for doc in filterable_docs - if ( - doc.meta.get("page") in ["123"] - and ( - doc.meta.get("name") in ["name_0", "name_1"] - or ("number" in doc.meta and doc.meta["number"] < 1) - ) - ) - ], - ) - - def test_filter_nested_or_and(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - filters_simplified = { - "$or": { - "number": {"$lt": 1}, - "$and": {"name": {"$in": ["name_0", "name_1"]}, "$not": {"chapter": {"$eq": "intro"}}}, - } - } - result = docstore.filter_documents(filters=filters_simplified) - assert self.contains_same_docs( - result, - [ - doc - for doc in filterable_docs - if ( - ("number" in doc.meta and doc.meta["number"] < 1) - or ( - doc.meta.get("name") in ["name_0", "name_1"] - and ("chapter" in doc.meta and doc.meta["chapter"] != "intro") - ) - ) - ], - ) - - def test_filter_nested_multiple_identical_operators_same_level( - self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document] - ): - docstore.write_documents(filterable_docs) - filters = { - "$or": [ - {"$and": {"name": {"$in": ["name_0", "name_1"]}, "page": "100"}}, - {"$and": {"chapter": {"$in": ["intro", "abstract"]}, "page": "123"}}, - ] - } - result = docstore.filter_documents(filters=filters) - assert self.contains_same_docs( - result, - [ - doc - for doc in filterable_docs - if ( - (doc.meta.get("name") in ["name_0", "name_1"] and doc.meta.get("page") == "100") - or (doc.meta.get("chapter") in ["intro", "abstract"] and doc.meta.get("page") == "123") - ) - ], - ) - - def test_write(self, docstore: ElasticsearchDocumentStore): - doc = Document(content="test doc") - docstore.write_documents([doc]) - assert docstore.filter_documents(filters={"id": doc.id}) == [doc] - - def test_write_duplicate_skip(self, docstore: ElasticsearchDocumentStore): - doc = Document(content="test doc") - docstore.write_documents([doc]) - docstore.write_documents(documents=[doc], policy=DuplicatePolicy.SKIP) - assert docstore.filter_documents(filters={"id": doc.id}) == [doc] - - def test_write_duplicate_overwrite(self, docstore: ElasticsearchDocumentStore): - doc1 = Document(content="test doc 1") - doc2 = Document(content="test doc 2") - object.__setattr__(doc2, "id", doc1.id) # Make two docs with different content but same ID - - docstore.write_documents([doc2]) - assert docstore.filter_documents(filters={"id": doc1.id}) == [doc2] - docstore.write_documents(documents=[doc1], policy=DuplicatePolicy.OVERWRITE) - assert docstore.filter_documents(filters={"id": doc1.id}) == [doc1] - - def test_write_not_docs(self, docstore: ElasticsearchDocumentStore): - with pytest.raises(ValueError): - docstore.write_documents(["not a document for sure"]) # type: ignore - - def test_write_not_list(self, docstore: ElasticsearchDocumentStore): - with pytest.raises(ValueError): - docstore.write_documents("not a list actually") # type: ignore - - # The tests below are filters not supported by ElasticsearchDocumentStore - # def test_in_filter_table(self): - # pass - - # def test_in_filter_embedding(self): - # pass - - # def test_ne_filter_table(self): - # pass - - # def test_ne_filter_embedding(self): - # pass - - # def test_nin_filter_table(self): - # pass - - # def test_nin_filter_embedding(self): - # pass - - # def test_gt_filter_non_numeric(self): - # pass - - # def test_gt_filter_table(self): - # pass - - # def test_gt_filter_embedding(self): - # pass - - # def test_gte_filter_non_numeric(self): - # pass - - # def test_gte_filter_table(self): - # pass - - # def test_gte_filter_embedding(self): - # pass - - # def test_lt_filter_non_numeric(self): - # pass - - # def test_lt_filter_table(self): - # pass - - # def test_lt_filter_embedding(self): - # pass - - # def test_lte_filter_non_numeric(self): - # pass - - # def test_lte_filter_table(self): - # pass - - # def test_lte_filter_embedding(self): - # pass - - # def test_nin_filter(self, docstore, filterable_docs): - # docstore.write_documents(filterable_docs) - # result = docstore.filter_documents(filters={"page": {"$nin": ["100", "123", "n.a."]}}) - # expected = [doc for doc in filterable_docs if doc.meta.get("page") not in ["100", "123"]] - # assert self.contains_same_docs(result, expected) - - # def test_filter_document_text(self, docstore, filterable_docs): - # docstore.write_documents(filterable_docs) - # result = docstore.filter_documents(filters={"text": "A Foo Document 1"}) - # assert self.contains_same_docs(result, [doc for doc in filterable_docs if doc.text == "A Foo Document 1"]) - - # def test_filter_document_dataframe(self, docstore, filterable_docs): - # docstore.write_documents(filterable_docs) - # result = docstore.filter_documents(filters={"dataframe": pd.DataFrame([1])}) - # expected = ( - # [doc for doc in filterable_docs if doc.dataframe is not None and doc.dataframe.equals(pd.DataFrame([1]))], - # ) - # assert self.contains_same_docs(result, expected) - - # def test_eq_filter_table(self, docstore, filterable_docs): - # docstore.write_documents(filterable_docs) - # result = docstore.filter_documents(filters={"dataframe": pd.DataFrame([1])}) - # assert self.contains_same_docs( - # result, - # [ - # doc - # for doc in filterable_docs - # if isinstance(doc.dataframe, pd.DataFrame) and doc.dataframe.equals(pd.DataFrame([1])) - # ], - # ) - - # def test_ne_filter(self, docstore, filterable_docs): - # docstore.write_documents(filterable_docs) - # result = docstore.filter_documents(filters={"page": {"$ne": "100"}}) - # assert self.contains_same_docs( - # result, - # [doc for doc in filterable_docs if doc.meta.get("page") != "100"], - # ) - - # def test_nin_filter_table(self, docstore, filterable_docs): - # docstore.write_documents(filterable_docs) - # result = docstore.filter_documents(filters={"dataframe": {"$nin": [pd.DataFrame([1]), pd.DataFrame([0])]}}) - # assert self.contains_same_docs( - # result, - # [ - # doc - # for doc in filterable_docs - # if not isinstance(doc.dataframe, pd.DataFrame) - # or (not doc.dataframe.equals(pd.DataFrame([1])) and not doc.dataframe.equals(pd.DataFrame([0]))) - # ], - # ) - - # def test_filter_nested_or(self, docstore, filterable_docs): - # docstore.write_documents(filterable_docs) - # filters = { - # "$or": { - # "name": {"$or": [{"$eq": "name_0"}, {"$eq": "name_1"}]}, - # "number": {"$lt": 1.0}, - # } - # } - # result = docstore.filter_documents(filters=filters) - # assert self.contains_same_docs( - # result, - # [ - # doc - # for doc in filterable_docs - # if (doc.meta.get("name") in ["name_0", "name_1"] or ("number" in doc.meta and doc.meta["number"] < 1)) - # ], - # ) - - # def test_filter_nested_or_and(self, docstore, filterable_docs): - # docstore.write_documents(filterable_docs) - # filters_simplified = { - # "$or": { - # "number": {"$lt": 1}, - # "$and": { - # "name": {"$in": ["name_0", "name_1"]}, - # "$not": {"chapter": {"$eq": "intro"}}, - # }, - # } - # } - # result = docstore.filter_documents(filters=filters_simplified) - # assert self.contains_same_docs( - # result, - # [ - # doc - # for doc in filterable_docs - # if ( - # ("number" in doc.meta and doc.meta["number"] < 1) - # or ( - # doc.meta.get("name") in ["name_0", "name_1"] - # and ("chapter" in doc.meta and doc.meta["chapter"] != "intro") - # ) - # ) - # ], - # ) From 42cdc31237836ff092eaf61432e5a95642968aba Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 9 Nov 2023 16:37:45 +0100 Subject: [PATCH 14/21] Add missing dependency --- document_stores/elasticsearch/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/document_stores/elasticsearch/pyproject.toml b/document_stores/elasticsearch/pyproject.toml index 8da7e2f0d..297e882ce 100644 --- a/document_stores/elasticsearch/pyproject.toml +++ b/document_stores/elasticsearch/pyproject.toml @@ -26,7 +26,8 @@ classifiers = [ dependencies = [ # we distribute the preview version of Haystack 2.0 under the package "haystack-ai" "haystack-ai", - "elasticsearch>=8,<9" + "elasticsearch>=8,<9", + "typing_extensions", # This is not a direct dependency, but `haystack-ai` is missing it cause `canals` is missing it ] [project.urls] From 6223a363298bf06a41e7126e769baff01f6db1f2 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Thu, 9 Nov 2023 16:46:24 +0100 Subject: [PATCH 15/21] Fix linting issues --- .../src/elasticsearch_haystack/document_store.py | 2 -- .../elasticsearch/src/elasticsearch_haystack/filters.py | 8 ++++++-- .../elasticsearch/tests/test_document_store.py | 4 +--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py index aa300db21..302ea09f6 100644 --- a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: 2023-present Silvano Cerza # # SPDX-License-Identifier: Apache-2.0 -import json import logging from typing import Any, Dict, List, Mapping, Optional, Union @@ -15,7 +14,6 @@ from haystack.preview.document_stores.decorator import document_store from haystack.preview.document_stores.errors import DuplicateDocumentError from haystack.preview.document_stores.protocols import DuplicatePolicy -from pandas import DataFrame from elasticsearch_haystack.filters import _normalize_filters diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py b/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py index 66adaacf7..78adae585 100644 --- a/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/filters.py @@ -26,7 +26,9 @@ def _normalize_filters(filters: Union[List[Dict], Dict], logical_condition="") - if len(conditions) > 1: conditions = _normalize_ranges(conditions) else: - conditions = conditions[0] + # mypy is complaining we're assigning a dict to a list of dicts. + # We're ok with this as we're returning right after this. + conditions = conditions[0] # type: ignore[assignment] if logical_condition == "$not": return {"bool": {"must_not": conditions}} @@ -42,7 +44,9 @@ def _parse_comparison(field: str, comparison: Union[Dict, List, str, float]) -> if isinstance(comparison, dict): for comparator, val in comparison.items(): if isinstance(val, DataFrame): - val = val.to_json() + # Ruff is complaining we're overriding the loop variable `var` + # but we actually want to override it. So we ignore the error. + val = val.to_json() # noqa: PLW2901 if comparator == "$eq": if isinstance(val, list): result.append( diff --git a/document_stores/elasticsearch/tests/test_document_store.py b/document_stores/elasticsearch/tests/test_document_store.py index 99d2f7058..22d613cbc 100644 --- a/document_stores/elasticsearch/tests/test_document_store.py +++ b/document_stores/elasticsearch/tests/test_document_store.py @@ -4,14 +4,12 @@ from typing import List from unittest.mock import patch +import pandas as pd import pytest from haystack.preview.dataclasses.document import Document from haystack.preview.document_stores.errors import DuplicateDocumentError from haystack.preview.document_stores.protocols import DuplicatePolicy from haystack.preview.testing.document_store import DocumentStoreBaseTests -from haystack.preview.errors import FilterError -import pandas as pd -import numpy as np from elasticsearch_haystack.document_store import ElasticsearchDocumentStore From ef13674594c9620d9d809431c5b4dcd90565a6a8 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Fri, 10 Nov 2023 10:42:55 +0100 Subject: [PATCH 16/21] Update project urls --- document_stores/elasticsearch/pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/document_stores/elasticsearch/pyproject.toml b/document_stores/elasticsearch/pyproject.toml index 297e882ce..8861f188a 100644 --- a/document_stores/elasticsearch/pyproject.toml +++ b/document_stores/elasticsearch/pyproject.toml @@ -31,9 +31,9 @@ dependencies = [ ] [project.urls] -Documentation = "https://github.com/deepset-ai/elasticsearch-haystack#readme" -Issues = "https://github.com/deepset-ai/elasticsearch-haystack/issues" -Source = "https://github.com/deepset-ai/elasticsearch-haystack" +Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/document_stores/elasticsearch#readme" +Issues = "https://github.com/deepset-ai/haystack-core-integrations/issues" +Source = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/document_stores/elasticsearch" [tool.hatch.version] path = "src/elasticsearch_haystack/__about__.py" From 6fca29a6e4e245a905446800128e73ea3682ef4f Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Fri, 10 Nov 2023 10:49:50 +0100 Subject: [PATCH 17/21] Simplify pagination handling --- .../elasticsearch_haystack/document_store.py | 23 +++++-------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py index 302ea09f6..27e31e732 100644 --- a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py @@ -144,30 +144,19 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc """ query = {"bool": {"filter": _normalize_filters(filters)}} if filters else None - # Get first page of the search results - res = self._client.search( - index=self._index, - query=query, - ) - documents = [self._deserialize_document(hit) for hit in res["hits"]["hits"]] - - # This is the number of total documents that match the query - total = res["hits"]["total"]["value"] - - # Keep querying until we have all the documents - from_ = len(documents) - while len(documents) < total: + documents = [] + from_ = 0 + # Handle pagination + while True: res = self._client.search( index=self._index, query=query, from_=from_, ) - - # Add new documents to the list documents.extend(self._deserialize_document(hit) for hit in res["hits"]["hits"]) - # Update the cursor from_ = len(documents) - + if from_ >= res["hits"]["total"]["value"]: + break return documents def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None: From 1574a3525564cf4e406d5896ede4ad3ce4407775 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Fri, 10 Nov 2023 10:53:01 +0100 Subject: [PATCH 18/21] Better document magic numbers --- .../src/elasticsearch_haystack/document_store.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py index 27e31e732..073b5f7c7 100644 --- a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py @@ -21,6 +21,14 @@ Hosts = Union[str, List[Union[str, Mapping[str, Union[str, int]], NodeConfig]]] +# document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to +# True (default). Scaling uses the expit function (inverse of the logit function) after applying a scaling factor +# (e.g., BM25_SCALING_FACTOR for the bm25_retrieval method). +# Larger scaling factor decreases scaled scores. For example, an input of 10 is scaled to 0.99 with BM25_SCALING_FACTOR=2 +# but to 0.78 with BM25_SCALING_FACTOR=8 (default). The defaults were chosen empirically. Increase the default if most +# unscaled scores are larger than expected (>30) and otherwise would incorrectly all be mapped to scores ~1. +BM25_SCALING_FACTOR = 8 + @document_store class ElasticsearchDocumentStore: @@ -289,6 +297,6 @@ def _bm25_retrieval( docs = [] for hit in res["hits"]["hits"]: if scale_score: - hit["_score"] = float(1 / (1 + np.exp(-np.asarray(hit["_score"] / 8)))) + hit["_score"] = float(1 / (1 + np.exp(-np.asarray(hit["_score"] / BM25_SCALING_FACTOR)))) docs.append(self._deserialize_document(hit)) return docs From dca94a5ab6e8b8c3af03a75b73840af42582440c Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Fri, 10 Nov 2023 11:13:38 +0100 Subject: [PATCH 19/21] Raise correct error when instantiating BM25Retriever --- .../elasticsearch/src/elasticsearch_haystack/bm25_retriever.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/bm25_retriever.py b/document_stores/elasticsearch/src/elasticsearch_haystack/bm25_retriever.py index 485ce2a15..d784845dc 100644 --- a/document_stores/elasticsearch/src/elasticsearch_haystack/bm25_retriever.py +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/bm25_retriever.py @@ -20,7 +20,7 @@ def __init__( scale_score: bool = True, ): if not isinstance(document_store, ElasticsearchDocumentStore): - raise + raise ValueError("document_store must be an instance of ElasticsearchDocumentStore") self._document_store = document_store self._filters = filters or {} From b196287f8ce6aa5bbcd24dd3c021b968a96a13eb Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Fri, 10 Nov 2023 11:18:39 +0100 Subject: [PATCH 20/21] Remove bad test docstring --- document_stores/elasticsearch/tests/test_document_store.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/document_stores/elasticsearch/tests/test_document_store.py b/document_stores/elasticsearch/tests/test_document_store.py index 22d613cbc..215df160e 100644 --- a/document_stores/elasticsearch/tests/test_document_store.py +++ b/document_stores/elasticsearch/tests/test_document_store.py @@ -121,9 +121,6 @@ def test_delete_empty(self, docstore: ElasticsearchDocumentStore): def test_delete_not_empty_nonexisting(self, docstore: ElasticsearchDocumentStore): """ - Verifies delete properly deletes specified document in DocumentStore containing - multiple documents. - `DocumentStoreBaseTests` declares this test but we override it since we want `delete_documents` to be idempotent. """ From 83ddad64c066bd0ccc54cae26fed02ec60d38174 Mon Sep 17 00:00:00 2001 From: Silvano Cerza Date: Fri, 10 Nov 2023 11:26:44 +0100 Subject: [PATCH 21/21] Fix linting issues --- .../src/elasticsearch_haystack/bm25_retriever.py | 3 ++- .../src/elasticsearch_haystack/document_store.py | 9 +++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/bm25_retriever.py b/document_stores/elasticsearch/src/elasticsearch_haystack/bm25_retriever.py index d784845dc..feb39f42a 100644 --- a/document_stores/elasticsearch/src/elasticsearch_haystack/bm25_retriever.py +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/bm25_retriever.py @@ -20,7 +20,8 @@ def __init__( scale_score: bool = True, ): if not isinstance(document_store, ElasticsearchDocumentStore): - raise ValueError("document_store must be an instance of ElasticsearchDocumentStore") + msg = "document_store must be an instance of ElasticsearchDocumentStore" + raise ValueError(msg) self._document_store = document_store self._filters = filters or {} diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py index 073b5f7c7..de059f685 100644 --- a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py @@ -24,9 +24,10 @@ # document scores are essentially unbounded and will be scaled to values between 0 and 1 if scale_score is set to # True (default). Scaling uses the expit function (inverse of the logit function) after applying a scaling factor # (e.g., BM25_SCALING_FACTOR for the bm25_retrieval method). -# Larger scaling factor decreases scaled scores. For example, an input of 10 is scaled to 0.99 with BM25_SCALING_FACTOR=2 -# but to 0.78 with BM25_SCALING_FACTOR=8 (default). The defaults were chosen empirically. Increase the default if most -# unscaled scores are larger than expected (>30) and otherwise would incorrectly all be mapped to scores ~1. +# Larger scaling factor decreases scaled scores. For example, an input of 10 is scaled to 0.99 with +# BM25_SCALING_FACTOR=2 but to 0.78 with BM25_SCALING_FACTOR=8 (default). The defaults were chosen empirically. +# Increase the default if most unscaled scores are larger than expected (>30) and otherwise would incorrectly +# all be mapped to scores ~1. BM25_SCALING_FACTOR = 8 @@ -152,7 +153,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc """ query = {"bool": {"filter": _normalize_filters(filters)}} if filters else None - documents = [] + documents: List[Document] = [] from_ = 0 # Handle pagination while True: