Skip to content

Commit

Permalink
Adding new slash command /tools
Browse files Browse the repository at this point in the history
Enabling use of custom tools (functions) with LLMs
  • Loading branch information
srdas committed Sep 11, 2024
1 parent 83cbd8e commit 564ab9b
Show file tree
Hide file tree
Showing 7 changed files with 323 additions and 0 deletions.
Binary file added docs/source/_static/tools_correct_answer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/source/_static/tools_wrong_answer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions docs/source/users/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -600,10 +600,16 @@ contents of the failing cell.
class="screenshot" style="max-width:65%" />


### Using custom tools in chat

In order to use your own custom tools in chat, create a tools file named `mytools.py` in the default directory. Then use the `/tools` command with your query. For details on how to build your custom tools file and usage of the `/tools` command, refer to [Using your custom tools library in the chat interface](tools.md).


### Additional chat commands

To start a new conversation, use the `/clear` command. This will clear the chat panel and reset the model's memory.


## The `%ai` and `%%ai` magic commands

Jupyter AI can also be used in notebooks via Jupyter AI magics. This section
Expand Down
77 changes: 77 additions & 0 deletions docs/source/users/tools.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Using your custom tools library in the chat interface

In many situations LLMs will handle complex mathematical formulas quite well and return correct answers, but this is often not the case. Even for textual repsonses, using custom functions can constrain responses to formats and content that is more accurate and acceptable.

Jupyter AI includes a slash command `/tools` that directs the LLM to use functions from a tools library that you provide. This is a single file titled `mytools.py` which may be stored in the default directory, that is, the one from which Jupyter is started. We provide an example of the tools file here, containing just three functions. Make sure to add the `@tool` decorator to each function and to import all packages that are not already installed within each function. The functions below are common financial formulas that are widely in use and you may expect that an LLM would be trained on these. While this is accurate, we will see that the LLM is unable to accurately execute the math in these formulas.

```
@tool
def BlackMertonScholes_Call(S: float, # current stock price
K: float, # exercise price of the option
T: float, # option maturity in years
d: float, # annualized dividend rate
r: float, # annualized risk free interest rate
v: float, # stock volatility
):
"""Black-Scholes-Merton option pricing model for call options"""
from scipy.stats import norm
d1 = (np.log(S/K) + (r-d+0.5*v**2)*T)/(v*np.sqrt(T))
d2 = d1 - v*np.sqrt(T)
call_option_price = S*np.exp(-d*T)*norm.cdf(d1) - K*np.exp(-r*T)*norm.cdf(d2)
return call_option_price
@tool
def BlackMertonScholes_Put(S: float, # current stock price
K: float, # exercise price of the option
T: float, # option maturity in years
d: float, # annualized dividend rate
r: float, # annualized risk free interest rate
v: float, # stock volatility
):
"""Black-Scholes-Merton option pricing model for put options"""
from scipy.stats import norm
d1 = (np.log(S/K) + (r-d+0.5*v**2)*T)/(v*np.sqrt(T))
d2 = d1 - v*np.sqrt(T)
put_option_price = K*np.exp(-r*T)*norm.cdf(-d2) - S*np.exp(-d*T)*norm.cdf(-d1)
return put_option_price
@tool
def calculate_monthly_payment(principal, annual_interest_rate, loan_term_years):
"""
Calculate the monthly mortgage payment.
Args:
principal (float): The principal amount of the loan.
annual_interest_rate (float): The annual interest rate as a decimal (e.g., 0.06 for 6%).
loan_term_years (int): The loan term in years.
Returns:
float: The monthly mortgage payment.
"""
import math
# Convert annual interest rate to monthly interest rate
monthly_interest_rate = annual_interest_rate / 12
# Calculate the number of monthly payments
num_payments = loan_term_years * 12
# Calculate the monthly payment using the annuity formula
monthly_payment = (principal * monthly_interest_rate) / (1 - math.pow(1 + monthly_interest_rate, -num_payments))
return monthly_payment
```

Each function contains the `@tool` decorator and the required imports. Note also the comment string that describes what each tool does. This will help direct the LLM to relevant tool. Providing sufficient guiding comments in the function is helpful in the form of comment strings, variable annotations, and expolicit argument comments, example of which are shown in the code above. For example, default values in comments will be used by the LLM if the user forgets to provide them (for example, see the explicit mention of a 6% interest rate in `calculate_monthly_payment` function above).

