Skip to content

Commit

Permalink
Fix validation and drawing for edgeless graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Dec 10, 2024
1 parent b6063a9 commit 89b0feb
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 1 deletion.
21 changes: 21 additions & 0 deletions libs/langgraph/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,34 @@ export type Interrupt = {
};

export type CommandParams<R> = {
/**
* Value to resume execution with. To be used together with {@link interrupt}.
*/
resume?: R;
/**
* Graph to send the command to. Supported values are:
* - None: the current graph (default)
* - GraphCommand.PARENT: closest parent graph
*/
graph?: string;
/**
* Update to apply to the graph's state.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
update?: Record<string, any>;
/**
* Can be one of the following:
* - name of the node to navigate to next (any node that belongs to the specified `graph`)
* - sequence of node names to navigate to next
* - `Send` object (to execute a node with the input provided)
* - sequence of `Send` objects
*/
goto?: string | Send | (string | Send)[];
};

/**
* One or more commands to update the graph's state and send messages to nodes.
*/
export class Command<R = unknown> {
lg_name = "Command";

Expand Down
16 changes: 16 additions & 0 deletions libs/langgraph/src/graph/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,14 @@ export type NodeSpec<RunInput, RunOutput> = {
metadata?: Record<string, unknown>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
subgraphs?: Pregel<any, any>[];
ends?: string[];
};

export type AddNodeOptions = {
metadata?: Record<string, unknown>;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
subgraphs?: Pregel<any, any>[];
ends?: string[];
};

export class Graph<
Expand Down Expand Up @@ -243,6 +245,7 @@ export class Graph<
runnable,
metadata: options?.metadata,
subgraphs: isPregelLike(runnable) ? [runnable] : options?.subgraphs,
ends: options?.ends,
} as NodeSpecType;

return this as Graph<N | K, RunInput, RunOutput, NodeSpecType>;
Expand Down Expand Up @@ -452,6 +455,11 @@ export class Graph<
}
}
}
for (const node of Object.values<NodeSpecType>(this.nodes)) {
for (const target of node.ends ?? []) {
allTargets.add(target);
}
}
// validate targets
for (const node of Object.keys(this.nodes)) {
if (!allTargets.has(node)) {
Expand Down Expand Up @@ -519,6 +527,7 @@ export class CompiledGraph<
triggers: [],
metadata: node.metadata,
subgraphs: node.subgraphs,
ends: node.ends,
})
.pipe(node.runnable)
.pipe(
Expand Down Expand Up @@ -758,6 +767,13 @@ export class CompiledGraph<
}
}
}
for (const [key, node] of Object.entries(this.builder.nodes) as [N, NodeSpec<State, Update>][]) {
if (node.ends !== undefined) {
for (const end of node.ends) {
addEdge(_escapeMermaidKeywords(key), _escapeMermaidKeywords(end), undefined, true);
}
}
}
return graph;
}

Expand Down
6 changes: 5 additions & 1 deletion libs/langgraph/src/graph/state.ts
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ export class StateGraph<
? // eslint-disable-next-line @typescript-eslint/no-explicit-any
[runnable as any]
: options?.subgraphs,
ends: options?.ends,
};

this.nodes[key as unknown as N] = nodeSpec;
Expand Down Expand Up @@ -439,7 +440,9 @@ export class StateGraph<
compiled.attachBranch(START, SELF, _getControlBranch() as Branch<S, N>, {
withReader: false,
});
for (const [key] of Object.entries<StateGraphNodeSpec<S, U>>(this.nodes)) {
for (const [key] of Object.entries<StateGraphNodeSpec<S, U>>(
this.nodes
)) {
compiled.attachBranch(
key as N,
SELF,
Expand Down Expand Up @@ -603,6 +606,7 @@ export class CompiledStateGraph<
metadata: node?.metadata,
retryPolicy: node?.retryPolicy,
subgraphs: node?.subgraphs,
ends: node?.ends,
});
}
}
Expand Down
5 changes: 5 additions & 0 deletions libs/langgraph/src/pregel/read.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ interface PregelNodeArgs<RunInput, RunOutput>
metadata?: Record<string, unknown>;
retryPolicy?: RetryPolicy;
subgraphs?: Runnable[];
ends?: string[];
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
Expand Down Expand Up @@ -120,6 +121,8 @@ export class PregelNode<

subgraphs?: Runnable[];

ends?: string[];

constructor(fields: PregelNodeArgs<RunInput, RunOutput>) {
const {
channels,
Expand All @@ -132,6 +135,7 @@ export class PregelNode<
retryPolicy,
tags,
subgraphs,
ends,
} = fields;
const mergedTags = [
...(fields.config?.tags ? fields.config.tags : []),
Expand Down Expand Up @@ -159,6 +163,7 @@ export class PregelNode<
this.tags = mergedTags;
this.retryPolicy = retryPolicy;
this.subgraphs = subgraphs;
this.ends = ends;
}

getWriters(): Array<Runnable> {
Expand Down
59 changes: 59 additions & 0 deletions libs/langgraph/src/tests/pregel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2005,6 +2005,65 @@ export function runPregelTests(
expect(res).toEqual({ items: ["0", "1", "2", "2", "3"] });
});

it("should support a simple edgeless graph", async () => {
const StateAnnotation = Annotation.Root({
foo: Annotation<string>,
});

const nodeA = async (state: typeof StateAnnotation.State) => {
console.log("Called A");
const goto = state.foo === "foo" ? "nodeB" : "nodeC";
return new Command({
update: {
foo: "a",
},
goto,
});
};

const nodeB = async (state: typeof StateAnnotation.State) => {
console.log("Called B");
return {
foo: state.foo + "|b",
};
};

const nodeC = async (state: typeof StateAnnotation.State) => {
console.log("Called C");
return {
foo: state.foo + "|c",
};
};

const graph = new StateGraph(StateAnnotation)
.addNode("nodeA", nodeA, {
ends: ["nodeB", "nodeC"],
})
.addNode("nodeB", nodeB)
.addNode("nodeC", nodeC)
.addEdge("__start__", "nodeA")
.compile();

const drawableGraph = await graph.getGraphAsync();
const mermaid = drawableGraph.drawMermaid();
// console.log(mermaid);
expect(mermaid).toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
__start__([<p>__start__</p>]):::first
nodeA(nodeA)
nodeB(nodeB)
nodeC(nodeC)
__start__ --> nodeA;
nodeA -.-> nodeB;
nodeA -.-> nodeC;
classDef default fill:#f2f0ff,line-height:1.2;
classDef first fill-opacity:0;
classDef last fill:#bfb6fc;
`)
expect(await graph.invoke({ foo: "foo" })).toEqual({ foo: "a|b" });
expect(await graph.invoke({ foo: "" })).toEqual({ foo: "a|c" });
});

it("should handle send sequences correctly", async () => {
const StateAnnotation = Annotation.Root({
items: Annotation<any[]>({
Expand Down

0 comments on commit 89b0feb

Please sign in to comment.