Skip to content

Commit

Permalink
Fix embedding retrieval top-k limit
Browse files Browse the repository at this point in the history
  • Loading branch information
basoko committed Nov 22, 2024
1 parent a979667 commit 1b2efb5
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _format_query_response(responses, include_metadata, include_values):
return QueryResponse(final_res)

def _query(self, vector, top_k, filters=None):
query = {"sort": {"$vector": vector}, "options": {"limit": top_k, "includeSimilarity": True}}
query = {"sort": {"$vector": vector}, "limit": top_k, "includeSimilarity": True}

if filters is not None:
query["filter"] = filters
Expand All @@ -222,6 +222,7 @@ def find_documents(self, find_query):
filter=find_query.get("filter"),
sort=find_query.get("sort"),
limit=find_query.get("limit"),
include_similarity=find_query.get("includeSimilarity"),
projection={"*": 1},
)

Expand Down
45 changes: 45 additions & 0 deletions integrations/astra/tests/test_embedding_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import os

import pytest
from haystack import Document
from haystack.document_stores.types import DuplicatePolicy

from haystack_integrations.document_stores.astra import AstraDocumentStore


@pytest.mark.integration
@pytest.mark.skipif(
os.environ.get("ASTRA_DB_APPLICATION_TOKEN", "") == "", reason="ASTRA_DB_APPLICATION_TOKEN env var not set"
)
@pytest.mark.skipif(os.environ.get("ASTRA_DB_API_ENDPOINT", "") == "", reason="ASTRA_DB_API_ENDPOINT env var not set")
class TestEmbeddingRetrieval:

@pytest.fixture
def document_store(self) -> AstraDocumentStore:
return AstraDocumentStore(
collection_name="haystack_integration",
duplicates_policy=DuplicatePolicy.OVERWRITE,
embedding_dimension=768,
)

@pytest.fixture(autouse=True)
def run_before_and_after_tests(self, document_store: AstraDocumentStore):
"""
Cleaning up document store
"""
document_store.delete_documents(delete_all=True)
assert document_store.count_documents() == 0

def test_search_with_top_k(self, document_store):
query_embedding = [0.1] * 768
common_embedding = [0.8] * 768

documents = [Document(content=f"This is document number {i}", embedding=common_embedding) for i in range(0, 3)]

document_store.write_documents(documents)

top_k = 2

result = document_store.search(query_embedding, top_k)

assert top_k == len(result)

0 comments on commit 1b2efb5

Please sign in to comment.