diff --git a/libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/base.py b/libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/base.py index ef5e1eae8a566..cc5205a3dcc83 100644 --- a/libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/base.py +++ b/libs/experimental/langchain_experimental/agents/agent_toolkits/pandas/base.py @@ -33,7 +33,8 @@ def _get_multi_prompt( input_variables: Optional[List[str]] = None, include_df_in_prompt: Optional[bool] = True, number_of_head_rows: int = 5, -) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]: + extra_tools: Sequence[BaseTool] = (), +) -> Tuple[BasePromptTemplate, List[BaseTool]]: num_dfs = len(dfs) if suffix is not None: suffix_to_use = suffix @@ -55,12 +56,13 @@ def _get_multi_prompt( df_locals = {} for i, dataframe in enumerate(dfs): df_locals[f"df{i + 1}"] = dataframe - tools = [PythonAstREPLTool(locals=df_locals)] - + tools = [PythonAstREPLTool(locals=df_locals)] + list(extra_tools) prompt = ZeroShotAgent.create_prompt( - tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables + tools, + prefix=prefix, + suffix=suffix_to_use, + input_variables=input_variables, ) - partial_prompt = prompt.partial() if "dfs_head" in input_variables: dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs]) @@ -77,7 +79,8 @@ def _get_single_prompt( input_variables: Optional[List[str]] = None, include_df_in_prompt: Optional[bool] = True, number_of_head_rows: int = 5, -) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]: + extra_tools: Sequence[BaseTool] = (), +) -> Tuple[BasePromptTemplate, List[BaseTool]]: if suffix is not None: suffix_to_use = suffix include_df_head = True @@ -96,10 +99,13 @@ def _get_single_prompt( if prefix is None: prefix = PREFIX - tools = [PythonAstREPLTool(locals={"df": df})] + tools = [PythonAstREPLTool(locals={"df": df})] + list(extra_tools) prompt = ZeroShotAgent.create_prompt( - tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables + tools, + prefix=prefix, + suffix=suffix_to_use, + input_variables=input_variables, ) partial_prompt = prompt.partial() @@ -117,7 +123,8 @@ def _get_prompt_and_tools( input_variables: Optional[List[str]] = None, include_df_in_prompt: Optional[bool] = True, number_of_head_rows: int = 5, -) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]: + extra_tools: Sequence[BaseTool] = (), +) -> Tuple[BasePromptTemplate, List[BaseTool]]: try: import pandas as pd @@ -141,6 +148,7 @@ def _get_prompt_and_tools( input_variables=input_variables, include_df_in_prompt=include_df_in_prompt, number_of_head_rows=number_of_head_rows, + extra_tools=extra_tools, ) else: if not isinstance(df, pd.DataFrame): @@ -152,6 +160,7 @@ def _get_prompt_and_tools( input_variables=input_variables, include_df_in_prompt=include_df_in_prompt, number_of_head_rows=number_of_head_rows, + extra_tools=extra_tools, ) @@ -287,6 +296,7 @@ def create_pandas_dataframe_agent( ) -> AgentExecutor: """Construct a pandas agent from an LLM and dataframe.""" agent: BaseSingleActionAgent + base_tools: Sequence[BaseTool] if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION: prompt, base_tools = _get_prompt_and_tools( df, @@ -295,8 +305,9 @@ def create_pandas_dataframe_agent( input_variables=input_variables, include_df_in_prompt=include_df_in_prompt, number_of_head_rows=number_of_head_rows, + extra_tools=extra_tools, ) - tools = base_tools + list(extra_tools) + tools = base_tools llm_chain = LLMChain( llm=llm, prompt=prompt, @@ -318,7 +329,7 @@ def create_pandas_dataframe_agent( include_df_in_prompt=include_df_in_prompt, number_of_head_rows=number_of_head_rows, ) - tools = base_tools + list(extra_tools) + tools = list(base_tools) + list(extra_tools) agent = OpenAIFunctionsAgent( llm=llm, prompt=_prompt,