Skip to content

Commit

Permalink
Elasticsearch Embedding Retriever (#54)
Browse files Browse the repository at this point in the history
* set scale_score default to False

* unrelated: replace text w content

* first implementation

* test

* fix some tests

* make tests more robust; skip unsupported ones

* rm unsupported test

* ignore import-not-found

* embedding retriever

* tests

* organize imports

* first chunk addressing PR feedback

* improve tests

* add docstrings

* more docstrings
  • Loading branch information
anakin87 authored Nov 16, 2023
1 parent 6babb9a commit 5ecacc5
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional

from haystack.preview import component, default_from_dict, default_to_dict
from haystack.preview.dataclasses import Document

from elasticsearch_haystack.document_store import ElasticsearchDocumentStore


@component
class ElasticsearchEmbeddingRetriever:
"""
Uses a vector similarity metric to retrieve documents from the ElasticsearchDocumentStore.
Needs to be connected to the ElasticsearchDocumentStore to run.
"""

def __init__(
self,
*,
document_store: ElasticsearchDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
num_candidates: Optional[int] = None,
):
"""
Create the ElasticsearchEmbeddingRetriever component.
:param document_store: An instance of ElasticsearchDocumentStore.
:param filters: Filters applied to the retrieved Documents. Defaults to None.
Filters are applied during the approximate kNN search to ensure that top_k matching documents are returned.
:param top_k: Maximum number of Documents to return, defaults to 10
:param num_candidates: Number of approximate nearest neighbor candidates on each shard. Defaults to top_k * 10.
Increasing this value will improve search accuracy at the cost of slower search speeds.
You can read more about it in the Elasticsearch documentation:
https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#tune-approximate-knn-for-speed-accuracy
:raises ValueError: If `document_store` is not an instance of ElasticsearchDocumentStore.
"""
if not isinstance(document_store, ElasticsearchDocumentStore):
msg = "document_store must be an instance of ElasticsearchDocumentStore"
raise ValueError(msg)

self._document_store = document_store
self._filters = filters or {}
self._top_k = top_k
self._num_candidates = num_candidates

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
filters=self._filters,
top_k=self._top_k,
num_candidates=self._num_candidates,
document_store=self._document_store.to_dict(),
)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ElasticsearchEmbeddingRetriever":
data["init_parameters"]["document_store"] = ElasticsearchDocumentStore.from_dict(
data["init_parameters"]["document_store"]
)
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query_embedding: List[float]):
"""
Retrieve documents using a vector similarity metric.
:param query_embedding: Embedding of the query.
:return: List of Document similar to `query_embedding`.
"""
docs = self._document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=self._filters,
top_k=self._top_k,
num_candidates=self._num_candidates,
)
return {"documents": docs}
79 changes: 79 additions & 0 deletions document_stores/elasticsearch/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# SPDX-FileCopyrightText: 2023-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0
from unittest.mock import Mock, patch

from haystack.preview.dataclasses import Document

from elasticsearch_haystack.document_store import ElasticsearchDocumentStore
from elasticsearch_haystack.embedding_retriever import ElasticsearchEmbeddingRetriever


def test_init_default():
mock_store = Mock(spec=ElasticsearchDocumentStore)
retriever = ElasticsearchEmbeddingRetriever(document_store=mock_store)
assert retriever._document_store == mock_store
assert retriever._filters == {}
assert retriever._top_k == 10
assert retriever._num_candidates is None


@patch("elasticsearch_haystack.document_store.Elasticsearch")
def test_to_dict(_mock_elasticsearch_client):
document_store = ElasticsearchDocumentStore(hosts="some fake host")
retriever = ElasticsearchEmbeddingRetriever(document_store=document_store)
res = retriever.to_dict()
assert res == {
"type": "ElasticsearchEmbeddingRetriever",
"init_parameters": {
"document_store": {
"init_parameters": {
"hosts": "some fake host",
"index": "default",
"embedding_similarity_function": "cosine",
},
"type": "ElasticsearchDocumentStore",
},
"filters": {},
"top_k": 10,
"num_candidates": None,
},
}


@patch("elasticsearch_haystack.document_store.Elasticsearch")
def test_from_dict(_mock_elasticsearch_client):
data = {
"type": "ElasticsearchEmbeddingRetriever",
"init_parameters": {
"document_store": {
"init_parameters": {"hosts": "some fake host", "index": "default"},
"type": "ElasticsearchDocumentStore",
},
"filters": {},
"top_k": 10,
"num_candidates": None,
},
}
retriever = ElasticsearchEmbeddingRetriever.from_dict(data)
assert retriever._document_store
assert retriever._filters == {}
assert retriever._top_k == 10
assert retriever._num_candidates is None


def test_run():
mock_store = Mock(spec=ElasticsearchDocumentStore)
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])]
retriever = ElasticsearchEmbeddingRetriever(document_store=mock_store)
res = retriever.run(query_embedding=[0.5, 0.7])
mock_store._embedding_retrieval.assert_called_once_with(
query_embedding=[0.5, 0.7],
filters={},
top_k=10,
num_candidates=None,
)
assert len(res) == 1
assert len(res["documents"]) == 1
assert res["documents"][0].content == "Test doc"
assert res["documents"][0].embedding == [0.1, 0.2]

0 comments on commit 5ecacc5

Please sign in to comment.