From d17aa376f4c113d7979051f60013518a37a5706b Mon Sep 17 00:00:00 2001 From: Ben Burns <803016+benjamincburns@users.noreply.github.com> Date: Thu, 24 Oct 2024 22:51:58 +1300 Subject: [PATCH] fix(checkpoint-sqlite): populate pending_sends and pendingWrites fixes #590 --- libs/checkpoint-sqlite/src/index.ts | 255 ++++++++++++++++---- libs/checkpoint-validation/src/spec/list.ts | 7 +- 2 files changed, 210 insertions(+), 52 deletions(-) diff --git a/libs/checkpoint-sqlite/src/index.ts b/libs/checkpoint-sqlite/src/index.ts index 6740bb22..b0ca91c7 100644 --- a/libs/checkpoint-sqlite/src/index.ts +++ b/libs/checkpoint-sqlite/src/index.ts @@ -8,6 +8,8 @@ import { type SerializerProtocol, type PendingWrite, type CheckpointMetadata, + TASKS, + copyCheckpoint, } from "@langchain/langgraph-checkpoint"; interface CheckpointRow { @@ -18,17 +20,20 @@ interface CheckpointRow { checkpoint_id: string; checkpoint_ns?: string; type?: string; + pending_writes: string; + pending_sends: string; } -interface WritesRow { - thread_id: string; - checkpoint_ns: string; - checkpoint_id: string; +interface PendingWriteColumn { task_id: string; - idx: number; channel: string; - type?: string; - value?: string; + type: string; + value: string; +} + +interface PendingSendColumn { + type: string; + value: string; } // In the `SqliteSaver.list` method, we need to sanitize the `options.filter` argument to ensure it only contains keys @@ -113,20 +118,93 @@ CREATE TABLE IF NOT EXISTS writes ( checkpoint_ns = "", checkpoint_id, } = config.configurable ?? {}; - let row: CheckpointRow; + const sql = checkpoint_id + ? "SELECT\n" + + " thread_id,\n" + + " checkpoint_ns,\n" + + " checkpoint_id,\n" + + " parent_checkpoint_id,\n" + + " type,\n" + + " checkpoint,\n" + + " metadata,\n" + + " (\n" + + " SELECT\n" + + " json_group_array(\n" + + " json_object(\n" + + " 'task_id', pw.task_id,\n" + + " 'channel', pw.channel,\n" + + " 'type', pw.type,\n" + + " 'value', CAST(pw.value AS TEXT)\n" + + " )\n" + + " )\n" + + " FROM writes as pw\n" + + " WHERE pw.thread_id = checkpoints.thread_id\n" + + " AND pw.checkpoint_ns = checkpoints.checkpoint_ns\n" + + " AND pw.checkpoint_id = checkpoints.checkpoint_id\n" + + " ) as pending_writes,\n" + + " (\n" + + " SELECT\n" + + " json_group_array(\n" + + " json_object(\n" + + " 'type', ps.type,\n" + + " 'value', CAST(ps.value AS TEXT)\n" + + " )\n" + + " )\n" + + " FROM writes as ps\n" + + " WHERE ps.thread_id = checkpoints.thread_id\n" + + " AND ps.checkpoint_ns = checkpoints.checkpoint_ns\n" + + " AND ps.checkpoint_id = checkpoints.parent_checkpoint_id\n" + + ` AND ps.channel = '${TASKS}'\n` + + " ORDER BY ps.idx\n" + + " ) as pending_sends\n" + + "FROM checkpoints\n" + + "WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?" + : "SELECT\n" + + " thread_id,\n" + + " checkpoint_ns,\n" + + " checkpoint_id,\n" + + " parent_checkpoint_id,\n" + + " type,\n" + + " checkpoint,\n" + + " metadata,\n" + + " (\n" + + " SELECT\n" + + " json_group_array(\n" + + " json_object(\n" + + " 'task_id', pw.task_id,\n" + + " 'channel', pw.channel,\n" + + " 'type', pw.type,\n" + + " 'value', CAST(pw.value AS TEXT)\n" + + " )\n" + + " )\n" + + " FROM writes as pw\n" + + " WHERE pw.thread_id = checkpoints.thread_id\n" + + " AND pw.checkpoint_ns = checkpoints.checkpoint_ns\n" + + " AND pw.checkpoint_id = checkpoints.checkpoint_id\n" + + " ) as pending_writes,\n" + + " (\n" + + " SELECT\n" + + " json_group_array(\n" + + " json_object(\n" + + " 'type', ps.type,\n" + + " 'value', CAST(ps.value AS TEXT)\n" + + " )\n" + + " )\n" + + " FROM writes as ps\n" + + " WHERE ps.thread_id = checkpoints.thread_id\n" + + " AND ps.checkpoint_ns = checkpoints.checkpoint_ns\n" + + " AND ps.checkpoint_id = checkpoints.parent_checkpoint_id\n" + + ` AND ps.channel = '${TASKS}'\n` + + " ORDER BY ps.idx\n" + + " ) as pending_sends\n" + + "FROM checkpoints\n" + + "WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1"; + + const args = [thread_id, checkpoint_ns]; if (checkpoint_id) { - row = this.db - .prepare( - `SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?` - ) - .get(thread_id, checkpoint_ns, checkpoint_id) as CheckpointRow; - } else { - row = this.db - .prepare( - `SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1` - ) - .get(thread_id, checkpoint_ns) as CheckpointRow; + args.push(checkpoint_id); } + const row = this.db.prepare(sql).get(...args) as CheckpointRow; if (row === undefined) { return undefined; } @@ -146,31 +224,36 @@ CREATE TABLE IF NOT EXISTS writes ( ) { throw new Error("Missing thread_id or checkpoint_id"); } - // find any pending writes - const pendingWritesRows = this.db - .prepare( - `SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?` - ) - .all( - finalConfig.configurable.thread_id.toString(), - checkpoint_ns, - finalConfig.configurable.checkpoint_id.toString() - ) as WritesRow[]; + const pendingWrites = await Promise.all( - pendingWritesRows.map(async (row) => { - return [ - row.task_id, - row.channel, - await this.serde.loadsTyped(row.type ?? "json", row.value ?? ""), - ] as [string, string, unknown]; - }) + (JSON.parse(row.pending_writes) as PendingWriteColumn[]).map( + async (write) => { + return [ + write.task_id, + write.channel, + await this.serde.loadsTyped( + write.type ?? "json", + write.value ?? "" + ), + ] as [string, string, unknown]; + } + ) + ); + + const pending_sends = await Promise.all( + (JSON.parse(row.pending_sends) as PendingSendColumn[]).map((send) => + this.serde.loadsTyped(send.type ?? "json", send.value ?? "") + ) ); + + const checkpoint = { + ...(await this.serde.loadsTyped(row.type ?? "json", row.checkpoint)), + pending_sends, + } as Checkpoint; + return { + checkpoint, config: finalConfig, - checkpoint: (await this.serde.loadsTyped( - row.type ?? "json", - row.checkpoint - )) as Checkpoint, metadata: (await this.serde.loadsTyped( row.type ?? "json", row.metadata @@ -196,7 +279,7 @@ CREATE TABLE IF NOT EXISTS writes ( this.setup(); const thread_id = config.configurable?.thread_id; const checkpoint_ns = config.configurable?.checkpoint_ns; - + // SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? let sql = `SELECT\n` + " thread_id,\n" + @@ -205,7 +288,37 @@ CREATE TABLE IF NOT EXISTS writes ( " parent_checkpoint_id,\n" + " type,\n" + " checkpoint,\n" + - " metadata\n" + + " metadata,\n" + + " (\n" + + " SELECT\n" + + " json_group_array(\n" + + " json_object(\n" + + " 'task_id', pw.task_id,\n" + + " 'channel', pw.channel,\n" + + " 'type', pw.type,\n" + + " 'value', CAST(pw.value AS TEXT)\n" + + " )\n" + + " )\n" + + " FROM writes as pw\n" + + " WHERE pw.thread_id = checkpoints.thread_id\n" + + " AND pw.checkpoint_ns = checkpoints.checkpoint_ns\n" + + " AND pw.checkpoint_id = checkpoints.checkpoint_id\n" + + " ) as pending_writes,\n" + + " (\n" + + " SELECT\n" + + " json_group_array(\n" + + " json_object(\n" + + " 'type', ps.type,\n" + + " 'value', CAST(ps.value AS TEXT)\n" + + " )\n" + + " )\n" + + " FROM writes as ps\n" + + " WHERE ps.thread_id = checkpoints.thread_id\n" + + " AND ps.checkpoint_ns = checkpoints.checkpoint_ns\n" + + " AND ps.checkpoint_id = checkpoints.parent_checkpoint_id\n" + + ` AND ps.channel = '${TASKS}'\n` + + " ORDER BY ps.idx\n" + + " ) as pending_sends\n" + "FROM checkpoints\n"; const whereClause: string[] = []; @@ -260,6 +373,32 @@ CREATE TABLE IF NOT EXISTS writes ( if (rows) { for (const row of rows) { + const pendingWrites = await Promise.all( + (JSON.parse(row.pending_writes) as PendingWriteColumn[]).map( + async (write) => { + return [ + write.task_id, + write.channel, + await this.serde.loadsTyped( + write.type ?? "json", + write.value ?? "" + ), + ] as [string, string, unknown]; + } + ) + ); + + const pending_sends = await Promise.all( + (JSON.parse(row.pending_sends) as PendingSendColumn[]).map((send) => + this.serde.loadsTyped(send.type ?? "json", send.value ?? "") + ) + ); + + const checkpoint = { + ...(await this.serde.loadsTyped(row.type ?? "json", row.checkpoint)), + pending_sends, + } as Checkpoint; + yield { config: { configurable: { @@ -268,10 +407,7 @@ CREATE TABLE IF NOT EXISTS writes ( checkpoint_id: row.checkpoint_id, }, }, - checkpoint: (await this.serde.loadsTyped( - row.type ?? "json", - row.checkpoint - )) as Checkpoint, + checkpoint, metadata: (await this.serde.loadsTyped( row.type ?? "json", row.metadata @@ -285,6 +421,7 @@ CREATE TABLE IF NOT EXISTS writes ( }, } : undefined, + pendingWrites, }; } } @@ -297,7 +434,19 @@ CREATE TABLE IF NOT EXISTS writes ( ): Promise { this.setup(); - const [type1, serializedCheckpoint] = this.serde.dumpsTyped(checkpoint); + if (!config.configurable) { + throw new Error("Empty configuration supplied."); + } + + if (!config.configurable?.thread_id) { + throw new Error("Missing thread_id field in config.configurable."); + } + + const preparedCheckpoint: Partial = copyCheckpoint(checkpoint); + delete preparedCheckpoint.pending_sends; + + const [type1, serializedCheckpoint] = + this.serde.dumpsTyped(preparedCheckpoint); const [type2, serializedMetadata] = this.serde.dumpsTyped(metadata); if (type1 !== type2) { throw new Error( @@ -336,6 +485,18 @@ CREATE TABLE IF NOT EXISTS writes ( ): Promise { this.setup(); + if (!config.configurable) { + throw new Error("Empty configuration supplied."); + } + + if (!config.configurable?.thread_id) { + throw new Error("Missing thread_id field in config.configurable."); + } + + if (!config.configurable?.checkpoint_id) { + throw new Error("Missing checkpoint_id field in config.configurable."); + } + const stmt = this.db.prepare(` INSERT OR REPLACE INTO writes (thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value) diff --git a/libs/checkpoint-validation/src/spec/list.ts b/libs/checkpoint-validation/src/spec/list.ts index d9f63097..41995bed 100644 --- a/libs/checkpoint-validation/src/spec/list.ts +++ b/libs/checkpoint-validation/src/spec/list.ts @@ -116,14 +116,11 @@ export function listTests( } else { expect(actualTuplesMap.size).toEqual(expectedTuplesMap.size); for (const [key, value] of actualTuplesMap.entries()) { - // TODO: MongoDBSaver and SQLiteSaver don't return pendingWrites on list, so we need to special case them + // TODO: MongoDBSaver doesn't return pendingWrites on list, so we need to special case them // see: https://github.com/langchain-ai/langgraphjs/issues/589 - // see: https://github.com/langchain-ai/langgraphjs/issues/590 const checkpointerIncludesPendingWritesOnList = initializer.checkpointerName !== - "@langchain/langgraph-checkpoint-mongodb" && - initializer.checkpointerName !== - "@langchain/langgraph-checkpoint-sqlite"; + "@langchain/langgraph-checkpoint-mongodb"; const expectedTuple = expectedTuplesMap.get(key); if (!checkpointerIncludesPendingWritesOnList) {