From a9efc87cd9bd13bbbfc9f68f7b17893002393238 Mon Sep 17 00:00:00 2001 From: David Oplatka Date: Fri, 1 Nov 2024 14:11:39 -0700 Subject: [PATCH] Add UDF and Chain Reranking (#27) --- vectara_agentic/tools.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vectara_agentic/tools.py b/vectara_agentic/tools.py index 1cf3ded..5a4f5a7 100644 --- a/vectara_agentic/tools.py +++ b/vectara_agentic/tools.py @@ -7,7 +7,7 @@ import importlib import os -from typing import Callable, List, Any, Optional, Type +from typing import Callable, List, Dict, Any, Optional, Type from pydantic import BaseModel, Field from llama_index.core.tools import FunctionTool @@ -159,6 +159,8 @@ def create_rag_tool( reranker: str = "mmr", rerank_k: int = 50, mmr_diversity_bias: float = 0.2, + udf_expression: str = None, + rerank_chain: List[Dict] = None, include_citations: bool = True, fcs_threshold: float = 0.0, ) -> VectaraTool: @@ -178,10 +180,16 @@ def create_rag_tool( reranker (str, optional): The reranker mode. rerank_k (int, optional): Number of top-k documents for reranking. mmr_diversity_bias (float, optional): MMR diversity bias. + udf_expression (str, optional): the user defined expression for reranking results. + rerank_chain (List[Dict], optional): A list of rerankers to be applied sequentially. + Each dictionary should specify the "type" of reranker (mmr, slingshot, udf) + and any other parameters (e.g. "limit" or "cutoff" for any type, + "diversity_bias" for mmr, and "user_function" for udf). + If using slingshot/multilingual_reranker_v1, it must be first in the list. include_citations (bool, optional): Whether to include citations in the response. If True, uses markdown vectara citations that requires the Vectara scale plan. fcs_threshold (float, optional): a threshold for factual consistency. - If set above 0, the tool notifies the calling agent that it "cannot respond" if FCS is too low + If set above 0, the tool notifies the calling agent that it "cannot respond" if FCS is too low. Returns: VectaraTool: A VectaraTool object. @@ -225,6 +233,8 @@ def rag_function(*args, **kwargs) -> ToolOutput: reranker=reranker, rerank_k=rerank_k if rerank_k * self.num_corpora <= 100 else int(100 / self.num_corpora), mmr_diversity_bias=mmr_diversity_bias, + udf_expression=udf_expression, + rerank_chain=rerank_chain, n_sentence_before=n_sentences_before, n_sentence_after=n_sentences_after, lambda_val=lambda_val,