Skip to content

Commit

Permalink
Adds support for retry policy for nodes (#377)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Aug 26, 2024
1 parent 5f8a645 commit 6cebd17
Show file tree
Hide file tree
Showing 13 changed files with 446 additions and 129 deletions.
1 change: 0 additions & 1 deletion libs/checkpoint-mongodb/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ const checkpoint = {
}
},
pending_sends: [],
current_tasks: {}
}

// store checkpoint
Expand Down
1 change: 0 additions & 1 deletion libs/checkpoint-sqlite/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ const checkpoint = {
}
},
pending_sends: [],
current_tasks: {}
}

// store checkpoint
Expand Down
1 change: 0 additions & 1 deletion libs/checkpoint/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ const checkpoint = {
}
},
pending_sends: [],
current_tasks: {}
}

// store checkpoint
Expand Down
52 changes: 35 additions & 17 deletions libs/langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,25 @@ export class Branch<IO, N extends string> {
}
}

export type NodeSpec<RunInput, RunOutput> = {
runnable: Runnable<RunInput, RunOutput>;
metadata?: Record<string, unknown>;
};

export type AddNodeOptions = { metadata?: Record<string, unknown> };

export class Graph<
N extends string = typeof END,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunInput = any,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput = any
RunOutput = any,
NodeSpecType extends NodeSpec<RunInput, RunOutput> = NodeSpec<
RunInput,
RunOutput
>
> {
nodes: Record<N, Runnable<RunInput, RunOutput>>;
nodes: Record<N, NodeSpecType>;

edges: Set<[N | typeof START, N | typeof END]>;

Expand All @@ -120,12 +131,12 @@ export class Graph<
supportMultipleEdges = false;

constructor() {
this.nodes = {} as Record<N, Runnable<RunInput, RunOutput>>;
this.nodes = {} as Record<N, NodeSpecType>;
this.edges = new Set();
this.branches = {};
}

private warnIfCompiled(message: string): void {
protected warnIfCompiled(message: string): void {
if (this.compiled) {
console.warn(message);
}
Expand All @@ -137,7 +148,8 @@ export class Graph<

addNode<K extends string, NodeInput = RunInput>(
key: K,
action: RunnableLike<NodeInput, RunOutput>
action: RunnableLike<NodeInput, RunOutput>,
options?: AddNodeOptions
): Graph<N | K, RunInput, RunOutput> {
if (key.includes(CHECKPOINT_NAMESPACE_SEPARATOR)) {
throw new Error(
Expand All @@ -155,12 +167,15 @@ export class Graph<
throw new Error(`Node \`${key}\` is reserved.`);
}

this.nodes[key as unknown as N] = _coerceToRunnable<RunInput, RunOutput>(
// Account for arbitrary state due to Send API
action as RunnableLike<RunInput, RunOutput>
);
this.nodes[key as unknown as N] = {
runnable: _coerceToRunnable<RunInput, RunOutput>(
// Account for arbitrary state due to Send API
action as RunnableLike<RunInput, RunOutput>
),
metadata: options?.metadata,
} as NodeSpecType;

return this as Graph<N | K, RunInput, RunOutput>;
return this as Graph<N | K, RunInput, RunOutput, NodeSpecType>;
}

addEdge(startKey: N | typeof START, endKey: N | typeof END): this {
Expand Down Expand Up @@ -278,7 +293,7 @@ export class Graph<
});

// attach nodes, edges and branches
for (const [key, node] of Object.entries<Runnable<RunInput, RunOutput>>(
for (const [key, node] of Object.entries<NodeSpec<RunInput, RunOutput>>(
this.nodes
)) {
compiled.attachNode(key as N, node);
Expand Down Expand Up @@ -385,13 +400,14 @@ export class CompiledGraph<
this.builder = builder;
}

attachNode(key: N, node: Runnable<RunInput, RunOutput>): void {
attachNode(key: N, node: NodeSpec<RunInput, RunOutput>): void {
this.channels[key] = new EphemeralValue();
this.nodes[key] = new PregelNode({
channels: [],
triggers: [],
metadata: node.metadata,
})
.pipe(node)
.pipe(node.runnable)
.pipe(
new ChannelWrite([{ channel: key, value: PASSTHROUGH }], [TAG_HIDDEN])
);
Expand Down Expand Up @@ -477,14 +493,16 @@ export class CompiledGraph<
END
),
};
for (const [key, node] of Object.entries<Runnable>(this.builder.nodes)) {
for (const [key, node] of Object.entries<NodeSpec<unknown, unknown>>(
this.builder.nodes
)) {
if (config?.xray) {
const subgraph = isCompiledGraph(node)
? node.getGraph({
...config,
xray: typeof xray === "number" && xray > 0 ? xray - 1 : xray,
})
: node.getGraph(config);
: node.runnable.getGraph(config);
subgraph.trimFirstNode();
subgraph.trimLastNode();
if (Object.keys(subgraph.nodes).length > 1) {
Expand All @@ -496,12 +514,12 @@ export class CompiledGraph<
startNodes[key] = newStartNode;
}
} else {
const newNode = graph.addNode(node, key);
const newNode = graph.addNode(node.runnable, key);
startNodes[key] = newNode;
endNodes[key] = newNode;
}
} else {
const newNode = graph.addNode(node, key);
const newNode = graph.addNode(node.runnable, key);
startNodes[key] = newNode;
endNodes[key] = newNode;
}
Expand Down
79 changes: 66 additions & 13 deletions libs/langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
/* eslint-disable @typescript-eslint/no-use-before-define */
import {
_coerceToRunnable,
Runnable,
RunnableConfig,
RunnableLike,
} from "@langchain/core/runnables";
import { BaseCheckpointSaver } from "@langchain/langgraph-checkpoint";
import { BaseChannel } from "../channels/base.js";
import { END, CompiledGraph, Graph, START, Branch } from "./graph.js";
import {
END,
CompiledGraph,
Graph,
START,
Branch,
AddNodeOptions,
NodeSpec,
} from "./graph.js";
import {
ChannelWrite,
ChannelWriteEntry,
Expand All @@ -18,7 +26,12 @@ import { NamedBarrierValue } from "../channels/named_barrier_value.js";
import { EphemeralValue } from "../channels/ephemeral_value.js";
import { RunnableCallable } from "../utils.js";
import { All } from "../pregel/types.js";
import { _isSend, Send, TAG_HIDDEN } from "../constants.js";
import {
_isSend,
CHECKPOINT_NAMESPACE_SEPARATOR,
Send,
TAG_HIDDEN,
} from "../constants.js";
import { InvalidUpdateError } from "../errors.js";
import {
AnnotationRoot,
Expand All @@ -28,6 +41,7 @@ import {
StateType,
UpdateType,
} from "./annotation.js";
import type { RetryPolicy } from "../pregel/utils.js";

const ROOT = "__root__";

Expand All @@ -44,6 +58,19 @@ export interface StateGraphArgs<Channels extends object | unknown> {
: ChannelReducers<{ __root__: Channels }>;
}

export type StateGraphNodeSpec<RunInput, RunOutput> = NodeSpec<
RunInput,
RunOutput
> & {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
input?: any;
retryPolicy?: RetryPolicy;
};

export type StateGraphAddNodeOptions = {
retryPolicy?: RetryPolicy;
} & AddNodeOptions;

/**
* A graph whose nodes communicate by reading and writing to a shared state.
* Each node takes a defined `State` as input and returns a `Partial<State>`.
Expand Down Expand Up @@ -111,7 +138,7 @@ export class StateGraph<
S = SD extends StateDefinition ? StateType<SD> : SD,
U = SD extends StateDefinition ? UpdateType<SD> : Partial<S>,
N extends string = typeof START
> extends Graph<N, S, U> {
> extends Graph<N, S, U, StateGraphNodeSpec<S, U>> {
channels: Record<string, BaseChannel>;

// TODO: this doesn't dedupe edges as in py, so worth fixing at some point
Expand Down Expand Up @@ -155,14 +182,40 @@ export class StateGraph<

addNode<K extends string, NodeInput = S>(
key: K,
action: RunnableLike<NodeInput, U>
action: RunnableLike<NodeInput, U>,
options?: StateGraphAddNodeOptions
): StateGraph<SD, S, U, N | K> {
if (key in this.channels) {
throw new Error(
`${key} is already being used as a state attribute (a.k.a. a channel), cannot also be used as a node name.`
);
}
return super.addNode(key, action) as StateGraph<SD, S, U, N | K>;

if (key.includes(CHECKPOINT_NAMESPACE_SEPARATOR)) {
throw new Error(
`"${CHECKPOINT_NAMESPACE_SEPARATOR}" is a reserved character and is not allowed in node names.`
);
}
this.warnIfCompiled(
`Adding a node to a graph that has already been compiled. This will not be reflected in the compiled graph.`
);

if (key in this.nodes) {
throw new Error(`Node \`${key}\` already present.`);
}
if (key === END || key === START) {
throw new Error(`Node \`${key}\` is reserved.`);
}

const nodeSpec: StateGraphNodeSpec<S, U> = {
runnable: _coerceToRunnable(action) as unknown as Runnable<S, U>,
retryPolicy: options?.retryPolicy,
metadata: options?.metadata,
};

this.nodes[key as unknown as N] = nodeSpec;

return this as StateGraph<SD, S, U, N | K>;
}

addEdge(startKey: typeof START | N | N[], endKey: N | typeof END): this {
Expand Down Expand Up @@ -239,7 +292,9 @@ export class StateGraph<

// attach nodes, edges and branches
compiled.attachNode(START);
for (const [key, node] of Object.entries<Runnable<S, U>>(this.nodes)) {
for (const [key, node] of Object.entries<StateGraphNodeSpec<S, U>>(
this.nodes
)) {
compiled.attachNode(key as N, node);
}
for (const [start, end] of this.edges) {
Expand Down Expand Up @@ -284,12 +339,9 @@ export class CompiledStateGraph<

attachNode(key: typeof START, node?: never): void;

attachNode(key: N, node: Runnable<S, U, RunnableConfig>): void;
attachNode(key: N, node: StateGraphNodeSpec<S, U>): void;

attachNode(
key: N | typeof START,
node?: Runnable<S, U, RunnableConfig>
): void {
attachNode(key: N | typeof START, node?: StateGraphNodeSpec<S, U>): void {
const stateKeys = Object.keys(this.builder.channels);

function getStateKey(key: keyof U, input: U) {
Expand Down Expand Up @@ -345,7 +397,8 @@ export class CompiledStateGraph<
[TAG_HIDDEN]
),
],
bound: node,
bound: node?.runnable,
retryPolicy: node?.retryPolicy,
});
}
}
Expand Down
58 changes: 2 additions & 56 deletions libs/langgraph/src/pregel/algo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,62 +56,6 @@ export const increment = (current?: number) => {
return current !== undefined ? current + 1 : 1;
};

export async function* executeTasks(
tasks: Record<
string,
() => Promise<{
// eslint-disable-next-line @typescript-eslint/no-explicit-any
task: PregelExecutableTask<any, any>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
result: any;
error: Error;
}>
>,
stepTimeout?: number,
signal?: AbortSignal
// eslint-disable-next-line @typescript-eslint/no-explicit-any
): AsyncGenerator<PregelExecutableTask<any, any>> {
if (stepTimeout && signal) {
if ("any" in AbortSignal) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
signal = (AbortSignal as any).any([
signal,
AbortSignal.timeout(stepTimeout),
]);
}
} else if (stepTimeout) {
signal = AbortSignal.timeout(stepTimeout);
}

// Abort if signal is aborted
signal?.throwIfAborted();

// Start all tasks
const executingTasks = Object.fromEntries(
Object.entries(tasks).map(([taskId, task]) => {
return [taskId, task()];
})
);
let listener: () => void;
const signalPromise = new Promise<never>((_resolve, reject) => {
listener = () => reject(new Error("Abort"));
signal?.addEventListener("abort", listener);
}).finally(() => signal?.removeEventListener("abort", listener));

while (Object.keys(executingTasks).length > 0) {
const { task, error } = await Promise.race([
...Object.values(executingTasks),
signalPromise,
]);
if (error !== undefined) {
// TODO: don't stop others if exception is interrupt
throw error;
}
yield task;
delete executingTasks[task.id];
}
}

export function shouldInterrupt<N extends PropertyKey, C extends PropertyKey>(
checkpoint: Checkpoint,
interruptNodes: All | N[],
Expand Down Expand Up @@ -426,6 +370,7 @@ export function _prepareNextTasks<
}
),
id: taskId,
retry_policy: proc.retryPolicy,
});
}
} else {
Expand Down Expand Up @@ -521,6 +466,7 @@ export function _prepareNextTasks<
}
),
id: taskId,
retry_policy: proc.retryPolicy,
});
}
} else {
Expand Down
Loading

0 comments on commit 6cebd17

Please sign in to comment.