From 1aaa56f0483846d3797453110e50e1e3adc2715c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Sep 2024 08:48:09 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../jupyter_ai/chat_handlers/tools.py | 25 +++++-------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py index 2c0438d34..1c91cda1a 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/tools.py @@ -7,18 +7,16 @@ import numpy as np from jupyter_ai.models import HumanChatMessage from jupyter_ai_magics.providers import BaseProvider +from jupyter_core.paths import jupyter_config_dir from langchain_core.prompts import PromptTemplate from langchain_core.runnables import ConfigurableFieldSpec from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.tools import tool from langgraph.graph import MessagesState, StateGraph from langgraph.prebuilt import ToolNode -from langchain_core.runnables import ConfigurableFieldSpec -from langchain_core.runnables.history import RunnableWithMessageHistory from .base import BaseChatHandler, SlashCommandRoutingType -from jupyter_core.paths import jupyter_config_dir TOOLS_DIR = os.path.join(jupyter_config_dir(), "jupyter-ai", "tools") PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question. @@ -70,10 +68,7 @@ def __init__(self, *args, **kwargs): self.parser.add_argument("query", nargs=argparse.REMAINDER) self.tools_file_path = None - - def setup_llm( - self, provider: Type[BaseProvider], provider_params: Dict[str, str] - ): + def setup_llm(self, provider: Type[BaseProvider], provider_params: Dict[str, str]): """Sets up the LLM before creating the LLM Chain""" unified_parameters = { "verbose": True, @@ -109,7 +104,6 @@ def create_llm_chain( ) self.llm_chain = runnable - def get_tool_files(self) -> list: """ Gets required tool files from TOOLS_DIR @@ -155,7 +149,6 @@ def conditional_continue(state: MessagesState) -> Literal["tools", "__end__"]: return "tools" return "__end__" - def get_tools(file_paths): """Get all tool objects from the tool files""" if len(file_paths) > 0: @@ -163,7 +156,7 @@ def get_tools(file_paths): for file_path in file_paths: with open(file_path) as file: exec(file.read()) - try: # For each tool file, collect tool list + try: # For each tool file, collect tool list with open(file_path) as file: content = file.read() tree = ast.parse(content) @@ -226,19 +219,13 @@ async def process_message(self, message: HumanChatMessage): return if args.list: - tool_files = os.listdir( - os.path.join(Path.home(), TOOLS_DIR) - ) + tool_files = os.listdir(os.path.join(Path.home(), TOOLS_DIR)) self.reply(f"The available tools files are: {tool_files}") return elif args.tools: - self.tools_file_path = os.path.join( - Path.home(), TOOLS_DIR, args.tools - ) + self.tools_file_path = os.path.join(Path.home(), TOOLS_DIR, args.tools) else: - self.tools_file_path = os.path.join( - Path.home(), TOOLS_DIR - ) + self.tools_file_path = os.path.join(Path.home(), TOOLS_DIR) query = " ".join(args.query) if not query: