From e727c87335948454035bf1c906c6d1662a7577f9 Mon Sep 17 00:00:00 2001 From: wlleiiwang Date: Tue, 12 Nov 2024 11:49:31 +0800 Subject: [PATCH] fix use external embedding for tencent vectordb --- .../vectorstores/tencentvectordb.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/libs/community/langchain_community/vectorstores/tencentvectordb.py b/libs/community/langchain_community/vectorstores/tencentvectordb.py index c3bda890fe2413..c48b3c02c36922 100644 --- a/libs/community/langchain_community/vectorstores/tencentvectordb.py +++ b/libs/community/langchain_community/vectorstores/tencentvectordb.py @@ -6,7 +6,7 @@ import logging import time from enum import Enum -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast, Callable import numpy as np from langchain_core.documents import Document @@ -168,8 +168,8 @@ def __init__( tcvectordb = guard_import("tcvectordb") tcollection = guard_import("tcvectordb.model.collection") enum = guard_import("tcvectordb.model.enum") - - if t_vdb_embedding: + self.embedding_model = None + if embedding is None and t_vdb_embedding: embedding_model = [ model for model in enum.EmbeddingModel @@ -566,3 +566,17 @@ def max_marginal_relevance_search_by_vector( ) # Reorder the values and return. return [documents[x] for x in new_ordering if x != -1] + + def _select_relevance_score_fn(self) -> Callable[[float], float]: + metric_type = self.index_params.metric_type + if metric_type == "COSINE": + return self._cosine_relevance_score_fn + elif metric_type == "L2": + return self._euclidean_relevance_score_fn + elif metric_type == "IP": + return self._max_inner_product_relevance_score_fn + else: + raise ValueError( + "No supported normalization function" + f" for distance metric of type: {metric_type}." + )