Skip to content

Commit

Permalink
refine functions (#20)
Browse files Browse the repository at this point in the history
Signed-off-by: ChengZi <[email protected]>
  • Loading branch information
zc277584121 authored Nov 8, 2024
1 parent 5465429 commit 9f007e4
Showing 1 changed file with 24 additions and 14 deletions.
38 changes: 24 additions & 14 deletions libs/milvus/langchain_milvus/vectorstores/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from pymilvus import RRFRanker, WeightedRanker
from pymilvus import MilvusClient, RRFRanker, WeightedRanker

from langchain_milvus import MilvusCollectionHybridSearchRetriever
from langchain_milvus.utils.sparse import BaseSparseEmbedding
Expand Down Expand Up @@ -409,7 +409,7 @@ def embeddings(self) -> Union[EmbeddingType, List[EmbeddingType]]: # type: igno
return self.embedding_func

@property
def client(self) -> Any:
def client(self) -> MilvusClient:
"""Get client."""
return self._milvus_client

Expand All @@ -419,15 +419,11 @@ def _is_multi_vector(self) -> bool:

@property
def _is_sparse(self) -> bool:
if self.index_params is None:
return False
indexes_params = self._as_list(self.index_params)
if len(indexes_params) > 1:
return False
index_type = indexes_params[0]["index_type"]
if "SPARSE" in index_type:
embedding_func: List[EmbeddingType] = self._as_list(self.embedding_func)
if self._is_sparse_embedding(embedding_func[0]):
return True
return False
else:
return False

@staticmethod
def _is_sparse_embedding(embeddings_function: EmbeddingType) -> bool:
Expand Down Expand Up @@ -1396,12 +1392,20 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]:
- etc.
"""
if self.index_params is None:
raise ValueError("No index params provided.")
if not self.col or not self.col.indexes:
raise ValueError(
"No index params provided. Could not determine relevance function."
)
if self._is_multi_vector:
raise ValueError("No supported normalization function for multi vectors.")
raise ValueError(
"No supported normalization function for multi vectors. "
"Could not determine relevance function."
)
if self._is_sparse:
raise ValueError("No supported normalization function for sparse indexes.")
raise ValueError(
"No supported normalization function for sparse indexes. "
"Could not determine relevance function."
)

def _map_l2_to_similarity(l2_distance: float) -> float:
"""Return a similarity score on a scale [0, 1].
Expand All @@ -1423,6 +1427,12 @@ def _map_ip_to_similarity(ip_score: float) -> float:
"""
return (ip_score + 1) / 2.0

if self.index_params is None:
logger.warning(
"No index params provided. Could not determine relevance function. "
"Use L2 distance as default."
)
return _map_l2_to_similarity
indexes_params = self._as_list(self.index_params)
metric_type = indexes_params[0]["metric_type"]
if metric_type == "L2":
Expand Down

0 comments on commit 9f007e4

Please sign in to comment.