Skip to content

Commit

Permalink
feat(langgraph): add subgraph add node option to explicitly specify s…
Browse files Browse the repository at this point in the history
…ubgraphs (langchain-ai#620)
  • Loading branch information
dqbd authored Oct 22, 2024
1 parent 9d10d27 commit 90b39ce
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 111 deletions.
21 changes: 16 additions & 5 deletions libs/langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import { gatherIteratorSync, RunnableCallable } from "../utils.js";
import { InvalidUpdateError, NodeInterrupt } from "../errors.js";
import { StateDefinition, StateType } from "./annotation.js";
import type { LangGraphRunnableConfig } from "../pregel/runnable_types.js";
import { isPregelLike } from "../pregel/utils/subgraph.js";

/** Special reserved node name denoting the start of a graph. */
export const START = "__start__";
Expand Down Expand Up @@ -128,9 +129,15 @@ export class Branch<
export type NodeSpec<RunInput, RunOutput> = {
runnable: Runnable<RunInput, RunOutput>;
metadata?: Record<string, unknown>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
subgraphs?: Pregel<any, any>[];
};

export type AddNodeOptions = { metadata?: Record<string, unknown> };
export type AddNodeOptions = {
metadata?: Record<string, unknown>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
subgraphs?: Pregel<any, any>[];
};

export class Graph<
N extends string = typeof END,
Expand Down Expand Up @@ -202,12 +209,15 @@ export class Graph<
throw new Error(`Node \`${key}\` is reserved.`);
}

const runnable = _coerceToRunnable<RunInput, RunOutput>(
// Account for arbitrary state due to Send API
action as RunnableLike<RunInput, RunOutput>
);

this.nodes[key as unknown as N] = {
runnable: _coerceToRunnable<RunInput, RunOutput>(
// Account for arbitrary state due to Send API
action as RunnableLike<RunInput, RunOutput>
),
runnable,
metadata: options?.metadata,
subgraphs: isPregelLike(runnable) ? [runnable] : options?.subgraphs,
} as NodeSpecType;

