diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py index 38d048f37..fbf7f9fe6 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py @@ -7,18 +7,16 @@ import numpy as np from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider +from jupyter_core.paths import jupyter_config_dir from langchain_core.prompts import PromptTemplate from langchain_core.runnables import ConfigurableFieldSpec from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.tools import tool from langgraph.graph import MessagesState, StateGraph from langgraph.prebuilt import ToolNode -from langchain_core.runnables import ConfigurableFieldSpec -from langchain_core.runnables.history import RunnableWithMessageHistory from .base import BaseChatHandler, SlashCommandRoutingType -from jupyter_core.paths import jupyter_config_dir TOOLS_DIR = os.path.join(jupyter_config_dir(), "jupyter-ai", "tools") PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. @@ -85,10 +83,7 @@ def __init__(self, *args, **kwargs): self.parser.add_argument("query", nargs=argparse.REMAINDER) self.tools_file_path = None - - def setup_llm( - self, provider: Type[BaseProvider], provider_params: Dict[str, str] - ): + def setup_llm(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): """Sets up the LLM before creating the LLM Chain""" unified_parameters = { "verbose": True, @@ -124,7 +119,6 @@ def create_llm_chain( ) self.llm_chain = runnable - def get_tool_files(self) -> list: """ Gets required tool files from TOOLS_DIR @@ -179,7 +173,7 @@ def get_tools(file_paths: list) -> list: for file_path in file_paths: with open(file_path) as file: exec(file.read()) - try: # For each tool file, collect tool list + try: # For each tool file, collect tool list with open(file_path) as file: content = file.read() tree = ast.parse(content) @@ -246,19 +240,13 @@ async def process_message(self, message: HumanChatMessage): return if args.list: - tool_files = os.listdir( - os.path.join(Path.home(), TOOLS_DIR) - ) + tool_files = os.listdir(os.path.join(Path.home(), TOOLS_DIR)) self.reply(f"The available tools files are: {tool_files}") return elif args.tools: - self.tools_file_path = os.path.join( - Path.home(), TOOLS_DIR, args.tools - ) + self.tools_file_path = os.path.join(Path.home(), TOOLS_DIR, args.tools) else: - self.tools_file_path = os.path.join( - Path.home(), TOOLS_DIR - ) + self.tools_file_path = os.path.join(Path.home(), TOOLS_DIR) query = " ".join(args.query) if not query: