Skip to content

Commit

Permalink
fix(checkpoint-sqlite): list method bug fixes (#582)
Browse files Browse the repository at this point in the history
  • Loading branch information
benjamincburns authored Oct 20, 2024
1 parent d17226c commit b76700e
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 15 deletions.
92 changes: 84 additions & 8 deletions libs/checkpoint-sqlite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,33 @@ interface WritesRow {
value?: string;
}

// In the `SqliteSaver.list` method, we need to sanitize the `options.filter` argument to ensure it only contains keys
// that are part of the `CheckpointMetadata` type. The lines below ensure that we get compile-time errors if the list
// of keys that we use is out of sync with the `CheckpointMetadata` type.
const checkpointMetadataKeys = ["source", "step", "writes", "parents"] as const;

type CheckKeys<T, K extends readonly (keyof T)[]> = [K[number]] extends [
keyof T
]
? [keyof T] extends [K[number]]
? K
: never
: never;

function validateKeys<T, K extends readonly (keyof T)[]>(
keys: CheckKeys<T, K>
): K {
return keys;
}

// If this line fails to compile, the list of keys that we use in the `SqliteSaver.list` method is out of sync with the
// `CheckpointMetadata` type. In that case, just update `checkpointMetadataKeys` to contain all the keys in
// `CheckpointMetadata`
const validCheckpointMetadataKeys = validateKeys<
CheckpointMetadata,
typeof checkpointMetadataKeys
>(checkpointMetadataKeys);

export class SqliteSaver extends BaseCheckpointSaver {
db: DatabaseType;

Expand Down Expand Up @@ -165,19 +192,68 @@ CREATE TABLE IF NOT EXISTS writes (
config: RunnableConfig,
options?: CheckpointListOptions
): AsyncGenerator<CheckpointTuple> {
const { limit, before } = options ?? {};
const { limit, before, filter } = options ?? {};
this.setup();
const thread_id = config.configurable?.thread_id;
let sql = `SELECT thread_id, checkpoint_ns, checkpoint_id, parent_checkpoint_id, type, checkpoint, metadata FROM checkpoints WHERE thread_id = ? ${
before ? "AND checkpoint_id < ?" : ""
} ORDER BY checkpoint_id DESC`;
if (limit) {
sql += ` LIMIT ${limit}`;
const checkpoint_ns = config.configurable?.checkpoint_ns;

let sql =
`SELECT\n` +
" thread_id,\n" +
" checkpoint_ns,\n" +
" checkpoint_id,\n" +
" parent_checkpoint_id,\n" +
" type,\n" +
" checkpoint,\n" +
" metadata\n" +
"FROM checkpoints\n";

const whereClause: string[] = [];

if (thread_id) {
whereClause.push("thread_id = ?");
}

if (checkpoint_ns !== undefined && checkpoint_ns !== null) {
whereClause.push("checkpoint_ns = ?");
}

if (before?.configurable?.checkpoint_id !== undefined) {
whereClause.push("checkpoint_id < ?");
}
const args = [thread_id, before?.configurable?.checkpoint_id].filter(
Boolean

const sanitizedFilter = Object.fromEntries(
Object.entries(filter ?? {}).filter(
([key, value]) =>
value !== undefined &&
validCheckpointMetadataKeys.includes(key as keyof CheckpointMetadata)
)
);

whereClause.push(
...Object.entries(sanitizedFilter).map(
([key]) => `jsonb(CAST(metadata AS TEXT))->'$.${key}' = ?`
)
);

if (whereClause.length > 0) {
sql += `WHERE\n ${whereClause.join(" AND\n ")}\n`;
}

sql += "\nORDER BY checkpoint_id DESC";

if (limit) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
sql += ` LIMIT ${parseInt(limit as any, 10)}`; // parseInt here (with cast to make TS happy) to sanitize input, as limit may be user-provided
}

const args = [
thread_id,
checkpoint_ns,
before?.configurable?.checkpoint_id,
...Object.values(sanitizedFilter).map((value) => JSON.stringify(value)),
].filter((value) => value !== undefined && value !== null);

const rows: CheckpointRow[] = this.db
.prepare(sql)
.all(...args) as CheckpointRow[];
Expand Down
26 changes: 19 additions & 7 deletions libs/checkpoint-sqlite/src/tests/checkpoints.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,12 @@ describe("SqliteSaver", () => {
},
},
checkpoint2,
{ source: "update", step: -1, writes: null, parents: {} }
{
source: "update",
step: -1,
writes: null,
parents: { "": checkpoint1.id },
}
);

// verify that parentTs is set and retrieved correctly for second checkpoint
Expand All @@ -119,18 +124,25 @@ describe("SqliteSaver", () => {
});

// list checkpoints
const checkpointTupleGenerator = await sqliteSaver.list({
configurable: { thread_id: "1" },
});
const checkpointTupleGenerator = await sqliteSaver.list(
{
configurable: { thread_id: "1" },
},
{
filter: {
source: "update",
step: -1,
parents: { "": checkpoint1.id },
},
}
);
const checkpointTuples: CheckpointTuple[] = [];
for await (const checkpoint of checkpointTupleGenerator) {
checkpointTuples.push(checkpoint);
}
expect(checkpointTuples.length).toBe(2);
expect(checkpointTuples.length).toBe(1);

const checkpointTuple1 = checkpointTuples[0];
const checkpointTuple2 = checkpointTuples[1];
expect(checkpointTuple1.checkpoint.ts).toBe("2024-04-20T17:19:07.952Z");
expect(checkpointTuple2.checkpoint.ts).toBe("2024-04-19T17:19:07.952Z");
});
});

0 comments on commit b76700e

Please sign in to comment.