Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(checkpoint-sqlite): populate pending_sends and pendingWrites #631

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 180 additions & 56 deletions libs/checkpoint-sqlite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import {
type SerializerProtocol,
type PendingWrite,
type CheckpointMetadata,
TASKS,
copyCheckpoint,
} from "@langchain/langgraph-checkpoint";

interface CheckpointRow {
Expand All @@ -18,17 +20,20 @@ interface CheckpointRow {
checkpoint_id: string;
checkpoint_ns?: string;
type?: string;
pending_writes: string;
pending_sends: string;
}

interface WritesRow {
thread_id: string;
checkpoint_ns: string;
checkpoint_id: string;
interface PendingWriteColumn {
task_id: string;
idx: number;
channel: string;
type?: string;
value?: string;
type: string;
value: string;
}

interface PendingSendColumn {
type: string;
value: string;
}

// In the `SqliteSaver.list` method, we need to sanitize the `options.filter` argument to ensure it only contains keys
Expand Down Expand Up @@ -113,20 +118,57 @@ CREATE TABLE IF NOT EXISTS writes (
checkpoint_ns = "",
checkpoint_id,
} = config.configurable ?? {};
let row: CheckpointRow;
const sql = `
SELECT
thread_id,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
type,
checkpoint,
metadata,
(
SELECT
json_group_array(
json_object(
'task_id', pw.task_id,
'channel', pw.channel,
'type', pw.type,
'value', CAST(pw.value AS TEXT)
)
)
FROM writes as pw
WHERE pw.thread_id = checkpoints.thread_id
AND pw.checkpoint_ns = checkpoints.checkpoint_ns
AND pw.checkpoint_id = checkpoints.checkpoint_id
) as pending_writes,
(
SELECT
json_group_array(
json_object(
'type', ps.type,
'value', CAST(ps.value AS TEXT)
)
)
FROM writes as ps
WHERE ps.thread_id = checkpoints.thread_id
AND ps.checkpoint_ns = checkpoints.checkpoint_ns
AND ps.checkpoint_id = checkpoints.parent_checkpoint_id
AND ps.channel = '${TASKS}'
ORDER BY ps.idx
) as pending_sends
FROM checkpoints
WHERE thread_id = ? AND checkpoint_ns = ? ${
checkpoint_id
? "AND checkpoint_id = ?"
: "ORDER BY checkpoint_id DESC LIMIT 1"
}`;

const args = [thread_id, checkpoint_ns];
if (checkpoint_id) {
row = this.db
.prepare(
`SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?`
)
.get(thread_id, checkpoint_ns, checkpoint_id) as CheckpointRow;
} else {
row = this.db
.prepare(
`SELECT thread_id, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? AND checkpoint_ns = ? ORDER BY checkpoint_id DESC LIMIT 1`
)
.get(thread_id, checkpoint_ns) as CheckpointRow;
args.push(checkpoint_id);
}
const row = this.db.prepare(sql).get(...args) as CheckpointRow;
if (row === undefined) {
return undefined;
}
Expand All @@ -146,31 +188,36 @@ CREATE TABLE IF NOT EXISTS writes (
) {
throw new Error("Missing thread_id or checkpoint_id");
}
// find any pending writes
const pendingWritesRows = this.db
.prepare(
`SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?`
)
.all(
finalConfig.configurable.thread_id.toString(),
checkpoint_ns,
finalConfig.configurable.checkpoint_id.toString()
) as WritesRow[];

const pendingWrites = await Promise.all(
pendingWritesRows.map(async (row) => {
return [
row.task_id,
row.channel,
await this.serde.loadsTyped(row.type ?? "json", row.value ?? ""),
] as [string, string, unknown];
})
(JSON.parse(row.pending_writes) as PendingWriteColumn[]).map(
async (write) => {
return [
write.task_id,
write.channel,
await this.serde.loadsTyped(
write.type ?? "json",
write.value ?? ""
),
] as [string, string, unknown];
}
)
);

const pending_sends = await Promise.all(
(JSON.parse(row.pending_sends) as PendingSendColumn[]).map((send) =>
this.serde.loadsTyped(send.type ?? "json", send.value ?? "")
)
);

const checkpoint = {
...(await this.serde.loadsTyped(row.type ?? "json", row.checkpoint)),
pending_sends,
} as Checkpoint;

return {
checkpoint,
config: finalConfig,
checkpoint: (await this.serde.loadsTyped(
row.type ?? "json",
row.checkpoint
)) as Checkpoint,
metadata: (await this.serde.loadsTyped(
row.type ?? "json",
row.metadata
Expand All @@ -196,17 +243,46 @@ CREATE TABLE IF NOT EXISTS writes (
this.setup();
const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns;

let sql =
`SELECT\n` +
" thread_id,\n" +
" checkpoint_ns,\n" +
" checkpoint_id,\n" +
" parent_checkpoint_id,\n" +
" type,\n" +
" checkpoint,\n" +
" metadata\n" +
"FROM checkpoints\n";
let sql = `
SELECT
thread_id,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
type,
checkpoint,
metadata,
(
SELECT
json_group_array(
json_object(
'task_id', pw.task_id,
'channel', pw.channel,
'type', pw.type,
'value', CAST(pw.value AS TEXT)
)
)
FROM writes as pw
WHERE pw.thread_id = checkpoints.thread_id
AND pw.checkpoint_ns = checkpoints.checkpoint_ns
AND pw.checkpoint_id = checkpoints.checkpoint_id
) as pending_writes,
(
SELECT
json_group_array(
json_object(
'type', ps.type,
'value', CAST(ps.value AS TEXT)
)
)
FROM writes as ps
WHERE ps.thread_id = checkpoints.thread_id
AND ps.checkpoint_ns = checkpoints.checkpoint_ns
AND ps.checkpoint_id = checkpoints.parent_checkpoint_id
AND ps.channel = '${TASKS}'
ORDER BY ps.idx
) as pending_sends
FROM checkpoints\n`;

const whereClause: string[] = [];

Expand Down Expand Up @@ -260,6 +336,32 @@ CREATE TABLE IF NOT EXISTS writes (

if (rows) {
for (const row of rows) {
const pendingWrites = await Promise.all(
(JSON.parse(row.pending_writes) as PendingWriteColumn[]).map(
async (write) => {
return [
write.task_id,
write.channel,
await this.serde.loadsTyped(
write.type ?? "json",
write.value ?? ""
),
] as [string, string, unknown];
}
)
);

const pending_sends = await Promise.all(
(JSON.parse(row.pending_sends) as PendingSendColumn[]).map((send) =>
this.serde.loadsTyped(send.type ?? "json", send.value ?? "")
)
);

const checkpoint = {
...(await this.serde.loadsTyped(row.type ?? "json", row.checkpoint)),
pending_sends,
} as Checkpoint;

yield {
config: {
configurable: {
Expand All @@ -268,10 +370,7 @@ CREATE TABLE IF NOT EXISTS writes (
checkpoint_id: row.checkpoint_id,
},
},
checkpoint: (await this.serde.loadsTyped(
row.type ?? "json",
row.checkpoint
)) as Checkpoint,
checkpoint,
metadata: (await this.serde.loadsTyped(
row.type ?? "json",
row.metadata
Expand All @@ -285,6 +384,7 @@ CREATE TABLE IF NOT EXISTS writes (
},
}
: undefined,
pendingWrites,
};
}
}
Expand All @@ -297,7 +397,19 @@ CREATE TABLE IF NOT EXISTS writes (
): Promise<RunnableConfig> {
this.setup();

const [type1, serializedCheckpoint] = this.serde.dumpsTyped(checkpoint);
if (!config.configurable) {
throw new Error("Empty configuration supplied.");
}

if (!config.configurable?.thread_id) {
throw new Error("Missing thread_id field in config.configurable.");
}

const preparedCheckpoint: Partial<Checkpoint> = copyCheckpoint(checkpoint);
delete preparedCheckpoint.pending_sends;

const [type1, serializedCheckpoint] =
this.serde.dumpsTyped(preparedCheckpoint);
const [type2, serializedMetadata] = this.serde.dumpsTyped(metadata);
if (type1 !== type2) {
throw new Error(
Expand Down Expand Up @@ -336,6 +448,18 @@ CREATE TABLE IF NOT EXISTS writes (
): Promise<void> {
this.setup();

if (!config.configurable) {
throw new Error("Empty configuration supplied.");
}

if (!config.configurable?.thread_id) {
throw new Error("Missing thread_id field in config.configurable.");
}

if (!config.configurable?.checkpoint_id) {
throw new Error("Missing checkpoint_id field in config.configurable.");
}

const stmt = this.db.prepare(`
INSERT OR REPLACE INTO writes
(thread_id, checkpoint_ns, checkpoint_id, task_id, idx, channel, type, value)
Expand Down
7 changes: 2 additions & 5 deletions libs/checkpoint-validation/src/spec/list.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,11 @@ export function listTests<T extends BaseCheckpointSaver>(
} else {
expect(actualTuplesMap.size).toEqual(expectedTuplesMap.size);
for (const [key, value] of actualTuplesMap.entries()) {
// TODO: MongoDBSaver and SQLiteSaver don't return pendingWrites on list, so we need to special case them
// TODO: MongoDBSaver doesn't return pendingWrites on list, so we need to special case them
// see: https://github.com/langchain-ai/langgraphjs/issues/589
// see: https://github.com/langchain-ai/langgraphjs/issues/590
const checkpointerIncludesPendingWritesOnList =
initializer.checkpointerName !==
"@langchain/langgraph-checkpoint-mongodb" &&
initializer.checkpointerName !==
"@langchain/langgraph-checkpoint-sqlite";
"@langchain/langgraph-checkpoint-mongodb";

const expectedTuple = expectedTuplesMap.get(key);
if (!checkpointerIncludesPendingWritesOnList) {
Expand Down
Loading