Skip to content

Commit

Permalink
fix(checkpoint-mongodb): apply filters correctly in list method
Browse files Browse the repository at this point in the history
  • Loading branch information
benjamincburns committed Oct 29, 2024
1 parent bd6da8a commit e8b90f3
Show file tree
Hide file tree
Showing 12 changed files with 563 additions and 59 deletions.
2 changes: 2 additions & 0 deletions libs/checkpoint-mongodb/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"@langchain/scripts": ">=0.1.3 <0.2.0",
"@swc/core": "^1.3.90",
"@swc/jest": "^0.2.29",
"@testcontainers/mongodb": "^10.13.2",
"@tsconfig/recommended": "^1.0.3",
"@types/better-sqlite3": "^7.6.9",
"@types/uuid": "^10",
Expand All @@ -62,6 +63,7 @@
"prettier": "^2.8.3",
"release-it": "^17.6.0",
"rollup": "^4.23.0",
"testcontainers": "^10.13.2",
"ts-jest": "^29.1.0",
"tsx": "^4.7.0",
"typescript": "^4.9.5 || ^5.4.5"
Expand Down
162 changes: 146 additions & 16 deletions libs/checkpoint-mongodb/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,21 @@ import {
type PendingWrite,
type CheckpointMetadata,
CheckpointPendingWrite,
validCheckpointMetadataKeys,
} from "@langchain/langgraph-checkpoint";
import { applyMigrations, needsMigration } from "./migrations/index.js";

export * from "./migrations/index.js";

// increment this whenever the structure of the database changes in a way that would require a migration
const CURRENT_SCHEMA_VERSION = 1;

export type MongoDBSaverParams = {
client: MongoClient;
dbName?: string;
checkpointCollectionName?: string;
checkpointWritesCollectionName?: string;
schemaVersionCollectionName?: string;
};

/**
Expand All @@ -26,16 +34,21 @@ export class MongoDBSaver extends BaseCheckpointSaver {

protected db: MongoDatabase;

private setupPromise: Promise<void> | undefined;

checkpointCollectionName = "checkpoints";

checkpointWritesCollectionName = "checkpoint_writes";

schemaVersionCollectionName = "schema_version";

constructor(
{
client,
dbName,
checkpointCollectionName,
checkpointWritesCollectionName,
schemaVersionCollectionName,
}: MongoDBSaverParams,
serde?: SerializerProtocol
) {
Expand All @@ -46,6 +59,118 @@ export class MongoDBSaver extends BaseCheckpointSaver {
checkpointCollectionName ?? this.checkpointCollectionName;
this.checkpointWritesCollectionName =
checkpointWritesCollectionName ?? this.checkpointWritesCollectionName;
this.schemaVersionCollectionName =
schemaVersionCollectionName ?? this.schemaVersionCollectionName;
}

/**
* Runs async setup tasks if they haven't been run yet.
*/
async setup(): Promise<void> {
if (this.setupPromise) {
return this.setupPromise;
}
this.setupPromise = this.initializeSchemaVersion();
return this.setupPromise;
}

private async isDatabaseEmpty(): Promise<boolean> {
const results = await Promise.all(
[this.checkpointCollectionName, this.checkpointWritesCollectionName].map(
async (collectionName) => {
const collection = this.db.collection(collectionName);
// set a limit of 1 to stop scanning if any documents are found
const count = await collection.countDocuments({}, { limit: 1 });
return count === 0;
}
)
);

return results.every((result) => result);
}

