From 3a61a61f81e3740a3d6f936951aeb2f5f46a321c Mon Sep 17 00:00:00 2001 From: David Basoko Date: Fri, 22 Nov 2024 10:16:42 +0100 Subject: [PATCH] Fix embedding retrieval top-k limit --- .../document_stores/astra/astra_client.py | 3 +- .../astra/tests/test_embedding_retrieval.py | 48 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 integrations/astra/tests/test_embedding_retrieval.py diff --git a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py index 6f2289786..1a3481e0c 100644 --- a/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py +++ b/integrations/astra/src/haystack_integrations/document_stores/astra/astra_client.py @@ -202,7 +202,7 @@ def _format_query_response(responses, include_metadata, include_values): return QueryResponse(final_res) def _query(self, vector, top_k, filters=None): - query = {"sort": {"$vector": vector}, "options": {"limit": top_k, "includeSimilarity": True}} + query = {"sort": {"$vector": vector}, "limit": top_k, "includeSimilarity": True} if filters is not None: query["filter"] = filters @@ -222,6 +222,7 @@ def find_documents(self, find_query): filter=find_query.get("filter"), sort=find_query.get("sort"), limit=find_query.get("limit"), + include_similarity=find_query.get("includeSimilarity"), projection={"*": 1}, ) diff --git a/integrations/astra/tests/test_embedding_retrieval.py b/integrations/astra/tests/test_embedding_retrieval.py new file mode 100644 index 000000000..bf23fe9f5 --- /dev/null +++ b/integrations/astra/tests/test_embedding_retrieval.py @@ -0,0 +1,48 @@ +import os + +import pytest +from haystack import Document +from haystack.document_stores.types import DuplicatePolicy + +from haystack_integrations.document_stores.astra import AstraDocumentStore + + +@pytest.mark.integration +@pytest.mark.skipif( + os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set" +) +@pytest.mark.skipif(os.environ.get("ASTRA_DB_API_ENDPOINT", "") == "", reason="ASTRA_DB_API_ENDPOINT env var not set") +class TestEmbeddingRetrieval: + + @pytest.fixture + def document_store(self) -> AstraDocumentStore: + return AstraDocumentStore( + collection_name="haystack_integration", + duplicates_policy=DuplicatePolicy.OVERWRITE, + embedding_dimension=768, + ) + + @pytest.fixture(autouse=True) + def run_before_and_after_tests(self, document_store: AstraDocumentStore): + """ + Cleaning up document store + """ + document_store.delete_documents(delete_all=True) + assert document_store.count_documents() == 0 + + def test_search_with_top_k(self, document_store): + query_embedding = [0.1] * 768 + common_embedding = [0.8] * 768 + + documents = [Document(content=f"This is document number {i}", embedding=common_embedding) for i in range(0, 3)] + + document_store.write_documents(documents) + + top_k = 2 + + result = document_store.search(query_embedding, top_k) + + assert top_k == len(result) + + for document in result: + assert document.score is not None