Skip to content

Commit

Permalink
Implement interrupt(...) and Command({resume: ...}) (#690)
Browse files Browse the repository at this point in the history
  • Loading branch information
nfcampos authored Nov 27, 2024
1 parent b460d1c commit e55369d
Show file tree
Hide file tree
Showing 12 changed files with 192 additions and 52 deletions.
4 changes: 4 additions & 0 deletions libs/checkpoint/src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import type {
} from "./types.js";
import {
ERROR,
INTERRUPT,
RESUME,
SCHEDULED,
type ChannelProtocol,
type SendProtocol,
Expand Down Expand Up @@ -203,4 +205,6 @@ export function maxChannelVersion(
export const WRITES_IDX_MAP: Record<string, number> = {
[ERROR]: -1,
[SCHEDULED]: -2,
[INTERRUPT]: -3,
[RESUME]: -4,
};
2 changes: 2 additions & 0 deletions libs/checkpoint/src/serde/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
export const TASKS = "__pregel_tasks";
export const ERROR = "__error__";
export const SCHEDULED = "__scheduled__";
export const INTERRUPT = "__interrupt__";
export const RESUME = "__resume__";

// Mirrors BaseChannel in "@langchain/langgraph"
export interface ChannelProtocol<
Expand Down
20 changes: 20 additions & 0 deletions libs/langgraph/src/constants.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
export const MISSING = Symbol.for("__missing__");

export const INPUT = "__input__";
export const ERROR = "__error__";
export const CONFIG_KEY_SEND = "__pregel_send";
Expand All @@ -6,11 +8,13 @@ export const CONFIG_KEY_CHECKPOINTER = "__pregel_checkpointer";
export const CONFIG_KEY_RESUMING = "__pregel_resuming";
export const CONFIG_KEY_TASK_ID = "__pregel_task_id";
export const CONFIG_KEY_STREAM = "__pregel_stream";
export const CONFIG_KEY_RESUME_VALUE = "__pregel_resume_value";

// this one is part of public API
export const CONFIG_KEY_CHECKPOINT_MAP = "checkpoint_map";

export const INTERRUPT = "__interrupt__";
export const RESUME = "__resume__";
export const RUNTIME_PLACEHOLDER = "__pregel_runtime_placeholder__";
export const RECURSION_LIMIT_DEFAULT = 25;

Expand All @@ -22,9 +26,11 @@ export const PUSH = "__pregel_push";
export const PULL = "__pregel_pull";

export const TASK_NAMESPACE = "6ba7b831-9dad-11d1-80b4-00c04fd430c8";
export const NULL_TASK_ID = "00000000-0000-0000-0000-000000000000";

export const RESERVED = [
INTERRUPT,
RESUME,
ERROR,
TASKS,
CONFIG_KEY_SEND,
Expand Down Expand Up @@ -114,3 +120,17 @@ export type Interrupt = {
value: any;
when: "during";
};

export class Command<R = unknown> {
lg_name = "Command";

resume: R;

constructor(args: { resume: R }) {
this.resume = args.resume;
}
}

export function _isCommand(x: unknown): x is Command {
return typeof x === "object" && !!x && (x as Command).lg_name === "Command";
}
12 changes: 11 additions & 1 deletion libs/langgraph/src/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ export class BaseLangGraphError extends Error {
}
}

export class GraphBubbleUp extends BaseLangGraphError {
get is_bubble_up() {
return true;
}
}

export class GraphRecursionError extends BaseLangGraphError {
constructor(message?: string, fields?: BaseLangGraphErrorFields) {
super(message, fields);
Expand All @@ -40,7 +46,7 @@ export class GraphValueError extends BaseLangGraphError {
}
}

export class GraphInterrupt extends BaseLangGraphError {
export class GraphInterrupt extends GraphBubbleUp {
interrupts: Interrupt[];

constructor(interrupts?: Interrupt[], fields?: BaseLangGraphErrorFields) {
Expand Down Expand Up @@ -74,6 +80,10 @@ export class NodeInterrupt extends GraphInterrupt {
}
}

export function isGraphBubbleUp(e?: Error): e is GraphBubbleUp {
return e !== undefined && (e as GraphBubbleUp).is_bubble_up === true;
}

export function isGraphInterrupt(
e?: GraphInterrupt | Error
): e is GraphInterrupt {
Expand Down
18 changes: 18 additions & 0 deletions libs/langgraph/src/interrupt.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { RunnableConfig } from "@langchain/core/runnables";
import { AsyncLocalStorageProviderSingleton } from "@langchain/core/singletons";
import { GraphInterrupt } from "./errors.js";
import { CONFIG_KEY_RESUME_VALUE, MISSING } from "./constants.js";

export function interrupt<I = unknown, R = unknown>(value: I): R {
const config: RunnableConfig | undefined =
AsyncLocalStorageProviderSingleton.getRunnableConfig();
if (!config) {
throw new Error("Called interrupt() outside the context of a graph.");
}
const resume = config.configurable?.[CONFIG_KEY_RESUME_VALUE];
if (resume !== MISSING) {
return resume as R;
} else {
throw new GraphInterrupt([{ value, when: "during" }]);
}
}
51 changes: 41 additions & 10 deletions libs/langgraph/src/pregel/algo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ import {
CHECKPOINT_NAMESPACE_END,
PUSH,
PULL,
RESUME,
CONFIG_KEY_RESUME_VALUE,
NULL_TASK_ID,
MISSING,
} from "../constants.js";
import { PregelExecutableTask, PregelTaskDescription } from "./types.js";
import { EmptyChannelError, InvalidUpdateError } from "../errors.js";
Expand Down Expand Up @@ -189,13 +193,19 @@ export function _localWrite(
commit(writes);
}

const IGNORE = new Set<string | number | symbol>([PUSH, RESUME, INTERRUPT]);

export function _applyWrites<Cc extends Record<string, BaseChannel>>(
checkpoint: Checkpoint,
channels: Cc,
tasks: WritesProtocol<keyof Cc>[],
// eslint-disable-next-line @typescript-eslint/no-explicit-any
getNextVersion?: (version: any, channel: BaseChannel) => any
): Record<string, PendingWriteValue[]> {
// if no task has triggers this is applying writes from the null task only
// so we don't do anything other than update the channels written to
const bumpStep = tasks.some((task) => task.triggers.length > 0);

// Filter out non instances of BaseChannel
const onlyChannels = Object.fromEntries(
Object.entries(channels).filter(([_, value]) => isBaseChannel(value))
Expand Down Expand Up @@ -240,7 +250,7 @@ export function _applyWrites<Cc extends Record<string, BaseChannel>>(
}

// Clear pending sends
if (checkpoint.pending_sends) {
if (checkpoint.pending_sends?.length && bumpStep) {
checkpoint.pending_sends = [];
}

Expand All @@ -252,7 +262,9 @@ export function _applyWrites<Cc extends Record<string, BaseChannel>>(
const pendingWritesByManaged = {} as Record<keyof Cc, PendingWriteValue[]>;
for (const task of tasks) {
for (const [chan, val] of task.writes) {
if (chan === TASKS) {
if (IGNORE.has(chan)) {
// do nothing
} else if (chan === TASKS) {
checkpoint.pending_sends.push({
node: (val as Send).node,
args: (val as Send).args,
Expand Down Expand Up @@ -313,14 +325,16 @@ export function _applyWrites<Cc extends Record<string, BaseChannel>>(
}

// Channels that weren't updated in this step are notified of a new step
for (const chan of Object.keys(onlyChannels)) {
if (!updatedChannels.has(chan)) {
const updated = onlyChannels[chan].update([]);
if (updated && getNextVersion !== undefined) {
checkpoint.channel_versions[chan] = getNextVersion(
maxVersion,
onlyChannels[chan]
);
if (bumpStep) {
for (const chan of Object.keys(onlyChannels)) {
if (!updatedChannels.has(chan)) {
const updated = onlyChannels[chan].update([]);
if (updated && getNextVersion !== undefined) {
checkpoint.channel_versions[chan] = getNextVersion(
maxVersion,
onlyChannels[chan]
);
}
}
}
}
Expand Down Expand Up @@ -350,6 +364,7 @@ export function _prepareNextTasks<
Cc extends StrRecord<string, BaseChannel>
>(
checkpoint: ReadonlyCheckpoint,
pendingWrites: [string, string, unknown][] | undefined,
processes: Nn,
channels: Cc,
managed: ManagedValueMapping,
Expand All @@ -363,6 +378,7 @@ export function _prepareNextTasks<
Cc extends StrRecord<string, BaseChannel>
>(
checkpoint: ReadonlyCheckpoint,
pendingWrites: [string, string, unknown][] | undefined,
processes: Nn,
channels: Cc,
managed: ManagedValueMapping,
Expand All @@ -376,6 +392,7 @@ export function _prepareNextTasks<
Cc extends StrRecord<string, BaseChannel>
>(
checkpoint: ReadonlyCheckpoint,
pendingWrites: [string, string, unknown][] | undefined,
processes: Nn,
channels: Cc,
managed: ManagedValueMapping,
Expand All @@ -393,6 +410,7 @@ export function _prepareNextTasks<
const task = _prepareSingleTask(
[PUSH, i],
checkpoint,
pendingWrites,
processes,
channels,
managed,
Expand All @@ -410,6 +428,7 @@ export function _prepareNextTasks<
const task = _prepareSingleTask(
[PULL, name],
checkpoint,
pendingWrites,
processes,
channels,
managed,
Expand All @@ -430,6 +449,7 @@ export function _prepareSingleTask<
>(
taskPath: [string, string | number],
checkpoint: ReadonlyCheckpoint,
pendingWrites: [string, string, unknown][] | undefined,
processes: Nn,
channels: Cc,
managed: ManagedValueMapping,
Expand All @@ -444,6 +464,7 @@ export function _prepareSingleTask<
>(
taskPath: [string, string | number],
checkpoint: ReadonlyCheckpoint,
pendingWrites: [string, string, unknown][] | undefined,
processes: Nn,
channels: Cc,
managed: ManagedValueMapping,
Expand All @@ -458,6 +479,7 @@ export function _prepareSingleTask<
>(
taskPath: [string, string | number],
checkpoint: ReadonlyCheckpoint,
pendingWrites: [string, string, unknown][] | undefined,
processes: Nn,
channels: Cc,
managed: ManagedValueMapping,
Expand All @@ -472,6 +494,7 @@ export function _prepareSingleTask<
>(
taskPath: [string, string | number],
checkpoint: ReadonlyCheckpoint,
pendingWrites: [string, string, unknown][] | undefined,
processes: Nn,
channels: Cc,
managed: ManagedValueMapping,
Expand Down Expand Up @@ -537,6 +560,9 @@ export function _prepareSingleTask<
metadata = { ...metadata, ...proc.metadata };
}
const writes: [keyof Cc, unknown][] = [];
const resume = pendingWrites?.find(
(w) => [taskId, NULL_TASK_ID].includes(w[0]) && w[1] === RESUME
);
return {
name: packet.node,
input: packet.args,
Expand Down Expand Up @@ -587,6 +613,7 @@ export function _prepareSingleTask<
...configurable[CONFIG_KEY_CHECKPOINT_MAP],
[parentNamespace]: checkpoint.id,
},
[CONFIG_KEY_RESUME_VALUE]: resume ? resume[2] : MISSING,
checkpoint_id: undefined,
checkpoint_ns: taskCheckpointNamespace,
},
Expand Down Expand Up @@ -661,6 +688,9 @@ export function _prepareSingleTask<
metadata = { ...metadata, ...proc.metadata };
}
const writes: [keyof Cc, unknown][] = [];
const resume = pendingWrites?.find(
(w) => [taskId, NULL_TASK_ID].includes(w[0]) && w[1] === RESUME
);
const taskCheckpointNamespace = `${checkpointNamespace}${CHECKPOINT_NAMESPACE_END}${taskId}`;
return {
name,
Expand Down Expand Up @@ -714,6 +744,7 @@ export function _prepareSingleTask<
...configurable[CONFIG_KEY_CHECKPOINT_MAP],
[parentNamespace]: checkpoint.id,
},
[CONFIG_KEY_RESUME_VALUE]: resume ? resume[2] : MISSING,
checkpoint_id: undefined,
checkpoint_ns: taskCheckpointNamespace,
},
Expand Down
17 changes: 9 additions & 8 deletions libs/langgraph/src/pregel/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import {
CHECKPOINT_NAMESPACE_END,
CONFIG_KEY_STREAM,
CONFIG_KEY_TASK_ID,
Command,
} from "../constants.js";
import {
PregelExecutableTask,
Expand All @@ -64,6 +65,7 @@ import {
GraphRecursionError,
GraphValueError,
InvalidUpdateError,
isGraphBubbleUp,
isGraphInterrupt,
} from "../errors.js";
import {
Expand Down Expand Up @@ -405,6 +407,7 @@ export class Pregel<
const nextTasks = Object.values(
_prepareNextTasks(
saved.checkpoint,
saved.pendingWrites,
this.nodes,
channels,
managed,
Expand Down Expand Up @@ -915,7 +918,7 @@ export class Pregel<
* @param options.debug Whether to print debug information during execution.
*/
override async stream(
input: PregelInputType,
input: PregelInputType | Command,
options?: Partial<PregelOptions<Nn, Cc, ConfigurableFieldType>>
): Promise<IterableReadableStream<PregelOutputType>> {
// The ensureConfig method called internally defaults recursionLimit to 25 if not
Expand Down Expand Up @@ -994,7 +997,7 @@ export class Pregel<
}

override async *_streamIterator(
input: PregelInputType,
input: PregelInputType | Command,
options?: Partial<PregelOptions<Nn, Cc>>
): AsyncGenerator<PregelOutputType> {
const streamSubgraphs = options?.subgraphs;
Expand Down Expand Up @@ -1126,11 +1129,11 @@ export class Pregel<
// Timeouts will be thrown
for await (const { task, error } of taskStream) {
if (error !== undefined) {
if (isGraphInterrupt(error)) {
if (isGraphBubbleUp(error)) {
if (loop.isNested) {
throw error;
}
if (error.interrupts.length) {
if (isGraphInterrupt(error) && error.interrupts.length) {
loop.putWrites(
task.id,
error.interrupts.map((interrupt) => [INTERRUPT, interrupt])
Expand All @@ -1140,13 +1143,11 @@ export class Pregel<
loop.putWrites(task.id, [
[ERROR, { message: error.message, name: error.name }],
]);
throw error;
}
} else {
loop.putWrites(task.id, task.writes);
}
if (error !== undefined && !isGraphInterrupt(error)) {
throw error;
}
}

if (debug) {
Expand Down Expand Up @@ -1244,7 +1245,7 @@ export class Pregel<
* @param options.debug Whether to print debug information during execution.
*/
override async invoke(
input: PregelInputType,
input: PregelInputType | Command,
options?: Partial<PregelOptions<Nn, Cc, ConfigurableFieldType>>
): Promise<PregelOutputType> {
const streamMode = options?.streamMode ?? "values";
Expand Down
Loading

0 comments on commit e55369d

Please sign in to comment.