diff --git a/libs/langgraph/src/graph/graph.ts b/libs/langgraph/src/graph/graph.ts index 21d9b402..2ff7cec0 100644 --- a/libs/langgraph/src/graph/graph.ts +++ b/libs/langgraph/src/graph/graph.ts @@ -31,6 +31,7 @@ import { gatherIteratorSync, RunnableCallable } from "../utils.js"; import { InvalidUpdateError, NodeInterrupt } from "../errors.js"; import { StateDefinition, StateType } from "./annotation.js"; import type { LangGraphRunnableConfig } from "../pregel/runnable_types.js"; +import { isPregelLike } from "../pregel/utils/subgraph.js"; /** Special reserved node name denoting the start of a graph. */ export const START = "__start__"; @@ -128,9 +129,15 @@ export class Branch< export type NodeSpec = { runnable: Runnable; metadata?: Record; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + subgraphs?: Pregel[]; }; -export type AddNodeOptions = { metadata?: Record }; +export type AddNodeOptions = { + metadata?: Record; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + subgraphs?: Pregel[]; +}; export class Graph< N extends string = typeof END, @@ -202,12 +209,15 @@ export class Graph< throw new Error(`Node \`${key}\` is reserved.`); } + const runnable = _coerceToRunnable( + // Account for arbitrary state due to Send API + action as RunnableLike + ); + this.nodes[key as unknown as N] = { - runnable: _coerceToRunnable( - // Account for arbitrary state due to Send API - action as RunnableLike - ), + runnable, metadata: options?.metadata, + subgraphs: isPregelLike(runnable) ? [runnable] : options?.subgraphs, } as NodeSpecType; return this as Graph; @@ -462,6 +472,7 @@ export class CompiledGraph< channels: [], triggers: [], metadata: node.metadata, + subgraphs: node.subgraphs, }) .pipe(node.runnable) .pipe( diff --git a/libs/langgraph/src/graph/state.ts b/libs/langgraph/src/graph/state.ts index 71c95bcd..32b58b36 100644 --- a/libs/langgraph/src/graph/state.ts +++ b/libs/langgraph/src/graph/state.ts @@ -48,6 +48,7 @@ import { import type { RetryPolicy } from "../pregel/utils/index.js"; import { isConfiguredManagedValue, ManagedValueSpec } from "../managed/base.js"; import type { LangGraphRunnableConfig } from "../pregel/runnable_types.js"; +import { isPregelLike } from "../pregel/utils/subgraph.js"; const ROOT = "__root__"; @@ -314,11 +315,17 @@ export class StateGraph< if (options?.input !== undefined) { this._addSchema(options.input.spec); } + + const runnable = _coerceToRunnable(action) as unknown as Runnable; const nodeSpec: StateGraphNodeSpec = { - runnable: _coerceToRunnable(action) as unknown as Runnable, + runnable, retryPolicy: options?.retryPolicy, metadata: options?.metadata, input: options?.input?.spec ?? this._schemaDefinition, + subgraphs: isPregelLike(runnable) + ? // eslint-disable-next-line @typescript-eslint/no-explicit-any + [runnable as any] + : options?.subgraphs, }; this.nodes[key as unknown as N] = nodeSpec; @@ -541,6 +548,7 @@ export class CompiledStateGraph< bound: node?.runnable, metadata: node?.metadata, retryPolicy: node?.retryPolicy, + subgraphs: node?.subgraphs, }); } } diff --git a/libs/langgraph/src/pregel/algo.ts b/libs/langgraph/src/pregel/algo.ts index 871c2bf8..13cff6eb 100644 --- a/libs/langgraph/src/pregel/algo.ts +++ b/libs/langgraph/src/pregel/algo.ts @@ -541,6 +541,7 @@ export function _prepareSingleTask< name: packet.node, input: packet.args, proc: node, + subgraphs: proc.subgraphs, writes, config: patchConfig( mergeConfigs(config, { @@ -665,6 +666,7 @@ export function _prepareSingleTask< name, input: val, proc: node, + subgraphs: proc.subgraphs, writes, config: patchConfig( mergeConfigs(config, { diff --git a/libs/langgraph/src/pregel/debug.ts b/libs/langgraph/src/pregel/debug.ts index 04458b96..df83867d 100644 --- a/libs/langgraph/src/pregel/debug.ts +++ b/libs/langgraph/src/pregel/debug.ts @@ -192,7 +192,8 @@ export function* mapDebugCheckpoint< const taskStates: Record = {}; for (const task of tasks) { - if (!findSubgraphPregel(task.proc)) continue; + const candidates = task.subgraphs?.length ? task.subgraphs : [task.proc]; + if (!candidates.find(findSubgraphPregel)) continue; let taskNs = `${task.name as string}:${task.id}`; if (parentNs) taskNs = `${parentNs}|${taskNs}`; diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 39f20f49..dc8062f1 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -350,29 +350,36 @@ export class Pregel< // find the subgraph if any // eslint-disable-next-line @typescript-eslint/no-explicit-any type SubgraphPregelType = Pregel | undefined; - const graph = findSubgraphPregel(node.bound) as SubgraphPregelType; - // if found, yield recursively - if (graph !== undefined) { - if (name === namespace) { - yield [name, graph]; - return; - } - if (namespace === undefined) { - yield [name, graph]; - } - if (recurse) { - let newNamespace = namespace; - if (namespace !== undefined) { - newNamespace = namespace.slice(name.length + 1); + + const candidates = node.subgraphs?.length ? node.subgraphs : [node.bound]; + + for (const candidate of candidates) { + const graph = findSubgraphPregel(candidate) as SubgraphPregelType; + + if (graph !== undefined) { + if (name === namespace) { + yield [name, graph]; + return; } - for (const [subgraphName, subgraph] of graph.getSubgraphs( - newNamespace, - recurse - )) { - yield [ - `${name}${CHECKPOINT_NAMESPACE_SEPARATOR}${subgraphName}`, - subgraph, - ]; + + if (namespace === undefined) { + yield [name, graph]; + } + + if (recurse) { + let newNamespace = namespace; + if (namespace !== undefined) { + newNamespace = namespace.slice(name.length + 1); + } + for (const [subgraphName, subgraph] of graph.getSubgraphs( + newNamespace, + recurse + )) { + yield [ + `${name}${CHECKPOINT_NAMESPACE_SEPARATOR}${subgraphName}`, + subgraph, + ]; + } } } } diff --git a/libs/langgraph/src/pregel/read.ts b/libs/langgraph/src/pregel/read.ts index 20952b00..5912f05e 100644 --- a/libs/langgraph/src/pregel/read.ts +++ b/libs/langgraph/src/pregel/read.ts @@ -83,6 +83,7 @@ interface PregelNodeArgs config?: RunnableConfig; metadata?: Record; retryPolicy?: RetryPolicy; + subgraphs?: Runnable[]; } // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -117,6 +118,8 @@ export class PregelNode< retryPolicy?: RetryPolicy; + subgraphs?: Runnable[]; + constructor(fields: PregelNodeArgs) { const { channels, @@ -128,6 +131,7 @@ export class PregelNode< metadata, retryPolicy, tags, + subgraphs, } = fields; const mergedTags = [ ...(fields.config?.tags ? fields.config.tags : []), @@ -154,6 +158,7 @@ export class PregelNode< this.metadata = metadata ?? this.metadata; this.tags = mergedTags; this.retryPolicy = retryPolicy; + this.subgraphs = subgraphs; } getWriters(): Array { diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index 4407bc8b..d216faa8 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -107,6 +107,7 @@ export interface PregelExecutableTask< readonly retry_policy?: RetryPolicy; readonly id: string; readonly path?: [string, ...(string | number)[]]; + readonly subgraphs?: Runnable[]; } export interface StateSnapshot { diff --git a/libs/langgraph/src/pregel/utils/subgraph.ts b/libs/langgraph/src/pregel/utils/subgraph.ts index 647572da..025c46bf 100644 --- a/libs/langgraph/src/pregel/utils/subgraph.ts +++ b/libs/langgraph/src/pregel/utils/subgraph.ts @@ -1,4 +1,8 @@ -import { RunnableSequence, Runnable } from "@langchain/core/runnables"; +import { + RunnableSequence, + Runnable, + RunnableLike, +} from "@langchain/core/runnables"; import type { PregelInterface } from "../types.js"; // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -7,9 +11,10 @@ function isRunnableSequence( ): x is RunnableSequence { return "steps" in x && Array.isArray(x.steps); } -function isPregelLike( + +export function isPregelLike( // eslint-disable-next-line @typescript-eslint/no-explicit-any - x: PregelInterface | Runnable + x: PregelInterface | RunnableLike // eslint-disable-next-line @typescript-eslint/no-explicit-any ): x is PregelInterface { return ( diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index 7f034de5..9f23e19b 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -10,6 +10,7 @@ import { jest, describe, beforeEach, + test, afterAll, } from "@jest/globals"; import { @@ -5946,58 +5947,94 @@ export function runPregelTests( }); describe("Subgraphs", () => { - it("nested graph interrupts parallel", async () => { - const InnerStateAnnotation = Annotation.Root({ - myKey: Annotation({ - reducer: (a, b) => a + b, - default: () => "", - }), - myOtherKey: Annotation, - }); - - const inner1 = async (state: typeof InnerStateAnnotation.State) => { - await new Promise((resolve) => { - setTimeout(resolve, 100); - }); - return { myKey: "got here", myOtherKey: state.myKey }; - }; - - const inner2 = (state: typeof InnerStateAnnotation.State) => { - return { - myKey: " and there", - myOtherKey: state.myKey, - }; - }; - - const inner = new StateGraph(InnerStateAnnotation) - .addNode("inner1", inner1) - .addNode("inner2", inner2) - .addEdge("inner1", "inner2") - .addEdge("__start__", "inner1"); - - const StateAnnotation = Annotation.Root({ - myKey: Annotation({ - reducer: (a, b) => a + b, - default: () => "", - }), - }); - - const outer1 = (_state: typeof StateAnnotation.State) => { - return { myKey: " and parallel" }; - }; - - const outer2 = (_state: typeof StateAnnotation.State) => { - return { myKey: " and back again" }; - }; - - const graph = new StateGraph(StateAnnotation) - .addNode("inner", inner.compile({ interruptBefore: ["inner2"] })) - .addNode("outer1", outer1) - .addNode("outer2", outer2) - .addEdge(START, "inner") - .addEdge(START, "outer1") - .addEdge(["inner", "outer1"], "outer2"); - + test.each([ + [ + "nested graph interrupts parallel", + (() => { + const inner = new StateGraph( + Annotation.Root({ + myKey: Annotation({ + reducer: (a, b) => a + b, + default: () => "", + }), + myOtherKey: Annotation, + }) + ) + .addNode("inner1", async (state) => { + await new Promise((resolve) => setTimeout(resolve, 100)); + return { myKey: "got here", myOtherKey: state.myKey }; + }) + .addNode("inner2", (state) => ({ + myKey: " and there", + myOtherKey: state.myKey, + })) + .addEdge("inner1", "inner2") + .addEdge("__start__", "inner1") + .compile({ interruptBefore: ["inner2"] }); + + const graph = new StateGraph( + Annotation.Root({ + myKey: Annotation({ + reducer: (a, b) => a + b, + default: () => "", + }), + }) + ) + .addNode("inner", inner) + .addNode("outer1", () => ({ myKey: " and parallel" })) + .addNode("outer2", () => ({ myKey: " and back again" })) + .addEdge(START, "inner") + .addEdge(START, "outer1") + .addEdge(["inner", "outer1"], "outer2"); + + return graph; + })(), + ], + [ + "nested graph interrupts parallel: subgraph in lambda", + (() => { + const inner = new StateGraph( + Annotation.Root({ + myKey: Annotation({ + reducer: (a, b) => a + b, + default: () => "", + }), + myOtherKey: Annotation, + }) + ) + .addNode("inner1", async (state) => { + await new Promise((resolve) => setTimeout(resolve, 100)); + return { myKey: "got here", myOtherKey: state.myKey }; + }) + .addNode("inner2", (state) => ({ + myKey: " and there", + myOtherKey: state.myKey, + })) + .addEdge("inner1", "inner2") + .addEdge("__start__", "inner1") + .compile({ interruptBefore: ["inner2"] }); + + const graph = new StateGraph( + Annotation.Root({ + myKey: Annotation({ + reducer: (a, b) => a + b, + default: () => "", + }), + }) + ) + .addNode("inner", (state, config) => inner.invoke(state, config), { + subgraphs: [inner], + }) + .addNode("outer1", () => ({ myKey: " and parallel" })) + .addNode("outer2", () => ({ myKey: " and back again" })) + .addEdge(START, "inner") + .addEdge(START, "outer1") + .addEdge(["inner", "outer1"], "outer2"); + + return graph; + })(), + ], + ])("%s", async (_name, graph) => { const checkpointer = await createCheckpointer(); const app = graph.compile({ checkpointer }); @@ -8074,35 +8111,82 @@ export function runPregelTests( ); }); - it("debug nested subgraph", async () => { - const state = Annotation.Root({ - messages: Annotation({ - reducer: (a, b) => a.concat(b), - default: () => [], - }), - }); - - const child = new StateGraph(state) - .addNode("c_one", () => ({ messages: ["c_one"] })) - .addNode("c_two", () => ({ messages: ["c_two"] })) - .addEdge(START, "c_one") - .addEdge("c_one", "c_two") - .addEdge("c_two", END); - - const parent = new StateGraph(state) - .addNode("p_one", () => ({ messages: ["p_one"] })) - .addNode("p_two", child.compile()) - .addEdge(START, "p_one") - .addEdge("p_one", "p_two") - .addEdge("p_two", END); + test.each([ + [ + "debug nested subgraph: default graph", + (() => { + const state = Annotation.Root({ + messages: Annotation({ + reducer: (a, b) => a.concat(b), + default: () => [], + }), + }); - const grandParent = new StateGraph(state) - .addNode("gp_one", () => ({ messages: ["gp_one"] })) - .addNode("gp_two", parent.compile()) - .addEdge(START, "gp_one") - .addEdge("gp_one", "gp_two") - .addEdge("gp_two", END); + const child = new StateGraph(state) + .addNode("c_one", () => ({ messages: ["c_one"] })) + .addNode("c_two", () => ({ messages: ["c_two"] })) + .addEdge(START, "c_one") + .addEdge("c_one", "c_two") + .addEdge("c_two", END); + + const parent = new StateGraph(state) + .addNode("p_one", () => ({ messages: ["p_one"] })) + .addNode("p_two", child.compile()) + .addEdge(START, "p_one") + .addEdge("p_one", "p_two") + .addEdge("p_two", END); + + const grandParent = new StateGraph(state) + .addNode("gp_one", () => ({ messages: ["gp_one"] })) + .addNode("gp_two", parent.compile()) + .addEdge(START, "gp_one") + .addEdge("gp_one", "gp_two") + .addEdge("gp_two", END); + + return grandParent; + })(), + ], + [ + "debug nested subgraph: subgraph as third argument", + (() => { + const state = Annotation.Root({ + messages: Annotation({ + reducer: (a, b) => a.concat(b), + default: () => [], + }), + }); + const child = new StateGraph(state) + .addNode("c_one", () => ({ messages: ["c_one"] })) + .addNode("c_two", () => ({ messages: ["c_two"] })) + .addEdge(START, "c_one") + .addEdge("c_one", "c_two") + .addEdge("c_two", END) + .compile(); + + const parent = new StateGraph(state) + .addNode("p_one", () => ({ messages: ["p_one"] })) + .addNode("p_two", (state, config) => child.invoke(state, config), { + subgraphs: [child], + }) + .addEdge(START, "p_one") + .addEdge("p_one", "p_two") + .addEdge("p_two", END) + .compile(); + + const grandParent = new StateGraph(state) + .addNode("gp_one", () => ({ messages: ["gp_one"] })) + .addNode("gp_two", (state, config) => parent.invoke(state, config), { + subgraphs: [parent], + }) + .addEdge(START, "gp_one") + .addEdge("gp_one", "gp_two") + .addEdge("gp_two", END); + + return grandParent; + })(), + ], + ])("%s", async (_title, grandParent) => { const checkpointer = await createCheckpointer(); const graph = grandParent.compile({ checkpointer });