Skip to content

Commit

Permalink
add score to document results
Browse files Browse the repository at this point in the history
Signed-off-by: ChengZi <[email protected]>
  • Loading branch information
zc277584121 committed Dec 12, 2024
1 parent 77b27de commit 9b210fb
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 53 deletions.
35 changes: 33 additions & 2 deletions src/milvus_haystack/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,8 @@ def _embedding_retrieval(
output_fields=output_fields,
timeout=None,
)
docs = self._parse_search_result(res, output_fields=output_fields)
distance_to_score_fn = self._select_score_fn()
docs = self._parse_search_result(res, output_fields=output_fields, distance_to_score_fn=distance_to_score_fn)
return docs

def _sparse_embedding_retrieval(
Expand Down Expand Up @@ -868,16 +869,46 @@ def _hybrid_retrieval(
docs = self._parse_search_result(res, output_fields=output_fields)
return docs

def _parse_search_result(self, result, output_fields=None) -> List[Document]:
def _parse_search_result(self, result, output_fields=None, distance_to_score_fn=lambda x: x) -> List[Document]:
if output_fields is None:
output_fields = self.fields[:]
docs = []
for res in result[0]:
data = {x: res.entity.get(x) for x in output_fields}
doc = self._parse_document(data)
doc.score = distance_to_score_fn(res.distance)
docs.append(doc)
return docs

def _select_score_fn(self):
def _map_l2_to_similarity(l2_distance: float) -> float:
"""Return a similarity score on a scale [0, 1].
It is recommended that the original vector is normalized,
Milvus only calculates the value before applying square root.
l2_distance range: (0 is most similar, 4 most dissimilar)
See
https://milvus.io/docs/metric.md?tab=floating#Euclidean-distance-L2
"""
return 1 - l2_distance / 4.0

def _map_ip_to_similarity(ip_score: float) -> float:
"""Return a similarity score on a scale [0, 1].
It is recommended that the original vector is normalized,
ip_score range: (1 is most similar, -1 most dissimilar)
See
https://milvus.io/docs/metric.md?tab=floating#Inner-product-IP
https://milvus.io/docs/metric.md?tab=floating#Cosine-Similarity
"""
return (ip_score + 1) / 2.0

metric_type = self.index_params.get("metric_type", None)
if metric_type == "L2":
return _map_l2_to_similarity
elif metric_type in ["IP", "COSINE"]:
return _map_ip_to_similarity
else:
return lambda x: x

def _parse_document(self, data: dict) -> Document:
# we store dummy vectors during writing documents if they are not provided,
# so we don't return them if they are dummy vectors
Expand Down
4 changes: 2 additions & 2 deletions tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
logger = logging.getLogger(__name__)

DEFAULT_CONNECTION_ARGS = {
"uri": "http://localhost:19530",
# "uri": "./milvus_test.db", # This uri works for Milvus Lite
# "uri": "http://localhost:19530", # This uri works for Milvus docker service
"uri": "./milvus_test.db", # This uri works for Milvus Lite
}


Expand Down
123 changes: 75 additions & 48 deletions tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from typing import List

import numpy as np
import pytest
from haystack import Document, default_to_dict
from haystack.dataclasses.sparse_embedding import SparseEmbedding
Expand All @@ -17,11 +18,31 @@
logger = logging.getLogger(__name__)

DEFAULT_CONNECTION_ARGS = {
"uri": "http://localhost:19530",
# "uri": "./milvus_test.db", # This uri works for Milvus Lite
# "uri": "http://localhost:19530", # This uri works for Milvus docker service
"uri": "./milvus_test.db", # This uri works for Milvus Lite
}


def l2_normalization(x: List[float]) -> List[float]:
v = np.array(x)
l2_norm = np.linalg.norm(v)
if l2_norm == 0:
return np.zeros_like(v)
normalized_v = v / l2_norm
return normalized_v.tolist()


def assert_docs_equal_except_score(doc1: Document, doc2: Document):
from dataclasses import fields

field_names = [field.name for field in fields(Document) if field.name != "score"]

for field_name in field_names:
value1 = getattr(doc1, field_name)
value2 = getattr(doc2, field_name)
assert value1 == value2


class TestMilvusEmbeddingTests:
@pytest.fixture
def document_store(self) -> MilvusDocumentStore:
Expand All @@ -33,25 +54,27 @@ def document_store(self) -> MilvusDocumentStore:

def test_run(self, document_store: MilvusDocumentStore):
documents = []
doc = Document(
content="A Foo Document",
meta={
"name": "name_0",
"page": "100",
"chapter": "intro",
"number": 2,
"date": "1969-07-21T20:17:40",
},
embedding=[-10.0] * 128,
)
documents.append(doc)
for i in range(10):
doc = Document(
content="A Foo Document",
meta={
"name": f"name_{i}",
"page": "100",
"chapter": "intro",
"number": 2,
"date": "1969-07-21T20:17:40",
},
embedding=l2_normalization([0.5] * 63 + [0.1 * i]),
)
documents.append(doc)
document_store.write_documents(documents)
retriever = MilvusEmbeddingRetriever(
document_store,
)
query_embedding = [-10.0] * 128
query_embedding = l2_normalization([0.5] * 64)
res = retriever.run(query_embedding)
assert res["documents"] == documents
assert len(res["documents"]) == 10
assert_docs_equal_except_score(res["documents"][0], documents[5])

def test_to_dict(self, document_store: MilvusDocumentStore):
expected_dict = {
Expand Down Expand Up @@ -147,29 +170,31 @@ def document_store(self) -> MilvusDocumentStore:
@pytest.fixture
def documents(self) -> List[Document]:
documents = []
doc = Document(
content="A Foo Document",
meta={
"name": "name_0",
"page": "100",
"chapter": "intro",
"number": 2,
"date": "1969-07-21T20:17:40",
},
embedding=[-10.0] * 128,
sparse_embedding=SparseEmbedding(indices=[0, 1, 2], values=[1.0, 2.0, 3.0]),
)
documents.append(doc)
for i in range(10):
doc = Document(
content="A Foo Document",
meta={
"name": f"name_{i}",
"page": "100",
"chapter": "intro",
"number": 2,
"date": "1969-07-21T20:17:40",
},
embedding=l2_normalization([0.5] * 64),
sparse_embedding=SparseEmbedding(indices=[0, 1, 2 + i], values=[1.0, 2.0, 3.0]),
)
documents.append(doc)
return documents

def test_run(self, document_store: MilvusDocumentStore, documents: List[Document]):
document_store.write_documents(documents)
retriever = MilvusSparseEmbeddingRetriever(
document_store,
)
sparse_query_embedding = SparseEmbedding(indices=[0, 1, 2], values=[1.0, 2.0, 3.0])
sparse_query_embedding = SparseEmbedding(indices=[0, 1, 2 + 5], values=[1.0, 2.0, 3.0])
res = retriever.run(sparse_query_embedding)
assert res["documents"] == documents
assert len(res["documents"]) == 10
assert_docs_equal_except_score(res["documents"][0], documents[5])

def test_fail_without_sparse_field(self, documents: List[Document]):
document_store = MilvusDocumentStore(
Expand Down Expand Up @@ -284,33 +309,35 @@ def document_store(self) -> MilvusDocumentStore:
@pytest.fixture
def documents(self) -> List[Document]:
documents = []
doc = Document(
content="A Foo Document",
meta={
"name": "name_0",
"page": "100",
"chapter": "intro",
"number": 2,
"date": "1969-07-21T20:17:40",
},
embedding=[-10.0] * 128,
sparse_embedding=SparseEmbedding(indices=[0, 1, 2], values=[1.0, 2.0, 3.0]),
)
documents.append(doc)
for i in range(10):
doc = Document(
content="A Foo Document",
meta={
"name": f"name_{i}",
"page": "100",
"chapter": "intro",
"number": 2,
"date": "1969-07-21T20:17:40",
},
embedding=l2_normalization([0.5] * 63 + [0.45 + 0.01 * i]),
sparse_embedding=SparseEmbedding(indices=[0, 1, 2 + i], values=[1.0, 2.0, 3.0]),
)
documents.append(doc)
return documents

def test_run(self, document_store: MilvusDocumentStore, documents: List[Document]):
document_store.write_documents(documents)
retriever = MilvusHybridRetriever(
document_store,
)
query_embedding = [-10.0] * 128
sparse_query_embedding = SparseEmbedding(indices=[0, 1, 2], values=[1.0, 2.0, 3.0])
query_embedding = l2_normalization([0.5] * 64)
sparse_query_embedding = SparseEmbedding(indices=[0, 1, 2 + 5], values=[1.0, 2.0, 3.0])
res = retriever.run(
query_embedding=query_embedding,
query_sparse_embedding=sparse_query_embedding,
)
assert res["documents"] == documents
assert len(res["documents"]) == 10
assert_docs_equal_except_score(res["documents"][0], documents[5])

def test_fail_without_sparse_field(self, documents: List[Document]):
document_store = MilvusDocumentStore(
Expand All @@ -324,7 +351,7 @@ def test_fail_without_sparse_field(self, documents: List[Document]):
retriever = MilvusHybridRetriever(
document_store,
)
query_embedding = [-10.0] * 128
query_embedding = l2_normalization([0.5] * 64)
sparse_query_embedding = SparseEmbedding(indices=[0, 1, 2], values=[1.0, 2.0, 3.0])
with pytest.raises(MilvusStoreError):
retriever.run(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
logger = logging.getLogger(__name__)

DEFAULT_CONNECTION_ARGS = {
"uri": "http://localhost:19530",
"uri": "http://localhost:19530", # This uri works for Milvus docker service
# "uri": "./milvus_test.db", # This uri works for Milvus Lite
# Note: milvus lite may fail in some tests due to currently not supporting some expressions
}
Expand Down

0 comments on commit 9b210fb

Please sign in to comment.