Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(langgraph): Fix validation and drawing for edgeless graphs #723

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions libs/langgraph/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,34 @@ export type Interrupt = {
};

export type CommandParams<R> = {
/**
* Value to resume execution with. To be used together with {@link interrupt}.
*/
resume?: R;
/**
* Graph to send the command to. Supported values are:
* - None: the current graph (default)
* - GraphCommand.PARENT: closest parent graph
*/
graph?: string;
/**
* Update to apply to the graph's state.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
update?: Record<string, any>;
/**
* Can be one of the following:
* - name of the node to navigate to next (any node that belongs to the specified `graph`)
* - sequence of node names to navigate to next
* - `Send` object (to execute a node with the input provided)
* - sequence of `Send` objects
*/
goto?: string | Send | (string | Send)[];
};

/**
* One or more commands to update the graph's state and send messages to nodes.
*/
export class Command<R = unknown> {
lg_name = "Command";

Expand Down
24 changes: 24 additions & 0 deletions libs/langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,14 @@ export type NodeSpec<RunInput, RunOutput> = {
metadata?: Record<string, unknown>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
subgraphs?: Pregel<any, any>[];
ends?: string[];
};

export type AddNodeOptions = {
metadata?: Record<string, unknown>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
subgraphs?: Pregel<any, any>[];
ends?: string[];
};

export class Graph<
Expand Down Expand Up @@ -243,6 +245,7 @@ export class Graph<
runnable,
metadata: options?.metadata,
subgraphs: isPregelLike(runnable) ? [runnable] : options?.subgraphs,
ends: options?.ends,
} as NodeSpecType;

return this as Graph<N | K, RunInput, RunOutput, NodeSpecType>;
Expand Down Expand Up @@ -452,6 +455,11 @@ export class Graph<
}
}
}
for (const node of Object.values<NodeSpecType>(this.nodes)) {
for (const target of node.ends ?? []) {
allTargets.add(target);
}
}
// validate targets
for (const node of Object.keys(this.nodes)) {
if (!allTargets.has(node)) {
Expand Down Expand Up @@ -519,6 +527,7 @@ export class CompiledGraph<
triggers: [],
metadata: node.metadata,
subgraphs: node.subgraphs,
ends: node.ends,
})
.pipe(node.runnable)
.pipe(
Expand Down Expand Up @@ -758,6 +767,21 @@ export class CompiledGraph<
}
}
}
for (const [key, node] of Object.entries(this.builder.nodes) as [
N,
NodeSpec<State, Update>
][]) {
if (node.ends !== undefined) {
for (const end of node.ends) {
addEdge(
_escapeMermaidKeywords(key),
_escapeMermaidKeywords(end),
undefined,
true
);
}
}
}
return graph;
}

Expand Down
2 changes: 2 additions & 0 deletions libs/langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ export class StateGraph<
? // eslint-disable-next-line @typescript-eslint/no-explicit-any
[runnable as any]
: options?.subgraphs,
ends: options?.ends,
};

this.nodes[key as unknown as N] = nodeSpec;
Expand Down Expand Up @@ -603,6 +604,7 @@ export class CompiledStateGraph<
metadata: node?.metadata,
retryPolicy: node?.retryPolicy,
subgraphs: node?.subgraphs,
ends: node?.ends,
});
}
}
Expand Down
5 changes: 5 additions & 0 deletions libs/langgraph/src/pregel/read.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ interface PregelNodeArgs<RunInput, RunOutput>
metadata?: Record<string, unknown>;
retryPolicy?: RetryPolicy;
subgraphs?: Runnable[];
ends?: string[];
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand Down Expand Up @@ -120,6 +121,8 @@ export class PregelNode<

subgraphs?: Runnable[];

ends?: string[];

constructor(fields: PregelNodeArgs<RunInput, RunOutput>) {
const {
channels,
Expand All @@ -132,6 +135,7 @@ export class PregelNode<
retryPolicy,
tags,
subgraphs,
ends,
} = fields;
const mergedTags = [
...(fields.config?.tags ? fields.config.tags : []),
Expand Down Expand Up @@ -159,6 +163,7 @@ export class PregelNode<
this.tags = mergedTags;
this.retryPolicy = retryPolicy;
this.subgraphs = subgraphs;
this.ends = ends;
}

getWriters(): Array<Runnable> {
Expand Down
59 changes: 59 additions & 0 deletions libs/langgraph/src/tests/pregel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2005,6 +2005,65 @@ export function runPregelTests(
expect(res).toEqual({ items: ["0", "1", "2", "2", "3"] });
});

it("should support a simple edgeless graph", async () => {
const StateAnnotation = Annotation.Root({
foo: Annotation<string>,
});

const nodeA = async (state: typeof StateAnnotation.State) => {
console.log("Called A");
const goto = state.foo === "foo" ? "nodeB" : "nodeC";
return new Command({
update: {
foo: "a",
},
goto,
});
};

const nodeB = async (state: typeof StateAnnotation.State) => {
console.log("Called B");
return {
foo: state.foo + "|b",
};
};

const nodeC = async (state: typeof StateAnnotation.State) => {
console.log("Called C");
return {
foo: state.foo + "|c",
};
};

const graph = new StateGraph(StateAnnotation)
.addNode("nodeA", nodeA, {
ends: ["nodeB", "nodeC"],
})
.addNode("nodeB", nodeB)
.addNode("nodeC", nodeC)
.addEdge("__start__", "nodeA")
.compile();

const drawableGraph = await graph.getGraphAsync();
const mermaid = drawableGraph.drawMermaid();
// console.log(mermaid);
expect(mermaid).toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
__start__([<p>__start__</p>]):::first
nodeA(nodeA)
nodeB(nodeB)
nodeC(nodeC)
__start__ --> nodeA;
nodeA -.-> nodeB;
nodeA -.-> nodeC;
classDef default fill:#f2f0ff,line-height:1.2;
classDef first fill-opacity:0;
classDef last fill:#bfb6fc;
`);
expect(await graph.invoke({ foo: "foo" })).toEqual({ foo: "a|b" });
expect(await graph.invoke({ foo: "" })).toEqual({ foo: "a|c" });
});

it("should handle send sequences correctly", async () => {
const StateAnnotation = Annotation.Root({
items: Annotation<any[]>({
Expand Down
Loading