From 8ef71cdd57f35e53688859a92cb52cabd5fb3cfb Mon Sep 17 00:00:00 2001 From: Cheney Zhang Date: Thu, 19 Dec 2024 15:41:13 +0800 Subject: [PATCH 1/5] Check when self.index_params is None Signed-off-by: ChengZi --- src/milvus_haystack/document_store.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/milvus_haystack/document_store.py b/src/milvus_haystack/document_store.py index 66854be..f5b5657 100644 --- a/src/milvus_haystack/document_store.py +++ b/src/milvus_haystack/document_store.py @@ -901,6 +901,8 @@ def _map_ip_to_similarity(ip_score: float) -> float: """ return (ip_score + 1) / 2.0 + if not self.index_params: + return lambda x: x metric_type = self.index_params.get("metric_type", None) if metric_type == "L2": return _map_l2_to_similarity From 465a27bcefb46edbefa62987c844e0d7a3bdfe11 Mon Sep 17 00:00:00 2001 From: Cheney Zhang Date: Thu, 19 Dec 2024 15:48:35 +0800 Subject: [PATCH 2/5] re-organize import code Signed-off-by: ChengZi --- src/milvus_haystack/document_store.py | 47 ++++++++------------------- 1 file changed, 13 insertions(+), 34 deletions(-) diff --git a/src/milvus_haystack/document_store.py b/src/milvus_haystack/document_store.py index f5b5657..8b1612f 100644 --- a/src/milvus_haystack/document_store.py +++ b/src/milvus_haystack/document_store.py @@ -10,8 +10,20 @@ from haystack.document_stores.types import DuplicatePolicy from haystack.errors import FilterError from haystack.utils import Secret, deserialize_secrets_inplace -from pymilvus import AnnSearchRequest, MilvusException, RRFRanker +from pymilvus import ( + AnnSearchRequest, + Collection, + CollectionSchema, + DataType, + FieldSchema, + MilvusException, + RRFRanker, + connections, + utility, +) from pymilvus.client.abstract import BaseRanker +from pymilvus.client.types import LoadState +from pymilvus.orm.types import infer_dtype_bydata from milvus_haystack.filters import parse_filters @@ -122,12 +134,6 @@ def __init__( :param replica_number: Number of replicas. Defaults to 1. :param timeout: Timeout in seconds. Defaults to None. """ - try: - from pymilvus import Collection, utility - except ImportError as err: - err_msg = "Could not import pymilvus python package. Please install it with `pip install pymilvus`." - raise ValueError(err_msg) from err - # Default search params when one is not provided. self.default_search_params = { "GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, @@ -311,8 +317,6 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D :return: Number of documents written. """ - from pymilvus import Collection, MilvusException - documents_cp = [MilvusDocumentStore._discard_invalid_meta(doc) for doc in deepcopy(documents)] if len(documents_cp) > 0 and not isinstance(documents_cp[0], Document): err_msg = "param 'documents' must contain a list of objects of type Document" @@ -484,8 +488,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "MilvusDocumentStore": def _create_connection_alias(self, connection_args: dict) -> str: """Create the connection to the Milvus server.""" - from pymilvus import MilvusException, connections - connection_args_cp = copy.deepcopy(connection_args) # Grab the connection arguments that are used for checking existing connection host: str = connection_args_cp.get("host", None) @@ -568,15 +570,6 @@ def _init( ) def _create_collection(self, embeddings: list, metas: Optional[List[Dict]] = None) -> None: - from pymilvus import ( - Collection, - CollectionSchema, - DataType, - FieldSchema, - MilvusException, - ) - from pymilvus.orm.types import infer_dtype_bydata - # Determine embedding dim dim = len(embeddings[0]) fields = [] @@ -630,8 +623,6 @@ def _create_collection(self, embeddings: list, metas: Optional[List[Dict]] = Non def _extract_fields(self) -> None: """Grab the existing fields from the Collection""" - from pymilvus import Collection - if isinstance(self.col, Collection): schema = self.col.schema for x in schema.fields: @@ -639,8 +630,6 @@ def _extract_fields(self) -> None: def _create_index(self) -> None: """Create an index on the collection""" - from pymilvus import Collection, MilvusException - if isinstance(self.col, Collection) and self._get_index() is None: try: # If no index params, use a default HNSW based one @@ -694,8 +683,6 @@ def _create_index(self) -> None: def _create_search_params(self) -> None: """Generate search params based on the current index type""" - from pymilvus import Collection - if isinstance(self.col, Collection) and self.search_params is None: index = self._get_index() if index is not None: @@ -706,8 +693,6 @@ def _create_search_params(self) -> None: def _get_index(self) -> Optional[Dict[str, Any]]: """Return the vector index information if it exists""" - from pymilvus import Collection - if isinstance(self.col, Collection): for x in self.col.indexes: if x.field_name == self._vector_field: @@ -721,9 +706,6 @@ def _load( timeout: Optional[float] = None, ) -> None: """Load the collection if available.""" - from pymilvus import Collection, utility - from pymilvus.client.types import LoadState - if ( isinstance(self.col, Collection) and self._get_index() is not None @@ -944,9 +926,6 @@ def _discard_invalid_meta(document: Document): """ Remove metadata fields with unsupported types from the document. """ - from pymilvus import DataType - from pymilvus.orm.types import infer_dtype_bydata - if not isinstance(document, Document): msg = f"Invalid document type: {type(document)}" raise ValueError(msg) From 7b6c07617f9a3ddc6184933ac15a00bc51a0b25c Mon Sep 17 00:00:00 2001 From: Cheney Zhang Date: Thu, 19 Dec 2024 15:53:22 +0800 Subject: [PATCH 3/5] introduce milvus client Signed-off-by: ChengZi --- src/milvus_haystack/document_store.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/milvus_haystack/document_store.py b/src/milvus_haystack/document_store.py index 8b1612f..2550175 100644 --- a/src/milvus_haystack/document_store.py +++ b/src/milvus_haystack/document_store.py @@ -16,6 +16,7 @@ CollectionSchema, DataType, FieldSchema, + MilvusClient, MilvusException, RRFRanker, connections, @@ -175,6 +176,9 @@ def __init__( # Create the connection to the server if connection_args is None: self.connection_args = DEFAULT_MILVUS_CONNECTION + self._milvus_client = MilvusClient( + **self.connection_args, + ) self.alias = self._create_connection_alias(self.connection_args) # type: ignore[arg-type] self.col: Optional[Collection] = None @@ -199,6 +203,11 @@ def __init__( ) self._dummy_value = 999.0 + @property + def client(self) -> MilvusClient: + """Get client.""" + return self._milvus_client + def count_documents(self) -> int: """ Returns how many documents are present in the document store. From 0f6821574d1a9c136c56e8c8f7c8197897b0f2b2 Mon Sep 17 00:00:00 2001 From: Cheney Zhang Date: Thu, 19 Dec 2024 15:57:35 +0800 Subject: [PATCH 4/5] enrich default_search_params Signed-off-by: ChengZi --- src/milvus_haystack/document_store.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/milvus_haystack/document_store.py b/src/milvus_haystack/document_store.py index 2550175..20ae196 100644 --- a/src/milvus_haystack/document_store.py +++ b/src/milvus_haystack/document_store.py @@ -137,9 +137,7 @@ def __init__( """ # Default search params when one is not provided. self.default_search_params = { - "GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, - "GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, - "GPU_CAGRA": {"metric_type": "L2", "params": {"itopk_size": 128}}, + "FLAT": {"metric_type": "L2", "params": {}}, "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, @@ -149,7 +147,16 @@ def __init__( "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}}, "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}}, "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}}, + "SCANN": {"metric_type": "L2", "params": {"search_k": 10}}, "AUTOINDEX": {"metric_type": "L2", "params": {}}, + "GPU_CAGRA": {"metric_type": "L2", "params": {"itopk_size": 128}}, + "GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, + "GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, + "SPARSE_INVERTED_INDEX": { + "metric_type": "IP", + "params": {"drop_ratio_build": 0.2}, + }, + "SPARSE_WAND": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}}, } self.collection_name = collection_name From daf0ee41b1b33feb4ce36ea575b1afe39d75ef9f Mon Sep 17 00:00:00 2001 From: Cheney Zhang Date: Thu, 19 Dec 2024 17:16:14 +0800 Subject: [PATCH 5/5] fix persistence unittest Signed-off-by: ChengZi --- tests/test_document_store.py | 2 +- tests/test_embedding_retriever.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_document_store.py b/tests/test_document_store.py index 9c07963..4609a5c 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -79,6 +79,6 @@ def test_to_and_from_dict(self, document_store: MilvusDocumentStore): assert document_store_dict == expected_dict reconstructed_document_store = MilvusDocumentStore.from_dict(document_store_dict) for field in vars(reconstructed_document_store): - if field.startswith("__") or field == "alias": + if field.startswith("__") or field in ["alias", "_milvus_client"]: continue assert getattr(reconstructed_document_store, field) == getattr(document_store, field) diff --git a/tests/test_embedding_retriever.py b/tests/test_embedding_retriever.py index 0dbd9e9..bf40272 100644 --- a/tests/test_embedding_retriever.py +++ b/tests/test_embedding_retriever.py @@ -148,7 +148,7 @@ def test_from_dict(self, document_store: MilvusDocumentStore): continue elif field == "document_store": for doc_store_field in vars(document_store): - if doc_store_field.startswith("__") or doc_store_field == "alias": + if doc_store_field.startswith("__") or doc_store_field in ["alias", "_milvus_client"]: continue assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr( document_store, doc_store_field @@ -286,7 +286,7 @@ def test_from_dict(self, document_store: MilvusDocumentStore): continue elif field == "document_store": for doc_store_field in vars(document_store): - if doc_store_field.startswith("__") or doc_store_field == "alias": + if doc_store_field.startswith("__") or doc_store_field in ["alias", "_milvus_client"]: continue assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr( document_store, doc_store_field @@ -433,7 +433,7 @@ def test_from_dict(self, document_store: MilvusDocumentStore): continue elif field == "document_store": for doc_store_field in vars(document_store): - if doc_store_field.startswith("__") or doc_store_field == "alias": + if doc_store_field.startswith("__") or doc_store_field in ["alias", "_milvus_client"]: continue assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr( document_store, doc_store_field