Skip to content

Commit

Permalink
Merge pull request #40 from zc277584121/main
Browse files Browse the repository at this point in the history
Optimize code and fix bug
  • Loading branch information
zc277584121 authored Dec 19, 2024
2 parents d6c8ad1 + daf0ee4 commit f157ccb
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 41 deletions.
71 changes: 34 additions & 37 deletions src/milvus_haystack/document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,21 @@
from haystack.document_stores.types import DuplicatePolicy
from haystack.errors import FilterError
from haystack.utils import Secret, deserialize_secrets_inplace
from pymilvus import AnnSearchRequest, MilvusException, RRFRanker
from pymilvus import (
AnnSearchRequest,
Collection,
CollectionSchema,
DataType,
FieldSchema,
MilvusClient,
MilvusException,
RRFRanker,
connections,
utility,
)
from pymilvus.client.abstract import BaseRanker
from pymilvus.client.types import LoadState
from pymilvus.orm.types import infer_dtype_bydata

from milvus_haystack.filters import parse_filters

Expand Down Expand Up @@ -122,17 +135,9 @@ def __init__(
:param replica_number: Number of replicas. Defaults to 1.
:param timeout: Timeout in seconds. Defaults to None.
"""
try:
from pymilvus import Collection, utility
except ImportError as err:
err_msg = "Could not import pymilvus python package. Please install it with `pip install pymilvus`."
raise ValueError(err_msg) from err

# Default search params when one is not provided.
self.default_search_params = {
"GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
"GPU_CAGRA": {"metric_type": "L2", "params": {"itopk_size": 128}},
"FLAT": {"metric_type": "L2", "params": {}},
"IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
Expand All @@ -142,7 +147,16 @@ def __init__(
"RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
"IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
"SCANN": {"metric_type": "L2", "params": {"search_k": 10}},
"AUTOINDEX": {"metric_type": "L2", "params": {}},
"GPU_CAGRA": {"metric_type": "L2", "params": {"itopk_size": 128}},
"GPU_IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"GPU_IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
"SPARSE_INVERTED_INDEX": {
"metric_type": "IP",
"params": {"drop_ratio_build": 0.2},
},
"SPARSE_WAND": {"metric_type": "IP", "params": {"drop_ratio_build": 0.2}},
}

self.collection_name = collection_name
Expand All @@ -169,6 +183,9 @@ def __init__(
# Create the connection to the server
if connection_args is None:
self.connection_args = DEFAULT_MILVUS_CONNECTION
self._milvus_client = MilvusClient(
**self.connection_args,
)
self.alias = self._create_connection_alias(self.connection_args) # type: ignore[arg-type]
self.col: Optional[Collection] = None

Expand All @@ -193,6 +210,11 @@ def __init__(
)
self._dummy_value = 999.0

@property
def client(self) -> MilvusClient:
"""Get client."""
return self._milvus_client

def count_documents(self) -> int:
"""
Returns how many documents are present in the document store.
Expand Down Expand Up @@ -311,8 +333,6 @@ def write_documents(self, documents: List[Document], policy: DuplicatePolicy = D
:return: Number of documents written.
"""

from pymilvus import Collection, MilvusException

documents_cp = [MilvusDocumentStore._discard_invalid_meta(doc) for doc in deepcopy(documents)]
if len(documents_cp) > 0 and not isinstance(documents_cp[0], Document):
err_msg = "param 'documents' must contain a list of objects of type Document"
Expand Down Expand Up @@ -484,8 +504,6 @@ def from_dict(cls, data: Dict[str, Any]) -> "MilvusDocumentStore":

def _create_connection_alias(self, connection_args: dict) -> str:
"""Create the connection to the Milvus server."""
from pymilvus import MilvusException, connections

connection_args_cp = copy.deepcopy(connection_args)
# Grab the connection arguments that are used for checking existing connection
host: str = connection_args_cp.get("host", None)
Expand Down Expand Up @@ -568,15 +586,6 @@ def _init(
)

def _create_collection(self, embeddings: list, metas: Optional[List[Dict]] = None) -> None:
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
MilvusException,
)
from pymilvus.orm.types import infer_dtype_bydata

# Determine embedding dim
dim = len(embeddings[0])
fields = []
Expand Down Expand Up @@ -630,17 +639,13 @@ def _create_collection(self, embeddings: list, metas: Optional[List[Dict]] = Non

def _extract_fields(self) -> None:
"""Grab the existing fields from the Collection"""
from pymilvus import Collection

if isinstance(self.col, Collection):
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)

def _create_index(self) -> None:
"""Create an index on the collection"""
from pymilvus import Collection, MilvusException

if isinstance(self.col, Collection) and self._get_index() is None:
try:
# If no index params, use a default HNSW based one
Expand Down Expand Up @@ -694,8 +699,6 @@ def _create_index(self) -> None:

def _create_search_params(self) -> None:
"""Generate search params based on the current index type"""
from pymilvus import Collection

if isinstance(self.col, Collection) and self.search_params is None:
index = self._get_index()
if index is not None:
Expand All @@ -706,8 +709,6 @@ def _create_search_params(self) -> None:

def _get_index(self) -> Optional[Dict[str, Any]]:
"""Return the vector index information if it exists"""
from pymilvus import Collection

if isinstance(self.col, Collection):
for x in self.col.indexes:
if x.field_name == self._vector_field:
Expand All @@ -721,9 +722,6 @@ def _load(
timeout: Optional[float] = None,
) -> None:
"""Load the collection if available."""
from pymilvus import Collection, utility
from pymilvus.client.types import LoadState

if (
isinstance(self.col, Collection)
and self._get_index() is not None
Expand Down Expand Up @@ -901,6 +899,8 @@ def _map_ip_to_similarity(ip_score: float) -> float:
"""
return (ip_score + 1) / 2.0

if not self.index_params:
return lambda x: x
metric_type = self.index_params.get("metric_type", None)
if metric_type == "L2":
return _map_l2_to_similarity
Expand Down Expand Up @@ -942,9 +942,6 @@ def _discard_invalid_meta(document: Document):
"""
Remove metadata fields with unsupported types from the document.
"""
from pymilvus import DataType
from pymilvus.orm.types import infer_dtype_bydata

if not isinstance(document, Document):
msg = f"Invalid document type: {type(document)}"
raise ValueError(msg)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ def test_to_and_from_dict(self, document_store: MilvusDocumentStore):
assert document_store_dict == expected_dict
reconstructed_document_store = MilvusDocumentStore.from_dict(document_store_dict)
for field in vars(reconstructed_document_store):
if field.startswith("__") or field == "alias":
if field.startswith("__") or field in ["alias", "_milvus_client"]:
continue
assert getattr(reconstructed_document_store, field) == getattr(document_store, field)
6 changes: 3 additions & 3 deletions tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_from_dict(self, document_store: MilvusDocumentStore):
continue
elif field == "document_store":
for doc_store_field in vars(document_store):
if doc_store_field.startswith("__") or doc_store_field == "alias":
if doc_store_field.startswith("__") or doc_store_field in ["alias", "_milvus_client"]:
continue
assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr(
document_store, doc_store_field
Expand Down Expand Up @@ -286,7 +286,7 @@ def test_from_dict(self, document_store: MilvusDocumentStore):
continue
elif field == "document_store":
for doc_store_field in vars(document_store):
if doc_store_field.startswith("__") or doc_store_field == "alias":
if doc_store_field.startswith("__") or doc_store_field in ["alias", "_milvus_client"]:
continue
assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr(
document_store, doc_store_field
Expand Down Expand Up @@ -433,7 +433,7 @@ def test_from_dict(self, document_store: MilvusDocumentStore):
continue
elif field == "document_store":
for doc_store_field in vars(document_store):
if doc_store_field.startswith("__") or doc_store_field == "alias":
if doc_store_field.startswith("__") or doc_store_field in ["alias", "_milvus_client"]:
continue
assert getattr(reconstructed_retriever.document_store, doc_store_field) == getattr(
document_store, doc_store_field
Expand Down

0 comments on commit f157ccb

Please sign in to comment.