From 06a36e34916e3ef2c39858852109dfa90c14b779 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Sun, 25 Aug 2024 21:36:50 -0700 Subject: [PATCH] Adds support for retryPolicy --- libs/checkpoint-mongodb/README.md | 1 - libs/checkpoint-sqlite/README.md | 1 - libs/checkpoint/README.md | 1 - libs/langgraph/src/graph/graph.ts | 52 ++++--- libs/langgraph/src/graph/state.ts | 79 +++++++++-- libs/langgraph/src/pregel/algo.ts | 58 +------- libs/langgraph/src/pregel/index.ts | 53 ++----- libs/langgraph/src/pregel/read.ts | 24 +++- libs/langgraph/src/pregel/retry.ts | 176 ++++++++++++++++++++++++ libs/langgraph/src/pregel/types.ts | 5 +- libs/langgraph/src/pregel/utils.ts | 28 ++++ libs/langgraph/src/tests/pregel.test.ts | 94 +++++++++++++ libs/langgraph/src/web.ts | 1 + 13 files changed, 444 insertions(+), 129 deletions(-) create mode 100644 libs/langgraph/src/pregel/retry.ts diff --git a/libs/checkpoint-mongodb/README.md b/libs/checkpoint-mongodb/README.md index 856aba19..b91d9556 100644 --- a/libs/checkpoint-mongodb/README.md +++ b/libs/checkpoint-mongodb/README.md @@ -48,7 +48,6 @@ const checkpoint = { } }, pending_sends: [], - current_tasks: {} } // store checkpoint diff --git a/libs/checkpoint-sqlite/README.md b/libs/checkpoint-sqlite/README.md index 05245b48..ee71a663 100644 --- a/libs/checkpoint-sqlite/README.md +++ b/libs/checkpoint-sqlite/README.md @@ -44,7 +44,6 @@ const checkpoint = { } }, pending_sends: [], - current_tasks: {} } // store checkpoint diff --git a/libs/checkpoint/README.md b/libs/checkpoint/README.md index 246a555e..fec91bed 100644 --- a/libs/checkpoint/README.md +++ b/libs/checkpoint/README.md @@ -81,7 +81,6 @@ const checkpoint = { } }, pending_sends: [], - current_tasks: {} } // store checkpoint diff --git a/libs/langgraph/src/graph/graph.ts b/libs/langgraph/src/graph/graph.ts index f3701f1b..9bc2f4b7 100644 --- a/libs/langgraph/src/graph/graph.ts +++ b/libs/langgraph/src/graph/graph.ts @@ -100,14 +100,25 @@ export class Branch { } } +export type NodeSpec = { + runnable: Runnable; + metadata?: Record; +}; + +export type AddNodeOptions = { metadata?: Record }; + export class Graph< N extends string = typeof END, // eslint-disable-next-line @typescript-eslint/no-explicit-any RunInput = any, // eslint-disable-next-line @typescript-eslint/no-explicit-any - RunOutput = any + RunOutput = any, + NodeSpecType extends NodeSpec = NodeSpec< + RunInput, + RunOutput + > > { - nodes: Record>; + nodes: Record; edges: Set<[N | typeof START, N | typeof END]>; @@ -120,12 +131,12 @@ export class Graph< supportMultipleEdges = false; constructor() { - this.nodes = {} as Record>; + this.nodes = {} as Record; this.edges = new Set(); this.branches = {}; } - private warnIfCompiled(message: string): void { + protected warnIfCompiled(message: string): void { if (this.compiled) { console.warn(message); } @@ -137,7 +148,8 @@ export class Graph< addNode( key: K, - action: RunnableLike + action: RunnableLike, + options?: AddNodeOptions ): Graph { if (key.includes(CHECKPOINT_NAMESPACE_SEPARATOR)) { throw new Error( @@ -155,12 +167,15 @@ export class Graph< throw new Error(`Node \`${key}\` is reserved.`); } - this.nodes[key as unknown as N] = _coerceToRunnable( - // Account for arbitrary state due to Send API - action as RunnableLike - ); + this.nodes[key as unknown as N] = { + runnable: _coerceToRunnable( + // Account for arbitrary state due to Send API + action as RunnableLike + ), + metadata: options?.metadata, + } as NodeSpecType; - return this as Graph; + return this as Graph; } addEdge(startKey: N | typeof START, endKey: N | typeof END): this { @@ -278,7 +293,7 @@ export class Graph< }); // attach nodes, edges and branches - for (const [key, node] of Object.entries>( + for (const [key, node] of Object.entries>( this.nodes )) { compiled.attachNode(key as N, node); @@ -385,13 +400,14 @@ export class CompiledGraph< this.builder = builder; } - attachNode(key: N, node: Runnable): void { + attachNode(key: N, node: NodeSpec): void { this.channels[key] = new EphemeralValue(); this.nodes[key] = new PregelNode({ channels: [], triggers: [], + metadata: node.metadata, }) - .pipe(node) + .pipe(node.runnable) .pipe( new ChannelWrite([{ channel: key, value: PASSTHROUGH }], [TAG_HIDDEN]) ); @@ -477,14 +493,16 @@ export class CompiledGraph< END ), }; - for (const [key, node] of Object.entries(this.builder.nodes)) { + for (const [key, node] of Object.entries>( + this.builder.nodes + )) { if (config?.xray) { const subgraph = isCompiledGraph(node) ? node.getGraph({ ...config, xray: typeof xray === "number" && xray > 0 ? xray - 1 : xray, }) - : node.getGraph(config); + : node.runnable.getGraph(config); subgraph.trimFirstNode(); subgraph.trimLastNode(); if (Object.keys(subgraph.nodes).length > 1) { @@ -496,12 +514,12 @@ export class CompiledGraph< startNodes[key] = newStartNode; } } else { - const newNode = graph.addNode(node, key); + const newNode = graph.addNode(node.runnable, key); startNodes[key] = newNode; endNodes[key] = newNode; } } else { - const newNode = graph.addNode(node, key); + const newNode = graph.addNode(node.runnable, key); startNodes[key] = newNode; endNodes[key] = newNode; } diff --git a/libs/langgraph/src/graph/state.ts b/libs/langgraph/src/graph/state.ts index 218c434a..1efc1e49 100644 --- a/libs/langgraph/src/graph/state.ts +++ b/libs/langgraph/src/graph/state.ts @@ -1,12 +1,20 @@ /* eslint-disable @typescript-eslint/no-use-before-define */ import { + _coerceToRunnable, Runnable, - RunnableConfig, RunnableLike, } from "@langchain/core/runnables"; import { BaseCheckpointSaver } from "@langchain/langgraph-checkpoint"; import { BaseChannel } from "../channels/base.js"; -import { END, CompiledGraph, Graph, START, Branch } from "./graph.js"; +import { + END, + CompiledGraph, + Graph, + START, + Branch, + AddNodeOptions, + NodeSpec, +} from "./graph.js"; import { ChannelWrite, ChannelWriteEntry, @@ -18,7 +26,12 @@ import { NamedBarrierValue } from "../channels/named_barrier_value.js"; import { EphemeralValue } from "../channels/ephemeral_value.js"; import { RunnableCallable } from "../utils.js"; import { All } from "../pregel/types.js"; -import { _isSend, Send, TAG_HIDDEN } from "../constants.js"; +import { + _isSend, + CHECKPOINT_NAMESPACE_SEPARATOR, + Send, + TAG_HIDDEN, +} from "../constants.js"; import { InvalidUpdateError } from "../errors.js"; import { AnnotationRoot, @@ -28,6 +41,7 @@ import { StateType, UpdateType, } from "./annotation.js"; +import type { RetryPolicy } from "../pregel/utils.js"; const ROOT = "__root__"; @@ -44,6 +58,19 @@ export interface StateGraphArgs { : ChannelReducers<{ __root__: Channels }>; } +export type StateGraphNodeSpec = NodeSpec< + RunInput, + RunOutput +> & { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + input?: any; + retryPolicy?: RetryPolicy; +}; + +export type StateGraphAddNodeOptions = { + retryPolicy?: RetryPolicy; +} & AddNodeOptions; + /** * A graph whose nodes communicate by reading and writing to a shared state. * Each node takes a defined `State` as input and returns a `Partial`. @@ -111,7 +138,7 @@ export class StateGraph< S = SD extends StateDefinition ? StateType : SD, U = SD extends StateDefinition ? UpdateType : Partial, N extends string = typeof START -> extends Graph { +> extends Graph> { channels: Record; // TODO: this doesn't dedupe edges as in py, so worth fixing at some point @@ -155,14 +182,40 @@ export class StateGraph< addNode( key: K, - action: RunnableLike + action: RunnableLike, + options?: StateGraphAddNodeOptions ): StateGraph { if (key in this.channels) { throw new Error( `${key} is already being used as a state attribute (a.k.a. a channel), cannot also be used as a node name.` ); } - return super.addNode(key, action) as StateGraph; + + if (key.includes(CHECKPOINT_NAMESPACE_SEPARATOR)) { + throw new Error( + `"${CHECKPOINT_NAMESPACE_SEPARATOR}" is a reserved character and is not allowed in node names.` + ); + } + this.warnIfCompiled( + `Adding a node to a graph that has already been compiled. This will not be reflected in the compiled graph.` + ); + + if (key in this.nodes) { + throw new Error(`Node \`${key}\` already present.`); + } + if (key === END || key === START) { + throw new Error(`Node \`${key}\` is reserved.`); + } + + const nodeSpec: StateGraphNodeSpec = { + runnable: _coerceToRunnable(action) as unknown as Runnable, + retryPolicy: options?.retryPolicy, + metadata: options?.metadata, + }; + + this.nodes[key as unknown as N] = nodeSpec; + + return this as StateGraph; } addEdge(startKey: typeof START | N | N[], endKey: N | typeof END): this { @@ -239,7 +292,9 @@ export class StateGraph< // attach nodes, edges and branches compiled.attachNode(START); - for (const [key, node] of Object.entries>(this.nodes)) { + for (const [key, node] of Object.entries>( + this.nodes + )) { compiled.attachNode(key as N, node); } for (const [start, end] of this.edges) { @@ -284,12 +339,9 @@ export class CompiledStateGraph< attachNode(key: typeof START, node?: never): void; - attachNode(key: N, node: Runnable): void; + attachNode(key: N, node: StateGraphNodeSpec): void; - attachNode( - key: N | typeof START, - node?: Runnable - ): void { + attachNode(key: N | typeof START, node?: StateGraphNodeSpec): void { const stateKeys = Object.keys(this.builder.channels); function getStateKey(key: keyof U, input: U) { @@ -345,7 +397,8 @@ export class CompiledStateGraph< [TAG_HIDDEN] ), ], - bound: node, + bound: node?.runnable, + retryPolicy: node?.retryPolicy, }); } } diff --git a/libs/langgraph/src/pregel/algo.ts b/libs/langgraph/src/pregel/algo.ts index dee52efb..bd067311 100644 --- a/libs/langgraph/src/pregel/algo.ts +++ b/libs/langgraph/src/pregel/algo.ts @@ -56,62 +56,6 @@ export const increment = (current?: number) => { return current !== undefined ? current + 1 : 1; }; -export async function* executeTasks( - tasks: Record< - string, - () => Promise<{ - // eslint-disable-next-line @typescript-eslint/no-explicit-any - task: PregelExecutableTask; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - result: any; - error: Error; - }> - >, - stepTimeout?: number, - signal?: AbortSignal - // eslint-disable-next-line @typescript-eslint/no-explicit-any -): AsyncGenerator> { - if (stepTimeout && signal) { - if ("any" in AbortSignal) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - signal = (AbortSignal as any).any([ - signal, - AbortSignal.timeout(stepTimeout), - ]); - } - } else if (stepTimeout) { - signal = AbortSignal.timeout(stepTimeout); - } - - // Abort if signal is aborted - signal?.throwIfAborted(); - - // Start all tasks - const executingTasks = Object.fromEntries( - Object.entries(tasks).map(([taskId, task]) => { - return [taskId, task()]; - }) - ); - let listener: () => void; - const signalPromise = new Promise((_resolve, reject) => { - listener = () => reject(new Error("Abort")); - signal?.addEventListener("abort", listener); - }).finally(() => signal?.removeEventListener("abort", listener)); - - while (Object.keys(executingTasks).length > 0) { - const { task, error } = await Promise.race([ - ...Object.values(executingTasks), - signalPromise, - ]); - if (error !== undefined) { - // TODO: don't stop others if exception is interrupt - throw error; - } - yield task; - delete executingTasks[task.id]; - } -} - export function shouldInterrupt( checkpoint: Checkpoint, interruptNodes: All | N[], @@ -426,6 +370,7 @@ export function _prepareNextTasks< } ), id: taskId, + retry_policy: proc.retryPolicy, }); } } else { @@ -521,6 +466,7 @@ export function _prepareNextTasks< } ), id: taskId, + retry_policy: proc.retryPolicy, }); } } else { diff --git a/libs/langgraph/src/pregel/index.ts b/libs/langgraph/src/pregel/index.ts index 90d3992a..21c3cc26 100644 --- a/libs/langgraph/src/pregel/index.ts +++ b/libs/langgraph/src/pregel/index.ts @@ -55,15 +55,15 @@ import { InvalidUpdateError, } from "../errors.js"; import { - executeTasks, _prepareNextTasks, _localRead, _applyWrites, StrRecord, } from "./algo.js"; import { prefixGenerator } from "../utils.js"; -import { _coerceToDict, getNewChannelVersions } from "./utils.js"; +import { _coerceToDict, getNewChannelVersions, RetryPolicy } from "./utils.js"; import { PregelLoop } from "./loop.js"; +import { executeTasksWithRetry } from "./retry.js"; type WriteValue = Runnable | RunnableFunc | unknown; @@ -223,6 +223,8 @@ export class Pregel< checkpointer?: BaseCheckpointSaver; + retryPolicy?: RetryPolicy; + constructor(fields: PregelParams) { super(fields); @@ -243,6 +245,7 @@ export class Pregel< this.stepTimeout = fields.stepTimeout ?? this.stepTimeout; this.debug = fields.debug ?? this.debug; this.checkpointer = fields.checkpointer; + this.retryPolicy = fields.retryPolicy; if (this.autoValidate) { this.validate(); @@ -686,43 +689,17 @@ export class Pregel< if (debug) { printStepTasks(loop.step, loop.tasks); } - // execute tasks, and wait for one to fail or all to finish. - // each task is independent from all other concurrent tasks - // yield updates/debug output as each task finishes - const tasks = Object.fromEntries( - loop.tasks - .filter((task) => task.writes.length === 0) - .map((pregelTask) => { - return [ - pregelTask.id, - async () => { - let error; - let result; - try { - result = await pregelTask.proc.invoke( - pregelTask.input, - pregelTask.config - ); - // eslint-disable-next-line @typescript-eslint/no-explicit-any - } catch (e: any) { - error = e; - error.pregelTaskId = pregelTask.id; - } - return { - task: pregelTask, - result, - error, - }; - }, - ]; - }) - ); - try { - for await (const task of executeTasks( - tasks, - this.stepTimeout, - config.signal + // execute tasks, and wait for one to fail or all to finish. + // each task is independent from all other concurrent tasks + // yield updates/debug output as each task finishes + for await (const task of executeTasksWithRetry( + loop.tasks.filter((task) => task.writes.length === 0), + { + stepTimeout: this.stepTimeout, + signal: config.signal, + retryPolicy: this.retryPolicy, + } )) { loop.putWrites(task.id, task.writes); if (streamMode.includes("updates")) { diff --git a/libs/langgraph/src/pregel/read.ts b/libs/langgraph/src/pregel/read.ts index c6654188..0b5aa03b 100644 --- a/libs/langgraph/src/pregel/read.ts +++ b/libs/langgraph/src/pregel/read.ts @@ -11,6 +11,7 @@ import { import { CONFIG_KEY_READ } from "../constants.js"; import { ChannelWrite } from "./write.js"; import { RunnableCallable } from "../utils.js"; +import type { RetryPolicy } from "./utils.js"; export class ChannelRead< // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -80,6 +81,8 @@ interface PregelNodeArgs // eslint-disable-next-line @typescript-eslint/no-explicit-any kwargs?: Record; config?: RunnableConfig; + metadata?: Record; + retryPolicy?: RetryPolicy; } // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -108,8 +111,21 @@ export class PregelNode< // eslint-disable-next-line @typescript-eslint/no-explicit-any kwargs: Record = {}; + metadata: Record = {}; + + retryPolicy?: RetryPolicy; + constructor(fields: PregelNodeArgs) { - const { channels, triggers, mapper, writers, bound, kwargs } = fields; + const { + channels, + triggers, + mapper, + writers, + bound, + kwargs, + metadata, + retryPolicy, + } = fields; const mergedTags = [ ...(fields.config?.tags ? fields.config.tags : []), ...(fields.tags ? fields.tags : []), @@ -132,6 +148,8 @@ export class PregelNode< this.writers = writers ?? this.writers; this.bound = bound ?? this.bound; this.kwargs = kwargs ?? this.kwargs; + this.metadata = metadata ?? this.metadata; + this.retryPolicy = retryPolicy; } getWriters(): Array { @@ -198,6 +216,7 @@ export class PregelNode< bound: this.bound, kwargs: this.kwargs, config: this.config, + retryPolicy: this.retryPolicy, }); } @@ -216,6 +235,7 @@ export class PregelNode< >, config: this.config, kwargs: this.kwargs, + retryPolicy: this.retryPolicy, }); } else if (this.bound === defaultRunnableBound) { return new PregelNode>({ @@ -226,6 +246,7 @@ export class PregelNode< bound: _coerceToRunnable(coerceable), config: this.config, kwargs: this.kwargs, + retryPolicy: this.retryPolicy, }); } else { return new PregelNode>({ @@ -236,6 +257,7 @@ export class PregelNode< bound: this.bound.pipe(coerceable), config: this.config, kwargs: this.kwargs, + retryPolicy: this.retryPolicy, }); } } diff --git a/libs/langgraph/src/pregel/retry.ts b/libs/langgraph/src/pregel/retry.ts new file mode 100644 index 00000000..a0add4e8 --- /dev/null +++ b/libs/langgraph/src/pregel/retry.ts @@ -0,0 +1,176 @@ +import { GraphInterrupt } from "../errors.js"; +import { PregelExecutableTask } from "./types.js"; +import type { RetryPolicy } from "./utils.js"; + +export const DEFAULT_INITIAL_INTERVAL = 500; +export const DEFAULT_BACKOFF_FACTOR = 2; +export const DEFAULT_MAX_INTERVAL = 128000; +export const DEFAULT_MAX_RETRIES = 3; + +const DEFAULT_STATUS_NO_RETRY = [ + 400, // Bad Request + 401, // Unauthorized + 402, // Payment Required + 403, // Forbidden + 404, // Not Found + 405, // Method Not Allowed + 406, // Not Acceptable + 407, // Proxy Authentication Required + 409, // Conflict +]; + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +const DEFAULT_RETRY_ON_HANDLER = (error: any) => { + if ( + error.message.startsWith("Cancel") || + error.message.startsWith("AbortError") || + error.name === "AbortError" + ) { + return false; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + if ((error as any)?.code === "ECONNABORTED") { + return false; + } + const status = + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any)?.response?.status ?? (error as any)?.status; + if (status && DEFAULT_STATUS_NO_RETRY.includes(+status)) { + return false; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any + if ((error as any)?.error?.code === "insufficient_quota") { + return false; + } + return true; +}; + +export async function* executeTasksWithRetry( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + tasks: PregelExecutableTask[], + options?: { + stepTimeout?: number; + signal?: AbortSignal; + retryPolicy?: RetryPolicy; + } + // eslint-disable-next-line @typescript-eslint/no-explicit-any +): AsyncGenerator> { + const { stepTimeout, retryPolicy } = options ?? {}; + let signal = options?.signal; + // Start tasks + const executingTasksMap = Object.fromEntries( + tasks.map((pregelTask) => { + return [pregelTask.id, _runWithRetry(pregelTask, retryPolicy)]; + }) + ); + if (stepTimeout && signal) { + if ("any" in AbortSignal) { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + signal = (AbortSignal as any).any([ + signal, + AbortSignal.timeout(stepTimeout), + ]); + } + } else if (stepTimeout) { + signal = AbortSignal.timeout(stepTimeout); + } + + // Abort if signal is aborted + signal?.throwIfAborted(); + + let listener: () => void; + const signalPromise = new Promise((_resolve, reject) => { + listener = () => reject(new Error("Abort")); + signal?.addEventListener("abort", listener); + }).finally(() => signal?.removeEventListener("abort", listener)); + + while (Object.keys(executingTasksMap).length > 0) { + const { task, error } = await Promise.race([ + ...Object.values(executingTasksMap), + signalPromise, + ]); + if (error !== undefined) { + // TODO: don't stop others if exception is interrupt + throw error; + } + yield task; + delete executingTasksMap[task.id]; + } +} + +async function _runWithRetry( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + pregelTask: PregelExecutableTask, + retryPolicy?: RetryPolicy +) { + const resolvedRetryPolicy = pregelTask.retry_policy ?? retryPolicy; + let interval = + resolvedRetryPolicy !== undefined + ? resolvedRetryPolicy.initialInterval ?? DEFAULT_INITIAL_INTERVAL + : 0; + let attempts = 0; + let error; + let result; + // eslint-disable-next-line no-constant-condition + while (true) { + // Modify writes in place to clear any previous retries + while (pregelTask.writes.length > 0) { + pregelTask.writes.pop(); + } + error = undefined; + try { + result = await pregelTask.proc.invoke( + pregelTask.input, + pregelTask.config + ); + break; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + } catch (e: any) { + error = e; + error.pregelTaskId = pregelTask.id; + if (error.name === GraphInterrupt.unminifiable_name) { + break; + } + if (resolvedRetryPolicy === undefined) { + break; + } + attempts += 1; + // check if we should give up + if ( + attempts >= (resolvedRetryPolicy.maxAttempts ?? DEFAULT_MAX_RETRIES) + ) { + break; + } + const retryOn = resolvedRetryPolicy.retryOn ?? DEFAULT_RETRY_ON_HANDLER; + if (!retryOn(error)) { + break; + } + interval = Math.min( + resolvedRetryPolicy.maxInterval ?? DEFAULT_MAX_INTERVAL, + interval * (resolvedRetryPolicy.backoffFactor ?? DEFAULT_BACKOFF_FACTOR) + ); + const intervalWithJitter = resolvedRetryPolicy.jitter + ? Math.floor(interval + Math.random() * 1000) + : interval; + // sleep before retrying + // eslint-disable-next-line no-promise-executor-return + await new Promise((resolve) => setTimeout(resolve, intervalWithJitter)); + // log the retry + const errorName = + error.name ?? + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error.constructor as any).unminifiable_name ?? + error.constructor.name; + console.log( + `Retrying task "${pregelTask.name}" after ${interval.toFixed( + 2 + )} seconds (attempt ${attempts}) after ${errorName}: ${error}` + ); + } + } + return { + task: pregelTask, + result, + error, + }; +} diff --git a/libs/langgraph/src/pregel/types.ts b/libs/langgraph/src/pregel/types.ts index 23ed0ffb..46139b03 100644 --- a/libs/langgraph/src/pregel/types.ts +++ b/libs/langgraph/src/pregel/types.ts @@ -6,6 +6,7 @@ import type { } from "@langchain/langgraph-checkpoint"; import type { BaseChannel } from "../channels/base.js"; import type { PregelNode } from "./read.js"; +import { RetryPolicy } from "./utils.js"; export type StreamMode = "values" | "updates" | "debug"; @@ -63,6 +64,8 @@ export interface PregelInterface< debug?: boolean; checkpointer?: BaseCheckpointSaver; + + retryPolicy?: RetryPolicy; } export type PregelParams< @@ -86,7 +89,7 @@ export interface PregelExecutableTask< readonly writes: PendingWrite[]; readonly config: RunnableConfig | undefined; readonly triggers: Array; - readonly retry_policy?: string; + readonly retry_policy?: RetryPolicy; readonly id: string; } diff --git a/libs/langgraph/src/pregel/utils.ts b/libs/langgraph/src/pregel/utils.ts index 677e723c..f2adc7a8 100644 --- a/libs/langgraph/src/pregel/utils.ts +++ b/libs/langgraph/src/pregel/utils.ts @@ -50,3 +50,31 @@ export function _getIdMetadata(metadata: Record) { langgraph_task_idx: metadata.langgraph_task_idx, }; } + +export type RetryPolicy = { + /** + * Amount of time that must elapse before the first retry occurs in milliseconds. + * @default 500 + */ + initialInterval?: number; + /** + * Multiplier by which the interval increases after each retry. + * @default 2 + */ + backoffFactor?: number; + /** + * Maximum amount of time that may elapse between retries in milliseconds. + * @default 128000 + */ + maxInterval?: number; + /** + * Maximum amount of time that may elapse between retries. + * @default 3 + */ + maxAttempts?: number; + /** Whether to add random jitter to the interval between retries. */ + jitter?: boolean; + /** A function that returns True for exceptions that should trigger a retry. */ + // eslint-disable-next-line @typescript-eslint/no-explicit-any + retryOn?: (e: any) => boolean; +}; diff --git a/libs/langgraph/src/tests/pregel.test.ts b/libs/langgraph/src/tests/pregel.test.ts index c7904256..c3e61f72 100644 --- a/libs/langgraph/src/tests/pregel.test.ts +++ b/libs/langgraph/src/tests/pregel.test.ts @@ -1982,6 +1982,100 @@ it("should type-error when Channel.subscribeTo would throw at runtime", () => { }).toThrow(); }); +it("should invoke checkpoint two", async () => { + const checkpointer = new MemorySaverAssertImmutable(); // Replace with actual checkpointer implementation + const addOne = jest.fn( + (x: { total: number; input: number }) => x.total + x.input + ); + let erroredOnce = false; + let nonRetryableErrorCount = 0; + + const raiseIfAbove10 = (input: number): number => { + if (input > 4) { + if (!erroredOnce) { + erroredOnce = true; + const error = new Error("I will be retried"); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any).status = 500; + throw error; + } + } + if (input > 10) { + nonRetryableErrorCount += 1; + const error = new Error("Input is too large"); + // eslint-disable-next-line @typescript-eslint/no-explicit-any + (error as any).status = 400; + throw error; + } + return input; + }; + + const one = Channel.subscribeTo(["input"]) + .join(["total"]) + .pipe(addOne) + .pipe(Channel.writeTo(["output", "total"])) + .pipe(raiseIfAbove10); + + const app = new Pregel({ + nodes: { one }, + channels: { + total: new BinaryOperatorAggregate((a, b) => a + b), + input: new LastValue(), + output: new LastValue(), + }, + inputChannels: "input", + outputChannels: "output", + checkpointer, + // Use the default policy + retryPolicy: {}, + }); + + // total starts out as 0, so output is 0+2=2 + expect(await app.invoke(2, { configurable: { thread_id: "1" } })).toBe(2); + let checkpoint = await checkpointer.get({ configurable: { thread_id: "1" } }); + expect(checkpoint).not.toBeNull(); + expect(checkpoint?.channel_values.total).toBe(2); + expect(erroredOnce).toBe(false); + expect(nonRetryableErrorCount).toBe(0); + + // total is now 2, so output is 2+3=5 + expect(await app.invoke(3, { configurable: { thread_id: "1" } })).toBe(5); + expect(erroredOnce).toBeTruthy(); + let checkpointTuple = await checkpointer.getTuple({ + configurable: { thread_id: "1" }, + }); + expect(checkpointTuple).not.toBeNull(); + expect(checkpointTuple?.checkpoint.channel_values.total).toBe(7); + expect(erroredOnce).toBe(true); + expect(nonRetryableErrorCount).toBe(0); + + // total is now 2+5=7, so output would be 7+4=11, but raises Error + await expect( + app.invoke(4, { configurable: { thread_id: "1" } }) + ).rejects.toThrow("Input is too large"); + + // checkpoint is not updated, error is recorded + checkpointTuple = await checkpointer.getTuple({ + configurable: { thread_id: "1" }, + }); + expect(checkpointTuple).not.toBeNull(); + expect(checkpointTuple?.checkpoint.channel_values.total).toBe(7); + expect(checkpointTuple?.pendingWrites).toEqual([ + [expect.any(String), "__error__", { message: "Input is too large" }], + ]); + expect(nonRetryableErrorCount).toBe(1); + + // on a new thread, total starts out as 0, so output is 0+5=5 + expect(await app.invoke(5, { configurable: { thread_id: "2" } })).toBe(5); + checkpoint = await checkpointer.get({ configurable: { thread_id: "1" } }); + expect(checkpoint).not.toBeNull(); + expect(checkpoint?.channel_values.total).toBe(7); + checkpoint = await checkpointer.get({ configurable: { thread_id: "2" } }); + expect(checkpoint).not.toBeNull(); + expect(checkpoint?.channel_values.total).toBe(5); + expect(nonRetryableErrorCount).toBe(1); +}); + describe("StateGraph", () => { class SearchAPI extends Tool { name = "search_api"; diff --git a/libs/langgraph/src/web.ts b/libs/langgraph/src/web.ts index 8eb3bba1..113d0f5f 100644 --- a/libs/langgraph/src/web.ts +++ b/libs/langgraph/src/web.ts @@ -18,6 +18,7 @@ export { InvalidUpdateError, EmptyChannelError, } from "./errors.js"; +export { type RetryPolicy } from "./pregel/utils.js"; export { Send } from "./constants.js"; export {