From e55369d1c642b09ddf432ca6609e4a8df641b207 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Wed, 27 Nov 2024 15:20:04 -0800 Subject: [PATCH] Implement interrupt(...) and Command({resume: ...}) (#690) --- libs/checkpoint/src/base.ts | 4 ++ libs/checkpoint/src/serde/types.ts | 2 + libs/langgraph/src/constants.ts | 20 +++++++++ libs/langgraph/src/errors.ts | 12 +++++- libs/langgraph/src/interrupt.ts | 18 ++++++++ libs/langgraph/src/pregel/algo.ts | 51 +++++++++++++++++----- libs/langgraph/src/pregel/index.ts | 17 ++++---- libs/langgraph/src/pregel/io.ts | 23 +++++++++- libs/langgraph/src/pregel/loop.ts | 57 ++++++++++++++----------- libs/langgraph/src/pregel/retry.ts | 4 +- libs/langgraph/src/tests/pregel.test.ts | 33 +++++++++++--- libs/langgraph/src/web.ts | 3 +- 12 files changed, 192 insertions(+), 52 deletions(-) create mode 100644 libs/langgraph/src/interrupt.ts diff --git a/libs/checkpoint/src/base.ts b/libs/checkpoint/src/base.ts index 5382b4e5f..3512436cd 100644 --- a/libs/checkpoint/src/base.ts +++ b/libs/checkpoint/src/base.ts @@ -8,6 +8,8 @@ import type { } from "./types.js"; import { ERROR, + INTERRUPT, + RESUME, SCHEDULED, type ChannelProtocol, type SendProtocol, @@ -203,4 +205,6 @@ export function maxChannelVersion( export const WRITES_IDX_MAP: Record = { [ERROR]: -1, [SCHEDULED]: -2, + [INTERRUPT]: -3, + [RESUME]: -4, }; diff --git a/libs/checkpoint/src/serde/types.ts b/libs/checkpoint/src/serde/types.ts index aaf429812..a2dfeac08 100644 --- a/libs/checkpoint/src/serde/types.ts +++ b/libs/checkpoint/src/serde/types.ts @@ -1,6 +1,8 @@ export const TASKS = "__pregel_tasks"; export const ERROR = "__error__"; export const SCHEDULED = "__scheduled__"; +export const INTERRUPT = "__interrupt__"; +export const RESUME = "__resume__"; // Mirrors BaseChannel in "@langchain/langgraph" export interface ChannelProtocol< diff --git a/libs/langgraph/src/constants.ts b/libs/langgraph/src/constants.ts index b19e7d3b1..229785469 100644 --- a/libs/langgraph/src/constants.ts +++ b/libs/langgraph/src/constants.ts @@ -1,3 +1,5 @@ +export const MISSING = Symbol.for("__missing__"); + export const INPUT = "__input__"; export const ERROR = "__error__"; export const CONFIG_KEY_SEND = "__pregel_send"; @@ -6,11 +8,13 @@ export const CONFIG_KEY_CHECKPOINTER = "__pregel_checkpointer"; export const CONFIG_KEY_RESUMING = "__pregel_resuming"; export const CONFIG_KEY_TASK_ID = "__pregel_task_id"; export const CONFIG_KEY_STREAM = "__pregel_stream"; +export const CONFIG_KEY_RESUME_VALUE = "__pregel_resume_value"; // this one is part of public API export const CONFIG_KEY_CHECKPOINT_MAP = "checkpoint_map"; export const INTERRUPT = "__interrupt__"; +export const RESUME = "__resume__"; export const RUNTIME_PLACEHOLDER = "__pregel_runtime_placeholder__"; export const RECURSION_LIMIT_DEFAULT = 25; @@ -22,9 +26,11 @@ export const PUSH = "__pregel_push"; export const PULL = "__pregel_pull"; export const TASK_NAMESPACE = "6ba7b831-9dad-11d1-80b4-00c04fd430c8"; +export const NULL_TASK_ID = "00000000-0000-0000-0000-000000000000"; export const RESERVED = [ INTERRUPT, + RESUME, ERROR, TASKS, CONFIG_KEY_SEND, @@ -114,3 +120,17 @@ export type Interrupt = { value: any; when: "during"; }; + +export class Command { + lg_name = "Command"; + + resume: R; + + constructor(args: { resume: R }) { + this.resume = args.resume; + } +} + +export function _isCommand(x: unknown): x is Command { + return typeof x === "object" && !!x && (x as Command).lg_name === "Command"; +} diff --git a/libs/langgraph/src/errors.ts b/libs/langgraph/src/errors.ts index f54c9d6f6..1c0358c1f 100644 --- a/libs/langgraph/src/errors.ts +++ b/libs/langgraph/src/errors.ts @@ -18,6 +18,12 @@ export class BaseLangGraphError extends Error { } } +export class GraphBubbleUp extends BaseLangGraphError { + get is_bubble_up() { + return true; + } +} + export class GraphRecursionError extends BaseLangGraphError { constructor(message?: string, fields?: BaseLangGraphErrorFields) { super(message, fields); @@ -40,7 +46,7 @@ export class GraphValueError extends BaseLangGraphError { } } -export class GraphInterrupt extends BaseLangGraphError { +export class GraphInterrupt extends GraphBubbleUp { interrupts: Interrupt[]; constructor(interrupts?: Interrupt[], fields?: BaseLangGraphErrorFields) { @@ -74,6 +80,10 @@ export class NodeInterrupt extends GraphInterrupt { } } +export function isGraphBubbleUp(e?: Error): e is GraphBubbleUp { + return e !== undefined && (e as GraphBubbleUp).is_bubble_up === true; +} + export function isGraphInterrupt( e?: GraphInterrupt | Error ): e is GraphInterrupt { diff --git a/libs/langgraph/src/interrupt.ts b/libs/langgraph/src/interrupt.ts new file mode 100644 index 000000000..ad033d5ee --- /dev/null +++ b/libs/langgraph/src/interrupt.ts @@ -0,0 +1,18 @@ +import { RunnableConfig } from "@langchain/core/runnables"; +import { AsyncLocalStorageProviderSingleton } from "@langchain/core/singletons"; +import { GraphInterrupt } from "./errors.js"; +import { CONFIG_KEY_RESUME_VALUE, MISSING } from "./constants.js"; + +export function interrupt(value: I): R { + const config: RunnableConfig | undefined = + AsyncLocalStorageProviderSingleton.getRunnableConfig(); + if (!config) { + throw new Error("Called interrupt() outside the context of a graph."); + } + const resume = config.configurable?.[CONFIG_KEY_RESUME_VALUE]; + if (resume !== MISSING) { + return resume as R; + } else { + throw new GraphInterrupt([{ value, when: "during" }]); + } +} diff --git a/libs/langgraph/src/pregel/algo.ts b/libs/langgraph/src/pregel/algo.ts index 7b1d71bc2..e353aa462 100644 --- a/libs/langgraph/src/pregel/algo.ts +++ b/libs/langgraph/src/pregel/algo.ts @@ -42,6 +42,10 @@ import { CHECKPOINT_NAMESPACE_END, PUSH, PULL, + RESUME, + CONFIG_KEY_RESUME_VALUE, + NULL_TASK_ID, + MISSING, } from "../constants.js"; import { PregelExecutableTask, PregelTaskDescription } from "./types.js"; import { EmptyChannelError, InvalidUpdateError } from "../errors.js"; @@ -189,6 +193,8 @@ export function _localWrite( commit(writes); } +const IGNORE = new Set([PUSH, RESUME, INTERRUPT]); + export function _applyWrites>( checkpoint: Checkpoint, channels: Cc, @@ -196,6 +202,10 @@ export function _applyWrites>( // eslint-disable-next-line @typescript-eslint/no-explicit-any getNextVersion?: (version: any, channel: BaseChannel) => any ): Record { + // if no task has triggers this is applying writes from the null task only + // so we don't do anything other than update the channels written to + const bumpStep = tasks.some((task) => task.triggers.length > 0); + // Filter out non instances of BaseChannel const onlyChannels = Object.fromEntries( Object.entries(channels).filter(([_, value]) => isBaseChannel(value)) @@ -240,7 +250,7 @@ export function _applyWrites>( } // Clear pending sends - if (checkpoint.pending_sends) { + if (checkpoint.pending_sends?.length && bumpStep) { checkpoint.pending_sends = []; } @@ -252,7 +262,9 @@ export function _applyWrites>( const pendingWritesByManaged = {} as Record; for (const task of tasks) { for (const [chan, val] of task.writes) { - if (chan === TASKS) { + if (IGNORE.has(chan)) { + // do nothing + } else if (chan === TASKS) { checkpoint.pending_sends.push({ node: (val as Send).node, args: (val as Send).args, @@ -313,14 +325,16 @@ export function _applyWrites>( } // Channels that weren't updated in this step are notified of a new step - for (const chan of Object.keys(onlyChannels)) { - if (!updatedChannels.has(chan)) { - const updated = onlyChannels[chan].update([]); - if (updated && getNextVersion !== undefined) { - checkpoint.channel_versions[chan] = getNextVersion( - maxVersion, - onlyChannels[chan] - ); + if (bumpStep) { + for (const chan of Object.keys(onlyChannels)) { + if (!updatedChannels.has(chan)) { + const updated = onlyChannels[chan].update([]); + if (updated && getNextVersion !== undefined) { + checkpoint.channel_versions[chan] = getNextVersion( + maxVersion, + onlyChannels[chan] + ); + } } } } @@ -350,6 +364,7 @@ export function _prepareNextTasks< Cc extends StrRecord >( checkpoint: ReadonlyCheckpoint, + pendingWrites: [string, string, unknown][] | undefined, processes: Nn, channels: Cc, managed: ManagedValueMapping, @@ -363,6 +378,7 @@ export function _prepareNextTasks< Cc extends StrRecord >( checkpoint: ReadonlyCheckpoint, + pendingWrites: [string, string, unknown][] | undefined, processes: Nn, channels: Cc, managed: ManagedValueMapping, @@ -376,6 +392,7 @@ export function _prepareNextTasks< Cc extends StrRecord >( checkpoint: ReadonlyCheckpoint, + pendingWrites: [string, string, unknown][] | undefined, processes: Nn, channels: Cc, managed: ManagedValueMapping, @@ -393,6 +410,7 @@ export function _prepareNextTasks< const task = _prepareSingleTask( [PUSH, i], checkpoint, + pendingWrites, processes, channels, managed, @@ -410,6 +428,7 @@ export function _prepareNextTasks< const task = _prepareSingleTask( [PULL, name], checkpoint, + pendingWrites, processes, channels, managed, @@ -430,6 +449,7 @@ export function _prepareSingleTask< >( taskPath: [string, string | number], checkpoint: ReadonlyCheckpoint, + pendingWrites: [string, string, unknown][] | undefined, processes: Nn, channels: Cc, managed: ManagedValueMapping, @@ -444,6 +464,7 @@ export function _prepareSingleTask< >( taskPath: [string, string | number], checkpoint: ReadonlyCheckpoint, + pendingWrites: [string, string, unknown][] | undefined, processes: Nn, channels: Cc, managed: ManagedValueMapping, @@ -458,6 +479,7 @@ export function _prepareSingleTask< >( taskPath: [string, string | number], checkpoint: ReadonlyCheckpoint, + pendingWrites: [string, string, unknown][] | undefined, processes: Nn, channels: Cc, managed: ManagedValueMapping, @@ -472,6 +494,7 @@ export function _prepareSingleTask< >( taskPath: [string, string | number], checkpoint: ReadonlyCheckpoint, + pendingWrites: [string, string, unknown][] | undefined, processes: Nn, channels: Cc, managed: ManagedValueMapping, @@ -537,6 +560,9 @@ export function _prepareSingleTask< metadata = { ...metadata, ...proc.metadata }; } const writes: [keyof Cc, unknown][] = []; + const resume = pendingWrites?.find( + (w) => [taskId, NULL_TASK_ID].includes(w[0]) && w[1] === RESUME + ); return { name: packet.node, input: packet.args, @@ -587,6 +613,7 @@ export function _prepareSingleTask< ...configurable[CONFIG_KEY_CHECKPOINT_MAP], [parentNamespace]: checkpoint.id, }, + [CONFIG_KEY_RESUME_VALUE]: resume ? resume[2] : MISSING, checkpoint_id: undefined, checkpoint_ns: taskCheckpointNamespace, }, @@ -661,6 +688,9 @@ export function _prepareSingleTask< metadata = { ...metadata, ...proc.metadata }; } const writes: [keyof Cc, unknown][] = []; + const resume = pendingWrites?.find( + (w) => [taskId, NULL_TASK_ID].includes(w[0]) && w[1] === RESUME + ); const taskCheckpointNamespace = `${checkpointNamespace}${CHECKPOINT_NAMESPACE_END}${taskId}`; return { name, @@ -714,6 +744,7 @@ export function _prepareSingleTask< ...configurable[CONFIG_KEY_CHECKPOINT_MAP], [parentNamespace]: checkpoint.id, }, + [CONFIG_KEY_RESUME_VALUE]: resume ? resume[2] : MISSING, checkpoint_id: undefined, checkpoint_ns: taskCheckpointNamespace, }, diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index a6efea636..bfbdb21e6 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -49,6 +49,7 @@ import { CHECKPOINT_NAMESPACE_END, CONFIG_KEY_STREAM, CONFIG_KEY_TASK_ID, + Command, } from "../constants.js"; import { PregelExecutableTask, @@ -64,6 +65,7 @@ import { GraphRecursionError, GraphValueError, InvalidUpdateError, + isGraphBubbleUp, isGraphInterrupt, } from "../errors.js"; import { @@ -405,6 +407,7 @@ export class Pregel< const nextTasks = Object.values( _prepareNextTasks( saved.checkpoint, + saved.pendingWrites, this.nodes, channels, managed, @@ -915,7 +918,7 @@ export class Pregel< * @param options.debug Whether to print debug information during execution. */ override async stream( - input: PregelInputType, + input: PregelInputType | Command, options?: Partial> ): Promise> { // The ensureConfig method called internally defaults recursionLimit to 25 if not @@ -994,7 +997,7 @@ export class Pregel< } override async *_streamIterator( - input: PregelInputType, + input: PregelInputType | Command, options?: Partial> ): AsyncGenerator { const streamSubgraphs = options?.subgraphs; @@ -1126,11 +1129,11 @@ export class Pregel< // Timeouts will be thrown for await (const { task, error } of taskStream) { if (error !== undefined) { - if (isGraphInterrupt(error)) { + if (isGraphBubbleUp(error)) { if (loop.isNested) { throw error; } - if (error.interrupts.length) { + if (isGraphInterrupt(error) && error.interrupts.length) { loop.putWrites( task.id, error.interrupts.map((interrupt) => [INTERRUPT, interrupt]) @@ -1140,13 +1143,11 @@ export class Pregel< loop.putWrites(task.id, [ [ERROR, { message: error.message, name: error.name }], ]); + throw error; } } else { loop.putWrites(task.id, task.writes); } - if (error !== undefined && !isGraphInterrupt(error)) { - throw error; - } } if (debug) { @@ -1244,7 +1245,7 @@ export class Pregel< * @param options.debug Whether to print debug information during execution. */ override async invoke( - input: PregelInputType, + input: PregelInputType | Command, options?: Partial> ): Promise { const streamMode = options?.streamMode ?? "values"; diff --git a/libs/langgraph/src/pregel/io.ts b/libs/langgraph/src/pregel/io.ts index 4cb0aacc1..eb16553db 100644 --- a/libs/langgraph/src/pregel/io.ts +++ b/libs/langgraph/src/pregel/io.ts @@ -1,7 +1,9 @@ import type { PendingWrite } from "@langchain/langgraph-checkpoint"; +import { validate } from "uuid"; + import type { BaseChannel } from "../channels/base.js"; import type { PregelExecutableTask } from "./types.js"; -import { TAG_HIDDEN } from "../constants.js"; +import { Command, NULL_TASK_ID, RESUME, TAG_HIDDEN } from "../constants.js"; import { EmptyChannelError } from "../errors.js"; export function readChannel( @@ -50,6 +52,25 @@ export function readChannels( } } +export function* mapCommand( + cmd: Command +): Generator<[string, string, unknown]> { + if (cmd.resume) { + if ( + typeof cmd.resume === "object" && + !!cmd.resume && + Object.keys(cmd.resume).length && + Object.keys(cmd.resume).every(validate) + ) { + for (const [tid, resume] of Object.entries(cmd.resume)) { + yield [tid, RESUME, resume]; + } + } else { + yield [NULL_TASK_ID, RESUME, cmd.resume]; + } + } +} + /** * Map input chunk to a sequence of pending writes in the form [channel, value]. */ diff --git a/libs/langgraph/src/pregel/loop.ts b/libs/langgraph/src/pregel/loop.ts index 32cb2e7b5..0c86ccc63 100644 --- a/libs/langgraph/src/pregel/loop.ts +++ b/libs/langgraph/src/pregel/loop.ts @@ -22,7 +22,9 @@ import { } from "../channels/base.js"; import { PregelExecutableTask, StreamMode } from "./types.js"; import { + _isCommand, CHECKPOINT_NAMESPACE_SEPARATOR, + Command, CONFIG_KEY_CHECKPOINT_MAP, CONFIG_KEY_READ, CONFIG_KEY_RESUMING, @@ -30,8 +32,8 @@ import { ERROR, INPUT, INTERRUPT, + RESUME, TAG_HIDDEN, - TASKS, } from "../constants.js"; import { _applyWrites, @@ -46,6 +48,7 @@ import { prefixGenerator, } from "../utils.js"; import { + mapCommand, mapInput, mapOutputUpdates, mapOutputValues, @@ -71,14 +74,13 @@ import { LangGraphRunnableConfig } from "./runnable_types.js"; const INPUT_DONE = Symbol.for("INPUT_DONE"); const INPUT_RESUMING = Symbol.for("INPUT_RESUMING"); const DEFAULT_LOOP_LIMIT = 25; -const SPECIAL_CHANNELS = [ERROR, INTERRUPT]; // [namespace, streamMode, payload] export type StreamChunk = [string[], StreamMode, unknown]; export type PregelLoopInitializeParams = { // eslint-disable-next-line @typescript-eslint/no-explicit-any - input?: any; + input?: any | Command; config: RunnableConfig; checkpointer?: BaseCheckpointSaver; outputKeys: string | string[]; @@ -93,7 +95,7 @@ export type PregelLoopInitializeParams = { type PregelLoopParams = { // eslint-disable-next-line @typescript-eslint/no-explicit-any - input?: any; + input?: any | Command; config: RunnableConfig; checkpointer?: BaseCheckpointSaver; checkpoint: Checkpoint; @@ -185,7 +187,7 @@ function createDuplexStream(...streams: IterableReadableWritableStream[]) { export class PregelLoop { // eslint-disable-next-line @typescript-eslint/no-explicit-any - protected input?: any; + protected input?: any | Command; // eslint-disable-next-line @typescript-eslint/no-explicit-any output: any; @@ -227,8 +229,6 @@ export class PregelLoop { protected skipDoneTasks: boolean; - protected taskWritesLeft: number = 0; - protected prevCheckpointConfig: RunnableConfig | undefined; status: @@ -297,7 +297,9 @@ export class PregelLoop { config.configurable[CONFIG_KEY_STREAM] ); } - const skipDoneTasks = config.configurable?.checkpoint_id === undefined; + const skipDoneTasks = config.configurable + ? !("checkpoint_id" in config.configurable) + : true; const isNested = CONFIG_KEY_READ in (config.configurable ?? {}); if ( !isNested && @@ -446,18 +448,6 @@ export class PregelLoop { if (writes.length === 0) { return; } - // adjust taskWritesLeft - const firstChannel = writes[0][0]; - const anyChannelIsSend = writes.find(([channel]) => channel === TASKS); - const alwaysSave = - anyChannelIsSend || SPECIAL_CHANNELS.includes(firstChannel); - if (!alwaysSave && !this.taskWritesLeft) { - return this._outputWrites(taskId, writes); - } else if (firstChannel !== INTERRUPT) { - // INTERRUPT makes us want to save the last task's writes - // so we don't decrement tasksWritesLeft in that case - this.taskWritesLeft -= 1; - } // save writes const pendingWrites: CheckpointPendingWrite[] = writes.map( ([key, value]) => { @@ -480,7 +470,9 @@ export class PregelLoop { if (putWritePromise !== undefined) { this.checkpointerPromises.push(putWritePromise); } - this._outputWrites(taskId, writes); + if (this.tasks) { + this._outputWrites(taskId, writes); + } } _outputWrites(taskId: string, writes: [string, unknown][], cached = false) { @@ -605,6 +597,7 @@ export class PregelLoop { const nextTasks = _prepareNextTasks( this.checkpoint, + this.checkpointPendingWrites, this.nodes, this.channels, this.managed, @@ -619,7 +612,6 @@ export class PregelLoop { } ); this.tasks = nextTasks; - this.taskWritesLeft = Object.values(this.tasks).length - 1; // Produce debug output if (this.checkpointer) { @@ -649,7 +641,7 @@ export class PregelLoop { // if there are pending writes from a previous loop, apply them if (this.skipDoneTasks && this.checkpointPendingWrites.length > 0) { for (const [tid, k, v] of this.checkpointPendingWrites) { - if (k === ERROR || k === INTERRUPT) { + if (k === ERROR || k === INTERRUPT || k === RESUME) { continue; } const task = Object.values(this.tasks).find((t) => t.id === tid); @@ -745,8 +737,24 @@ export class PregelLoop { ) ); this._emit(valuesOutput); - // map inputs to channel updates + } else if (_isCommand(this.input)) { + const writes: { [key: string]: PendingWrite[] } = {}; + // group writes by task id + for (const [tid, key, value] of mapCommand(this.input)) { + if (writes[tid] === undefined) { + writes[tid] = []; + } + writes[tid].push([key, value]); + } + if (Object.keys(writes).length === 0) { + throw new EmptyInputError("Received empty Command input"); + } + // save writes + for (const [tid, ws] of Object.entries(writes)) { + this.putWrites(tid, ws); + } } else { + // map inputs to channel updates const inputWrites = await gatherIterator(mapInput(inputKeys, this.input)); if (inputWrites.length === 0) { throw new EmptyInputError( @@ -755,6 +763,7 @@ export class PregelLoop { } const discardTasks = _prepareNextTasks( this.checkpoint, + this.checkpointPendingWrites, this.nodes, this.channels, this.managed, diff --git a/libs/langgraph/src/pregel/retry.ts b/libs/langgraph/src/pregel/retry.ts index 6e17c987a..60094130e 100644 --- a/libs/langgraph/src/pregel/retry.ts +++ b/libs/langgraph/src/pregel/retry.ts @@ -1,4 +1,4 @@ -import { getSubgraphsSeenSet, isGraphInterrupt } from "../errors.js"; +import { getSubgraphsSeenSet, isGraphBubbleUp } from "../errors.js"; import { PregelExecutableTask } from "./types.js"; import type { RetryPolicy } from "./utils/index.js"; @@ -129,7 +129,7 @@ async function _runWithRetry( } catch (e: any) { error = e; error.pregelTaskId = pregelTask.id; - if (isGraphInterrupt(error)) { + if (isGraphBubbleUp(error)) { break; } if (resolvedRetryPolicy === undefined) { diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index f9b396869..d7d8bd348 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -88,11 +88,13 @@ import { MultipleSubgraphsError, NodeInterrupt, } from "../errors.js"; -import { ERROR, INTERRUPT, PULL, PUSH, Send } from "../constants.js"; +import { Command, ERROR, INTERRUPT, PULL, PUSH, Send } from "../constants.js"; import { ManagedValueMapping } from "../managed/base.js"; import { SharedValue } from "../managed/shared_value.js"; import { MessagesAnnotation } from "../graph/messages_annotation.js"; import { LangGraphRunnableConfig } from "../pregel/runnable_types.js"; +import { initializeAsyncLocalStorageSingleton } from "../setup/async_local_storage.js"; +import { interrupt } from "../interrupt.js"; expect.extend({ toHaveKeyStartingWith(received: object, prefix: string) { @@ -120,6 +122,11 @@ export function runPregelTests( afterAll(teardown); } + beforeAll(() => { + // Will occur naturally if user imports from main `@langchain/langgraph` endpoint. + initializeAsyncLocalStorageSingleton(); + }); + describe("Channel", () => { describe("writeTo", () => { it("should return a ChannelWrite instance with the expected writes", () => { @@ -860,6 +867,7 @@ export function runPregelTests( const taskDescriptions = Object.values( _prepareNextTasks( checkpoint, + [], processes, channels, managed, @@ -988,6 +996,7 @@ export function runPregelTests( const tasks = Object.values( _prepareNextTasks( checkpoint, + [], processes, channels, managed, @@ -2699,10 +2708,9 @@ export function runPregelTests( s: typeof StateAnnotation.State ): Partial => { toolTwoNodeCount += 1; - if (s.market === "DE") { - throw new NodeInterrupt("Just because..."); - } - return { my_key: " all good" }; + const answer: string = + s.market === "DE" ? interrupt("Just because...") : " all good"; + return { my_key: answer }; }; const toolTwoGraph = new StateGraph(StateAnnotation) @@ -2790,6 +2798,21 @@ export function runPregelTests( await gatherIterator(toolTwoCheckpointer.list(thread1, { limit: 2 })) ).slice(-1)[0].config, }); + + // resume execution + expect( + await gatherIterator( + toolTwo.stream(new Command({ resume: " this is great" }), { + configurable: { thread_id: "1" }, + }) + ) + ).toEqual([ + { + tool_two: { + my_key: " this is great", + }, + }, + ]); }); it("should not cancel node on other node interrupted", async () => { diff --git a/libs/langgraph/src/web.ts b/libs/langgraph/src/web.ts index 5518c38ca..182513785 100644 --- a/libs/langgraph/src/web.ts +++ b/libs/langgraph/src/web.ts @@ -31,7 +31,8 @@ export { } from "./channels/index.js"; export { type AnnotationRoot as _INTERNAL_ANNOTATION_ROOT } from "./graph/index.js"; export { type RetryPolicy } from "./pregel/utils/index.js"; -export { Send } from "./constants.js"; +export { Send, Command, type Interrupt } from "./constants.js"; +export { interrupt } from "./interrupt.js"; export { MemorySaver,