diff --git a/examples/agent_executor/base.ipynb b/examples/agent_executor/base.ipynb index cbc3ecc2..7eeddb4d 100644 --- a/examples/agent_executor/base.ipynb +++ b/examples/agent_executor/base.ipynb @@ -79,7 +79,6 @@ "const AgentState = Annotation.Root({\n", " messages: Annotation({\n", " reducer: (x, y) => x.concat(y),\n", - " default: () => [],\n", " }),\n", "});" ] diff --git a/examples/how-tos/branching.ipynb b/examples/how-tos/branching.ipynb index fb04ae45..bdfc2504 100644 --- a/examples/how-tos/branching.ipynb +++ b/examples/how-tos/branching.ipynb @@ -64,7 +64,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -80,20 +80,13 @@ } ], "source": [ - "import { END, START, StateGraph } from \"@langchain/langgraph\";\n", - "import { StateGraphArgs } from \"@langchain/langgraph\";\n", - "\n", - "// Define the state type\n", - "interface IState {\n", - " aggregate: string[];\n", - "}\n", + "import { END, START, StateGraph, Annotation } from \"@langchain/langgraph\";\n", "\n", - "const graphState: StateGraphArgs[\"channels\"] = {\n", - " aggregate: {\n", - " value: (x: string[], y: string[]) => x.concat(y),\n", - " default: () => [],\n", - " },\n", - "};\n", + "const GraphState = Annotation.Root({\n", + " aggregate: Annotation({\n", + " reducer: (x, y) => x.concat(y),\n", + " })\n", + "})\n", "\n", "// Define the ReturnNodeValue class\n", "class ReturnNodeValue {\n", @@ -103,7 +96,7 @@ " this._value = value;\n", " }\n", "\n", - " public call(state: IState) {\n", + " public call(state: typeof GraphState.State) {\n", " console.log(`Adding ${this._value} to ${state.aggregate}`);\n", " return { aggregate: [this._value] };\n", " }\n", @@ -115,7 +108,7 @@ "const nodeC = new ReturnNodeValue(\"I'm C\");\n", "const nodeD = new ReturnNodeValue(\"I'm D\");\n", "\n", - "const builder = new StateGraph({ channels: graphState })\n", + "const builder = new StateGraph(GraphState)\n", " .addNode(\"a\", nodeA.call.bind(nodeA))\n", " .addEdge(START, \"a\")\n", " .addNode(\"b\", nodeB.call.bind(nodeB))\n", @@ -151,7 +144,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -167,23 +160,14 @@ } ], "source": [ - "// Define the state type\n", - "interface IState2 {\n", - " // The operator.add reducer function makes this append-only\n", - " aggregate: string[];\n", - " which: string;\n", - "}\n", - "\n", - "const graphState2: StateGraphArgs[\"channels\"] = {\n", - " aggregate: {\n", - " value: (x: string[], y: string[]) => x.concat(y),\n", - " default: () => [],\n", - " },\n", - " which: {\n", - " value: (x: string, y: string) => (y ? y : x),\n", - " default: () => \"bc\",\n", - " },\n", - "};\n", + "const GraphStateConditionalBranching = Annotation.Root({\n", + " aggregate: Annotation({\n", + " reducer: (x, y) => x.concat(y),\n", + " }),\n", + " which: Annotation({\n", + " reducer: (x: string, y: string) => (y ?? x),\n", + " })\n", + "})\n", "\n", "// Create the graph\n", "const nodeA2 = new ReturnNodeValue(\"I'm A\");\n", @@ -192,14 +176,14 @@ "const nodeD2 = new ReturnNodeValue(\"I'm D\");\n", "const nodeE2 = new ReturnNodeValue(\"I'm E\");\n", "// Define the route function\n", - "function routeCDorBC(state: IState2): string[] {\n", + "function routeCDorBC(state: typeof GraphStateConditionalBranching.State): string[] {\n", " if (state.which === \"cd\") {\n", " return [\"c\", \"d\"];\n", " }\n", " return [\"b\", \"c\"];\n", "}\n", "\n", - "const builder2 = new StateGraph({ channels: graphState2 })\n", + "const builder2 = new StateGraph(GraphStateConditionalBranching)\n", " .addNode(\"a\", nodeA2.call.bind(nodeA2))\n", " .addEdge(START, \"a\")\n", " .addNode(\"b\", nodeB2.call.bind(nodeB2))\n", @@ -222,7 +206,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -263,7 +247,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -282,6 +266,11 @@ } ], "source": [ + "type ScoredValue = {\n", + " value: string;\n", + " score: number;\n", + "};\n", + "\n", "const reduceFanouts = (left?: ScoredValue[], right?: ScoredValue[]) => {\n", " if (!left) {\n", " left = [];\n", @@ -293,34 +282,18 @@ " return left.concat(right);\n", "};\n", "\n", - "type ScoredValue = {\n", - " value: string;\n", - " score: number;\n", - "};\n", - "\n", - "// Define the state type\n", - "// 'value' defines the 'reducer', which determines how updates are applied\n", - "// 'default' defines the default value for the state\n", - "interface IState3 {\n", - " aggregate: string[];\n", - " which: string;\n", - " fanoutValues: ScoredValue[];\n", - "}\n", + "const GraphStateStableSorting = Annotation.Root({\n", + " aggregate: Annotation({\n", + " reducer: (x, y) => x.concat(y),\n", + " }),\n", + " which: Annotation({\n", + " reducer: (x: string, y: string) => (y ?? x),\n", + " }),\n", + " fanoutValues: Annotation({\n", + " reducer: reduceFanouts,\n", + " }),\n", + "})\n", "\n", - "const graphState3: StateGraphArgs[\"channels\"] = {\n", - " aggregate: {\n", - " value: (x: string[], y: string[]) => x.concat(y),\n", - " default: () => [],\n", - " },\n", - " which: {\n", - " value: (x: string, y: string) => (y ? y : x),\n", - " default: () => \"\",\n", - " },\n", - " fanoutValues: {\n", - " value: reduceFanouts,\n", - " default: () => [],\n", - " },\n", - "};\n", "\n", "class ParallelReturnNodeValue {\n", " private _value: string;\n", @@ -331,7 +304,7 @@ " this._score = score;\n", " }\n", "\n", - " public call(state: IState3) {\n", + " public call(state: typeof GraphStateStableSorting.State) {\n", " console.log(`Adding ${this._value} to ${state.aggregate}`);\n", " return { fanoutValues: [{ value: this._value, score: this._score }] };\n", " }\n", @@ -345,7 +318,7 @@ "const nodeC3 = new ParallelReturnNodeValue(\"I'm C\", 0.9);\n", "const nodeD3 = new ParallelReturnNodeValue(\"I'm D\", 0.3);\n", "\n", - "const aggregateFanouts = (state: { fanoutValues: ScoredValue[] }) => {\n", + "const aggregateFanouts = (state: typeof GraphStateStableSorting.State) => {\n", " // Sort by score (reversed)\n", " state.fanoutValues.sort((a, b) => b.score - a.score);\n", " return {\n", @@ -355,14 +328,14 @@ "};\n", "\n", "// Define the route function\n", - "function routeBCOrCD(state: { which: string }): string[] {\n", + "function routeBCOrCD(state: typeof GraphStateStableSorting.State): string[] {\n", " if (state.which === \"cd\") {\n", " return [\"c\", \"d\"];\n", " }\n", " return [\"b\", \"c\"];\n", "}\n", "\n", - "const builder3 = new StateGraph({ channels: graphState3 })\n", + "const builder3 = new StateGraph(GraphStateStableSorting)\n", " .addNode(\"a\", nodeA3.call.bind(nodeA3))\n", " .addEdge(START, \"a\")\n", " .addNode(\"b\", nodeB3.call.bind(nodeB3))\n", @@ -394,7 +367,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -416,13 +389,6 @@ "let g3result2 = await graph3.invoke({ aggregate: [], which: \"cd\" });\n", "console.log(\"Result 2: \", g3result2);\n" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {