Skip to content

Commit

Permalink
Add support to pin a single op node to the top of a layer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 659986999
  • Loading branch information
Google AI Edge authored and copybara-github committed Aug 6, 2024
1 parent 9e6df39 commit 55cd021
Show file tree
Hide file tree
Showing 8 changed files with 132 additions and 16 deletions.
3 changes: 3 additions & 0 deletions src/ui/src/components/visualizer/common/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ export const TENSOR_VALUES_KEY = '__value';
/** The key to store the tensor tag in i/o metadata. */
export const TENSOR_TAG_METADATA_KEY = '__tensor_tag';

/** The margin for the left and right side of the layout. */
export const LAYOUT_MARGIN_X = 20;

/** A map from color names to the corresponding hex color. */
export const COLOR_NAME_TO_HEX: Record<string, string> = {
'aliceblue': '#f0f8ff',
Expand Down
4 changes: 4 additions & 0 deletions src/ui/src/components/visualizer/common/input_graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/

import {
GraphNodeConfig,
GraphNodeStyle,
IncomingEdge,
KeyValueList,
Expand Down Expand Up @@ -136,4 +137,7 @@ export declare interface GraphNode {

/** The default style of the node. */
style?: GraphNodeStyle;

/** Custom configs for the node. */
config?: GraphNodeConfig;
}
7 changes: 7 additions & 0 deletions src/ui/src/components/visualizer/common/model_graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/

import {
GraphNodeConfig,
GraphNodeStyle,
IncomingEdge,
KeyValuePairs,
Expand Down Expand Up @@ -206,6 +207,9 @@ export declare interface OpNode extends ModelNodeBase {

/** The style of the node. */
style?: GraphNodeStyle;

/** Custom configs for the node. */
config?: GraphNodeConfig;
}

/**
Expand Down Expand Up @@ -237,6 +241,9 @@ export declare interface GroupNode extends ModelNodeBase {
* nodes to layout.
*/
sectionContainer?: boolean;

/** The op node that should be pinned to the top of the group. */
pinToTopOpNode?: OpNode;
}

/** A node in a model graph. */
Expand Down
6 changes: 6 additions & 0 deletions src/ui/src/components/visualizer/common/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ export declare interface GraphNodeStyle {
hoveredBorderColor?: string;
}

/** Custom configs for a graph node. */
export declare interface GraphNodeConfig {
/** Whether to pin the node to the top of the group it belongs to. */
pinToGroupTop?: boolean;
}

/** Data to pass along when clicking "open in popup" on a group node. */
export interface PopupPanelData {
id: string;
Expand Down
27 changes: 26 additions & 1 deletion src/ui/src/components/visualizer/webgl_renderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ import * as three from 'three';
import {AppService} from './app_service';
import {
GLOBAL_KEY,
LAYOUT_MARGIN_X,
NODE_LABEL_HEIGHT,
WEBGL_ELEMENT_Y_FACTOR,
} from './common/consts';
Expand Down Expand Up @@ -260,6 +261,7 @@ export class WebglRenderer implements OnInit, OnDestroy {
readonly GROUP_NODE_BORDER_COLOR = new THREE.Color('#aaa');
readonly GROUP_NODE_LABEL_SEPARATOR_COLOR = new THREE.Color('#DADCE0');
readonly GROUP_NODE_ICON_COLOR = new THREE.Color('#444746');
readonly GROUP_NODE_PIN_TO_TOP_SEPARATOR_COLOR = new THREE.Color('#bbb');
readonly EDGE_COLOR = new THREE.Color(
this.appService.config()?.edgeColor || '#aaa',
);
Expand Down Expand Up @@ -1776,7 +1778,7 @@ export class WebglRenderer implements OnInit, OnDestroy {
}
nodeBodyRectangles.push({
id: node.id,
index: i,
index: nodeBodyRectangles.length,
bound: {
x: x + width / 2,
y: y + height / 2,
Expand All @@ -1797,6 +1799,29 @@ export class WebglRenderer implements OnInit, OnDestroy {
bgColor.b === 1,
});

// Render separator between the pinned node and the rest of the nodes.
if (isGroupNode(node) && node.expanded && node.pinToTopOpNode) {
nodeBodyRectangles.push({
id: `${node.id}_pin_to_top_separator`,
index: nodeBodyRectangles.length,
bound: {
x: x + width / 2,
y:
(node.pinToTopOpNode.globalY || 0) +
(node.pinToTopOpNode.height || 0) / 2 +
12.5,
width: width - LAYOUT_MARGIN_X * 2,
height: 1,
},
yOffset: WEBGL_ELEMENT_Y_FACTOR * nodeIndex + 0.1,
isRounded: true,
borderColor: this.GROUP_NODE_PIN_TO_TOP_SEPARATOR_COLOR,
bgColor: this.GROUP_NODE_PIN_TO_TOP_SEPARATOR_COLOR,
borderWidth: 1,
opacity: 1,
});
}

// Subgraph indicators.
if (isOpNode(node) && node.subgraphIds) {
const indicatorWidth = SUBGRAPH_INDICATOR_SIZE;
Expand Down
56 changes: 51 additions & 5 deletions src/ui/src/components/visualizer/worker/graph_expander.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
* ==============================================================================
*/

import {GroupNode, ModelGraph} from '../common/model_graph';
import {LAYOUT_MARGIN_X} from '../common/consts';
import {GroupNode, ModelGraph, OpNode} from '../common/model_graph';
import {NodeDataProviderRunData, ShowOnNodeItemData} from '../common/types';
import {getDeepestExpandedGroupNodeIds, isGroupNode} from '../common/utils';

Expand All @@ -25,7 +26,6 @@ import {
GraphLayout,
LAYOUT_MARGIN_BOTTOM,
LAYOUT_MARGIN_TOP,
LAYOUT_MARGIN_X,
getNodeHeight,
getNodeWidth,
} from './graph_layout';
Expand Down Expand Up @@ -84,8 +84,13 @@ export class GraphExpander {

// Grow size.
const curTargetWidth = rect.width + LAYOUT_MARGIN_X * 2;
const curTargetHeight =
let curTargetHeight =
rect.height + LAYOUT_MARGIN_TOP + LAYOUT_MARGIN_BOTTOM;
if (curGroupNode.pinToTopOpNode) {
curTargetHeight += this.getPinToTopNodeVerticalSpace(
curGroupNode.pinToTopOpNode,
);
}
curGroupNode.width = curTargetWidth;
curGroupNode.height = curTargetHeight;

Expand Down Expand Up @@ -158,8 +163,13 @@ export class GraphExpander {

// Grow size.
const curTargetWidth = rect.width + LAYOUT_MARGIN_X * 2;
const curTargetHeight =
let curTargetHeight =
rect.height + LAYOUT_MARGIN_TOP + LAYOUT_MARGIN_BOTTOM;
if (groupNode.pinToTopOpNode) {
curTargetHeight += this.getPinToTopNodeVerticalSpace(
groupNode.pinToTopOpNode,
);
}
groupNode.width = curTargetWidth;
groupNode.height = curTargetHeight;
}
Expand Down Expand Up @@ -263,8 +273,13 @@ export class GraphExpander {

// Shrink size.
const curTargetWidth = rect.width + LAYOUT_MARGIN_X * 2;
const curTargetHeight =
let curTargetHeight =
rect.height + LAYOUT_MARGIN_TOP + LAYOUT_MARGIN_BOTTOM;
if (curGroupNode.pinToTopOpNode) {
curTargetHeight += this.getPinToTopNodeVerticalSpace(
curGroupNode.pinToTopOpNode,
);
}
curGroupNode.width = curTargetWidth;
curGroupNode.height = curTargetHeight;

Expand Down Expand Up @@ -401,6 +416,33 @@ export class GraphExpander {
(groupNode.y || 0) +
(groupNode.globalY || 0) +
(node.localOffsetY || 0);

// Move the node down if the current group node has a node pinned to
// top.
if (
groupNode.pinToTopOpNode &&
node.id !== groupNode.pinToTopOpNode.id
) {
node.globalY += this.getPinToTopNodeVerticalSpace(
groupNode.pinToTopOpNode,
);
}

// For the pinned-to-top node, move it to the top-middle of the group
// node.
if (groupNode.pinToTopOpNode?.id === node.id) {
node.globalX =
(groupNode.x || 0) +
(groupNode.globalX || 0) +
(groupNode.width || 0) / 2;
node.globalY =
(groupNode.y || 0) +
(groupNode.globalY || 0) +
(node.localOffsetY || 0) +
this.getPinToTopNodeVerticalSpace(node as OpNode) -
(node.height || 0) / 2 +
10;
}
}
if (isGroupNode(node)) {
this.updateNodeOffset(node);
Expand Down Expand Up @@ -434,4 +476,8 @@ export class GraphExpander {
}
}
}

private getPinToTopNodeVerticalSpace(node: OpNode): number {
return (node.height || 0) + 20;
}
}
39 changes: 29 additions & 10 deletions src/ui/src/components/visualizer/worker/graph_layout.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/

import {
LAYOUT_MARGIN_X,
MAX_IO_ROWS_IN_ATTRS_TABLE,
NODE_ATTRS_TABLE_FONT_SIZE,
NODE_ATTRS_TABLE_LABEL_VALUE_PADDING,
Expand All @@ -34,6 +35,7 @@ import {
OpNode,
} from '../common/model_graph';
import {
GraphNodeConfig,
KeyValueList,
NodeDataProviderRunData,
Point,
Expand All @@ -57,9 +59,6 @@ import {

import {Dagre, DagreGraphInstance} from './dagre_types';

/** The margin for the left and right side of the layout. */
export const LAYOUT_MARGIN_X = 20;

/** The margin for the top and bottom side of the layout. */
export const LAYOUT_MARGIN_TOP = 36;

Expand All @@ -85,6 +84,7 @@ export declare interface DagreNode {
height: number;
x?: number;
y?: number;
config?: GraphNodeConfig;
}

interface LayoutGraph {
Expand Down Expand Up @@ -143,7 +143,11 @@ export class GraphLayout {

// Set nodes/edges to dagre.
for (const id of Object.keys(layoutGraph.nodes)) {
this.dagreGraph.setNode(id, layoutGraph.nodes[id]);
const dagreNode = layoutGraph.nodes[id];
if (dagreNode.config?.pinToGroupTop) {
continue;
}
this.dagreGraph.setNode(id, dagreNode);
}
for (const fromNodeId of Object.keys(layoutGraph.outgoingEdges)) {
for (const toNodeId of layoutGraph.outgoingEdges[fromNodeId]) {
Expand All @@ -154,7 +158,8 @@ export class GraphLayout {
// Run the layout algorithm.
this.dagre.layout(this.dagreGraph);

// Set the results back to the original model nodes.
// Set the results back to the original model nodes and calculate the bound
// that contains all the nodes.
let minX = Number.MAX_VALUE;
let minY = Number.MAX_VALUE;
let maxX = Number.NEGATIVE_INFINITY;
Expand All @@ -172,13 +177,17 @@ export class GraphLayout {
node.localOffsetX = 0;
node.localOffsetY = 0;

minX = Math.min(minX, node.x);
minY = Math.min(minY, node.y);
maxX = Math.max(maxX, node.x + node.width);
maxY = Math.max(maxY, node.y + node.height);
// Don't consider the bound of the node if it's pinned to the top of the
// group.
if (!dagreNode.config?.pinToGroupTop) {
minX = Math.min(minX, node.x);
minY = Math.min(minY, node.y);
maxX = Math.max(maxX, node.x + node.width);
maxY = Math.max(maxY, node.y + node.height);
}
}

// Edges.
// Expand the bound to include all the edges.
let minEdgeX = Number.MAX_VALUE;
let minEdgeY = Number.MAX_VALUE;
let maxEdgeX = Number.NEGATIVE_INFINITY;
Expand Down Expand Up @@ -511,6 +520,7 @@ export function getLayoutGraph(
nodeDataProviderRuns,
testMode,
),
config: isOpNode(node) ? node.config : undefined,
};
layoutGraph.nodes[node.id] = dagreNode;
}
Expand All @@ -520,6 +530,15 @@ export function getLayoutGraph(
modelGraph.layoutGraphEdges[rootGroupNodeId] || {};
for (const [fromNodeId, toNodeIds] of Object.entries(curLayoutGraphEdges)) {
for (const toNodeId of Object.keys(toNodeIds)) {
// Ignore edges from/to nodes pinned to group top.
const fromNode = modelGraph.nodesById[fromNodeId];
const toNode = modelGraph.nodesById[toNodeId];
if (fromNode && isOpNode(fromNode) && fromNode.config?.pinToGroupTop) {
continue;
}
if (toNode && isOpNode(toNode) && toNode.config?.pinToGroupTop) {
continue;
}
addLayoutGraphEdge(layoutGraph, fromNodeId, toNodeId);
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/ui/src/components/visualizer/worker/graph_processor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,9 @@ export class GraphProcessor {
if (graphNode.style) {
opNode.style = graphNode.style;
}
if (graphNode.config) {
opNode.config = graphNode.config;
}
modelGraph.nodes.push(opNode);
modelGraph.nodesById[opNode.id] = opNode;

Expand Down Expand Up @@ -290,6 +293,9 @@ export class GraphProcessor {
}
if (!parentGroupNode.nsChildrenIds.includes(node.id)) {
parentGroupNode.nsChildrenIds.push(node.id);
if (isOpNode(node) && node.config?.pinToGroupTop) {
parentGroupNode.pinToTopOpNode = node;
}
}
}
}
Expand Down

0 comments on commit 55cd021

Please sign in to comment.