private async initializeSchemaVersion(): Promise<void> {
const schemaVersionCollection = this.db.collection(
this.schemaVersionCollectionName
);

// empty database, no migrations needed - just set the schema version and move on
if (await this.isDatabaseEmpty()) {
const schemaVersionCollection = this.db.collection(
this.schemaVersionCollectionName
);

const versionDoc = await schemaVersionCollection.findOne({});
if (!versionDoc) {
await schemaVersionCollection.insertOne({
version: CURRENT_SCHEMA_VERSION,
});
}
} else {
// non-empty database, check if migrations are needed
const dbNeedsMigration = await needsMigration({
client: this.client,
dbName: this.db.databaseName,
checkpointCollectionName: this.checkpointCollectionName,
checkpointWritesCollectionName: this.checkpointWritesCollectionName,
schemaVersionCollectionName: this.schemaVersionCollectionName,
serializer: this.serde,
currentSchemaVersion: CURRENT_SCHEMA_VERSION,
});

if (dbNeedsMigration) {
throw new Error(
`Database needs migration. Call the migrate() method to migrate the database.`
);
}

// always defined if dbNeedsMigration is false
const versionDoc = (await schemaVersionCollection.findOne({}))!;

if (versionDoc.version == null) {
throw new Error(
`BUG: Database schema version is corrupt. Manual intervention required.`
);
}

if (versionDoc.version > CURRENT_SCHEMA_VERSION) {
throw new Error(
`Database created with newer version of checkpoint-mongodb. This version supports schema version ` +
`${CURRENT_SCHEMA_VERSION} but the database was created with schema version ${versionDoc.version}.`
);
}

if (versionDoc.version < CURRENT_SCHEMA_VERSION) {
throw new Error(
`BUG: Schema version ${versionDoc.version} is outdated (should be >= ${CURRENT_SCHEMA_VERSION}), but no ` +
`migration wants to execute.`
);
}
}
}

async migrate() {
if (
await needsMigration({
client: this.client,
dbName: this.db.databaseName,
checkpointCollectionName: this.checkpointCollectionName,
checkpointWritesCollectionName: this.checkpointWritesCollectionName,
schemaVersionCollectionName: this.schemaVersionCollectionName,
serializer: this.serde,
currentSchemaVersion: CURRENT_SCHEMA_VERSION,
})
) {
await applyMigrations({
client: this.client,
dbName: this.db.databaseName,
checkpointCollectionName: this.checkpointCollectionName,
checkpointWritesCollectionName: this.checkpointWritesCollectionName,
schemaVersionCollectionName: this.schemaVersionCollectionName,
serializer: this.serde,
currentSchemaVersion: CURRENT_SCHEMA_VERSION,
});
}
}

/**
Expand All @@ -55,6 +180,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
* for the given thread ID is retrieved.
*/
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
await this.setup();

