Skip to content

Commit

Permalink
Add pending writes to debug state, fix forking state bug (#338)
Browse files Browse the repository at this point in the history
* Add pending writes to debug state

* Fix checkpoint forking bug

* Fix test
  • Loading branch information
jacoblee93 authored Aug 21, 2024
1 parent 3598cc1 commit a226918
Show file tree
Hide file tree
Showing 5 changed files with 336 additions and 20 deletions.
27 changes: 16 additions & 11 deletions langgraph/src/channels/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,25 @@ export function emptyChannels<Cc extends Record<string, BaseChannel>>(

export function createCheckpoint<ValueType>(
checkpoint: ReadonlyCheckpoint,
channels: Record<string, BaseChannel<ValueType>>,
channels: Record<string, BaseChannel<ValueType>> | undefined,
step: number
): Checkpoint {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const values: Record<string, any> = {};
for (const k of Object.keys(channels)) {
try {
values[k] = channels[k].checkpoint();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (error: any) {
if (error.name === EmptyChannelError.unminifiable_name) {
// no-op
} else {
throw error; // Rethrow unexpected errors
let values: Record<string, any>;
if (channels === undefined) {
values = checkpoint.channel_values;
} else {
values = {};
for (const k of Object.keys(channels)) {
try {
values[k] = channels[k].checkpoint();
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} catch (error: any) {
if (error.name === EmptyChannelError.unminifiable_name) {
// no-op
} else {
throw error; // Rethrow unexpected errors
}
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions langgraph/src/pregel/debug.ts
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ export function* mapDebugCheckpoint<
};
}

function tasksWithWrites<N extends PropertyKey, C extends PropertyKey>(
tasks: readonly PregelExecutableTask<N, C>[],
export function tasksWithWrites<N extends PropertyKey, C extends PropertyKey>(
tasks: PregelTaskDescription[] | readonly PregelExecutableTask<N, C>[],
pendingWrites: CheckpointPendingWrite[]
): PregelTaskDescription[] {
return tasks.map((task): PregelTaskDescription => {
Expand Down
61 changes: 54 additions & 7 deletions langgraph/src/pregel/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
getCallbackManagerForConfig,
patchConfig,
} from "@langchain/core/runnables";
import { IterableReadableStream } from "@langchain/core/utils/stream";
import {
BaseChannel,
createCheckpoint,
Expand All @@ -29,6 +30,7 @@ import {
printStepCheckpoint,
printStepTasks,
printStepWrites,
tasksWithWrites,
} from "./debug.js";
import { ChannelWrite, ChannelWriteEntry, PASSTHROUGH } from "./write.js";
import {
Expand Down Expand Up @@ -278,6 +280,9 @@ export class Pregel<
}
}

/**
* Get the current state of the graph.
*/
async getState(config: RunnableConfig): Promise<StateSnapshot> {
if (!this.checkpointer) {
throw new GraphValueError("No checkpointer set");
Expand All @@ -297,13 +302,17 @@ export class Pregel<
return {
values: readChannels(channels, this.streamChannelsAsIs),
next: nextTasks.map((task) => task.name),
tasks: tasksWithWrites(nextTasks, saved?.pendingWrites ?? []),
metadata: saved?.metadata,
config: saved ? saved.config : config,
createdAt: saved?.checkpoint.ts,
parentConfig: saved?.parentConfig,
};
}

/**
* Get the history of the state of the graph.
*/
async *getStateHistory(
config: RunnableConfig,
options?: CheckpointListOptions
Expand All @@ -324,6 +333,7 @@ export class Pregel<
yield {
values: readChannels(channels, this.streamChannelsAsIs),
next: nextTasks.map((task) => task.name),
tasks: tasksWithWrites(nextTasks, saved.pendingWrites ?? []),
metadata: saved.metadata,
config: saved.config,
createdAt: saved.checkpoint.ts,
Expand All @@ -332,6 +342,11 @@ export class Pregel<
}
}

/**
* Update the state of the graph with the given values, as if they came from
* node `as_node`. If `as_node` is not provided, it will be set to the last node
* that updated the state, if not ambiguous.
*/
async updateState(
config: RunnableConfig,
values: Record<string, unknown> | unknown,
Expand Down Expand Up @@ -361,10 +376,10 @@ export class Pregel<
};

// Find last node that updated the state, if not provided
if (values === undefined && asNode === undefined) {
if (values == null && asNode === undefined) {
return await this.checkpointer.put(
checkpointConfig,
createCheckpoint(checkpoint, {}, step),
createCheckpoint(checkpoint, undefined, step),
{
source: "update",
step,
Expand All @@ -380,7 +395,7 @@ export class Pregel<
})
.flat()
.find((v) => !!v);
if (asNode === undefined && !nonNullVersion) {
if (asNode === undefined && nonNullVersion === undefined) {
if (
typeof this.inputChannels === "string" &&
this.nodes[this.inputChannels] !== undefined
Expand Down Expand Up @@ -562,7 +577,29 @@ export class Pregel<
];
}

async *_streamIterator(
/**
* Stream graph steps for a single input.
* @param input The input to the graph.
* @param options The configuration to use for the run.
* @param options.streamMode The mode to stream output. Defaults to value set on initialization.
* Options are "values", "updates", and "debug". Default is "values".
* values: Emit the current values of the state for each step.
* updates: Emit only the updates to the state for each step.
* Output is a dict with the node name as key and the updated values as value.
* debug: Emit debug events for each step.
* @param options.outputKeys The keys to stream. Defaults to all non-context channels.
* @param options.interruptBefore Nodes to interrupt before.
* @param options.interruptAfter Nodes to interrupt after.
* @param options.debug Whether to print debug information during execution.
*/
override async stream(
input: PregelInputType,
options?: Partial<PregelOptions<Nn, Cc>>
): Promise<IterableReadableStream<PregelOutputType>> {
return super.stream(input, options);
}

override async *_streamIterator(
input: PregelInputType,
options?: Partial<PregelOptions<Nn, Cc>>
): AsyncGenerator<PregelOutputType> {
Expand Down Expand Up @@ -722,10 +759,20 @@ export class Pregel<

/**
* Run the graph with a single input and config.
* @param input
* @param options
* @param input The input to the graph.
* @param options The configuration to use for the run.
* @param options.streamMode The mode to stream output. Defaults to value set on initialization.
* Options are "values", "updates", and "debug". Default is "values".
* values: Emit the current values of the state for each step.
* updates: Emit only the updates to the state for each step.
* Output is a dict with the node name as key and the updated values as value.
* debug: Emit debug events for each step.
* @param options.outputKeys The keys to stream. Defaults to all non-context channels.
* @param options.interruptBefore Nodes to interrupt before.
* @param options.interruptAfter Nodes to interrupt after.
* @param options.debug Whether to print debug information during execution.
*/
async invoke(
override async invoke(
input: PregelInputType,
options?: Partial<PregelOptions<Nn, Cc>>
): Promise<PregelOutputType> {
Expand Down
4 changes: 4 additions & 0 deletions langgraph/src/pregel/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ export interface StateSnapshot {
* @default undefined
*/
readonly parentConfig?: RunnableConfig | undefined;
/**
* Tasks to execute in this step. If already attempted, may contain an error.
*/
readonly tasks: PregelTaskDescription[];
}

export type All = "*";
Loading

0 comments on commit a226918

Please sign in to comment.