From 5f8a645ab027190b72017239dbdd608943069fde Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Sun, 25 Aug 2024 19:50:17 -0700 Subject: [PATCH] Handle explicit undefined values passed into serde (#374) --- libs/checkpoint/src/memory.ts | 67 +++--- libs/checkpoint/src/serde/jsonplus.ts | 13 +- .../src/serde/tests/jsonplus.test.ts | 30 ++- libs/langgraph/src/pregel/index.ts | 1 - libs/langgraph/src/pregel/loop.ts | 7 +- libs/langgraph/src/pregel/types.ts | 2 +- libs/langgraph/src/tests/pregel.test.ts | 209 +++++++++++++++++- 7 files changed, 273 insertions(+), 56 deletions(-) diff --git a/libs/checkpoint/src/memory.ts b/libs/checkpoint/src/memory.ts index 774917ce..fa363a46 100644 --- a/libs/checkpoint/src/memory.ts +++ b/libs/checkpoint/src/memory.ts @@ -54,17 +54,7 @@ export class MemorySaver extends BaseCheckpointSaver { ]; }) ); - const parentConfig = - parentCheckpointId !== undefined - ? { - configurable: { - thread_id, - checkpoint_ns, - checkpoint_id, - }, - } - : undefined; - return { + const checkpointTuple: CheckpointTuple = { config, checkpoint: (await this.serde.loadsTyped( "json", @@ -75,8 +65,17 @@ export class MemorySaver extends BaseCheckpointSaver { metadata )) as CheckpointMetadata, pendingWrites, - parentConfig, }; + if (parentCheckpointId !== undefined) { + checkpointTuple.parentConfig = { + configurable: { + thread_id, + checkpoint_ns, + checkpoint_id, + }, + }; + } + return checkpointTuple; } } else { const checkpoints = this.storage[thread_id]?.[checkpoint_ns]; @@ -99,17 +98,7 @@ export class MemorySaver extends BaseCheckpointSaver { ]; }) ); - const parentConfig = - parentCheckpointId !== undefined - ? { - configurable: { - thread_id, - checkpoint_ns, - checkpoint_id: parentCheckpointId, - }, - } - : undefined; - return { + const checkpointTuple: CheckpointTuple = { config: { configurable: { thread_id, @@ -126,8 +115,17 @@ export class MemorySaver extends BaseCheckpointSaver { metadata )) as CheckpointMetadata, pendingWrites, - parentConfig, }; + if (parentCheckpointId !== undefined) { + checkpointTuple.parentConfig = { + configurable: { + thread_id, + checkpoint_ns, + checkpoint_id: parentCheckpointId, + }, + }; + } + return checkpointTuple; } } @@ -191,7 +189,7 @@ export class MemorySaver extends BaseCheckpointSaver { }) ); - yield { + const checkpointTuple: CheckpointTuple = { config: { configurable: { thread_id: threadId, @@ -205,16 +203,17 @@ export class MemorySaver extends BaseCheckpointSaver { )) as Checkpoint, metadata, pendingWrites, - parentConfig: parentCheckpointId - ? { - configurable: { - thread_id: threadId, - checkpoint_ns: checkpointNamespace, - checkpoint_id: parentCheckpointId, - }, - } - : undefined, }; + if (parentCheckpointId !== undefined) { + checkpointTuple.parentConfig = { + configurable: { + thread_id: threadId, + checkpoint_ns: checkpointNamespace, + checkpoint_id: parentCheckpointId, + }, + }; + } + yield checkpointTuple; } } } diff --git a/libs/checkpoint/src/serde/jsonplus.ts b/libs/checkpoint/src/serde/jsonplus.ts index 1e4b7c20..eaa82c52 100644 --- a/libs/checkpoint/src/serde/jsonplus.ts +++ b/libs/checkpoint/src/serde/jsonplus.ts @@ -5,7 +5,9 @@ import { SerializerProtocol } from "./base.js"; async function _reviver(value: any): Promise { if (value && typeof value === "object") { - if ( + if (value.lc === 2 && value.type === "undefined") { + return undefined; + } else if ( value.lc === 2 && value.type === "constructor" && Array.isArray(value.id) @@ -64,14 +66,19 @@ function _encodeConstructorArgs( lc: 2, type: "constructor", id: [constructor.name], - method, + method: method ?? null, args: args ?? [], kwargs: kwargs ?? {}, }; } function _default(_key: string, obj: any): any { - if (obj instanceof Set || obj instanceof Map) { + if (obj === undefined) { + return { + lc: 2, + type: "undefined", + }; + } else if (obj instanceof Set || obj instanceof Map) { return _encodeConstructorArgs(obj.constructor, undefined, [ Array.from(obj), ]); diff --git a/libs/checkpoint/src/serde/tests/jsonplus.test.ts b/libs/checkpoint/src/serde/tests/jsonplus.test.ts index 9fc9578d..70bb4ee5 100644 --- a/libs/checkpoint/src/serde/tests/jsonplus.test.ts +++ b/libs/checkpoint/src/serde/tests/jsonplus.test.ts @@ -3,7 +3,7 @@ import { AIMessage, HumanMessage } from "@langchain/core/messages"; import { uuid6 } from "../../id.js"; import { JsonPlusSerializer } from "../jsonplus.js"; -const value = { +const complexValue = { number: 1, id: uuid6(-1), error: new Error("test error"), @@ -18,6 +18,7 @@ const value = { new Error("nestedfoo"), 5, true, + null, false, { a: "b", @@ -26,12 +27,27 @@ const value = { ], object: { messages: [new HumanMessage("hey there"), new AIMessage("hi how are you")], + nestedNullVal: null, + emptyString: "", }, + emptyString: "", + nullVal: null, }; -it("should serialize and deserialize various data types", async () => { - const serde = new JsonPlusSerializer(); - const [type, serialized] = serde.dumpsTyped(value); - const deserialized = await serde.loadsTyped(type, serialized); - expect(deserialized).toEqual(value); -}); +const VALUES = [ + ["undefined", undefined], + ["null", null], + ["empty string", ""], + ["simple string", "foobar"], + ["various data types", complexValue], +] satisfies [string, unknown][]; + +it.each(VALUES)( + "should serialize and deserialize %s", + async (_description, value) => { + const serde = new JsonPlusSerializer(); + const [type, serialized] = serde.dumpsTyped(value); + const deserialized = await serde.loadsTyped(type, serialized); + expect(deserialized).toEqual(value); + } +); diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 90d3992a..3a1c5d4b 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -454,7 +454,6 @@ export class Pregel< writers.length > 1 ? RunnableSequence.from(writers as any) : writers[0], writes: [], triggers: [INTERRUPT], - config: undefined, id: uuid5(INTERRUPT, checkpoint.id), }; diff --git a/libs/langgraph/src/pregel/loop.ts b/libs/langgraph/src/pregel/loop.ts index ea9795eb..c61800d7 100644 --- a/libs/langgraph/src/pregel/loop.ts +++ b/libs/langgraph/src/pregel/loop.ts @@ -302,7 +302,7 @@ export class PregelLoop { : mapOutputValues(outputKeys, writes, this.channels).next().value; await this._putCheckpoint({ source: "loop", - writes: metadataWrites, + writes: metadataWrites ?? null, }); // after execution, check if we should interrupt if (shouldInterrupt(this.checkpoint, interruptAfter, this.tasks)) { @@ -456,7 +456,10 @@ export class PregelLoop { this.checkpointerGetNextVersion ); // save input checkpoint - await this._putCheckpoint({ source: "input", writes: this.input }); + await this._putCheckpoint({ + source: "input", + writes: this.input ?? null, + }); } // done with input this.input = isResuming ? INPUT_RESUMING : INPUT_DONE; diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index 23ed0ffb..8dcdb0c4 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -84,7 +84,7 @@ export interface PregelExecutableTask< readonly input: unknown; readonly proc: Runnable; readonly writes: PendingWrite[]; - readonly config: RunnableConfig | undefined; + readonly config?: RunnableConfig; readonly triggers: Array; readonly retry_policy?: string; readonly id: string; diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index c7904256..92648f23 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -42,6 +42,7 @@ import { Graph, START, StateGraph, + StateGraphArgs, StateType, } from "../graph/index.js"; import { Topic } from "../channels/topic.js"; @@ -1340,7 +1341,7 @@ it("should invoke two processes with input/output and interrupt", async () => { checkpoint_id: expect.any(String), }, }, - metadata: { source: "loop", step: 5 }, + metadata: { source: "loop", step: 5, writes: null }, createdAt: expect.any(String), parentConfig: history[2].config, }), @@ -1370,7 +1371,7 @@ it("should invoke two processes with input/output and interrupt", async () => { checkpoint_id: expect.any(String), }, }, - metadata: { source: "loop", step: 3 }, + metadata: { source: "loop", step: 3, writes: null }, createdAt: expect.any(String), parentConfig: history[4].config, }), @@ -1415,7 +1416,7 @@ it("should invoke two processes with input/output and interrupt", async () => { checkpoint_id: expect.any(String), }, }, - metadata: { source: "loop", step: 0 }, + metadata: { source: "loop", step: 0, writes: null }, createdAt: expect.any(String), parentConfig: history[7].config, }), @@ -1657,7 +1658,7 @@ it("pending writes resume", async () => { }), }, ]); - expect(state.metadata).toEqual({ source: "loop", step: 0 }); + expect(state.metadata).toEqual({ source: "loop", step: 0, writes: null }); // should contain pending write of "one" and should contain error from "two" const checkpoint = await checkpointer.getTuple(thread1); @@ -2931,6 +2932,197 @@ describe("StateGraph", () => { ["values", { value: 6 }], ]); }); + + it("should allow undefined values returned in a node update", async () => { + interface GraphState { + test?: string; + reducerField?: string; + } + + const graphState: StateGraphArgs["channels"] = { + test: null, + reducerField: { + default: () => "", + reducer: (x, y?: string) => y ?? x, + }, + }; + + const workflow = new StateGraph({ channels: graphState }); + + async function updateTest( + _state: GraphState + ): Promise> { + return { + test: "test", + reducerField: "should not be wiped", + }; + } + + async function wipeFields( + _state: GraphState + ): Promise> { + return { + test: undefined, + reducerField: undefined, + }; + } + + workflow + .addNode("updateTest", updateTest) + .addNode("wipeFields", wipeFields) + .addEdge(START, "updateTest") + .addEdge("updateTest", "wipeFields") + .addEdge("wipeFields", END); + + const checkpointer = new MemorySaver(); + + const app = workflow.compile({ checkpointer }); + const config: RunnableConfig = { + configurable: { thread_id: "102" }, + }; + const res = await app.invoke( + { + messages: ["initial input"], + }, + config + ); + expect(res).toEqual({ + reducerField: "should not be wiped", + }); + const history = await gatherIterator(app.getStateHistory(config)); + expect(history).toEqual([ + { + values: { + reducerField: "should not be wiped", + }, + next: [], + tasks: [], + metadata: { + source: "loop", + writes: { + wipeFields: { + test: undefined, + reducerField: undefined, + }, + }, + step: 2, + }, + config: { + configurable: { + thread_id: "102", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + createdAt: expect.any(String), + parentConfig: { + configurable: { + thread_id: "102", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + }, + { + values: { + test: "test", + reducerField: "should not be wiped", + }, + next: ["wipeFields"], + tasks: [ + { + id: expect.any(String), + name: "wipeFields", + }, + ], + metadata: { + source: "loop", + writes: { + updateTest: { + test: "test", + reducerField: "should not be wiped", + }, + }, + step: 1, + }, + config: { + configurable: { + thread_id: "102", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + createdAt: expect.any(String), + parentConfig: { + configurable: { + thread_id: "102", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + }, + { + values: { + reducerField: "", + }, + next: ["updateTest"], + tasks: [ + { + id: expect.any(String), + name: "updateTest", + }, + ], + metadata: { + source: "loop", + writes: null, + step: 0, + }, + config: { + configurable: { + thread_id: "102", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + createdAt: expect.any(String), + parentConfig: { + configurable: { + thread_id: "102", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + }, + { + values: { + reducerField: "", + }, + next: ["__start__"], + tasks: [ + { + id: expect.any(String), + name: "__start__", + }, + ], + metadata: { + source: "input", + writes: { + messages: ["initial input"], + }, + step: -1, + }, + config: { + configurable: { + thread_id: "102", + checkpoint_ns: "", + checkpoint_id: expect.any(String), + }, + }, + createdAt: expect.any(String), + parentConfig: undefined, + }, + ]); + }); }); describe("PreBuilt", () => { @@ -3327,7 +3519,7 @@ it("checkpoint events", async () => { metadata: { source: "loop", step: 0, - writes: undefined, + writes: null, }, next: ["prepare"], tasks: [{ id: expect.any(String), name: "prepare" }], @@ -3580,6 +3772,7 @@ it("StateGraph start branch then end", async () => { { source: "loop", step: 0, + writes: null, }, { source: "input", @@ -3595,7 +3788,7 @@ it("StateGraph start branch then end", async () => { .config, createdAt: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread1))! .checkpoint.ts, - metadata: { source: "loop", step: 0 }, + metadata: { source: "loop", step: 0, writes: null }, parentConfig: ( await last( toolTwoWithCheckpointer.checkpointer!.list(thread1, { limit: 2 }) @@ -3645,7 +3838,7 @@ it("StateGraph start branch then end", async () => { .config, createdAt: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread2))! .checkpoint.ts, - metadata: { source: "loop", step: 0 }, + metadata: { source: "loop", step: 0, writes: null }, parentConfig: ( await last( toolTwoWithCheckpointer.checkpointer!.list(thread2, { limit: 2 }) @@ -3695,7 +3888,7 @@ it("StateGraph start branch then end", async () => { .config, createdAt: (await toolTwoWithCheckpointer.checkpointer!.getTuple(thread3))! .checkpoint.ts, - metadata: { source: "loop", step: 0 }, + metadata: { source: "loop", step: 0, writes: null }, parentConfig: ( await last( toolTwoWithCheckpointer.checkpointer!.list(thread3, { limit: 2 })