Skip to content

Commit

Permalink
Elasticsearch Document store - embedding retrieval (#52)
Browse files Browse the repository at this point in the history
* set scale_score default to False

* unrelated: replace text w content

* first implementation

* test

* fix some tests

* make tests more robust; skip unsupported ones

* rm unsupported test

* ignore import-not-found

* first chunk addressing PR feedback

* improve tests
  • Loading branch information
anakin87 authored Nov 16, 2023
1 parent 8df2edf commit 7d2b824
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 29 deletions.
2 changes: 1 addition & 1 deletion document_stores/elasticsearch/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
services:
elasticsearch:
image: "docker.elastic.co/elasticsearch/elasticsearch:8.10.0"
image: "docker.elastic.co/elasticsearch/elasticsearch:8.11.1"
ports:
- 9200:9200
restart: on-failure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Literal, Mapping, Optional, Union

import numpy as np

Expand Down Expand Up @@ -33,7 +33,14 @@

@document_store
class ElasticsearchDocumentStore:
def __init__(self, *, hosts: Optional[Hosts] = None, index: str = "default", **kwargs):
def __init__(
self,
*,
hosts: Optional[Hosts] = None,
index: str = "default",
embedding_similarity_function: Literal["cosine", "dot_product", "l2_norm", "max_inner_product"] = "cosine",
**kwargs,
):
"""
Creates a new ElasticsearchDocumentStore instance.
Expand All @@ -45,19 +52,32 @@ def __init__(self, *, hosts: Optional[Hosts] = None, index: str = "default", **k
:param hosts: List of hosts running the Elasticsearch client. Defaults to None
:param index: Name of index in Elasticsearch, if it doesn't exist it will be created. Defaults to "default"
:param embedding_similarity_function: The similarity function used to compare Documents embeddings.
Defaults to "cosine". This parameter only takes effect if the index does not yet exist and is created.
To choose the most appropriate function, look for information about your embedding model.
To understand how document scores are computed, see the Elasticsearch documentation:
https://www.elastic.co/guide/en/elasticsearch/reference/current/dense-vector.html#dense-vector-params
:param **kwargs: Optional arguments that ``Elasticsearch`` takes.
"""
self._hosts = hosts
self._client = Elasticsearch(hosts, **kwargs)
self._index = index
self._embedding_similarity_function = embedding_similarity_function
self._kwargs = kwargs

# Check client connection, this will raise if not connected
self._client.info()

# configure mapping for the embedding field
mappings = {
"properties": {
"embedding": {"type": "dense_vector", "index": True, "similarity": embedding_similarity_function}
}
}

# Create the index if it doesn't exist
if not self._client.indices.exists(index=index):
self._client.indices.create(index=index)
self._client.indices.create(index=index, mappings=mappings)

def to_dict(self) -> Dict[str, Any]:
# This is not the best solution to serialise this class but is the fastest to implement.
Expand All @@ -67,6 +87,7 @@ def to_dict(self) -> Dict[str, Any]:
self,
hosts=self._hosts,
index=self._index,
embedding_similarity_function=self._embedding_similarity_function,
**self._kwargs,
)

