Skip to content

Commit

Permalink
refactor: updated base agent to accept runnables, deprecated llmChain
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Oct 16, 2023
1 parent 21f0433 commit f99c304
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 28 deletions.
43 changes: 30 additions & 13 deletions langchain/src/agents/agent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,14 @@ export interface AgentArgs {
* include a variable called "agent_scratchpad" where the agent can put its
* intermediary work.
*/
export abstract class Agent extends BaseSingleActionAgent {
llmChain: LLMChain;
export abstract class Agent<
RunInput extends ChainValues & {
agent_scratchpad?: string | BaseMessage[];
stop?: string[];
} = any,
RunOutput extends AgentAction | AgentFinish = any
> extends BaseSingleActionAgent {
runnable: Runnable<RunInput, RunOutput>;

outputParser: AgentActionOutputParser | undefined;

Expand All @@ -288,13 +294,23 @@ export abstract class Agent extends BaseSingleActionAgent {
}

get inputKeys(): string[] {
return this.llmChain.inputKeys.filter((k) => k !== "agent_scratchpad");
// eslint-disable-next-line no-instanceof/no-instanceof
if (this.runnable instanceof LLMChain) {
return this.runnable.inputKeys.filter((k) => k !== "agent_scratchpad");
}
return [];
}

constructor(input: AgentInput) {
super(input);

this.llmChain = input.llmChain;
if (!input.runnable && !input.llmChain) {
throw new Error(
`Runnable and LLMChain are both missing, one is required.`
);
}

this.runnable = input.runnable;
this._allowedTools = input.allowedTools;
this.outputParser = input.outputParser;
}
Expand Down Expand Up @@ -385,12 +401,12 @@ export abstract class Agent extends BaseSingleActionAgent {

private async _plan(
steps: AgentStep[],
inputs: ChainValues,
inputs: RunInput,
suffix?: string,
callbackManager?: CallbackManager
): Promise<AgentAction | AgentFinish> {
const thoughts = await this.constructScratchPad(steps);
const newInputs: ChainValues = {
const newInputs: RunInput = {
...inputs,
agent_scratchpad: suffix ? `${thoughts}${suffix}` : thoughts,
};
Expand All @@ -399,15 +415,16 @@ export abstract class Agent extends BaseSingleActionAgent {
newInputs.stop = this._stop();
}

const output = await this.llmChain.predict(newInputs, callbackManager);
const output = await this.runnable.invoke(newInputs, callbackManager);
console.log({
newInputs,
output,
});
if (!this.outputParser) {
throw new Error("Output parser not set");
}
return this.outputParser.parse(output, callbackManager);
return output;
// if (!this.outputParser) {
// throw new Error("Output parser not set");
// }
// return this.outputParser.parse(output, callbackManager);
}

/**
Expand All @@ -421,7 +438,7 @@ export abstract class Agent extends BaseSingleActionAgent {
*/
plan(
steps: AgentStep[],
inputs: ChainValues,
inputs: RunInput,
callbackManager?: CallbackManager
): Promise<AgentAction | AgentFinish> {
return this._plan(steps, inputs, undefined, callbackManager);
Expand All @@ -433,7 +450,7 @@ export abstract class Agent extends BaseSingleActionAgent {
async returnStoppedResponse(
earlyStoppingMethod: StoppingMethod,
steps: AgentStep[],
inputs: ChainValues,
inputs: RunInput,
callbackManager?: CallbackManager
): Promise<AgentFinish> {
if (earlyStoppingMethod === "force") {
Expand Down
2 changes: 1 addition & 1 deletion langchain/src/agents/chat/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ export class ChatAgent extends Agent {
args?.outputParser ?? ChatAgent.getDefaultOutputParser();

return new ChatAgent({
llmChain: chain,
runnable: chain,
outputParser,
allowedTools: tools.map((t) => t.name),
});
Expand Down
2 changes: 1 addition & 1 deletion langchain/src/agents/chat_convo/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ export class ChatConversationalAgent extends Agent {
callbacks: args?.callbacks ?? args?.callbackManager,
});
return new ChatConversationalAgent({
llmChain: chain,
runnable: chain,
outputParser,
allowedTools: tools.map((t) => t.name),
});
Expand Down
2 changes: 1 addition & 1 deletion langchain/src/agents/executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import { Runnable } from "../schema/runnable/base.js";
* properties specific to agent execution.
*/
export interface AgentExecutorInput extends ChainInputs {
agent: BaseSingleActionAgent | BaseMultiActionAgent | Runnable;
agent: BaseSingleActionAgent | BaseMultiActionAgent;
tools: this["agent"]["ToolType"][];
returnIntermediateSteps?: boolean;
maxIterations?: number;
Expand Down
2 changes: 1 addition & 1 deletion langchain/src/agents/mrkl/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ export class ZeroShotAgent extends Agent {
});

return new ZeroShotAgent({
llmChain: chain,
runnable: chain,
allowedTools: tools.map((t) => t.name),
outputParser,
});
Expand Down
4 changes: 2 additions & 2 deletions langchain/src/agents/openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ export class OpenAIAgent extends Agent {
callbacks: args?.callbacks,
});
return new OpenAIAgent({
llmChain: chain,
runnable: chain,
allowedTools: tools.map((t) => t.name),
tools,
});
Expand Down Expand Up @@ -233,7 +233,7 @@ export class OpenAIAgent extends Agent {
}

// Split inputs between prompt and llm
const llm = this.llmChain.llm as ChatOpenAI;
const llm = this.runnable.llm as ChatOpenAI;
const valuesForPrompt = { ...newInputs };
const valuesForLLM: (typeof llm)["CallOptions"] = {
tools: this.tools,
Expand Down
2 changes: 1 addition & 1 deletion langchain/src/agents/structured_chat/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ export class StructuredChatAgent extends Agent {
});

return new StructuredChatAgent({
llmChain: chain,
runnable: chain,
outputParser,
allowedTools: tools.map((t) => t.name),
});
Expand Down
8 changes: 6 additions & 2 deletions langchain/src/agents/tests/agent.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import { WebBrowser } from "../../tools/webbrowser.js";
import { Tool } from "../../tools/base.js";
import { ChatOpenAI } from "../../chat_models/openai.js";
import { RunnableSequence } from "../../schema/runnable/base.js";
import { RunnableAgent } from "../agent.js";

test("Run agent from hub", async () => {
const model = new OpenAI({ temperature: 0, modelName: "text-babbage-001" });
Expand Down Expand Up @@ -63,8 +62,13 @@ test("Pass runnable to agent executor", async () => {
outputParser,
]);

const agent = new ZeroShotAgent({
runnable,
allowedTools: tools.map((t) => t.name),
});

const executor = AgentExecutor.fromAgentAndTools({
agent: runnable,
agent,
tools,
});
const res = await executor.invoke({
Expand Down
21 changes: 17 additions & 4 deletions langchain/src/agents/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import { LLMChain } from "../chains/llm_chain.js";
import { SerializedLLMChain } from "../chains/serde.js";
import { AgentAction, AgentFinish, ChainValues } from "../schema/index.js";
import {
AgentAction,
AgentFinish,
BaseMessage,
ChainValues,
} from "../schema/index.js";
import { BaseOutputParser } from "../schema/output_parser.js";
import { Runnable } from "../schema/runnable/base.js";

Expand All @@ -9,11 +14,19 @@ import { Runnable } from "../schema/runnable/base.js";
* LLMChain instance, an optional output parser, and an optional list of
* allowed tools.
*/
export interface AgentInput {
llmChain: LLMChain;
export type AgentInput<
RunInput extends ChainValues & {
agent_scratchpad?: string | BaseMessage[];
stop?: string[];
} = any,
RunOutput extends AgentAction | AgentFinish = any
> = {
/** @deprecated - use runnable instead */
llmChain?: LLMChain;
runnable: Runnable<RunInput, RunOutput>;
outputParser: AgentActionOutputParser | undefined;
allowedTools?: string[];
}
};

/**
* Interface defining the input for creating an agent that uses runnables.
Expand Down
2 changes: 0 additions & 2 deletions langchain/src/schema/runnable/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ export abstract class Runnable<
> extends Serializable {
protected lc_runnable = true;

declare ToolType: StructuredTool;

abstract invoke(
input: RunInput,
options?: Partial<CallOptions>
Expand Down

0 comments on commit f99c304

Please sign in to comment.