From 2d1c891804cd049411f720630ba6c4338101a2e1 Mon Sep 17 00:00:00 2001 From: Brace Sproul Date: Tue, 10 Dec 2024 14:49:06 -0800 Subject: [PATCH] fix(langgraph): Allow multiple interrupts per node (#713) --- .vscode/settings.json | 3 +- libs/checkpoint/src/base.ts | 6 ++ libs/checkpoint/src/memory.ts | 93 +++++++++-------- libs/langgraph/src/constants.ts | 3 + libs/langgraph/src/interrupt.ts | 127 +++++++++++++++++++++--- libs/langgraph/src/pregel/algo.ts | 31 +++--- libs/langgraph/src/pregel/index.ts | 12 ++- libs/langgraph/src/pregel/io.ts | 18 +++- libs/langgraph/src/pregel/loop.ts | 43 ++++++-- libs/langgraph/src/pregel/types.ts | 6 ++ libs/langgraph/src/tests/pregel.test.ts | 70 +++++++++++++ 11 files changed, 322 insertions(+), 90 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 713fbf4a9..3e0ded24d 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,5 +4,6 @@ "checkpointer", "Checkpointers", "Pregel" - ] + ], + "typescript.tsdk": "node_modules/typescript/lib" } \ No newline at end of file diff --git a/libs/checkpoint/src/base.ts b/libs/checkpoint/src/base.ts index 3512436cd..6d7616bf8 100644 --- a/libs/checkpoint/src/base.ts +++ b/libs/checkpoint/src/base.ts @@ -208,3 +208,9 @@ export const WRITES_IDX_MAP: Record = { [INTERRUPT]: -3, [RESUME]: -4, }; + +export function getCheckpointId(config: RunnableConfig): string { + return ( + config.configurable?.checkpoint_id || config.configurable?.thread_ts || "" + ); +} diff --git a/libs/checkpoint/src/memory.ts b/libs/checkpoint/src/memory.ts index adc192cea..415f9120a 100644 --- a/libs/checkpoint/src/memory.ts +++ b/libs/checkpoint/src/memory.ts @@ -5,6 +5,8 @@ import { CheckpointListOptions, CheckpointTuple, copyCheckpoint, + getCheckpointId, + WRITES_IDX_MAP, } from "./base.js"; import { SerializerProtocol } from "./serde/base.js"; import { @@ -29,7 +31,7 @@ export class MemorySaver extends BaseCheckpointSaver { Record> > = {}; - writes: Record = {}; + writes: Record> = {}; constructor(serde?: SerializerProtocol) { super(serde); @@ -42,13 +44,14 @@ export class MemorySaver extends BaseCheckpointSaver { ) { let pendingSends: SendProtocol[] = []; if (parentCheckpointId !== undefined) { + const key = _generateKey(threadId, checkpointNs, parentCheckpointId); pendingSends = await Promise.all( - this.writes[_generateKey(threadId, checkpointNs, parentCheckpointId)] + Object.values(this.writes[key] || {}) ?.filter(([_taskId, channel]) => { return channel === TASKS; }) .map(([_taskId, _channel, writes]) => { - return this.serde.loadsTyped("json", writes as string); + return this.serde.loadsTyped("json", writes); }) ?? [] ); } @@ -58,15 +61,13 @@ export class MemorySaver extends BaseCheckpointSaver { async getTuple(config: RunnableConfig): Promise { const thread_id = config.configurable?.thread_id; const checkpoint_ns = config.configurable?.checkpoint_ns ?? ""; - let checkpoint_id = config.configurable?.checkpoint_id; + let checkpoint_id = getCheckpointId(config); if (checkpoint_id) { const saved = this.storage[thread_id]?.[checkpoint_ns]?.[checkpoint_id]; if (saved !== undefined) { const [checkpoint, metadata, parentCheckpointId] = saved; - const writes = - this.writes[_generateKey(thread_id, checkpoint_ns, checkpoint_id)] ?? - []; + const key = _generateKey(thread_id, checkpoint_ns, checkpoint_id); const pending_sends = await this._getPendingSends( thread_id, checkpoint_ns, @@ -77,13 +78,15 @@ export class MemorySaver extends BaseCheckpointSaver { pending_sends, }; const pendingWrites: CheckpointPendingWrite[] = await Promise.all( - writes.map(async ([taskId, channel, value]) => { - return [ - taskId, - channel, - await this.serde.loadsTyped("json", value as string), - ]; - }) + Object.values(this.writes[key] || {}).map( + async ([taskId, channel, value]) => { + return [ + taskId, + channel, + await this.serde.loadsTyped("json", value), + ]; + } + ) ); const checkpointTuple: CheckpointTuple = { config, @@ -114,9 +117,7 @@ export class MemorySaver extends BaseCheckpointSaver { )[0]; const saved = checkpoints[checkpoint_id]; const [checkpoint, metadata, parentCheckpointId] = saved; - const writes = - this.writes[_generateKey(thread_id, checkpoint_ns, checkpoint_id)] ?? - []; + const key = _generateKey(thread_id, checkpoint_ns, checkpoint_id); const pending_sends = await this._getPendingSends( thread_id, checkpoint_ns, @@ -127,13 +128,15 @@ export class MemorySaver extends BaseCheckpointSaver { pending_sends, }; const pendingWrites: CheckpointPendingWrite[] = await Promise.all( - writes.map(async ([taskId, channel, value]) => { - return [ - taskId, - channel, - await this.serde.loadsTyped("json", value as string), - ]; - }) + Object.values(this.writes[key] || {}).map( + async ([taskId, channel, value]) => { + return [ + taskId, + channel, + await this.serde.loadsTyped("json", value), + ]; + } + ) ); const checkpointTuple: CheckpointTuple = { config: { @@ -176,6 +179,7 @@ export class MemorySaver extends BaseCheckpointSaver { ? [config.configurable?.thread_id] : Object.keys(this.storage); const configCheckpointNamespace = config.configurable?.checkpoint_ns; + const configCheckpointId = config.configurable?.checkpoint_id; for (const threadId of threadIds) { for (const checkpointNamespace of Object.keys( @@ -196,7 +200,12 @@ export class MemorySaver extends BaseCheckpointSaver { checkpointId, [checkpoint, metadataStr, parentCheckpointId], ] of sortedCheckpoints) { - // Filter by checkpoint ID + // Filter by checkpoint ID from config + if (configCheckpointId && checkpointId !== configCheckpointId) { + continue; + } + + // Filter by checkpoint ID from before config if ( before && before.configurable?.checkpoint_id && @@ -224,25 +233,23 @@ export class MemorySaver extends BaseCheckpointSaver { // Limit search results if (limit !== undefined) { if (limit <= 0) break; - // eslint-disable-next-line no-param-reassign limit -= 1; } - const writes = - this.writes[ - _generateKey(threadId, checkpointNamespace, checkpointId) - ] ?? []; + const key = _generateKey(threadId, checkpointNamespace, checkpointId); + const writes = Object.values(this.writes[key] || {}); const pending_sends = await this._getPendingSends( threadId, checkpointNamespace, parentCheckpointId ); + const pendingWrites: CheckpointPendingWrite[] = await Promise.all( writes.map(async ([taskId, channel, value]) => { return [ taskId, channel, - await this.serde.loadsTyped("json", value as string), + await this.serde.loadsTyped("json", value), ]; }) ); @@ -336,16 +343,22 @@ export class MemorySaver extends BaseCheckpointSaver { `Failed to put writes. The passed RunnableConfig is missing a required "checkpoint_id" field in its "configurable" property.` ); } - const key = _generateKey(threadId, checkpointNamespace, checkpointId); - if (this.writes[key] === undefined) { - this.writes[key] = []; + const outerKey = _generateKey(threadId, checkpointNamespace, checkpointId); + const outerWrites_ = this.writes[outerKey]; + if (this.writes[outerKey] === undefined) { + this.writes[outerKey] = {}; } - const pendingWrites: CheckpointPendingWrite[] = writes.map( - ([channel, value]) => { - const [, serializedValue] = this.serde.dumpsTyped(value); - return [taskId, channel, serializedValue]; + writes.forEach(([channel, value], idx) => { + const [, serializedValue] = this.serde.dumpsTyped(value); + const innerKey: [string, number] = [ + taskId, + WRITES_IDX_MAP[channel] || idx, + ]; + const innerKeyStr = `${innerKey[0]},${innerKey[1]}`; + if (innerKey[1] >= 0 && outerWrites_ && innerKeyStr in outerWrites_) { + return; } - ); - this.writes[key].push(...pendingWrites); + this.writes[outerKey][innerKeyStr] = [taskId, channel, serializedValue]; + }); } } diff --git a/libs/langgraph/src/constants.ts b/libs/langgraph/src/constants.ts index e6022bf24..4357bd241 100644 --- a/libs/langgraph/src/constants.ts +++ b/libs/langgraph/src/constants.ts @@ -9,6 +9,9 @@ export const CONFIG_KEY_RESUMING = "__pregel_resuming"; export const CONFIG_KEY_TASK_ID = "__pregel_task_id"; export const CONFIG_KEY_STREAM = "__pregel_stream"; export const CONFIG_KEY_RESUME_VALUE = "__pregel_resume_value"; +export const CONFIG_KEY_WRITES = "__pregel_writes"; +export const CONFIG_KEY_SCRATCHPAD = "__pregel_scratchpad"; +export const CONFIG_KEY_CHECKPOINT_NS = "checkpoint_ns"; // this one is part of public API export const CONFIG_KEY_CHECKPOINT_MAP = "checkpoint_map"; diff --git a/libs/langgraph/src/interrupt.ts b/libs/langgraph/src/interrupt.ts index 593a4de39..6b58b7dc4 100644 --- a/libs/langgraph/src/interrupt.ts +++ b/libs/langgraph/src/interrupt.ts @@ -1,25 +1,126 @@ -import { RunnableConfig } from "@langchain/core/runnables"; import { AsyncLocalStorageProviderSingleton } from "@langchain/core/singletons"; +import { CheckpointPendingWrite } from "@langchain/langgraph-checkpoint"; +import { RunnableConfig } from "@langchain/core/runnables"; import { GraphInterrupt } from "./errors.js"; -import { CONFIG_KEY_RESUME_VALUE, MISSING } from "./constants.js"; +import { + CONFIG_KEY_CHECKPOINT_NS, + CONFIG_KEY_SCRATCHPAD, + CONFIG_KEY_TASK_ID, + CONFIG_KEY_WRITES, + CONFIG_KEY_SEND, + CHECKPOINT_NAMESPACE_SEPARATOR, + NULL_TASK_ID, + RESUME, +} from "./constants.js"; +import { PregelScratchpad } from "./pregel/types.js"; +/** + * Interrupts the execution of a graph node. + * This function can be used to pause execution of a node, and return the value of the `resume` + * input when the graph is re-invoked using `Command`. + * Multiple interrupts can be called within a single node, and each will be handled sequentially. + * + * When an interrupt is called: + * 1. If there's a `resume` value available (from a previous `Command`), it returns that value. + * 2. Otherwise, it throws a `GraphInterrupt` with the provided value + * 3. The graph can be resumed by passing a `Command` with a `resume` value + * + * @param value - The value to include in the interrupt. This will be available in task.interrupts[].value + * @returns The `resume` value provided when the graph is re-invoked with a Command + * + * @example + * ```typescript + * // Define a node that uses multiple interrupts + * const nodeWithInterrupts = () => { + * // First interrupt - will pause execution and include {value: 1} in task values + * const answer1 = interrupt({ value: 1 }); + * + * // Second interrupt - only called after first interrupt is resumed + * const answer2 = interrupt({ value: 2 }); + * + * // Use the resume values + * return { myKey: answer1 + " " + answer2 }; + * }; + * + * // Resume the graph after first interrupt + * await graph.stream(new Command({ resume: "answer 1" })); + * + * // Resume the graph after second interrupt + * await graph.stream(new Command({ resume: "answer 2" })); + * // Final result: { myKey: "answer 1 answer 2" } + * ``` + * + * @throws {Error} If called outside the context of a graph + * @throws {GraphInterrupt} When no resume value is available + */ export function interrupt(value: I): R { const config: RunnableConfig | undefined = AsyncLocalStorageProviderSingleton.getRunnableConfig(); if (!config) { throw new Error("Called interrupt() outside the context of a graph."); } - const resume = config.configurable?.[CONFIG_KEY_RESUME_VALUE]; - if (resume !== MISSING) { - return resume as R; + + // Track interrupt index + const scratchpad: PregelScratchpad = + config.configurable?.[CONFIG_KEY_SCRATCHPAD]; + if (scratchpad.interruptCounter === undefined) { + scratchpad.interruptCounter = 0; } else { - throw new GraphInterrupt([ - { - value, - when: "during", - resumable: true, - ns: config.configurable?.checkpoint_ns?.split("|"), - }, - ]); + scratchpad.interruptCounter += 1; + } + const idx = scratchpad.interruptCounter; + + // Find previous resume values + const taskId = config.configurable?.[CONFIG_KEY_TASK_ID]; + const writes: CheckpointPendingWrite[] = + config.configurable?.[CONFIG_KEY_WRITES] ?? []; + + if (!scratchpad.resume) { + const newResume = (writes.find( + (w) => w[0] === taskId && w[1] === RESUME + )?.[2] || []) as R | R[]; + scratchpad.resume = Array.isArray(newResume) ? newResume : [newResume]; } + + if (scratchpad.resume) { + if (idx < scratchpad.resume.length) { + return scratchpad.resume[idx]; + } + } + + // Find current resume value + if (!scratchpad.usedNullResume) { + scratchpad.usedNullResume = true; + const sortedWrites = [...writes].sort( + (a, b) => b[0].localeCompare(a[0]) // Sort in reverse order + ); + + for (const [tid, c, v] of sortedWrites) { + if (tid === NULL_TASK_ID && c === RESUME) { + if (scratchpad.resume.length !== idx) { + throw new Error( + `Resume length mismatch: ${scratchpad.resume.length} !== ${idx}` + ); + } + scratchpad.resume.push(v as R); + const send = config.configurable?.[CONFIG_KEY_SEND]; + if (send) { + send([[RESUME, scratchpad.resume]]); + } + return v as R; + } + } + } + + // No resume value found + throw new GraphInterrupt([ + { + value, + when: "during", + resumable: true, + ns: config.configurable?.[CONFIG_KEY_CHECKPOINT_NS]?.split( + CHECKPOINT_NAMESPACE_SEPARATOR + ), + }, + ]); } diff --git a/libs/langgraph/src/pregel/algo.ts b/libs/langgraph/src/pregel/algo.ts index 0bdf167e2..5f3300a3c 100644 --- a/libs/langgraph/src/pregel/algo.ts +++ b/libs/langgraph/src/pregel/algo.ts @@ -43,9 +43,9 @@ import { PUSH, PULL, RESUME, - CONFIG_KEY_RESUME_VALUE, NULL_TASK_ID, - MISSING, + CONFIG_KEY_SCRATCHPAD, + CONFIG_KEY_WRITES, } from "../constants.js"; import { PregelExecutableTask, PregelTaskDescription } from "./types.js"; import { EmptyChannelError, InvalidUpdateError } from "../errors.js"; @@ -166,7 +166,6 @@ export function _localWrite( // eslint-disable-next-line @typescript-eslint/no-explicit-any commit: (writes: [string, any][]) => any, processes: Record, - channels: Record, managed: ManagedValueMapping, // eslint-disable-next-line @typescript-eslint/no-explicit-any writes: [string, any][] @@ -187,8 +186,6 @@ export function _localWrite( } // replace any runtime values with placeholders managed.replaceRuntimeValues(step, value.args); - } else if (!(chan in channels) && !managed.get(chan)) { - console.warn(`Skipping write for channel '${chan}' which has no readers`); } } commit(writes); @@ -589,9 +586,6 @@ export function _prepareSingleTask< metadata = { ...metadata, ...proc.metadata }; } const writes: [keyof Cc, unknown][] = []; - const resume = pendingWrites?.find( - (w) => [taskId, NULL_TASK_ID].includes(w[0]) && w[1] === RESUME - ); return { name: packet.node, input: packet.args, @@ -615,7 +609,6 @@ export function _prepareSingleTask< step, (items: [keyof Cc, unknown][]) => writes.push(...items), processes, - channels, managed, writes_ ), @@ -643,9 +636,11 @@ export function _prepareSingleTask< ...configurable[CONFIG_KEY_CHECKPOINT_MAP], [parentNamespace]: checkpoint.id, }, - [CONFIG_KEY_RESUME_VALUE]: resume - ? resume[2] - : configurable[CONFIG_KEY_RESUME_VALUE] ?? MISSING, + [CONFIG_KEY_WRITES]: [ + ...(pendingWrites || []), + ...(configurable[CONFIG_KEY_WRITES] || []), + ].filter((w) => w[0] === NULL_TASK_ID || w[0] === taskId), + [CONFIG_KEY_SCRATCHPAD]: {}, checkpoint_id: undefined, checkpoint_ns: taskCheckpointNamespace, }, @@ -721,9 +716,6 @@ export function _prepareSingleTask< metadata = { ...metadata, ...proc.metadata }; } const writes: [keyof Cc, unknown][] = []; - const resume = pendingWrites?.find( - (w) => [taskId, NULL_TASK_ID].includes(w[0]) && w[1] === RESUME - ); const taskCheckpointNamespace = `${checkpointNamespace}${CHECKPOINT_NAMESPACE_END}${taskId}`; return { name, @@ -750,7 +742,6 @@ export function _prepareSingleTask< writes.push(...items); }, processes, - channels, managed, writes_ ), @@ -778,9 +769,11 @@ export function _prepareSingleTask< ...configurable[CONFIG_KEY_CHECKPOINT_MAP], [parentNamespace]: checkpoint.id, }, - [CONFIG_KEY_RESUME_VALUE]: resume - ? resume[2] - : configurable[CONFIG_KEY_RESUME_VALUE] ?? MISSING, + [CONFIG_KEY_WRITES]: [ + ...(pendingWrites || []), + ...(configurable[CONFIG_KEY_WRITES] || []), + ].filter((w) => w[0] === NULL_TASK_ID || w[0] === taskId), + [CONFIG_KEY_SCRATCHPAD]: {}, checkpoint_id: undefined, checkpoint_ns: taskCheckpointNamespace, }, diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 3c3edbe6a..81226dfdb 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -53,6 +53,7 @@ import { Command, NULL_TASK_ID, INPUT, + RESUME, PUSH, } from "../constants.js"; import { @@ -1227,10 +1228,13 @@ export class Pregel< throw error; } if (isGraphInterrupt(error) && error.interrupts.length) { - loop.putWrites( - task.id, - error.interrupts.map((interrupt) => [INTERRUPT, interrupt]) - ); + const interrupts: PendingWrite[] = + error.interrupts.map((interrupt) => [INTERRUPT, interrupt]); + const resumes = task.writes.filter((w) => w[0] === RESUME); + if (resumes.length) { + interrupts.push(...resumes); + } + loop.putWrites(task.id, interrupts); } } else { loop.putWrites(task.id, [ diff --git a/libs/langgraph/src/pregel/io.ts b/libs/langgraph/src/pregel/io.ts index 5c6b1e214..178282dd1 100644 --- a/libs/langgraph/src/pregel/io.ts +++ b/libs/langgraph/src/pregel/io.ts @@ -1,4 +1,7 @@ -import type { PendingWrite } from "@langchain/langgraph-checkpoint"; +import type { + CheckpointPendingWrite, + PendingWrite, +} from "@langchain/langgraph-checkpoint"; import { validate } from "uuid"; import type { BaseChannel } from "../channels/base.js"; @@ -64,7 +67,8 @@ export function readChannels( * Map input chunk to a sequence of pending writes in the form (channel, value). */ export function* mapCommand( - cmd: Command + cmd: Command, + pendingWrites: CheckpointPendingWrite[] ): Generator<[string, string, unknown]> { if (cmd.graph === Command.PARENT) { throw new InvalidUpdateError("There is no parent graph."); @@ -97,7 +101,15 @@ export function* mapCommand( Object.keys(cmd.resume).every(validate) ) { for (const [tid, resume] of Object.entries(cmd.resume)) { - yield [tid, RESUME, resume]; + // Find existing resume values for this task ID + const existing = (pendingWrites.find( + ([id, type]) => id === tid && type === RESUME + )?.[2] ?? []) as unknown[]; + + // Ensure we have an array and append the resume value + existing.push(resume); + + yield [tid, RESUME, existing]; } } else { yield [NULL_TASK_ID, RESUME, cmd.resume]; diff --git a/libs/langgraph/src/pregel/loop.ts b/libs/langgraph/src/pregel/loop.ts index d525f3ea4..3881c37f5 100644 --- a/libs/langgraph/src/pregel/loop.ts +++ b/libs/langgraph/src/pregel/loop.ts @@ -13,6 +13,7 @@ import { All, BaseStore, AsyncBatchedStore, + WRITES_IDX_MAP, } from "@langchain/langgraph-checkpoint"; import { @@ -445,16 +446,31 @@ export class PregelLoop { * @param writes */ putWrites(taskId: string, writes: PendingWrite[]) { - if (writes.length === 0) { + let writesCopy = writes; + if (writesCopy.length === 0) { return; } + + // deduplicate writes to special channels, last write wins + if (writesCopy.every(([key]) => key in WRITES_IDX_MAP)) { + writesCopy = Array.from( + new Map(writesCopy.map((w) => [w[0], w])).values() + ); + } // save writes - const pendingWrites: CheckpointPendingWrite[] = writes.map( - ([key, value]) => { - return [taskId, key, value]; + for (const [c, v] of writesCopy) { + if (c in WRITES_IDX_MAP) { + const idx = this.checkpointPendingWrites.findIndex( + (w) => w[0] === taskId && w[1] === c + ); + if (idx !== -1) { + this.checkpointPendingWrites[idx] = [taskId, c, v]; + } else { + this.checkpointPendingWrites.push([taskId, c, v]); + } } - ); - this.checkpointPendingWrites.push(...pendingWrites); + } + const putWritePromise = this.checkpointer?.putWrites( { ...this.checkpointConfig, @@ -464,14 +480,15 @@ export class PregelLoop { checkpoint_id: this.checkpoint.id, }, }, - writes, + writesCopy, taskId ); if (putWritePromise !== undefined) { this.checkpointerPromises.push(putWritePromise); } + if (this.tasks) { - this._outputWrites(taskId, writes); + this._outputWrites(taskId, writesCopy); } } @@ -541,7 +558,10 @@ export class PregelLoop { if (![INPUT_DONE, INPUT_RESUMING].includes(this.input)) { await this._first(inputKeys); } else if ( - Object.values(this.tasks).every((task) => task.writes.length > 0) + Object.values(this.tasks).every( + (task) => + task.writes.filter(([c]) => !(c in WRITES_IDX_MAP)).length > 0 + ) ) { const writes = Object.values(this.tasks).flatMap((t) => t.writes); // All tasks have finished @@ -723,7 +743,10 @@ export class PregelLoop { if (_isCommand(this.input)) { const writes: { [key: string]: PendingWrite[] } = {}; // group writes by task id - for (const [tid, key, value] of mapCommand(this.input)) { + for (const [tid, key, value] of mapCommand( + this.input, + this.checkpointPendingWrites + )) { if (writes[tid] === undefined) { writes[tid] = []; } diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index e1ab1d3e0..b98cdb023 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -226,3 +226,9 @@ export interface StateSnapshot { */ readonly tasks: PregelTaskDescription[]; } + +export type PregelScratchpad = { + interruptCounter: number; + usedNullResume: boolean; + resume: Resume[]; +}; diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index 3c19d073b..0e017048f 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -9246,6 +9246,76 @@ graph TD; foo: "abc|node-1|node-2", }); }); + + it("can throw a node interrupt multiple times in a single node", async () => { + const GraphAnnotation = Annotation.Root({ + myKey: Annotation({ + reducer: (a, b) => a + b, + }), + }); + + const nodeOne = (_: typeof GraphAnnotation.State) => { + const answer = interrupt({ value: 1 }); + const answer2 = interrupt({ value: 2 }); + return { myKey: answer + " " + answer2 }; + }; + + const graph = new StateGraph(GraphAnnotation) + .addNode("one", nodeOne) + .addEdge(START, "one") + .compile({ checkpointer: await createCheckpointer() }); + + const config = { + configurable: { thread_id: "test_multi_interrupt" }, + streamMode: "values" as const, + }; + const firstResult = await gatherIterator( + graph.stream( + { + myKey: "DE", + }, + config + ) + ); + expect(firstResult).toBeDefined(); + const firstState = await graph.getState(config); + expect(firstState.tasks).toHaveLength(1); + expect(firstState.tasks[0].interrupts).toHaveLength(1); + expect(firstState.tasks[0].interrupts[0].value).toEqual({ + value: 1, + }); + + const secondResult = await gatherIterator( + graph.stream( + new Command({ + resume: "answer 1", + }), + config + ) + ); + expect(secondResult).toBeDefined(); + + const secondState = await graph.getState(config); + expect(secondState.tasks).toHaveLength(1); + expect(secondState.tasks[0].interrupts).toHaveLength(1); + expect(secondState.tasks[0].interrupts[0].value).toEqual({ + value: 2, + }); + + const thirdResult = await gatherIterator( + graph.stream( + new Command({ + resume: "answer 2", + }), + config + ) + ); + expect(thirdResult[thirdResult.length - 1].myKey).toEqual( + "DEanswer 1 answer 2" + ); + const thirdState = await graph.getState(config); + expect(thirdState.tasks).toHaveLength(0); + }); } runPregelTests(() => new MemorySaverAssertImmutable());