Skip to content

Commit

Permalink
Fix #11737 issue (extra_tools option of create_pandas_dataframe_agent…
Browse files Browse the repository at this point in the history
… is not working) (#13203)

- **Description:** Fix #11737 issue (extra_tools option of
create_pandas_dataframe_agent is not working),
  - **Issue:** #11737 ,
  - **Dependencies:** no,
- **Tag maintainer:** @baskaryan, @eyurtsev, @hwchase17 I needed this
method at work, so I modified it myself and used it. There is a similar
issue(#11737) and PR(#13018) of @PyroGenesis, so I combined my code at
the original PR.
You may be busy, but it would be great help for me if you checked. Thank
you.
  - **Twitter handle:** @lunara_x 

If you need an .ipynb example about this, please tag me. 
I will share what I am working on after removing any work-related
content.

---------

Co-authored-by: Harrison Chase <[email protected]>
  • Loading branch information
eunhye1kim and hwchase17 authored Dec 5, 2023
1 parent 77a15fa commit f758c8a
Showing 1 changed file with 22 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def _get_multi_prompt(
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
extra_tools: Sequence[BaseTool] = (),
) -> Tuple[BasePromptTemplate, List[BaseTool]]:
num_dfs = len(dfs)
if suffix is not None:
suffix_to_use = suffix
Expand All @@ -55,12 +56,13 @@ def _get_multi_prompt(
df_locals = {}
for i, dataframe in enumerate(dfs):
df_locals[f"df{i + 1}"] = dataframe
tools = [PythonAstREPLTool(locals=df_locals)]

tools = [PythonAstREPLTool(locals=df_locals)] + list(extra_tools)
prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
tools,
prefix=prefix,
suffix=suffix_to_use,
input_variables=input_variables,
)

partial_prompt = prompt.partial()
if "dfs_head" in input_variables:
dfs_head = "\n\n".join([d.head(number_of_head_rows).to_markdown() for d in dfs])
Expand All @@ -77,7 +79,8 @@ def _get_single_prompt(
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
extra_tools: Sequence[BaseTool] = (),
) -> Tuple[BasePromptTemplate, List[BaseTool]]:
if suffix is not None:
suffix_to_use = suffix
include_df_head = True
Expand All @@ -96,10 +99,13 @@ def _get_single_prompt(
if prefix is None:
prefix = PREFIX

tools = [PythonAstREPLTool(locals={"df": df})]
tools = [PythonAstREPLTool(locals={"df": df})] + list(extra_tools)

prompt = ZeroShotAgent.create_prompt(
tools, prefix=prefix, suffix=suffix_to_use, input_variables=input_variables
tools,
prefix=prefix,
suffix=suffix_to_use,
input_variables=input_variables,
)

partial_prompt = prompt.partial()
Expand All @@ -117,7 +123,8 @@ def _get_prompt_and_tools(
input_variables: Optional[List[str]] = None,
include_df_in_prompt: Optional[bool] = True,
number_of_head_rows: int = 5,
) -> Tuple[BasePromptTemplate, List[PythonAstREPLTool]]:
extra_tools: Sequence[BaseTool] = (),
) -> Tuple[BasePromptTemplate, List[BaseTool]]:
try:
import pandas as pd

Expand All @@ -141,6 +148,7 @@ def _get_prompt_and_tools(
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
extra_tools=extra_tools,
)
else:
if not isinstance(df, pd.DataFrame):
Expand All @@ -152,6 +160,7 @@ def _get_prompt_and_tools(
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
extra_tools=extra_tools,
)


Expand Down Expand Up @@ -287,6 +296,7 @@ def create_pandas_dataframe_agent(
) -> AgentExecutor:
"""Construct a pandas agent from an LLM and dataframe."""
agent: BaseSingleActionAgent
base_tools: Sequence[BaseTool]
if agent_type == AgentType.ZERO_SHOT_REACT_DESCRIPTION:
prompt, base_tools = _get_prompt_and_tools(
df,
Expand All @@ -295,8 +305,9 @@ def create_pandas_dataframe_agent(
input_variables=input_variables,
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
extra_tools=extra_tools,
)
tools = base_tools + list(extra_tools)
tools = base_tools
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
Expand All @@ -318,7 +329,7 @@ def create_pandas_dataframe_agent(
include_df_in_prompt=include_df_in_prompt,
number_of_head_rows=number_of_head_rows,
)
tools = base_tools + list(extra_tools)
tools = list(base_tools) + list(extra_tools)
agent = OpenAIFunctionsAgent(
llm=llm,
prompt=_prompt,
Expand Down

0 comments on commit f758c8a

Please sign in to comment.