Skip to content

Commit

Permalink
Support checkpoint_ns, refactor (#313)
Browse files Browse the repository at this point in the history
* Refactor checkpoint

* Fix lint and format

* Fix tests, fix UUIDv5 issue

* Fix build
  • Loading branch information
jacoblee93 authored Aug 14, 2024
1 parent 35669d8 commit a89219c
Show file tree
Hide file tree
Showing 13 changed files with 430 additions and 131 deletions.
35 changes: 16 additions & 19 deletions langgraph/src/checkpoint/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,13 @@ import { RunnableConfig } from "@langchain/core/runnables";
import { DefaultSerializer, SerializerProtocol } from "../serde/base.js";
import { uuid6 } from "./id.js";
import { SendInterface } from "../constants.js";
import {
CheckpointMetadata,
CheckpointPendingWrite,
PendingWrite,
} from "../pregel/types.js";

export interface CheckpointMetadata {
source: "input" | "loop" | "update";
/**
* The source of the checkpoint.
* - "input": The checkpoint was created from an input to invoke/stream/batch.
* - "loop": The checkpoint was created from inside the pregel loop.
* - "update": The checkpoint was created from a manual state update. */
step: number;
/**
* The step number of the checkpoint.
* -1 for the first "input" checkpoint.
* 0 for the first "loop" checkpoint.
* ... for the nth checkpoint afterwards. */
writes: Record<string, unknown> | null;
/**
* The writes that were made between the previous checkpoint and this one.
* Mapping from node name to writes emitted by that node.
*/
}
export type { CheckpointMetadata };

export interface Checkpoint<
N extends string = string,
Expand Down Expand Up @@ -128,6 +115,7 @@ export interface CheckpointTuple {
checkpoint: Checkpoint;
metadata?: CheckpointMetadata;
parentConfig?: RunnableConfig;
pendingWrites?: CheckpointPendingWrite[];
}

export abstract class BaseCheckpointSaver {
Expand Down Expand Up @@ -157,4 +145,13 @@ export abstract class BaseCheckpointSaver {
checkpoint: Checkpoint,
metadata: CheckpointMetadata
): Promise<RunnableConfig>;

/**
* Store intermediate writes linked to a checkpoint.
*/
abstract putWrites(
config: RunnableConfig,
writes: PendingWrite[],
taskId: string
): Promise<void>;
}
14 changes: 13 additions & 1 deletion langgraph/src/checkpoint/id.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
import { v6 } from "uuid";
import { v5, v6 } from "uuid";

export function uuid6(clockseq: number): string {
return v6({ clockseq });
}

// Skip UUID validation check, since UUID6s
// generated with negative clockseq are not
// technically compliant, but still work.
// See: https://github.com/uuidjs/uuid/issues/511
export function uuid5(name: string, namespace: string): string {
const namespaceBytes = namespace
.replace(/-/g, "")
.match(/.{2}/g)!
.map((byte) => parseInt(byte, 16));
return v5(name, new Uint8Array(namespaceBytes));
}
243 changes: 198 additions & 45 deletions langgraph/src/checkpoint/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,102 @@ import {
CheckpointTuple,
} from "./base.js";
import { SerializerProtocol } from "../serde/base.js";
import { CheckpointPendingWrite, PendingWrite } from "../pregel/types.js";

function _generateKey(
threadId: string,
checkpointNamespace: string,
checkpointId: string
) {
return JSON.stringify([threadId, checkpointNamespace, checkpointId]);
}

export class MemorySaver extends BaseCheckpointSaver {
storage: Record<string, Record<string, [string, string]>>;
// thread ID -> checkpoint namespace -> checkpoint ID -> checkpoint mapping
storage: Record<
string,
Record<string, Record<string, [string, string, string | undefined]>>
> = {};

writes: Record<string, CheckpointPendingWrite[]> = {};

constructor(serde?: SerializerProtocol<unknown>) {
super(serde);
this.storage = {};
}

async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns ?? "";
const checkpoint_id = config.configurable?.checkpoint_id;
const checkpoints = this.storage[thread_id];

if (checkpoint_id) {
const checkpoint = checkpoints[checkpoint_id];
if (checkpoint) {
const saved = this.storage[thread_id]?.[checkpoint_ns]?.[checkpoint_id];
if (saved !== undefined) {
const [checkpoint, metadata, parentCheckpointId] = saved;
const writes =
this.writes[_generateKey(thread_id, checkpoint_ns, checkpoint_id)] ??
[];
const pendingWrites: CheckpointPendingWrite[] = await Promise.all(
writes.map(async ([taskId, channel, value]) => {
return [taskId, channel, await this.serde.parse(value as string)];
})
);
const parentConfig =
parentCheckpointId !== undefined
? {
configurable: {
thread_id,
checkpoint_ns,
checkpoint_id,
},
}
: undefined;
return {
config,
checkpoint: (await this.serde.parse(checkpoint[0])) as Checkpoint,
metadata: (await this.serde.parse(
checkpoint[1]
)) as CheckpointMetadata,
checkpoint: (await this.serde.parse(checkpoint)) as Checkpoint,
metadata: (await this.serde.parse(metadata)) as CheckpointMetadata,
pendingWrites,
parentConfig,
};
}
} else {
if (checkpoints) {
const checkpoints = this.storage[thread_id]?.[checkpoint_ns];
if (checkpoints !== undefined) {
const maxThreadTs = Object.keys(checkpoints).sort((a, b) =>
b.localeCompare(a)
)[0];
const checkpoint = checkpoints[maxThreadTs];
const saved = checkpoints[maxThreadTs];
const [checkpoint, metadata, parentCheckpointId] = saved;
const writes =
this.writes[_generateKey(thread_id, checkpoint_ns, checkpoint_id)] ??
[];
const pendingWrites: CheckpointPendingWrite[] = await Promise.all(
writes.map(async ([taskId, channel, value]) => {
return [taskId, channel, await this.serde.parse(value as string)];
})
);
const parentConfig =
parentCheckpointId !== undefined
? {
configurable: {
thread_id,
checkpoint_ns,
checkpoint_id: parentCheckpointId,
},
}
: undefined;
return {
config: { configurable: { thread_id, checkpoint_id: maxThreadTs } },
checkpoint: (await this.serde.parse(checkpoint[0])) as Checkpoint,
metadata: (await this.serde.parse(
checkpoint[1]
)) as CheckpointMetadata,
config: {
configurable: {
thread_id,
checkpoint_id: maxThreadTs,
checkpoint_ns,
},
},
checkpoint: (await this.serde.parse(checkpoint)) as Checkpoint,
metadata: (await this.serde.parse(metadata)) as CheckpointMetadata,
pendingWrites,
parentConfig,
};
}
}
Expand All @@ -55,21 +114,74 @@ export class MemorySaver extends BaseCheckpointSaver {
limit?: number,
before?: RunnableConfig
): AsyncGenerator<CheckpointTuple> {
const thread_id = config.configurable?.thread_id;
const checkpoints = this.storage[thread_id] ?? {};

// sort in desc order
for (const [checkpoint_id, checkpoint] of Object.entries(checkpoints)
.filter((c) =>
before ? c[0] < before.configurable?.checkpoint_id : true
)
.sort((a, b) => b[0].localeCompare(a[0]))
.slice(0, limit)) {
yield {
config: { configurable: { thread_id, checkpoint_id } },
checkpoint: (await this.serde.parse(checkpoint[0])) as Checkpoint,
metadata: (await this.serde.parse(checkpoint[1])) as CheckpointMetadata,
};
const threadIds = config.configurable?.thread_id
? [config.configurable?.thread_id]
: Object.keys(this.storage);
const checkpointNamespace = config.configurable?.checkpoint_ns ?? "";

for (const threadId of threadIds) {
const checkpoints = this.storage[threadId]?.[checkpointNamespace] ?? {};
const sortedCheckpoints = Object.entries(checkpoints).sort((a, b) =>
b[0].localeCompare(a[0])
);

for (const [
checkpointId,
[checkpoint, metadataStr, parentCheckpointId],
] of sortedCheckpoints) {
// Filter by checkpoint ID
if (
before &&
before.configurable?.checkpoint_id &&
checkpointId >= before.configurable.checkpoint_id
) {
continue;
}

// Parse metadata
const metadata = (await this.serde.parse(
metadataStr
)) as CheckpointMetadata;

// Limit search results
if (limit !== undefined) {
if (limit <= 0) break;
// eslint-disable-next-line no-param-reassign
limit -= 1;
}

const writes =
this.writes[
_generateKey(threadId, checkpointNamespace, checkpointId)
] ?? [];
const pendingWrites: CheckpointPendingWrite[] = await Promise.all(
writes.map(async ([taskId, channel, value]) => {
return [taskId, channel, await this.serde.parse(value as string)];
})
);

yield {
config: {
configurable: {
thread_id: threadId,
checkpoint_ns: checkpointNamespace,
checkpoint_id: checkpointId,
},
},
checkpoint: (await this.serde.parse(checkpoint)) as Checkpoint,
metadata,
pendingWrites,
parentConfig: parentCheckpointId
? {
configurable: {
thread_id: threadId,
checkpoint_ns: checkpointNamespace,
checkpoint_id: parentCheckpointId,
},
}
: undefined,
};
}
}
}

Expand All @@ -78,27 +190,68 @@ export class MemorySaver extends BaseCheckpointSaver {
checkpoint: Checkpoint,
metadata: CheckpointMetadata
): Promise<RunnableConfig> {
const thread_id = config.configurable?.thread_id;
const threadId = config.configurable?.thread_id;
const checkpointNamespace = config.configurable?.checkpoint_ns;
if (threadId === undefined) {
throw new Error(
`Failed to put checkpoint. The passed RunnableConfig is missing a required "thread_id" field in its "configurable" property.`
);
}
if (checkpointNamespace === undefined) {
throw new Error(
`Failed to put checkpoint. The passed RunnableConfig is missing a required "checkpoint_ns" field in its "configurable" property.`
);
}

if (this.storage[thread_id]) {
this.storage[thread_id][checkpoint.id] = [
this.serde.stringify(checkpoint),
this.serde.stringify(metadata),
];
} else {
this.storage[thread_id] = {
[checkpoint.id]: [
this.serde.stringify(checkpoint),
this.serde.stringify(metadata),
],
};
if (!this.storage[threadId]) {
this.storage[threadId] = {};
}
if (!this.storage[threadId][checkpointNamespace]) {
this.storage[threadId][checkpointNamespace] = {};
}

this.storage[threadId][checkpointNamespace][checkpoint.id] = [
this.serde.stringify(checkpoint),
this.serde.stringify(metadata),
config.configurable?.checkpoint_id, // parent
];

return {
configurable: {
thread_id,
thread_id: threadId,
checkpoint_ns: checkpointNamespace,
checkpoint_id: checkpoint.id,
},
};
}

async putWrites(
config: RunnableConfig,
writes: PendingWrite[],
taskId: string
): Promise<void> {
const threadId = config.configurable?.thread_id;
const checkpointNamespace = config.configurable?.checkpoint_ns;
const checkpointId = config.configurable?.checkpoint_id;
if (threadId === undefined) {
throw new Error(
`Failed to put writes. The passed RunnableConfig is missing a required "thread_id" field in its "configurable" property`
);
}
if (checkpointId === undefined) {
throw new Error(
`Failed to put writes. The passed RunnableConfig is missing a required "checkpoint_id" field in its "configurable" property.`
);
}
const key = _generateKey(threadId, checkpointNamespace, checkpointId);
if (this.writes[key] === undefined) {
this.writes[key] = [];
}
const pendingWrites: CheckpointPendingWrite[] = writes.map(
([channel, value]) => {
return [taskId, channel, this.serde.stringify(value)];
}
);
this.writes[key].push(...pendingWrites);
}
}
10 changes: 10 additions & 0 deletions langgraph/src/checkpoint/sqlite.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
CheckpointTuple,
} from "./base.js";
import { SerializerProtocol } from "../serde/base.js";
import { PendingWrite } from "../pregel/types.js";

// snake_case is used to match Python implementation
interface Row {
Expand Down Expand Up @@ -207,4 +208,13 @@ CREATE TABLE IF NOT EXISTS checkpoints (
},
};
}

// TODO: Implement
putWrites(
_config: RunnableConfig,
_writes: PendingWrite[],
_taskId: string
): Promise<void> {
throw new Error("Not implemented");
}
}
2 changes: 2 additions & 0 deletions langgraph/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ export const TAG_HIDDEN = "langsmith:hidden";
export const TASKS = "__pregel_tasks";
export const TASK_NAMESPACE = "6ba7b831-9dad-11d1-80b4-00c04fd430c8";

export const CHECKPOINT_NAMESPACE_SEPARATOR = "|";

export interface SendInterface {
node: string;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand Down
Loading

0 comments on commit a89219c

Please sign in to comment.