-
-
Notifications
You must be signed in to change notification settings - Fork 341
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Enabling use of custom tools (functions) with LLMs
- Loading branch information
Showing
7 changed files
with
323 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# Using your custom tools library in the chat interface | ||
|
||
In many situations LLMs will handle complex mathematical formulas quite well and return correct answers, but this is often not the case. Even for textual repsonses, using custom functions can constrain responses to formats and content that is more accurate and acceptable. | ||
|
||
Jupyter AI includes a slash command `/tools` that directs the LLM to use functions from a tools library that you provide. This is a single file titled `mytools.py` which may be stored in the default directory, that is, the one from which Jupyter is started. We provide an example of the tools file here, containing just three functions. Make sure to add the `@tool` decorator to each function and to import all packages that are not already installed within each function. The functions below are common financial formulas that are widely in use and you may expect that an LLM would be trained on these. While this is accurate, we will see that the LLM is unable to accurately execute the math in these formulas. | ||
|
||
``` | ||
@tool | ||
def BlackMertonScholes_Call(S: float, # current stock price | ||
K: float, # exercise price of the option | ||
T: float, # option maturity in years | ||
d: float, # annualized dividend rate | ||
r: float, # annualized risk free interest rate | ||
v: float, # stock volatility | ||
): | ||
"""Black-Scholes-Merton option pricing model for call options""" | ||
from scipy.stats import norm | ||
d1 = (np.log(S/K) + (r-d+0.5*v**2)*T)/(v*np.sqrt(T)) | ||
d2 = d1 - v*np.sqrt(T) | ||
call_option_price = S*np.exp(-d*T)*norm.cdf(d1) - K*np.exp(-r*T)*norm.cdf(d2) | ||
return call_option_price | ||
@tool | ||
def BlackMertonScholes_Put(S: float, # current stock price | ||
K: float, # exercise price of the option | ||
T: float, # option maturity in years | ||
d: float, # annualized dividend rate | ||
r: float, # annualized risk free interest rate | ||
v: float, # stock volatility | ||
): | ||
"""Black-Scholes-Merton option pricing model for put options""" | ||
from scipy.stats import norm | ||
d1 = (np.log(S/K) + (r-d+0.5*v**2)*T)/(v*np.sqrt(T)) | ||
d2 = d1 - v*np.sqrt(T) | ||
put_option_price = K*np.exp(-r*T)*norm.cdf(-d2) - S*np.exp(-d*T)*norm.cdf(-d1) | ||
return put_option_price | ||
@tool | ||
def calculate_monthly_payment(principal, annual_interest_rate, loan_term_years): | ||
""" | ||
Calculate the monthly mortgage payment. | ||
Args: | ||
principal (float): The principal amount of the loan. | ||
annual_interest_rate (float): The annual interest rate as a decimal (e.g., 0.06 for 6%). | ||
loan_term_years (int): The loan term in years. | ||
Returns: | ||
float: The monthly mortgage payment. | ||
""" | ||
import math | ||
# Convert annual interest rate to monthly interest rate | ||
monthly_interest_rate = annual_interest_rate / 12 | ||
# Calculate the number of monthly payments | ||
num_payments = loan_term_years * 12 | ||
# Calculate the monthly payment using the annuity formula | ||
monthly_payment = (principal * monthly_interest_rate) / (1 - math.pow(1 + monthly_interest_rate, -num_payments)) | ||
return monthly_payment | ||
``` | ||
|
||
Each function contains the `@tool` decorator and the required imports. Note also the comment string that describes what each tool does. This will help direct the LLM to relevant tool. Providing sufficient guiding comments in the function is helpful in the form of comment strings, variable annotations, and expolicit argument comments, example of which are shown in the code above. For example, default values in comments will be used by the LLM if the user forgets to provide them (for example, see the explicit mention of a 6% interest rate in `calculate_monthly_payment` function above). | ||
|
||
When the `/tools` command is used, Jupyter AI will bind the custom tools file to the LLM currently in use and build a `LangGraph` (https://langchain-ai.github.io/langgraph/). It will use this graph to respond to the query and use the appropriate tools, if available. | ||
|
||
As an example, submit this query in the chat interface without using tools: "What is the price of a put option where the stock price is 100, the exercise price is 101, the time to maturity is 1 year, the risk free rate is 3%, the dividend rate is zero, and the stock volatility is 20%?" The correct answer to this query is $6.93. However, though the LLM returns the correct formula, it computes the answer incorrectly: | ||
|
||
<img src="../_static/tools_wrong_answer.png" | ||
width="75%" | ||
alt='Incorrect use of the Black-Merton-Scholes formula for pricing a put option.' | ||
class="screenshot" /> | ||
|
||
Next, use the `/tools` command with the same query to get the correct answer: | ||
|
||
<img src="../_static/tools_correct_answer.png" | ||
width="75%" | ||
alt='Incorrect use of the Black-Merton-Scholes formula for pricing a put option.' | ||
class="screenshot" /> | ||
|
||
You can try the other tools in this example or build your own custom tools file to experiment with this feature. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
##### TO DO ##### | ||
# - icon for /tools | ||
# - Where to put the mytools.py file? | ||
# - Pass tools file location on startup, or in chat, or keep one file, or handle multiple tool files. | ||
# - Handling the different providers (Chat*) and models (model_id) | ||
# - To integrate with chat history and memory or not? | ||
# - How to suppress the problem with % sign messing up output? | ||
# - Show full exchange or only the answer? | ||
# - Error handling | ||
# - Documentation | ||
# - What's the best way to add this to magics? | ||
# - Long term: Using the more advanced features of LangGraph, Agents, Multi-agentic workflows, etc. | ||
|
||
|
||
import argparse | ||
from typing import Dict, Type | ||
|
||
from jupyter_ai.models import HumanChatMessage | ||
from jupyter_ai_magics.providers import BaseProvider | ||
from langchain.chains import ConversationalRetrievalChain | ||
from langchain.memory import ConversationBufferWindowMemory | ||
from langchain_core.prompts import PromptTemplate | ||
from langchain.chains import LLMChain | ||
|
||
from .base import BaseChatHandler, SlashCommandRoutingType | ||
|
||
# LangGraph imports for using tools | ||
import os | ||
import re | ||
import numpy as np | ||
import math | ||
from typing import Literal | ||
|
||
from langchain_core.messages import AIMessage | ||
from langchain_core.tools import tool | ||
from langgraph.prebuilt import ToolNode | ||
from langgraph.graph import StateGraph, MessagesState | ||
|
||
# Chat Providers (add more as needed) | ||
from langchain_aws import ChatBedrock | ||
from langchain_ollama import ChatOllama | ||
from langchain_anthropic import ChatAnthropic | ||
from langchain_openai import ChatOpenAI, AzureChatOpenAI | ||
from langchain_cohere import ChatCohere | ||
from langchain_google_genai import ChatGoogleGenerativeAI | ||
|
||
|
||
PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. | ||
Chat History: | ||
{chat_history} | ||
Follow Up Input: {question} | ||
Standalone question: | ||
Format the answer to be as pretty as possible. | ||
""" | ||
CONDENSE_PROMPT = PromptTemplate.from_template(PROMPT_TEMPLATE) | ||
|
||
|
||
class ToolsChatHandler(BaseChatHandler): | ||
"""Processes messages prefixed with /tools. This actor will | ||
bind a <tool_name>.py collection of tools to the LLM and | ||
build a computational graph to direct queries to tools | ||
that apply to the prompt. If there is no appropriate tool, | ||
the LLM will default to a standard chat response from the LLM | ||
without using tools. | ||
""" | ||
|
||
id = "tools" | ||
name = "Use tools with LLM" | ||
help = "Ask a question that uses your custom tools" | ||
routing_type = SlashCommandRoutingType(slash_id="tools") | ||
|
||
uses_llm = True | ||
|
||
# def __init__(self, retriever, *args, **kwargs): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
# self._retriever = retriever | ||
self.parser.prog = "/tools" | ||
self.parser.add_argument("query", nargs=argparse.REMAINDER) | ||
self.tools_file_path = os.path.join(self.output_dir, 'mytools.py') # Maybe pass as parameter? | ||
self.chat_provider = "" # Default, updated with function `setChatProvider` | ||
|
||
|
||
# https://python.langchain.com/v0.2/docs/integrations/platforms/ | ||
def setChatProvider(self, provider): # For selecting the model to bind tools with | ||
try: | ||
if "bedrock" in provider.name.lower(): | ||
chat_provider = "ChatBedrock" | ||
elif "ollama" in provider.name.lower(): | ||
chat_provider = "ChatOllama" | ||
elif "anthropic" in provider.name.lower(): | ||
chat_provider = "ChatAnthropic" | ||
elif "azure" in provider.name.lower(): | ||
chat_provider = "AzureChatOpenAI" | ||
elif "openai" in provider.name.lower(): | ||
chat_provider = "ChatOpenAI" | ||
elif "cohere" in provider.name.lower(): | ||
chat_provider = "ChatCohere" | ||
elif "google" in provider.name.lower(): | ||
chat_provider = "ChatGoogleGenerativeAI" | ||
return chat_provider | ||
except Exception as e: | ||
self.log.error(e) | ||
response = """The related chat provider is not supported.""" | ||
self.reply(response) | ||
|
||
|
||
def create_llm_chain( | ||
self, provider: Type[BaseProvider], provider_params: Dict[str, str] | ||
): | ||
unified_parameters = { | ||
**provider_params, | ||
**(self.get_model_parameters(provider, provider_params)), | ||
} | ||
llm = provider(**unified_parameters) | ||
self.chat_provider = self.setChatProvider(provider) | ||
self.llm = llm | ||
memory = ConversationBufferWindowMemory( | ||
memory_key="chat_history", return_messages=True, k=2 | ||
) | ||
self.llm_chain = LLMChain(llm=self.llm, | ||
prompt=CONDENSE_PROMPT, | ||
memory=memory, | ||
verbose=False) | ||
|
||
|
||
# #### TOOLS FOR USE WITH LANGGRAPH ##### | ||
""" | ||
Bind tools to LLM and provide chat functionality. | ||
Call: | ||
/tools <query> | ||
""" | ||
|
||
def conditional_continue(state: MessagesState) -> Literal["tools", "__end__"]: | ||
messages = state["messages"] | ||
last_message = messages[-1] | ||
if last_message.tool_calls: | ||
return "tools" | ||
return "__end__" | ||
|
||
def get_tool_names(tools_file_path): | ||
""" | ||
Read a file and extract the function names following the @tool decorator. | ||
Args: | ||
file_path (str): The path to the file. | ||
Returns: | ||
list: A list of function names. | ||
""" | ||
with open(tools_file_path, 'r') as file: | ||
content = file.read() | ||
# Use a regular expression to find the function names | ||
tool_pattern = r'@tool\n\s*def\s+(\w+)' | ||
tools = re.findall(tool_pattern, content) | ||
return tools | ||
|
||
def toolChat(self, query): | ||
print("TOOL CHAT", query) | ||
for chunk in self.app.stream({"messages": [("human", query)]}, stream_mode="values"): | ||
response = chunk["messages"][-1].pretty_print() | ||
return response | ||
|
||
|
||
##### MAIN FUNCTION ##### | ||
def useLLMwithTools(self, chat_provider, model_name, tools_file_path, query): | ||
|
||
def call_tool(state: MessagesState): | ||
messages = state["messages"] | ||
response = self.model_with_tools.invoke(messages) | ||
return {"messages": [response]} | ||
|
||
# Read in the tools file | ||
file_path = tools_file_path | ||
with open(file_path) as file: | ||
exec(file.read()) | ||
|
||
# Get tool names and create node with tools | ||
tools = ToolsChatHandler.get_tool_names(file_path) | ||
tools = [eval(j) for j in tools] | ||
tool_node = ToolNode(tools) | ||
|
||
# Bind tools to LLM | ||
if chat_provider=="ChatBedrock": | ||
self.model_with_tools = eval(chat_provider)(model_id=model_name, | ||
model_kwargs={"temperature": 0}).bind_tools(tools) | ||
else: | ||
self.model_with_tools = eval(chat_provider)(model=model_name, temperature=0).bind_tools(tools) | ||
|
||
# Initialize graph | ||
agentic_workflow = StateGraph(MessagesState) | ||
# Define the agent and tool nodes we will cycle between | ||
agentic_workflow.add_node("agent", call_tool) | ||
agentic_workflow.add_node("tools", tool_node) | ||
# Add edges to the graph | ||
agentic_workflow.add_edge("__start__", "agent") | ||
agentic_workflow.add_conditional_edges("agent", ToolsChatHandler.conditional_continue) | ||
agentic_workflow.add_edge("tools", "agent") | ||
# Compile graph | ||
app = agentic_workflow.compile() | ||
|
||
# Run query | ||
# res = ToolsChatHandler.toolChat(self, query) | ||
res = app.invoke({"messages": query}) | ||
return res["messages"][-1].content | ||
|
||
|
||
async def process_message(self, message: HumanChatMessage): | ||
args = self.parse_args(message) | ||
if args is None: | ||
return | ||
query = " ".join(args.query) | ||
if not query: | ||
self.reply(f"{self.parser.format_usage()}", message) | ||
return | ||
|
||
self.get_llm_chain() | ||
|
||
try: | ||
with self.pending("Using LLM with tools ..."): | ||
# result = await self.llm_chain.acall({"question": query}) | ||
response = self.useLLMwithTools(self.chat_provider, | ||
self.llm.model_id, | ||
self.tools_file_path, | ||
query) | ||
self.reply(response, message) | ||
except Exception as e: | ||
self.log.error(e) | ||
response = """Sorry, tool usage failed. | ||
Either (i) this LLM does not accept tools, (ii) there an error in | ||
the custom tools file, (iii) you may also want to check the | ||
location of the tools file, or (iv) you may need to install the | ||
`langchain_<provider_name>` package. (v) Finally, check that you have | ||
authorized access to the LLM.""" | ||
self.reply(response, message) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters