Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(langgraph): Allow multiple interrupts per node #713

Merged
merged 17 commits into from
Dec 10, 2024
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
"checkpointer",
"Checkpointers",
"Pregel"
]
],
"typescript.tsdk": "node_modules/typescript/lib"
}
6 changes: 6 additions & 0 deletions libs/checkpoint/src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,9 @@ export const WRITES_IDX_MAP: Record<string, number> = {
[INTERRUPT]: -3,
[RESUME]: -4,
};

export function getCheckpointId(config: RunnableConfig): string {
return (
config.configurable?.checkpoint_id || config.configurable?.thread_ts || ""
);
}
93 changes: 53 additions & 40 deletions libs/checkpoint/src/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import {
CheckpointListOptions,
CheckpointTuple,
copyCheckpoint,
getCheckpointId,
WRITES_IDX_MAP,
} from "./base.js";
import { SerializerProtocol } from "./serde/base.js";
import {
Expand All @@ -29,7 +31,7 @@ export class MemorySaver extends BaseCheckpointSaver {
Record<string, Record<string, [Uint8Array, Uint8Array, string | undefined]>>
> = {};

writes: Record<string, CheckpointPendingWrite[]> = {};
writes: Record<string, Record<string, [string, string, Uint8Array]>> = {};

constructor(serde?: SerializerProtocol) {
super(serde);
Expand All @@ -42,13 +44,14 @@ export class MemorySaver extends BaseCheckpointSaver {
) {
let pendingSends: SendProtocol[] = [];
if (parentCheckpointId !== undefined) {
const key = _generateKey(threadId, checkpointNs, parentCheckpointId);
pendingSends = await Promise.all(
this.writes[_generateKey(threadId, checkpointNs, parentCheckpointId)]
Object.values(this.writes[key] || {})
?.filter(([_taskId, channel]) => {
return channel === TASKS;
})
.map(([_taskId, _channel, writes]) => {
return this.serde.loadsTyped("json", writes as string);
return this.serde.loadsTyped("json", writes);
}) ?? []
);
}
Expand All @@ -58,15 +61,13 @@ export class MemorySaver extends BaseCheckpointSaver {
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns ?? "";
let checkpoint_id = config.configurable?.checkpoint_id;
let checkpoint_id = getCheckpointId(config);

if (checkpoint_id) {
const saved = this.storage[thread_id]?.[checkpoint_ns]?.[checkpoint_id];
if (saved !== undefined) {
const [checkpoint, metadata, parentCheckpointId] = saved;
const writes =
this.writes[_generateKey(thread_id, checkpoint_ns, checkpoint_id)] ??
[];
const key = _generateKey(thread_id, checkpoint_ns, checkpoint_id);
const pending_sends = await this._getPendingSends(
thread_id,
checkpoint_ns,
Expand All @@ -77,13 +78,15 @@ export class MemorySaver extends BaseCheckpointSaver {
pending_sends,
};
const pendingWrites: CheckpointPendingWrite[] = await Promise.all(
writes.map(async ([taskId, channel, value]) => {
return [
taskId,
channel,
await this.serde.loadsTyped("json", value as string),
];
})
Object.values(this.writes[key] || {}).map(
async ([taskId, channel, value]) => {
return [
taskId,
channel,
await this.serde.loadsTyped("json", value),
];
}
)
);
const checkpointTuple: CheckpointTuple = {
config,
Expand Down Expand Up @@ -114,9 +117,7 @@ export class MemorySaver extends BaseCheckpointSaver {
)[0];
const saved = checkpoints[checkpoint_id];
const [checkpoint, metadata, parentCheckpointId] = saved;
const writes =
this.writes[_generateKey(thread_id, checkpoint_ns, checkpoint_id)] ??
[];
const key = _generateKey(thread_id, checkpoint_ns, checkpoint_id);
const pending_sends = await this._getPendingSends(
thread_id,
checkpoint_ns,
Expand All @@ -127,13 +128,15 @@ export class MemorySaver extends BaseCheckpointSaver {
pending_sends,
};
const pendingWrites: CheckpointPendingWrite[] = await Promise.all(
writes.map(async ([taskId, channel, value]) => {
return [
taskId,
channel,
await this.serde.loadsTyped("json", value as string),
];
})
Object.values(this.writes[key] || {}).map(
async ([taskId, channel, value]) => {
return [
taskId,
channel,
await this.serde.loadsTyped("json", value),
];
}
)
);
const checkpointTuple: CheckpointTuple = {
config: {
Expand Down Expand Up @@ -176,6 +179,7 @@ export class MemorySaver extends BaseCheckpointSaver {
? [config.configurable?.thread_id]
: Object.keys(this.storage);
const configCheckpointNamespace = config.configurable?.checkpoint_ns;
const configCheckpointId = config.configurable?.checkpoint_id;

for (const threadId of threadIds) {
for (const checkpointNamespace of Object.keys(
Expand All @@ -196,7 +200,12 @@ export class MemorySaver extends BaseCheckpointSaver {
checkpointId,
[checkpoint, metadataStr, parentCheckpointId],
] of sortedCheckpoints) {
// Filter by checkpoint ID
// Filter by checkpoint ID from config
if (configCheckpointId && checkpointId !== configCheckpointId) {
continue;
}

// Filter by checkpoint ID from before config
if (
before &&
before.configurable?.checkpoint_id &&
Expand Down Expand Up @@ -224,25 +233,23 @@ export class MemorySaver extends BaseCheckpointSaver {
// Limit search results
if (limit !== undefined) {
if (limit <= 0) break;
// eslint-disable-next-line no-param-reassign
limit -= 1;
}

const writes =
this.writes[
_generateKey(threadId, checkpointNamespace, checkpointId)
] ?? [];
const key = _generateKey(threadId, checkpointNamespace, checkpointId);
const writes = Object.values(this.writes[key] || {});
const pending_sends = await this._getPendingSends(
threadId,
checkpointNamespace,
parentCheckpointId
);

const pendingWrites: CheckpointPendingWrite[] = await Promise.all(
writes.map(async ([taskId, channel, value]) => {
return [
taskId,
channel,
await this.serde.loadsTyped("json", value as string),
await this.serde.loadsTyped("json", value),
];
})
);
Expand Down Expand Up @@ -336,16 +343,22 @@ export class MemorySaver extends BaseCheckpointSaver {
`Failed to put writes. The passed RunnableConfig is missing a required "checkpoint_id" field in its "configurable" property.`
);
}
const key = _generateKey(threadId, checkpointNamespace, checkpointId);
if (this.writes[key] === undefined) {
this.writes[key] = [];
const outerKey = _generateKey(threadId, checkpointNamespace, checkpointId);
const outerWrites_ = this.writes[outerKey];
if (this.writes[outerKey] === undefined) {
this.writes[outerKey] = {};
}
const pendingWrites: CheckpointPendingWrite[] = writes.map(
([channel, value]) => {
const [, serializedValue] = this.serde.dumpsTyped(value);
return [taskId, channel, serializedValue];
writes.forEach(([channel, value], idx) => {
const [, serializedValue] = this.serde.dumpsTyped(value);
const innerKey: [string, number] = [
taskId,
WRITES_IDX_MAP[channel] || idx,
];
const innerKeyStr = `${innerKey[0]},${innerKey[1]}`;
if (innerKey[1] >= 0 && outerWrites_ && innerKeyStr in outerWrites_) {
return;
}
);
this.writes[key].push(...pendingWrites);
this.writes[outerKey][innerKeyStr] = [taskId, channel, serializedValue];
});
}
}
3 changes: 3 additions & 0 deletions libs/langgraph/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ export const CONFIG_KEY_RESUMING = "__pregel_resuming";
export const CONFIG_KEY_TASK_ID = "__pregel_task_id";
export const CONFIG_KEY_STREAM = "__pregel_stream";
export const CONFIG_KEY_RESUME_VALUE = "__pregel_resume_value";
export const CONFIG_KEY_WRITES = "__pregel_writes";
export const CONFIG_KEY_SCRATCHPAD = "__pregel_scratchpad";
export const CONFIG_KEY_CHECKPOINT_NS = "checkpoint_ns";

// this one is part of public API
export const CONFIG_KEY_CHECKPOINT_MAP = "checkpoint_map";
Expand Down
127 changes: 114 additions & 13 deletions libs/langgraph/src/interrupt.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,126 @@
import { RunnableConfig } from "@langchain/core/runnables";
import { AsyncLocalStorageProviderSingleton } from "@langchain/core/singletons";
import { CheckpointPendingWrite } from "@langchain/langgraph-checkpoint";
import { RunnableConfig } from "@langchain/core/runnables";
import { GraphInterrupt } from "./errors.js";
import { CONFIG_KEY_RESUME_VALUE, MISSING } from "./constants.js";
import {
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_SCRATCHPAD,
CONFIG_KEY_TASK_ID,
CONFIG_KEY_WRITES,
CONFIG_KEY_SEND,
CHECKPOINT_NAMESPACE_SEPARATOR,
NULL_TASK_ID,
RESUME,
} from "./constants.js";
import { PregelScratchpad } from "./pregel/types.js";

/**
* Interrupts the execution of a graph node.
* This function can be used to pause execution of a node, and return the value of the `resume`
* input when the graph is re-invoked using `Command`.
* Multiple interrupts can be called within a single node, and each will be handled sequentially.
*
* When an interrupt is called:
* 1. If there's a `resume` value available (from a previous `Command`), it returns that value.
* 2. Otherwise, it throws a `GraphInterrupt` with the provided value
* 3. The graph can be resumed by passing a `Command` with a `resume` value
*
* @param value - The value to include in the interrupt. This will be available in task.interrupts[].value
* @returns The `resume` value provided when the graph is re-invoked with a Command
*
* @example
* ```typescript
* // Define a node that uses multiple interrupts
* const nodeWithInterrupts = () => {
* // First interrupt - will pause execution and include {value: 1} in task values
* const answer1 = interrupt({ value: 1 });
*
* // Second interrupt - only called after first interrupt is resumed
* const answer2 = interrupt({ value: 2 });
*
* // Use the resume values
* return { myKey: answer1 + " " + answer2 };
* };
*
* // Resume the graph after first interrupt
* await graph.stream(new Command({ resume: "answer 1" }));
*
* // Resume the graph after second interrupt
* await graph.stream(new Command({ resume: "answer 2" }));
* // Final result: { myKey: "answer 1 answer 2" }
* ```
*
* @throws {Error} If called outside the context of a graph
* @throws {GraphInterrupt} When no resume value is available
*/
export function interrupt<I = unknown, R = unknown>(value: I): R {
const config: RunnableConfig | undefined =
AsyncLocalStorageProviderSingleton.getRunnableConfig();
if (!config) {
throw new Error("Called interrupt() outside the context of a graph.");
}
const resume = config.configurable?.[CONFIG_KEY_RESUME_VALUE];
if (resume !== MISSING) {
return resume as R;

// Track interrupt index
const scratchpad: PregelScratchpad<R> =
config.configurable?.[CONFIG_KEY_SCRATCHPAD];
if (scratchpad.interruptCounter === undefined) {
scratchpad.interruptCounter = 0;
} else {
throw new GraphInterrupt([
{
value,
when: "during",
resumable: true,
ns: config.configurable?.checkpoint_ns?.split("|"),
},
]);
scratchpad.interruptCounter += 1;
}
const idx = scratchpad.interruptCounter;

// Find previous resume values
const taskId = config.configurable?.[CONFIG_KEY_TASK_ID];
const writes: CheckpointPendingWrite[] =
config.configurable?.[CONFIG_KEY_WRITES] ?? [];

if (!scratchpad.resume) {
const newResume = (writes.find(
(w) => w[0] === taskId && w[1] === RESUME
)?.[2] || []) as R | R[];
scratchpad.resume = Array.isArray(newResume) ? newResume : [newResume];
}

if (scratchpad.resume) {
if (idx < scratchpad.resume.length) {
return scratchpad.resume[idx];
}
}

// Find current resume value
if (!scratchpad.usedNullResume) {
scratchpad.usedNullResume = true;
const sortedWrites = [...writes].sort(
(a, b) => b[0].localeCompare(a[0]) // Sort in reverse order
);

for (const [tid, c, v] of sortedWrites) {
if (tid === NULL_TASK_ID && c === RESUME) {
if (scratchpad.resume.length !== idx) {
throw new Error(
`Resume length mismatch: ${scratchpad.resume.length} !== ${idx}`
);
}
scratchpad.resume.push(v as R);
const send = config.configurable?.[CONFIG_KEY_SEND];
if (send) {
send([[RESUME, scratchpad.resume]]);
}
return v as R;
}
}
}

// No resume value found
throw new GraphInterrupt([
{
value,
when: "during",
resumable: true,
ns: config.configurable?.[CONFIG_KEY_CHECKPOINT_NS]?.split(
CHECKPOINT_NAMESPACE_SEPARATOR
),
},
]);
}
Loading
Loading