Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 committed Apr 12, 2024
1 parent ee819c8 commit 804afd7
Show file tree
Hide file tree
Showing 8 changed files with 315 additions and 350 deletions.
2 changes: 2 additions & 0 deletions integrations/qdrant/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ ignore = [
"B027",
# Allow boolean positional values in function calls, like `dict.get(... True)`
"FBT003",
# Allow boolean arguments in function definition
"FBT001", "FBT002",
# Ignore checks for possible passwords
"S105",
"S106",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(
document_store: QdrantDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = True, # noqa: FBT001, FBT002
return_embedding: bool = False, # noqa: FBT001, FBT002
scale_score: bool = True,
return_embedding: bool = False,
):
"""
Create a QdrantEmbeddingRetriever component.
Expand Down Expand Up @@ -137,7 +137,7 @@ class QdrantSparseRetriever:
document_store = QdrantDocumentStore(
":memory:",
recreate_index=True,
return_sparse_embedding=True,
return_embedding=True,
wait_result_from_api=True,
)
retriever = QdrantSparseRetriever(document_store=document_store)
Expand All @@ -151,8 +151,8 @@ def __init__(
document_store: QdrantDocumentStore,
filters: Optional[Dict[str, Any]] = None,
top_k: int = 10,
scale_score: bool = True, # noqa: FBT001, FBT002
return_embedding: bool = False, # noqa: FBT001, FBT002
scale_score: bool = True,
return_embedding: bool = False,
):
"""
Create a QdrantSparseRetriever component.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,88 +11,70 @@
SPARSE_VECTORS_NAME = "text-sparse"


class HaystackToQdrant:
"""A converter from Haystack to Qdrant types."""

UUID_NAMESPACE = uuid.UUID("3896d314-1e95-4a3a-b45a-945f9f0b541d")

def documents_to_batch(
self,
documents: List[Document],
*,
embedding_field: str,
use_sparse_embeddings: bool,
sparse_embedding_field: str,
) -> List[rest.PointStruct]:
points = []
for document in documents:
payload = document.to_dict(flatten=False)
if use_sparse_embeddings:
vector = {}

dense_vector = payload.pop(embedding_field, None)
if dense_vector is not None:
vector[DENSE_VECTORS_NAME] = dense_vector

sparse_vector = payload.pop(sparse_embedding_field, None)
if sparse_vector is not None:
sparse_vector_instance = rest.SparseVector(**sparse_vector)
vector[SPARSE_VECTORS_NAME] = sparse_vector_instance

else:
vector = payload.pop(embedding_field) or {}
_id = self.convert_id(payload.get("id"))

point = rest.PointStruct(
payload=payload,
vector=vector,
id=_id,
)
points.append(point)
return points

def convert_id(self, _id: str) -> str:
"""
Converts any string into a UUID-like format in a deterministic way.
Qdrant does not accept any string as an id, so an internal id has to be
generated for each point. This is a deterministic way of doing so.
"""
return uuid.uuid5(self.UUID_NAMESPACE, _id).hex
UUID_NAMESPACE = uuid.UUID("3896d314-1e95-4a3a-b45a-945f9f0b541d")


def convert_haystack_documents_to_qdrant_points(
documents: List[Document],
*,
embedding_field: str,
use_sparse_embeddings: bool,
) -> List[rest.PointStruct]:
points = []
for document in documents:
payload = document.to_dict(flatten=False)
if use_sparse_embeddings:
vector = {}

dense_vector = payload.pop(embedding_field, None)
if dense_vector is not None:
vector[DENSE_VECTORS_NAME] = dense_vector

sparse_vector = payload.pop("sparse_embedding", None)
if sparse_vector is not None:
sparse_vector_instance = rest.SparseVector(**sparse_vector)
vector[SPARSE_VECTORS_NAME] = sparse_vector_instance

else:
vector = payload.pop(embedding_field) or {}
_id = convert_id(payload.get("id"))

point = rest.PointStruct(
payload=payload,
vector=vector,
id=_id,
)
points.append(point)
return points


def convert_id(_id: str) -> str:
"""
Converts any string into a UUID-like format in a deterministic way.
Qdrant does not accept any string as an id, so an internal id has to be
generated for each point. This is a deterministic way of doing so.
"""
return uuid.uuid5(UUID_NAMESPACE, _id).hex


QdrantPoint = Union[rest.ScoredPoint, rest.Record]


class QdrantToHaystack:
def __init__(
self,
content_field: str,
name_field: str,
embedding_field: str,
use_sparse_embeddings: bool, # noqa: FBT001
sparse_embedding_field: str,
):
self.content_field = content_field
self.name_field = name_field
self.embedding_field = embedding_field
self.use_sparse_embeddings = use_sparse_embeddings
self.sparse_embedding_field = sparse_embedding_field

def point_to_document(self, point: QdrantPoint) -> Document:
payload = {**point.payload}
payload["score"] = point.score if hasattr(point, "score") else None
if not self.use_sparse_embeddings:
payload["embedding"] = point.vector if hasattr(point, "vector") else None
else:
if hasattr(point, "vector") and point.vector is not None:
payload["embedding"] = point.vector.get(DENSE_VECTORS_NAME)
def convert_qdrant_point_to_haystack_document(point: QdrantPoint, use_sparse_embeddings: bool) -> Document:
payload = {**point.payload}
payload["score"] = point.score if hasattr(point, "score") else None

if not use_sparse_embeddings:
payload["embedding"] = point.vector if hasattr(point, "vector") else None
elif hasattr(point, "vector") and point.vector is not None:
payload["embedding"] = point.vector.get(DENSE_VECTORS_NAME)

if hasattr(point, "vector") and point.vector is not None and SPARSE_VECTORS_NAME in point.vector:
parse_vector_dict = {
"indices": point.vector[SPARSE_VECTORS_NAME].indices,
"values": point.vector[SPARSE_VECTORS_NAME].values,
}
payload["sparse_embedding"] = parse_vector_dict
if SPARSE_VECTORS_NAME in point.vector:
parse_vector_dict = {
"indices": point.vector[SPARSE_VECTORS_NAME].indices,
"values": point.vector[SPARSE_VECTORS_NAME].values,
}
payload["sparse_embedding"] = parse_vector_dict

return Document.from_dict(payload)
return Document.from_dict(payload)
Loading

0 comments on commit 804afd7

Please sign in to comment.