Skip to content

Commit

Permalink
Refactor to use PregelLoop (#330)
Browse files Browse the repository at this point in the history
* Refactor checkpoint

* Fix lint and format

* Fix tests, fix UUIDv5 issue

* Fix build

* Refactor _applyWrites

* Finish initial loop implementation

* Fix build

* Continued refactor

* Fix test
  • Loading branch information
jacoblee93 authored Aug 20, 2024
1 parent fe3e936 commit 1b701f2
Show file tree
Hide file tree
Showing 26 changed files with 1,911 additions and 898 deletions.
2 changes: 2 additions & 0 deletions langgraph/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"license": "MIT",
"dependencies": {
"@langchain/core": ">=0.2.20 <0.3.0",
"double-ended-queue": "^2.1.0-0",
"uuid": "^10.0.0",
"zod": "^3.23.8"
},
Expand All @@ -44,6 +45,7 @@
"@swc/jest": "^0.2.29",
"@tsconfig/recommended": "^1.0.3",
"@types/better-sqlite3": "^7.6.9",
"@types/double-ended-queue": "^2",
"@types/uuid": "^10",
"@typescript-eslint/eslint-plugin": "^6.12.0",
"@typescript-eslint/parser": "^6.12.0",
Expand Down
9 changes: 9 additions & 0 deletions langgraph/src/channels/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,15 @@ export abstract class BaseChannel<
* @returns {CheckpointType | undefined}
*/
abstract checkpoint(): CheckpointType | undefined;

/**
* Mark the current value of the channel as consumed. By default, no-op.
* This is called by Pregel before the start of the next step, for all
* channels that triggered a node. If the channel was updated, return true.
*/
consume(): boolean {
return true;
}
}

export function emptyChannels<Cc extends Record<string, BaseChannel>>(
Expand Down
36 changes: 28 additions & 8 deletions langgraph/src/checkpoint/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ import { RunnableConfig } from "@langchain/core/runnables";
import { DefaultSerializer, SerializerProtocol } from "../serde/base.js";
import { uuid6 } from "./id.js";
import { SendInterface } from "../constants.js";
import {
CheckpointMetadata,
CheckpointPendingWrite,
import type {
PendingWrite,
} from "../pregel/types.js";
CheckpointPendingWrite,
CheckpointMetadata,
} from "./types.js";

export type { CheckpointMetadata };
export type ChannelVersions = Record<string, string | number>;

export interface Checkpoint<
N extends string = string,
Expand Down Expand Up @@ -118,6 +118,13 @@ export interface CheckpointTuple {
pendingWrites?: CheckpointPendingWrite[];
}

export type CheckpointListOptions = {
limit?: number;
before?: RunnableConfig;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
filter?: Record<string, any>;
};

export abstract class BaseCheckpointSaver {
serde: SerializerProtocol<unknown> = DefaultSerializer;

Expand All @@ -136,14 +143,14 @@ export abstract class BaseCheckpointSaver {

abstract list(
config: RunnableConfig,
limit?: number,
before?: RunnableConfig
options?: CheckpointListOptions
): AsyncGenerator<CheckpointTuple>;

abstract put(
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata
metadata: CheckpointMetadata,
newVersions: ChannelVersions
): Promise<RunnableConfig>;

/**
Expand All @@ -154,4 +161,17 @@ export abstract class BaseCheckpointSaver {
writes: PendingWrite[],
taskId: string
): Promise<void>;

/**
* Generate the next version ID for a channel.
*
* Default is to use integer versions, incrementing by 1. If you override, you can use str/int/float versions,
* as long as they are monotonically increasing.
*
* TODO: Fix type
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
getNextVersion(current: number | undefined, _channel: any) {
return current !== undefined ? current + 1 : 1;
}
}
2 changes: 1 addition & 1 deletion langgraph/src/checkpoint/index.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
export { MemorySaver } from "./memory.js";
export {
type Checkpoint,
type CheckpointMetadata,
copyCheckpoint,
emptyCheckpoint,
BaseCheckpointSaver,
} from "./base.js";
export { type CheckpointMetadata } from "./types.js";
13 changes: 9 additions & 4 deletions langgraph/src/checkpoint/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ import { RunnableConfig } from "@langchain/core/runnables";
import {
BaseCheckpointSaver,
Checkpoint,
CheckpointMetadata,
CheckpointListOptions,
CheckpointTuple,
} from "./base.js";
import { SerializerProtocol } from "../serde/base.js";
import { CheckpointPendingWrite, PendingWrite } from "../pregel/types.js";
import {
CheckpointMetadata,
CheckpointPendingWrite,
PendingWrite,
} from "../checkpoint/types.js";

function _generateKey(
threadId: string,
Expand Down Expand Up @@ -111,9 +115,10 @@ export class MemorySaver extends BaseCheckpointSaver {

async *list(
config: RunnableConfig,
limit?: number,
before?: RunnableConfig
options?: CheckpointListOptions
): AsyncGenerator<CheckpointTuple> {
// eslint-disable-next-line prefer-const
let { before, limit } = options ?? {};
const threadIds = config.configurable?.thread_id
? [config.configurable?.thread_id]
: Object.keys(this.storage);
Expand Down
Loading

0 comments on commit 1b701f2

Please sign in to comment.