From 89b0feb3dc0bbfd996a50beaf8b9c2ba47429500 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Tue, 10 Dec 2024 02:25:02 -0800 Subject: [PATCH] Fix validation and drawing for edgeless graphs --- libs/langgraph/src/constants.ts | 21 +++++++++ libs/langgraph/src/graph/graph.ts | 16 +++++++ libs/langgraph/src/graph/state.ts | 6 ++- libs/langgraph/src/pregel/read.ts | 5 +++ libs/langgraph/src/tests/pregel.test.ts | 59 +++++++++++++++++++++++++ 5 files changed, 106 insertions(+), 1 deletion(-) diff --git a/libs/langgraph/src/constants.ts b/libs/langgraph/src/constants.ts index 5b8447e9..e6022bf2 100644 --- a/libs/langgraph/src/constants.ts +++ b/libs/langgraph/src/constants.ts @@ -125,13 +125,34 @@ export type Interrupt = { }; export type CommandParams = { + /** + * Value to resume execution with. To be used together with {@link interrupt}. + */ resume?: R; + /** + * Graph to send the command to. Supported values are: + * - None: the current graph (default) + * - GraphCommand.PARENT: closest parent graph + */ graph?: string; + /** + * Update to apply to the graph's state. + */ // eslint-disable-next-line @typescript-eslint/no-explicit-any update?: Record; + /** + * Can be one of the following: + * - name of the node to navigate to next (any node that belongs to the specified `graph`) + * - sequence of node names to navigate to next + * - `Send` object (to execute a node with the input provided) + * - sequence of `Send` objects + */ goto?: string | Send | (string | Send)[]; }; +/** + * One or more commands to update the graph's state and send messages to nodes. + */ export class Command { lg_name = "Command"; diff --git a/libs/langgraph/src/graph/graph.ts b/libs/langgraph/src/graph/graph.ts index 1253e9c6..24c80e9b 100644 --- a/libs/langgraph/src/graph/graph.ts +++ b/libs/langgraph/src/graph/graph.ts @@ -156,12 +156,14 @@ export type NodeSpec = { metadata?: Record; // eslint-disable-next-line @typescript-eslint/no-explicit-any subgraphs?: Pregel[]; + ends?: string[]; }; export type AddNodeOptions = { metadata?: Record; // eslint-disable-next-line @typescript-eslint/no-explicit-any subgraphs?: Pregel[]; + ends?: string[]; }; export class Graph< @@ -243,6 +245,7 @@ export class Graph< runnable, metadata: options?.metadata, subgraphs: isPregelLike(runnable) ? [runnable] : options?.subgraphs, + ends: options?.ends, } as NodeSpecType; return this as Graph; @@ -452,6 +455,11 @@ export class Graph< } } } + for (const node of Object.values(this.nodes)) { + for (const target of node.ends ?? []) { + allTargets.add(target); + } + } // validate targets for (const node of Object.keys(this.nodes)) { if (!allTargets.has(node)) { @@ -519,6 +527,7 @@ export class CompiledGraph< triggers: [], metadata: node.metadata, subgraphs: node.subgraphs, + ends: node.ends, }) .pipe(node.runnable) .pipe( @@ -758,6 +767,13 @@ export class CompiledGraph< } } } + for (const [key, node] of Object.entries(this.builder.nodes) as [N, NodeSpec][]) { + if (node.ends !== undefined) { + for (const end of node.ends) { + addEdge(_escapeMermaidKeywords(key), _escapeMermaidKeywords(end), undefined, true); + } + } + } return graph; } diff --git a/libs/langgraph/src/graph/state.ts b/libs/langgraph/src/graph/state.ts index 3d40de7d..e46fe1d7 100644 --- a/libs/langgraph/src/graph/state.ts +++ b/libs/langgraph/src/graph/state.ts @@ -340,6 +340,7 @@ export class StateGraph< ? // eslint-disable-next-line @typescript-eslint/no-explicit-any [runnable as any] : options?.subgraphs, + ends: options?.ends, }; this.nodes[key as unknown as N] = nodeSpec; @@ -439,7 +440,9 @@ export class StateGraph< compiled.attachBranch(START, SELF, _getControlBranch() as Branch, { withReader: false, }); - for (const [key] of Object.entries>(this.nodes)) { + for (const [key] of Object.entries>( + this.nodes + )) { compiled.attachBranch( key as N, SELF, @@ -603,6 +606,7 @@ export class CompiledStateGraph< metadata: node?.metadata, retryPolicy: node?.retryPolicy, subgraphs: node?.subgraphs, + ends: node?.ends, }); } } diff --git a/libs/langgraph/src/pregel/read.ts b/libs/langgraph/src/pregel/read.ts index 63a95de7..fdb122ce 100644 --- a/libs/langgraph/src/pregel/read.ts +++ b/libs/langgraph/src/pregel/read.ts @@ -84,6 +84,7 @@ interface PregelNodeArgs metadata?: Record; retryPolicy?: RetryPolicy; subgraphs?: Runnable[]; + ends?: string[]; } // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -120,6 +121,8 @@ export class PregelNode< subgraphs?: Runnable[]; + ends?: string[]; + constructor(fields: PregelNodeArgs) { const { channels, @@ -132,6 +135,7 @@ export class PregelNode< retryPolicy, tags, subgraphs, + ends, } = fields; const mergedTags = [ ...(fields.config?.tags ? fields.config.tags : []), @@ -159,6 +163,7 @@ export class PregelNode< this.tags = mergedTags; this.retryPolicy = retryPolicy; this.subgraphs = subgraphs; + this.ends = ends; } getWriters(): Array { diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index 9347270c..2428359f 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -2005,6 +2005,65 @@ export function runPregelTests( expect(res).toEqual({ items: ["0", "1", "2", "2", "3"] }); }); + it("should support a simple edgeless graph", async () => { + const StateAnnotation = Annotation.Root({ + foo: Annotation, + }); + + const nodeA = async (state: typeof StateAnnotation.State) => { + console.log("Called A"); + const goto = state.foo === "foo" ? "nodeB" : "nodeC"; + return new Command({ + update: { + foo: "a", + }, + goto, + }); + }; + + const nodeB = async (state: typeof StateAnnotation.State) => { + console.log("Called B"); + return { + foo: state.foo + "|b", + }; + }; + + const nodeC = async (state: typeof StateAnnotation.State) => { + console.log("Called C"); + return { + foo: state.foo + "|c", + }; + }; + + const graph = new StateGraph(StateAnnotation) + .addNode("nodeA", nodeA, { + ends: ["nodeB", "nodeC"], + }) + .addNode("nodeB", nodeB) + .addNode("nodeC", nodeC) + .addEdge("__start__", "nodeA") + .compile(); + + const drawableGraph = await graph.getGraphAsync(); + const mermaid = drawableGraph.drawMermaid(); + // console.log(mermaid); + expect(mermaid).toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%% +graph TD; + __start__([

__start__

]):::first + nodeA(nodeA) + nodeB(nodeB) + nodeC(nodeC) + __start__ --> nodeA; + nodeA -.-> nodeB; + nodeA -.-> nodeC; + classDef default fill:#f2f0ff,line-height:1.2; + classDef first fill-opacity:0; + classDef last fill:#bfb6fc; +`) + expect(await graph.invoke({ foo: "foo" })).toEqual({ foo: "a|b" }); + expect(await graph.invoke({ foo: "" })).toEqual({ foo: "a|c" }); + }); + it("should handle send sequences correctly", async () => { const StateAnnotation = Annotation.Root({ items: Annotation({