Skip to content

Commit

Permalink
core[patch]: Adds mermaid graph format (#5978)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Jul 4, 2024
1 parent c013136 commit db9e352
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 22 deletions.
79 changes: 57 additions & 22 deletions langchain-core/src/runnables/graph.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
import { zodToJsonSchema } from "zod-to-json-schema";
import { v4 as uuidv4, validate as isUuid } from "uuid";
import type { RunnableInterface, RunnableIOSchema } from "./types.js";
import type {
RunnableInterface,
RunnableIOSchema,
Node,
Edge,
} from "./types.js";
import { isRunnableInterface } from "./utils.js";

interface Edge {
source: string;
target: string;
data?: string;
}

interface Node {
id: string;

data: RunnableIOSchema | RunnableInterface;
}
import { drawMermaid, drawMermaidPng } from "./graph_mermaid.js";

const MAX_DATA_DISPLAY_NAME_LENGTH = 42;

Expand All @@ -22,17 +16,12 @@ export function nodeDataStr(node: Node): string {
return node.id;
} else if (isRunnableInterface(node.data)) {
try {
let data = node.data.toString();
if (
data.startsWith("<") ||
data[0] !== data[0].toUpperCase() ||
data.split("\n").length > 1
) {
data = node.data.getName();
} else if (data.length > MAX_DATA_DISPLAY_NAME_LENGTH) {
let data = node.data.getName();
data = data.startsWith("Runnable") ? data.slice("Runnable".length) : data;
if (data.length > MAX_DATA_DISPLAY_NAME_LENGTH) {
data = `${data.substring(0, MAX_DATA_DISPLAY_NAME_LENGTH)}...`;
}
return data.startsWith("Runnable") ? data.slice("Runnable".length) : data;
return data;
} catch (error) {
return node.data.getName();
}
Expand Down Expand Up @@ -179,4 +168,50 @@ export class Graph {
}
}
}

drawMermaid(params?: {
withStyles?: boolean;
curveStyle?: string;
nodeColors?: Record<string, string>;
wrapLabelNWords?: number;
}): string {
const {
withStyles,
curveStyle,
nodeColors = { start: "#ffdfba", end: "#baffc9", other: "#fad7de" },
wrapLabelNWords,
} = params ?? {};
const nodes: Record<string, string> = {};
for (const node of Object.values(this.nodes)) {
nodes[node.id] = nodeDataStr(node);
}

const firstNode = this.firstNode();
const firstNodeLabel = firstNode ? nodeDataStr(firstNode) : undefined;

const lastNode = this.lastNode();
const lastNodeLabel = lastNode ? nodeDataStr(lastNode) : undefined;

return drawMermaid(nodes, this.edges, {
firstNodeLabel,
lastNodeLabel,
withStyles,
curveStyle,
nodeColors,
wrapLabelNWords,
});
}

async drawMermaidPng(params?: {
withStyles?: boolean;
curveStyle?: string;
nodeColors?: Record<string, string>;
wrapLabelNWords?: number;
backgroundColor?: string;
}): Promise<Blob> {
const mermaidSyntax = this.drawMermaid(params);
return drawMermaidPng(mermaidSyntax, {
backgroundColor: params?.backgroundColor,
});
}
}
177 changes: 177 additions & 0 deletions langchain-core/src/runnables/graph_mermaid.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import { Edge } from "./types.js";

function _escapeNodeLabel(nodeLabel: string): string {
// Escapes the node label for Mermaid syntax.
return nodeLabel.replace(/[^a-zA-Z-_0-9]/g, "_");
}

// Adjusts Mermaid edge to map conditional nodes to pure nodes.
function _adjustMermaidEdge(edge: Edge, nodes: Record<string, string>) {
const sourceNodeLabel = nodes[edge.source] ?? edge.source;
const targetNodeLabel = nodes[edge.target] ?? edge.target;
return [sourceNodeLabel, targetNodeLabel];
}

function _generateMermaidGraphStyles(
nodeColors: Record<string, string>
): string {
let styles = "";
for (const [className, color] of Object.entries(nodeColors)) {
styles += `\tclassDef ${className}class fill:${color};\n`;
}
return styles;
}

