Skip to content

Commit

Permalink
Add UDF and Chain Reranking (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-oplatka authored Nov 1, 2024
1 parent 868aa8e commit a9efc87
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions vectara_agentic/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import importlib
import os

from typing import Callable, List, Any, Optional, Type
from typing import Callable, List, Dict, Any, Optional, Type
from pydantic import BaseModel, Field

from llama_index.core.tools import FunctionTool
Expand Down Expand Up @@ -159,6 +159,8 @@ def create_rag_tool(
reranker: str = "mmr",
rerank_k: int = 50,
mmr_diversity_bias: float = 0.2,
udf_expression: str = None,
rerank_chain: List[Dict] = None,
include_citations: bool = True,
fcs_threshold: float = 0.0,
) -> VectaraTool:
Expand All @@ -178,10 +180,16 @@ def create_rag_tool(
reranker (str, optional): The reranker mode.
rerank_k (int, optional): Number of top-k documents for reranking.
mmr_diversity_bias (float, optional): MMR diversity bias.
udf_expression (str, optional): the user defined expression for reranking results.
rerank_chain (List[Dict], optional): A list of rerankers to be applied sequentially.
Each dictionary should specify the "type" of reranker (mmr, slingshot, udf)
and any other parameters (e.g. "limit" or "cutoff" for any type,
"diversity_bias" for mmr, and "user_function" for udf).
If using slingshot/multilingual_reranker_v1, it must be first in the list.
include_citations (bool, optional): Whether to include citations in the response.
If True, uses markdown vectara citations that requires the Vectara scale plan.
fcs_threshold (float, optional): a threshold for factual consistency.
If set above 0, the tool notifies the calling agent that it "cannot respond" if FCS is too low
If set above 0, the tool notifies the calling agent that it "cannot respond" if FCS is too low.
Returns:
VectaraTool: A VectaraTool object.
Expand Down Expand Up @@ -225,6 +233,8 @@ def rag_function(*args, **kwargs) -> ToolOutput:
reranker=reranker,
rerank_k=rerank_k if rerank_k * self.num_corpora <= 100 else int(100 / self.num_corpora),
mmr_diversity_bias=mmr_diversity_bias,
udf_expression=udf_expression,
rerank_chain=rerank_chain,
n_sentence_before=n_sentences_before,
n_sentence_after=n_sentences_after,
lambda_val=lambda_val,
Expand Down

0 comments on commit a9efc87

Please sign in to comment.