From fdd8ff8441812ab7591adb2c4a14d6d43c1170da Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Mon, 16 Oct 2023 15:12:56 -0700 Subject: [PATCH] refactor: two output parsing methods for if runnable vs no runnable is passed in --- langchain/src/agents/agent.ts | 25 +++++++++++++++++++++---- langchain/src/agents/types.ts | 7 ++++++- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/langchain/src/agents/agent.ts b/langchain/src/agents/agent.ts index 23bd57735fd5..46cae57f66c8 100644 --- a/langchain/src/agents/agent.ts +++ b/langchain/src/agents/agent.ts @@ -19,7 +19,6 @@ import { StoppingMethod, } from "./types.js"; import { Runnable } from "../schema/runnable/base.js"; -import { StringOutputParser } from "../schema/output_parser.js"; /** * Record type for arguments passed to output parsers. @@ -416,12 +415,30 @@ export abstract class Agent< newInputs.stop = this._stop(); } - const output = await this.runnable.invoke(newInputs, callbackManager); - + /** + * The output type for this is a little weird, depending on if a + * runnable was passed in or not. + * + * If a runnable is passed in (and not an LLMChain), then the output is + * `AgentAction | AgentFinish`. + * If an LLMChain was passed, the output will be `{ text: string }`. + */ + const output = (await this.runnable.invoke(newInputs, callbackManager)) as + | { + text: string; + } + | AgentAction + | AgentFinish; + if (!this.outputParser) { throw new Error("Output parser not set"); } - return this.outputParser.parse(output, callbackManager); + + if ("text" in output) { + return this.outputParser.parse(output.text, callbackManager); + } + + return this.outputParser.parseAgentOutput(output, callbackManager); } /** diff --git a/langchain/src/agents/types.ts b/langchain/src/agents/types.ts index cfb9fd05263d..6b20a612b3a9 100644 --- a/langchain/src/agents/types.ts +++ b/langchain/src/agents/types.ts @@ -1,3 +1,4 @@ +import { Callbacks } from "../callbacks/manager.js"; import { LLMChain } from "../chains/llm_chain.js"; import { SerializedLLMChain } from "../chains/serde.js"; import { @@ -47,7 +48,11 @@ export interface RunnableAgentInput< */ export abstract class AgentActionOutputParser extends BaseOutputParser< AgentAction | AgentFinish -> {} +> { + async parseAgentOutput(output: AgentFinish | AgentAction, _callbackManager?: Callbacks) { + return output; + } +} /** * Type representing the stopping method for an agent. It can be either