When the `/tools` command is used, Jupyter AI will bind the custom tools file to the LLM currently in use and build a `LangGraph` (https://langchain-ai.github.io/langgraph/). It will use this graph to respond to the query and use the appropriate tools, if available.

As an example, submit this query in the chat interface without using tools: "What is the price of a put option where the stock price is 100, the exercise price is 101, the time to maturity is 1 year, the risk free rate is 3%, the dividend rate is zero, and the stock volatility is 20%?" The correct answer to this query is $6.93. However, though the LLM returns the correct formula, it computes the answer incorrectly:

<img src="../_static/tools_wrong_answer.png"
width="75%"
alt='Incorrect use of the Black-Merton-Scholes formula for pricing a put option.'
class="screenshot" />

Next, use the `/tools` command with the same query to get the correct answer:

<img src="../_static/tools_correct_answer.png"
width="75%"
alt='Incorrect use of the Black-Merton-Scholes formula for pricing a put option.'
class="screenshot" />

You can try the other tools in this example or build your own custom tools file to experiment with this feature.
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .generate import GenerateChatHandler
from .help import HelpChatHandler
from .learn import LearnChatHandler
from .tools import ToolsChatHandler
236 changes: 236 additions & 0 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
##### TO DO #####
# - icon for /tools
# - Where to put the mytools.py file?
# - Pass tools file location on startup, or in chat, or keep one file, or handle multiple tool files.
# - Handling the different providers (Chat*) and models (model_id)
# - To integrate with chat history and memory or not?
# - How to suppress the problem with % sign messing up output?
# - Show full exchange or only the answer?
# - Error handling
# - Documentation
# - What's the best way to add this to magics?
# - Long term: Using the more advanced features of LangGraph, Agents, Multi-agentic workflows, etc.


import argparse
from typing import Dict, Type

from jupyter_ai.models import HumanChatMessage
from jupyter_ai_magics.providers import BaseProvider
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain_core.prompts import PromptTemplate
from langchain.chains import LLMChain

from .base import BaseChatHandler, SlashCommandRoutingType

# LangGraph imports for using tools
import os
import re
import numpy as np
import math
from typing import Literal

from langchain_core.messages import AIMessage
from langchain_core.tools import tool
from langgraph.prebuilt import ToolNode
from langgraph.graph import StateGraph, MessagesState

# Chat Providers (add more as needed)
from langchain_aws import ChatBedrock
from langchain_ollama import ChatOllama
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI, AzureChatOpenAI
from langchain_cohere import ChatCohere
from langchain_google_genai import ChatGoogleGenerativeAI


PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:
Format the answer to be as pretty as possible.
"""
CONDENSE_PROMPT = PromptTemplate.from_template(PROMPT_TEMPLATE)


class ToolsChatHandler(BaseChatHandler):
"""Processes messages prefixed with /tools. This actor will
bind a <tool_name>.py collection of tools to the LLM and
build a computational graph to direct queries to tools
that apply to the prompt. If there is no appropriate tool,
the LLM will default to a standard chat response from the LLM
without using tools.
"""

id = "tools"
name = "Use tools with LLM"
help = "Ask a question that uses your custom tools"
routing_type = SlashCommandRoutingType(slash_id="tools")

uses_llm = True

# def __init__(self, retriever, *args, **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# self._retriever = retriever
self.parser.prog = "/tools"
self.parser.add_argument("query", nargs=argparse.REMAINDER)
self.tools_file_path = os.path.join(self.output_dir, 'mytools.py') # Maybe pass as parameter?
self.chat_provider = "" # Default, updated with function `setChatProvider`


# https://python.langchain.com/v0.2/docs/integrations/platforms/
def setChatProvider(self, provider): # For selecting the model to bind tools with
try:
if "bedrock" in provider.name.lower():
chat_provider = "ChatBedrock"
elif "ollama" in provider.name.lower():
chat_provider = "ChatOllama"
elif "anthropic" in provider.name.lower():
chat_provider = "ChatAnthropic"
elif "azure" in provider.name.lower():
chat_provider = "AzureChatOpenAI"
elif "openai" in provider.name.lower():
chat_provider = "ChatOpenAI"
elif "cohere" in provider.name.lower():
chat_provider = "ChatCohere"
elif "google" in provider.name.lower():
chat_provider = "ChatGoogleGenerativeAI"
return chat_provider
except Exception as e:
self.log.error(e)
response = """The related chat provider is not supported."""
self.reply(response)


def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
unified_parameters = {
**provider_params,
**(self.get_model_parameters(provider, provider_params)),
}
llm = provider(**unified_parameters)
self.chat_provider = self.setChatProvider(provider)
self.llm = llm
memory = ConversationBufferWindowMemory(
memory_key="chat_history", return_messages=True, k=2
)
self.llm_chain = LLMChain(llm=self.llm,
prompt=CONDENSE_PROMPT,
memory=memory,
verbose=False)


# #### TOOLS FOR USE WITH LANGGRAPH #####
"""
Bind tools to LLM and provide chat functionality.
Call:
/tools <query>
"""

def conditional_continue(state: MessagesState) -> Literal["tools", "__end__"]:
messages = state["messages"]
last_message = messages[-1]
if last_message.tool_calls:
return "tools"
return "__end__"

def get_tool_names(tools_file_path):
"""
Read a file and extract the function names following the @tool decorator.
Args:
file_path (str): The path to the file.
Returns:
list: A list of function names.
"""
with open(tools_file_path, 'r') as file:
content = file.read()
# Use a regular expression to find the function names
tool_pattern = r'@tool\n\s*def\s+(\w+)'
tools = re.findall(tool_pattern, content)
return tools

def toolChat(self, query):
print("TOOL CHAT", query)
for chunk in self.app.stream({"messages": [("human", query)]}, stream_mode="values"):
response = chunk["messages"][-1].pretty_print()
return response


##### MAIN FUNCTION #####
def useLLMwithTools(self, chat_provider, model_name, tools_file_path, query):

def call_tool(state: MessagesState):
messages = state["messages"]
response = self.model_with_tools.invoke(messages)
return {"messages": [response]}

# Read in the tools file
file_path = tools_file_path
with open(file_path) as file:
exec(file.read())

# Get tool names and create node with tools
tools = ToolsChatHandler.get_tool_names(file_path)
tools = [eval(j) for j in tools]
tool_node = ToolNode(tools)

# Bind tools to LLM
if chat_provider=="ChatBedrock":
self.model_with_tools = eval(chat_provider)(model_id=model_name,
model_kwargs={"temperature": 0}).bind_tools(tools)
else:
self.model_with_tools = eval(chat_provider)(model=model_name, temperature=0).bind_tools(tools)

# Initialize graph
agentic_workflow = StateGraph(MessagesState)
# Define the agent and tool nodes we will cycle between
agentic_workflow.add_node("agent", call_tool)
agentic_workflow.add_node("tools", tool_node)
# Add edges to the graph
agentic_workflow.add_edge("__start__", "agent")
agentic_workflow.add_conditional_edges("agent", ToolsChatHandler.conditional_continue)
agentic_workflow.add_edge("tools", "agent")
# Compile graph
app = agentic_workflow.compile()

# Run query
# res = ToolsChatHandler.toolChat(self, query)
res = app.invoke({"messages": query})
return res["messages"][-1].content


async def process_message(self, message: HumanChatMessage):
args = self.parse_args(message)
if args is None:
return
query = " ".join(args.query)
if not query:
self.reply(f"{self.parser.format_usage()}", message)
return

self.get_llm_chain()

try:
with self.pending("Using LLM with tools ..."):
# result = await self.llm_chain.acall({"question": query})
response = self.useLLMwithTools(self.chat_provider,
self.llm.model_id,
self.tools_file_path,
query)
self.reply(response, message)
except Exception as e:
self.log.error(e)
response = """Sorry, tool usage failed.
Either (i) this LLM does not accept tools, (ii) there an error in
the custom tools file, (iii) you may also want to check the
location of the tools file, or (iv) you may need to install the
`langchain_<provider_name>` package. (v) Finally, check that you have
authorized access to the LLM."""
self.reply(response, message)

3 changes: 3 additions & 0 deletions packages/jupyter-ai/jupyter_ai/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
GenerateChatHandler,
HelpChatHandler,
LearnChatHandler,
ToolsChatHandler,
)
from .completions.handlers import DefaultInlineCompletionHandler
from .config_manager import ConfigManager
Expand Down Expand Up @@ -315,6 +316,7 @@ def initialize_settings(self):
export_chat_handler = ExportChatHandler(**chat_handler_kwargs)

fix_chat_handler = FixChatHandler(**chat_handler_kwargs)
tools_chat_handler = ToolsChatHandler(**chat_handler_kwargs)

chat_handlers["default"] = default_chat_handler
chat_handlers["/ask"] = ask_chat_handler
Expand All @@ -323,6 +325,7 @@ def initialize_settings(self):
chat_handlers["/learn"] = learn_chat_handler
chat_handlers["/export"] = export_chat_handler
chat_handlers["/fix"] = fix_chat_handler
chat_handlers["/tools"] = tools_chat_handler

slash_command_pattern = r"^[a-zA-Z0-9_]+$"
for chat_handler_ep in chat_handler_eps:
Expand Down

0 comments on commit 564ab9b

Please sign in to comment.