Skip to content

Commit

Permalink
update branching how to
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Aug 19, 2024
1 parent 2add468 commit a1440b7
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 78 deletions.
1 change: 0 additions & 1 deletion examples/agent_executor/base.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
"const AgentState = Annotation.Root({\n",
" messages: Annotation<BaseMessage[]>({\n",
" reducer: (x, y) => x.concat(y),\n",
" default: () => [],\n",
" }),\n",
"});"
]
Expand Down
120 changes: 43 additions & 77 deletions examples/how-tos/branching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand All @@ -80,20 +80,13 @@
}
],
"source": [
"import { END, START, StateGraph } from \"@langchain/langgraph\";\n",
"import { StateGraphArgs } from \"@langchain/langgraph\";\n",
"\n",
"// Define the state type\n",
"interface IState {\n",
" aggregate: string[];\n",
"}\n",
"import { END, START, StateGraph, Annotation } from \"@langchain/langgraph\";\n",
"\n",
"const graphState: StateGraphArgs<IState>[\"channels\"] = {\n",
" aggregate: {\n",
" value: (x: string[], y: string[]) => x.concat(y),\n",
" default: () => [],\n",
" },\n",
"};\n",
"const GraphState = Annotation.Root({\n",
" aggregate: Annotation<string[]>({\n",
" reducer: (x, y) => x.concat(y),\n",
" })\n",
"})\n",
"\n",
"// Define the ReturnNodeValue class\n",
"class ReturnNodeValue {\n",
Expand All @@ -103,7 +96,7 @@
" this._value = value;\n",
" }\n",
"\n",
" public call(state: IState) {\n",
" public call(state: typeof GraphState.State) {\n",
" console.log(`Adding ${this._value} to ${state.aggregate}`);\n",
" return { aggregate: [this._value] };\n",
" }\n",
Expand All @@ -115,7 +108,7 @@
"const nodeC = new ReturnNodeValue(\"I'm C\");\n",
"const nodeD = new ReturnNodeValue(\"I'm D\");\n",
"\n",
"const builder = new StateGraph<IState>({ channels: graphState })\n",
"const builder = new StateGraph(GraphState)\n",
" .addNode(\"a\", nodeA.call.bind(nodeA))\n",
" .addEdge(START, \"a\")\n",
" .addNode(\"b\", nodeB.call.bind(nodeB))\n",
Expand Down Expand Up @@ -151,7 +144,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -167,23 +160,14 @@
}
],
"source": [
"// Define the state type\n",
"interface IState2 {\n",
" // The operator.add reducer function makes this append-only\n",
" aggregate: string[];\n",
" which: string;\n",
"}\n",
"\n",
"const graphState2: StateGraphArgs<IState2>[\"channels\"] = {\n",
" aggregate: {\n",
" value: (x: string[], y: string[]) => x.concat(y),\n",
" default: () => [],\n",
" },\n",
" which: {\n",
" value: (x: string, y: string) => (y ? y : x),\n",
" default: () => \"bc\",\n",
" },\n",
"};\n",
"const GraphStateConditionalBranching = Annotation.Root({\n",
" aggregate: Annotation<string[]>({\n",
" reducer: (x, y) => x.concat(y),\n",
" }),\n",
" which: Annotation<string>({\n",
" reducer: (x: string, y: string) => (y ?? x),\n",
" })\n",
"})\n",
"\n",
"// Create the graph\n",
"const nodeA2 = new ReturnNodeValue(\"I'm A\");\n",
Expand All @@ -192,14 +176,14 @@
"const nodeD2 = new ReturnNodeValue(\"I'm D\");\n",
"const nodeE2 = new ReturnNodeValue(\"I'm E\");\n",
"// Define the route function\n",
"function routeCDorBC(state: IState2): string[] {\n",
"function routeCDorBC(state: typeof GraphStateConditionalBranching.State): string[] {\n",
" if (state.which === \"cd\") {\n",
" return [\"c\", \"d\"];\n",
" }\n",
" return [\"b\", \"c\"];\n",
"}\n",
"\n",
"const builder2 = new StateGraph<IState2>({ channels: graphState2 })\n",
"const builder2 = new StateGraph(GraphStateConditionalBranching)\n",
" .addNode(\"a\", nodeA2.call.bind(nodeA2))\n",
" .addEdge(START, \"a\")\n",
" .addNode(\"b\", nodeB2.call.bind(nodeB2))\n",
Expand All @@ -222,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -263,7 +247,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -282,6 +266,11 @@
}
],
"source": [
"type ScoredValue = {\n",
" value: string;\n",
" score: number;\n",
"};\n",
"\n",
"const reduceFanouts = (left?: ScoredValue[], right?: ScoredValue[]) => {\n",
" if (!left) {\n",
" left = [];\n",
Expand All @@ -293,34 +282,18 @@
" return left.concat(right);\n",
"};\n",
"\n",
"type ScoredValue = {\n",
" value: string;\n",
" score: number;\n",
"};\n",
"\n",
"// Define the state type\n",
"// 'value' defines the 'reducer', which determines how updates are applied\n",
"// 'default' defines the default value for the state\n",
"interface IState3 {\n",
" aggregate: string[];\n",
" which: string;\n",
" fanoutValues: ScoredValue[];\n",
"}\n",
"const GraphStateStableSorting = Annotation.Root({\n",
" aggregate: Annotation<string[]>({\n",
" reducer: (x, y) => x.concat(y),\n",
" }),\n",
" which: Annotation<string>({\n",
" reducer: (x: string, y: string) => (y ?? x),\n",
" }),\n",
" fanoutValues: Annotation<ScoredValue[]>({\n",
" reducer: reduceFanouts,\n",
" }),\n",
"})\n",
"\n",
"const graphState3: StateGraphArgs<IState3>[\"channels\"] = {\n",
" aggregate: {\n",
" value: (x: string[], y: string[]) => x.concat(y),\n",
" default: () => [],\n",
" },\n",
" which: {\n",
" value: (x: string, y: string) => (y ? y : x),\n",
" default: () => \"\",\n",
" },\n",
" fanoutValues: {\n",
" value: reduceFanouts,\n",
" default: () => [],\n",
" },\n",
"};\n",
"\n",
"class ParallelReturnNodeValue {\n",
" private _value: string;\n",
Expand All @@ -331,7 +304,7 @@
" this._score = score;\n",
" }\n",
"\n",
" public call(state: IState3) {\n",
" public call(state: typeof GraphStateStableSorting.State) {\n",
" console.log(`Adding ${this._value} to ${state.aggregate}`);\n",
" return { fanoutValues: [{ value: this._value, score: this._score }] };\n",
" }\n",
Expand All @@ -345,7 +318,7 @@
"const nodeC3 = new ParallelReturnNodeValue(\"I'm C\", 0.9);\n",
"const nodeD3 = new ParallelReturnNodeValue(\"I'm D\", 0.3);\n",
"\n",
"const aggregateFanouts = (state: { fanoutValues: ScoredValue[] }) => {\n",
"const aggregateFanouts = (state: typeof GraphStateStableSorting.State) => {\n",
" // Sort by score (reversed)\n",
" state.fanoutValues.sort((a, b) => b.score - a.score);\n",
" return {\n",
Expand All @@ -355,14 +328,14 @@
"};\n",
"\n",
"// Define the route function\n",
"function routeBCOrCD(state: { which: string }): string[] {\n",
"function routeBCOrCD(state: typeof GraphStateStableSorting.State): string[] {\n",
" if (state.which === \"cd\") {\n",
" return [\"c\", \"d\"];\n",
" }\n",
" return [\"b\", \"c\"];\n",
"}\n",
"\n",
"const builder3 = new StateGraph<IState3>({ channels: graphState3 })\n",
"const builder3 = new StateGraph(GraphStateStableSorting)\n",
" .addNode(\"a\", nodeA3.call.bind(nodeA3))\n",
" .addEdge(START, \"a\")\n",
" .addNode(\"b\", nodeB3.call.bind(nodeB3))\n",
Expand Down Expand Up @@ -394,7 +367,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand All @@ -416,13 +389,6 @@
"let g3result2 = await graph3.invoke({ aggregate: [], which: \"cd\" });\n",
"console.log(\"Result 2: \", g3result2);\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit a1440b7

Please sign in to comment.