Skip to content

Commit

Permalink
fix(langgraph): Allow multiple interrupts per node (#713)
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul authored Dec 10, 2024
1 parent a97965d commit 2d1c891
Show file tree
Hide file tree
Showing 11 changed files with 322 additions and 90 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
"checkpointer",
"Checkpointers",
"Pregel"
]
],
"typescript.tsdk": "node_modules/typescript/lib"
}
6 changes: 6 additions & 0 deletions libs/checkpoint/src/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,9 @@ export const WRITES_IDX_MAP: Record<string, number> = {
[INTERRUPT]: -3,
[RESUME]: -4,
};

export function getCheckpointId(config: RunnableConfig): string {
return (
config.configurable?.checkpoint_id || config.configurable?.thread_ts || ""
);
}
93 changes: 53 additions & 40 deletions libs/checkpoint/src/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import {
CheckpointListOptions,
CheckpointTuple,
copyCheckpoint,
getCheckpointId,
WRITES_IDX_MAP,
} from "./base.js";
import { SerializerProtocol } from "./serde/base.js";
import {
Expand All @@ -29,7 +31,7 @@ export class MemorySaver extends BaseCheckpointSaver {
Record<string, Record<string, [Uint8Array, Uint8Array, string | undefined]>>
> = {};

writes: Record<string, CheckpointPendingWrite[]> = {};
writes: Record<string, Record<string, [string, string, Uint8Array]>> = {};

constructor(serde?: SerializerProtocol) {
super(serde);
Expand All @@ -42,13 +44,14 @@ export class MemorySaver extends BaseCheckpointSaver {
) {
let pendingSends: SendProtocol[] = [];
if (parentCheckpointId !== undefined) {
const key = _generateKey(threadId, checkpointNs, parentCheckpointId);
pendingSends = await Promise.all(
this.writes[_generateKey(threadId, checkpointNs, parentCheckpointId)]
Object.values(this.writes[key] || {})
?.filter(([_taskId, channel]) => {
return channel === TASKS;
})
.map(([_taskId, _channel, writes]) => {
return this.serde.loadsTyped("json", writes as string);
return this.serde.loadsTyped("json", writes);
}) ?? []
);
}
Expand All @@ -58,15 +61,13 @@ export class MemorySaver extends BaseCheckpointSaver {
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns ?? "";
let checkpoint_id = config.configurable?.checkpoint_id;
let checkpoint_id = getCheckpointId(config);

if (checkpoint_id) {
const saved = this.storage[thread_id]?.[checkpoint_ns]?.[checkpoint_id];
if (saved !== undefined) {
const [checkpoint, metadata, parentCheckpointId] = saved;
const writes =
this.writes[_generateKey(thread_id, checkpoint_ns, checkpoint_id)] ??
[];
const key = _generateKey(thread_id, checkpoint_ns, checkpoint_id);
const pending_sends = await this._getPendingSends(
thread_id,
checkpoint_ns,
Expand All @@ -77,13 +78,15 @@ export class MemorySaver extends BaseCheckpointSaver {
pending_sends,
};
const pendingWrites: CheckpointPendingWrite[] = await Promise.all(
writes.map(async ([taskId, channel, value]) => {
return [
taskId,
channel,
await this.serde.loadsTyped("json", value as string),
];
})
Object.values(this.writes[key] || {}).map(
async ([taskId, channel, value]) => {
return [
taskId,
channel,
await this.serde.loadsTyped("json", value),
];
}
)
);
const checkpointTuple: CheckpointTuple = {
config,
Expand Down Expand Up @@ -114,9 +117,7 @@ export class MemorySaver extends BaseCheckpointSaver {
)[0];
const saved = checkpoints[checkpoint_id];
const [checkpoint, metadata, parentCheckpointId] = saved;
const writes =
this.writes[_generateKey(thread_id, checkpoint_ns, checkpoint_id)] ??
[];
const key = _generateKey(thread_id, checkpoint_ns, checkpoint_id);
const pending_sends = await this._getPendingSends(
thread_id,
checkpoint_ns,
Expand All @@ -127,13 +128,15 @@ export class MemorySaver extends BaseCheckpointSaver {
pending_sends,
};
const pendingWrites: CheckpointPendingWrite[] = await Promise.all(
writes.map(async ([taskId, channel, value]) => {
return [
taskId,
channel,
await this.serde.loadsTyped("json", value as string),
];
})
Object.values(this.writes[key] || {}).map(
async ([taskId, channel, value]) => {
return [
taskId,
channel,
await this.serde.loadsTyped("json", value),
];
}
)
);
const checkpointTuple: CheckpointTuple = {
config: {
Expand Down Expand Up @@ -176,6 +179,7 @@ export class MemorySaver extends BaseCheckpointSaver {
? [config.configurable?.thread_id]
: Object.keys(this.storage);
const configCheckpointNamespace = config.configurable?.checkpoint_ns;
const configCheckpointId = config.configurable?.checkpoint_id;

for (const threadId of threadIds) {
for (const checkpointNamespace of Object.keys(
Expand All @@ -196,7 +200,12 @@ export class MemorySaver extends BaseCheckpointSaver {
checkpointId,
[checkpoint, metadataStr, parentCheckpointId],
] of sortedCheckpoints) {
// Filter by checkpoint ID
// Filter by checkpoint ID from config
if (configCheckpointId && checkpointId !== configCheckpointId) {
continue;
}

// Filter by checkpoint ID from before config
if (
before &&
before.configurable?.checkpoint_id &&
Expand Down Expand Up @@ -224,25 +233,23 @@ export class MemorySaver extends BaseCheckpointSaver {
// Limit search results
if (limit !== undefined) {
if (limit <= 0) break;
// eslint-disable-next-line no-param-reassign
limit -= 1;
}

const writes =
this.writes[
_generateKey(threadId, checkpointNamespace, checkpointId)
] ?? [];
const key = _generateKey(threadId, checkpointNamespace, checkpointId);
const writes = Object.values(this.writes[key] || {});
const pending_sends = await this._getPendingSends(
threadId,
checkpointNamespace,
parentCheckpointId
);

const pendingWrites: CheckpointPendingWrite[] = await Promise.all(
writes.map(async ([taskId, channel, value]) => {
return [
taskId,
channel,
await this.serde.loadsTyped("json", value as string),
await this.serde.loadsTyped("json", value),
];
})
);
Expand Down Expand Up @@ -336,16 +343,22 @@ export class MemorySaver extends BaseCheckpointSaver {
`Failed to put writes. The passed RunnableConfig is missing a required "checkpoint_id" field in its "configurable" property.`
);
}
const key = _generateKey(threadId, checkpointNamespace, checkpointId);
if (this.writes[key] === undefined) {
this.writes[key] = [];
const outerKey = _generateKey(threadId, checkpointNamespace, checkpointId);
const outerWrites_ = this.writes[outerKey];
if (this.writes[outerKey] === undefined) {
this.writes[outerKey] = {};
}
const pendingWrites: CheckpointPendingWrite[] = writes.map(
([channel, value]) => {
const [, serializedValue] = this.serde.dumpsTyped(value);
return [taskId, channel, serializedValue];
writes.forEach(([channel, value], idx) => {
const [, serializedValue] = this.serde.dumpsTyped(value);
const innerKey: [string, number] = [
taskId,
WRITES_IDX_MAP[channel] || idx,
];
const innerKeyStr = `${innerKey[0]},${innerKey[1]}`;
if (innerKey[1] >= 0 && outerWrites_ && innerKeyStr in outerWrites_) {
return;
}
);
this.writes[key].push(...pendingWrites);
this.writes[outerKey][innerKeyStr] = [taskId, channel, serializedValue];
});
}
}
3 changes: 3 additions & 0 deletions libs/langgraph/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ export const CONFIG_KEY_RESUMING = "__pregel_resuming";
export const CONFIG_KEY_TASK_ID = "__pregel_task_id";
export const CONFIG_KEY_STREAM = "__pregel_stream";
export const CONFIG_KEY_RESUME_VALUE = "__pregel_resume_value";
export const CONFIG_KEY_WRITES = "__pregel_writes";
export const CONFIG_KEY_SCRATCHPAD = "__pregel_scratchpad";
export const CONFIG_KEY_CHECKPOINT_NS = "checkpoint_ns";

// this one is part of public API
export const CONFIG_KEY_CHECKPOINT_MAP = "checkpoint_map";
Expand Down
127 changes: 114 additions & 13 deletions libs/langgraph/src/interrupt.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,126 @@
import { RunnableConfig } from "@langchain/core/runnables";
import { AsyncLocalStorageProviderSingleton } from "@langchain/core/singletons";
import { CheckpointPendingWrite } from "@langchain/langgraph-checkpoint";
import { RunnableConfig } from "@langchain/core/runnables";
import { GraphInterrupt } from "./errors.js";
import { CONFIG_KEY_RESUME_VALUE, MISSING } from "./constants.js";
import {
CONFIG_KEY_CHECKPOINT_NS,
CONFIG_KEY_SCRATCHPAD,
CONFIG_KEY_TASK_ID,
CONFIG_KEY_WRITES,
CONFIG_KEY_SEND,
CHECKPOINT_NAMESPACE_SEPARATOR,
NULL_TASK_ID,
RESUME,
} from "./constants.js";
import { PregelScratchpad } from "./pregel/types.js";

/**
* Interrupts the execution of a graph node.
* This function can be used to pause execution of a node, and return the value of the `resume`
* input when the graph is re-invoked using `Command`.
* Multiple interrupts can be called within a single node, and each will be handled sequentially.
*
* When an interrupt is called:
* 1. If there's a `resume` value available (from a previous `Command`), it returns that value.
* 2. Otherwise, it throws a `GraphInterrupt` with the provided value
* 3. The graph can be resumed by passing a `Command` with a `resume` value
*
* @param value - The value to include in the interrupt. This will be available in task.interrupts[].value
* @returns The `resume` value provided when the graph is re-invoked with a Command
*
* @example
* ```typescript
* // Define a node that uses multiple interrupts
* const nodeWithInterrupts = () => {
* // First interrupt - will pause execution and include {value: 1} in task values
* const answer1 = interrupt({ value: 1 });
*
* // Second interrupt - only called after first interrupt is resumed
* const answer2 = interrupt({ value: 2 });
*
* // Use the resume values
* return { myKey: answer1 + " " + answer2 };
* };
*
* // Resume the graph after first interrupt
* await graph.stream(new Command({ resume: "answer 1" }));
*
* // Resume the graph after second interrupt
* await graph.stream(new Command({ resume: "answer 2" }));
* // Final result: { myKey: "answer 1 answer 2" }
* ```
*
* @throws {Error} If called outside the context of a graph
* @throws {GraphInterrupt} When no resume value is available
*/
export function interrupt<I = unknown, R = unknown>(value: I): R {
const config: RunnableConfig | undefined =
AsyncLocalStorageProviderSingleton.getRunnableConfig();
if (!config) {
throw new Error("Called interrupt() outside the context of a graph.");
}
const resume = config.configurable?.[CONFIG_KEY_RESUME_VALUE];
if (resume !== MISSING) {
return resume as R;

// Track interrupt index
const scratchpad: PregelScratchpad<R> =
config.configurable?.[CONFIG_KEY_SCRATCHPAD];
if (scratchpad.interruptCounter === undefined) {
scratchpad.interruptCounter = 0;
} else {
throw new GraphInterrupt([
{
value,
when: "during",
resumable: true,
ns: config.configurable?.checkpoint_ns?.split("|"),
},
]);
scratchpad.interruptCounter += 1;
}
const idx = scratchpad.interruptCounter;

// Find previous resume values
const taskId = config.configurable?.[CONFIG_KEY_TASK_ID];
const writes: CheckpointPendingWrite[] =
config.configurable?.[CONFIG_KEY_WRITES] ?? [];

if (!scratchpad.resume) {
const newResume = (writes.find(
(w) => w[0] === taskId && w[1] === RESUME
)?.[2] || []) as R | R[];
scratchpad.resume = Array.isArray(newResume) ? newResume : [newResume];
}

if (scratchpad.resume) {
if (idx < scratchpad.resume.length) {
return scratchpad.resume[idx];
}
}

// Find current resume value
if (!scratchpad.usedNullResume) {
scratchpad.usedNullResume = true;
const sortedWrites = [...writes].sort(
(a, b) => b[0].localeCompare(a[0]) // Sort in reverse order
);

for (const [tid, c, v] of sortedWrites) {
if (tid === NULL_TASK_ID && c === RESUME) {
if (scratchpad.resume.length !== idx) {
throw new Error(
`Resume length mismatch: ${scratchpad.resume.length} !== ${idx}`
);
}
scratchpad.resume.push(v as R);
const send = config.configurable?.[CONFIG_KEY_SEND];
if (send) {
send([[RESUME, scratchpad.resume]]);
}
return v as R;
}
}
}

// No resume value found
throw new GraphInterrupt([
{
value,
when: "during",
resumable: true,
ns: config.configurable?.[CONFIG_KEY_CHECKPOINT_NS]?.split(
CHECKPOINT_NAMESPACE_SEPARATOR
),
},
]);
}
Loading

0 comments on commit 2d1c891

Please sign in to comment.