Skip to content

Commit

Permalink
globally import (#24)
Browse files Browse the repository at this point in the history
Signed-off-by: ChengZi <[email protected]>
  • Loading branch information
zc277584121 authored Nov 11, 2024
1 parent 7a408b7 commit 6d81c3e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 37 deletions.
49 changes: 13 additions & 36 deletions libs/milvus/langchain_milvus/vectorstores/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,19 @@
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from pymilvus import MilvusClient, RRFRanker, WeightedRanker
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
MilvusClient,
MilvusException,
RRFRanker,
WeightedRanker,
utility,
)
from pymilvus.client.types import LoadState # type: ignore
from pymilvus.orm.types import infer_dtype_bydata # type: ignore

from langchain_milvus import MilvusCollectionHybridSearchRetriever
from langchain_milvus.utils.sparse import BaseSparseEmbedding
Expand Down Expand Up @@ -280,14 +292,6 @@ def __init__(
metadata_schema: Optional[dict[str, Any]] = None,
):
"""Initialize the Milvus vector store."""
try:
from pymilvus import Collection, MilvusClient, utility
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)

# Default search params when one is not provided.
self.default_search_params = {
"FLAT": {"metric_type": "L2", "params": {}},
Expand Down Expand Up @@ -451,15 +455,6 @@ def _init(
def _create_collection(
self, embeddings: List[list], metadatas: Optional[list[dict]] = None
) -> None:
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
MilvusException,
)
from pymilvus.orm.types import infer_dtype_bydata # type: ignore

fields = []
vector_fields: List[str] = self._as_list(self._vector_field)
# If enable_dynamic_field, we don't need to create fields, and just pass it.
Expand Down Expand Up @@ -632,8 +627,6 @@ def _create_collection(
raise e

def _get_field_schema_from_dict(self, field_name: str, schema_dict: dict): # type: ignore[no-untyped-def]
from pymilvus import FieldSchema

assert "dtype" in schema_dict, (
f"Please provide `dtype` in the schema dict. "
f"Existing keys are: {schema_dict.keys()}"
Expand All @@ -645,17 +638,13 @@ def _get_field_schema_from_dict(self, field_name: str, schema_dict: dict): # ty

def _extract_fields(self) -> None:
"""Grab the existing fields from the Collection"""
from pymilvus import Collection

if isinstance(self.col, Collection):
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)

def _get_index(self, field_name: Optional[str] = None) -> Optional[dict[str, Any]]:
"""Return the vector index information if it exists"""
from pymilvus import Collection

if not self._is_multi_vector:
field_name: str = field_name or self._vector_field # type: ignore

Expand All @@ -667,8 +656,6 @@ def _get_index(self, field_name: Optional[str] = None) -> Optional[dict[str, Any

def _create_index(self) -> None:
"""Create an index on the collection"""
from pymilvus import Collection, MilvusException

if isinstance(self.col, Collection) and self._get_index() is None:
embeddings_functions: List[EmbeddingType] = self._as_list(
self.embedding_func
Expand Down Expand Up @@ -740,8 +727,6 @@ def _create_search_params(self) -> None:
"""Generate search params based on the current index type"""
import copy

from pymilvus import Collection

if isinstance(self.col, Collection) and self.search_params is None:
vector_fields: List[str] = self._as_list(self._vector_field)
search_params_list: List[dict] = []
Expand All @@ -768,9 +753,6 @@ def _load(
timeout: Optional[float] = None,
) -> None:
"""Load the collection if available."""
from pymilvus import Collection, utility
from pymilvus.client.types import LoadState # type: ignore

timeout = self.timeout or timeout
if (
isinstance(self.col, Collection)
Expand Down Expand Up @@ -934,7 +916,6 @@ def add_embeddings(
Returns:
List[str]: The resulting keys for each inserted element.
"""
from pymilvus import Collection, MilvusException

if not self._is_multi_vector:
embeddings = [[embedding] for embedding in embeddings] # type: ignore
Expand Down Expand Up @@ -1585,8 +1566,6 @@ def get_pks(self, expr: str, **kwargs: Any) -> List[int] | None:
List[int]: List of IDs (Primary Keys)
"""

from pymilvus import MilvusException

if self.col is None:
logger.debug("No existing collection to get pk.")
return None
Expand Down Expand Up @@ -1617,8 +1596,6 @@ def upsert( # type: ignore
List[str]: IDs of the added texts.
"""

from pymilvus import MilvusException

if documents is None or len(documents) == 0:
logger.debug("No documents to upsert.")
return None
Expand Down
3 changes: 2 additions & 1 deletion libs/milvus/langchain_milvus/vectorstores/zilliz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import logging
from typing import List, Optional, Union, cast

from pymilvus import Collection, MilvusException

from langchain_milvus.vectorstores.milvus import EmbeddingType, Milvus

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -73,7 +75,6 @@ class Zilliz(Milvus):

def _create_index(self) -> None:
"""Create an index on the collection"""
from pymilvus import Collection, MilvusException

self.index_params = cast(Optional[Union[dict, List[dict]]], self.index_params) # type: ignore

Expand Down

0 comments on commit 6d81c3e

Please sign in to comment.