From 356206893f8716dcac76d9aa07a09aff90f1463c Mon Sep 17 00:00:00 2001 From: Ben Burns <803016+benjamincburns@users.noreply.github.com> Date: Thu, 24 Oct 2024 22:51:58 +1300 Subject: [PATCH] fix(checkpoint-sqlite): populate pending_sends and pendingWrites fixes #590 --- libs/checkpoint-sqlite/src/index.ts | 236 +++++++++++++++----- libs/checkpoint-validation/src/spec/list.ts | 7 +- 2 files changed, 182 insertions(+), 61 deletions(-) diff --git a/libs/checkpoint-sqlite/src/index.ts b/libs/checkpoint-sqlite/src/index.ts index 6740bb22..a75542ed 100644 --- a/libs/checkpoint-sqlite/src/index.ts +++ b/libs/checkpoint-sqlite/src/index.ts @@ -8,6 +8,8 @@ import { type SerializerProtocol, type PendingWrite, type CheckpointMetadata, + TASKS, + copyCheckpoint, } from "@langchain/langgraph-checkpoint"; interface CheckpointRow { @@ -18,17 +20,20 @@ interface CheckpointRow { checkpoint_id: string; checkpoint_ns?: string; type?: string; + pending_writes: string; + pending_sends: string; } -interface WritesRow { - thread_id: string; - checkpoint_ns: string; - checkpoint_id: string; +interface PendingWriteColumn { task_id: string; - idx: number; channel: string; - type?: string; - value?: string; + type: string; + value: string; +} + +interface PendingSendColumn { + type: string; + value: string; } // In the `SqliteSaver.list` method, we need to sanitize the `options.filter` argument to ensure it only contains keys @@ -113,20 +118,57 @@ CREATE TABLE IF NOT EXISTS writes ( checkpoint_ns = "", checkpoint_id, } = config.configurable ?? {}; - let row: CheckpointRow; + const sql = ` + SELECT + thread_id, + checkpoint_ns, + checkpoint_id, + parent_checkpoint_id, + type, + checkpoint, + metadata, + ( + SELECT + json_group_array( + json_object( + 'task_id', pw.task_id, + 'channel', pw.channel, + 'type', pw.type, + 'value', CAST(pw.value AS TEXT) + ) + ) + FROM writes as pw + WHERE pw.thread_id = checkpoints.thread_id + AND pw.checkpoint_ns = checkpoints.checkpoint_ns + AND pw.checkpoint_id = checkpoints.checkpoint_id + ) as pending_writes, + ( + SELECT + json_group_array( + json_object( + 'type', ps.type, + 'value', CAST(ps.value AS TEXT) + ) + ) + FROM writes as ps + WHERE ps.thread_id = checkpoints.thread_id + AND ps.checkpoint_ns = checkpoints.checkpoint_ns + AND ps.checkpoint_id = checkpoints.parent_checkpoint_id + AND ps.channel = '${TASKS}' + ORDER BY ps.idx + ) as pending_sends + FROM checkpoints + WHERE thread_id = ? AND checkpoint_ns = ? ${ + checkpoint_id + ? "AND checkpoint_id = ?" + : "ORDER BY checkpoint_id DESC LIMIT 1" + }`; + + const args = [thread_id, checkpoint_ns]; if (checkpoint_id) { - row = this.db - .prepare( - `SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?` - ) - .get(thread_id, checkpoint_ns, checkpoint_id) as CheckpointRow; - } else { - row = this.db - .prepare( - `SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1` - ) - .get(thread_id, checkpoint_ns) as CheckpointRow; + args.push(checkpoint_id); } + const row = this.db.prepare(sql).get(...args) as CheckpointRow; if (row === undefined) { return undefined; } @@ -146,31 +188,36 @@ CREATE TABLE IF NOT EXISTS writes ( ) { throw new Error("Missing thread_id or checkpoint_id"); } - // find any pending writes - const pendingWritesRows = this.db - .prepare( - `SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?` - ) - .all( - finalConfig.configurable.thread_id.toString(), - checkpoint_ns, - finalConfig.configurable.checkpoint_id.toString() - ) as WritesRow[]; + const pendingWrites = await Promise.all( - pendingWritesRows.map(async (row) => { - return [ - row.task_id, - row.channel, - await this.serde.loadsTyped(row.type ?? "json", row.value ?? ""), - ] as [string, string, unknown]; - }) + (JSON.parse(row.pending_writes) as PendingWriteColumn[]).map( + async (write) => { + return [ + write.task_id, + write.channel, + await this.serde.loadsTyped( + write.type ?? "json", + write.value ?? "" + ), + ] as [string, string, unknown]; + } + ) + ); + + const pending_sends = await Promise.all( + (JSON.parse(row.pending_sends) as PendingSendColumn[]).map((send) => + this.serde.loadsTyped(send.type ?? "json", send.value ?? "") + ) ); + + const checkpoint = { + ...(await this.serde.loadsTyped(row.type ?? "json", row.checkpoint)), + pending_sends, + } as Checkpoint; + return { + checkpoint, config: finalConfig, - checkpoint: (await this.serde.loadsTyped( - row.type ?? "json", - row.checkpoint - )) as Checkpoint, metadata: (await this.serde.loadsTyped( row.type ?? "json", row.metadata @@ -196,17 +243,46 @@ CREATE TABLE IF NOT EXISTS writes ( this.setup(); const thread_id = config.configurable?.thread_id; const checkpoint_ns = config.configurable?.checkpoint_ns; - - let sql = - `SELECT\n` + - " thread_id,\n" + - " checkpoint_ns,\n" + - " checkpoint_id,\n" + - " parent_checkpoint_id,\n" + - " type,\n" + - " checkpoint,\n" + - " metadata\n" + - "FROM checkpoints\n"; + let sql = ` + SELECT + thread_id, + checkpoint_ns, + checkpoint_id, + parent_checkpoint_id, + type, + checkpoint, + metadata, + ( + SELECT + json_group_array( + json_object( + 'task_id', pw.task_id, + 'channel', pw.channel, + 'type', pw.type, + 'value', CAST(pw.value AS TEXT) + ) + ) + FROM writes as pw + WHERE pw.thread_id = checkpoints.thread_id + AND pw.checkpoint_ns = checkpoints.checkpoint_ns + AND pw.checkpoint_id = checkpoints.checkpoint_id + ) as pending_writes, + ( + SELECT + json_group_array( + json_object( + 'type', ps.type, + 'value', CAST(ps.value AS TEXT) + ) + ) + FROM writes as ps + WHERE ps.thread_id = checkpoints.thread_id + AND ps.checkpoint_ns = checkpoints.checkpoint_ns + AND ps.checkpoint_id = checkpoints.parent_checkpoint_id + AND ps.channel = '${TASKS}' + ORDER BY ps.idx + ) as pending_sends + FROM checkpoints\n`; const whereClause: string[] = []; @@ -260,6 +336,32 @@ CREATE TABLE IF NOT EXISTS writes ( if (rows) { for (const row of rows) { + const pendingWrites = await Promise.all( + (JSON.parse(row.pending_writes) as PendingWriteColumn[]).map( + async (write) => { + return [ + write.task_id, + write.channel, + await this.serde.loadsTyped( + write.type ?? "json", + write.value ?? "" + ), + ] as [string, string, unknown]; + } + ) + ); + + const pending_sends = await Promise.all( + (JSON.parse(row.pending_sends) as PendingSendColumn[]).map((send) => + this.serde.loadsTyped(send.type ?? "json", send.value ?? "") + ) + ); + + const checkpoint = { + ...(await this.serde.loadsTyped(row.type ?? "json", row.checkpoint)), + pending_sends, + } as Checkpoint; + yield { config: { configurable: { @@ -268,10 +370,7 @@ CREATE TABLE IF NOT EXISTS writes ( checkpoint_id: row.checkpoint_id, }, }, - checkpoint: (await this.serde.loadsTyped( - row.type ?? "json", - row.checkpoint - )) as Checkpoint, + checkpoint, metadata: (await this.serde.loadsTyped( row.type ?? "json", row.metadata @@ -285,6 +384,7 @@ CREATE TABLE IF NOT EXISTS writes ( }, } : undefined, + pendingWrites, }; } } @@ -297,7 +397,19 @@ CREATE TABLE IF NOT EXISTS writes ( ): Promise { this.setup(); - const [type1, serializedCheckpoint] = this.serde.dumpsTyped(checkpoint); + if (!config.configurable) { + throw new Error("Empty configuration supplied."); + } + + if (!config.configurable?.thread_id) { + throw new Error("Missing thread_id field in config.configurable."); + } + + const preparedCheckpoint: Partial = copyCheckpoint(checkpoint); + delete preparedCheckpoint.pending_sends; + + const [type1, serializedCheckpoint] = + this.serde.dumpsTyped(preparedCheckpoint); const [type2, serializedMetadata] = this.serde.dumpsTyped(metadata); if (type1 !== type2) { throw new Error( @@ -336,6 +448,18 @@ CREATE TABLE IF NOT EXISTS writes ( ): Promise { this.setup(); + if (!config.configurable) { + throw new Error("Empty configuration supplied."); + } + + if (!config.configurable?.thread_id) { + throw new Error("Missing thread_id field in config.configurable."); + } + + if (!config.configurable?.checkpoint_id) { + throw new Error("Missing checkpoint_id field in config.configurable."); + } + const stmt = this.db.prepare(` INSERT OR REPLACE INTO writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) diff --git a/libs/checkpoint-validation/src/spec/list.ts b/libs/checkpoint-validation/src/spec/list.ts index d9f63097..41995bed 100644 --- a/libs/checkpoint-validation/src/spec/list.ts +++ b/libs/checkpoint-validation/src/spec/list.ts @@ -116,14 +116,11 @@ export function listTests( } else { expect(actualTuplesMap.size).toEqual(expectedTuplesMap.size); for (const [key, value] of actualTuplesMap.entries()) { - // TODO: MongoDBSaver and SQLiteSaver don't return pendingWrites on list, so we need to special case them + // TODO: MongoDBSaver doesn't return pendingWrites on list, so we need to special case them // see: https://github.com/langchain-ai/langgraphjs/issues/589 - // see: https://github.com/langchain-ai/langgraphjs/issues/590 const checkpointerIncludesPendingWritesOnList = initializer.checkpointerName !== - "@langchain/langgraph-checkpoint-mongodb" && - initializer.checkpointerName !== - "@langchain/langgraph-checkpoint-sqlite"; + "@langchain/langgraph-checkpoint-mongodb"; const expectedTuple = expectedTuplesMap.get(key); if (!checkpointerIncludesPendingWritesOnList) {