Skip to content

Commit

Permalink
Update tools.py
Browse files Browse the repository at this point in the history
  • Loading branch information
srdas committed Sep 15, 2024
1 parent 78b605c commit 7f72556
Showing 1 changed file with 39 additions and 24 deletions.
63 changes: 39 additions & 24 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import argparse
import math
import ast

# LangGraph imports for using tools
import os
import re
from typing import Dict, Literal, Type

import numpy as np
Expand Down Expand Up @@ -57,15 +57,14 @@ class ToolsChatHandler(BaseChatHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# self._retriever = retriever
self.parser.prog = "/tools"
self.parser.add_argument("query", nargs=argparse.REMAINDER)
self.tools_file_path = os.path.join(
self.output_dir, "mytools.py"
) # Maybe pass as parameter?
self.chat_provider = "" # Default, updated with function `setChatProvider`

# https://python.langchain.com/v0.2/docs/integrations/platforms/

# https://python.langchain.com/v0.2/docs/integrations/platforms/
def setChatProvider(self, provider): # For selecting the model to bind tools with
try:
if "bedrock" in provider.name.lower():
Expand All @@ -88,6 +87,7 @@ def setChatProvider(self, provider): # For selecting the model to bind tools wi
response = """The related chat provider is not supported."""
self.reply(response)


def create_llm_chain(
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
):
Expand All @@ -96,8 +96,8 @@ def create_llm_chain(
**(self.get_model_parameters(provider, provider_params)),
}
llm = provider(**unified_parameters)
self.chat_provider = self.setChatProvider(provider)
self.llm = llm
self.chat_provider = self.setChatProvider(provider)
memory = ConversationBufferWindowMemory(
memory_key="chat_history", return_messages=True, k=2
)
Expand Down Expand Up @@ -129,9 +129,13 @@ def get_tool_names(tools_file_path):
"""
with open(tools_file_path) as file:
content = file.read()
# Use a regular expression to find the function names
tool_pattern = r"@tool\n\s*def\s+(\w+)"
tools = re.findall(tool_pattern, content)
tree = ast.parse(content)
tools = []
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
for decorator in node.decorator_list:
if isinstance(decorator, ast.Name) and decorator.id == 'tool':
tools.append(node.name)
return tools

def toolChat(self, query):
Expand All @@ -143,31 +147,45 @@ def toolChat(self, query):
return response

##### MAIN FUNCTION #####
def useLLMwithTools(self, chat_provider, model_name, tools_file_path, query):

def useLLMwithTools(self, query):
"""
LangGraph documentation : https://langchain-ai.github.io/langgraph/tutorials/introduction/
The code below:
1. Extracts the function names in the custom tools file
2. Adds the tools to the Tool Node
3. Binds the Tool Node to the LLM
4. Sets up a basic LangGraph with nodes and edges
5. Compiles the graph into a runnable app
6. This function is then called with a prompt
Every time a query is submitted the langgraph is rebuilt in case the tools file has been changed.
"""
def call_tool(state: MessagesState):
messages = state["messages"]
response = self.model_with_tools.invoke(messages)
return {"messages": [response]}

# Read in the tools file
file_path = tools_file_path
# Read in the tools file, WARNING - THIS USES EXEC()
file_path = self.tools_file_path
with open(file_path) as file:
exec(file.read())

# Get tool names and create node with tools
tools = ToolsChatHandler.get_tool_names(file_path)
tools = [eval(j) for j in tools]
tool_names = ToolsChatHandler.get_tool_names(file_path)
tools = [eval(j) for j in tool_names]
tool_node = ToolNode(tools)

# Bind tools to LLM
if chat_provider == "ChatBedrock":
self.model_with_tools = eval(chat_provider)(
model_id=model_name, model_kwargs={"temperature": 0}
# print("SELF.LLM_CLASS", self.llm.__class__.id, "MODEL", self.llm.model_id, "CHAT PROVIDER", eval(self.chat_provider.__class__), "SELF.LLM.CHAT_MODELS", self.llm.chat_models)
# self.model_with_tools = self.llm.__class__(
# model_id=self.llm.model_id
# ).bind_tools(tools)
if self.chat_provider == "ChatBedrock":
self.model_with_tools = eval(self.chat_provider)(
model_id=self.llm.model_id, #model_kwargs={"temperature": 0}
).bind_tools(tools)
else:
self.model_with_tools = eval(chat_provider)(
model=model_name, temperature=0
self.model_with_tools = eval(self.chat_provider)(
model=self.llm.model_id, #temperature=0
).bind_tools(tools)

# Initialize graph
Expand All @@ -185,7 +203,7 @@ def call_tool(state: MessagesState):
app = agentic_workflow.compile()

# Run query
# res = ToolsChatHandler.toolChat(self, query)
# res = ToolsChatHandler.toolChat(self, query) # For all Human and AI messages, if needed later
res = app.invoke({"messages": query})
return res["messages"][-1].content

Expand All @@ -202,10 +220,7 @@ async def process_message(self, message: HumanChatMessage):

try:
with self.pending("Using LLM with tools ..."):
# result = await self.llm_chain.acall({"question": query})
response = self.useLLMwithTools(
self.chat_provider, self.llm.model_id, self.tools_file_path, query
)
response = self.useLLMwithTools(query)
self.reply(response, message)
except Exception as e:
self.log.error(e)
Expand Down

0 comments on commit 7f72556

Please sign in to comment.