diff --git a/src/milvus_haystack/document_store.py b/src/milvus_haystack/document_store.py index 66854be..20ae196 100644 --- a/src/milvus_haystack/document_store.py +++ b/src/milvus_haystack/document_store.py @@ -10,8 +10,21 @@ from haystack.document_stores.types import DuplicatePolicy from haystack.errors import FilterError from haystack.utils import Secret, deserialize_secrets_inplace -from pymilvus import AnnSearchRequest, MilvusException, RRFRanker +from pymilvus import ( + AnnSearchRequest, + Collection, + CollectionSchema, + DataType, + FieldSchema, + MilvusClient, + MilvusException, + RRFRanker, + connections, + utility, +) from pymilvus.client.abstract import BaseRanker +from pymilvus.client.types import LoadState +from pymilvus.orm.types import infer_dtype_bydata from milvus_haystack.filters import parse_filters @@ -122,17 +135,9 @@ def __init__( :param replica_number: Number of replicas. Defaults to 1. :param timeout: Timeout in seconds. Defaults to None. """ - try: - from pymilvus import Collection, utility - except ImportError as err: - err_msg = "Could not import pymilvus python package. Please install it with `pip install pymilvus`." - raise ValueError(err_msg) from err - # Default search params when one is not provided. self.default_search_params = { - "GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, - "GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, - "GPU_CAGRA": {"metric_type": "L2", "params": {"itopk_size": 128}}, + "FLAT": {"metric_type": "L2", "params": {}}, "IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, "IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}}, "IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, @@ -142,7 +147,16 @@ def __init__( "RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}}, "IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}}, "ANNOY": {"metric_type": "L2", "params": {"search_k": 10}}, + "SCANN": {"metric_type": "L2", "params": {"search_k": 10}}, "AUTOINDEX": {"metric_type": "L2", "params": {}}, + "GPU_CAGRA": {"metric_type": "L2", "params": {"itopk_size": 128}}, + "GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}}, + "GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}}, + "SPARSE_INVERTED_INDEX": { + "metric_type": "IP", + "params": {"drop_ratio_build": 0.2}, + }, + "SPARSE_WAND": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}}, } self.collection_name = collection_name @@ -169,6 +183,9 @@ def __init__( # Create the connection to the server if connection_args is None: self.connection_args = DEFAULT_MILVUS_CONNECTION + self._milvus_client = MilvusClient( + **self.connection_args, + ) self.alias = self._create_connection_alias(self.connection_args) # type: ignore[arg-type] self.col: Optional[Collection] = None @@ -193,6 +210,11 @@ def __init__( ) self._dummy_value = 999.0 + @property + def client(self) -> MilvusClient: + """Get client.""" + return self._milvus_client + def count_documents(self) -> int: """ Returns how many documents are present in the document store. @@ -311,8 +333,6 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D :return: Number of documents written. """ - from pymilvus import Collection, MilvusException - documents_cp = [MilvusDocumentStore._discard_invalid_meta(doc) for doc in deepcopy(documents)] if len(documents_cp) > 0 and not isinstance(documents_cp[0], Document): err_msg = "param 'documents' must contain a list of objects of type Document" @@ -484,8 +504,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "MilvusDocumentStore": def _create_connection_alias(self, connection_args: dict) -> str: """Create the connection to the Milvus server.""" - from pymilvus import MilvusException, connections - connection_args_cp = copy.deepcopy(connection_args) # Grab the connection arguments that are used for checking existing connection host: str = connection_args_cp.get("host", None) @@ -568,15 +586,6 @@ def _init( ) def _create_collection(self, embeddings: list, metas: Optional[List[Dict]] = None) -> None: - from pymilvus import ( - Collection, - CollectionSchema, - DataType, - FieldSchema, - MilvusException, - ) - from pymilvus.orm.types import infer_dtype_bydata - # Determine embedding dim dim = len(embeddings[0]) fields = [] @@ -630,8 +639,6 @@ def _create_collection(self, embeddings: list, metas: Optional[List[Dict]] = Non def _extract_fields(self) -> None: """Grab the existing fields from the Collection""" - from pymilvus import Collection - if isinstance(self.col, Collection): schema = self.col.schema for x in schema.fields: @@ -639,8 +646,6 @@ def _extract_fields(self) -> None: def _create_index(self) -> None: """Create an index on the collection""" - from pymilvus import Collection, MilvusException - if isinstance(self.col, Collection) and self._get_index() is None: try: # If no index params, use a default HNSW based one @@ -694,8 +699,6 @@ def _create_index(self) -> None: def _create_search_params(self) -> None: """Generate search params based on the current index type""" - from pymilvus import Collection - if isinstance(self.col, Collection) and self.search_params is None: index = self._get_index() if index is not None: @@ -706,8 +709,6 @@ def _create_search_params(self) -> None: def _get_index(self) -> Optional[Dict[str, Any]]: """Return the vector index information if it exists""" - from pymilvus import Collection - if isinstance(self.col, Collection): for x in self.col.indexes: if x.field_name == self._vector_field: @@ -721,9 +722,6 @@ def _load( timeout: Optional[float] = None, ) -> None: """Load the collection if available.""" - from pymilvus import Collection, utility - from pymilvus.client.types import LoadState - if ( isinstance(self.col, Collection) and self._get_index() is not None @@ -901,6 +899,8 @@ def _map_ip_to_similarity(ip_score: float) -> float: """ return (ip_score + 1) / 2.0 + if not self.index_params: + return lambda x: x metric_type = self.index_params.get("metric_type", None) if metric_type == "L2": return _map_l2_to_similarity @@ -942,9 +942,6 @@ def _discard_invalid_meta(document: Document): """ Remove metadata fields with unsupported types from the document. """ - from pymilvus import DataType - from pymilvus.orm.types import infer_dtype_bydata - if not isinstance(document, Document): msg = f"Invalid document type: {type(document)}" raise ValueError(msg) diff --git a/tests/test_document_store.py b/tests/test_document_store.py index 9c07963..4609a5c 100644 --- a/tests/test_document_store.py +++ b/tests/test_document_store.py @@ -79,6 +79,6 @@ def test_to_and_from_dict(self, document_store: MilvusDocumentStore): assert document_store_dict == expected_dict reconstructed_document_store = MilvusDocumentStore.from_dict(document_store_dict) for field in vars(reconstructed_document_store): - if field.startswith("__") or field == "alias": + if field.startswith("__") or field in ["alias", "_milvus_client"]: continue assert getattr(reconstructed_document_store, field) == getattr(document_store, field) diff --git a/tests/test_embedding_retriever.py b/tests/test_embedding_retriever.py index 0dbd9e9..bf40272 100644 --- a/tests/test_embedding_retriever.py +++ b/tests/test_embedding_retriever.py @@ -148,7 +148,7 @@ def test_from_dict(self, document_store: MilvusDocumentStore): continue elif field == "document_store": for doc_store_field in vars(document_store): - if doc_store_field.startswith("__") or doc_store_field == "alias": + if doc_store_field.startswith("__") or doc_store_field in ["alias", "_milvus_client"]: continue assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr( document_store, doc_store_field @@ -286,7 +286,7 @@ def test_from_dict(self, document_store: MilvusDocumentStore): continue elif field == "document_store": for doc_store_field in vars(document_store): - if doc_store_field.startswith("__") or doc_store_field == "alias": + if doc_store_field.startswith("__") or doc_store_field in ["alias", "_milvus_client"]: continue assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr( document_store, doc_store_field @@ -433,7 +433,7 @@ def test_from_dict(self, document_store: MilvusDocumentStore): continue elif field == "document_store": for doc_store_field in vars(document_store): - if doc_store_field.startswith("__") or doc_store_field == "alias": + if doc_store_field.startswith("__") or doc_store_field in ["alias", "_milvus_client"]: continue assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr( document_store, doc_store_field