Expand All @@ -80,6 +101,26 @@ def count_documents(self) -> int:
"""
return self._client.count(index=self._index)["count"]

def _search_documents(self, **kwargs) -> List[Document]:
"""
Calls the Elasticsearch client's search method and handles pagination.
"""

documents: List[Document] = []
from_ = 0
# Handle pagination
while True:
res = self._client.search(
index=self._index,
from_=from_,
**kwargs,
)
documents.extend(self._deserialize_document(hit) for hit in res["hits"]["hits"])
from_ = len(documents)
if from_ >= res["hits"]["total"]["value"]:
break
return documents

def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Document]:
"""
Returns the documents that match the filters provided.
Expand Down Expand Up @@ -152,20 +193,7 @@ def filter_documents(self, filters: Optional[Dict[str, Any]] = None) -> List[Doc
:return: a list of Documents that match the given filters.
"""
query = {"bool": {"filter": _normalize_filters(filters)}} if filters else None

documents: List[Document] = []
from_ = 0
# Handle pagination
while True:
res = self._client.search(
index=self._index,
query=query,
from_=from_,
)
documents.extend(self._deserialize_document(hit) for hit in res["hits"]["hits"])
from_ = len(documents)
if from_ >= res["hits"]["total"]["value"]:
break
documents = self._search_documents(query=query)
return documents

def write_documents(self, documents: List[Document], policy: DuplicatePolicy = DuplicatePolicy.FAIL) -> None:
Expand Down Expand Up @@ -306,3 +334,53 @@ def _bm25_retrieval(
hit["_score"] = float(1 / (1 + np.exp(-np.asarray(hit["_score"] / BM25_SCALING_FACTOR))))
docs.append(self._deserialize_document(hit))
return docs

def _embedding_retrieval(
self,
query_embedding: List[float],
*,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
num_candidates: Optional[int] = None,
) -> List[Document]:
"""
Retrieves documents that are most similar to the query embedding using a vector similarity metric.
It uses the Elasticsearch's Approximate k-Nearest Neighbors search algorithm.
This method is not mean to be part of the public interface of
`ElasticsearchDocumentStore` nor called directly.
`ElasticsearchEmbeddingRetriever` uses this method directly and is the public interface for it.
:param query_embedding: Embedding of the query.
:param filters: Filters applied to the retrieved Documents. Defaults to None.
Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned.
:param top_k: Maximum number of Documents to return, defaults to 10
:param num_candidates: Number of approximate nearest neighbor candidates on each shard. Defaults to top_k * 10.
Increasing this value will improve search accuracy at the cost of slower search speeds.
You can read more about it in the Elasticsearch documentation:
https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#tune-approximate-knn-for-speed-accuracy
:raises ValueError: If `query_embedding` is an empty list
:return: List of Document that are most similar to `query_embedding`
"""

if not query_embedding:
msg = "query_embedding must be a non-empty list of floats"
raise ValueError(msg)

if not num_candidates:
num_candidates = top_k * 10

body: Dict[str, Any] = {
"knn": {
"field": "embedding",
"query_vector": query_embedding,
"k": top_k,
"num_candidates": num_candidates,
},
}

if filters:
body["knn"]["filter"] = _normalize_filters(filters)

docs = self._search_documents(**body)
return docs
6 changes: 5 additions & 1 deletion document_stores/elasticsearch/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ def test_to_dict(_mock_elasticsearch_client):
"type": "ElasticsearchBM25Retriever",
"init_parameters": {
"document_store": {
"init_parameters": {"hosts": "some fake host", "index": "default"},
"init_parameters": {
"hosts": "some fake host",
"index": "default",
"embedding_similarity_function": "cosine",
},
"type": "ElasticsearchDocumentStore",
},
"filters": {},
Expand Down
80 changes: 70 additions & 10 deletions document_stores/elasticsearch/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pandas as pd
import pytest
from elasticsearch.exceptions import BadRequestError # type: ignore[import-not-found]
from haystack.preview.dataclasses.document import Document
from haystack.preview.document_stores.errors import DuplicateDocumentError
from haystack.preview.document_stores.protocols import DuplicatePolicy
Expand All @@ -30,7 +31,13 @@ def docstore(self, request):
# Use a different index for each test so we can run them in parallel
index = f"{request.node.name}"

store = ElasticsearchDocumentStore(hosts=hosts, index=index)
# this similarity function is rarely used in practice, but it is robust for test cases with fake embeddings
# in fact, it works fine with vectors like [0.0] * 768, while cosine similarity would raise an exception
embedding_similarity_function = "max_inner_product"

store = ElasticsearchDocumentStore(
hosts=hosts, index=index, embedding_similarity_function=embedding_similarity_function
)
yield store
store._client.options(ignore_status=[400, 404]).indices.delete(index=index)

Expand All @@ -43,6 +50,7 @@ def test_to_dict(self, _mock_elasticsearch_client):
"init_parameters": {
"hosts": "some hosts",
"index": "default",
"embedding_similarity_function": "cosine",
},
}

Expand All @@ -53,11 +61,13 @@ def test_from_dict(self, _mock_elasticsearch_client):
"init_parameters": {
"hosts": "some hosts",
"index": "default",
"embedding_similarity_function": "cosine",
},
}
document_store = ElasticsearchDocumentStore.from_dict(data)
assert document_store._hosts == "some hosts"
assert document_store._index == "default"
assert document_store._embedding_similarity_function == "cosine"

def test_bm25_retrieval(self, docstore: ElasticsearchDocumentStore):
docstore.write_documents(
Expand Down Expand Up @@ -169,15 +179,6 @@ def test_in_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_
def test_in_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]):
pass

def test_ne_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs)
embedding = [0.0] * 768
result = docstore.filter_documents(filters={"embedding": {"$ne": embedding}})
assert self.contains_same_docs(
result,
[doc for doc in filterable_docs if doc.embedding is None or not embedding == doc.embedding],
)

@pytest.mark.skip(reason="Not supported")
def test_nin_filter_table(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]):
pass
Expand All @@ -186,6 +187,26 @@ def test_nin_filter_table(self, docstore: ElasticsearchDocumentStore, filterable
def test_nin_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]):
pass

@pytest.mark.skip(reason="Not supported")
def test_eq_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]):
"""
If the embedding field is a dense vector (as expected), raise the following error:
elasticsearch.BadRequestError: BadRequestError(400, 'search_phase_execution_exception',
"failed to create query: Field [embedding] of type [dense_vector] doesn't support term queries")
"""
pass

@pytest.mark.skip(reason="Not supported")
def test_ne_filter_embedding(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]):
"""
If the embedding field is a dense vector (as expected), raise the following error:
elasticsearch.BadRequestError: BadRequestError(400, 'search_phase_execution_exception',
"failed to create query: Field [embedding] of type [dense_vector] doesn't support term queries")
"""
pass

def test_gt_filter_non_numeric(self, docstore: ElasticsearchDocumentStore, filterable_docs: List[Document]):
docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"page": {"$gt": "100"}})
Expand Down Expand Up @@ -231,3 +252,42 @@ def test_lte_filter_table(self, docstore: ElasticsearchDocumentStore, filterable
docstore.write_documents(filterable_docs)
result = docstore.filter_documents(filters={"dataframe": {"$lte": pd.DataFrame([[1, 2, 3], [-1, -2, -3]])}})
assert self.contains_same_docs(result, [d for d in filterable_docs if d.dataframe is not None])

def test_embedding_retrieval(self, docstore: ElasticsearchDocumentStore):
docs = [
Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]),
Document(content="Not very similar document", embedding=[0.0, 0.8, 0.3, 0.9]),
]
docstore.write_documents(docs)
results = docstore._embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2, filters={})
assert len(results) == 2
assert results[0].content == "Most similar document"
assert results[1].content == "2nd best document"

def test_embedding_retrieval_w_filters(self, docstore: ElasticsearchDocumentStore):
docs = [
Document(content="Most similar document", embedding=[1.0, 1.0, 1.0, 1.0]),
Document(content="2nd best document", embedding=[0.8, 0.8, 0.8, 1.0]),
Document(
content="Not very similar document with meta field",
embedding=[0.0, 0.8, 0.3, 0.9],
meta={"meta_field": "custom_value"},
),
]
docstore.write_documents(docs)

filters = {"meta_field": {"$eq": "custom_value"}}
results = docstore._embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=2, filters=filters)
assert len(results) == 1
assert results[0].content == "Not very similar document with meta field"

def test_embedding_retrieval_query_documents_different_embedding_sizes(self, docstore: ElasticsearchDocumentStore):
"""
Test that the retrieval fails if the query embedding and the documents have different embedding sizes.
"""
docs = [Document(content="Hello world", embedding=[0.1, 0.2, 0.3, 0.4])]
docstore.write_documents(docs)

with pytest.raises(BadRequestError):
docstore._embedding_retrieval(query_embedding=[0.1, 0.1])

0 comments on commit 7d2b824

Please sign in to comment.