From 55cd021cf4a0e4b9671c3e0234780de9ac7fafff Mon Sep 17 00:00:00 2001 From: Google AI Edge Date: Tue, 6 Aug 2024 09:30:59 -0700 Subject: [PATCH] Add support to pin a single op node to the top of a layer PiperOrigin-RevId: 659986999 --- .../components/visualizer/common/consts.ts | 3 + .../visualizer/common/input_graph.ts | 4 ++ .../visualizer/common/model_graph.ts | 7 +++ .../src/components/visualizer/common/types.ts | 6 ++ .../components/visualizer/webgl_renderer.ts | 27 ++++++++- .../visualizer/worker/graph_expander.ts | 56 +++++++++++++++++-- .../visualizer/worker/graph_layout.ts | 39 +++++++++---- .../visualizer/worker/graph_processor.ts | 6 ++ 8 files changed, 132 insertions(+), 16 deletions(-) diff --git a/src/ui/src/components/visualizer/common/consts.ts b/src/ui/src/components/visualizer/common/consts.ts index e4d12146..4bad6170 100644 --- a/src/ui/src/components/visualizer/common/consts.ts +++ b/src/ui/src/components/visualizer/common/consts.ts @@ -88,6 +88,9 @@ export const TENSOR_VALUES_KEY = '__value'; /** The key to store the tensor tag in i/o metadata. */ export const TENSOR_TAG_METADATA_KEY = '__tensor_tag'; +/** The margin for the left and right side of the layout. */ +export const LAYOUT_MARGIN_X = 20; + /** A map from color names to the corresponding hex color. */ export const COLOR_NAME_TO_HEX: Record = { 'aliceblue': '#f0f8ff', diff --git a/src/ui/src/components/visualizer/common/input_graph.ts b/src/ui/src/components/visualizer/common/input_graph.ts index da441c93..36523bfc 100644 --- a/src/ui/src/components/visualizer/common/input_graph.ts +++ b/src/ui/src/components/visualizer/common/input_graph.ts @@ -17,6 +17,7 @@ */ import { + GraphNodeConfig, GraphNodeStyle, IncomingEdge, KeyValueList, @@ -136,4 +137,7 @@ export declare interface GraphNode { /** The default style of the node. */ style?: GraphNodeStyle; + + /** Custom configs for the node. */ + config?: GraphNodeConfig; } diff --git a/src/ui/src/components/visualizer/common/model_graph.ts b/src/ui/src/components/visualizer/common/model_graph.ts index 92737698..b3c82482 100644 --- a/src/ui/src/components/visualizer/common/model_graph.ts +++ b/src/ui/src/components/visualizer/common/model_graph.ts @@ -17,6 +17,7 @@ */ import { + GraphNodeConfig, GraphNodeStyle, IncomingEdge, KeyValuePairs, @@ -206,6 +207,9 @@ export declare interface OpNode extends ModelNodeBase { /** The style of the node. */ style?: GraphNodeStyle; + + /** Custom configs for the node. */ + config?: GraphNodeConfig; } /** @@ -237,6 +241,9 @@ export declare interface GroupNode extends ModelNodeBase { * nodes to layout. */ sectionContainer?: boolean; + + /** The op node that should be pinned to the top of the group. */ + pinToTopOpNode?: OpNode; } /** A node in a model graph. */ diff --git a/src/ui/src/components/visualizer/common/types.ts b/src/ui/src/components/visualizer/common/types.ts index 51cd67e9..1888f4dc 100644 --- a/src/ui/src/components/visualizer/common/types.ts +++ b/src/ui/src/components/visualizer/common/types.ts @@ -110,6 +110,12 @@ export declare interface GraphNodeStyle { hoveredBorderColor?: string; } +/** Custom configs for a graph node. */ +export declare interface GraphNodeConfig { + /** Whether to pin the node to the top of the group it belongs to. */ + pinToGroupTop?: boolean; +} + /** Data to pass along when clicking "open in popup" on a group node. */ export interface PopupPanelData { id: string; diff --git a/src/ui/src/components/visualizer/webgl_renderer.ts b/src/ui/src/components/visualizer/webgl_renderer.ts index e738e7e2..ac3dfc1e 100644 --- a/src/ui/src/components/visualizer/webgl_renderer.ts +++ b/src/ui/src/components/visualizer/webgl_renderer.ts @@ -48,6 +48,7 @@ import * as three from 'three'; import {AppService} from './app_service'; import { GLOBAL_KEY, + LAYOUT_MARGIN_X, NODE_LABEL_HEIGHT, WEBGL_ELEMENT_Y_FACTOR, } from './common/consts'; @@ -260,6 +261,7 @@ export class WebglRenderer implements OnInit, OnDestroy { readonly GROUP_NODE_BORDER_COLOR = new THREE.Color('#aaa'); readonly GROUP_NODE_LABEL_SEPARATOR_COLOR = new THREE.Color('#DADCE0'); readonly GROUP_NODE_ICON_COLOR = new THREE.Color('#444746'); + readonly GROUP_NODE_PIN_TO_TOP_SEPARATOR_COLOR = new THREE.Color('#bbb'); readonly EDGE_COLOR = new THREE.Color( this.appService.config()?.edgeColor || '#aaa', ); @@ -1776,7 +1778,7 @@ export class WebglRenderer implements OnInit, OnDestroy { } nodeBodyRectangles.push({ id: node.id, - index: i, + index: nodeBodyRectangles.length, bound: { x: x + width / 2, y: y + height / 2, @@ -1797,6 +1799,29 @@ export class WebglRenderer implements OnInit, OnDestroy { bgColor.b === 1, }); + // Render separator between the pinned node and the rest of the nodes. + if (isGroupNode(node) && node.expanded && node.pinToTopOpNode) { + nodeBodyRectangles.push({ + id: `${node.id}_pin_to_top_separator`, + index: nodeBodyRectangles.length, + bound: { + x: x + width / 2, + y: + (node.pinToTopOpNode.globalY || 0) + + (node.pinToTopOpNode.height || 0) / 2 + + 12.5, + width: width - LAYOUT_MARGIN_X * 2, + height: 1, + }, + yOffset: WEBGL_ELEMENT_Y_FACTOR * nodeIndex + 0.1, + isRounded: true, + borderColor: this.GROUP_NODE_PIN_TO_TOP_SEPARATOR_COLOR, + bgColor: this.GROUP_NODE_PIN_TO_TOP_SEPARATOR_COLOR, + borderWidth: 1, + opacity: 1, + }); + } + // Subgraph indicators. if (isOpNode(node) && node.subgraphIds) { const indicatorWidth = SUBGRAPH_INDICATOR_SIZE; diff --git a/src/ui/src/components/visualizer/worker/graph_expander.ts b/src/ui/src/components/visualizer/worker/graph_expander.ts index 901c2da0..3e7578cd 100644 --- a/src/ui/src/components/visualizer/worker/graph_expander.ts +++ b/src/ui/src/components/visualizer/worker/graph_expander.ts @@ -16,7 +16,8 @@ * ============================================================================== */ -import {GroupNode, ModelGraph} from '../common/model_graph'; +import {LAYOUT_MARGIN_X} from '../common/consts'; +import {GroupNode, ModelGraph, OpNode} from '../common/model_graph'; import {NodeDataProviderRunData, ShowOnNodeItemData} from '../common/types'; import {getDeepestExpandedGroupNodeIds, isGroupNode} from '../common/utils'; @@ -25,7 +26,6 @@ import { GraphLayout, LAYOUT_MARGIN_BOTTOM, LAYOUT_MARGIN_TOP, - LAYOUT_MARGIN_X, getNodeHeight, getNodeWidth, } from './graph_layout'; @@ -84,8 +84,13 @@ export class GraphExpander { // Grow size. const curTargetWidth = rect.width + LAYOUT_MARGIN_X * 2; - const curTargetHeight = + let curTargetHeight = rect.height + LAYOUT_MARGIN_TOP + LAYOUT_MARGIN_BOTTOM; + if (curGroupNode.pinToTopOpNode) { + curTargetHeight += this.getPinToTopNodeVerticalSpace( + curGroupNode.pinToTopOpNode, + ); + } curGroupNode.width = curTargetWidth; curGroupNode.height = curTargetHeight; @@ -158,8 +163,13 @@ export class GraphExpander { // Grow size. const curTargetWidth = rect.width + LAYOUT_MARGIN_X * 2; - const curTargetHeight = + let curTargetHeight = rect.height + LAYOUT_MARGIN_TOP + LAYOUT_MARGIN_BOTTOM; + if (groupNode.pinToTopOpNode) { + curTargetHeight += this.getPinToTopNodeVerticalSpace( + groupNode.pinToTopOpNode, + ); + } groupNode.width = curTargetWidth; groupNode.height = curTargetHeight; } @@ -263,8 +273,13 @@ export class GraphExpander { // Shrink size. const curTargetWidth = rect.width + LAYOUT_MARGIN_X * 2; - const curTargetHeight = + let curTargetHeight = rect.height + LAYOUT_MARGIN_TOP + LAYOUT_MARGIN_BOTTOM; + if (curGroupNode.pinToTopOpNode) { + curTargetHeight += this.getPinToTopNodeVerticalSpace( + curGroupNode.pinToTopOpNode, + ); + } curGroupNode.width = curTargetWidth; curGroupNode.height = curTargetHeight; @@ -401,6 +416,33 @@ export class GraphExpander { (groupNode.y || 0) + (groupNode.globalY || 0) + (node.localOffsetY || 0); + + // Move the node down if the current group node has a node pinned to + // top. + if ( + groupNode.pinToTopOpNode && + node.id !== groupNode.pinToTopOpNode.id + ) { + node.globalY += this.getPinToTopNodeVerticalSpace( + groupNode.pinToTopOpNode, + ); + } + + // For the pinned-to-top node, move it to the top-middle of the group + // node. + if (groupNode.pinToTopOpNode?.id === node.id) { + node.globalX = + (groupNode.x || 0) + + (groupNode.globalX || 0) + + (groupNode.width || 0) / 2; + node.globalY = + (groupNode.y || 0) + + (groupNode.globalY || 0) + + (node.localOffsetY || 0) + + this.getPinToTopNodeVerticalSpace(node as OpNode) - + (node.height || 0) / 2 + + 10; + } } if (isGroupNode(node)) { this.updateNodeOffset(node); @@ -434,4 +476,8 @@ export class GraphExpander { } } } + + private getPinToTopNodeVerticalSpace(node: OpNode): number { + return (node.height || 0) + 20; + } } diff --git a/src/ui/src/components/visualizer/worker/graph_layout.ts b/src/ui/src/components/visualizer/worker/graph_layout.ts index 039bb9f6..9fd3fcee 100644 --- a/src/ui/src/components/visualizer/worker/graph_layout.ts +++ b/src/ui/src/components/visualizer/worker/graph_layout.ts @@ -17,6 +17,7 @@ */ import { + LAYOUT_MARGIN_X, MAX_IO_ROWS_IN_ATTRS_TABLE, NODE_ATTRS_TABLE_FONT_SIZE, NODE_ATTRS_TABLE_LABEL_VALUE_PADDING, @@ -34,6 +35,7 @@ import { OpNode, } from '../common/model_graph'; import { + GraphNodeConfig, KeyValueList, NodeDataProviderRunData, Point, @@ -57,9 +59,6 @@ import { import {Dagre, DagreGraphInstance} from './dagre_types'; -/** The margin for the left and right side of the layout. */ -export const LAYOUT_MARGIN_X = 20; - /** The margin for the top and bottom side of the layout. */ export const LAYOUT_MARGIN_TOP = 36; @@ -85,6 +84,7 @@ export declare interface DagreNode { height: number; x?: number; y?: number; + config?: GraphNodeConfig; } interface LayoutGraph { @@ -143,7 +143,11 @@ export class GraphLayout { // Set nodes/edges to dagre. for (const id of Object.keys(layoutGraph.nodes)) { - this.dagreGraph.setNode(id, layoutGraph.nodes[id]); + const dagreNode = layoutGraph.nodes[id]; + if (dagreNode.config?.pinToGroupTop) { + continue; + } + this.dagreGraph.setNode(id, dagreNode); } for (const fromNodeId of Object.keys(layoutGraph.outgoingEdges)) { for (const toNodeId of layoutGraph.outgoingEdges[fromNodeId]) { @@ -154,7 +158,8 @@ export class GraphLayout { // Run the layout algorithm. this.dagre.layout(this.dagreGraph); - // Set the results back to the original model nodes. + // Set the results back to the original model nodes and calculate the bound + // that contains all the nodes. let minX = Number.MAX_VALUE; let minY = Number.MAX_VALUE; let maxX = Number.NEGATIVE_INFINITY; @@ -172,13 +177,17 @@ export class GraphLayout { node.localOffsetX = 0; node.localOffsetY = 0; - minX = Math.min(minX, node.x); - minY = Math.min(minY, node.y); - maxX = Math.max(maxX, node.x + node.width); - maxY = Math.max(maxY, node.y + node.height); + // Don't consider the bound of the node if it's pinned to the top of the + // group. + if (!dagreNode.config?.pinToGroupTop) { + minX = Math.min(minX, node.x); + minY = Math.min(minY, node.y); + maxX = Math.max(maxX, node.x + node.width); + maxY = Math.max(maxY, node.y + node.height); + } } - // Edges. + // Expand the bound to include all the edges. let minEdgeX = Number.MAX_VALUE; let minEdgeY = Number.MAX_VALUE; let maxEdgeX = Number.NEGATIVE_INFINITY; @@ -511,6 +520,7 @@ export function getLayoutGraph( nodeDataProviderRuns, testMode, ), + config: isOpNode(node) ? node.config : undefined, }; layoutGraph.nodes[node.id] = dagreNode; } @@ -520,6 +530,15 @@ export function getLayoutGraph( modelGraph.layoutGraphEdges[rootGroupNodeId] || {}; for (const [fromNodeId, toNodeIds] of Object.entries(curLayoutGraphEdges)) { for (const toNodeId of Object.keys(toNodeIds)) { + // Ignore edges from/to nodes pinned to group top. + const fromNode = modelGraph.nodesById[fromNodeId]; + const toNode = modelGraph.nodesById[toNodeId]; + if (fromNode && isOpNode(fromNode) && fromNode.config?.pinToGroupTop) { + continue; + } + if (toNode && isOpNode(toNode) && toNode.config?.pinToGroupTop) { + continue; + } addLayoutGraphEdge(layoutGraph, fromNodeId, toNodeId); } } diff --git a/src/ui/src/components/visualizer/worker/graph_processor.ts b/src/ui/src/components/visualizer/worker/graph_processor.ts index 5d10037e..efcb2aa2 100644 --- a/src/ui/src/components/visualizer/worker/graph_processor.ts +++ b/src/ui/src/components/visualizer/worker/graph_processor.ts @@ -159,6 +159,9 @@ export class GraphProcessor { if (graphNode.style) { opNode.style = graphNode.style; } + if (graphNode.config) { + opNode.config = graphNode.config; + } modelGraph.nodes.push(opNode); modelGraph.nodesById[opNode.id] = opNode; @@ -290,6 +293,9 @@ export class GraphProcessor { } if (!parentGroupNode.nsChildrenIds.includes(node.id)) { parentGroupNode.nsChildrenIds.push(node.id); + if (isOpNode(node) && node.config?.pinToGroupTop) { + parentGroupNode.pinToTopOpNode = node; + } } } }