diff --git a/langgraph/src/checkpoint/base.ts b/langgraph/src/checkpoint/base.ts index 721b8bbd..e4b5a5a3 100644 --- a/langgraph/src/checkpoint/base.ts +++ b/langgraph/src/checkpoint/base.ts @@ -54,21 +54,6 @@ export interface ReadonlyCheckpoint extends Readonly { >; } -export function getChannelVersion( - checkpoint: ReadonlyCheckpoint, - channel: string -): number { - return checkpoint.channel_versions[channel] ?? 0; -} - -export function getVersionSeen( - checkpoint: ReadonlyCheckpoint, - node: string, - channel: string -): number { - return checkpoint.versions_seen[node]?.[channel] ?? 0; -} - export function deepCopy(obj: T): T { if (typeof obj !== "object" || obj === null) { return obj; diff --git a/langgraph/src/checkpoint/memory.ts b/langgraph/src/checkpoint/memory.ts index 07e73ae1..66ecf2a1 100644 --- a/langgraph/src/checkpoint/memory.ts +++ b/langgraph/src/checkpoint/memory.ts @@ -36,7 +36,7 @@ export class MemorySaver extends BaseCheckpointSaver { async getTuple(config: RunnableConfig): Promise { const thread_id = config.configurable?.thread_id; const checkpoint_ns = config.configurable?.checkpoint_ns ?? ""; - const checkpoint_id = config.configurable?.checkpoint_id; + let checkpoint_id = config.configurable?.checkpoint_id; if (checkpoint_id) { const saved = this.storage[thread_id]?.[checkpoint_ns]?.[checkpoint_id]; @@ -71,10 +71,11 @@ export class MemorySaver extends BaseCheckpointSaver { } else { const checkpoints = this.storage[thread_id]?.[checkpoint_ns]; if (checkpoints !== undefined) { - const maxThreadTs = Object.keys(checkpoints).sort((a, b) => + // eslint-disable-next-line prefer-destructuring + checkpoint_id = Object.keys(checkpoints).sort((a, b) => b.localeCompare(a) )[0]; - const saved = checkpoints[maxThreadTs]; + const saved = checkpoints[checkpoint_id]; const [checkpoint, metadata, parentCheckpointId] = saved; const writes = this.writes[_generateKey(thread_id, checkpoint_ns, checkpoint_id)] ?? @@ -98,7 +99,7 @@ export class MemorySaver extends BaseCheckpointSaver { config: { configurable: { thread_id, - checkpoint_id: maxThreadTs, + checkpoint_id, checkpoint_ns, }, }, diff --git a/langgraph/src/pregel/algo.ts b/langgraph/src/pregel/algo.ts index 5581b096..13914600 100644 --- a/langgraph/src/pregel/algo.ts +++ b/langgraph/src/pregel/algo.ts @@ -15,8 +15,6 @@ import { Checkpoint, ReadonlyCheckpoint, copyCheckpoint, - getChannelVersion, - getVersionSeen, } from "../checkpoint/base.js"; import { PregelNode } from "./read.js"; import { readChannel, readChannels } from "./io.js"; @@ -38,6 +36,7 @@ import { All, PregelExecutableTask, PregelTaskDescription } from "./types.js"; import { PendingWrite, PendingWriteValue } from "../checkpoint/types.js"; import { EmptyChannelError, InvalidUpdateError } from "../errors.js"; import { uuid5 } from "../checkpoint/id.js"; +import { _getIdMetadata, getNullChannelVersion } from "./utils.js"; /** * Construct a type with a set of properties K of type T @@ -56,11 +55,21 @@ export const increment = (current?: number) => { return current !== undefined ? current + 1 : 1; }; -export async function executeTasks( - tasks: Array<() => Promise>, +export async function* executeTasks( + tasks: Record< + string, + () => Promise<{ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + task: PregelExecutableTask; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + result: any; + error: Error; + }> + >, stepTimeout?: number, signal?: AbortSignal -): Promise { + // eslint-disable-next-line @typescript-eslint/no-explicit-any +): AsyncGenerator> { if (stepTimeout && signal) { if ("any" in AbortSignal) { // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -77,22 +86,29 @@ export async function executeTasks( signal?.throwIfAborted(); // Start all tasks - const started = tasks.map((task) => task()); - - let listener: () => void; - // Wait for all tasks to settle - // If any tasks fail, or signal is aborted, the promise will reject - await Promise.all( - signal - ? [ - ...started, - new Promise((_resolve, reject) => { - listener = () => reject(new Error("Abort")); - signal?.addEventListener("abort", listener); - }).finally(() => signal?.removeEventListener("abort", listener)), - ] - : started + const executingTasks = Object.fromEntries( + Object.entries(tasks).map(([taskId, task]) => { + return [taskId, task()]; + }) ); + let listener: () => void; + const signalPromise = new Promise((_resolve, reject) => { + listener = () => reject(new Error("Abort")); + signal?.addEventListener("abort", listener); + }).finally(() => signal?.removeEventListener("abort", listener)); + + while (Object.keys(executingTasks).length > 0) { + const { task, error } = await Promise.race([ + ...Object.values(executingTasks), + signalPromise, + ]); + if (error !== undefined) { + // TODO: don't stop others if exception is interrupt + throw error; + } + yield task; + delete executingTasks[task.id]; + } } export function shouldInterrupt( @@ -109,7 +125,7 @@ export function shouldInterrupt( } else if (versionType === "string") { nullVersion = ""; } - const seen = checkpoint.versions_seen[INTERRUPT] || {}; + const seen = checkpoint.versions_seen[INTERRUPT] ?? {}; const anyChannelUpdated = Object.entries(checkpoint.channel_versions).some( ([chan, version]) => { @@ -277,8 +293,6 @@ export function _applyWrites>( ); } updatedChannels.add(chan); - } else { - console.warn(`Skipping write for channel ${chan} which has no readers`); } } @@ -313,7 +327,7 @@ export function _prepareNextTasks< config: RunnableConfig, forExecution: false, extra: NextTaskExtraFields -): [Checkpoint, Array]; +): PregelTaskDescription[]; export function _prepareNextTasks< Nn extends StrRecord, @@ -325,7 +339,7 @@ export function _prepareNextTasks< config: RunnableConfig, forExecution: true, extra: NextTaskExtraFields -): [Checkpoint, Array>]; +): PregelExecutableTask[]; export function _prepareNextTasks< Nn extends StrRecord, @@ -337,12 +351,8 @@ export function _prepareNextTasks< config: RunnableConfig, forExecution: boolean, extra: NextTaskExtraFields -): [ - Checkpoint, - PregelTaskDescription[] | PregelExecutableTask[] -] { +): PregelTaskDescription[] | PregelExecutableTask[] { const parentNamespace = config.configurable?.checkpoint_ns ?? ""; - const newCheckpoint = copyCheckpoint(checkpoint); const tasks: Array> = []; const taskDescriptions: Array = []; const { step, isResuming = false, checkpointer, manager } = extra; @@ -361,12 +371,12 @@ export function _prepareNextTasks< continue; } const triggers = [TASKS]; - const metadata = { + const metadata = _getIdMetadata({ langgraph_step: step, langgraph_node: packet.node, langgraph_triggers: triggers, - langgraph_task_idx: tasks.length, - }; + langgraph_task_idx: forExecution ? tasks.length : taskDescriptions.length, + }); const checkpointNamespace = parentNamespace === "" ? packet.node @@ -424,81 +434,41 @@ export function _prepareNextTasks< // Check if any processes should be run in next step // If so, prepare the values to be passed to them + const nullVersion = getNullChannelVersion(checkpoint.channel_versions); + if (nullVersion === undefined) { + return forExecution ? tasks : taskDescriptions; + } for (const [name, proc] of Object.entries(processes)) { - const updatedChannels = proc.triggers + const seen = checkpoint.versions_seen[name] ?? {}; + const triggers = proc.triggers .filter((chan) => { - try { - readChannel(channels, chan, false); - return true; - } catch (e) { - return false; - } + const result = readChannel(channels, chan, false, true); + const isEmptyChannelError = + // eslint-disable-next-line no-instanceof/no-instanceof + result instanceof Error && + result.name === EmptyChannelError.unminifiable_name; + return ( + !isEmptyChannelError && + (checkpoint.channel_versions[chan] ?? nullVersion) > + (seen[chan] ?? nullVersion) + ); }) - .filter( - (chan) => - getChannelVersion(newCheckpoint, chan) > - getVersionSeen(newCheckpoint, name, chan) - ); - - const hasUpdatedChannels = updatedChannels.length > 0; + .sort(); // If any of the channels read by this process were updated - if (hasUpdatedChannels) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - let val: any; - - // If all trigger channels subscribed by this process are not empty - // then invoke the process with the values of all non-empty channels - if (Array.isArray(proc.channels)) { - let emptyChannels = 0; - for (const chan of proc.channels) { - try { - val = readChannel(channels, chan, false); - break; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (e: any) { - if (e.name === EmptyChannelError.unminifiable_name) { - emptyChannels += 1; - continue; - } else { - throw e; - } - } - } - - if (emptyChannels === proc.channels.length) { - continue; - } - } else if (typeof proc.channels === "object") { - val = {}; - try { - for (const [k, chan] of Object.entries(proc.channels)) { - val[k] = readChannel(channels, chan, !proc.triggers.includes(chan)); - } - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (e: any) { - if (e.name === EmptyChannelError.unminifiable_name) { - continue; - } else { - throw e; - } - } - } else { - throw new Error( - `Invalid channels type, expected list or dict, got ${proc.channels}` - ); + if (triggers.length > 0) { + const val = _procInput(proc, channels, forExecution); + if (val === undefined) { + continue; } - // If the process has a mapper, apply it to the value - if (proc.mapper !== undefined) { - val = proc.mapper(val); - } - - const metadata = { + const metadata = _getIdMetadata({ langgraph_step: step, langgraph_node: name, - langgraph_triggers: proc.triggers, - langgraph_task_idx: tasks.length, - }; + langgraph_triggers: triggers, + langgraph_task_idx: forExecution + ? tasks.length + : taskDescriptions.length, + }); const checkpointNamespace = parentNamespace === "" @@ -511,18 +481,6 @@ export function _prepareNextTasks< ); if (forExecution) { - // Update seen versions - if (!newCheckpoint.versions_seen[name]) { - newCheckpoint.versions_seen[name] = {}; - } - proc.triggers.forEach((chan: string) => { - const version = newCheckpoint.channel_versions[chan]; - if (version !== undefined) { - // side effect: updates newCheckpoint - newCheckpoint.versions_seen[name][chan] = version; - } - }); - const node = proc.getNode(); if (node !== undefined) { const writes: [keyof Cc, unknown][] = []; @@ -531,7 +489,7 @@ export function _prepareNextTasks< input: val, proc: node, writes, - triggers: proc.triggers, + triggers, config: patchConfig( mergeConfigs(config, proc.config, { metadata }), { @@ -551,7 +509,7 @@ export function _prepareNextTasks< { name, writes: writes as Array<[string, unknown]>, - triggers: proc.triggers, + triggers, } ), [CONFIG_KEY_CHECKPOINTER]: checkpointer, @@ -569,6 +527,61 @@ export function _prepareNextTasks< } } } + return forExecution ? tasks : taskDescriptions; +} + +function _procInput( + proc: PregelNode, + channels: StrRecord, + forExecution: boolean +) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + let val: any; + // If all trigger channels subscribed by this process are not empty + // then invoke the process with the values of all non-empty channels + if (Array.isArray(proc.channels)) { + let successfulRead = false; + for (const chan of proc.channels) { + try { + val = readChannel(channels, chan, false); + successfulRead = true; + break; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + if (e.name === EmptyChannelError.unminifiable_name) { + continue; + } else { + throw e; + } + } + } + if (!successfulRead) { + return; + } + } else if (typeof proc.channels === "object") { + val = {}; + try { + for (const [k, chan] of Object.entries(proc.channels)) { + val[k] = readChannel(channels, chan, !proc.triggers.includes(chan)); + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + if (e.name === EmptyChannelError.unminifiable_name) { + return; + } else { + throw e; + } + } + } else { + throw new Error( + `Invalid channels type, expected list or dict, got ${proc.channels}` + ); + } + + // If the process has a mapper, apply it to the value + if (forExecution && proc.mapper !== undefined) { + val = proc.mapper(val); + } - return [newCheckpoint, forExecution ? tasks : taskDescriptions]; + return val; } diff --git a/langgraph/src/pregel/debug.ts b/langgraph/src/pregel/debug.ts index 6f925ddc..33a54fd6 100644 --- a/langgraph/src/pregel/debug.ts +++ b/langgraph/src/pregel/debug.ts @@ -10,6 +10,7 @@ import { ERROR, TAG_HIDDEN, TASK_NAMESPACE } from "../constants.js"; import { EmptyChannelError } from "../errors.js"; import { PregelExecutableTask, PregelTaskDescription } from "./types.js"; import { readChannels } from "./io.js"; +import { _getIdMetadata } from "./utils.js"; type ConsoleColors = { start: string; @@ -86,12 +87,12 @@ export function* mapDebugTasks( if (config?.tags?.includes(TAG_HIDDEN)) continue; const metadata = { ...config?.metadata }; - const idMetadata = { + const idMetadata = _getIdMetadata({ langgraph_step: metadata.langgraph_step, langgraph_node: metadata.langgraph_node, langgraph_triggers: metadata.langgraph_triggers, langgraph_task_idx: metadata.langgraph_task_idx, - }; + }); yield { type: "task", @@ -120,12 +121,7 @@ export function* mapDebugTaskResults< if (config?.tags?.includes(TAG_HIDDEN)) continue; const metadata = { ...config?.metadata }; - const idMetadata = { - langgraph_step: metadata.langgraph_step, - langgraph_node: metadata.langgraph_node, - langgraph_triggers: metadata.langgraph_triggers, - langgraph_task_idx: metadata.langgraph_task_idx, - }; + const idMetadata = _getIdMetadata(metadata); yield { type: "task_result", diff --git a/langgraph/src/pregel/index.ts b/langgraph/src/pregel/index.ts index 6b4462d1..905152a0 100644 --- a/langgraph/src/pregel/index.ts +++ b/langgraph/src/pregel/index.ts @@ -37,6 +37,7 @@ import { CONFIG_KEY_CHECKPOINTER, CONFIG_KEY_READ, CONFIG_KEY_SEND, + ERROR, INTERRUPT, } from "../constants.js"; import { @@ -291,13 +292,13 @@ export class Pregel< const saved = await this.checkpointer.getTuple(config); const checkpoint = saved ? saved.checkpoint : emptyCheckpoint(); const channels = emptyChannels(this.channels, checkpoint); - const [, nextTasks] = _prepareNextTasks( + const nextTasks = _prepareNextTasks( checkpoint, this.nodes, channels, saved !== undefined ? saved.config : config, false, - { step: -1 } + { step: saved ? (saved.metadata?.step ?? -1) + 1 : -1 } ); return { values: readChannels(channels, this.streamChannelsAsIs), @@ -322,7 +323,7 @@ export class Pregel< } for await (const saved of this.checkpointer.list(config, options)) { const channels = emptyChannels(this.channels, saved.checkpoint); - const [, nextTasks] = _prepareNextTasks( + const nextTasks = _prepareNextTasks( saved.checkpoint, this.nodes, channels, @@ -688,31 +689,61 @@ export class Pregel< // execute tasks, and wait for one to fail or all to finish. // each task is independent from all other concurrent tasks // yield updates/debug output as each task finishes - const tasks = loop.tasks.map((pregelTask) => () => { - return pregelTask.proc.invoke(pregelTask.input, pregelTask.config); - }); - - await executeTasks(tasks, this.stepTimeout, config.signal); - - for (const task of loop.tasks) { - loop.putWrites(task.id, task.writes); - } + const tasks = Object.fromEntries( + loop.tasks + .filter((task) => task.writes.length === 0) + .map((pregelTask) => { + return [ + pregelTask.id, + async () => { + let error; + let result; + try { + result = await pregelTask.proc.invoke( + pregelTask.input, + pregelTask.config + ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + error = e; + error.pregelTaskId = pregelTask.id; + } + return { + task: pregelTask, + result, + error, + }; + }, + ]; + }) + ); - if (streamMode.includes("updates")) { - // TODO: Refactor - for await (const task of loop.tasks) { - yield* prefixGenerator( - mapOutputUpdates(outputKeys, [task]), - streamMode.length > 1 ? "updates" : undefined - ); + try { + for await (const task of executeTasks( + tasks, + this.stepTimeout, + config.signal + )) { + loop.putWrites(task.id, task.writes); + if (streamMode.includes("updates")) { + yield* prefixGenerator( + mapOutputUpdates(outputKeys, [task]), + streamMode.length > 1 ? "updates" : undefined + ); + } + if (streamMode.includes("debug")) { + yield* prefixGenerator( + mapDebugTaskResults(loop.step, [task], this.streamChannelsList), + streamMode.length > 1 ? "debug" : undefined + ); + } } - } - - if (streamMode.includes("debug")) { - yield* prefixGenerator( - mapDebugTaskResults(loop.step, loop.tasks, this.streamChannelsList), - streamMode.length > 1 ? "debug" : undefined - ); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + if (e.pregelTaskId) { + loop.putWrites(e.pregelTaskId, [[ERROR, { message: e.message }]]); + } + throw e; } if (debug) { diff --git a/langgraph/src/pregel/loop.ts b/langgraph/src/pregel/loop.ts index 03b370f2..506d8be3 100644 --- a/langgraph/src/pregel/loop.ts +++ b/langgraph/src/pregel/loop.ts @@ -23,6 +23,7 @@ import { import { CONFIG_KEY_READ, CONFIG_KEY_RESUMING, + ERROR, INPUT, INTERRUPT, } from "../constants.js"; @@ -322,7 +323,7 @@ export class PregelLoop { return false; } - const [, nextTasks] = _prepareNextTasks( + const nextTasks = _prepareNextTasks( this.checkpoint, this.graph.nodes, this.channels, @@ -364,6 +365,10 @@ export class PregelLoop { // if there are pending writes from a previous loop, apply them if (this.checkpointPendingWrites.length > 0) { for (const [tid, k, v] of this.checkpointPendingWrites) { + // TODO: Do the same for INTERRUPT + if (k === ERROR) { + continue; + } const task = this.tasks.find((t) => t.id === tid); if (task) { task.writes.push([k, v]); @@ -432,7 +437,7 @@ export class PregelLoop { )}` ); } - const [, discardTasks] = _prepareNextTasks( + const discardTasks = _prepareNextTasks( this.checkpoint, this.graph.nodes, this.channels, diff --git a/langgraph/src/pregel/utils.ts b/langgraph/src/pregel/utils.ts index 13c03a28..534e6af1 100644 --- a/langgraph/src/pregel/utils.ts +++ b/langgraph/src/pregel/utils.ts @@ -1,21 +1,25 @@ import type { ChannelVersions } from "../checkpoint/base.js"; +export function getNullChannelVersion(currentVersions: ChannelVersions) { + const versionValues = Object.values(currentVersions); + const versionType = + versionValues.length > 0 ? typeof versionValues[0] : undefined; + let nullVersion: number | string | undefined; + if (versionType === "number") { + nullVersion = 0; + } else if (versionType === "string") { + nullVersion = ""; + } + return nullVersion; +} + export function getNewChannelVersions( previousVersions: ChannelVersions, currentVersions: ChannelVersions ): ChannelVersions { // Get new channel versions if (Object.keys(previousVersions).length > 0) { - const versionValues = Object.values(currentVersions); - const versionType = - versionValues.length > 0 ? typeof versionValues[0] : undefined; - let nullVersion: number | string; - if (versionType === "number") { - nullVersion = 0; - } else if (versionType === "string") { - nullVersion = ""; - } - + const nullVersion = getNullChannelVersion(currentVersions); return Object.fromEntries( Object.entries(currentVersions).filter( ([k, v]) => v > (previousVersions[k] ?? nullVersion) @@ -36,3 +40,13 @@ export function _coerceToDict(value: any, defaultKey: string) { ? value : { [defaultKey]: value }; } + +// Order matters +export function _getIdMetadata(metadata: Record) { + return { + langgraph_step: metadata.langgraph_step, + langgraph_node: metadata.langgraph_node, + langgraph_triggers: metadata.langgraph_triggers, + langgraph_task_idx: metadata.langgraph_task_idx, + }; +} diff --git a/langgraph/src/tests/pregel.test.ts b/langgraph/src/tests/pregel.test.ts index 6317e968..49382ec0 100644 --- a/langgraph/src/tests/pregel.test.ts +++ b/langgraph/src/tests/pregel.test.ts @@ -1,5 +1,7 @@ /* eslint-disable no-process-env */ /* eslint-disable no-promise-executor-return */ +/* eslint-disable no-instanceof/no-instanceof */ +/* eslint-disable @typescript-eslint/no-explicit-any */ import { it, expect, jest, describe } from "@jest/globals"; import { RunnableConfig, @@ -54,7 +56,7 @@ import { Checkpoint, CheckpointTuple } from "../checkpoint/base.js"; import { GraphRecursionError, InvalidUpdateError } from "../errors.js"; import { SqliteSaver } from "../checkpoint/sqlite.js"; import { uuid5, uuid6 } from "../checkpoint/id.js"; -import { INTERRUPT, Send, TASKS } from "../constants.js"; +import { ERROR, INTERRUPT, Send, TASKS } from "../constants.js"; describe("Channel", () => { describe("writeTo", () => { @@ -683,7 +685,7 @@ describe("_prepareNextTasks", () => { }; // call method / assertions - const [newCheckpoint, taskDescriptions] = _prepareNextTasks( + const taskDescriptions = _prepareNextTasks( checkpoint, processes, channels, @@ -703,8 +705,8 @@ describe("_prepareNextTasks", () => { }); // the returned checkpoint is a copy of the passed checkpoint without versionsSeen updated - expect(newCheckpoint.versions_seen.node1.channel1).toBe(1); - expect(newCheckpoint.versions_seen.node2.channel2).toBe(5); + expect(checkpoint.versions_seen.node1.channel1).toBe(1); + expect(checkpoint.versions_seen.node2.channel2).toBe(5); }); it("should return an array of PregelExecutableTasks", () => { @@ -801,7 +803,7 @@ describe("_prepareNextTasks", () => { }; // call method / assertions - const [newCheckpoint, tasks] = _prepareNextTasks( + const tasks = _prepareNextTasks( checkpoint, processes, channels, @@ -858,7 +860,7 @@ describe("_prepareNextTasks", () => { input: 100, proc: new RunnablePassthrough(), writes: [], - triggers: ["channel1", "channel2"], + triggers: ["channel1"], config: { tags: [], configurable: expect.any(Object), @@ -866,7 +868,7 @@ describe("_prepareNextTasks", () => { langgraph_node: "node2", langgraph_step: -1, langgraph_task_idx: 2, - langgraph_triggers: ["channel1", "channel2"], + langgraph_triggers: ["channel1"], }), recursionLimit: 25, runId: undefined, @@ -875,9 +877,10 @@ describe("_prepareNextTasks", () => { id: expect.any(String), }); - expect(newCheckpoint.versions_seen.node1.channel1).toBe(2); - expect(newCheckpoint.versions_seen.node2.channel1).toBe(2); - expect(newCheckpoint.versions_seen.node2.channel2).toBe(5); + // Should not update versions seen, that occurs when applying writes + expect(checkpoint.versions_seen.node1.channel1).toBe(1); + expect(checkpoint.versions_seen.node2.channel1).not.toBeDefined(); + expect(checkpoint.versions_seen.node2.channel2).toBe(5); }); }); @@ -1537,6 +1540,30 @@ it("should raise InvalidUpdateError when the same LastValue channel is updated t await expect(app.invoke(2)).rejects.toThrow(InvalidUpdateError); }); +it("should fail to process two processes in an invalid way", async () => { + const addOne = jest.fn((x: number): number => x + 1); + + const one = Channel.subscribeTo("input") + .pipe(addOne) + .pipe(Channel.writeTo(["output"])); + const two = Channel.subscribeTo("input") + .pipe(addOne) + .pipe(Channel.writeTo(["output"])); + + const app = new Pregel({ + nodes: { one, two }, + channels: { + output: new LastValue(), + input: new LastValue(), + }, + inputChannels: "input", + outputChannels: "output", + }); + + // LastValue channels can only be updated once per iteration + await expect(app.invoke(2)).rejects.toThrow(InvalidUpdateError); +}); + it("should process two inputs to two outputs validly", async () => { const addOne = jest.fn((x: number): number => x + 1); @@ -1562,6 +1589,115 @@ it("should process two inputs to two outputs validly", async () => { expect(await app.invoke(2)).toEqual([3, 3]); }); +it("pending writes resume", async () => { + const checkpointer = new MemorySaverAssertImmutable(); + const StateAnnotation = Annotation.Root({ + value: Annotation({ reducer: (a, b) => a + b }), + }); + class AwhileMaker extends RunnableLambda { + calls: number = 0; + + sleep: number; + + rtn: Record | Error; + + constructor(sleep: number, rtn: Record | Error) { + super({ + func: async () => { + this.calls += 1; + await new Promise((resolve) => setTimeout(resolve, this.sleep)); + if (this.rtn instanceof Error) { + throw this.rtn; + } + return this.rtn; + }, + }); + this.sleep = sleep; + this.rtn = rtn; + } + + reset() { + this.calls = 0; + } + } + + const one = new AwhileMaker(0.2, { value: 2 }); + const two = new AwhileMaker(0.6, new Error("I'm not good")); + const builder = new StateGraph(StateAnnotation) + .addNode("one", one) + .addNode("two", two) + .addEdge("__start__", "one") + .addEdge("__start__", "two") + .addEdge("one", "__end__") + // TODO: Add retry policy + .addEdge("two", "__end__"); + const graph = builder.compile({ checkpointer }); + const thread1 = { configurable: { thread_id: "1" } }; + await expect(graph.invoke({ value: 1 }, thread1)).rejects.toThrow( + "I'm not good" + ); + expect(one.calls).toEqual(1); + expect(two.calls).toEqual(1); + + const state = await graph.getState(thread1); + expect(state).toBeDefined(); + expect(state.values).toEqual({ value: 1 }); + expect(state.next).toEqual(["one", "two"]); + expect(state.tasks).toEqual([ + { id: expect.any(String), name: "one" }, + { + id: expect.any(String), + name: "two", + error: expect.objectContaining({ + message: "I'm not good", + }), + }, + ]); + expect(state.metadata).toEqual({ source: "loop", step: 0 }); + + // should contain pending write of "one" and should contain error from "two" + const checkpoint = await checkpointer.getTuple(thread1); + expect(checkpoint).toBeDefined(); + const expectedWrites = [ + [expect.any(String), "one", "one"], + [expect.any(String), "value", 2], + [ + expect.any(String), + ERROR, + expect.objectContaining({ + message: "I'm not good", + }), + ], + ]; + expect(checkpoint?.pendingWrites).toEqual( + expect.arrayContaining(expectedWrites) + ); + + // both non-error pending writes come from same task + const nonErrorWrites = checkpoint!.pendingWrites!.filter( + (w) => w[1] !== ERROR + ); + expect(nonErrorWrites[0][0]).toEqual(nonErrorWrites[1][0]); + const errorWrites = checkpoint!.pendingWrites!.filter((w) => w[1] === ERROR); + expect(errorWrites[0][0]).not.toEqual(nonErrorWrites[0][0]); + + // resume execution + await expect(graph.invoke(null, thread1)).rejects.toThrow("I'm not good"); + // node "one" succeeded previously, so shouldn't be called again + expect(one.calls).toEqual(1); + // node "two" should have been called once again + expect(two.calls).toEqual(2); + + // confirm no new checkpoints saved + const state2 = await graph.getState(thread1); + expect(state2.metadata).toEqual(state.metadata); + + // resume execution, without exception + two.rtn = { value: 3 }; + // both the pending write and the new write were applied, 1 + 2 + 3 = 6 + expect(await graph.invoke(null, thread1)).toEqual({ value: 6 }); +}); + it("should allow a conditional edge after a send", async () => { const State = { items: Annotation({ @@ -3297,7 +3433,7 @@ it("checkpoint events", async () => { id: anyStringSame("task3"), name: "finish", input: { my_key: "value prepared slow", market: "DE" }, - triggers: ["tool_two_fast", "tool_two_slow"], + triggers: ["tool_two_slow"], }, }, { diff --git a/langgraph/src/tests/tracing.test.ts b/langgraph/src/tests/tracing.test.ts index d6632eb9..e4ca271c 100644 --- a/langgraph/src/tests/tracing.test.ts +++ b/langgraph/src/tests/tracing.test.ts @@ -94,10 +94,7 @@ it("stream events for a multi-node graph", async () => { langgraph_node: "testnode", langgraph_step: 1, langgraph_task_idx: 0, - langgraph_triggers: [ - "start:testnode", - "branch:testnode:condition:testnode", - ], + langgraph_triggers: ["start:testnode"], }), }, { @@ -114,10 +111,7 @@ it("stream events for a multi-node graph", async () => { langgraph_node: "testnode", langgraph_step: 1, langgraph_task_idx: 0, - langgraph_triggers: [ - "start:testnode", - "branch:testnode:condition:testnode", - ], + langgraph_triggers: ["start:testnode"], }), }, { @@ -134,10 +128,7 @@ it("stream events for a multi-node graph", async () => { langgraph_node: "testnode", langgraph_step: 1, langgraph_task_idx: 0, - langgraph_triggers: [ - "start:testnode", - "branch:testnode:condition:testnode", - ], + langgraph_triggers: ["start:testnode"], ls_model_type: "chat", ls_stop: undefined, }), @@ -157,10 +148,7 @@ it("stream events for a multi-node graph", async () => { langgraph_node: "testnode", langgraph_step: 1, langgraph_task_idx: 0, - langgraph_triggers: [ - "start:testnode", - "branch:testnode:condition:testnode", - ], + langgraph_triggers: ["start:testnode"], ls_model_type: "chat", ls_stop: undefined, }), @@ -182,10 +170,7 @@ it("stream events for a multi-node graph", async () => { langgraph_node: "testnode", langgraph_step: 1, langgraph_task_idx: 0, - langgraph_triggers: [ - "start:testnode", - "branch:testnode:condition:testnode", - ], + langgraph_triggers: ["start:testnode"], }), }, { @@ -202,10 +187,7 @@ it("stream events for a multi-node graph", async () => { langgraph_node: "testnode", langgraph_step: 1, langgraph_task_idx: 0, - langgraph_triggers: [ - "start:testnode", - "branch:testnode:condition:testnode", - ], + langgraph_triggers: ["start:testnode"], }), }, { @@ -223,10 +205,7 @@ it("stream events for a multi-node graph", async () => { langgraph_node: "testnode", langgraph_step: 1, langgraph_task_idx: 0, - langgraph_triggers: [ - "start:testnode", - "branch:testnode:condition:testnode", - ], + langgraph_triggers: ["start:testnode"], }), }, { @@ -243,10 +222,7 @@ it("stream events for a multi-node graph", async () => { langgraph_node: "testnode", langgraph_step: 1, langgraph_task_idx: 0, - langgraph_triggers: [ - "start:testnode", - "branch:testnode:condition:testnode", - ], + langgraph_triggers: ["start:testnode"], }), }, { @@ -263,10 +239,7 @@ it("stream events for a multi-node graph", async () => { langgraph_node: "testnode", langgraph_step: 1, langgraph_task_idx: 0, - langgraph_triggers: [ - "start:testnode", - "branch:testnode:condition:testnode", - ], + langgraph_triggers: ["start:testnode"], ls_model_type: "chat", ls_stop: undefined, }), @@ -286,10 +259,7 @@ it("stream events for a multi-node graph", async () => { langgraph_node: "testnode", langgraph_step: 1, langgraph_task_idx: 0, - langgraph_triggers: [ - "start:testnode", - "branch:testnode:condition:testnode", - ], + langgraph_triggers: ["start:testnode"], ls_model_type: "chat", ls_stop: undefined, }), @@ -311,10 +281,7 @@ it("stream events for a multi-node graph", async () => { langgraph_node: "testnode", langgraph_step: 1, langgraph_task_idx: 0, - langgraph_triggers: [ - "start:testnode", - "branch:testnode:condition:testnode", - ], + langgraph_triggers: ["start:testnode"], }), }, { @@ -332,10 +299,7 @@ it("stream events for a multi-node graph", async () => { langgraph_node: "testnode", langgraph_step: 1, langgraph_task_idx: 0, - langgraph_triggers: [ - "start:testnode", - "branch:testnode:condition:testnode", - ], + langgraph_triggers: ["start:testnode"], }), }, {