const {
thread_id,
checkpoint_ns = "",
Expand Down Expand Up @@ -109,10 +236,7 @@ export class MongoDBSaver extends BaseCheckpointSaver {
config: { configurable: configurableValues },
checkpoint,
pendingWrites,
metadata: (await this.serde.loadsTyped(
doc.type,
doc.metadata.value()
)) as CheckpointMetadata,
metadata: doc.metadata as CheckpointMetadata,
parentConfig:
doc.parent_checkpoint_id != null
? {
Expand All @@ -135,6 +259,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
config: RunnableConfig,
options?: CheckpointListOptions
): AsyncGenerator<CheckpointTuple> {
await this.setup();

const { limit, before, filter } = options ?? {};
const query: Record<string, unknown> = {};

Expand All @@ -150,9 +276,16 @@ export class MongoDBSaver extends BaseCheckpointSaver {
}

if (filter) {
Object.entries(filter).forEach(([key, value]) => {
query[`metadata.${key}`] = value;
});
Object.entries(filter)
.filter(
([key, value]) =>
validCheckpointMetadataKeys.includes(
key as keyof CheckpointMetadata
) && value !== undefined
)
.forEach(([key, value]) => {
query[`metadata.${key}`] = value;
});
}

if (before) {
Expand All @@ -173,10 +306,7 @@ export class MongoDBSaver extends BaseCheckpointSaver {
doc.type,
doc.checkpoint.value()
)) as Checkpoint;
const metadata = (await this.serde.loadsTyped(
doc.type,
doc.metadata.value()
)) as CheckpointMetadata;
const metadata = doc.metadata as CheckpointMetadata;

yield {
config: {
Expand Down Expand Up @@ -210,6 +340,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
checkpoint: Checkpoint,
metadata: CheckpointMetadata
): Promise<RunnableConfig> {
await this.setup();

const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns ?? "";
const checkpoint_id = checkpoint.id;
Expand All @@ -220,15 +352,11 @@ export class MongoDBSaver extends BaseCheckpointSaver {
}
const [checkpointType, serializedCheckpoint] =
this.serde.dumpsTyped(checkpoint);
const [metadataType, serializedMetadata] = this.serde.dumpsTyped(metadata);
if (checkpointType !== metadataType) {
throw new Error("Mismatched checkpoint and metadata types.");
}
const doc = {
parent_checkpoint_id: config.configurable?.checkpoint_id,
type: checkpointType,
checkpoint: serializedCheckpoint,
metadata: serializedMetadata,
metadata,
};
const upsertQuery = {
thread_id,
Expand Down Expand Up @@ -259,6 +387,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
writes: PendingWrite[],
taskId: string
): Promise<void> {
await this.setup();

const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns;
const checkpoint_id = config.configurable?.checkpoint_id;
Expand Down
110 changes: 110 additions & 0 deletions libs/checkpoint-mongodb/src/migrations/1_object_metadata.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import { Binary, ObjectId, Collection, Document, WithId } from "mongodb";
import { CheckpointMetadata } from "@langchain/langgraph-checkpoint";
import { Migration, MigrationParams } from "./base.js";

const BULK_WRITE_SIZE = 100;

interface OldCheckpointDocument {
parent_checkpoint_id: string | undefined;
type: string;
checkpoint: Binary;
metadata: Binary;
thread_id: string;
checkpoint_ns: string | undefined;
checkpoint_id: string;
}

interface NewCheckpointDocument {
parent_checkpoint_id: string | undefined;
type: string;
checkpoint: Binary;
metadata: CheckpointMetadata;
thread_id: string;
checkpoint_ns: string | undefined;
checkpoint_id: string;
}

export class Migration1ObjectMetadata extends Migration {
version = 1;

constructor(params: MigrationParams) {
super(params);
}

override async apply() {
const db = this.client.db(this.dbName);
const checkpointCollection = db.collection(this.checkpointCollectionName);
const schemaVersionCollection = db.collection(
this.schemaVersionCollectionName
);

// Fetch all documents from the checkpoints collection
const cursor = checkpointCollection.find({});

let updateBatch: {
id: string;
newDoc: NewCheckpointDocument;
}[] = [];

for await (const doc of cursor) {
// already migrated
if (!(doc.metadata._bsontype && doc.metadata._bsontype === "Binary")) {
continue;
}

const oldDoc = doc as WithId<OldCheckpointDocument>;

const metadata: CheckpointMetadata = await this.serializer.loadsTyped(
oldDoc.type,
oldDoc.metadata.value()
);

const newDoc: NewCheckpointDocument = {
...oldDoc,
metadata,
};

updateBatch.push({
id: doc._id.toString(),
newDoc,
});

if (updateBatch.length >= BULK_WRITE_SIZE) {
await this.flushBatch(updateBatch, checkpointCollection);
updateBatch = [];
}
}

if (updateBatch.length > 0) {
await this.flushBatch(updateBatch, checkpointCollection);
}

// Update schema version to 1
await schemaVersionCollection.updateOne(
{},
{ $set: { version: 1 } },
{ upsert: true }
);
}

private async flushBatch(
updateBatch: {
id: string;
newDoc: NewCheckpointDocument;
}[],
checkpointCollection: Collection<Document>
) {
if (updateBatch.length === 0) {
throw new Error("No updates to apply");
}

const bulkOps = updateBatch.map(({ id, newDoc: newCheckpoint }) => ({
updateOne: {
filter: { _id: new ObjectId(id) },
update: { $set: newCheckpoint },
},
}));

await checkpointCollection.bulkWrite(bulkOps);
}
}
Loading

0 comments on commit e8b90f3

Please sign in to comment.