From a22691895663f35cd75344ce2a9bc7171df589d5 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Wed, 21 Aug 2024 11:17:54 -0700 Subject: [PATCH] Add pending writes to debug state, fix forking state bug (#338) * Add pending writes to debug state * Fix checkpoint forking bug * Fix test --- langgraph/src/channels/base.ts | 27 +-- langgraph/src/pregel/debug.ts | 4 +- langgraph/src/pregel/index.ts | 61 ++++++- langgraph/src/pregel/types.ts | 4 + langgraph/src/tests/pregel.test.ts | 260 +++++++++++++++++++++++++++++ 5 files changed, 336 insertions(+), 20 deletions(-) diff --git a/langgraph/src/channels/base.ts b/langgraph/src/channels/base.ts index a8de4514..1fc98b1e 100644 --- a/langgraph/src/channels/base.ts +++ b/langgraph/src/channels/base.ts @@ -84,20 +84,25 @@ export function emptyChannels>( export function createCheckpoint( checkpoint: ReadonlyCheckpoint, - channels: Record>, + channels: Record> | undefined, step: number ): Checkpoint { // eslint-disable-next-line @typescript-eslint/no-explicit-any - const values: Record = {}; - for (const k of Object.keys(channels)) { - try { - values[k] = channels[k].checkpoint(); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (error: any) { - if (error.name === EmptyChannelError.unminifiable_name) { - // no-op - } else { - throw error; // Rethrow unexpected errors + let values: Record; + if (channels === undefined) { + values = checkpoint.channel_values; + } else { + values = {}; + for (const k of Object.keys(channels)) { + try { + values[k] = channels[k].checkpoint(); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (error: any) { + if (error.name === EmptyChannelError.unminifiable_name) { + // no-op + } else { + throw error; // Rethrow unexpected errors + } } } } diff --git a/langgraph/src/pregel/debug.ts b/langgraph/src/pregel/debug.ts index bc5ff688..6f925ddc 100644 --- a/langgraph/src/pregel/debug.ts +++ b/langgraph/src/pregel/debug.ts @@ -207,8 +207,8 @@ export function* mapDebugCheckpoint< }; } -function tasksWithWrites( - tasks: readonly PregelExecutableTask[], +export function tasksWithWrites( + tasks: PregelTaskDescription[] | readonly PregelExecutableTask[], pendingWrites: CheckpointPendingWrite[] ): PregelTaskDescription[] { return tasks.map((task): PregelTaskDescription => { diff --git a/langgraph/src/pregel/index.ts b/langgraph/src/pregel/index.ts index 528dc48d..6b4462d1 100644 --- a/langgraph/src/pregel/index.ts +++ b/langgraph/src/pregel/index.ts @@ -10,6 +10,7 @@ import { getCallbackManagerForConfig, patchConfig, } from "@langchain/core/runnables"; +import { IterableReadableStream } from "@langchain/core/utils/stream"; import { BaseChannel, createCheckpoint, @@ -29,6 +30,7 @@ import { printStepCheckpoint, printStepTasks, printStepWrites, + tasksWithWrites, } from "./debug.js"; import { ChannelWrite, ChannelWriteEntry, PASSTHROUGH } from "./write.js"; import { @@ -278,6 +280,9 @@ export class Pregel< } } + /** + * Get the current state of the graph. + */ async getState(config: RunnableConfig): Promise { if (!this.checkpointer) { throw new GraphValueError("No checkpointer set"); @@ -297,6 +302,7 @@ export class Pregel< return { values: readChannels(channels, this.streamChannelsAsIs), next: nextTasks.map((task) => task.name), + tasks: tasksWithWrites(nextTasks, saved?.pendingWrites ?? []), metadata: saved?.metadata, config: saved ? saved.config : config, createdAt: saved?.checkpoint.ts, @@ -304,6 +310,9 @@ export class Pregel< }; } + /** + * Get the history of the state of the graph. + */ async *getStateHistory( config: RunnableConfig, options?: CheckpointListOptions @@ -324,6 +333,7 @@ export class Pregel< yield { values: readChannels(channels, this.streamChannelsAsIs), next: nextTasks.map((task) => task.name), + tasks: tasksWithWrites(nextTasks, saved.pendingWrites ?? []), metadata: saved.metadata, config: saved.config, createdAt: saved.checkpoint.ts, @@ -332,6 +342,11 @@ export class Pregel< } } + /** + * Update the state of the graph with the given values, as if they came from + * node `as_node`. If `as_node` is not provided, it will be set to the last node + * that updated the state, if not ambiguous. + */ async updateState( config: RunnableConfig, values: Record | unknown, @@ -361,10 +376,10 @@ export class Pregel< }; // Find last node that updated the state, if not provided - if (values === undefined && asNode === undefined) { + if (values == null && asNode === undefined) { return await this.checkpointer.put( checkpointConfig, - createCheckpoint(checkpoint, {}, step), + createCheckpoint(checkpoint, undefined, step), { source: "update", step, @@ -380,7 +395,7 @@ export class Pregel< }) .flat() .find((v) => !!v); - if (asNode === undefined && !nonNullVersion) { + if (asNode === undefined && nonNullVersion === undefined) { if ( typeof this.inputChannels === "string" && this.nodes[this.inputChannels] !== undefined @@ -562,7 +577,29 @@ export class Pregel< ]; } - async *_streamIterator( + /** + * Stream graph steps for a single input. + * @param input The input to the graph. + * @param options The configuration to use for the run. + * @param options.streamMode The mode to stream output. Defaults to value set on initialization. + * Options are "values", "updates", and "debug". Default is "values". + * values: Emit the current values of the state for each step. + * updates: Emit only the updates to the state for each step. + * Output is a dict with the node name as key and the updated values as value. + * debug: Emit debug events for each step. + * @param options.outputKeys The keys to stream. Defaults to all non-context channels. + * @param options.interruptBefore Nodes to interrupt before. + * @param options.interruptAfter Nodes to interrupt after. + * @param options.debug Whether to print debug information during execution. + */ + override async stream( + input: PregelInputType, + options?: Partial> + ): Promise> { + return super.stream(input, options); + } + + override async *_streamIterator( input: PregelInputType, options?: Partial> ): AsyncGenerator { @@ -722,10 +759,20 @@ export class Pregel< /** * Run the graph with a single input and config. - * @param input - * @param options + * @param input The input to the graph. + * @param options The configuration to use for the run. + * @param options.streamMode The mode to stream output. Defaults to value set on initialization. + * Options are "values", "updates", and "debug". Default is "values". + * values: Emit the current values of the state for each step. + * updates: Emit only the updates to the state for each step. + * Output is a dict with the node name as key and the updated values as value. + * debug: Emit debug events for each step. + * @param options.outputKeys The keys to stream. Defaults to all non-context channels. + * @param options.interruptBefore Nodes to interrupt before. + * @param options.interruptAfter Nodes to interrupt after. + * @param options.debug Whether to print debug information during execution. */ - async invoke( + override async invoke( input: PregelInputType, options?: Partial> ): Promise { diff --git a/langgraph/src/pregel/types.ts b/langgraph/src/pregel/types.ts index cee40e7d..43f5c35e 100644 --- a/langgraph/src/pregel/types.ts +++ b/langgraph/src/pregel/types.ts @@ -114,6 +114,10 @@ export interface StateSnapshot { * @default undefined */ readonly parentConfig?: RunnableConfig | undefined; + /** + * Tasks to execute in this step. If already attempted, may contain an error. + */ + readonly tasks: PregelTaskDescription[]; } export type All = "*"; diff --git a/langgraph/src/tests/pregel.test.ts b/langgraph/src/tests/pregel.test.ts index 56b548d6..6317e968 100644 --- a/langgraph/src/tests/pregel.test.ts +++ b/langgraph/src/tests/pregel.test.ts @@ -972,6 +972,21 @@ it("should process input and write kwargs correctly", async () => { }); }); +// TODO: Check undefined too +const FALSEY_VALUES = [null, 0, "", [], {}, new Set()]; +it.each(FALSEY_VALUES)( + "should process falsey value: %p", + async (falsyValue) => { + const graph = new Graph() + .addNode("return_falsy_const", () => falsyValue) + .addEdge(START, "return_falsy_const") + .addEdge("return_falsy_const", END) + .compile(); + + expect(await graph.invoke(1)).toBe(falsyValue); + } +); + it("should invoke single process in out objects", async () => { const addOne = jest.fn((x: number): number => x + 1); const chain = Channel.subscribeTo("input") @@ -1223,6 +1238,237 @@ it("should process batch with two processes and delays with graph", async () => expect(await graph.batch([3, 2, 1, 3, 5])).toEqual([5, 4, 3, 5, 7]); }); +it("should invoke two processes with input/output and interrupt", async () => { + const checkpointer = new MemorySaverAssertImmutable(); + const addOne = jest.fn((x: number) => { + return x + 1; + }); + const one = Channel.subscribeTo("input") + .pipe(addOne) + .pipe(Channel.writeTo(["inbox"])); + const two = Channel.subscribeTo("inbox") + .pipe(addOne) + .pipe(Channel.writeTo(["output"])); + + const app = new Pregel({ + nodes: { one, two }, + channels: { + inbox: new LastValue(), + output: new LastValue(), + input: new LastValue(), + }, + inputChannels: "input", + outputChannels: "output", + checkpointer, + interruptAfter: ["one"], + }); + + const thread1 = { configurable: { thread_id: "1" } }; + const thread2 = { configurable: { thread_id: "2" } }; + + // start execution, stop at inbox + expect(await app.invoke(2, thread1)).toBeUndefined(); + + // inbox == 3 + let checkpoint = await checkpointer.get(thread1); + expect(checkpoint?.channel_values.inbox).toBe(3); + + // resume execution, finish + expect(await app.invoke(null, thread1)).toBe(4); + + // start execution again, stop at inbox + expect(await app.invoke(20, thread1)).toBeUndefined(); + + // inbox == 21 + checkpoint = await checkpointer.get(thread1); + expect(checkpoint).not.toBeUndefined(); + expect(checkpoint?.channel_values.inbox).toBe(21); + + // send a new value in, interrupting the previous execution + expect(await app.invoke(3, thread1)).toBeUndefined(); + expect(await app.invoke(null, thread1)).toBe(5); + + // start execution again, stopping at inbox + expect(await app.invoke(20, thread2)).toBeUndefined(); + + // inbox == 21 + let snapshot = await app.getState(thread2); + expect(snapshot.values.inbox).toBe(21); + expect(snapshot.next).toEqual(["two"]); + + // update the state, resume + await app.updateState(thread2, 25, "one"); + expect(await app.invoke(null, thread2)).toBe(26); + + // no pending tasks + snapshot = await app.getState(thread2); + expect(snapshot.next).toEqual([]); + + // list history + const history = await gatherIterator(app.getStateHistory(thread1)); + expect(history).toEqual([ + expect.objectContaining({ + values: { inbox: 4, output: 5, input: 3 }, + tasks: [], + next: [], + config: { + configurable: { + thread_id: "1", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + metadata: { source: "loop", step: 6, writes: 5 }, + createdAt: expect.any(String), + parentConfig: history[1].config, + }), + expect.objectContaining({ + values: { inbox: 4, output: 4, input: 3 }, + tasks: [{ id: expect.any(String), name: "two" }], + next: ["two"], + config: { + configurable: { + thread_id: "1", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + metadata: { source: "loop", step: 5 }, + createdAt: expect.any(String), + parentConfig: history[2].config, + }), + expect.objectContaining({ + values: { inbox: 21, output: 4, input: 3 }, + tasks: [{ id: expect.any(String), name: "one" }], + next: ["one"], + config: { + configurable: { + thread_id: "1", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + metadata: { source: "input", step: 4, writes: 3 }, + createdAt: expect.any(String), + parentConfig: history[3].config, + }), + expect.objectContaining({ + values: { inbox: 21, output: 4, input: 20 }, + tasks: [{ id: expect.any(String), name: "two" }], + next: ["two"], + config: { + configurable: { + thread_id: "1", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + metadata: { source: "loop", step: 3 }, + createdAt: expect.any(String), + parentConfig: history[4].config, + }), + expect.objectContaining({ + values: { inbox: 3, output: 4, input: 20 }, + tasks: [{ id: expect.any(String), name: "one" }], + next: ["one"], + config: { + configurable: { + thread_id: "1", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + metadata: { source: "input", step: 2, writes: 20 }, + createdAt: expect.any(String), + parentConfig: history[5].config, + }), + expect.objectContaining({ + values: { inbox: 3, output: 4, input: 2 }, + tasks: [], + next: [], + config: { + configurable: { + thread_id: "1", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + metadata: { source: "loop", step: 1, writes: 4 }, + createdAt: expect.any(String), + parentConfig: history[6].config, + }), + expect.objectContaining({ + values: { inbox: 3, input: 2 }, + tasks: [{ id: expect.any(String), name: "two" }], + next: ["two"], + config: { + configurable: { + thread_id: "1", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + metadata: { source: "loop", step: 0 }, + createdAt: expect.any(String), + parentConfig: history[7].config, + }), + expect.objectContaining({ + values: { input: 2 }, + tasks: [{ id: expect.any(String), name: "one" }], + next: ["one"], + config: { + configurable: { + thread_id: "1", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + metadata: { source: "input", step: -1, writes: 2 }, + createdAt: expect.any(String), + parentConfig: undefined, + }), + ]); + + // forking from any previous checkpoint w/out forking should do nothing + expect( + await gatherIterator( + app.stream(null, { ...history[0].config, streamMode: "updates" }) + ) + ).toEqual([]); + expect( + await gatherIterator( + app.stream(null, { ...history[1].config, streamMode: "updates" }) + ) + ).toEqual([]); + expect( + await gatherIterator( + app.stream(null, { ...history[2].config, streamMode: "updates" }) + ) + ).toEqual([]); + + // forking and re-running from any prev checkpoint should re-run nodes + let forkConfig = await app.updateState(history[0].config, null); + expect( + await gatherIterator( + app.stream(null, { ...forkConfig, streamMode: "updates" }) + ) + ).toEqual([]); + + forkConfig = await app.updateState(history[1].config, null); + expect( + await gatherIterator( + app.stream(null, { ...forkConfig, streamMode: "updates" }) + ) + ).toEqual([{ two: { output: 5 } }]); + + forkConfig = await app.updateState(history[2].config, null); + expect( + await gatherIterator( + app.stream(null, { ...forkConfig, streamMode: "updates" }) + ) + ).toEqual([{ one: { inbox: 4 } }]); +}); + it("should batch many processes with input and output", async () => { const testSize = 100; const addOne = jest.fn((x: number) => x + 1); @@ -2274,6 +2520,7 @@ describe("StateGraph", () => { values: { messages: expectedOutputMessages.slice(0, 2), }, + tasks: [{ id: expect.any(String), name: "tools" }], next: ["tools"], metadata: { source: "loop", @@ -2323,6 +2570,7 @@ describe("StateGraph", () => { ], }, next: ["tools"], + tasks: [{ id: expect.any(String), name: "tools" }], metadata: { source: "update", step: 2, @@ -2399,6 +2647,10 @@ describe("StateGraph", () => { ], }, next: ["tools", "tools"], + tasks: [ + { id: expect.any(String), name: "tools" }, + { id: expect.any(String), name: "tools" }, + ], metadata: { source: "loop", step: 4, @@ -2456,6 +2708,7 @@ describe("StateGraph", () => { ], }, next: [], + tasks: [], metadata: { source: "update", step: 5, @@ -3195,6 +3448,7 @@ it("StateGraph start branch then end", async () => { expect(await toolTwoWithCheckpointer.getState(thread1)).toEqual({ values: { my_key: "value ⛰️", market: "DE" }, next: ["tool_two_slow"], + tasks: [{ id: expect.any(String), name: "tool_two_slow" }], config: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread1))! .config, createdAt: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread1))! @@ -3214,6 +3468,7 @@ it("StateGraph start branch then end", async () => { expect(await toolTwoWithCheckpointer.getState(thread1)).toEqual({ values: { my_key: "value ⛰️ slow", market: "DE" }, next: [], + tasks: [], config: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread1))! .config, createdAt: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread1))! @@ -3243,6 +3498,7 @@ it("StateGraph start branch then end", async () => { expect(await toolTwoWithCheckpointer.getState(thread2)).toEqual({ values: { my_key: "value", market: "US" }, next: ["tool_two_fast"], + tasks: [{ id: expect.any(String), name: "tool_two_fast" }], config: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread2))! .config, createdAt: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread2))! @@ -3262,6 +3518,7 @@ it("StateGraph start branch then end", async () => { expect(await toolTwoWithCheckpointer.getState(thread2)).toEqual({ values: { my_key: "value fast", market: "US" }, next: [], + tasks: [], config: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread2))! .config, createdAt: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread2))! @@ -3291,6 +3548,7 @@ it("StateGraph start branch then end", async () => { expect(await toolTwoWithCheckpointer.getState(thread3)).toEqual({ values: { my_key: "value", market: "US" }, next: ["tool_two_fast"], + tasks: [{ id: expect.any(String), name: "tool_two_fast" }], config: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread3))! .config, createdAt: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread3))! @@ -3307,6 +3565,7 @@ it("StateGraph start branch then end", async () => { expect(await toolTwoWithCheckpointer.getState(thread3)).toEqual({ values: { my_key: "valuekey", market: "US" }, next: ["tool_two_fast"], + tasks: [{ id: expect.any(String), name: "tool_two_fast" }], config: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread3))! .config, createdAt: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread3))! @@ -3330,6 +3589,7 @@ it("StateGraph start branch then end", async () => { expect(await toolTwoWithCheckpointer.getState(thread3)).toEqual({ values: { my_key: "valuekey fast", market: "US" }, next: [], + tasks: [], config: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread3))! .config, createdAt: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread3))!