From 7d2b824ad27be1514a38a111d2d2a2480e12eaf8 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Thu, 16 Nov 2023 18:05:42 +0100 Subject: [PATCH] Elasticsearch Document store - embedding retrieval (#52) * set scale_score default to False * unrelated: replace text w content * first implementation * test * fix some tests * make tests more robust; skip unsupported ones * rm unsupported test * ignore import-not-found * first chunk addressing PR feedback * improve tests --- .../elasticsearch/docker-compose.yml | 2 +- .../elasticsearch_haystack/document_store.py | 112 +++++++++++++++--- .../tests/test_bm25_retriever.py | 6 +- .../tests/test_document_store.py | 80 +++++++++++-- 4 files changed, 171 insertions(+), 29 deletions(-) diff --git a/document_stores/elasticsearch/docker-compose.yml b/document_stores/elasticsearch/docker-compose.yml index 6d21941b7..66dba73f5 100644 --- a/document_stores/elasticsearch/docker-compose.yml +++ b/document_stores/elasticsearch/docker-compose.yml @@ -1,6 +1,6 @@ services: elasticsearch: - image: "docker.elastic.co/elasticsearch/elasticsearch:8.10.0" + image: "docker.elastic.co/elasticsearch/elasticsearch:8.11.1" ports: - 9200:9200 restart: on-failure diff --git a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py index 740b54180..083918d71 100644 --- a/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py +++ b/document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 import logging -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Literal, Mapping, Optional, Union import numpy as np @@ -33,7 +33,14 @@ @document_store class ElasticsearchDocumentStore: - def __init__(self, *, hosts: Optional[Hosts] = None, index: str = "default", **kwargs): + def __init__( + self, + *, + hosts: Optional[Hosts] = None, + index: str = "default", + embedding_similarity_function: Literal["cosine", "dot_product", "l2_norm", "max_inner_product"] = "cosine", + **kwargs, + ): """ Creates a new ElasticsearchDocumentStore instance. @@ -45,19 +52,32 @@ def __init__(self, *, hosts: Optional[Hosts] = None, index: str = "default", **k :param hosts: List of hosts running the Elasticsearch client. Defaults to None :param index: Name of index in Elasticsearch, if it doesn't exist it will be created. Defaults to "default" + :param embedding_similarity_function: The similarity function used to compare Documents embeddings. + Defaults to "cosine". This parameter only takes effect if the index does not yet exist and is created. + To choose the most appropriate function, look for information about your embedding model. + To understand how document scores are computed, see the Elasticsearch documentation: + https://www.elastic.co/guide/en/elasticsearch/reference/current/dense-vector.html#dense-vector-params :param **kwargs: Optional arguments that ``Elasticsearch`` takes. """ self._hosts = hosts self._client = Elasticsearch(hosts, **kwargs) self._index = index + self._embedding_similarity_function = embedding_similarity_function self._kwargs = kwargs # Check client connection, this will raise if not connected self._client.info() + # configure mapping for the embedding field + mappings = { + "properties": { + "embedding": {"type": "dense_vector", "index": True, "similarity": embedding_similarity_function} + } + } + # Create the index if it doesn't exist if not self._client.indices.exists(index=index): - self._client.indices.create(index=index) + self._client.indices.create(index=index, mappings=mappings) def to_dict(self) -> Dict[str, Any]: # This is not the best solution to serialise this class but is the fastest to implement. @@ -67,6 +87,7 @@ def to_dict(self) -> Dict[str, Any]: self, hosts=self._hosts, index=self._index, + embedding_similarity_function=self._embedding_similarity_function, **self._kwargs, ) @@ -80,6 +101,26 @@ def count_documents(self) -> int: """ return self._client.count(index=self._index)["count"] + def _search_documents(self, **kwargs) -> List[Document]: + """ + Calls the Elasticsearch client's search method and handles pagination. + """ + + documents: List[Document] = [] + from_ = 0 + # Handle pagination + while True: + res = self._client.search( + index=self._index, + from_=from_, + **kwargs, + ) + documents.extend(self._deserialize_document(hit) for hit in res["hits"]["hits"]) + from_ = len(documents) + if from_ >= res["hits"]["total"]["value"]: + break + return documents + def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]: """ Returns the documents that match the filters provided. @@ -152,20 +193,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc :return: a list of Documents that match the given filters. """ query = {"bool": {"filter": _normalize_filters(filters)}} if filters else None - - documents: List[Document] = [] - from_ = 0 - # Handle pagination - while True: - res = self._client.search( - index=self._index, - query=query, - from_=from_, - ) - documents.extend(self._deserialize_document(hit) for hit in res["hits"]["hits"]) - from_ = len(documents) - if from_ >= res["hits"]["total"]["value"]: - break + documents = self._search_documents(query=query) return documents def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None: @@ -306,3 +334,53 @@ def _bm25_retrieval( hit["_score"] = float(1 / (1 + np.exp(-np.asarray(hit["_score"] / BM25_SCALING_FACTOR)))) docs.append(self._deserialize_document(hit)) return docs + + def _embedding_retrieval( + self, + query_embedding: List[float], + *, + filters: Optional[Dict[str, Any]] = None, + top_k: int = 10, + num_candidates: Optional[int] = None, + ) -> List[Document]: + """ + Retrieves documents that are most similar to the query embedding using a vector similarity metric. + It uses the Elasticsearch's Approximate k-Nearest Neighbors search algorithm. + + This method is not mean to be part of the public interface of + `ElasticsearchDocumentStore` nor called directly. + `ElasticsearchEmbeddingRetriever` uses this method directly and is the public interface for it. + + :param query_embedding: Embedding of the query. + :param filters: Filters applied to the retrieved Documents. Defaults to None. + Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned. + :param top_k: Maximum number of Documents to return, defaults to 10 + :param num_candidates: Number of approximate nearest neighbor candidates on each shard. Defaults to top_k * 10. + Increasing this value will improve search accuracy at the cost of slower search speeds. + You can read more about it in the Elasticsearch documentation: + https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#tune-approximate-knn-for-speed-accuracy + :raises ValueError: If `query_embedding` is an empty list + :return: List of Document that are most similar to `query_embedding` + """ + + if not query_embedding: + msg = "query_embedding must be a non-empty list of floats" + raise ValueError(msg) + + if not num_candidates: + num_candidates = top_k * 10 + + body: Dict[str, Any] = { + "knn": { + "field": "embedding", + "query_vector": query_embedding, + "k": top_k, + "num_candidates": num_candidates, + }, + } + + if filters: + body["knn"]["filter"] = _normalize_filters(filters) + + docs = self._search_documents(**body) + return docs diff --git a/document_stores/elasticsearch/tests/test_bm25_retriever.py b/document_stores/elasticsearch/tests/test_bm25_retriever.py index 86c5aac3a..9139368d9 100644 --- a/document_stores/elasticsearch/tests/test_bm25_retriever.py +++ b/document_stores/elasticsearch/tests/test_bm25_retriever.py @@ -27,7 +27,11 @@ def test_to_dict(_mock_elasticsearch_client): "type": "ElasticsearchBM25Retriever", "init_parameters": { "document_store": { - "init_parameters": {"hosts": "some fake host", "index": "default"}, + "init_parameters": { + "hosts": "some fake host", + "index": "default", + "embedding_similarity_function": "cosine", + }, "type": "ElasticsearchDocumentStore", }, "filters": {}, diff --git a/document_stores/elasticsearch/tests/test_document_store.py b/document_stores/elasticsearch/tests/test_document_store.py index 130da8340..11443546c 100644 --- a/document_stores/elasticsearch/tests/test_document_store.py +++ b/document_stores/elasticsearch/tests/test_document_store.py @@ -6,6 +6,7 @@ import pandas as pd import pytest +from elasticsearch.exceptions import BadRequestError # type: ignore[import-not-found] from haystack.preview.dataclasses.document import Document from haystack.preview.document_stores.errors import DuplicateDocumentError from haystack.preview.document_stores.protocols import DuplicatePolicy @@ -30,7 +31,13 @@ def docstore(self, request): # Use a different index for each test so we can run them in parallel index = f"{request.node.name}" - store = ElasticsearchDocumentStore(hosts=hosts, index=index) + # this similarity function is rarely used in practice, but it is robust for test cases with fake embeddings + # in fact, it works fine with vectors like [0.0] * 768, while cosine similarity would raise an exception + embedding_similarity_function = "max_inner_product" + + store = ElasticsearchDocumentStore( + hosts=hosts, index=index, embedding_similarity_function=embedding_similarity_function + ) yield store store._client.options(ignore_status=[400, 404]).indices.delete(index=index) @@ -43,6 +50,7 @@ def test_to_dict(self, _mock_elasticsearch_client): "init_parameters": { "hosts": "some hosts", "index": "default", + "embedding_similarity_function": "cosine", }, } @@ -53,11 +61,13 @@ def test_from_dict(self, _mock_elasticsearch_client): "init_parameters": { "hosts": "some hosts", "index": "default", + "embedding_similarity_function": "cosine", }, } document_store = ElasticsearchDocumentStore.from_dict(data) assert document_store._hosts == "some hosts" assert document_store._index == "default" + assert document_store._embedding_similarity_function == "cosine" def test_bm25_retrieval(self, docstore: ElasticsearchDocumentStore): docstore.write_documents( @@ -169,15 +179,6 @@ def test_in_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_ def test_in_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): pass - def test_ne_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): - docstore.write_documents(filterable_docs) - embedding = [0.0] * 768 - result = docstore.filter_documents(filters={"embedding": {"$ne": embedding}}) - assert self.contains_same_docs( - result, - [doc for doc in filterable_docs if doc.embedding is None or not embedding == doc.embedding], - ) - @pytest.mark.skip(reason="Not supported") def test_nin_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): pass @@ -186,6 +187,26 @@ def test_nin_filter_table(self, docstore: ElasticsearchDocumentStore, filterable def test_nin_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): pass + @pytest.mark.skip(reason="Not supported") + def test_eq_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + """ + If the embedding field is a dense vector (as expected), raise the following error: + + elasticsearch.BadRequestError: BadRequestError(400, 'search_phase_execution_exception', + "failed to create query: Field [embedding] of type [dense_vector] doesn't support term queries") + """ + pass + + @pytest.mark.skip(reason="Not supported") + def test_ne_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): + """ + If the embedding field is a dense vector (as expected), raise the following error: + + elasticsearch.BadRequestError: BadRequestError(400, 'search_phase_execution_exception', + "failed to create query: Field [embedding] of type [dense_vector] doesn't support term queries") + """ + pass + def test_gt_filter_non_numeric(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]): docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"page": {"$gt": "100"}}) @@ -231,3 +252,42 @@ def test_lte_filter_table(self, docstore: ElasticsearchDocumentStore, filterable docstore.write_documents(filterable_docs) result = docstore.filter_documents(filters={"dataframe": {"$lte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}}) assert self.contains_same_docs(result, [d for d in filterable_docs if d.dataframe is not None]) + + def test_embedding_retrieval(self, docstore: ElasticsearchDocumentStore): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document(content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9]), + ] + docstore.write_documents(docs) + results = docstore._embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2, filters={}) + assert len(results) == 2 + assert results[0].content == "Most similar document" + assert results[1].content == "2nd best document" + + def test_embedding_retrieval_w_filters(self, docstore: ElasticsearchDocumentStore): + docs = [ + Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]), + Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]), + Document( + content="Not very similar document with meta field", + embedding=[0.0, 0.8, 0.3, 0.9], + meta={"meta_field": "custom_value"}, + ), + ] + docstore.write_documents(docs) + + filters = {"meta_field": {"$eq": "custom_value"}} + results = docstore._embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2, filters=filters) + assert len(results) == 1 + assert results[0].content == "Not very similar document with meta field" + + def test_embedding_retrieval_query_documents_different_embedding_sizes(self, docstore: ElasticsearchDocumentStore): + """ + Test that the retrieval fails if the query embedding and the documents have different embedding sizes. + """ + docs = [Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4])] + docstore.write_documents(docs) + + with pytest.raises(BadRequestError): + docstore._embedding_retrieval(query_embedding=[0.1, 0.1])