Skip to content

Commit

Permalink
Elasticsearch - refactor _search_documents (#57)
Browse files Browse the repository at this point in the history
* set scale_score default to False

* unrelated: replace text w content

* first implementation

* test

* fix some tests

* make tests more robust; skip unsupported ones

* rm unsupported test

* ignore import-not-found

* first chunk addressing PR feedback

* improve tests

* use _search_documents also in bm25 retrieval

* improve logic and tests

* fix format

* better format

* Update document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py

Co-authored-by: Silvano Cerza <[email protected]>

* Update document_stores/elasticsearch/src/elasticsearch_haystack/document_store.py

Co-authored-by: Silvano Cerza <[email protected]>

* remove wrong increment

* move ruff ignore error

---------

Co-authored-by: Silvano Cerza <[email protected]>
  • Loading branch information
anakin87 and silvanocerza authored Nov 20, 2023
1 parent 48c0d5f commit 1c6410e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ def _search_documents(self, **kwargs) -> List[Document]:
Calls the Elasticsearch client's search method and handles pagination.
"""

top_k = kwargs.get("size")
if top_k is None and "knn" in kwargs and "k" in kwargs["knn"]:
top_k = kwargs["knn"]["k"]

documents: List[Document] = []
from_ = 0
# Handle pagination
Expand All @@ -115,8 +119,12 @@ def _search_documents(self, **kwargs) -> List[Document]:
from_=from_,
**kwargs,
)

documents.extend(self._deserialize_document(hit) for hit in res["hits"]["hits"])
from_ = len(documents)

if top_k is not None and from_ >= top_k:
break
if from_ >= res["hits"]["total"]["value"]:
break
return documents
Expand Down Expand Up @@ -326,14 +334,13 @@ def _bm25_retrieval(
if filters:
body["query"]["bool"]["filter"] = _normalize_filters(filters)

res = self._client.search(index=self._index, **body)
documents = self._search_documents(**body)

docs = []
for hit in res["hits"]["hits"]:
if scale_score:
hit["_score"] = float(1 / (1 + np.exp(-np.asarray(hit["_score"] / BM25_SCALING_FACTOR))))
docs.append(self._deserialize_document(hit))
return docs
if scale_score:
for doc in documents:
doc.score = float(1 / (1 + np.exp(-np.asarray(doc.score / BM25_SCALING_FACTOR))))

return documents

def _embedding_retrieval(
self,
Expand Down
44 changes: 44 additions & 0 deletions document_stores/elasticsearch/tests/test_document_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: 2023-present Silvano Cerza <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import random
from typing import List
from unittest.mock import patch

Expand Down Expand Up @@ -92,6 +94,34 @@ def test_bm25_retrieval(self, docstore: ElasticsearchDocumentStore):
assert "functional" in res[1].content
assert "functional" in res[2].content

def test_bm25_retrieval_pagination(self, docstore: ElasticsearchDocumentStore):
"""
Test that handling of pagination works as expected, when the matching documents are > 10.
"""
docstore.write_documents(
[
Document(content="Haskell is a functional programming language"),
Document(content="Lisp is a functional programming language"),
Document(content="Exilir is a functional programming language"),
Document(content="F# is a functional programming language"),
Document(content="C# is a functional programming language"),
Document(content="C++ is an object oriented programming language"),
Document(content="Dart is an object oriented programming language"),
Document(content="Go is an object oriented programming language"),
Document(content="Python is a object oriented programming language"),
Document(content="Ruby is a object oriented programming language"),
Document(content="PHP is a object oriented programming language"),
Document(content="Java is an object oriented programming language"),
Document(content="Javascript is a programming language"),
Document(content="Typescript is a programming language"),
Document(content="C is a programming language"),
]
)

res = docstore._bm25_retrieval("programming", top_k=11)
assert len(res) == 11
assert all("programming" in doc.content for doc in res)

def test_bm25_retrieval_with_fuzziness(self, docstore: ElasticsearchDocumentStore):
docstore.write_documents(
[
Expand Down Expand Up @@ -282,6 +312,20 @@ def test_embedding_retrieval_w_filters(self, docstore: ElasticsearchDocumentStor
assert len(results) == 1
assert results[0].content == "Not very similar document with meta field"

def test_embedding_retrieval_pagination(self, docstore: ElasticsearchDocumentStore):
"""
Test that handling of pagination works as expected, when the matching documents are > 10.
"""

docs = [
Document(content=f"Document {i}", embedding=[random.random() for _ in range(4)]) # noqa: S311
for i in range(20)
]

docstore.write_documents(docs)
results = docstore._embedding_retrieval(query_embedding=[0.1, 0.1, 0.1, 0.1], top_k=11, filters={})
assert len(results) == 11

def test_embedding_retrieval_query_documents_different_embedding_sizes(self, docstore: ElasticsearchDocumentStore):
"""
Test that the retrieval fails if the query embedding and the documents have different embedding sizes.
Expand Down

0 comments on commit 1c6410e

Please sign in to comment.