diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py index acc14b536..90429fd7b 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py @@ -1,9 +1,9 @@ import argparse import math +import ast # LangGraph imports for using tools import os -import re from typing import Dict, Literal, Type import numpy as np @@ -57,15 +57,14 @@ class ToolsChatHandler(BaseChatHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # self._retriever = retriever self.parser.prog = "/tools" self.parser.add_argument("query", nargs=argparse.REMAINDER) self.tools_file_path = os.path.join( self.output_dir, "mytools.py" ) # Maybe pass as parameter? - self.chat_provider = "" # Default, updated with function `setChatProvider` - # https://python.langchain.com/v0.2/docs/integrations/platforms/ + +# https://python.langchain.com/v0.2/docs/integrations/platforms/ def setChatProvider(self, provider): # For selecting the model to bind tools with try: if "bedrock" in provider.name.lower(): @@ -88,6 +87,7 @@ def setChatProvider(self, provider): # For selecting the model to bind tools wi response = """The related chat provider is not supported.""" self.reply(response) + def create_llm_chain( self, provider: Type[BaseProvider], provider_params: Dict[str, str] ): @@ -96,8 +96,8 @@ def create_llm_chain( **(self.get_model_parameters(provider, provider_params)), } llm = provider(**unified_parameters) - self.chat_provider = self.setChatProvider(provider) self.llm = llm + self.chat_provider = self.setChatProvider(provider) memory = ConversationBufferWindowMemory( memory_key="chat_history", return_messages=True, k=2 ) @@ -129,9 +129,13 @@ def get_tool_names(tools_file_path): """ with open(tools_file_path) as file: content = file.read() - # Use a regular expression to find the function names - tool_pattern = r"@tool\n\s*def\s+(\w+)" - tools = re.findall(tool_pattern, content) + tree = ast.parse(content) + tools = [] + for node in ast.walk(tree): + if isinstance(node, ast.FunctionDef): + for decorator in node.decorator_list: + if isinstance(decorator, ast.Name) and decorator.id == 'tool': + tools.append(node.name) return tools def toolChat(self, query): @@ -143,31 +147,45 @@ def toolChat(self, query): return response ##### MAIN FUNCTION ##### - def useLLMwithTools(self, chat_provider, model_name, tools_file_path, query): - + def useLLMwithTools(self, query): + """ + LangGraph documentation : https://langchain-ai.github.io/langgraph/tutorials/introduction/ + The code below: + 1. Extracts the function names in the custom tools file + 2. Adds the tools to the Tool Node + 3. Binds the Tool Node to the LLM + 4. Sets up a basic LangGraph with nodes and edges + 5. Compiles the graph into a runnable app + 6. This function is then called with a prompt + Every time a query is submitted the langgraph is rebuilt in case the tools file has been changed. + """ def call_tool(state: MessagesState): messages = state["messages"] response = self.model_with_tools.invoke(messages) return {"messages": [response]} - # Read in the tools file - file_path = tools_file_path + # Read in the tools file, WARNING - THIS USES EXEC() + file_path = self.tools_file_path with open(file_path) as file: exec(file.read()) # Get tool names and create node with tools - tools = ToolsChatHandler.get_tool_names(file_path) - tools = [eval(j) for j in tools] + tool_names = ToolsChatHandler.get_tool_names(file_path) + tools = [eval(j) for j in tool_names] tool_node = ToolNode(tools) # Bind tools to LLM - if chat_provider == "ChatBedrock": - self.model_with_tools = eval(chat_provider)( - model_id=model_name, model_kwargs={"temperature": 0} + # print("SELF.LLM_CLASS", self.llm.__class__.id, "MODEL", self.llm.model_id, "CHAT PROVIDER", eval(self.chat_provider.__class__), "SELF.LLM.CHAT_MODELS", self.llm.chat_models) + # self.model_with_tools = self.llm.__class__( + # model_id=self.llm.model_id + # ).bind_tools(tools) + if self.chat_provider == "ChatBedrock": + self.model_with_tools = eval(self.chat_provider)( + model_id=self.llm.model_id, #model_kwargs={"temperature": 0} ).bind_tools(tools) else: - self.model_with_tools = eval(chat_provider)( - model=model_name, temperature=0 + self.model_with_tools = eval(self.chat_provider)( + model=self.llm.model_id, #temperature=0 ).bind_tools(tools) # Initialize graph @@ -185,7 +203,7 @@ def call_tool(state: MessagesState): app = agentic_workflow.compile() # Run query - # res = ToolsChatHandler.toolChat(self, query) + # res = ToolsChatHandler.toolChat(self, query) # For all Human and AI messages, if needed later res = app.invoke({"messages": query}) return res["messages"][-1].content @@ -202,10 +220,7 @@ async def process_message(self, message: HumanChatMessage): try: with self.pending("Using LLM with tools ..."): - # result = await self.llm_chain.acall({"question": query}) - response = self.useLLMwithTools( - self.chat_provider, self.llm.model_id, self.tools_file_path, query - ) + response = self.useLLMwithTools(query) self.reply(response, message) except Exception as e: self.log.error(e)