/**
* Draws a Mermaid graph using the provided graph data
*/
export function drawMermaid(
nodes: Record<string, string>,
edges: Edge[],
config?: {
firstNodeLabel?: string;
lastNodeLabel?: string;
curveStyle?: string;
withStyles?: boolean;
nodeColors?: Record<string, string>;
wrapLabelNWords?: number;
}
): string {
const {
firstNodeLabel,
lastNodeLabel,
nodeColors,
withStyles = true,
curveStyle = "linear",
wrapLabelNWords = 9,
} = config ?? {};
// Initialize Mermaid graph configuration
let mermaidGraph = withStyles
? `%%{init: {'flowchart': {'curve': '${curveStyle}'}}}%%\ngraph TD;\n`
: "graph TD;\n";
if (withStyles) {
// Node formatting templates
const defaultClassLabel = "default";
const formatDict: Record<string, string> = {
[defaultClassLabel]: "{0}([{1}]):::otherclass",
};
if (firstNodeLabel !== undefined) {
formatDict[firstNodeLabel] = "{0}[{0}]:::startclass";
}
if (lastNodeLabel !== undefined) {
formatDict[lastNodeLabel] = "{0}[{0}]:::endclass";
}

// Add nodes to the graph
for (const node of Object.values(nodes)) {
const nodeLabel = formatDict[node] ?? formatDict[defaultClassLabel];
const escapedNodeLabel = _escapeNodeLabel(node);
const nodeParts = node.split(":");
const nodeSplit = nodeParts[nodeParts.length - 1];
mermaidGraph += `\t${nodeLabel
.replace(/\{0\}/g, escapedNodeLabel)
.replace(/\{1\}/g, nodeSplit)};\n`;
}
}
let subgraph = "";
// Add edges to the graph
for (const edge of edges) {
const sourcePrefix = edge.source.includes(":")
? edge.source.split(":")[0]
: undefined;
const targetPrefix = edge.target.includes(":")
? edge.target.split(":")[0]
: undefined;
// Exit subgraph if source or target is not in the same subgraph
if (
subgraph !== "" &&
(subgraph !== sourcePrefix || subgraph !== targetPrefix)
) {
mermaidGraph += "\tend\n";
subgraph = "";
}
// Enter subgraph if source and target are in the same subgraph
if (
subgraph === "" &&
sourcePrefix !== undefined &&
sourcePrefix === targetPrefix
) {
mermaidGraph = `\tsubgraph ${sourcePrefix}\n`;
subgraph = sourcePrefix;
}
const [source, target] = _adjustMermaidEdge(edge, nodes);
let edgeLabel = "";
// Add BR every wrapLabelNWords words
if (edge.data !== undefined) {
let edgeData = edge.data;
const words = edgeData.split(" ");
// Group words into chunks of wrapLabelNWords size
if (words.length > wrapLabelNWords) {
edgeData = words
.reduce((acc: string[], word: string, i: number) => {
if (i % wrapLabelNWords === 0) acc.push("");
acc[acc.length - 1] += ` ${word}`;
return acc;
}, [])
.join("<br>");
if (edge.conditional) {
edgeLabel = ` -. ${edgeData} .-> `;
} else {
edgeLabel = ` -- ${edgeData} --> `;
}
}
} else {
if (edge.conditional) {
edgeLabel = ` -.-> `;
} else {
edgeLabel = ` --> `;
}
}
mermaidGraph += `\t${_escapeNodeLabel(
source
)}${edgeLabel}${_escapeNodeLabel(target)};\n`;
}
if (subgraph !== "") {
mermaidGraph += "end\n";
}

// Add custom styles for nodes
if (withStyles && nodeColors !== undefined) {
mermaidGraph += _generateMermaidGraphStyles(nodeColors);
}
return mermaidGraph;
}

/**
* Renders Mermaid graph using the Mermaid.INK API.
*/
export async function drawMermaidPng(
mermaidSyntax: string,
config?: {
backgroundColor?: string;
}
) {
let { backgroundColor = "white" } = config ?? {};
// Use btoa for compatibility, assume ASCII
const mermaidSyntaxEncoded = btoa(mermaidSyntax);
// Check if the background color is a hexadecimal color code using regex
if (backgroundColor !== undefined) {
const hexColorPattern = /^#(?:[0-9a-fA-F]{3}){1,2}$/;
if (!hexColorPattern.test(backgroundColor)) {
backgroundColor = `!${backgroundColor}`;
}
}
const imageUrl = `https://mermaid.ink/img/${mermaidSyntaxEncoded}?bgColor=${backgroundColor}`;
const res = await fetch(imageUrl);
if (!res.ok) {
throw new Error(
[
`Failed to render the graph using the Mermaid.INK API.`,
`Status code: ${res.status}`,
`Status text: ${res.statusText}`,
].join("\n")
);
}
const content = await res.blob();
return content;
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 16 additions & 0 deletions langchain-core/src/runnables/tests/runnable_graph.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,20 @@ test("Test graph sequence", async () => {
{ source: 2, target: 3 },
],
});
expect(graph.drawMermaid())
.toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
\tPromptTemplateInput[PromptTemplateInput]:::startclass;
\tPromptTemplate([PromptTemplate]):::otherclass;
\tFakeLLM([FakeLLM]):::otherclass;
\tCommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass;
\tCommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass;
\tPromptTemplateInput --> PromptTemplate;
\tPromptTemplate --> FakeLLM;
\tCommaSeparatedListOutputParser --> CommaSeparatedListOutputParserOutput;
\tFakeLLM --> CommaSeparatedListOutputParser;
\tclassDef startclass fill:#ffdfba;
\tclassDef endclass fill:#baffc9;
\tclassDef otherclass fill:#fad7de;
`);
});
12 changes: 12 additions & 0 deletions langchain-core/src/runnables/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,15 @@ export interface RunnableInterface<

getName(suffix?: string): string;
}

export interface Edge {
source: string;
target: string;
data?: string;
conditional?: boolean;
}

export interface Node {
id: string;
data: RunnableIOSchema | RunnableInterface;
}

0 comments on commit db9e352

Please sign in to comment.