Skip to content

Commit

Permalink
fix(checkpoint-sqlite): populate pending_sends and pendingWrites
Browse files Browse the repository at this point in the history
  • Loading branch information
benjamincburns committed Oct 25, 2024
1 parent cc6625c commit 3562068
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 61 deletions.
236 changes: 180 additions & 56 deletions libs/checkpoint-sqlite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import {
type SerializerProtocol,
type PendingWrite,
type CheckpointMetadata,
TASKS,
copyCheckpoint,
} from "@langchain/langgraph-checkpoint";

interface CheckpointRow {
Expand All @@ -18,17 +20,20 @@ interface CheckpointRow {
checkpoint_id: string;
checkpoint_ns?: string;
type?: string;
pending_writes: string;
pending_sends: string;
}

interface WritesRow {
thread_id: string;
checkpoint_ns: string;
checkpoint_id: string;
interface PendingWriteColumn {
task_id: string;
idx: number;
channel: string;
type?: string;
value?: string;
type: string;
value: string;
}

interface PendingSendColumn {
type: string;
value: string;
}

// In the `SqliteSaver.list` method, we need to sanitize the `options.filter` argument to ensure it only contains keys
Expand Down Expand Up @@ -113,20 +118,57 @@ CREATE TABLE IF NOT EXISTS writes (
checkpoint_ns = "",
checkpoint_id,
} = config.configurable ?? {};
let row: CheckpointRow;
const sql = `
SELECT
thread_id,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
type,
checkpoint,
metadata,
(
SELECT
json_group_array(
json_object(
'task_id', pw.task_id,
'channel', pw.channel,
'type', pw.type,
'value', CAST(pw.value AS TEXT)
)
)
FROM writes as pw
WHERE pw.thread_id = checkpoints.thread_id
AND pw.checkpoint_ns = checkpoints.checkpoint_ns
AND pw.checkpoint_id = checkpoints.checkpoint_id
) as pending_writes,
(
SELECT
json_group_array(
json_object(
'type', ps.type,
'value', CAST(ps.value AS TEXT)
)
)
FROM writes as ps
WHERE ps.thread_id = checkpoints.thread_id
AND ps.checkpoint_ns = checkpoints.checkpoint_ns
AND ps.checkpoint_id = checkpoints.parent_checkpoint_id
AND ps.channel = '${TASKS}'
ORDER BY ps.idx
) as pending_sends
FROM checkpoints
WHERE thread_id = ? AND checkpoint_ns = ? ${
checkpoint_id
? "AND checkpoint_id = ?"
: "ORDER BY checkpoint_id DESC LIMIT 1"
}`;

const args = [thread_id, checkpoint_ns];
if (checkpoint_id) {
row = this.db
.prepare(
`SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?`
)
.get(thread_id, checkpoint_ns, checkpoint_id) as CheckpointRow;
} else {
row = this.db
.prepare(
`SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1`
)
.get(thread_id, checkpoint_ns) as CheckpointRow;
args.push(checkpoint_id);
}
const row = this.db.prepare(sql).get(...args) as CheckpointRow;
if (row === undefined) {
return undefined;
}
Expand All @@ -146,31 +188,36 @@ CREATE TABLE IF NOT EXISTS writes (
) {
throw new Error("Missing thread_id or checkpoint_id");
}
// find any pending writes
const pendingWritesRows = this.db
.prepare(
`SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?`
)
.all(
finalConfig.configurable.thread_id.toString(),
checkpoint_ns,
finalConfig.configurable.checkpoint_id.toString()
) as WritesRow[];

const pendingWrites = await Promise.all(
pendingWritesRows.map(async (row) => {
return [
row.task_id,
row.channel,
await this.serde.loadsTyped(row.type ?? "json", row.value ?? ""),
] as [string, string, unknown];
})
(JSON.parse(row.pending_writes) as PendingWriteColumn[]).map(
async (write) => {
return [
write.task_id,
write.channel,
await this.serde.loadsTyped(
write.type ?? "json",
write.value ?? ""
),
] as [string, string, unknown];
}
)
);

const pending_sends = await Promise.all(
(JSON.parse(row.pending_sends) as PendingSendColumn[]).map((send) =>
this.serde.loadsTyped(send.type ?? "json", send.value ?? "")
)
);

const checkpoint = {
...(await this.serde.loadsTyped(row.type ?? "json", row.checkpoint)),
pending_sends,
} as Checkpoint;

return {
checkpoint,
config: finalConfig,
checkpoint: (await this.serde.loadsTyped(
row.type ?? "json",
row.checkpoint
)) as Checkpoint,
metadata: (await this.serde.loadsTyped(
row.type ?? "json",
row.metadata
Expand All @@ -196,17 +243,46 @@ CREATE TABLE IF NOT EXISTS writes (
this.setup();
const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns;

let sql =
`SELECT\n` +
" thread_id,\n" +
" checkpoint_ns,\n" +
" checkpoint_id,\n" +
" parent_checkpoint_id,\n" +
" type,\n" +
" checkpoint,\n" +
" metadata\n" +
"FROM checkpoints\n";
let sql = `
SELECT
thread_id,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
type,
checkpoint,
metadata,
(
SELECT
json_group_array(
json_object(
'task_id', pw.task_id,
'channel', pw.channel,
'type', pw.type,
'value', CAST(pw.value AS TEXT)
)
)
FROM writes as pw
WHERE pw.thread_id = checkpoints.thread_id
AND pw.checkpoint_ns = checkpoints.checkpoint_ns
AND pw.checkpoint_id = checkpoints.checkpoint_id
) as pending_writes,
(
SELECT
json_group_array(
json_object(
'type', ps.type,
'value', CAST(ps.value AS TEXT)
)
)
FROM writes as ps
WHERE ps.thread_id = checkpoints.thread_id
AND ps.checkpoint_ns = checkpoints.checkpoint_ns
AND ps.checkpoint_id = checkpoints.parent_checkpoint_id
AND ps.channel = '${TASKS}'
ORDER BY ps.idx
) as pending_sends
FROM checkpoints\n`;

const whereClause: string[] = [];

Expand Down Expand Up @@ -260,6 +336,32 @@ CREATE TABLE IF NOT EXISTS writes (

if (rows) {
for (const row of rows) {
const pendingWrites = await Promise.all(
(JSON.parse(row.pending_writes) as PendingWriteColumn[]).map(
async (write) => {
return [
write.task_id,
write.channel,
await this.serde.loadsTyped(
write.type ?? "json",
write.value ?? ""
),
] as [string, string, unknown];
}
)
);

const pending_sends = await Promise.all(
(JSON.parse(row.pending_sends) as PendingSendColumn[]).map((send) =>
this.serde.loadsTyped(send.type ?? "json", send.value ?? "")
)
);

const checkpoint = {
...(await this.serde.loadsTyped(row.type ?? "json", row.checkpoint)),
pending_sends,
} as Checkpoint;

yield {
config: {
configurable: {
Expand All @@ -268,10 +370,7 @@ CREATE TABLE IF NOT EXISTS writes (
checkpoint_id: row.checkpoint_id,
},
},
checkpoint: (await this.serde.loadsTyped(
row.type ?? "json",
row.checkpoint
)) as Checkpoint,
checkpoint,
metadata: (await this.serde.loadsTyped(
row.type ?? "json",
row.metadata
Expand All @@ -285,6 +384,7 @@ CREATE TABLE IF NOT EXISTS writes (
},
}
: undefined,
pendingWrites,
};
}
}
Expand All @@ -297,7 +397,19 @@ CREATE TABLE IF NOT EXISTS writes (
): Promise<RunnableConfig> {
this.setup();

const [type1, serializedCheckpoint] = this.serde.dumpsTyped(checkpoint);
if (!config.configurable) {
throw new Error("Empty configuration supplied.");
}

if (!config.configurable?.thread_id) {
throw new Error("Missing thread_id field in config.configurable.");
}

const preparedCheckpoint: Partial<Checkpoint> = copyCheckpoint(checkpoint);
delete preparedCheckpoint.pending_sends;

const [type1, serializedCheckpoint] =
this.serde.dumpsTyped(preparedCheckpoint);
const [type2, serializedMetadata] = this.serde.dumpsTyped(metadata);
if (type1 !== type2) {
throw new Error(
Expand Down Expand Up @@ -336,6 +448,18 @@ CREATE TABLE IF NOT EXISTS writes (
): Promise<void> {
this.setup();

if (!config.configurable) {
throw new Error("Empty configuration supplied.");
}

if (!config.configurable?.thread_id) {
throw new Error("Missing thread_id field in config.configurable.");
}

if (!config.configurable?.checkpoint_id) {
throw new Error("Missing checkpoint_id field in config.configurable.");
}

const stmt = this.db.prepare(`
INSERT OR REPLACE INTO writes
(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value)
Expand Down
7 changes: 2 additions & 5 deletions libs/checkpoint-validation/src/spec/list.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,11 @@ export function listTests<T extends BaseCheckpointSaver>(
} else {
expect(actualTuplesMap.size).toEqual(expectedTuplesMap.size);
for (const [key, value] of actualTuplesMap.entries()) {
// TODO: MongoDBSaver and SQLiteSaver don't return pendingWrites on list, so we need to special case them
// TODO: MongoDBSaver doesn't return pendingWrites on list, so we need to special case them
// see: https://github.com/langchain-ai/langgraphjs/issues/589
// see: https://github.com/langchain-ai/langgraphjs/issues/590
const checkpointerIncludesPendingWritesOnList =
initializer.checkpointerName !==
"@langchain/langgraph-checkpoint-mongodb" &&
initializer.checkpointerName !==
"@langchain/langgraph-checkpoint-sqlite";
"@langchain/langgraph-checkpoint-mongodb";

const expectedTuple = expectedTuplesMap.get(key);
if (!checkpointerIncludesPendingWritesOnList) {
Expand Down

0 comments on commit 3562068

Please sign in to comment.