Skip to content

Commit

Permalink
fix(checkpoint-mongodb): apply filters correctly in list method
Browse files Browse the repository at this point in the history
  • Loading branch information
benjamincburns committed Oct 23, 2024
1 parent 06b546a commit 6ff5b9f
Show file tree
Hide file tree
Showing 12 changed files with 618 additions and 85 deletions.
5 changes: 4 additions & 1 deletion libs/checkpoint-mongodb/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
"author": "LangChain",
"license": "MIT",
"dependencies": {
"mongodb": "^6.8.0"
"mongodb": "^6.8.0",
"zod": "^3.23.8"
},
"peerDependencies": {
"@langchain/core": ">=0.2.31 <0.4.0",
Expand All @@ -43,6 +44,7 @@
"@langchain/scripts": ">=0.1.3 <0.2.0",
"@swc/core": "^1.3.90",
"@swc/jest": "^0.2.29",
"@testcontainers/mongodb": "^10.13.2",
"@tsconfig/recommended": "^1.0.3",
"@types/better-sqlite3": "^7.6.9",
"@types/uuid": "^10",
Expand All @@ -62,6 +64,7 @@
"prettier": "^2.8.3",
"release-it": "^17.6.0",
"rollup": "^4.23.0",
"testcontainers": "^10.13.2",
"ts-jest": "^29.1.0",
"tsx": "^4.7.0",
"typescript": "^4.9.5 || ^5.4.5"
Expand Down
204 changes: 163 additions & 41 deletions libs/checkpoint-mongodb/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,25 @@ import {
type PendingWrite,
type CheckpointMetadata,
CheckpointPendingWrite,
validCheckpointMetadataKeys,
} from "@langchain/langgraph-checkpoint";
import { applyMigrations, needsMigration } from "./migrations/index.js";
import {
checkpointCollectionName,
checkpointWritesCollectionName,
CollectionNameMapping,
schemaVersionCollectionName,
} from "./migrations/base.js";

export * from "./migrations/index.js";

// increment this whenever the structure of the database changes in a way that would require a migration
const CURRENT_SCHEMA_VERSION = 1;

export type MongoDBSaverParams = {
client: MongoClient;
dbName?: string;
checkpointCollectionName?: string;
checkpointWritesCollectionName?: string;
collectionNameMapping?: Partial<CollectionNameMapping>;
};

/**
Expand All @@ -26,26 +38,129 @@ export class MongoDBSaver extends BaseCheckpointSaver {

protected db: MongoDatabase;

checkpointCollectionName = "checkpoints";
private setupPromise: Promise<void> | undefined;

checkpointWritesCollectionName = "checkpoint_writes";
private collectionNameMapping: CollectionNameMapping;

constructor(
{
client,
dbName,
checkpointCollectionName,
checkpointWritesCollectionName,
}: MongoDBSaverParams,
{ client, dbName, collectionNameMapping }: MongoDBSaverParams,
serde?: SerializerProtocol
) {
super(serde);
this.client = client;
this.db = this.client.db(dbName);
this.checkpointCollectionName =
checkpointCollectionName ?? this.checkpointCollectionName;
this.checkpointWritesCollectionName =
checkpointWritesCollectionName ?? this.checkpointWritesCollectionName;
this.collectionNameMapping = {
[checkpointCollectionName]: checkpointCollectionName,
[checkpointWritesCollectionName]: checkpointWritesCollectionName,
[schemaVersionCollectionName]: schemaVersionCollectionName,
...collectionNameMapping,
};
}

/**
* Runs async setup tasks if they haven't been run yet.
*/
async setup(): Promise<void> {
if (this.setupPromise) {
return this.setupPromise;
}
this.setupPromise = this.initializeSchemaVersion();
return this.setupPromise;
}

private async isDatabaseEmpty(): Promise<boolean> {
const results = await Promise.all(
Object.values(this.collectionNameMapping)
.filter(
(collectionName) => collectionName !== schemaVersionCollectionName
)
.map(async (collectionName) => {
const collection = this.db.collection(collectionName);
// set a limit of 1 to stop scanning if any documents are found
const count = await collection.countDocuments({}, { limit: 1 });
return count === 0;
})
);

return results.every((result) => result);
}

private async initializeSchemaVersion(): Promise<void> {
const schemaVersionCollection = this.db.collection(
this.collectionNameMapping[schemaVersionCollectionName]
);

// empty database, no migrations needed - just set the schema version and move on
if (await this.isDatabaseEmpty()) {
const schemaVersionCollection = this.db.collection(
this.collectionNameMapping[schemaVersionCollectionName]
);

const versionDoc = await schemaVersionCollection.findOne({});
if (!versionDoc) {
await schemaVersionCollection.insertOne({
version: CURRENT_SCHEMA_VERSION,
});
}
} else {
// non-empty database, check if migrations are needed
const dbNeedsMigration = await needsMigration({
client: this.client,
dbName: this.db.databaseName,
collectionNameMapping: this.collectionNameMapping,
serializer: this.serde,
currentSchemaVersion: CURRENT_SCHEMA_VERSION,
});

if (dbNeedsMigration) {
throw new Error(
`Database needs migration. Call the migrate() method to migrate the database.`
);
}

// always defined if dbNeedsMigration is false
const versionDoc = (await schemaVersionCollection.findOne({}))!;

if (versionDoc.version == null) {
throw new Error(
`BUG: Database schema version is corrupt. Manual intervention required.`
);
}

if (versionDoc.version > CURRENT_SCHEMA_VERSION) {
throw new Error(
`Database created with newer version of checkpoint-mongodb. This version supports schema version ` +
`${CURRENT_SCHEMA_VERSION} but the database was created with schema version ${versionDoc.version}.`
);
}

if (versionDoc.version < CURRENT_SCHEMA_VERSION) {
throw new Error(
`BUG: Schema version ${versionDoc.version} is outdated (should be >= ${CURRENT_SCHEMA_VERSION}), but no ` +
`migration wants to execute.`
);
}
}
}

async migrate() {
if (
await needsMigration({
client: this.client,
dbName: this.db.databaseName,
collectionNameMapping: this.collectionNameMapping,
serializer: this.serde,
currentSchemaVersion: CURRENT_SCHEMA_VERSION,
})
) {
await applyMigrations({
client: this.client,
dbName: this.db.databaseName,
collectionNameMapping: this.collectionNameMapping,
serializer: this.serde,
currentSchemaVersion: CURRENT_SCHEMA_VERSION,
});
}
}

/**
Expand All @@ -55,6 +170,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
* for the given thread ID is retrieved.
*/
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
await this.setup();

const {
thread_id,
checkpoint_ns = "",
Expand All @@ -71,7 +188,7 @@ export class MongoDBSaver extends BaseCheckpointSaver {
query = { thread_id, checkpoint_ns };
}
const result = await this.db
.collection(this.checkpointCollectionName)
.collection(this.collectionNameMapping[checkpointCollectionName])
.find(query)
.sort("checkpoint_id", -1)
.limit(1)
Expand All @@ -90,7 +207,7 @@ export class MongoDBSaver extends BaseCheckpointSaver {
doc.checkpoint.value()
)) as Checkpoint;
const serializedWrites = await this.db
.collection(this.checkpointWritesCollectionName)
.collection(this.collectionNameMapping[checkpointWritesCollectionName])
.find(configurableValues)
.toArray();
const pendingWrites: CheckpointPendingWrite[] = await Promise.all(
Expand All @@ -109,10 +226,7 @@ export class MongoDBSaver extends BaseCheckpointSaver {
config: { configurable: configurableValues },
checkpoint,
pendingWrites,
metadata: (await this.serde.loadsTyped(
doc.type,
doc.metadata.value()
)) as CheckpointMetadata,
metadata: doc.metadata as CheckpointMetadata,
parentConfig:
doc.parent_checkpoint_id != null
? {
Expand All @@ -135,6 +249,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
config: RunnableConfig,
options?: CheckpointListOptions
): AsyncGenerator<CheckpointTuple> {
await this.setup();

const { limit, before, filter } = options ?? {};
const query: Record<string, unknown> = {};

Expand All @@ -150,17 +266,24 @@ export class MongoDBSaver extends BaseCheckpointSaver {
}

if (filter) {
Object.entries(filter).forEach(([key, value]) => {
query[`metadata.${key}`] = value;
});
Object.entries(filter)
.filter(
([key, value]) =>
validCheckpointMetadataKeys.includes(
key as keyof CheckpointMetadata
) && value !== undefined
)
.forEach(([key, value]) => {
query[`metadata.${key}`] = value;
});
}

if (before) {
query.checkpoint_id = { $lt: before.configurable?.checkpoint_id };
}

let result = this.db
.collection(this.checkpointCollectionName)
.collection(this.collectionNameMapping[checkpointCollectionName])
.find(query)
.sort("checkpoint_id", -1);

Expand All @@ -173,10 +296,7 @@ export class MongoDBSaver extends BaseCheckpointSaver {
doc.type,
doc.checkpoint.value()
)) as Checkpoint;
const metadata = (await this.serde.loadsTyped(
doc.type,
doc.metadata.value()
)) as CheckpointMetadata;
const metadata = doc.metadata as CheckpointMetadata;

yield {
config: {
Expand Down Expand Up @@ -210,6 +330,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
checkpoint: Checkpoint,
metadata: CheckpointMetadata
): Promise<RunnableConfig> {
await this.setup();

const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns ?? "";
const checkpoint_id = checkpoint.id;
Expand All @@ -220,28 +342,26 @@ export class MongoDBSaver extends BaseCheckpointSaver {
}
const [checkpointType, serializedCheckpoint] =
this.serde.dumpsTyped(checkpoint);
const [metadataType, serializedMetadata] = this.serde.dumpsTyped(metadata);
if (checkpointType !== metadataType) {
throw new Error("Mismatched checkpoint and metadata types.");
}
const doc = {
parent_checkpoint_id: config.configurable?.checkpoint_id,
type: checkpointType,
checkpoint: serializedCheckpoint,
metadata: serializedMetadata,
metadata,
};
const upsertQuery = {
thread_id,
checkpoint_ns,
checkpoint_id,
};
await this.db.collection(this.checkpointCollectionName).updateOne(
upsertQuery,
{
$set: doc,
},
{ upsert: true }
);
await this.db
.collection(this.collectionNameMapping[checkpointCollectionName])
.updateOne(
upsertQuery,
{
$set: doc,
},
{ upsert: true }
);
return {
configurable: {
thread_id,
Expand All @@ -259,6 +379,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
writes: PendingWrite[],
taskId: string
): Promise<void> {
await this.setup();

const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns;
const checkpoint_id = config.configurable?.checkpoint_id;
Expand Down Expand Up @@ -299,7 +421,7 @@ export class MongoDBSaver extends BaseCheckpointSaver {
});

await this.db
.collection(this.checkpointWritesCollectionName)
.collection(this.collectionNameMapping[checkpointWritesCollectionName])
.bulkWrite(operations);
}
}
Loading

0 comments on commit 6ff5b9f

Please sign in to comment.