diff --git a/langchain-core/src/runnables/graph.ts b/langchain-core/src/runnables/graph.ts index 854688e95ecf..8521dd334a79 100644 --- a/langchain-core/src/runnables/graph.ts +++ b/langchain-core/src/runnables/graph.ts @@ -1,19 +1,13 @@ import { zodToJsonSchema } from "zod-to-json-schema"; import { v4 as uuidv4, validate as isUuid } from "uuid"; -import type { RunnableInterface, RunnableIOSchema } from "./types.js"; +import type { + RunnableInterface, + RunnableIOSchema, + Node, + Edge, +} from "./types.js"; import { isRunnableInterface } from "./utils.js"; - -interface Edge { - source: string; - target: string; - data?: string; -} - -interface Node { - id: string; - - data: RunnableIOSchema | RunnableInterface; -} +import { drawMermaid, drawMermaidPng } from "./graph_mermaid.js"; const MAX_DATA_DISPLAY_NAME_LENGTH = 42; @@ -22,17 +16,12 @@ export function nodeDataStr(node: Node): string { return node.id; } else if (isRunnableInterface(node.data)) { try { - let data = node.data.toString(); - if ( - data.startsWith("<") || - data[0] !== data[0].toUpperCase() || - data.split("\n").length > 1 - ) { - data = node.data.getName(); - } else if (data.length > MAX_DATA_DISPLAY_NAME_LENGTH) { + let data = node.data.getName(); + data = data.startsWith("Runnable") ? data.slice("Runnable".length) : data; + if (data.length > MAX_DATA_DISPLAY_NAME_LENGTH) { data = `${data.substring(0, MAX_DATA_DISPLAY_NAME_LENGTH)}...`; } - return data.startsWith("Runnable") ? data.slice("Runnable".length) : data; + return data; } catch (error) { return node.data.getName(); } @@ -179,4 +168,50 @@ export class Graph { } } } + + drawMermaid(params?: { + withStyles?: boolean; + curveStyle?: string; + nodeColors?: Record; + wrapLabelNWords?: number; + }): string { + const { + withStyles, + curveStyle, + nodeColors = { start: "#ffdfba", end: "#baffc9", other: "#fad7de" }, + wrapLabelNWords, + } = params ?? {}; + const nodes: Record = {}; + for (const node of Object.values(this.nodes)) { + nodes[node.id] = nodeDataStr(node); + } + + const firstNode = this.firstNode(); + const firstNodeLabel = firstNode ? nodeDataStr(firstNode) : undefined; + + const lastNode = this.lastNode(); + const lastNodeLabel = lastNode ? nodeDataStr(lastNode) : undefined; + + return drawMermaid(nodes, this.edges, { + firstNodeLabel, + lastNodeLabel, + withStyles, + curveStyle, + nodeColors, + wrapLabelNWords, + }); + } + + async drawMermaidPng(params?: { + withStyles?: boolean; + curveStyle?: string; + nodeColors?: Record; + wrapLabelNWords?: number; + backgroundColor?: string; + }): Promise { + const mermaidSyntax = this.drawMermaid(params); + return drawMermaidPng(mermaidSyntax, { + backgroundColor: params?.backgroundColor, + }); + } } diff --git a/langchain-core/src/runnables/graph_mermaid.ts b/langchain-core/src/runnables/graph_mermaid.ts new file mode 100644 index 000000000000..bd7a7a5c5026 --- /dev/null +++ b/langchain-core/src/runnables/graph_mermaid.ts @@ -0,0 +1,177 @@ +import { Edge } from "./types.js"; + +function _escapeNodeLabel(nodeLabel: string): string { + // Escapes the node label for Mermaid syntax. + return nodeLabel.replace(/[^a-zA-Z-_0-9]/g, "_"); +} + +// Adjusts Mermaid edge to map conditional nodes to pure nodes. +function _adjustMermaidEdge(edge: Edge, nodes: Record) { + const sourceNodeLabel = nodes[edge.source] ?? edge.source; + const targetNodeLabel = nodes[edge.target] ?? edge.target; + return [sourceNodeLabel, targetNodeLabel]; +} + +function _generateMermaidGraphStyles( + nodeColors: Record +): string { + let styles = ""; + for (const [className, color] of Object.entries(nodeColors)) { + styles += `\tclassDef ${className}class fill:${color};\n`; + } + return styles; +} + +/** + * Draws a Mermaid graph using the provided graph data + */ +export function drawMermaid( + nodes: Record, + edges: Edge[], + config?: { + firstNodeLabel?: string; + lastNodeLabel?: string; + curveStyle?: string; + withStyles?: boolean; + nodeColors?: Record; + wrapLabelNWords?: number; + } +): string { + const { + firstNodeLabel, + lastNodeLabel, + nodeColors, + withStyles = true, + curveStyle = "linear", + wrapLabelNWords = 9, + } = config ?? {}; + // Initialize Mermaid graph configuration + let mermaidGraph = withStyles + ? `%%{init: {'flowchart': {'curve': '${curveStyle}'}}}%%\ngraph TD;\n` + : "graph TD;\n"; + if (withStyles) { + // Node formatting templates + const defaultClassLabel = "default"; + const formatDict: Record = { + [defaultClassLabel]: "{0}([{1}]):::otherclass", + }; + if (firstNodeLabel !== undefined) { + formatDict[firstNodeLabel] = "{0}[{0}]:::startclass"; + } + if (lastNodeLabel !== undefined) { + formatDict[lastNodeLabel] = "{0}[{0}]:::endclass"; + } + + // Add nodes to the graph + for (const node of Object.values(nodes)) { + const nodeLabel = formatDict[node] ?? formatDict[defaultClassLabel]; + const escapedNodeLabel = _escapeNodeLabel(node); + const nodeParts = node.split(":"); + const nodeSplit = nodeParts[nodeParts.length - 1]; + mermaidGraph += `\t${nodeLabel + .replace(/\{0\}/g, escapedNodeLabel) + .replace(/\{1\}/g, nodeSplit)};\n`; + } + } + let subgraph = ""; + // Add edges to the graph + for (const edge of edges) { + const sourcePrefix = edge.source.includes(":") + ? edge.source.split(":")[0] + : undefined; + const targetPrefix = edge.target.includes(":") + ? edge.target.split(":")[0] + : undefined; + // Exit subgraph if source or target is not in the same subgraph + if ( + subgraph !== "" && + (subgraph !== sourcePrefix || subgraph !== targetPrefix) + ) { + mermaidGraph += "\tend\n"; + subgraph = ""; + } + // Enter subgraph if source and target are in the same subgraph + if ( + subgraph === "" && + sourcePrefix !== undefined && + sourcePrefix === targetPrefix + ) { + mermaidGraph = `\tsubgraph ${sourcePrefix}\n`; + subgraph = sourcePrefix; + } + const [source, target] = _adjustMermaidEdge(edge, nodes); + let edgeLabel = ""; + // Add BR every wrapLabelNWords words + if (edge.data !== undefined) { + let edgeData = edge.data; + const words = edgeData.split(" "); + // Group words into chunks of wrapLabelNWords size + if (words.length > wrapLabelNWords) { + edgeData = words + .reduce((acc: string[], word: string, i: number) => { + if (i % wrapLabelNWords === 0) acc.push(""); + acc[acc.length - 1] += ` ${word}`; + return acc; + }, []) + .join("
"); + if (edge.conditional) { + edgeLabel = ` -. ${edgeData} .-> `; + } else { + edgeLabel = ` -- ${edgeData} --> `; + } + } + } else { + if (edge.conditional) { + edgeLabel = ` -.-> `; + } else { + edgeLabel = ` --> `; + } + } + mermaidGraph += `\t${_escapeNodeLabel( + source + )}${edgeLabel}${_escapeNodeLabel(target)};\n`; + } + if (subgraph !== "") { + mermaidGraph += "end\n"; + } + + // Add custom styles for nodes + if (withStyles && nodeColors !== undefined) { + mermaidGraph += _generateMermaidGraphStyles(nodeColors); + } + return mermaidGraph; +} + +/** + * Renders Mermaid graph using the Mermaid.INK API. + */ +export async function drawMermaidPng( + mermaidSyntax: string, + config?: { + backgroundColor?: string; + } +) { + let { backgroundColor = "white" } = config ?? {}; + // Use btoa for compatibility, assume ASCII + const mermaidSyntaxEncoded = btoa(mermaidSyntax); + // Check if the background color is a hexadecimal color code using regex + if (backgroundColor !== undefined) { + const hexColorPattern = /^#(?:[0-9a-fA-F]{3}){1,2}$/; + if (!hexColorPattern.test(backgroundColor)) { + backgroundColor = `!${backgroundColor}`; + } + } + const imageUrl = `https://mermaid.ink/img/${mermaidSyntaxEncoded}?bgColor=${backgroundColor}`; + const res = await fetch(imageUrl); + if (!res.ok) { + throw new Error( + [ + `Failed to render the graph using the Mermaid.INK API.`, + `Status code: ${res.status}`, + `Status text: ${res.statusText}`, + ].join("\n") + ); + } + const content = await res.blob(); + return content; +} diff --git a/langchain-core/src/runnables/tests/data/mermaid.png b/langchain-core/src/runnables/tests/data/mermaid.png new file mode 100644 index 000000000000..c0db4875bff3 Binary files /dev/null and b/langchain-core/src/runnables/tests/data/mermaid.png differ diff --git a/langchain-core/src/runnables/tests/runnable_graph.test.ts b/langchain-core/src/runnables/tests/runnable_graph.test.ts index 352eacc618dd..909792cab626 100644 --- a/langchain-core/src/runnables/tests/runnable_graph.test.ts +++ b/langchain-core/src/runnables/tests/runnable_graph.test.ts @@ -87,4 +87,20 @@ test("Test graph sequence", async () => { { source: 2, target: 3 }, ], }); + expect(graph.drawMermaid()) + .toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%% +graph TD; +\tPromptTemplateInput[PromptTemplateInput]:::startclass; +\tPromptTemplate([PromptTemplate]):::otherclass; +\tFakeLLM([FakeLLM]):::otherclass; +\tCommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass; +\tCommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass; +\tPromptTemplateInput --> PromptTemplate; +\tPromptTemplate --> FakeLLM; +\tCommaSeparatedListOutputParser --> CommaSeparatedListOutputParserOutput; +\tFakeLLM --> CommaSeparatedListOutputParser; +\tclassDef startclass fill:#ffdfba; +\tclassDef endclass fill:#baffc9; +\tclassDef otherclass fill:#fad7de; +`); }); diff --git a/langchain-core/src/runnables/types.ts b/langchain-core/src/runnables/types.ts index 0e7e319ddd8e..0050a955f5b8 100644 --- a/langchain-core/src/runnables/types.ts +++ b/langchain-core/src/runnables/types.ts @@ -61,3 +61,15 @@ export interface RunnableInterface< getName(suffix?: string): string; } + +export interface Edge { + source: string; + target: string; + data?: string; + conditional?: boolean; +} + +export interface Node { + id: string; + data: RunnableIOSchema | RunnableInterface; +}