From 8efb555d0d19243eb5543471650c474a04873b6c Mon Sep 17 00:00:00 2001 From: Google AI Edge Date: Thu, 14 Nov 2024 14:52:19 -0800 Subject: [PATCH] Add support for showing inputs/outputs for selected layer node PiperOrigin-RevId: 696659244 --- src/ui/src/components/home_page/home_page.ts | 4 + .../visualizer/common/visualizer_config.ts | 3 + .../components/visualizer/info_panel.ng.html | 327 +++++--- .../src/components/visualizer/info_panel.scss | 36 +- .../src/components/visualizer/info_panel.ts | 391 ++++++--- .../visualizer/split_pane_service.ts | 40 +- .../webgl_renderer_io_highlight_service.ts | 747 ++++++++++-------- src/ui/src/services/settings_service.ts | 17 +- 8 files changed, 947 insertions(+), 618 deletions(-) diff --git a/src/ui/src/components/home_page/home_page.ts b/src/ui/src/components/home_page/home_page.ts index 9c334ff8..50f3ab28 100644 --- a/src/ui/src/components/home_page/home_page.ts +++ b/src/ui/src/components/home_page/home_page.ts @@ -47,6 +47,7 @@ import { SETTING_EDGE_COLOR, SETTING_EDGE_LABEL_FONT_SIZE, SETTING_HIDE_OP_NODES_WITH_LABELS, + SETTING_HIGHLIGHT_LAYER_NODE_INPUTS_OUTPUTS, SETTING_KEEP_LAYERS_WITH_A_SINGLE_CHILD, SETTING_MAX_CONST_ELEMENT_COUNT_LIMIT, SETTING_SHOW_OP_NODE_OUT_OF_LAYER_EDGES_WITHOUT_SELECTING, @@ -351,6 +352,9 @@ export class HomePage implements AfterViewInit { this.settingsService.getBooleanValue( SETTING_SHOW_OP_NODE_OUT_OF_LAYER_EDGES_WITHOUT_SELECTING, ), + highlightLayerNodeInputsOutputs: this.settingsService.getBooleanValue( + SETTING_HIGHLIGHT_LAYER_NODE_INPUTS_OUTPUTS, + ), }; } diff --git a/src/ui/src/components/visualizer/common/visualizer_config.ts b/src/ui/src/components/visualizer/common/visualizer_config.ts index b0885a8d..83ccc290 100644 --- a/src/ui/src/components/visualizer/common/visualizer_config.ts +++ b/src/ui/src/components/visualizer/common/visualizer_config.ts @@ -58,6 +58,9 @@ export declare interface VisualizerConfig { */ showOpNodeOutOfLayerEdgesWithoutSelecting?: boolean; + /** Whether to highlight layer node inputs and outputs. */ + highlightLayerNodeInputsOutputs?: boolean; + /** The default node styler rules. */ nodeStylerRules?: NodeStylerRule[]; diff --git a/src/ui/src/components/visualizer/info_panel.ng.html b/src/ui/src/components/visualizer/info_panel.ng.html index e3cfc1ef..3851712c 100644 --- a/src/ui/src/components/visualizer/info_panel.ng.html +++ b/src/ui/src/components/visualizer/info_panel.ng.html @@ -62,7 +62,7 @@ -
@@ -78,64 +78,8 @@ }
-
- @for (item of flatInputItems; track item.opNode.id; let i = $index) { -
-
-
{{item.index}}
- @if (item.opNode.hideInLayout) { -
{{getInputName(item)}}
- } @else { -
- {{getInputName(item)}} -
- my_location -
-
- @if (!item.opNode.hideInLayout) { -
- - {{getInputOpNodeToggleVisibilityIcon(item.opNode.id)}} - -
- } -
- } -
- - @for (metadataItem of item.metadataList; track metadataItem.key) { - @if (getShowMetadata(metadataItem)) { - - - - - } - } - -
- } -
+ +
@@ -155,75 +99,8 @@ }
-
- @for (item of outputItemsForCurPage; track item.outputId; let i = $index; let last = $last) { -
-
-
{{item.index}}
-
{{getOutputName(item)}}
-
- @if (getHasConnectedToNodes(item)) { -
- - {{getOutputToggleVisibilityIcon(item.outputId)}} - -
- } -
- @if (item.metadataList.length > 0) { - - @for (metadataItem of item.metadataList; track metadataItem.key) { - - - - - } - - } -
- } -
+ + @@ -246,6 +123,48 @@ + + +
+
+
+ + layer inputs ({{curGroupInputsCount}}) +
+ @if (showGroupInputPaginator) { + + + } +
+ + +
+ + +
+
+
+ + layer outputs ({{curGroupOutputsCount}}) +
+ @if (showGroupOutputPaginator) { + + + } +
+ + +
@@ -270,4 +189,154 @@ } - \ No newline at end of file + + + +
+ @for (item of items; track item.opNode.id; let i = $index) { +
+
+
{{item.index}}
+ @if (item.opNode.hideInLayout) { +
{{getInputName(item)}}
+ @if (item.targetOpNode) { +
+ arrow_forward +
{{item.targetOpNode.label}}
+
+ } + } @else { +
+ {{getInputName(item)}} +
+ my_location +
+ @if (item.targetOpNode) { +
+ arrow_forward +
{{item.targetOpNode.label}}
+
+ } +
+ @if (!item.opNode.hideInLayout) { +
+ + {{getInputOpNodeToggleVisibilityIcon(item.opNode.id)}} + +
+ } +
+ } +
+ + @for (metadataItem of item.metadataList; track metadataItem.key) { + @if (getShowMetadata(metadataItem)) { + + + + + } + } + +
+ } +
+
+ + +
+ @for (item of items; track $index; let i = $index; let last = $last) { +
+
+
{{item.index}}
+
{{getOutputName(item)}}
+ @if (item.showSourceOpNode) { +
+ ({{item.sourceOpNode.label}}) +
+ } +
+ @if (getHasConnectedToNodes(item)) { +
+ + {{getOutputToggleVisibilityIcon(item)}} + +
+ } +
+ @if (item.metadataList.length > 0) { + + @for (metadataItem of item.metadataList; track metadataItem.key) { + + + + + } + + } +
+ } +
+
diff --git a/src/ui/src/components/visualizer/info_panel.scss b/src/ui/src/components/visualizer/info_panel.scss index 0906f6f6..874b7584 100644 --- a/src/ui/src/components/visualizer/info_panel.scss +++ b/src/ui/src/components/visualizer/info_panel.scss @@ -92,13 +92,33 @@ background-color: #f6f6f6; .locator-icon-container { - opacity: .8; + opacity: 0.8; } } } } } + .target-op-container { + display: flex; + align-items: center; + color: #999; + font-weight: normal; + + mat-icon.arrow { + font-size: 12px; + height: 12px; + width: 12px; + margin: 0 4px; + } + } + + .source-op-node-label { + color: #999; + font-weight: normal; + margin-left: 6px; + } + .metadata-table { margin-top: 3px; margin-left: 18px; @@ -130,7 +150,7 @@ box-sizing: border-box; background-color: white; user-select: none; - color: rgba(0, 0, 0, .87); + color: rgba(0, 0, 0, 0.87); &.input, &.output, @@ -254,7 +274,7 @@ cursor: pointer; &:hover .locator-icon-container { - opacity: .8; + opacity: 0.8; } &.search-match { @@ -341,7 +361,7 @@ display: flex; align-items: center; justify-content: center; - opacity: .5; + opacity: 0.5; margin-left: 4px; &.left { @@ -355,7 +375,7 @@ } &:hover { - opacity: .8; + opacity: 0.8; } mat-icon { @@ -370,12 +390,12 @@ display: flex; align-items: center; justify-content: center; - opacity: .5; + opacity: 0.5; padding: 0 11px 0 20px; cursor: pointer; &:hover { - opacity: .8; + opacity: 0.8; } mat-icon { @@ -467,4 +487,4 @@ font-size: 12px; padding: 3px 0; } -} \ No newline at end of file +} diff --git a/src/ui/src/components/visualizer/info_panel.ts b/src/ui/src/components/visualizer/info_panel.ts index e3769570..9aacfff9 100644 --- a/src/ui/src/components/visualizer/info_panel.ts +++ b/src/ui/src/components/visualizer/info_panel.ts @@ -43,10 +43,12 @@ import {AppService} from './app_service'; import {TENSOR_TAG_METADATA_KEY, TENSOR_VALUES_KEY} from './common/consts'; import {GroupNode, ModelGraph, ModelNode, OpNode} from './common/model_graph'; import { + IncomingEdge, KeyValue, KeyValueList, KeyValuePairs, NodeDataProviderRunInfo, + OutgoingEdge, SearchMatchAttr, SearchMatchInputMetadata, SearchMatchOutputMetadata, @@ -84,6 +86,8 @@ enum SectionLabel { IDENTICAL_GROUPS = 'Identical groups', INPUTS = 'inputs', OUTPUTS = 'outputs', + GROUP_INPUTS = 'layer inputs', + GROUP_OUTPUTS = 'layer outputs', } interface InfoItem { @@ -102,18 +106,21 @@ interface InfoItem { interface OutputItem { index: number; tensorTag: string; + sourceOpNode: OpNode; outputId: string; metadataList: OutputItemMetadata[]; + showSourceOpNode?: boolean; } interface OutputItemMetadata extends KeyValue { connectedNodes?: OpNode[]; } -interface FlatInputItem { +interface InputItem { index: number; opNode: OpNode; metadataList: KeyValueList; + targetOpNode?: OpNode; } const MIN_WIDTH = 64; @@ -163,14 +170,21 @@ export class InfoPanel { @HostBinding('style.min-width.px') minWidth = DEFAULT_WIDTH; sections: InfoSection[] = []; - flatInputItems: FlatInputItem[] = []; + inputItems: InputItem[] = []; + inputItemsForCurPage: InputItem[] = []; outputItems: OutputItem[] = []; outputItemsForCurPage: OutputItem[] = []; + groupInputItems: InputItem[] = []; + groupInputItemsForCurPage: InputItem[] = []; + groupOutputItems: OutputItem[] = []; + groupOutputItemsForCurPage: OutputItem[] = []; identicalGroupNodes: GroupNode[] = []; identicalGroupsData?: TreeNode[]; curRendererId = ''; curInputsCount = 0; curOutputsCount = 0; + curGroupInputsCount = 0; + curGroupOutputsCount = 0; resizing = false; hide = false; @@ -211,8 +225,6 @@ export class InfoPanel { private curSearchAttrMatches: SearchMatchAttr[] = []; private curSearchInputMatches: SearchMatchInputMetadata[] = []; private curSearchOutputMatches: SearchMatchOutputMetadata[] = []; - private inputSourceNodes: OpNode[] = []; - private inputMetadataList: KeyValuePairs[] = []; private savedWidth = 0; constructor( @@ -329,18 +341,29 @@ export class InfoPanel { } handleInputPaginatorChanged(curPageIndex: number) { - const curInputSourceNodes = this.inputSourceNodes.slice( + this.inputItemsForCurPage = this.inputItems.slice( curPageIndex * this.ioPageSize, (curPageIndex + 1) * this.ioPageSize, ); - const curInputMetadataList = this.inputMetadataList.slice( + this.changeDetectorRef.markForCheck(); + + setTimeout(() => { + this.updateInputValueContentsExpandable(); + }); + } + + handleOutputPaginatorChanged(curPageIndex: number) { + this.outputItemsForCurPage = this.outputItems.slice( curPageIndex * this.ioPageSize, (curPageIndex + 1) * this.ioPageSize, ); - this.flatInputItems = this.genInputFlatItems( + this.changeDetectorRef.markForCheck(); + } + + handleGroupInputPaginatorChanged(curPageIndex: number) { + this.groupInputItemsForCurPage = this.groupInputItems.slice( curPageIndex * this.ioPageSize, - curInputSourceNodes, - curInputMetadataList, + (curPageIndex + 1) * this.ioPageSize, ); this.changeDetectorRef.markForCheck(); @@ -349,8 +372,8 @@ export class InfoPanel { }); } - handleOutputPaginatorChanged(curPageIndex: number) { - this.outputItemsForCurPage = this.outputItems.slice( + handleGroupOutputPaginatorChanged(curPageIndex: number) { + this.groupOutputItemsForCurPage = this.groupOutputItems.slice( curPageIndex * this.ioPageSize, (curPageIndex + 1) * this.ioPageSize, ); @@ -432,7 +455,7 @@ export class InfoPanel { handleToggleInputOpNodeVisibility( nodeId: string, - allItems: FlatInputItem[], + allItems: InputItem[], event: MouseEvent, ) { event.stopPropagation(); @@ -464,7 +487,7 @@ export class InfoPanel { } handleToggleOutputVisibility( - outputId: string, + item: OutputItem, items: OutputItem[], event: MouseEvent, ) { @@ -472,31 +495,39 @@ export class InfoPanel { if (event.altKey) { this.splitPaneService.setOutputVisible( - outputId, - items.map((item) => item.outputId), + item.sourceOpNode.id, + item.outputId, + items.map((item) => ({ + nodeId: item.sourceOpNode.id, + outputId: item.outputId, + })), ); } else { - this.splitPaneService.toggleOutputVisibility(outputId); + this.splitPaneService.toggleOutputVisibility( + item.sourceOpNode.id, + item.outputId, + ); } } - getOutputToggleVisible(outputId: string): boolean { - return this.splitPaneService.getOutputVisible(outputId); + getOutputToggleVisible(item: OutputItem): boolean { + return this.splitPaneService.getOutputVisible( + item.sourceOpNode.id, + item.outputId, + ); } - getOutputToggleVisibilityIcon(outputId: string): string { - return this.getOutputToggleVisible(outputId) - ? 'visibility' - : 'visibility_off'; + getOutputToggleVisibilityIcon(item: OutputItem): string { + return this.getOutputToggleVisible(item) ? 'visibility' : 'visibility_off'; } - getOutputToggleVisibilityTooltip(outputId: string): string { - return this.getOutputToggleVisible(outputId) + getOutputToggleVisibilityTooltip(item: OutputItem): string { + return this.getOutputToggleVisible(item) ? 'Click to hide highlight' : 'Click to show highlight'; } - getInputName(item: FlatInputItem): string { + getInputName(item: InputItem): string { const tensorTagItem = item.metadataList.find( (item) => item.key === TENSOR_TAG_METADATA_KEY, ); @@ -506,7 +537,7 @@ export class InfoPanel { : item.opNode.label; } - getInputTensorTag(item: FlatInputItem): string { + getInputTensorTag(item: InputItem): string { const tensorTagItem = item.metadataList.find( (item) => item.key === TENSOR_TAG_METADATA_KEY, ); @@ -559,7 +590,7 @@ export class InfoPanel { get showInputPaginator(): boolean { return ( - this.inputSourceNodes.length > this.ioPageSize && + this.inputItems.length > this.ioPageSize && !this.isSectionCollapsed(SectionLabel.INPUTS) ); } @@ -571,6 +602,20 @@ export class InfoPanel { ); } + get showGroupInputPaginator(): boolean { + return ( + this.groupInputItems.length > this.ioPageSize && + !this.isSectionCollapsed(SectionLabel.GROUP_INPUTS) + ); + } + + get showGroupOutputPaginator(): boolean { + return ( + this.groupOutputItems.length > this.ioPageSize && + !this.isSectionCollapsed(SectionLabel.GROUP_OUTPUTS) + ); + } + get showIdenticalGroupsPaginator(): boolean { return ( this.identicalGroupNodes.length > this.ioPageSize && @@ -602,10 +647,10 @@ export class InfoPanel { private genInfoData() { this.sections = []; - this.flatInputItems = []; - this.inputSourceNodes = []; - this.inputMetadataList = []; + this.inputItems = []; this.outputItems = []; + this.groupInputItems = []; + this.groupOutputItems = []; this.identicalGroupNodes = []; this.identicalGroupsData = undefined; @@ -617,6 +662,9 @@ export class InfoPanel { this.genInputsOutputsData(); } else if (isGroupNode(this.curSelectedNode)) { this.genInfoDataForSelectedGroupNode(); + if (this.appService.config()?.highlightLayerNodeInputsOutputs) { + this.genGroupInputsOutputsData(); + } } } } @@ -777,86 +825,48 @@ export class InfoPanel { const selectedOpNode = this.curSelectedNode as OpNode; const incomingEdges = selectedOpNode.incomingEdges || []; - this.inputMetadataList = []; - this.inputSourceNodes = []; - this.flatInputItems = []; - for (const edge of incomingEdges) { + this.inputItems = []; + for (let i = 0; i < incomingEdges.length; i++) { + const edge = incomingEdges[i]; + const metadataList = this.genInputMetadataList(selectedOpNode, edge); const sourceOpNode = this.curModelGraph?.nodesById[ edge.sourceNodeId ] as OpNode; - this.inputSourceNodes.push(sourceOpNode); - const metadata = - (selectedOpNode.inputsMetadata || {})[edge.targetNodeInputId] || {}; - // Merge the corresponding output metadata with the current input - // metadata. - const sourceNodeOutputMetadata = { - ...((sourceOpNode.outputsMetadata || {})[edge.sourceNodeOutputId] || - {}), - }; - for (const key of Object.keys(sourceNodeOutputMetadata)) { - if (metadata[key] == null && key !== TENSOR_TAG_METADATA_KEY) { - metadata[key] = sourceNodeOutputMetadata[key]; - } - } - this.inputMetadataList.push(metadata); - } - this.curInputsCount = this.inputSourceNodes.length; - if (incomingEdges.length > 0) { - const curInputSourceNodes = this.inputSourceNodes.slice( - 0, - this.ioPageSize, - ); - const curInputMetadataList = this.inputMetadataList.slice( - 0, - this.ioPageSize, - ); - this.flatInputItems = this.genInputFlatItems( - 0, - curInputSourceNodes, - curInputMetadataList, - ); + this.inputItems.push({ + index: i, + opNode: sourceOpNode, + metadataList, + }); } + this.curInputsCount = this.inputItems.length; + this.inputItemsForCurPage = this.inputItems.slice(0, this.ioPageSize); + // Outputs. this.outputItems = []; const outputsMetadata = selectedOpNode.outputsMetadata || {}; const outgoingEdges = selectedOpNode.outgoingEdges || []; let index = 0; for (const outputId of Object.keys(outputsMetadata)) { - // The metadata for the current output tensor. - const metadataList: OutputItemMetadata[] = []; - let tensorTag = ''; - for (const metadataKey of Object.keys(outputsMetadata[outputId])) { - const value = outputsMetadata[outputId][metadataKey]; - if (metadataKey === TENSOR_TAG_METADATA_KEY) { - tensorTag = value; - } - // Hide all metadata keys that start with '__'. - if (metadataKey.startsWith('__')) { - continue; - } - metadataList.push({ - key: metadataKey, - value, - }); - } - metadataList.sort((a, b) => a.key.localeCompare(b.key)); - // The connected nodes. const connectedNodes = outgoingEdges .filter((edge) => edge.sourceNodeOutputId === outputId) .map( (edge) => this.curModelGraph!.nodesById[edge.targetNodeId], ) as OpNode[]; - metadataList.push({ - key: this.outputMetadataConnectedTo, - value: '', + + // Metadata list. + const {metadataList, tensorTag} = this.genOutputMetadataList( + outgoingEdges, + outputsMetadata[outputId], connectedNodes, - }); + ); + this.outputItems.push({ index, tensorTag, outputId, + sourceOpNode: selectedOpNode, metadataList, }); index++; @@ -865,6 +875,178 @@ export class InfoPanel { this.outputItemsForCurPage = this.outputItems.slice(0, this.ioPageSize); } + private genGroupInputsOutputsData() { + if (!this.curModelGraph || !this.curSelectedNode) { + return; + } + + // Inputs. + const selectedGroupNode = this.curSelectedNode as GroupNode; + const seenInputNodeIds = new Set(); + + this.groupInputItems = []; + let index = 0; + for (const nodeId of selectedGroupNode.descendantsOpNodeIds || []) { + const descendantIds = new Set( + selectedGroupNode.descendantsOpNodeIds || [], + ); + const opNode = this.curModelGraph?.nodesById[nodeId] as OpNode; + const incomingEdges = opNode.incomingEdges || []; + + for (const edge of incomingEdges) { + const sourceOpNode = this.curModelGraph?.nodesById[ + edge.sourceNodeId + ] as OpNode; + + // Ignore if the source op node is within the layer. + if (descendantIds.has(sourceOpNode.id)) { + continue; + } + + // Dedup. + if (seenInputNodeIds.has(sourceOpNode.id)) { + continue; + } + seenInputNodeIds.add(sourceOpNode.id); + + const metadataList = this.genInputMetadataList(opNode, edge); + this.groupInputItems.push({ + index: index++, + opNode: sourceOpNode, + metadataList, + targetOpNode: opNode, + }); + } + } + + this.curGroupInputsCount = this.groupInputItems.length; + this.groupInputItemsForCurPage = this.groupInputItems.slice( + 0, + this.ioPageSize, + ); + + // Outputs. + this.groupOutputItems = []; + index = 0; + for (const nodeId of selectedGroupNode.descendantsOpNodeIds || []) { + const descendantIds = new Set( + selectedGroupNode.descendantsOpNodeIds || [], + ); + const opNode = this.curModelGraph?.nodesById[nodeId] as OpNode; + + const outputsMetadata = opNode.outputsMetadata || {}; + const outgoingEdges = opNode.outgoingEdges || []; + for (const outputId of Object.keys(outputsMetadata)) { + // The connected nodes. + const connectedNodes = outgoingEdges + .filter((edge) => !descendantIds.has(edge.targetNodeId)) + .filter((edge) => edge.sourceNodeOutputId === outputId) + .map( + (edge) => this.curModelGraph!.nodesById[edge.targetNodeId], + ) as OpNode[]; + if (connectedNodes.length === 0) { + continue; + } + + // Metadata list. + const {metadataList, tensorTag} = this.genOutputMetadataList( + outgoingEdges, + outputsMetadata[outputId], + connectedNodes, + ); + + this.groupOutputItems.push({ + index, + tensorTag, + outputId, + sourceOpNode: opNode, + metadataList, + showSourceOpNode: true, + }); + index++; + } + } + + this.curGroupOutputsCount = this.groupOutputItems.length; + this.groupOutputItemsForCurPage = this.groupOutputItems.slice( + 0, + this.ioPageSize, + ); + } + + private genInputMetadataList( + opNode: OpNode, + edge: IncomingEdge, + ): KeyValueList { + const sourceOpNode = this.curModelGraph?.nodesById[ + edge.sourceNodeId + ] as OpNode; + const metadata = + (opNode.inputsMetadata || {})[edge.targetNodeInputId] || {}; + // Merge the corresponding output metadata with the current input + // metadata. + const sourceNodeOutputMetadata = { + ...((sourceOpNode.outputsMetadata || {})[edge.sourceNodeOutputId] || {}), + }; + for (const key of Object.keys(sourceNodeOutputMetadata)) { + if (metadata[key] == null && key !== TENSOR_TAG_METADATA_KEY) { + metadata[key] = sourceNodeOutputMetadata[key]; + } + } + // Sort by key. + const metadataList: KeyValueList = []; + Object.entries(metadata).forEach(([key, value]) => { + metadataList.push({key, value}); + }); + metadataList.sort((a, b) => a.key.localeCompare(b.key)); + // Add namespace to metadata. + metadataList.push({ + key: this.inputMetadataNamespaceKey, + value: getNamespaceLabel(sourceOpNode), + }); + // Add tensor values to metadata if existed. + const attrs = sourceOpNode.attrs || {}; + if (attrs[TENSOR_VALUES_KEY]) { + metadataList.push({ + key: this.inputMetadataValuesKey, + value: attrs[TENSOR_VALUES_KEY], + }); + } + return metadataList; + } + + private genOutputMetadataList( + outgoingEdges: OutgoingEdge[], + outputMetadata: KeyValuePairs, + connectedNodes: OpNode[], + ) { + const metadataList: OutputItemMetadata[] = []; + let tensorTag = ''; + for (const metadataKey of Object.keys(outputMetadata)) { + const value = outputMetadata[metadataKey]; + if (metadataKey === TENSOR_TAG_METADATA_KEY) { + tensorTag = value; + } + // Hide all metadata keys that start with '__'. + if (metadataKey.startsWith('__')) { + continue; + } + metadataList.push({ + key: metadataKey, + value, + }); + } + metadataList.sort((a, b) => a.key.localeCompare(b.key)); + + metadataList.push({ + key: this.outputMetadataConnectedTo, + value: '', + connectedNodes, + }); + + return {metadataList, tensorTag}; + } + private genInfoDataForSelectedGroupNode() { if (!this.curModelGraph || !this.curSelectedNode) { return; @@ -1011,39 +1193,6 @@ export class InfoPanel { animate(); } - private genInputFlatItems( - startIndex: number, - inputSourceNodes: OpNode[], - inputMetadataList: KeyValuePairs[], - ): FlatInputItem[] { - const flatInputItems: FlatInputItem[] = []; - for (let i = 0; i < inputSourceNodes.length; i++) { - const sourceNode = inputSourceNodes[i]; - const metadataList: KeyValueList = []; - Object.entries(inputMetadataList[i]).forEach(([key, value]) => { - metadataList.push({key, value}); - }); - metadataList.sort((a, b) => a.key.localeCompare(b.key)); - metadataList.push({ - key: this.inputMetadataNamespaceKey, - value: getNamespaceLabel(inputSourceNodes[i]), - }); - const attrs = sourceNode.attrs || {}; - if (attrs[TENSOR_VALUES_KEY]) { - metadataList.push({ - key: this.inputMetadataValuesKey, - value: attrs[TENSOR_VALUES_KEY], - }); - } - flatInputItems.push({ - index: i + startIndex, - opNode: sourceNode, - metadataList, - }); - } - return flatInputItems; - } - private updateInputValueContentsExpandable() { for (let i = 0; i < this.inputValueContents.length; i++) { const valueContent = this.inputValueContents.get(i)?.nativeElement; diff --git a/src/ui/src/components/visualizer/split_pane_service.ts b/src/ui/src/components/visualizer/split_pane_service.ts index d2825a6b..69f40528 100644 --- a/src/ui/src/components/visualizer/split_pane_service.ts +++ b/src/ui/src/components/visualizer/split_pane_service.ts @@ -28,7 +28,8 @@ export class SplitPaneService { readonly hiddenInputOpNodeIds = signal>({}); /** - * The output ids that are hidden (i.e. not highlighted) in the model graph. + * The {nodeId}___{outputId} that are hidden (i.e. not highlighted) in the + * model graph. */ readonly hiddenOutputIds = signal>({}); @@ -71,24 +72,31 @@ export class SplitPaneService { } } - toggleOutputVisibility(outputId: string) { + toggleOutputVisibility(nodeId: string, outputId: string) { this.hiddenOutputIds.update((ids) => { - const visible = ids[outputId] === true; + const key = `${nodeId}___${outputId}`; + const visible = ids[key] === true; if (!visible) { - ids[outputId] = true; + ids[key] = true; } else { - delete ids[outputId]; + delete ids[key]; } return {...ids}; }); } - setOutputVisible(outputId: string, allOutputIds: string[]) { + setOutputVisible( + nodeId: string, + outputId: string, + allNodeAndOutputIds: Array<{nodeId: string; outputId: string}>, + ) { // Check if the output is the only visible one. - let isNodeTheOnlyVisibleOne = this.hiddenOutputIds()[outputId] !== true; - for (const id of allOutputIds) { - if (id !== outputId) { - if (!this.hiddenOutputIds()[id]) { + const key = `${nodeId}___${outputId}`; + let isNodeTheOnlyVisibleOne = this.hiddenOutputIds()[key] !== true; + for (const {nodeId, outputId} of allNodeAndOutputIds) { + const curKey = `${nodeId}___${outputId}`; + if (curKey !== key) { + if (!this.hiddenOutputIds()[curKey]) { isNodeTheOnlyVisibleOne = false; } } @@ -101,9 +109,10 @@ export class SplitPaneService { // If not, hide the other outputs. else { const ids: Record = {}; - for (const id of allOutputIds) { - if (id !== outputId) { - ids[id] = true; + for (const {nodeId, outputId} of allNodeAndOutputIds) { + const curKey = `${nodeId}___${outputId}`; + if (curKey !== key) { + ids[curKey] = true; } } this.hiddenOutputIds.set(ids); @@ -114,8 +123,9 @@ export class SplitPaneService { return !this.hiddenInputOpNodeIds()[nodeId]; } - getOutputVisible(outputId: string): boolean { - return !this.hiddenOutputIds()[outputId]; + getOutputVisible(nodeId: string, outputId: string): boolean { + const key = `${nodeId}___${outputId}`; + return !this.hiddenOutputIds()[key]; } resetInputOutputHiddenIds() { diff --git a/src/ui/src/components/visualizer/webgl_renderer_io_highlight_service.ts b/src/ui/src/components/visualizer/webgl_renderer_io_highlight_service.ts index fa370ec0..b385b0f0 100644 --- a/src/ui/src/components/visualizer/webgl_renderer_io_highlight_service.ts +++ b/src/ui/src/components/visualizer/webgl_renderer_io_highlight_service.ts @@ -352,7 +352,7 @@ export class WebglRendererIoHighlightService { getHighlightedIncomingNodesAndEdges( hiddenInputNodeIds: Record, - selectedNode?: OpNode, + selectedNode?: ModelNode, options?: { ignoreEdgesWithinSameNamespace?: boolean; reuseRenderedEdgeCurvePoints?: boolean; @@ -364,202 +364,227 @@ export class WebglRendererIoHighlightService { options?.reuseRenderedEdgeCurvePoints ?? false; if (!selectedNode) { - selectedNode = this.webglRenderer.curModelGraph.nodesById[ - this.webglRenderer.selectedNodeId - ] as OpNode; + selectedNode = + this.webglRenderer.curModelGraph.nodesById[ + this.webglRenderer.selectedNodeId + ]; } const renderedEdges: ModelEdge[] = []; const highlightedNodes: ModelNode[] = []; const inputsByHighlightedNode: Record = {}; const overlayEdges: OverlayModelEdge[] = []; - for (const incomingEdge of selectedNode.incomingEdges || []) { - if (hiddenInputNodeIds[incomingEdge.sourceNodeId]) { - continue; - } - const sourceNode = this.webglRenderer.curModelGraph.nodesById[ - incomingEdge.sourceNodeId - ] as OpNode; - if (!sourceNode) { - continue; - } - - if ( - ignoreEdgesWithinSameNamespace && - sourceNode.namespace === selectedNode.namespace - ) { - continue; + const opNodes: OpNode[] = []; + const ignoredIncomingNodesIds = new Set(); + const seenIncomingNodesIds = new Set(); + if (isOpNode(selectedNode)) { + opNodes.push(selectedNode); + } else if (isGroupNode(selectedNode)) { + for (const id of selectedNode.descendantsOpNodeIds || []) { + const node = this.webglRenderer.curModelGraph.nodesById[id] as OpNode; + opNodes.push(node); + ignoredIncomingNodesIds.add(id); } + } - // Find the common namespace prefix. - const commonNamespace = findCommonNamespace( - sourceNode.namespace, - selectedNode.namespace, - ); + for (const opNode of opNodes) { + for (const incomingEdge of opNode.incomingEdges || []) { + if (hiddenInputNodeIds[incomingEdge.sourceNodeId]) { + continue; + } + const sourceNode = this.webglRenderer.curModelGraph.nodesById[ + incomingEdge.sourceNodeId + ] as OpNode; + if (!sourceNode) { + continue; + } - // Go from the given node to all its ns ancestors, find the last collapsed - // node before reaching the given namespace. If all ancestor nodes are - // expanded, return the given node. - const highlightedNode = this.getLastCollapsedAncestorNode( - sourceNode, - commonNamespace, - ); - highlightedNodes.push(highlightedNode); + if (ignoredIncomingNodesIds.has(sourceNode.id)) { + continue; + } - // Update inputsByHighlighedNode. - if (inputsByHighlightedNode[highlightedNode.id] == null) { - inputsByHighlightedNode[highlightedNode.id] = []; - } - inputsByHighlightedNode[highlightedNode.id].push(sourceNode); - - // Find the existing edge in the common namespace that connects two - // nodes n1 and n2 where n1 contains `sourceNode` and n2 contains - // `node`. - const renderedEdge = this.findEdgeConnectingTwoNodesInNamespace( - commonNamespace, - sourceNode.id, - selectedNode.id, - ); + if (seenIncomingNodesIds.has(sourceNode.id)) { + continue; + } + seenIncomingNodesIds.add(sourceNode.id); - // Start to construct an edge from the source node to the selected node. - // - const points: Point[] = []; - const curvePoints: Point[] = []; - - if (renderedEdge) { - renderedEdges.push(renderedEdge); - const renderedEdgeCurvePoints = renderedEdge.curvePoints || []; - - // Add a point from the highlighted node that connects to the first - // point of the rendered edge above. - const renderedEdgeFromNode = - this.webglRenderer.curModelGraph.nodesById[renderedEdge.fromNodeId]; - if (renderedEdge.fromNodeId !== highlightedNode.id) { - const renderedEdgeStartX = - renderedEdge.points[0].x + (renderedEdgeFromNode.globalX || 0); - const renderedEdgeStartY = - renderedEdge.points[0].y + (renderedEdgeFromNode.globalY || 0); - const startPt = this.getBestAnchorPointOnNode( - renderedEdgeStartX, - renderedEdgeStartY, - highlightedNode, - ); - points.push({ - x: startPt.x - (highlightedNode.globalX || 0), - y: startPt.y - (highlightedNode.globalY || 0), - }); - if (reuseRenderedEdgeCurvePoints) { - curvePoints.push( - { - x: startPt.x - (highlightedNode.globalX || 0), - y: startPt.y - (highlightedNode.globalY || 0), - }, - { - x: - renderedEdgeCurvePoints[0].x - - (highlightedNode.globalX || 0) + - (renderedEdgeFromNode.globalX || 0), - y: - renderedEdgeCurvePoints[0].y - - (highlightedNode.globalY || 0) + - (renderedEdgeFromNode.globalY || 0), - }, - ); - } + if ( + ignoreEdgesWithinSameNamespace && + sourceNode.namespace === opNode.namespace + ) { + continue; } - // Add the points in rendered edge. - let targetPoints: Point[] = points; - let sourcePoints: Point[] = renderedEdge.points; - if (reuseRenderedEdgeCurvePoints) { - targetPoints = curvePoints; - sourcePoints = renderedEdgeCurvePoints; + // Find the common namespace prefix. + const commonNamespace = findCommonNamespace( + sourceNode.namespace, + opNode.namespace, + ); + + // Go from the given node to all its ns ancestors, find the last collapsed + // node before reaching the given namespace. If all ancestor nodes are + // expanded, return the given node. + const highlightedNode = this.getLastCollapsedAncestorNode( + sourceNode, + commonNamespace, + ); + highlightedNodes.push(highlightedNode); + + // Update inputsByHighlighedNode. + if (inputsByHighlightedNode[highlightedNode.id] == null) { + inputsByHighlightedNode[highlightedNode.id] = []; } - targetPoints.push( - ...sourcePoints.map((pt) => { - return { - x: - pt.x - - (highlightedNode.globalX || 0) + - (renderedEdgeFromNode.globalX || 0), - y: - pt.y - - (highlightedNode.globalY || 0) + - (renderedEdgeFromNode.globalY || 0), - }; - }), + inputsByHighlightedNode[highlightedNode.id].push(sourceNode); + + // Find the existing edge in the common namespace that connects two + // nodes n1 and n2 where n1 contains `sourceNode` and n2 contains + // `node`. + const renderedEdge = this.findEdgeConnectingTwoNodesInNamespace( + commonNamespace, + sourceNode.id, + opNode.id, ); - // Add a point from the selected node that connects to the last point of - // the rendered edge. - if (renderedEdge.toNodeId !== selectedNode?.id) { - const renderedEdgeLastX = - renderedEdge.points[renderedEdge.points.length - 1].x + - (renderedEdgeFromNode.globalX || 0); - const renderedEdgeLastY = - renderedEdge.points[renderedEdge.points.length - 1].y + - (renderedEdgeFromNode.globalY || 0); - const endPt = this.getBestAnchorPointOnNode( - renderedEdgeLastX, - renderedEdgeLastY, - selectedNode, - ); - if (!reuseRenderedEdgeCurvePoints) { + // Start to construct an edge from the source node to the selected node. + // + const points: Point[] = []; + const curvePoints: Point[] = []; + + if (renderedEdge) { + renderedEdges.push(renderedEdge); + const renderedEdgeCurvePoints = renderedEdge.curvePoints || []; + + // Add a point from the highlighted node that connects to the first + // point of the rendered edge above. + const renderedEdgeFromNode = + this.webglRenderer.curModelGraph.nodesById[renderedEdge.fromNodeId]; + if (renderedEdge.fromNodeId !== highlightedNode.id) { + const renderedEdgeStartX = + renderedEdge.points[0].x + (renderedEdgeFromNode.globalX || 0); + const renderedEdgeStartY = + renderedEdge.points[0].y + (renderedEdgeFromNode.globalY || 0); + const startPt = this.getBestAnchorPointOnNode( + renderedEdgeStartX, + renderedEdgeStartY, + highlightedNode, + ); points.push({ - x: endPt.x - (highlightedNode.globalX || 0), - y: endPt.y - (highlightedNode.globalY || 0), + x: startPt.x - (highlightedNode.globalX || 0), + y: startPt.y - (highlightedNode.globalY || 0), }); - } else { - curvePoints.push( - { + if (reuseRenderedEdgeCurvePoints) { + curvePoints.push( + { + x: startPt.x - (highlightedNode.globalX || 0), + y: startPt.y - (highlightedNode.globalY || 0), + }, + { + x: + renderedEdgeCurvePoints[0].x - + (highlightedNode.globalX || 0) + + (renderedEdgeFromNode.globalX || 0), + y: + renderedEdgeCurvePoints[0].y - + (highlightedNode.globalY || 0) + + (renderedEdgeFromNode.globalY || 0), + }, + ); + } + } + + // Add the points in rendered edge. + let targetPoints: Point[] = points; + let sourcePoints: Point[] = renderedEdge.points; + if (reuseRenderedEdgeCurvePoints) { + targetPoints = curvePoints; + sourcePoints = renderedEdgeCurvePoints; + } + targetPoints.push( + ...sourcePoints.map((pt) => { + return { x: - renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1] - .x - + pt.x - (highlightedNode.globalX || 0) + (renderedEdgeFromNode.globalX || 0), y: - renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1] - .y - + pt.y - (highlightedNode.globalY || 0) + (renderedEdgeFromNode.globalY || 0), - }, - { + }; + }), + ); + + // Add a point from the selected node that connects to the last point of + // the rendered edge. + if (renderedEdge.toNodeId !== opNode?.id && isOpNode(selectedNode)) { + const renderedEdgeLastX = + renderedEdge.points[renderedEdge.points.length - 1].x + + (renderedEdgeFromNode.globalX || 0); + const renderedEdgeLastY = + renderedEdge.points[renderedEdge.points.length - 1].y + + (renderedEdgeFromNode.globalY || 0); + const endPt = this.getBestAnchorPointOnNode( + renderedEdgeLastX, + renderedEdgeLastY, + opNode, + ); + if (!reuseRenderedEdgeCurvePoints) { + points.push({ x: endPt.x - (highlightedNode.globalX || 0), y: endPt.y - (highlightedNode.globalY || 0), - }, - ); + }); + } else { + curvePoints.push( + { + x: + renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1] + .x - + (highlightedNode.globalX || 0) + + (renderedEdgeFromNode.globalX || 0), + y: + renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1] + .y - + (highlightedNode.globalY || 0) + + (renderedEdgeFromNode.globalY || 0), + }, + { + x: endPt.x - (highlightedNode.globalX || 0), + y: endPt.y - (highlightedNode.globalY || 0), + }, + ); + } } + } else if ( + isGroupNode(highlightedNode) || + (isOpNode(highlightedNode) && !highlightedNode.hideInLayout) + ) { + (reuseRenderedEdgeCurvePoints ? curvePoints : points).push( + ...this.getDirectEdgeBetweenNodes(highlightedNode, opNode), + ); } - } else if ( - isGroupNode(highlightedNode) || - (isOpNode(highlightedNode) && !highlightedNode.hideInLayout) - ) { - (reuseRenderedEdgeCurvePoints ? curvePoints : points).push( - ...this.getDirectEdgeBetweenNodes(highlightedNode, selectedNode), - ); - } - // Use these points to form an edge and add it as an overlay edge. - if (!reuseRenderedEdgeCurvePoints) { - if (points.length > 0) { - overlayEdges.push({ - id: `overlay_${highlightedNode.id}___${selectedNode.id}`, - fromNodeId: highlightedNode.id, - toNodeId: selectedNode.id, - points, - type: 'incoming', - }); - } - } else { - if (curvePoints.length > 0) { - overlayEdges.push({ - id: `overlay_${highlightedNode.id}___${selectedNode.id}`, - fromNodeId: highlightedNode.id, - toNodeId: selectedNode.id, - points: [], - curvePoints, - type: 'incoming', - }); + // Use these points to form an edge and add it as an overlay edge. + if (!reuseRenderedEdgeCurvePoints) { + if (points.length > 0) { + overlayEdges.push({ + id: `overlay_${highlightedNode.id}___${opNode.id}`, + fromNodeId: highlightedNode.id, + toNodeId: opNode.id, + points, + type: 'incoming', + }); + } + } else { + if (curvePoints.length > 0) { + overlayEdges.push({ + id: `overlay_${highlightedNode.id}___${opNode.id}`, + fromNodeId: highlightedNode.id, + toNodeId: opNode.id, + points: [], + curvePoints, + type: 'incoming', + }); + } } } } @@ -574,7 +599,7 @@ export class WebglRendererIoHighlightService { getHighlightedOutgoingNodesAndEdges( hiddenOutputIds: Record, - selectedNode?: OpNode, + selectedNode?: ModelNode, options?: { ignoreEdgesWithinSameNamespace?: boolean; reuseRenderedEdgeCurvePoints?: boolean; @@ -586,204 +611,234 @@ export class WebglRendererIoHighlightService { options?.reuseRenderedEdgeCurvePoints ?? false; if (!selectedNode) { - selectedNode = this.webglRenderer.curModelGraph.nodesById[ - this.webglRenderer.selectedNodeId - ] as OpNode; + selectedNode = + this.webglRenderer.curModelGraph.nodesById[ + this.webglRenderer.selectedNodeId + ]; } const renderedEdges: ModelEdge[] = []; const highlightedNodes: ModelNode[] = []; const outputsByHighlightedNode: Record = {}; const overlayEdges: OverlayModelEdge[] = []; - for (const outgoingEdges of selectedNode.outgoingEdges || []) { - if (hiddenOutputIds[outgoingEdges.sourceNodeOutputId]) { - continue; - } - - const targetNode = this.webglRenderer.curModelGraph.nodesById[ - outgoingEdges.targetNodeId - ] as OpNode; - if (!targetNode) { - continue; + const opNodes: OpNode[] = []; + const ignoredOutgoingNodesIds = new Set(); + const seenOutgoingNodesIds = new Set(); + if (isOpNode(selectedNode)) { + opNodes.push(selectedNode); + } else if (isGroupNode(selectedNode)) { + for (const id of selectedNode.descendantsOpNodeIds || []) { + const node = this.webglRenderer.curModelGraph.nodesById[id] as OpNode; + opNodes.push(node); + ignoredOutgoingNodesIds.add(id); } + } - if ( - ignoreEdgesWithinSameNamespace && - targetNode.namespace === selectedNode.namespace - ) { - continue; - } + for (const opNode of opNodes) { + for (const outgoingEdges of opNode.outgoingEdges || []) { + if ( + hiddenOutputIds[`${opNode.id}___${outgoingEdges.sourceNodeOutputId}`] + ) { + continue; + } - // Find the common namespace prefix. - const commonNamespace = findCommonNamespace( - targetNode.namespace, - selectedNode.namespace, - ); + const targetNode = this.webglRenderer.curModelGraph.nodesById[ + outgoingEdges.targetNodeId + ] as OpNode; + if (!targetNode) { + continue; + } - // Go from the given node to all its ns ancestors, find the last - // collapsed node before reaching the given namespace, and style it with - // the given class. If all ancestor nodes are expanded, style the given - // node. - const highlightedNode = this.getLastCollapsedAncestorNode( - targetNode, - commonNamespace, - ); - highlightedNodes.push(highlightedNode); + if (ignoredOutgoingNodesIds.has(targetNode.id)) { + continue; + } - // Update outputsByHighlighedNode. - if (outputsByHighlightedNode[highlightedNode.id] == null) { - outputsByHighlightedNode[highlightedNode.id] = []; - } - outputsByHighlightedNode[highlightedNode.id].push(targetNode); - - // Find the existing edge in the common namespace that connects two - // nodes n1 and n2 where n1 contains `sourceNode` and n2 contains - // `node`. - const renderedEdge = this.findEdgeConnectingTwoNodesInNamespace( - commonNamespace, - selectedNode.id, - targetNode.id, - ); + if (seenOutgoingNodesIds.has(targetNode.id)) { + continue; + } + seenOutgoingNodesIds.add(targetNode.id); - // Start to construct an edge from the selected node to target node. - // - const points: Point[] = []; - const curvePoints: Point[] = []; - - if (renderedEdge) { - renderedEdges.push(renderedEdge); - const renderedEdgeCurvePoints = renderedEdge.curvePoints || []; - - // Add a point from the selected node that connects to the first point - // of the rendered edge. - const renderedEdgeFromNode = - this.webglRenderer.curModelGraph.nodesById[renderedEdge.fromNodeId]; - if (renderedEdge.fromNodeId !== selectedNode?.id) { - const renderedEdgeStartX = - renderedEdge.points[0].x + (renderedEdgeFromNode.globalX || 0); - const renderedEdgeStartY = - renderedEdge.points[0].y + (renderedEdgeFromNode.globalY || 0); - const endPt = this.getBestAnchorPointOnNode( - renderedEdgeStartX, - renderedEdgeStartY, - selectedNode, - ); - points.push({ - x: endPt.x - (selectedNode.globalX || 0), - y: endPt.y - (selectedNode.globalY || 0), - }); - if (reuseRenderedEdgeCurvePoints) { - curvePoints.push( - { - x: endPt.x - (selectedNode.globalX || 0), - y: endPt.y - (selectedNode.globalY || 0), - }, - { - x: - renderedEdgeCurvePoints[0].x - - (selectedNode.globalX || 0) + - (renderedEdgeFromNode.globalX || 0), - y: - renderedEdgeCurvePoints[0].y - - (selectedNode.globalY || 0) + - (renderedEdgeFromNode.globalY || 0), - }, - ); - } + if ( + ignoreEdgesWithinSameNamespace && + targetNode.namespace === opNode.namespace + ) { + continue; } - // Add the points in rendered edge. - let targetPoints: Point[] = points; - let sourcePoints: Point[] = renderedEdge.points; - if (reuseRenderedEdgeCurvePoints) { - targetPoints = curvePoints; - sourcePoints = renderedEdgeCurvePoints; + // Find the common namespace prefix. + const commonNamespace = findCommonNamespace( + targetNode.namespace, + opNode.namespace, + ); + + // Go from the given node to all its ns ancestors, find the last + // collapsed node before reaching the given namespace, and style it with + // the given class. If all ancestor nodes are expanded, style the given + // node. + const highlightedNode = this.getLastCollapsedAncestorNode( + targetNode, + commonNamespace, + ); + highlightedNodes.push(highlightedNode); + + // Update outputsByHighlighedNode. + if (outputsByHighlightedNode[highlightedNode.id] == null) { + outputsByHighlightedNode[highlightedNode.id] = []; } - targetPoints.push( - ...sourcePoints.map((pt) => { - return { - x: - pt.x - - (selectedNode.globalX || 0) + - (renderedEdgeFromNode.globalX || 0), - y: - pt.y - - (selectedNode.globalY || 0) + - (renderedEdgeFromNode.globalY || 0), - }; - }), + outputsByHighlightedNode[highlightedNode.id].push(targetNode); + + // Find the existing edge in the common namespace that connects two + // nodes n1 and n2 where n1 contains `sourceNode` and n2 contains + // `node`. + const renderedEdge = this.findEdgeConnectingTwoNodesInNamespace( + commonNamespace, + opNode.id, + targetNode.id, ); - // Add a point from the highlighted node that connects to the first - // point of the rendered edge above. - if (renderedEdge.toNodeId !== highlightedNode.id) { - const renderedEdgeLastX = - renderedEdge.points[renderedEdge.points.length - 1].x + - (renderedEdgeFromNode.globalX || 0); - const renderedEdgeLastY = - renderedEdge.points[renderedEdge.points.length - 1].y + - (renderedEdgeFromNode.globalY || 0); - const startPt = this.getBestAnchorPointOnNode( - renderedEdgeLastX, - renderedEdgeLastY, - highlightedNode, - ); - if (!reuseRenderedEdgeCurvePoints) { - points.push({ - x: startPt.x - (selectedNode.globalX || 0), - y: startPt.y - (selectedNode.globalY || 0), - }); - } else { - curvePoints.push( - { + // Start to construct an edge from the selected node to target node. + // + const points: Point[] = []; + const curvePoints: Point[] = []; + + if (renderedEdge) { + renderedEdges.push(renderedEdge); + const renderedEdgeCurvePoints = renderedEdge.curvePoints || []; + + const renderedEdgeFromNode = + this.webglRenderer.curModelGraph.nodesById[renderedEdge.fromNodeId]; + + // Add a point from the selected node that connects to the first point + // of the rendered edge. + if (isOpNode(selectedNode)) { + if (renderedEdge.fromNodeId !== opNode?.id) { + const renderedEdgeStartX = + renderedEdge.points[0].x + (renderedEdgeFromNode.globalX || 0); + const renderedEdgeStartY = + renderedEdge.points[0].y + (renderedEdgeFromNode.globalY || 0); + const endPt = this.getBestAnchorPointOnNode( + renderedEdgeStartX, + renderedEdgeStartY, + opNode, + ); + points.push({ + x: endPt.x - (opNode.globalX || 0), + y: endPt.y - (opNode.globalY || 0), + }); + if (reuseRenderedEdgeCurvePoints) { + curvePoints.push( + { + x: endPt.x - (opNode.globalX || 0), + y: endPt.y - (opNode.globalY || 0), + }, + { + x: + renderedEdgeCurvePoints[0].x - + (opNode.globalX || 0) + + (renderedEdgeFromNode.globalX || 0), + y: + renderedEdgeCurvePoints[0].y - + (opNode.globalY || 0) + + (renderedEdgeFromNode.globalY || 0), + }, + ); + } + } + } + + // Add the points in rendered edge. + let targetPoints: Point[] = points; + let sourcePoints: Point[] = renderedEdge.points; + if (reuseRenderedEdgeCurvePoints) { + targetPoints = curvePoints; + sourcePoints = renderedEdgeCurvePoints; + } + targetPoints.push( + ...sourcePoints.map((pt) => { + return { x: - renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1] - .x - - (selectedNode.globalX || 0) + + pt.x - + (opNode.globalX || 0) + (renderedEdgeFromNode.globalX || 0), y: - renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1] - .y - - (selectedNode.globalY || 0) + + pt.y - + (opNode.globalY || 0) + (renderedEdgeFromNode.globalY || 0), - }, - { - x: startPt.x - (selectedNode.globalX || 0), - y: startPt.y - (selectedNode.globalY || 0), - }, + }; + }), + ); + + // Add a point from the highlighted node that connects to the first + // point of the rendered edge above. + if (renderedEdge.toNodeId !== highlightedNode.id) { + const renderedEdgeLastX = + renderedEdge.points[renderedEdge.points.length - 1].x + + (renderedEdgeFromNode.globalX || 0); + const renderedEdgeLastY = + renderedEdge.points[renderedEdge.points.length - 1].y + + (renderedEdgeFromNode.globalY || 0); + const startPt = this.getBestAnchorPointOnNode( + renderedEdgeLastX, + renderedEdgeLastY, + highlightedNode, ); + if (!reuseRenderedEdgeCurvePoints) { + points.push({ + x: startPt.x - (opNode.globalX || 0), + y: startPt.y - (opNode.globalY || 0), + }); + } else { + curvePoints.push( + { + x: + renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1] + .x - + (opNode.globalX || 0) + + (renderedEdgeFromNode.globalX || 0), + y: + renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1] + .y - + (opNode.globalY || 0) + + (renderedEdgeFromNode.globalY || 0), + }, + { + x: startPt.x - (opNode.globalX || 0), + y: startPt.y - (opNode.globalY || 0), + }, + ); + } } + } else if ( + isGroupNode(highlightedNode) || + (isOpNode(highlightedNode) && !highlightedNode.hideInLayout) + ) { + (reuseRenderedEdgeCurvePoints ? curvePoints : points).push( + ...this.getDirectEdgeBetweenNodes(opNode, highlightedNode), + ); } - } else if ( - isGroupNode(highlightedNode) || - (isOpNode(highlightedNode) && !highlightedNode.hideInLayout) - ) { - (reuseRenderedEdgeCurvePoints ? curvePoints : points).push( - ...this.getDirectEdgeBetweenNodes(selectedNode, highlightedNode), - ); - } - // Use these points to form an edge and add it as an overlay edge. - if (!reuseRenderedEdgeCurvePoints) { - if (points.length > 0) { - overlayEdges.push({ - id: `overlay_${selectedNode.id}___${highlightedNode.id}`, - fromNodeId: selectedNode.id, - toNodeId: highlightedNode.id, - points, - type: 'outgoing', - }); - } - } else { - if (curvePoints.length > 0) { - overlayEdges.push({ - id: `overlay_${selectedNode.id}___${highlightedNode.id}`, - fromNodeId: selectedNode.id, - toNodeId: highlightedNode.id, - points: [], - curvePoints, - type: 'outgoing', - }); + // Use these points to form an edge and add it as an overlay edge. + if (!reuseRenderedEdgeCurvePoints) { + if (points.length > 0) { + overlayEdges.push({ + id: `overlay_${opNode.id}___${highlightedNode.id}`, + fromNodeId: opNode.id, + toNodeId: highlightedNode.id, + points, + type: 'outgoing', + }); + } + } else { + if (curvePoints.length > 0) { + overlayEdges.push({ + id: `overlay_${opNode.id}___${highlightedNode.id}`, + fromNodeId: opNode.id, + toNodeId: highlightedNode.id, + points: [], + curvePoints, + type: 'outgoing', + }); + } } } } @@ -830,12 +885,16 @@ export class WebglRendererIoHighlightService { return false; } - // Ignore when clicking on a group node. + // Ignore when clicking on a group node and the corresponding config option + // is not enabled. const selectedNode = this.webglRenderer.curModelGraph.nodesById[ this.webglRenderer.selectedNodeId ]; - if (isGroupNode(selectedNode)) { + if ( + isGroupNode(selectedNode) && + !this.webglRenderer.appService.config()?.highlightLayerNodeInputsOutputs + ) { return false; } diff --git a/src/ui/src/services/settings_service.ts b/src/ui/src/services/settings_service.ts index 348aa01c..0206eb0f 100644 --- a/src/ui/src/services/settings_service.ts +++ b/src/ui/src/services/settings_service.ts @@ -35,6 +35,7 @@ export enum SettingKey { DISALLOW_VERTICAL_EDGE_LABELS = 'disallow_vertical_edge_labels', KEEP_LAYERS_WITH_A_SINGLE_CHILD = 'keep_layers_with_a_single_child', SHOW_OP_NODE_OUT_OF_LAYER_EDGES_WITHOUT_SELECTING = 'show_op_node_out_of_layer_edges_without_selecting', + HIGHLIGHT_LAYER_NODE_INPUTS_OUTPUTS = 'highlight_layer_node_inputs_outputs', } /** Setting types. */ @@ -156,6 +157,19 @@ export const SETTING_SHOW_OP_NODE_OUT_OF_LAYER_EDGES_WITHOUT_SELECTING: Setting 'especially for larger models.', }; +/** Setting for highlighting layer node inputs and outputs. */ +export const SETTING_HIGHLIGHT_LAYER_NODE_INPUTS_OUTPUTS: Setting = { + label: 'Highlight inputs and outputs of the selected layer node', + key: SettingKey.HIGHLIGHT_LAYER_NODE_INPUTS_OUTPUTS, + type: SettingType.BOOLEAN, + defaultValue: false, + help: + 'By default, inputs and outputs are highlighted only when an op node ' + + 'is selected. Enable this setting to see inputs and outputs for a ' + + 'selected layer node, including all its descendant op nodes within ' + + 'that layer.', +}; + const SETTINGS_LOCAL_STORAGE_KEY = 'model_explorer_settings'; /** All settings. */ @@ -168,7 +182,8 @@ export const ALL_SETTINGS = [ SETTING_KEEP_LAYERS_WITH_A_SINGLE_CHILD, SETTING_SHOW_WELCOME_CARD, SETTING_DISALLOW_VERTICAL_EDGE_LABELS, - SETTING_SHOW_OP_NODE_OUT_OF_LAYER_EDGES_WITHOUT_SELECTING, + SETTING_MAX_CONST_ELEMENT_COUNT_LIMIT, + SETTING_HIGHLIGHT_LAYER_NODE_INPUTS_OUTPUTS, ]; /**