From 804afd7db9242410c822b28fb8c5dfb1061b486d Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 12 Apr 2024 12:56:50 +0200 Subject: [PATCH] refactoring --- integrations/qdrant/pyproject.toml | 2 + .../components/retrievers/qdrant/retriever.py | 10 +- .../document_stores/qdrant/converters.py | 138 +++---- .../document_stores/qdrant/document_store.py | 95 ++--- .../document_stores/qdrant/filters.py | 366 +++++++++--------- integrations/qdrant/tests/test_converters.py | 49 +-- .../qdrant/tests/test_dict_converters.py | 3 - integrations/qdrant/tests/test_retriever.py | 2 - 8 files changed, 315 insertions(+), 350 deletions(-) diff --git a/integrations/qdrant/pyproject.toml b/integrations/qdrant/pyproject.toml index fb969fb4e..a566de955 100644 --- a/integrations/qdrant/pyproject.toml +++ b/integrations/qdrant/pyproject.toml @@ -103,6 +103,8 @@ ignore = [ "B027", # Allow boolean positional values in function calls, like `dict.get(... True)` "FBT003", + # Allow boolean arguments in function definition + "FBT001", "FBT002", # Ignore checks for possible passwords "S105", "S106", diff --git a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py index 12a67a3b7..0b7bfa1a4 100644 --- a/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py +++ b/integrations/qdrant/src/haystack_integrations/components/retrievers/qdrant/retriever.py @@ -33,8 +33,8 @@ def __init__( document_store: QdrantDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, - scale_score: bool = True, # noqa: FBT001, FBT002 - return_embedding: bool = False, # noqa: FBT001, FBT002 + scale_score: bool = True, + return_embedding: bool = False, ): """ Create a QdrantEmbeddingRetriever component. @@ -137,7 +137,7 @@ class QdrantSparseRetriever: document_store = QdrantDocumentStore( ":memory:", recreate_index=True, - return_sparse_embedding=True, + return_embedding=True, wait_result_from_api=True, ) retriever = QdrantSparseRetriever(document_store=document_store) @@ -151,8 +151,8 @@ def __init__( document_store: QdrantDocumentStore, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, - scale_score: bool = True, # noqa: FBT001, FBT002 - return_embedding: bool = False, # noqa: FBT001, FBT002 + scale_score: bool = True, + return_embedding: bool = False, ): """ Create a QdrantSparseRetriever component. diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py index 4bc7443ce..96bd4f37a 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/converters.py @@ -11,88 +11,70 @@ SPARSE_VECTORS_NAME = "text-sparse" -class HaystackToQdrant: - """A converter from Haystack to Qdrant types.""" - - UUID_NAMESPACE = uuid.UUID("3896d314-1e95-4a3a-b45a-945f9f0b541d") - - def documents_to_batch( - self, - documents: List[Document], - *, - embedding_field: str, - use_sparse_embeddings: bool, - sparse_embedding_field: str, - ) -> List[rest.PointStruct]: - points = [] - for document in documents: - payload = document.to_dict(flatten=False) - if use_sparse_embeddings: - vector = {} - - dense_vector = payload.pop(embedding_field, None) - if dense_vector is not None: - vector[DENSE_VECTORS_NAME] = dense_vector - - sparse_vector = payload.pop(sparse_embedding_field, None) - if sparse_vector is not None: - sparse_vector_instance = rest.SparseVector(**sparse_vector) - vector[SPARSE_VECTORS_NAME] = sparse_vector_instance - - else: - vector = payload.pop(embedding_field) or {} - _id = self.convert_id(payload.get("id")) - - point = rest.PointStruct( - payload=payload, - vector=vector, - id=_id, - ) - points.append(point) - return points - - def convert_id(self, _id: str) -> str: - """ - Converts any string into a UUID-like format in a deterministic way. - - Qdrant does not accept any string as an id, so an internal id has to be - generated for each point. This is a deterministic way of doing so. - """ - return uuid.uuid5(self.UUID_NAMESPACE, _id).hex +UUID_NAMESPACE = uuid.UUID("3896d314-1e95-4a3a-b45a-945f9f0b541d") + + +def convert_haystack_documents_to_qdrant_points( + documents: List[Document], + *, + embedding_field: str, + use_sparse_embeddings: bool, +) -> List[rest.PointStruct]: + points = [] + for document in documents: + payload = document.to_dict(flatten=False) + if use_sparse_embeddings: + vector = {} + + dense_vector = payload.pop(embedding_field, None) + if dense_vector is not None: + vector[DENSE_VECTORS_NAME] = dense_vector + + sparse_vector = payload.pop("sparse_embedding", None) + if sparse_vector is not None: + sparse_vector_instance = rest.SparseVector(**sparse_vector) + vector[SPARSE_VECTORS_NAME] = sparse_vector_instance + + else: + vector = payload.pop(embedding_field) or {} + _id = convert_id(payload.get("id")) + + point = rest.PointStruct( + payload=payload, + vector=vector, + id=_id, + ) + points.append(point) + return points + + +def convert_id(_id: str) -> str: + """ + Converts any string into a UUID-like format in a deterministic way. + + Qdrant does not accept any string as an id, so an internal id has to be + generated for each point. This is a deterministic way of doing so. + """ + return uuid.uuid5(UUID_NAMESPACE, _id).hex QdrantPoint = Union[rest.ScoredPoint, rest.Record] -class QdrantToHaystack: - def __init__( - self, - content_field: str, - name_field: str, - embedding_field: str, - use_sparse_embeddings: bool, # noqa: FBT001 - sparse_embedding_field: str, - ): - self.content_field = content_field - self.name_field = name_field - self.embedding_field = embedding_field - self.use_sparse_embeddings = use_sparse_embeddings - self.sparse_embedding_field = sparse_embedding_field - - def point_to_document(self, point: QdrantPoint) -> Document: - payload = {**point.payload} - payload["score"] = point.score if hasattr(point, "score") else None - if not self.use_sparse_embeddings: - payload["embedding"] = point.vector if hasattr(point, "vector") else None - else: - if hasattr(point, "vector") and point.vector is not None: - payload["embedding"] = point.vector.get(DENSE_VECTORS_NAME) +def convert_qdrant_point_to_haystack_document(point: QdrantPoint, use_sparse_embeddings: bool) -> Document: + payload = {**point.payload} + payload["score"] = point.score if hasattr(point, "score") else None + + if not use_sparse_embeddings: + payload["embedding"] = point.vector if hasattr(point, "vector") else None + elif hasattr(point, "vector") and point.vector is not None: + payload["embedding"] = point.vector.get(DENSE_VECTORS_NAME) - if hasattr(point, "vector") and point.vector is not None and SPARSE_VECTORS_NAME in point.vector: - parse_vector_dict = { - "indices": point.vector[SPARSE_VECTORS_NAME].indices, - "values": point.vector[SPARSE_VECTORS_NAME].values, - } - payload["sparse_embedding"] = parse_vector_dict + if SPARSE_VECTORS_NAME in point.vector: + parse_vector_dict = { + "indices": point.vector[SPARSE_VECTORS_NAME].indices, + "values": point.vector[SPARSE_VECTORS_NAME].values, + } + payload["sparse_embedding"] = parse_vector_dict - return Document.from_dict(payload) + return Document.from_dict(payload) diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py index a32c970ce..8771a3515 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/document_store.py @@ -12,20 +12,23 @@ from haystack.document_stores.errors import DocumentStoreError, DuplicateDocumentError from haystack.document_stores.types import DuplicatePolicy from haystack.utils import Secret, deserialize_secrets_inplace -from haystack.utils.filters import convert +from haystack.utils.filters import convert as convert_legacy_filters from qdrant_client import grpc from qdrant_client.http import models as rest from qdrant_client.http.exceptions import UnexpectedResponse from tqdm import tqdm -from .converters import HaystackToQdrant, QdrantToHaystack -from .filters import QdrantFilterConverter +from .converters import ( + DENSE_VECTORS_NAME, + SPARSE_VECTORS_NAME, + convert_haystack_documents_to_qdrant_points, + convert_id, + convert_qdrant_point_to_haystack_document, +) +from .filters import convert_filters_to_qdrant logger = logging.getLogger(__name__) -DENSE_VECTORS_NAME = "text-dense" -SPARSE_VECTORS_NAME = "text-sparse" - class QdrantStoreError(DocumentStoreError): pass @@ -58,7 +61,7 @@ def __init__( url: Optional[str] = None, port: int = 6333, grpc_port: int = 6334, - prefer_grpc: bool = False, # noqa: FBT001, FBT002 + prefer_grpc: bool = False, https: Optional[bool] = None, api_key: Optional[Secret] = None, prefix: Optional[str] = None, @@ -67,17 +70,16 @@ def __init__( path: Optional[str] = None, index: str = "Document", embedding_dim: int = 768, - on_disk: bool = False, # noqa: FBT001, FBT002 + on_disk: bool = False, content_field: str = "content", name_field: str = "name", embedding_field: str = "embedding", - use_sparse_embeddings: bool = False, # noqa: FBT001, FBT002 - sparse_embedding_field: str = "sparse_embedding", + use_sparse_embeddings: bool = False, similarity: str = "cosine", - return_embedding: bool = False, # noqa: FBT001, FBT002 - progress_bar: bool = True, # noqa: FBT001, FBT002 + return_embedding: bool = False, + progress_bar: bool = True, duplicate_documents: str = "overwrite", - recreate_index: bool = False, # noqa: FBT001, FBT002 + recreate_index: bool = False, shard_number: Optional[int] = None, replication_factor: Optional[int] = None, write_consistency_factor: Optional[int] = None, @@ -87,7 +89,7 @@ def __init__( wal_config: Optional[dict] = None, quantization_config: Optional[dict] = None, init_from: Optional[dict] = None, - wait_result_from_api: bool = True, # noqa: FBT001, FBT002 + wait_result_from_api: bool = True, metadata: Optional[dict] = None, write_batch_size: int = 100, scroll_size: int = 10_000, @@ -151,17 +153,11 @@ def __init__( self.content_field = content_field self.name_field = name_field self.embedding_field = embedding_field - self.sparse_embedding_field = sparse_embedding_field self.similarity = similarity self.index = index self.return_embedding = return_embedding self.progress_bar = progress_bar self.duplicate_documents = duplicate_documents - self.qdrant_filter_converter = QdrantFilterConverter() - self.haystack_to_qdrant_converter = HaystackToQdrant() - self.qdrant_to_haystack = QdrantToHaystack( - content_field, name_field, embedding_field, use_sparse_embeddings, sparse_embedding_field - ) self.write_batch_size = write_batch_size self.scroll_size = scroll_size @@ -186,7 +182,7 @@ def filter_documents( raise ValueError(msg) if filters and "operator" not in filters: - filters = convert(filters) + filters = convert_legacy_filters(filters) return list( self.get_documents_generator( filters, @@ -217,11 +213,10 @@ def write_documents( batched_documents = get_batches_from_generator(document_objects, self.write_batch_size) with tqdm(total=len(document_objects), disable=not self.progress_bar) as progress_bar: for document_batch in batched_documents: - batch = self.haystack_to_qdrant_converter.documents_to_batch( + batch = convert_haystack_documents_to_qdrant_points( document_batch, embedding_field=self.embedding_field, use_sparse_embeddings=self.use_sparse_embeddings, - sparse_embedding_field=self.sparse_embedding_field, ) self.client.upsert( @@ -234,7 +229,7 @@ def write_documents( return len(document_objects) def delete_documents(self, ids: List[str]): - ids = [self.haystack_to_qdrant_converter.convert_id(_id) for _id in ids] + ids = [convert_id(_id) for _id in ids] try: self.client.delete( collection_name=self.index, @@ -267,7 +262,7 @@ def get_documents_generator( filters: Optional[Dict[str, Any]] = None, ) -> Generator[Document, None, None]: index = self.index - qdrant_filters = self.qdrant_filter_converter.convert(filters) + qdrant_filters = convert_filters_to_qdrant(filters) next_offset = None stop_scrolling = False @@ -285,7 +280,9 @@ def get_documents_generator( ) for record in records: - yield self.qdrant_to_haystack.point_to_document(record) + yield convert_qdrant_point_to_haystack_document( + record, use_sparse_embeddings=self.use_sparse_embeddings + ) def get_documents_by_id( self, @@ -296,7 +293,7 @@ def get_documents_by_id( documents: List[Document] = [] - ids = [self.haystack_to_qdrant_converter.convert_id(_id) for _id in ids] + ids = [convert_id(_id) for _id in ids] records = self.client.retrieve( collection_name=index, ids=ids, @@ -305,7 +302,9 @@ def get_documents_by_id( ) for record in records: - documents.append(self.qdrant_to_haystack.point_to_document(record)) + documents.append( + convert_qdrant_point_to_haystack_document(record, use_sparse_embeddings=self.use_sparse_embeddings) + ) return documents def query_by_sparse( @@ -313,8 +312,8 @@ def query_by_sparse( query_sparse_embedding: SparseEmbedding, filters: Optional[Dict[str, Any]] = None, top_k: int = 10, - scale_score: bool = True, # noqa: FBT001, FBT002 - return_embedding: bool = False, # noqa: FBT001, FBT002 + scale_score: bool = True, + return_embedding: bool = False, ) -> List[Document]: if not self.use_sparse_embeddings: message = ( @@ -323,7 +322,7 @@ def query_by_sparse( ) raise QdrantStoreError(message) - qdrant_filters = self.qdrant_filter_converter.convert(filters) + qdrant_filters = convert_filters_to_qdrant(filters) query_indices = query_sparse_embedding.indices query_values = query_sparse_embedding.values points = self.client.search( @@ -339,7 +338,10 @@ def query_by_sparse( limit=top_k, with_vectors=return_embedding, ) - results = [self.qdrant_to_haystack.point_to_document(point) for point in points] + results = [ + convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) + for point in points + ] if scale_score: for document in results: score = document.score @@ -352,10 +354,10 @@ def query_by_embedding( query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: int = 10, - scale_score: bool = True, # noqa: FBT001, FBT002 - return_embedding: bool = False, # noqa: FBT001, FBT002 + scale_score: bool = True, + return_embedding: bool = False, ) -> List[Document]: - qdrant_filters = self.qdrant_filter_converter.convert(filters) + qdrant_filters = convert_filters_to_qdrant(filters) points = self.client.search( collection_name=self.index, @@ -367,7 +369,10 @@ def query_by_embedding( limit=top_k, with_vectors=return_embedding, ) - results = [self.qdrant_to_haystack.point_to_document(point) for point in points] + results = [ + convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=self.use_sparse_embeddings) + for point in points + ] if scale_score: for document in results: score = document.score @@ -406,10 +411,10 @@ def _set_up_collection( self, collection_name: str, embedding_dim: int, - recreate_collection: bool, # noqa: FBT001 + recreate_collection: bool, similarity: str, - use_sparse_embeddings: bool, # noqa: FBT001 - on_disk: bool = False, # noqa: FBT001, FBT002 + use_sparse_embeddings: bool, + on_disk: bool = False, payload_fields_to_index: Optional[List[dict]] = None, ): distance = self._get_distance(similarity) @@ -490,13 +495,15 @@ def _recreate_collection( collection_name: str, distance, embedding_dim: int, - on_disk: bool, # noqa: FBT001 - use_sparse_embeddings: bool, # noqa: FBT001 + on_disk: bool, + use_sparse_embeddings: bool, ): - dense_vectors_config = rest.VectorParams(size=embedding_dim, on_disk=on_disk, distance=distance) + # dense vectors configuration + vectors_config = rest.VectorParams(size=embedding_dim, on_disk=on_disk, distance=distance) if use_sparse_embeddings: - vectors_config = {DENSE_VECTORS_NAME: dense_vectors_config} + # in this case, we need to define named vectors + vectors_config = {DENSE_VECTORS_NAME: vectors_config} sparse_vectors_config = { SPARSE_VECTORS_NAME: rest.SparseVectorParams( @@ -505,8 +512,6 @@ def _recreate_collection( ) ), } - else: - vectors_config = dense_vectors_config self.client.recreate_collection( collection_name=collection_name, diff --git a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py index 72a74a8b1..c4387b1e5 100644 --- a/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py +++ b/integrations/qdrant/src/haystack_integrations/document_stores/qdrant/filters.py @@ -4,226 +4,230 @@ from haystack.utils.filters import COMPARISON_OPERATORS, LOGICAL_OPERATORS, FilterError from qdrant_client.http import models -from .converters import HaystackToQdrant +from .converters import convert_id COMPARISON_OPERATORS = COMPARISON_OPERATORS.keys() LOGICAL_OPERATORS = LOGICAL_OPERATORS.keys() -class QdrantFilterConverter: +def convert_filters_to_qdrant( + filter_term: Optional[Union[List[dict], dict]] = None, +) -> Optional[models.Filter]: """Converts Haystack filters to the format used by Qdrant.""" - def __init__(self): - self.haystack_to_qdrant_converter = HaystackToQdrant() + if not filter_term: + return None - def convert( - self, - filter_term: Optional[Union[List[dict], dict]] = None, - ) -> Optional[models.Filter]: - if not filter_term: - return None + must_clauses, should_clauses, must_not_clauses = [], [], [] - must_clauses, should_clauses, must_not_clauses = [], [], [] + if isinstance(filter_term, dict): + filter_term = [filter_term] - if isinstance(filter_term, dict): - filter_term = [filter_term] + for item in filter_term: + operator = item.get("operator") + if operator is None: + msg = "Operator not found in filters" + raise FilterError(msg) - for item in filter_term: - operator = item.get("operator") - if operator is None: - msg = "Operator not found in filters" - raise FilterError(msg) + if operator in LOGICAL_OPERATORS and "conditions" not in item: + msg = f"'conditions' not found for '{operator}'" + raise FilterError(msg) - if operator in LOGICAL_OPERATORS and "conditions" not in item: - msg = f"'conditions' not found for '{operator}'" + if operator == "AND": + must_clauses.append(convert_filters_to_qdrant(item.get("conditions", []))) + elif operator == "OR": + should_clauses.append(convert_filters_to_qdrant(item.get("conditions", []))) + elif operator == "NOT": + must_not_clauses.append(convert_filters_to_qdrant(item.get("conditions", []))) + elif operator in COMPARISON_OPERATORS: + field = item.get("field") + value = item.get("value") + if field is None or value is None: + msg = f"'field' or 'value' not found for '{operator}'" raise FilterError(msg) - if operator == "AND": - must_clauses.append(self.convert(item.get("conditions", []))) - elif operator == "OR": - should_clauses.append(self.convert(item.get("conditions", []))) - elif operator == "NOT": - must_not_clauses.append(self.convert(item.get("conditions", []))) - elif operator in COMPARISON_OPERATORS: - field = item.get("field") - value = item.get("value") - if field is None or value is None: - msg = f"'field' or 'value' not found for '{operator}'" - raise FilterError(msg) - - must_clauses.extend( - self._parse_comparison_operation(comparison_operation=operator, key=field, value=value) - ) - else: - msg = f"Unknown operator {operator} used in filters" - raise FilterError(msg) + must_clauses.extend(_parse_comparison_operation(comparison_operation=operator, key=field, value=value)) + else: + msg = f"Unknown operator {operator} used in filters" + raise FilterError(msg) - payload_filter = models.Filter( - must=must_clauses or None, - should=should_clauses or None, - must_not=must_not_clauses or None, - ) + payload_filter = models.Filter( + must=must_clauses or None, + should=should_clauses or None, + must_not=must_not_clauses or None, + ) - filter_result = self._squeeze_filter(payload_filter) + filter_result = _squeeze_filter(payload_filter) - return filter_result + return filter_result - def _parse_comparison_operation( - self, comparison_operation: str, key: str, value: Union[dict, List, str, float] - ) -> List[models.Condition]: - conditions: List[models.Condition] = [] - condition_builder_mapping = { - "==": self._build_eq_condition, - "in": self._build_in_condition, - "!=": self._build_ne_condition, - "not in": self._build_nin_condition, - ">": self._build_gt_condition, - ">=": self._build_gte_condition, - "<": self._build_lt_condition, - "<=": self._build_lte_condition, - } +def _parse_comparison_operation( + comparison_operation: str, key: str, value: Union[dict, List, str, float] +) -> List[models.Condition]: + conditions: List[models.Condition] = [] - condition_builder = condition_builder_mapping.get(comparison_operation) + condition_builder_mapping = { + "==": _build_eq_condition, + "in": _build_in_condition, + "!=": _build_ne_condition, + "not in": _build_nin_condition, + ">": _build_gt_condition, + ">=": _build_gte_condition, + "<": _build_lt_condition, + "<=": _build_lte_condition, + } - if condition_builder is None: - msg = f"Unknown operator {comparison_operation} used in filters" - raise ValueError(msg) + condition_builder = condition_builder_mapping.get(comparison_operation) - conditions.append(condition_builder(key, value)) + if condition_builder is None: + msg = f"Unknown operator {comparison_operation} used in filters" + raise ValueError(msg) - return conditions + conditions.append(condition_builder(key, value)) - def _build_eq_condition(self, key: str, value: models.ValueVariants) -> models.Condition: - if isinstance(value, str) and " " in value: - models.FieldCondition(key=key, match=models.MatchText(text=value)) - return models.FieldCondition(key=key, match=models.MatchValue(value=value)) + return conditions - def _build_in_condition(self, key: str, value: List[models.ValueVariants]) -> models.Condition: - if not isinstance(value, list): - msg = f"Value {value} is not a list" - raise FilterError(msg) - return models.Filter( - should=[ - ( - models.FieldCondition(key=key, match=models.MatchText(text=item)) - if isinstance(item, str) and " " not in item - else models.FieldCondition(key=key, match=models.MatchValue(value=item)) - ) - for item in value - ] - ) - - def _build_ne_condition(self, key: str, value: models.ValueVariants) -> models.Condition: - return models.Filter( - must_not=[ - ( - models.FieldCondition(key=key, match=models.MatchText(text=value)) - if isinstance(value, str) and " " not in value - else models.FieldCondition(key=key, match=models.MatchValue(value=value)) - ) - ] - ) - - def _build_nin_condition(self, key: str, value: List[models.ValueVariants]) -> models.Condition: - if not isinstance(value, list): - msg = f"Value {value} is not a list" - raise FilterError(msg) - return models.Filter( - must_not=[ - ( - models.FieldCondition(key=key, match=models.MatchText(text=item)) - if isinstance(item, str) and " " in item - else models.FieldCondition(key=key, match=models.MatchValue(value=item)) - ) - for item in value - ] - ) - - def _build_lt_condition(self, key: str, value: Union[str, float, int]) -> models.Condition: - if isinstance(value, str) and is_datetime_string(value): - return models.FieldCondition(key=key, range=models.DatetimeRange(lt=value)) - - if isinstance(value, (int, float)): - return models.FieldCondition(key=key, range=models.Range(lt=value)) - - msg = f"Value {value} is not an int or float or datetime string" - raise FilterError(msg) - def _build_lte_condition(self, key: str, value: Union[str, float, int]) -> models.Condition: - if isinstance(value, str) and is_datetime_string(value): - return models.FieldCondition(key=key, range=models.DatetimeRange(lte=value)) +def _build_eq_condition(key: str, value: models.ValueVariants) -> models.Condition: + if isinstance(value, str) and " " in value: + models.FieldCondition(key=key, match=models.MatchText(text=value)) + return models.FieldCondition(key=key, match=models.MatchValue(value=value)) - if isinstance(value, (int, float)): - return models.FieldCondition(key=key, range=models.Range(lte=value)) - msg = f"Value {value} is not an int or float or datetime string" +def _build_in_condition(key: str, value: List[models.ValueVariants]) -> models.Condition: + if not isinstance(value, list): + msg = f"Value {value} is not a list" + raise FilterError(msg) + return models.Filter( + should=[ + ( + models.FieldCondition(key=key, match=models.MatchText(text=item)) + if isinstance(item, str) and " " not in item + else models.FieldCondition(key=key, match=models.MatchValue(value=item)) + ) + for item in value + ] + ) + + +def _build_ne_condition(key: str, value: models.ValueVariants) -> models.Condition: + return models.Filter( + must_not=[ + ( + models.FieldCondition(key=key, match=models.MatchText(text=value)) + if isinstance(value, str) and " " not in value + else models.FieldCondition(key=key, match=models.MatchValue(value=value)) + ) + ] + ) + + +def _build_nin_condition(key: str, value: List[models.ValueVariants]) -> models.Condition: + if not isinstance(value, list): + msg = f"Value {value} is not a list" raise FilterError(msg) + return models.Filter( + must_not=[ + ( + models.FieldCondition(key=key, match=models.MatchText(text=item)) + if isinstance(item, str) and " " in item + else models.FieldCondition(key=key, match=models.MatchValue(value=item)) + ) + for item in value + ] + ) - def _build_gt_condition(self, key: str, value: Union[str, float, int]) -> models.Condition: - if isinstance(value, str) and is_datetime_string(value): - return models.FieldCondition(key=key, range=models.DatetimeRange(gt=value)) - if isinstance(value, (int, float)): - return models.FieldCondition(key=key, range=models.Range(gt=value)) +def _build_lt_condition(key: str, value: Union[str, float, int]) -> models.Condition: + if isinstance(value, str) and is_datetime_string(value): + return models.FieldCondition(key=key, range=models.DatetimeRange(lt=value)) - msg = f"Value {value} is not an int or float or datetime string" - raise FilterError(msg) + if isinstance(value, (int, float)): + return models.FieldCondition(key=key, range=models.Range(lt=value)) - def _build_gte_condition(self, key: str, value: Union[str, float, int]) -> models.Condition: - if isinstance(value, str) and is_datetime_string(value): - return models.FieldCondition(key=key, range=models.DatetimeRange(gte=value)) + msg = f"Value {value} is not an int or float or datetime string" + raise FilterError(msg) - if isinstance(value, (int, float)): - return models.FieldCondition(key=key, range=models.Range(gte=value)) - msg = f"Value {value} is not an int or float or datetime string" - raise FilterError(msg) +def _build_lte_condition(key: str, value: Union[str, float, int]) -> models.Condition: + if isinstance(value, str) and is_datetime_string(value): + return models.FieldCondition(key=key, range=models.DatetimeRange(lte=value)) + + if isinstance(value, (int, float)): + return models.FieldCondition(key=key, range=models.Range(lte=value)) + + msg = f"Value {value} is not an int or float or datetime string" + raise FilterError(msg) + - def _build_has_id_condition(self, id_values: List[models.ExtendedPointId]) -> models.HasIdCondition: - return models.HasIdCondition( - has_id=[ - # Ids are converted into their internal representation - self.haystack_to_qdrant_converter.convert_id(item) - for item in id_values - ] - ) - - def _squeeze_filter(self, payload_filter: models.Filter) -> models.Filter: - """ - Simplify given payload filter, if the nested structure might be unnested. - That happens if there is a single clause in that filter. - :param payload_filter: - :returns: - """ - filter_parts = { - "must": payload_filter.must, - "should": payload_filter.should, - "must_not": payload_filter.must_not, - } - - total_clauses = sum(len(x) for x in filter_parts.values() if x is not None) - if total_clauses == 0 or total_clauses > 1: - return payload_filter - - # Payload filter has just a single clause provided (either must, should - # or must_not). If that single clause is also of a models.Filter type, - # then it might be returned instead. - for part_name, filter_part in filter_parts.items(): - if not filter_part: - continue - - subfilter = filter_part[0] - if not isinstance(subfilter, models.Filter): - # The inner statement is a simple condition like models.FieldCondition - # so it cannot be simplified. - continue - - if subfilter.must: - return models.Filter(**{part_name: subfilter.must}) +def _build_gt_condition(key: str, value: Union[str, float, int]) -> models.Condition: + if isinstance(value, str) and is_datetime_string(value): + return models.FieldCondition(key=key, range=models.DatetimeRange(gt=value)) + if isinstance(value, (int, float)): + return models.FieldCondition(key=key, range=models.Range(gt=value)) + + msg = f"Value {value} is not an int or float or datetime string" + raise FilterError(msg) + + +def _build_gte_condition(key: str, value: Union[str, float, int]) -> models.Condition: + if isinstance(value, str) and is_datetime_string(value): + return models.FieldCondition(key=key, range=models.DatetimeRange(gte=value)) + + if isinstance(value, (int, float)): + return models.FieldCondition(key=key, range=models.Range(gte=value)) + + msg = f"Value {value} is not an int or float or datetime string" + raise FilterError(msg) + + +def _build_has_id_condition(id_values: List[models.ExtendedPointId]) -> models.HasIdCondition: + return models.HasIdCondition( + has_id=[ + # Ids are converted into their internal representation + convert_id(item) + for item in id_values + ] + ) + + +def _squeeze_filter(payload_filter: models.Filter) -> models.Filter: + """ + Simplify given payload filter, if the nested structure might be unnested. + That happens if there is a single clause in that filter. + :param payload_filter: + :returns: + """ + filter_parts = { + "must": payload_filter.must, + "should": payload_filter.should, + "must_not": payload_filter.must_not, + } + + total_clauses = sum(len(x) for x in filter_parts.values() if x is not None) + if total_clauses == 0 or total_clauses > 1: return payload_filter + # Payload filter has just a single clause provided (either must, should + # or must_not). If that single clause is also of a models.Filter type, + # then it might be returned instead. + for part_name, filter_part in filter_parts.items(): + if not filter_part: + continue + + subfilter = filter_part[0] + if not isinstance(subfilter, models.Filter): + # The inner statement is a simple condition like models.FieldCondition + # so it cannot be simplified. + continue + + if subfilter.must: + return models.Filter(**{part_name: subfilter.must}) + + return payload_filter + def is_datetime_string(value: str) -> bool: try: diff --git a/integrations/qdrant/tests/test_converters.py b/integrations/qdrant/tests/test_converters.py index fd9d5f3ad..242c4cafe 100644 --- a/integrations/qdrant/tests/test_converters.py +++ b/integrations/qdrant/tests/test_converters.py @@ -1,40 +1,19 @@ import numpy as np -import pytest -from haystack_integrations.document_stores.qdrant.converters import HaystackToQdrant, QdrantToHaystack +from haystack_integrations.document_stores.qdrant.converters import ( + convert_id, + convert_qdrant_point_to_haystack_document, +) from qdrant_client.http import models as rest -CONTENT_FIELD = "content" -NAME_FIELD = "name" -EMBEDDING_FIELD = "vector" -SPARSE_EMBEDDING_FIELD = "sparse-vector" - -@pytest.fixture -def haystack_to_qdrant() -> HaystackToQdrant: - return HaystackToQdrant() - - -@pytest.fixture -def qdrant_to_haystack(request) -> QdrantToHaystack: - return QdrantToHaystack( - content_field=CONTENT_FIELD, - name_field=NAME_FIELD, - embedding_field=EMBEDDING_FIELD, - use_sparse_embeddings=request.param, - sparse_embedding_field=SPARSE_EMBEDDING_FIELD, - ) - - -def test_convert_id_is_deterministic(haystack_to_qdrant: HaystackToQdrant): - first_id = haystack_to_qdrant.convert_id("test-id") - second_id = haystack_to_qdrant.convert_id("test-id") +def test_convert_id_is_deterministic(): + first_id = convert_id("test-id") + second_id = convert_id("test-id") assert first_id == second_id -@pytest.mark.parametrize("qdrant_to_haystack", [True], indirect=True) -def test_point_to_document_reverts_proper_structure_from_record_with_sparse( - qdrant_to_haystack: QdrantToHaystack, -): +def test_point_to_document_reverts_proper_structure_from_record_with_sparse(): + point = rest.Record( id="c7c62e8e-02b9-4ec6-9f88-46bd97b628b7", payload={ @@ -51,7 +30,7 @@ def test_point_to_document_reverts_proper_structure_from_record_with_sparse( "text-sparse": {"indices": [7, 1024, 367], "values": [0.1, 0.98, 0.33]}, }, ) - document = qdrant_to_haystack.point_to_document(point) + document = convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=True) assert "my-id" == document.id assert "Lorem ipsum" == document.content assert "text" == document.content_type @@ -60,10 +39,8 @@ def test_point_to_document_reverts_proper_structure_from_record_with_sparse( assert 0.0 == np.sum(np.array([1.0, 0.0, 0.0, 0.0]) - document.embedding) -@pytest.mark.parametrize("qdrant_to_haystack", [False], indirect=True) -def test_point_to_document_reverts_proper_structure_from_record_without_sparse( - qdrant_to_haystack: QdrantToHaystack, -): +def test_point_to_document_reverts_proper_structure_from_record_without_sparse(): + point = rest.Record( id="c7c62e8e-02b9-4ec6-9f88-46bd97b628b7", payload={ @@ -77,7 +54,7 @@ def test_point_to_document_reverts_proper_structure_from_record_without_sparse( }, vector=[1.0, 0.0, 0.0, 0.0], ) - document = qdrant_to_haystack.point_to_document(point) + document = convert_qdrant_point_to_haystack_document(point, use_sparse_embeddings=False) assert "my-id" == document.id assert "Lorem ipsum" == document.content assert "text" == document.content_type diff --git a/integrations/qdrant/tests/test_dict_converters.py b/integrations/qdrant/tests/test_dict_converters.py index 7e94fa083..6c8e46710 100644 --- a/integrations/qdrant/tests/test_dict_converters.py +++ b/integrations/qdrant/tests/test_dict_converters.py @@ -26,7 +26,6 @@ def test_to_dict(): "name_field": "name", "embedding_field": "embedding", "use_sparse_embeddings": False, - "sparse_embedding_field": "sparse_embedding", "similarity": "cosine", "return_embedding": False, "progress_bar": True, @@ -66,7 +65,6 @@ def test_from_dict(): "name_field": "name", "embedding_field": "embedding", "use_sparse_embeddings": True, - "sparse_embedding_field": "sparse_embedding", "similarity": "cosine", "return_embedding": False, "progress_bar": True, @@ -91,7 +89,6 @@ def test_from_dict(): document_store.name_field == "name", document_store.embedding_field == "embedding", document_store.use_sparse_embeddings is True, - document_store.sparse_embedding_field == "sparse_embedding", document_store.on_disk is False, document_store.similarity == "cosine", document_store.return_embedding is False, diff --git a/integrations/qdrant/tests/test_retriever.py b/integrations/qdrant/tests/test_retriever.py index 0f9452143..96e748220 100644 --- a/integrations/qdrant/tests/test_retriever.py +++ b/integrations/qdrant/tests/test_retriever.py @@ -50,7 +50,6 @@ def test_to_dict(self): "name_field": "name", "embedding_field": "embedding", "use_sparse_embeddings": False, - "sparse_embedding_field": "sparse_embedding", "similarity": "cosine", "return_embedding": False, "progress_bar": True, @@ -173,7 +172,6 @@ def test_to_dict(self): "name_field": "name", "embedding_field": "embedding", "use_sparse_embeddings": False, - "sparse_embedding_field": "sparse_embedding", "similarity": "cosine", "return_embedding": False, "progress_bar": True,