return this as Graph<N | K, RunInput, RunOutput, NodeSpecType>;
Expand Down Expand Up @@ -462,6 +472,7 @@ export class CompiledGraph<
channels: [],
triggers: [],
metadata: node.metadata,
subgraphs: node.subgraphs,
})
.pipe(node.runnable)
.pipe(
Expand Down
10 changes: 9 additions & 1 deletion libs/langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import {
import type { RetryPolicy } from "../pregel/utils/index.js";
import { isConfiguredManagedValue, ManagedValueSpec } from "../managed/base.js";
import type { LangGraphRunnableConfig } from "../pregel/runnable_types.js";
import { isPregelLike } from "../pregel/utils/subgraph.js";

const ROOT = "__root__";

Expand Down Expand Up @@ -314,11 +315,17 @@ export class StateGraph<
if (options?.input !== undefined) {
this._addSchema(options.input.spec);
}

const runnable = _coerceToRunnable(action) as unknown as Runnable<S, U>;
const nodeSpec: StateGraphNodeSpec<S, U> = {
runnable: _coerceToRunnable(action) as unknown as Runnable<S, U>,
runnable,
retryPolicy: options?.retryPolicy,
metadata: options?.metadata,
input: options?.input?.spec ?? this._schemaDefinition,
subgraphs: isPregelLike(runnable)
? // eslint-disable-next-line @typescript-eslint/no-explicit-any
[runnable as any]
: options?.subgraphs,
};

this.nodes[key as unknown as N] = nodeSpec;
Expand Down Expand Up @@ -541,6 +548,7 @@ export class CompiledStateGraph<
bound: node?.runnable,
metadata: node?.metadata,
retryPolicy: node?.retryPolicy,
subgraphs: node?.subgraphs,
});
}
}
Expand Down
2 changes: 2 additions & 0 deletions libs/langgraph/src/pregel/algo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ export function _prepareSingleTask<
name: packet.node,
input: packet.args,
proc: node,
subgraphs: proc.subgraphs,
writes,
config: patchConfig(
mergeConfigs(config, {
Expand Down Expand Up @@ -665,6 +666,7 @@ export function _prepareSingleTask<
name,
input: val,
proc: node,
subgraphs: proc.subgraphs,
writes,
config: patchConfig(
mergeConfigs(config, {
Expand Down
3 changes: 2 additions & 1 deletion libs/langgraph/src/pregel/debug.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ export function* mapDebugCheckpoint<
const taskStates: Record<string, RunnableConfig | StateSnapshot> = {};

for (const task of tasks) {
if (!findSubgraphPregel(task.proc)) continue;
const candidates = task.subgraphs?.length ? task.subgraphs : [task.proc];
if (!candidates.find(findSubgraphPregel)) continue;

let taskNs = `${task.name as string}:${task.id}`;
if (parentNs) taskNs = `${parentNs}|${taskNs}`;
Expand Down
51 changes: 29 additions & 22 deletions libs/langgraph/src/pregel/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -350,29 +350,36 @@ export class Pregel<
// find the subgraph if any
// eslint-disable-next-line @typescript-eslint/no-explicit-any
type SubgraphPregelType = Pregel<any, any> | undefined;
const graph = findSubgraphPregel(node.bound) as SubgraphPregelType;
// if found, yield recursively
if (graph !== undefined) {
if (name === namespace) {
yield [name, graph];
return;
}
if (namespace === undefined) {
yield [name, graph];
}
if (recurse) {
let newNamespace = namespace;
if (namespace !== undefined) {
newNamespace = namespace.slice(name.length + 1);

const candidates = node.subgraphs?.length ? node.subgraphs : [node.bound];

for (const candidate of candidates) {
const graph = findSubgraphPregel(candidate) as SubgraphPregelType;

if (graph !== undefined) {
if (name === namespace) {
yield [name, graph];
return;
}
for (const [subgraphName, subgraph] of graph.getSubgraphs(
newNamespace,
recurse
)) {
yield [
`${name}${CHECKPOINT_NAMESPACE_SEPARATOR}${subgraphName}`,
subgraph,
];

if (namespace === undefined) {
yield [name, graph];
}

if (recurse) {
let newNamespace = namespace;
if (namespace !== undefined) {
newNamespace = namespace.slice(name.length + 1);
}
for (const [subgraphName, subgraph] of graph.getSubgraphs(
newNamespace,
recurse
)) {
yield [
`${name}${CHECKPOINT_NAMESPACE_SEPARATOR}${subgraphName}`,
subgraph,
];
}
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions libs/langgraph/src/pregel/read.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ interface PregelNodeArgs<RunInput, RunOutput>
config?: RunnableConfig;
metadata?: Record<string, unknown>;
retryPolicy?: RetryPolicy;
subgraphs?: Runnable[];
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand Down Expand Up @@ -117,6 +118,8 @@ export class PregelNode<

retryPolicy?: RetryPolicy;

subgraphs?: Runnable[];

constructor(fields: PregelNodeArgs<RunInput, RunOutput>) {
const {
channels,
Expand All @@ -128,6 +131,7 @@ export class PregelNode<
metadata,
retryPolicy,
tags,
subgraphs,
} = fields;
const mergedTags = [
...(fields.config?.tags ? fields.config.tags : []),
Expand All @@ -154,6 +158,7 @@ export class PregelNode<
this.metadata = metadata ?? this.metadata;
this.tags = mergedTags;
this.retryPolicy = retryPolicy;
this.subgraphs = subgraphs;
}

getWriters(): Array<Runnable> {
Expand Down
1 change: 1 addition & 0 deletions libs/langgraph/src/pregel/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ export interface PregelExecutableTask<
readonly retry_policy?: RetryPolicy;
readonly id: string;
readonly path?: [string, ...(string | number)[]];
readonly subgraphs?: Runnable[];
}

export interface StateSnapshot {
Expand Down
11 changes: 8 additions & 3 deletions libs/langgraph/src/pregel/utils/subgraph.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import { RunnableSequence, Runnable } from "@langchain/core/runnables";
import {
RunnableSequence,
Runnable,
RunnableLike,
} from "@langchain/core/runnables";
import type { PregelInterface } from "../types.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand All @@ -7,9 +11,10 @@ function isRunnableSequence(
): x is RunnableSequence {
return "steps" in x && Array.isArray(x.steps);
}
function isPregelLike(

export function isPregelLike(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
x: PregelInterface<any, any> | Runnable
x: PregelInterface<any, any> | RunnableLike<any, any, any>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
): x is PregelInterface<any, any> {
return (
Expand Down
Loading

0 comments on commit 90b39ce

Please sign in to comment.