From 58399daf7603ca58c293d616a2267b3a35336950 Mon Sep 17 00:00:00 2001 From: Google AI Edge Date: Sun, 13 Oct 2024 10:18:23 -0700 Subject: [PATCH] - Allow users to specify a name in node data provider json. - Allow users to hide node data from aggregated stats table and children stats table. - Allow users to hide specific stats in node data from aggregated stats table. - In sync navigation, show a message when the mapped node cannot be found in the other side of the split pane. - Fix a bug in io highlighter. PiperOrigin-RevId: 685452581 --- .../src/components/visualizer/common/types.ts | 39 +++- .../src/components/visualizer/common/utils.ts | 17 +- .../src/components/visualizer/info_panel.ts | 9 +- .../node_data_provider_summary_panel.ng.html | 200 +++++++++--------- .../node_data_provider_summary_panel.scss | 4 + .../node_data_provider_summary_panel.ts | 92 +++++++- .../visualizer/split_panes_container.ng.html | 3 + .../visualizer/split_panes_container.scss | 22 ++ .../visualizer/split_panes_container.ts | 51 +++++ .../visualizer/sync_navigation_service.ts | 3 + .../src/components/visualizer/view_on_node.ts | 3 +- .../components/visualizer/webgl_renderer.ts | 13 ++ .../webgl_renderer_io_highlight_service.ts | 6 +- .../visualizer/worker/graph_layout.ts | 5 +- 14 files changed, 357 insertions(+), 110 deletions(-) diff --git a/src/ui/src/components/visualizer/common/types.ts b/src/ui/src/components/visualizer/common/types.ts index cb1f9c32..b346d184 100644 --- a/src/ui/src/components/visualizer/common/types.ts +++ b/src/ui/src/components/visualizer/common/types.ts @@ -236,8 +236,24 @@ export declare interface NodeInfo { node?: ModelNode; } +/** Supported aggregated stats. */ +export type AggregatedStat = 'min' | 'max' | 'sum' | 'avg'; + /** Node data provider data for a single graph. */ export declare interface NodeDataProviderGraphData { + /** + * The name of the node data. + * + * The node data's name (mainly used for display purposes) is determined by + * the following sources, in order of priority: + * + * 1) an explicitly specified name from this field, which overrides any other + * source; + * 2) the `name` parameter in the API call, if applicable; + * 3) the JSON file name, if the data is loaded from a JSON file. + */ + name?: string; + /** * Node data indexed by node keys. * @@ -282,7 +298,28 @@ export declare interface NodeDataProviderGraphData { */ gradient?: GradientItem[]; - // https://gist.github.com/Myndex/e1025706436736166561d339fd667493 + /** + * Whether to hide the corresponding column in aggregated stats table + * (the first table). + * + * If all columns in that table are hidden, the whole table will be hidden. + */ + hideInAggregatedStatsTable?: boolean; + + /** + * Whether to hide the corresponding column in children stats table + * (the second table). + * + * If all columns in that table are hidden, the whole table will be hidden. + */ + hideInChildrenStatsTable?: boolean; + + /** + * The stats to hide in the aggregated stats table (the first table). + * + * The value for the hidden stat will be displayed as '-'. + */ + hideAggregatedStats?: AggregatedStat[]; } /** The top level node data provider data, indexed by graph id. */ diff --git a/src/ui/src/components/visualizer/common/utils.ts b/src/ui/src/components/visualizer/common/utils.ts index 6c2a9963..f321c464 100644 --- a/src/ui/src/components/visualizer/common/utils.ts +++ b/src/ui/src/components/visualizer/common/utils.ts @@ -541,11 +541,14 @@ export function getOpNodeDataProviderKeyValuePairsForAttrsTable( type.replace(NODE_DATA_PROVIDER_SHOW_ON_NODE_TYPE_PREFIX, ''), ); const runs = Object.values(curNodeDataProviderRuns).filter((run) => - runNames.includes(run.runName), + runNames.includes(getRunName(run, {id: modelGraphId})), ); for (const run of runs) { const value = (run.results || {})?.[modelGraphId][node.id]?.strValue || '-'; - keyValuePairs.push({key: run.runName, value}); + keyValuePairs.push({ + key: getRunName(run, {id: modelGraphId}), + value, + }); } return keyValuePairs; } @@ -1053,3 +1056,13 @@ export function getIntersectionPoints(rect1: Rect, rect2: Rect) { return {intersection1, intersection2}; } + +/** Gets the run name for the given run. */ +export function getRunName( + run: NodeDataProviderRunData, + modelGraphIdLike?: {id: string}, +): string { + return ( + run.nodeDataProviderData?.[modelGraphIdLike?.id || '']?.name ?? run.runName + ); +} diff --git a/src/ui/src/components/visualizer/info_panel.ts b/src/ui/src/components/visualizer/info_panel.ts index ae746074..e3769570 100644 --- a/src/ui/src/components/visualizer/info_panel.ts +++ b/src/ui/src/components/visualizer/info_panel.ts @@ -53,7 +53,12 @@ import { SearchMatchType, SearchResults, } from './common/types'; -import {getNamespaceLabel, isGroupNode, isOpNode} from './common/utils'; +import { + getNamespaceLabel, + getRunName, + isGroupNode, + isOpNode, +} from './common/utils'; import {ExpandableInfoText} from './expandable_info_text'; import {HoverableLabel} from './hoverable_label'; import {InfoPanelService} from './info_panel_service'; @@ -751,7 +756,7 @@ export class InfoPanel { nodeDataProvidersSection.items.push({ id: run.runId, section: nodeDataProvidersSection, - label: run.runName, + label: getRunName(run, this.curModelGraph), value: strValue, canShowOnNode: run.done, showOnNode: this.curShowOnNodeDataProviderRuns[run.runId] != null, diff --git a/src/ui/src/components/visualizer/node_data_provider_summary_panel.ng.html b/src/ui/src/components/visualizer/node_data_provider_summary_panel.ng.html index 0b3d1747..d4e60cbc 100644 --- a/src/ui/src/components/visualizer/node_data_provider_summary_panel.ng.html +++ b/src/ui/src/components/visualizer/node_data_provider_summary_panel.ng.html @@ -51,108 +51,116 @@ -
-
-
- {{statsTableTitleIcon}} - {{statsTableTitle}} + @if (showStatsTable) { +
+
+
+ {{statsTableTitleIcon}} + {{statsTableTitle}} +
+ + + + + + + + + + + + + +
+ Stat + +
+
{{i + 1}}
+
{{runItem.runName}}
+
+
{{row.stat}} + {{getStatValue(value)}} +
- - - - - - - - - - - - - -
- Stat - -
-
{{i + 1}}
-
{{runItem.runName}}
-
-
{{row.stat}} - {{getStatValue(value)}} -
-
+ } -
-
-
- {{childrenStatsTableTitleIcon}} - {{childrenStatsTableTitle}} + @if (showChildrenStatsTable) { +
+
+
+ {{childrenStatsTableTitleIcon}} + {{childrenStatsTableTitle}} +
+ @if (childrenStatRowsCount > tablePageSize && !childrenStatsTableCollapsed) { + + + }
- @if (childrenStatRowsCount > tablePageSize && !childrenStatsTableCollapsed) { - - - } + + + + + + + + + + + + + + + + + +
+
+ # + + {{curChildrenStatSortingDirection === 'asc' ? 'arrow_upward' : 'arrow_downward'}} + +
+
+
+ Node + + {{curChildrenStatSortingDirection === 'asc' ? 'arrow_upward' : 'arrow_downward'}} + +
+
+
+
{{col.runIndex + 1}}
+
{{col.label}}
+ + {{curChildrenStatSortingDirection === 'asc' ? 'arrow_upward' : 'arrow_downward'}} + +
+
{{row.index}} + {{row.label}} + + {{strValue}} +
- - - - - - - - - - - - - - - - - -
-
- # - - {{curChildrenStatSortingDirection === 'asc' ? 'arrow_upward' : 'arrow_downward'}} - -
-
-
- Node - - {{curChildrenStatSortingDirection === 'asc' ? 'arrow_upward' : 'arrow_downward'}} - -
-
-
-
{{col.runIndex + 1}}
-
{{col.label}}
- - {{curChildrenStatSortingDirection === 'asc' ? 'arrow_upward' : 'arrow_downward'}} - -
-
{{row.index}} - {{row.label}} - - {{strValue}} -
-
+ }
stat.sum); this.curStatRows[3].values = stats.map((stat) => stat.sum / stat.count); + // Hide stat values based on hideAggregatedStats. + const allStats: AggregatedStat[] = ['min', 'max', 'sum', 'avg']; + for (let i = 0; i < runs.length; i++) { + const run = runs[i]; + const statsToHide: AggregatedStat[] = + run.nodeDataProviderData?.[this.curModelGraph.id] + ?.hideAggregatedStats ?? []; + for (let j = 0; j < allStats.length; j++) { + const stat = allStats[j]; + if (statsToHide.includes(stat)) { + // Set the value to positive infinity so that it will be displayed as + // '-' in the table. See `getStatValue()`. + this.curStatRows[j].values[i] = Number.POSITIVE_INFINITY; + } + } + } + // Generate children stats columns. this.childrenStatsCols = []; let childrenStatColIndex = 0; @@ -650,7 +721,10 @@ export class NodeDataProviderSummaryPanel implements OnChanges { this.childrenStatsCols.push({ colIndex: childrenStatColIndex, runIndex: i, - label: `${runs[i].runName} • ${childrenStat}`, + label: `${this.getRunName(runs[i])} • ${childrenStat}`, + hideInChildrenStatsTable: + runs[i].nodeDataProviderData?.[this.curModelGraph.id] + ?.hideInChildrenStatsTable, }); childrenStatColIndex++; } @@ -667,6 +741,7 @@ export class NodeDataProviderSummaryPanel implements OnChanges { const node = this.curModelGraph.nodesById[nodeId]; const colValues: number[] = []; const colStrs: string[] = []; + const colHidden: boolean[] = []; for (let runIndex = 0; runIndex < runs.length; runIndex++) { const run = runs[runIndex]; const curResults = run.results || {}; @@ -697,6 +772,10 @@ export class NodeDataProviderSummaryPanel implements OnChanges { } colValues.push(sumPct); colStrs.push(hasValue ? sumPct.toFixed(1) : '-'); + colHidden.push( + run.nodeDataProviderData?.[this.curModelGraph.id] + ?.hideInChildrenStatsTable === true, + ); } this.curChildrenStatRows.push({ id: nodeId, @@ -704,6 +783,7 @@ export class NodeDataProviderSummaryPanel implements OnChanges { index: i, colValues, colStrs, + colHidden, }); } this.savedChildrenStatRows = [...this.curChildrenStatRows]; @@ -850,4 +930,8 @@ export class NodeDataProviderSummaryPanel implements OnChanges { }) .join(','); } + + private getRunName(run: NodeDataProviderRunData): string { + return getRunName(run, this.curModelGraph); + } } diff --git a/src/ui/src/components/visualizer/split_panes_container.ng.html b/src/ui/src/components/visualizer/split_panes_container.ng.html index 516ab161..b4dcd8d5 100644 --- a/src/ui/src/components/visualizer/split_panes_container.ng.html +++ b/src/ui/src/components/visualizer/split_panes_container.ng.html @@ -111,6 +111,9 @@
+
+ No mapped node found +
}
diff --git a/src/ui/src/components/visualizer/split_panes_container.scss b/src/ui/src/components/visualizer/split_panes_container.scss index 633f3351..ee7fd7aa 100644 --- a/src/ui/src/components/visualizer/split_panes_container.scss +++ b/src/ui/src/components/visualizer/split_panes_container.scss @@ -220,6 +220,28 @@ height: 24px; // Over resizer. z-index: 250; + + .no-mapped-node-message { + position: absolute; + top: 28px; + width: 140px; + font-size: 12px; + left: -44px; + background-color: #a00; + color: white; + padding: 2px 4px; + display: flex; + align-items: center; + justify-content: center; + border-radius: 99px; + pointer-events: none; + opacity: 0; + transition: opacity 100ms; + + &.show { + opacity: 1; + } + } } } diff --git a/src/ui/src/components/visualizer/split_panes_container.ts b/src/ui/src/components/visualizer/split_panes_container.ts index 24daf804..644614c1 100644 --- a/src/ui/src/components/visualizer/split_panes_container.ts +++ b/src/ui/src/components/visualizer/split_panes_container.ts @@ -24,6 +24,7 @@ import { ChangeDetectorRef, Component, computed, + DestroyRef, effect, ElementRef, QueryList, @@ -31,6 +32,7 @@ import { ViewChild, ViewChildren, } from '@angular/core'; +import {takeUntilDestroyed} from '@angular/core/rxjs-interop'; import {MatIconModule} from '@angular/material/icon'; import {MatProgressSpinnerModule} from '@angular/material/progress-spinner'; import {MatTooltipModule} from '@angular/material/tooltip'; @@ -51,6 +53,7 @@ import {GraphPanel} from './graph_panel'; import {InfoPanel} from './info_panel'; import {SplitPane} from './split_pane'; import {SyncNavigationButton} from './sync_navigation_button'; +import {SyncNavigationService} from './sync_navigation_service'; import {WorkerService} from './worker_service'; interface ProcessingTask { @@ -91,6 +94,8 @@ interface ProcessingTask { }) export class SplitPanesContainer implements AfterViewInit { @ViewChild('panesContainer') panesContainer!: ElementRef; + @ViewChild('noMappedNodeMessage') + noMappedNodeMessage?: ElementRef; @ViewChildren('splitPane') splitPanes = new QueryList(); readonly processingTasks: Record = {}; @@ -102,9 +107,13 @@ export class SplitPanesContainer implements AfterViewInit { curUpdateProcessingProgressReq?: UpdateProcessingProgressRequest; + private hideNoMappedNodeMessageTimeoutId = -1; + constructor( private readonly changeDetectorRef: ChangeDetectorRef, private readonly appService: AppService, + private readonly destroyRef: DestroyRef, + private readonly syncNavigationService: SyncNavigationService, private readonly workerService: WorkerService, ) { this.panes = this.appService.panes; @@ -142,6 +151,16 @@ export class SplitPanesContainer implements AfterViewInit { break; } }); + + this.syncNavigationService.showNoMappedNodeMessageTrigger$ + .pipe(takeUntilDestroyed(this.destroyRef)) + .subscribe((data) => { + if (data === undefined) { + this.hideNoMappedNodeMessage(); + } else { + this.showNoMappedNodeMessage(); + } + }); } ngAfterViewInit() { @@ -296,4 +315,36 @@ export class SplitPanesContainer implements AfterViewInit { this.changeDetectorRef.detectChanges(); } } + + private hideNoMappedNodeMessage() { + const ele = this.noMappedNodeMessage?.nativeElement; + if (!ele) { + return; + } + + if (this.hideNoMappedNodeMessageTimeoutId >= 0) { + clearTimeout(this.hideNoMappedNodeMessageTimeoutId); + this.hideNoMappedNodeMessageTimeoutId = -1; + } + + ele.classList.remove('show'); + } + + private showNoMappedNodeMessage() { + const ele = this.noMappedNodeMessage?.nativeElement; + if (!ele) { + return; + } + + if (this.hideNoMappedNodeMessageTimeoutId >= 0) { + clearTimeout(this.hideNoMappedNodeMessageTimeoutId); + this.hideNoMappedNodeMessageTimeoutId = -1; + } + + // Hide after 3 seconds. + ele.classList.add('show'); + this.hideNoMappedNodeMessageTimeoutId = setTimeout(() => { + ele.classList.remove('show'); + }, 3000); + } } diff --git a/src/ui/src/components/visualizer/sync_navigation_service.ts b/src/ui/src/components/visualizer/sync_navigation_service.ts index a9d5dcf8..af64156a 100644 --- a/src/ui/src/components/visualizer/sync_navigation_service.ts +++ b/src/ui/src/components/visualizer/sync_navigation_service.ts @@ -41,6 +41,9 @@ export class SyncNavigationService { readonly syncNavigationModeChanged$ = new Subject(); + // {} means showing the message, and undefined means hiding the message. + readonly showNoMappedNodeMessageTrigger$ = new Subject<{} | undefined>(); + private savedProcessedSyncNavigationData: Record< string, ProcessedSyncNavigationData diff --git a/src/ui/src/components/visualizer/view_on_node.ts b/src/ui/src/components/visualizer/view_on_node.ts index b58cf3e2..eaa770c9 100644 --- a/src/ui/src/components/visualizer/view_on_node.ts +++ b/src/ui/src/components/visualizer/view_on_node.ts @@ -45,6 +45,7 @@ import { ShowOnNodeItemData, ShowOnNodeItemType, } from './common/types'; +import {getRunName} from './common/utils'; import {LocalStorageService} from './local_storage_service'; import {NodeDataProviderExtensionService} from './node_data_provider_extension_service'; @@ -101,7 +102,7 @@ export class ViewOnNode { ), ) : []; - return runs.map((run) => run.runName); + return runs.map((run) => getRunName(run, modelGraph)); }); private savedNodeDataProviderRunNames: string[] = []; diff --git a/src/ui/src/components/visualizer/webgl_renderer.ts b/src/ui/src/components/visualizer/webgl_renderer.ts index 689840ef..db63e52d 100644 --- a/src/ui/src/components/visualizer/webgl_renderer.ts +++ b/src/ui/src/components/visualizer/webgl_renderer.ts @@ -752,6 +752,19 @@ export class WebglRenderer implements OnInit, OnDestroy { !hideInLayout ) { this.revealNode(mappedNodeId, false); + this.syncNavigationService.showNoMappedNodeMessageTrigger$.next( + undefined, + ); + } else { + if (mappedNodeId !== '' && (!mappedNode || hideInLayout)) { + this.syncNavigationService.showNoMappedNodeMessageTrigger$.next( + {}, + ); + } else if (mappedNodeId === '') { + this.syncNavigationService.showNoMappedNodeMessageTrigger$.next( + undefined, + ); + } } } }); diff --git a/src/ui/src/components/visualizer/webgl_renderer_io_highlight_service.ts b/src/ui/src/components/visualizer/webgl_renderer_io_highlight_service.ts index 853ffc65..fa370ec0 100644 --- a/src/ui/src/components/visualizer/webgl_renderer_io_highlight_service.ts +++ b/src/ui/src/components/visualizer/webgl_renderer_io_highlight_service.ts @@ -901,9 +901,9 @@ export class WebglRendererIoHighlightService { targetNodeId: string, ): ModelEdge | undefined { const groupNodeId = namespace === '' ? '' : `${namespace}___group___`; - return this.webglRenderer.curModelGraph.edgesByGroupNodeIds[ - groupNodeId - ].find((edge) => { + return ( + this.webglRenderer.curModelGraph.edgesByGroupNodeIds[groupNodeId] ?? [] + ).find((edge) => { const fromNode = this.webglRenderer.curModelGraph.nodesById[edge.fromNodeId]; const toNode = this.webglRenderer.curModelGraph.nodesById[edge.toNodeId]; diff --git a/src/ui/src/components/visualizer/worker/graph_layout.ts b/src/ui/src/components/visualizer/worker/graph_layout.ts index b845226b..dcc6f2ea 100644 --- a/src/ui/src/components/visualizer/worker/graph_layout.ts +++ b/src/ui/src/components/visualizer/worker/graph_layout.ts @@ -55,6 +55,7 @@ import { getOpNodeFieldLabelsFromShowOnNodeItemTypes, getOpNodeInputsKeyValuePairsForAttrsTable, getOpNodeOutputsKeyValuePairsForAttrsTable, + getRunName, isGroupNode, isOpNode, splitLabel, @@ -493,6 +494,7 @@ export function getNodeHeight( showOnNodeItemTypes, node, nodeDataProviderRuns, + modelGraph, ); } else if (isGroupNode(node)) { attrsTableRowCount = getGroupNodeAttrsTableRowCount( @@ -583,6 +585,7 @@ function getOpNodeAttrsTableRowCount( showOnNodeItemTypes: Record, node: OpNode, nodeDataProviderRuns: Record, + modelGraph: ModelGraph, ): number { // Basic info fields. const baiscFieldIds = @@ -623,7 +626,7 @@ function getOpNodeAttrsTableRowCount( ) && Object.values(nodeDataProviderRuns).some( (run) => - run.runName === + getRunName(run, modelGraph) === showOnNodeItemType.replace( NODE_DATA_PROVIDER_SHOW_ON_NODE_TYPE_PREFIX, '',