diff --git a/src/ui/src/components/home_page/home_page.ts b/src/ui/src/components/home_page/home_page.ts
index 9c334ff8..50f3ab28 100644
--- a/src/ui/src/components/home_page/home_page.ts
+++ b/src/ui/src/components/home_page/home_page.ts
@@ -47,6 +47,7 @@ import {
SETTING_EDGE_COLOR,
SETTING_EDGE_LABEL_FONT_SIZE,
SETTING_HIDE_OP_NODES_WITH_LABELS,
+ SETTING_HIGHLIGHT_LAYER_NODE_INPUTS_OUTPUTS,
SETTING_KEEP_LAYERS_WITH_A_SINGLE_CHILD,
SETTING_MAX_CONST_ELEMENT_COUNT_LIMIT,
SETTING_SHOW_OP_NODE_OUT_OF_LAYER_EDGES_WITHOUT_SELECTING,
@@ -351,6 +352,9 @@ export class HomePage implements AfterViewInit {
this.settingsService.getBooleanValue(
SETTING_SHOW_OP_NODE_OUT_OF_LAYER_EDGES_WITHOUT_SELECTING,
),
+ highlightLayerNodeInputsOutputs: this.settingsService.getBooleanValue(
+ SETTING_HIGHLIGHT_LAYER_NODE_INPUTS_OUTPUTS,
+ ),
};
}
diff --git a/src/ui/src/components/visualizer/common/visualizer_config.ts b/src/ui/src/components/visualizer/common/visualizer_config.ts
index b0885a8d..83ccc290 100644
--- a/src/ui/src/components/visualizer/common/visualizer_config.ts
+++ b/src/ui/src/components/visualizer/common/visualizer_config.ts
@@ -58,6 +58,9 @@ export declare interface VisualizerConfig {
*/
showOpNodeOutOfLayerEdgesWithoutSelecting?: boolean;
+ /** Whether to highlight layer node inputs and outputs. */
+ highlightLayerNodeInputsOutputs?: boolean;
+
/** The default node styler rules. */
nodeStylerRules?: NodeStylerRule[];
diff --git a/src/ui/src/components/visualizer/info_panel.ng.html b/src/ui/src/components/visualizer/info_panel.ng.html
index e3cfc1ef..3851712c 100644
--- a/src/ui/src/components/visualizer/info_panel.ng.html
+++ b/src/ui/src/components/visualizer/info_panel.ng.html
@@ -62,7 +62,7 @@
-
@@ -246,6 +123,48 @@
+
+
+ 0" #groupInputsSectionEle
+ [class.collapsed]="isSectionCollapsed(SectionLabel.GROUP_INPUTS)">
+
+
+
+
+
+
+ 0" #groupOutputsSectionEle
+ [class.collapsed]="isSectionCollapsed(SectionLabel.GROUP_OUTPUTS)">
+
+
+
+
@@ -270,4 +189,154 @@
}
-
\ No newline at end of file
+
+
+
+
+ @for (item of items; track item.opNode.id; let i = $index) {
+
+
+
{{item.index}}
+ @if (item.opNode.hideInLayout) {
+
{{getInputName(item)}}
+ @if (item.targetOpNode) {
+
+
arrow_forward
+
{{item.targetOpNode.label}}
+
+ }
+ } @else {
+
+ {{getInputName(item)}}
+
+ my_location
+
+ @if (item.targetOpNode) {
+
+
arrow_forward
+
{{item.targetOpNode.label}}
+
+ }
+
+ @if (!item.opNode.hideInLayout) {
+
+
+ {{getInputOpNodeToggleVisibilityIcon(item.opNode.id)}}
+
+
+ }
+
+ }
+
+
+
+ }
+
+
+
+
+
+ @for (item of items; track $index; let i = $index; let last = $last) {
+
+
+
{{item.index}}
+
{{getOutputName(item)}}
+ @if (item.showSourceOpNode) {
+
+ ({{item.sourceOpNode.label}})
+
+ }
+
+ @if (getHasConnectedToNodes(item)) {
+
+
+ {{getOutputToggleVisibilityIcon(item)}}
+
+
+ }
+
+ @if (item.metadataList.length > 0) {
+
+ }
+
+ }
+
+
diff --git a/src/ui/src/components/visualizer/info_panel.scss b/src/ui/src/components/visualizer/info_panel.scss
index 0906f6f6..874b7584 100644
--- a/src/ui/src/components/visualizer/info_panel.scss
+++ b/src/ui/src/components/visualizer/info_panel.scss
@@ -92,13 +92,33 @@
background-color: #f6f6f6;
.locator-icon-container {
- opacity: .8;
+ opacity: 0.8;
}
}
}
}
}
+ .target-op-container {
+ display: flex;
+ align-items: center;
+ color: #999;
+ font-weight: normal;
+
+ mat-icon.arrow {
+ font-size: 12px;
+ height: 12px;
+ width: 12px;
+ margin: 0 4px;
+ }
+ }
+
+ .source-op-node-label {
+ color: #999;
+ font-weight: normal;
+ margin-left: 6px;
+ }
+
.metadata-table {
margin-top: 3px;
margin-left: 18px;
@@ -130,7 +150,7 @@
box-sizing: border-box;
background-color: white;
user-select: none;
- color: rgba(0, 0, 0, .87);
+ color: rgba(0, 0, 0, 0.87);
&.input,
&.output,
@@ -254,7 +274,7 @@
cursor: pointer;
&:hover .locator-icon-container {
- opacity: .8;
+ opacity: 0.8;
}
&.search-match {
@@ -341,7 +361,7 @@
display: flex;
align-items: center;
justify-content: center;
- opacity: .5;
+ opacity: 0.5;
margin-left: 4px;
&.left {
@@ -355,7 +375,7 @@
}
&:hover {
- opacity: .8;
+ opacity: 0.8;
}
mat-icon {
@@ -370,12 +390,12 @@
display: flex;
align-items: center;
justify-content: center;
- opacity: .5;
+ opacity: 0.5;
padding: 0 11px 0 20px;
cursor: pointer;
&:hover {
- opacity: .8;
+ opacity: 0.8;
}
mat-icon {
@@ -467,4 +487,4 @@
font-size: 12px;
padding: 3px 0;
}
-}
\ No newline at end of file
+}
diff --git a/src/ui/src/components/visualizer/info_panel.ts b/src/ui/src/components/visualizer/info_panel.ts
index e3769570..9aacfff9 100644
--- a/src/ui/src/components/visualizer/info_panel.ts
+++ b/src/ui/src/components/visualizer/info_panel.ts
@@ -43,10 +43,12 @@ import {AppService} from './app_service';
import {TENSOR_TAG_METADATA_KEY, TENSOR_VALUES_KEY} from './common/consts';
import {GroupNode, ModelGraph, ModelNode, OpNode} from './common/model_graph';
import {
+ IncomingEdge,
KeyValue,
KeyValueList,
KeyValuePairs,
NodeDataProviderRunInfo,
+ OutgoingEdge,
SearchMatchAttr,
SearchMatchInputMetadata,
SearchMatchOutputMetadata,
@@ -84,6 +86,8 @@ enum SectionLabel {
IDENTICAL_GROUPS = 'Identical groups',
INPUTS = 'inputs',
OUTPUTS = 'outputs',
+ GROUP_INPUTS = 'layer inputs',
+ GROUP_OUTPUTS = 'layer outputs',
}
interface InfoItem {
@@ -102,18 +106,21 @@ interface InfoItem {
interface OutputItem {
index: number;
tensorTag: string;
+ sourceOpNode: OpNode;
outputId: string;
metadataList: OutputItemMetadata[];
+ showSourceOpNode?: boolean;
}
interface OutputItemMetadata extends KeyValue {
connectedNodes?: OpNode[];
}
-interface FlatInputItem {
+interface InputItem {
index: number;
opNode: OpNode;
metadataList: KeyValueList;
+ targetOpNode?: OpNode;
}
const MIN_WIDTH = 64;
@@ -163,14 +170,21 @@ export class InfoPanel {
@HostBinding('style.min-width.px') minWidth = DEFAULT_WIDTH;
sections: InfoSection[] = [];
- flatInputItems: FlatInputItem[] = [];
+ inputItems: InputItem[] = [];
+ inputItemsForCurPage: InputItem[] = [];
outputItems: OutputItem[] = [];
outputItemsForCurPage: OutputItem[] = [];
+ groupInputItems: InputItem[] = [];
+ groupInputItemsForCurPage: InputItem[] = [];
+ groupOutputItems: OutputItem[] = [];
+ groupOutputItemsForCurPage: OutputItem[] = [];
identicalGroupNodes: GroupNode[] = [];
identicalGroupsData?: TreeNode[];
curRendererId = '';
curInputsCount = 0;
curOutputsCount = 0;
+ curGroupInputsCount = 0;
+ curGroupOutputsCount = 0;
resizing = false;
hide = false;
@@ -211,8 +225,6 @@ export class InfoPanel {
private curSearchAttrMatches: SearchMatchAttr[] = [];
private curSearchInputMatches: SearchMatchInputMetadata[] = [];
private curSearchOutputMatches: SearchMatchOutputMetadata[] = [];
- private inputSourceNodes: OpNode[] = [];
- private inputMetadataList: KeyValuePairs[] = [];
private savedWidth = 0;
constructor(
@@ -329,18 +341,29 @@ export class InfoPanel {
}
handleInputPaginatorChanged(curPageIndex: number) {
- const curInputSourceNodes = this.inputSourceNodes.slice(
+ this.inputItemsForCurPage = this.inputItems.slice(
curPageIndex * this.ioPageSize,
(curPageIndex + 1) * this.ioPageSize,
);
- const curInputMetadataList = this.inputMetadataList.slice(
+ this.changeDetectorRef.markForCheck();
+
+ setTimeout(() => {
+ this.updateInputValueContentsExpandable();
+ });
+ }
+
+ handleOutputPaginatorChanged(curPageIndex: number) {
+ this.outputItemsForCurPage = this.outputItems.slice(
curPageIndex * this.ioPageSize,
(curPageIndex + 1) * this.ioPageSize,
);
- this.flatInputItems = this.genInputFlatItems(
+ this.changeDetectorRef.markForCheck();
+ }
+
+ handleGroupInputPaginatorChanged(curPageIndex: number) {
+ this.groupInputItemsForCurPage = this.groupInputItems.slice(
curPageIndex * this.ioPageSize,
- curInputSourceNodes,
- curInputMetadataList,
+ (curPageIndex + 1) * this.ioPageSize,
);
this.changeDetectorRef.markForCheck();
@@ -349,8 +372,8 @@ export class InfoPanel {
});
}
- handleOutputPaginatorChanged(curPageIndex: number) {
- this.outputItemsForCurPage = this.outputItems.slice(
+ handleGroupOutputPaginatorChanged(curPageIndex: number) {
+ this.groupOutputItemsForCurPage = this.groupOutputItems.slice(
curPageIndex * this.ioPageSize,
(curPageIndex + 1) * this.ioPageSize,
);
@@ -432,7 +455,7 @@ export class InfoPanel {
handleToggleInputOpNodeVisibility(
nodeId: string,
- allItems: FlatInputItem[],
+ allItems: InputItem[],
event: MouseEvent,
) {
event.stopPropagation();
@@ -464,7 +487,7 @@ export class InfoPanel {
}
handleToggleOutputVisibility(
- outputId: string,
+ item: OutputItem,
items: OutputItem[],
event: MouseEvent,
) {
@@ -472,31 +495,39 @@ export class InfoPanel {
if (event.altKey) {
this.splitPaneService.setOutputVisible(
- outputId,
- items.map((item) => item.outputId),
+ item.sourceOpNode.id,
+ item.outputId,
+ items.map((item) => ({
+ nodeId: item.sourceOpNode.id,
+ outputId: item.outputId,
+ })),
);
} else {
- this.splitPaneService.toggleOutputVisibility(outputId);
+ this.splitPaneService.toggleOutputVisibility(
+ item.sourceOpNode.id,
+ item.outputId,
+ );
}
}
- getOutputToggleVisible(outputId: string): boolean {
- return this.splitPaneService.getOutputVisible(outputId);
+ getOutputToggleVisible(item: OutputItem): boolean {
+ return this.splitPaneService.getOutputVisible(
+ item.sourceOpNode.id,
+ item.outputId,
+ );
}
- getOutputToggleVisibilityIcon(outputId: string): string {
- return this.getOutputToggleVisible(outputId)
- ? 'visibility'
- : 'visibility_off';
+ getOutputToggleVisibilityIcon(item: OutputItem): string {
+ return this.getOutputToggleVisible(item) ? 'visibility' : 'visibility_off';
}
- getOutputToggleVisibilityTooltip(outputId: string): string {
- return this.getOutputToggleVisible(outputId)
+ getOutputToggleVisibilityTooltip(item: OutputItem): string {
+ return this.getOutputToggleVisible(item)
? 'Click to hide highlight'
: 'Click to show highlight';
}
- getInputName(item: FlatInputItem): string {
+ getInputName(item: InputItem): string {
const tensorTagItem = item.metadataList.find(
(item) => item.key === TENSOR_TAG_METADATA_KEY,
);
@@ -506,7 +537,7 @@ export class InfoPanel {
: item.opNode.label;
}
- getInputTensorTag(item: FlatInputItem): string {
+ getInputTensorTag(item: InputItem): string {
const tensorTagItem = item.metadataList.find(
(item) => item.key === TENSOR_TAG_METADATA_KEY,
);
@@ -559,7 +590,7 @@ export class InfoPanel {
get showInputPaginator(): boolean {
return (
- this.inputSourceNodes.length > this.ioPageSize &&
+ this.inputItems.length > this.ioPageSize &&
!this.isSectionCollapsed(SectionLabel.INPUTS)
);
}
@@ -571,6 +602,20 @@ export class InfoPanel {
);
}
+ get showGroupInputPaginator(): boolean {
+ return (
+ this.groupInputItems.length > this.ioPageSize &&
+ !this.isSectionCollapsed(SectionLabel.GROUP_INPUTS)
+ );
+ }
+
+ get showGroupOutputPaginator(): boolean {
+ return (
+ this.groupOutputItems.length > this.ioPageSize &&
+ !this.isSectionCollapsed(SectionLabel.GROUP_OUTPUTS)
+ );
+ }
+
get showIdenticalGroupsPaginator(): boolean {
return (
this.identicalGroupNodes.length > this.ioPageSize &&
@@ -602,10 +647,10 @@ export class InfoPanel {
private genInfoData() {
this.sections = [];
- this.flatInputItems = [];
- this.inputSourceNodes = [];
- this.inputMetadataList = [];
+ this.inputItems = [];
this.outputItems = [];
+ this.groupInputItems = [];
+ this.groupOutputItems = [];
this.identicalGroupNodes = [];
this.identicalGroupsData = undefined;
@@ -617,6 +662,9 @@ export class InfoPanel {
this.genInputsOutputsData();
} else if (isGroupNode(this.curSelectedNode)) {
this.genInfoDataForSelectedGroupNode();
+ if (this.appService.config()?.highlightLayerNodeInputsOutputs) {
+ this.genGroupInputsOutputsData();
+ }
}
}
}
@@ -777,86 +825,48 @@ export class InfoPanel {
const selectedOpNode = this.curSelectedNode as OpNode;
const incomingEdges = selectedOpNode.incomingEdges || [];
- this.inputMetadataList = [];
- this.inputSourceNodes = [];
- this.flatInputItems = [];
- for (const edge of incomingEdges) {
+ this.inputItems = [];
+ for (let i = 0; i < incomingEdges.length; i++) {
+ const edge = incomingEdges[i];
+ const metadataList = this.genInputMetadataList(selectedOpNode, edge);
const sourceOpNode = this.curModelGraph?.nodesById[
edge.sourceNodeId
] as OpNode;
- this.inputSourceNodes.push(sourceOpNode);
- const metadata =
- (selectedOpNode.inputsMetadata || {})[edge.targetNodeInputId] || {};
- // Merge the corresponding output metadata with the current input
- // metadata.
- const sourceNodeOutputMetadata = {
- ...((sourceOpNode.outputsMetadata || {})[edge.sourceNodeOutputId] ||
- {}),
- };
- for (const key of Object.keys(sourceNodeOutputMetadata)) {
- if (metadata[key] == null && key !== TENSOR_TAG_METADATA_KEY) {
- metadata[key] = sourceNodeOutputMetadata[key];
- }
- }
- this.inputMetadataList.push(metadata);
- }
- this.curInputsCount = this.inputSourceNodes.length;
- if (incomingEdges.length > 0) {
- const curInputSourceNodes = this.inputSourceNodes.slice(
- 0,
- this.ioPageSize,
- );
- const curInputMetadataList = this.inputMetadataList.slice(
- 0,
- this.ioPageSize,
- );
- this.flatInputItems = this.genInputFlatItems(
- 0,
- curInputSourceNodes,
- curInputMetadataList,
- );
+ this.inputItems.push({
+ index: i,
+ opNode: sourceOpNode,
+ metadataList,
+ });
}
+ this.curInputsCount = this.inputItems.length;
+ this.inputItemsForCurPage = this.inputItems.slice(0, this.ioPageSize);
+
// Outputs.
this.outputItems = [];
const outputsMetadata = selectedOpNode.outputsMetadata || {};
const outgoingEdges = selectedOpNode.outgoingEdges || [];
let index = 0;
for (const outputId of Object.keys(outputsMetadata)) {
- // The metadata for the current output tensor.
- const metadataList: OutputItemMetadata[] = [];
- let tensorTag = '';
- for (const metadataKey of Object.keys(outputsMetadata[outputId])) {
- const value = outputsMetadata[outputId][metadataKey];
- if (metadataKey === TENSOR_TAG_METADATA_KEY) {
- tensorTag = value;
- }
- // Hide all metadata keys that start with '__'.
- if (metadataKey.startsWith('__')) {
- continue;
- }
- metadataList.push({
- key: metadataKey,
- value,
- });
- }
- metadataList.sort((a, b) => a.key.localeCompare(b.key));
-
// The connected nodes.
const connectedNodes = outgoingEdges
.filter((edge) => edge.sourceNodeOutputId === outputId)
.map(
(edge) => this.curModelGraph!.nodesById[edge.targetNodeId],
) as OpNode[];
- metadataList.push({
- key: this.outputMetadataConnectedTo,
- value: '',
+
+ // Metadata list.
+ const {metadataList, tensorTag} = this.genOutputMetadataList(
+ outgoingEdges,
+ outputsMetadata[outputId],
connectedNodes,
- });
+ );
+
this.outputItems.push({
index,
tensorTag,
outputId,
+ sourceOpNode: selectedOpNode,
metadataList,
});
index++;
@@ -865,6 +875,178 @@ export class InfoPanel {
this.outputItemsForCurPage = this.outputItems.slice(0, this.ioPageSize);
}
+ private genGroupInputsOutputsData() {
+ if (!this.curModelGraph || !this.curSelectedNode) {
+ return;
+ }
+
+ // Inputs.
+ const selectedGroupNode = this.curSelectedNode as GroupNode;
+ const seenInputNodeIds = new Set();
+
+ this.groupInputItems = [];
+ let index = 0;
+ for (const nodeId of selectedGroupNode.descendantsOpNodeIds || []) {
+ const descendantIds = new Set(
+ selectedGroupNode.descendantsOpNodeIds || [],
+ );
+ const opNode = this.curModelGraph?.nodesById[nodeId] as OpNode;
+ const incomingEdges = opNode.incomingEdges || [];
+
+ for (const edge of incomingEdges) {
+ const sourceOpNode = this.curModelGraph?.nodesById[
+ edge.sourceNodeId
+ ] as OpNode;
+
+ // Ignore if the source op node is within the layer.
+ if (descendantIds.has(sourceOpNode.id)) {
+ continue;
+ }
+
+ // Dedup.
+ if (seenInputNodeIds.has(sourceOpNode.id)) {
+ continue;
+ }
+ seenInputNodeIds.add(sourceOpNode.id);
+
+ const metadataList = this.genInputMetadataList(opNode, edge);
+ this.groupInputItems.push({
+ index: index++,
+ opNode: sourceOpNode,
+ metadataList,
+ targetOpNode: opNode,
+ });
+ }
+ }
+
+ this.curGroupInputsCount = this.groupInputItems.length;
+ this.groupInputItemsForCurPage = this.groupInputItems.slice(
+ 0,
+ this.ioPageSize,
+ );
+
+ // Outputs.
+ this.groupOutputItems = [];
+ index = 0;
+ for (const nodeId of selectedGroupNode.descendantsOpNodeIds || []) {
+ const descendantIds = new Set(
+ selectedGroupNode.descendantsOpNodeIds || [],
+ );
+ const opNode = this.curModelGraph?.nodesById[nodeId] as OpNode;
+
+ const outputsMetadata = opNode.outputsMetadata || {};
+ const outgoingEdges = opNode.outgoingEdges || [];
+ for (const outputId of Object.keys(outputsMetadata)) {
+ // The connected nodes.
+ const connectedNodes = outgoingEdges
+ .filter((edge) => !descendantIds.has(edge.targetNodeId))
+ .filter((edge) => edge.sourceNodeOutputId === outputId)
+ .map(
+ (edge) => this.curModelGraph!.nodesById[edge.targetNodeId],
+ ) as OpNode[];
+ if (connectedNodes.length === 0) {
+ continue;
+ }
+
+ // Metadata list.
+ const {metadataList, tensorTag} = this.genOutputMetadataList(
+ outgoingEdges,
+ outputsMetadata[outputId],
+ connectedNodes,
+ );
+
+ this.groupOutputItems.push({
+ index,
+ tensorTag,
+ outputId,
+ sourceOpNode: opNode,
+ metadataList,
+ showSourceOpNode: true,
+ });
+ index++;
+ }
+ }
+
+ this.curGroupOutputsCount = this.groupOutputItems.length;
+ this.groupOutputItemsForCurPage = this.groupOutputItems.slice(
+ 0,
+ this.ioPageSize,
+ );
+ }
+
+ private genInputMetadataList(
+ opNode: OpNode,
+ edge: IncomingEdge,
+ ): KeyValueList {
+ const sourceOpNode = this.curModelGraph?.nodesById[
+ edge.sourceNodeId
+ ] as OpNode;
+ const metadata =
+ (opNode.inputsMetadata || {})[edge.targetNodeInputId] || {};
+ // Merge the corresponding output metadata with the current input
+ // metadata.
+ const sourceNodeOutputMetadata = {
+ ...((sourceOpNode.outputsMetadata || {})[edge.sourceNodeOutputId] || {}),
+ };
+ for (const key of Object.keys(sourceNodeOutputMetadata)) {
+ if (metadata[key] == null && key !== TENSOR_TAG_METADATA_KEY) {
+ metadata[key] = sourceNodeOutputMetadata[key];
+ }
+ }
+ // Sort by key.
+ const metadataList: KeyValueList = [];
+ Object.entries(metadata).forEach(([key, value]) => {
+ metadataList.push({key, value});
+ });
+ metadataList.sort((a, b) => a.key.localeCompare(b.key));
+ // Add namespace to metadata.
+ metadataList.push({
+ key: this.inputMetadataNamespaceKey,
+ value: getNamespaceLabel(sourceOpNode),
+ });
+ // Add tensor values to metadata if existed.
+ const attrs = sourceOpNode.attrs || {};
+ if (attrs[TENSOR_VALUES_KEY]) {
+ metadataList.push({
+ key: this.inputMetadataValuesKey,
+ value: attrs[TENSOR_VALUES_KEY],
+ });
+ }
+ return metadataList;
+ }
+
+ private genOutputMetadataList(
+ outgoingEdges: OutgoingEdge[],
+ outputMetadata: KeyValuePairs,
+ connectedNodes: OpNode[],
+ ) {
+ const metadataList: OutputItemMetadata[] = [];
+ let tensorTag = '';
+ for (const metadataKey of Object.keys(outputMetadata)) {
+ const value = outputMetadata[metadataKey];
+ if (metadataKey === TENSOR_TAG_METADATA_KEY) {
+ tensorTag = value;
+ }
+ // Hide all metadata keys that start with '__'.
+ if (metadataKey.startsWith('__')) {
+ continue;
+ }
+ metadataList.push({
+ key: metadataKey,
+ value,
+ });
+ }
+ metadataList.sort((a, b) => a.key.localeCompare(b.key));
+
+ metadataList.push({
+ key: this.outputMetadataConnectedTo,
+ value: '',
+ connectedNodes,
+ });
+
+ return {metadataList, tensorTag};
+ }
+
private genInfoDataForSelectedGroupNode() {
if (!this.curModelGraph || !this.curSelectedNode) {
return;
@@ -1011,39 +1193,6 @@ export class InfoPanel {
animate();
}
- private genInputFlatItems(
- startIndex: number,
- inputSourceNodes: OpNode[],
- inputMetadataList: KeyValuePairs[],
- ): FlatInputItem[] {
- const flatInputItems: FlatInputItem[] = [];
- for (let i = 0; i < inputSourceNodes.length; i++) {
- const sourceNode = inputSourceNodes[i];
- const metadataList: KeyValueList = [];
- Object.entries(inputMetadataList[i]).forEach(([key, value]) => {
- metadataList.push({key, value});
- });
- metadataList.sort((a, b) => a.key.localeCompare(b.key));
- metadataList.push({
- key: this.inputMetadataNamespaceKey,
- value: getNamespaceLabel(inputSourceNodes[i]),
- });
- const attrs = sourceNode.attrs || {};
- if (attrs[TENSOR_VALUES_KEY]) {
- metadataList.push({
- key: this.inputMetadataValuesKey,
- value: attrs[TENSOR_VALUES_KEY],
- });
- }
- flatInputItems.push({
- index: i + startIndex,
- opNode: sourceNode,
- metadataList,
- });
- }
- return flatInputItems;
- }
-
private updateInputValueContentsExpandable() {
for (let i = 0; i < this.inputValueContents.length; i++) {
const valueContent = this.inputValueContents.get(i)?.nativeElement;
diff --git a/src/ui/src/components/visualizer/split_pane_service.ts b/src/ui/src/components/visualizer/split_pane_service.ts
index d2825a6b..69f40528 100644
--- a/src/ui/src/components/visualizer/split_pane_service.ts
+++ b/src/ui/src/components/visualizer/split_pane_service.ts
@@ -28,7 +28,8 @@ export class SplitPaneService {
readonly hiddenInputOpNodeIds = signal>({});
/**
- * The output ids that are hidden (i.e. not highlighted) in the model graph.
+ * The {nodeId}___{outputId} that are hidden (i.e. not highlighted) in the
+ * model graph.
*/
readonly hiddenOutputIds = signal>({});
@@ -71,24 +72,31 @@ export class SplitPaneService {
}
}
- toggleOutputVisibility(outputId: string) {
+ toggleOutputVisibility(nodeId: string, outputId: string) {
this.hiddenOutputIds.update((ids) => {
- const visible = ids[outputId] === true;
+ const key = `${nodeId}___${outputId}`;
+ const visible = ids[key] === true;
if (!visible) {
- ids[outputId] = true;
+ ids[key] = true;
} else {
- delete ids[outputId];
+ delete ids[key];
}
return {...ids};
});
}
- setOutputVisible(outputId: string, allOutputIds: string[]) {
+ setOutputVisible(
+ nodeId: string,
+ outputId: string,
+ allNodeAndOutputIds: Array<{nodeId: string; outputId: string}>,
+ ) {
// Check if the output is the only visible one.
- let isNodeTheOnlyVisibleOne = this.hiddenOutputIds()[outputId] !== true;
- for (const id of allOutputIds) {
- if (id !== outputId) {
- if (!this.hiddenOutputIds()[id]) {
+ const key = `${nodeId}___${outputId}`;
+ let isNodeTheOnlyVisibleOne = this.hiddenOutputIds()[key] !== true;
+ for (const {nodeId, outputId} of allNodeAndOutputIds) {
+ const curKey = `${nodeId}___${outputId}`;
+ if (curKey !== key) {
+ if (!this.hiddenOutputIds()[curKey]) {
isNodeTheOnlyVisibleOne = false;
}
}
@@ -101,9 +109,10 @@ export class SplitPaneService {
// If not, hide the other outputs.
else {
const ids: Record = {};
- for (const id of allOutputIds) {
- if (id !== outputId) {
- ids[id] = true;
+ for (const {nodeId, outputId} of allNodeAndOutputIds) {
+ const curKey = `${nodeId}___${outputId}`;
+ if (curKey !== key) {
+ ids[curKey] = true;
}
}
this.hiddenOutputIds.set(ids);
@@ -114,8 +123,9 @@ export class SplitPaneService {
return !this.hiddenInputOpNodeIds()[nodeId];
}
- getOutputVisible(outputId: string): boolean {
- return !this.hiddenOutputIds()[outputId];
+ getOutputVisible(nodeId: string, outputId: string): boolean {
+ const key = `${nodeId}___${outputId}`;
+ return !this.hiddenOutputIds()[key];
}
resetInputOutputHiddenIds() {
diff --git a/src/ui/src/components/visualizer/webgl_renderer_io_highlight_service.ts b/src/ui/src/components/visualizer/webgl_renderer_io_highlight_service.ts
index fa370ec0..b385b0f0 100644
--- a/src/ui/src/components/visualizer/webgl_renderer_io_highlight_service.ts
+++ b/src/ui/src/components/visualizer/webgl_renderer_io_highlight_service.ts
@@ -352,7 +352,7 @@ export class WebglRendererIoHighlightService {
getHighlightedIncomingNodesAndEdges(
hiddenInputNodeIds: Record,
- selectedNode?: OpNode,
+ selectedNode?: ModelNode,
options?: {
ignoreEdgesWithinSameNamespace?: boolean;
reuseRenderedEdgeCurvePoints?: boolean;
@@ -364,202 +364,227 @@ export class WebglRendererIoHighlightService {
options?.reuseRenderedEdgeCurvePoints ?? false;
if (!selectedNode) {
- selectedNode = this.webglRenderer.curModelGraph.nodesById[
- this.webglRenderer.selectedNodeId
- ] as OpNode;
+ selectedNode =
+ this.webglRenderer.curModelGraph.nodesById[
+ this.webglRenderer.selectedNodeId
+ ];
}
const renderedEdges: ModelEdge[] = [];
const highlightedNodes: ModelNode[] = [];
const inputsByHighlightedNode: Record = {};
const overlayEdges: OverlayModelEdge[] = [];
- for (const incomingEdge of selectedNode.incomingEdges || []) {
- if (hiddenInputNodeIds[incomingEdge.sourceNodeId]) {
- continue;
- }
- const sourceNode = this.webglRenderer.curModelGraph.nodesById[
- incomingEdge.sourceNodeId
- ] as OpNode;
- if (!sourceNode) {
- continue;
- }
-
- if (
- ignoreEdgesWithinSameNamespace &&
- sourceNode.namespace === selectedNode.namespace
- ) {
- continue;
+ const opNodes: OpNode[] = [];
+ const ignoredIncomingNodesIds = new Set();
+ const seenIncomingNodesIds = new Set();
+ if (isOpNode(selectedNode)) {
+ opNodes.push(selectedNode);
+ } else if (isGroupNode(selectedNode)) {
+ for (const id of selectedNode.descendantsOpNodeIds || []) {
+ const node = this.webglRenderer.curModelGraph.nodesById[id] as OpNode;
+ opNodes.push(node);
+ ignoredIncomingNodesIds.add(id);
}
+ }
- // Find the common namespace prefix.
- const commonNamespace = findCommonNamespace(
- sourceNode.namespace,
- selectedNode.namespace,
- );
+ for (const opNode of opNodes) {
+ for (const incomingEdge of opNode.incomingEdges || []) {
+ if (hiddenInputNodeIds[incomingEdge.sourceNodeId]) {
+ continue;
+ }
+ const sourceNode = this.webglRenderer.curModelGraph.nodesById[
+ incomingEdge.sourceNodeId
+ ] as OpNode;
+ if (!sourceNode) {
+ continue;
+ }
- // Go from the given node to all its ns ancestors, find the last collapsed
- // node before reaching the given namespace. If all ancestor nodes are
- // expanded, return the given node.
- const highlightedNode = this.getLastCollapsedAncestorNode(
- sourceNode,
- commonNamespace,
- );
- highlightedNodes.push(highlightedNode);
+ if (ignoredIncomingNodesIds.has(sourceNode.id)) {
+ continue;
+ }
- // Update inputsByHighlighedNode.
- if (inputsByHighlightedNode[highlightedNode.id] == null) {
- inputsByHighlightedNode[highlightedNode.id] = [];
- }
- inputsByHighlightedNode[highlightedNode.id].push(sourceNode);
-
- // Find the existing edge in the common namespace that connects two
- // nodes n1 and n2 where n1 contains `sourceNode` and n2 contains
- // `node`.
- const renderedEdge = this.findEdgeConnectingTwoNodesInNamespace(
- commonNamespace,
- sourceNode.id,
- selectedNode.id,
- );
+ if (seenIncomingNodesIds.has(sourceNode.id)) {
+ continue;
+ }
+ seenIncomingNodesIds.add(sourceNode.id);
- // Start to construct an edge from the source node to the selected node.
- //
- const points: Point[] = [];
- const curvePoints: Point[] = [];
-
- if (renderedEdge) {
- renderedEdges.push(renderedEdge);
- const renderedEdgeCurvePoints = renderedEdge.curvePoints || [];
-
- // Add a point from the highlighted node that connects to the first
- // point of the rendered edge above.
- const renderedEdgeFromNode =
- this.webglRenderer.curModelGraph.nodesById[renderedEdge.fromNodeId];
- if (renderedEdge.fromNodeId !== highlightedNode.id) {
- const renderedEdgeStartX =
- renderedEdge.points[0].x + (renderedEdgeFromNode.globalX || 0);
- const renderedEdgeStartY =
- renderedEdge.points[0].y + (renderedEdgeFromNode.globalY || 0);
- const startPt = this.getBestAnchorPointOnNode(
- renderedEdgeStartX,
- renderedEdgeStartY,
- highlightedNode,
- );
- points.push({
- x: startPt.x - (highlightedNode.globalX || 0),
- y: startPt.y - (highlightedNode.globalY || 0),
- });
- if (reuseRenderedEdgeCurvePoints) {
- curvePoints.push(
- {
- x: startPt.x - (highlightedNode.globalX || 0),
- y: startPt.y - (highlightedNode.globalY || 0),
- },
- {
- x:
- renderedEdgeCurvePoints[0].x -
- (highlightedNode.globalX || 0) +
- (renderedEdgeFromNode.globalX || 0),
- y:
- renderedEdgeCurvePoints[0].y -
- (highlightedNode.globalY || 0) +
- (renderedEdgeFromNode.globalY || 0),
- },
- );
- }
+ if (
+ ignoreEdgesWithinSameNamespace &&
+ sourceNode.namespace === opNode.namespace
+ ) {
+ continue;
}
- // Add the points in rendered edge.
- let targetPoints: Point[] = points;
- let sourcePoints: Point[] = renderedEdge.points;
- if (reuseRenderedEdgeCurvePoints) {
- targetPoints = curvePoints;
- sourcePoints = renderedEdgeCurvePoints;
+ // Find the common namespace prefix.
+ const commonNamespace = findCommonNamespace(
+ sourceNode.namespace,
+ opNode.namespace,
+ );
+
+ // Go from the given node to all its ns ancestors, find the last collapsed
+ // node before reaching the given namespace. If all ancestor nodes are
+ // expanded, return the given node.
+ const highlightedNode = this.getLastCollapsedAncestorNode(
+ sourceNode,
+ commonNamespace,
+ );
+ highlightedNodes.push(highlightedNode);
+
+ // Update inputsByHighlighedNode.
+ if (inputsByHighlightedNode[highlightedNode.id] == null) {
+ inputsByHighlightedNode[highlightedNode.id] = [];
}
- targetPoints.push(
- ...sourcePoints.map((pt) => {
- return {
- x:
- pt.x -
- (highlightedNode.globalX || 0) +
- (renderedEdgeFromNode.globalX || 0),
- y:
- pt.y -
- (highlightedNode.globalY || 0) +
- (renderedEdgeFromNode.globalY || 0),
- };
- }),
+ inputsByHighlightedNode[highlightedNode.id].push(sourceNode);
+
+ // Find the existing edge in the common namespace that connects two
+ // nodes n1 and n2 where n1 contains `sourceNode` and n2 contains
+ // `node`.
+ const renderedEdge = this.findEdgeConnectingTwoNodesInNamespace(
+ commonNamespace,
+ sourceNode.id,
+ opNode.id,
);
- // Add a point from the selected node that connects to the last point of
- // the rendered edge.
- if (renderedEdge.toNodeId !== selectedNode?.id) {
- const renderedEdgeLastX =
- renderedEdge.points[renderedEdge.points.length - 1].x +
- (renderedEdgeFromNode.globalX || 0);
- const renderedEdgeLastY =
- renderedEdge.points[renderedEdge.points.length - 1].y +
- (renderedEdgeFromNode.globalY || 0);
- const endPt = this.getBestAnchorPointOnNode(
- renderedEdgeLastX,
- renderedEdgeLastY,
- selectedNode,
- );
- if (!reuseRenderedEdgeCurvePoints) {
+ // Start to construct an edge from the source node to the selected node.
+ //
+ const points: Point[] = [];
+ const curvePoints: Point[] = [];
+
+ if (renderedEdge) {
+ renderedEdges.push(renderedEdge);
+ const renderedEdgeCurvePoints = renderedEdge.curvePoints || [];
+
+ // Add a point from the highlighted node that connects to the first
+ // point of the rendered edge above.
+ const renderedEdgeFromNode =
+ this.webglRenderer.curModelGraph.nodesById[renderedEdge.fromNodeId];
+ if (renderedEdge.fromNodeId !== highlightedNode.id) {
+ const renderedEdgeStartX =
+ renderedEdge.points[0].x + (renderedEdgeFromNode.globalX || 0);
+ const renderedEdgeStartY =
+ renderedEdge.points[0].y + (renderedEdgeFromNode.globalY || 0);
+ const startPt = this.getBestAnchorPointOnNode(
+ renderedEdgeStartX,
+ renderedEdgeStartY,
+ highlightedNode,
+ );
points.push({
- x: endPt.x - (highlightedNode.globalX || 0),
- y: endPt.y - (highlightedNode.globalY || 0),
+ x: startPt.x - (highlightedNode.globalX || 0),
+ y: startPt.y - (highlightedNode.globalY || 0),
});
- } else {
- curvePoints.push(
- {
+ if (reuseRenderedEdgeCurvePoints) {
+ curvePoints.push(
+ {
+ x: startPt.x - (highlightedNode.globalX || 0),
+ y: startPt.y - (highlightedNode.globalY || 0),
+ },
+ {
+ x:
+ renderedEdgeCurvePoints[0].x -
+ (highlightedNode.globalX || 0) +
+ (renderedEdgeFromNode.globalX || 0),
+ y:
+ renderedEdgeCurvePoints[0].y -
+ (highlightedNode.globalY || 0) +
+ (renderedEdgeFromNode.globalY || 0),
+ },
+ );
+ }
+ }
+
+ // Add the points in rendered edge.
+ let targetPoints: Point[] = points;
+ let sourcePoints: Point[] = renderedEdge.points;
+ if (reuseRenderedEdgeCurvePoints) {
+ targetPoints = curvePoints;
+ sourcePoints = renderedEdgeCurvePoints;
+ }
+ targetPoints.push(
+ ...sourcePoints.map((pt) => {
+ return {
x:
- renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1]
- .x -
+ pt.x -
(highlightedNode.globalX || 0) +
(renderedEdgeFromNode.globalX || 0),
y:
- renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1]
- .y -
+ pt.y -
(highlightedNode.globalY || 0) +
(renderedEdgeFromNode.globalY || 0),
- },
- {
+ };
+ }),
+ );
+
+ // Add a point from the selected node that connects to the last point of
+ // the rendered edge.
+ if (renderedEdge.toNodeId !== opNode?.id && isOpNode(selectedNode)) {
+ const renderedEdgeLastX =
+ renderedEdge.points[renderedEdge.points.length - 1].x +
+ (renderedEdgeFromNode.globalX || 0);
+ const renderedEdgeLastY =
+ renderedEdge.points[renderedEdge.points.length - 1].y +
+ (renderedEdgeFromNode.globalY || 0);
+ const endPt = this.getBestAnchorPointOnNode(
+ renderedEdgeLastX,
+ renderedEdgeLastY,
+ opNode,
+ );
+ if (!reuseRenderedEdgeCurvePoints) {
+ points.push({
x: endPt.x - (highlightedNode.globalX || 0),
y: endPt.y - (highlightedNode.globalY || 0),
- },
- );
+ });
+ } else {
+ curvePoints.push(
+ {
+ x:
+ renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1]
+ .x -
+ (highlightedNode.globalX || 0) +
+ (renderedEdgeFromNode.globalX || 0),
+ y:
+ renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1]
+ .y -
+ (highlightedNode.globalY || 0) +
+ (renderedEdgeFromNode.globalY || 0),
+ },
+ {
+ x: endPt.x - (highlightedNode.globalX || 0),
+ y: endPt.y - (highlightedNode.globalY || 0),
+ },
+ );
+ }
}
+ } else if (
+ isGroupNode(highlightedNode) ||
+ (isOpNode(highlightedNode) && !highlightedNode.hideInLayout)
+ ) {
+ (reuseRenderedEdgeCurvePoints ? curvePoints : points).push(
+ ...this.getDirectEdgeBetweenNodes(highlightedNode, opNode),
+ );
}
- } else if (
- isGroupNode(highlightedNode) ||
- (isOpNode(highlightedNode) && !highlightedNode.hideInLayout)
- ) {
- (reuseRenderedEdgeCurvePoints ? curvePoints : points).push(
- ...this.getDirectEdgeBetweenNodes(highlightedNode, selectedNode),
- );
- }
- // Use these points to form an edge and add it as an overlay edge.
- if (!reuseRenderedEdgeCurvePoints) {
- if (points.length > 0) {
- overlayEdges.push({
- id: `overlay_${highlightedNode.id}___${selectedNode.id}`,
- fromNodeId: highlightedNode.id,
- toNodeId: selectedNode.id,
- points,
- type: 'incoming',
- });
- }
- } else {
- if (curvePoints.length > 0) {
- overlayEdges.push({
- id: `overlay_${highlightedNode.id}___${selectedNode.id}`,
- fromNodeId: highlightedNode.id,
- toNodeId: selectedNode.id,
- points: [],
- curvePoints,
- type: 'incoming',
- });
+ // Use these points to form an edge and add it as an overlay edge.
+ if (!reuseRenderedEdgeCurvePoints) {
+ if (points.length > 0) {
+ overlayEdges.push({
+ id: `overlay_${highlightedNode.id}___${opNode.id}`,
+ fromNodeId: highlightedNode.id,
+ toNodeId: opNode.id,
+ points,
+ type: 'incoming',
+ });
+ }
+ } else {
+ if (curvePoints.length > 0) {
+ overlayEdges.push({
+ id: `overlay_${highlightedNode.id}___${opNode.id}`,
+ fromNodeId: highlightedNode.id,
+ toNodeId: opNode.id,
+ points: [],
+ curvePoints,
+ type: 'incoming',
+ });
+ }
}
}
}
@@ -574,7 +599,7 @@ export class WebglRendererIoHighlightService {
getHighlightedOutgoingNodesAndEdges(
hiddenOutputIds: Record,
- selectedNode?: OpNode,
+ selectedNode?: ModelNode,
options?: {
ignoreEdgesWithinSameNamespace?: boolean;
reuseRenderedEdgeCurvePoints?: boolean;
@@ -586,204 +611,234 @@ export class WebglRendererIoHighlightService {
options?.reuseRenderedEdgeCurvePoints ?? false;
if (!selectedNode) {
- selectedNode = this.webglRenderer.curModelGraph.nodesById[
- this.webglRenderer.selectedNodeId
- ] as OpNode;
+ selectedNode =
+ this.webglRenderer.curModelGraph.nodesById[
+ this.webglRenderer.selectedNodeId
+ ];
}
const renderedEdges: ModelEdge[] = [];
const highlightedNodes: ModelNode[] = [];
const outputsByHighlightedNode: Record = {};
const overlayEdges: OverlayModelEdge[] = [];
- for (const outgoingEdges of selectedNode.outgoingEdges || []) {
- if (hiddenOutputIds[outgoingEdges.sourceNodeOutputId]) {
- continue;
- }
-
- const targetNode = this.webglRenderer.curModelGraph.nodesById[
- outgoingEdges.targetNodeId
- ] as OpNode;
- if (!targetNode) {
- continue;
+ const opNodes: OpNode[] = [];
+ const ignoredOutgoingNodesIds = new Set();
+ const seenOutgoingNodesIds = new Set();
+ if (isOpNode(selectedNode)) {
+ opNodes.push(selectedNode);
+ } else if (isGroupNode(selectedNode)) {
+ for (const id of selectedNode.descendantsOpNodeIds || []) {
+ const node = this.webglRenderer.curModelGraph.nodesById[id] as OpNode;
+ opNodes.push(node);
+ ignoredOutgoingNodesIds.add(id);
}
+ }
- if (
- ignoreEdgesWithinSameNamespace &&
- targetNode.namespace === selectedNode.namespace
- ) {
- continue;
- }
+ for (const opNode of opNodes) {
+ for (const outgoingEdges of opNode.outgoingEdges || []) {
+ if (
+ hiddenOutputIds[`${opNode.id}___${outgoingEdges.sourceNodeOutputId}`]
+ ) {
+ continue;
+ }
- // Find the common namespace prefix.
- const commonNamespace = findCommonNamespace(
- targetNode.namespace,
- selectedNode.namespace,
- );
+ const targetNode = this.webglRenderer.curModelGraph.nodesById[
+ outgoingEdges.targetNodeId
+ ] as OpNode;
+ if (!targetNode) {
+ continue;
+ }
- // Go from the given node to all its ns ancestors, find the last
- // collapsed node before reaching the given namespace, and style it with
- // the given class. If all ancestor nodes are expanded, style the given
- // node.
- const highlightedNode = this.getLastCollapsedAncestorNode(
- targetNode,
- commonNamespace,
- );
- highlightedNodes.push(highlightedNode);
+ if (ignoredOutgoingNodesIds.has(targetNode.id)) {
+ continue;
+ }
- // Update outputsByHighlighedNode.
- if (outputsByHighlightedNode[highlightedNode.id] == null) {
- outputsByHighlightedNode[highlightedNode.id] = [];
- }
- outputsByHighlightedNode[highlightedNode.id].push(targetNode);
-
- // Find the existing edge in the common namespace that connects two
- // nodes n1 and n2 where n1 contains `sourceNode` and n2 contains
- // `node`.
- const renderedEdge = this.findEdgeConnectingTwoNodesInNamespace(
- commonNamespace,
- selectedNode.id,
- targetNode.id,
- );
+ if (seenOutgoingNodesIds.has(targetNode.id)) {
+ continue;
+ }
+ seenOutgoingNodesIds.add(targetNode.id);
- // Start to construct an edge from the selected node to target node.
- //
- const points: Point[] = [];
- const curvePoints: Point[] = [];
-
- if (renderedEdge) {
- renderedEdges.push(renderedEdge);
- const renderedEdgeCurvePoints = renderedEdge.curvePoints || [];
-
- // Add a point from the selected node that connects to the first point
- // of the rendered edge.
- const renderedEdgeFromNode =
- this.webglRenderer.curModelGraph.nodesById[renderedEdge.fromNodeId];
- if (renderedEdge.fromNodeId !== selectedNode?.id) {
- const renderedEdgeStartX =
- renderedEdge.points[0].x + (renderedEdgeFromNode.globalX || 0);
- const renderedEdgeStartY =
- renderedEdge.points[0].y + (renderedEdgeFromNode.globalY || 0);
- const endPt = this.getBestAnchorPointOnNode(
- renderedEdgeStartX,
- renderedEdgeStartY,
- selectedNode,
- );
- points.push({
- x: endPt.x - (selectedNode.globalX || 0),
- y: endPt.y - (selectedNode.globalY || 0),
- });
- if (reuseRenderedEdgeCurvePoints) {
- curvePoints.push(
- {
- x: endPt.x - (selectedNode.globalX || 0),
- y: endPt.y - (selectedNode.globalY || 0),
- },
- {
- x:
- renderedEdgeCurvePoints[0].x -
- (selectedNode.globalX || 0) +
- (renderedEdgeFromNode.globalX || 0),
- y:
- renderedEdgeCurvePoints[0].y -
- (selectedNode.globalY || 0) +
- (renderedEdgeFromNode.globalY || 0),
- },
- );
- }
+ if (
+ ignoreEdgesWithinSameNamespace &&
+ targetNode.namespace === opNode.namespace
+ ) {
+ continue;
}
- // Add the points in rendered edge.
- let targetPoints: Point[] = points;
- let sourcePoints: Point[] = renderedEdge.points;
- if (reuseRenderedEdgeCurvePoints) {
- targetPoints = curvePoints;
- sourcePoints = renderedEdgeCurvePoints;
+ // Find the common namespace prefix.
+ const commonNamespace = findCommonNamespace(
+ targetNode.namespace,
+ opNode.namespace,
+ );
+
+ // Go from the given node to all its ns ancestors, find the last
+ // collapsed node before reaching the given namespace, and style it with
+ // the given class. If all ancestor nodes are expanded, style the given
+ // node.
+ const highlightedNode = this.getLastCollapsedAncestorNode(
+ targetNode,
+ commonNamespace,
+ );
+ highlightedNodes.push(highlightedNode);
+
+ // Update outputsByHighlighedNode.
+ if (outputsByHighlightedNode[highlightedNode.id] == null) {
+ outputsByHighlightedNode[highlightedNode.id] = [];
}
- targetPoints.push(
- ...sourcePoints.map((pt) => {
- return {
- x:
- pt.x -
- (selectedNode.globalX || 0) +
- (renderedEdgeFromNode.globalX || 0),
- y:
- pt.y -
- (selectedNode.globalY || 0) +
- (renderedEdgeFromNode.globalY || 0),
- };
- }),
+ outputsByHighlightedNode[highlightedNode.id].push(targetNode);
+
+ // Find the existing edge in the common namespace that connects two
+ // nodes n1 and n2 where n1 contains `sourceNode` and n2 contains
+ // `node`.
+ const renderedEdge = this.findEdgeConnectingTwoNodesInNamespace(
+ commonNamespace,
+ opNode.id,
+ targetNode.id,
);
- // Add a point from the highlighted node that connects to the first
- // point of the rendered edge above.
- if (renderedEdge.toNodeId !== highlightedNode.id) {
- const renderedEdgeLastX =
- renderedEdge.points[renderedEdge.points.length - 1].x +
- (renderedEdgeFromNode.globalX || 0);
- const renderedEdgeLastY =
- renderedEdge.points[renderedEdge.points.length - 1].y +
- (renderedEdgeFromNode.globalY || 0);
- const startPt = this.getBestAnchorPointOnNode(
- renderedEdgeLastX,
- renderedEdgeLastY,
- highlightedNode,
- );
- if (!reuseRenderedEdgeCurvePoints) {
- points.push({
- x: startPt.x - (selectedNode.globalX || 0),
- y: startPt.y - (selectedNode.globalY || 0),
- });
- } else {
- curvePoints.push(
- {
+ // Start to construct an edge from the selected node to target node.
+ //
+ const points: Point[] = [];
+ const curvePoints: Point[] = [];
+
+ if (renderedEdge) {
+ renderedEdges.push(renderedEdge);
+ const renderedEdgeCurvePoints = renderedEdge.curvePoints || [];
+
+ const renderedEdgeFromNode =
+ this.webglRenderer.curModelGraph.nodesById[renderedEdge.fromNodeId];
+
+ // Add a point from the selected node that connects to the first point
+ // of the rendered edge.
+ if (isOpNode(selectedNode)) {
+ if (renderedEdge.fromNodeId !== opNode?.id) {
+ const renderedEdgeStartX =
+ renderedEdge.points[0].x + (renderedEdgeFromNode.globalX || 0);
+ const renderedEdgeStartY =
+ renderedEdge.points[0].y + (renderedEdgeFromNode.globalY || 0);
+ const endPt = this.getBestAnchorPointOnNode(
+ renderedEdgeStartX,
+ renderedEdgeStartY,
+ opNode,
+ );
+ points.push({
+ x: endPt.x - (opNode.globalX || 0),
+ y: endPt.y - (opNode.globalY || 0),
+ });
+ if (reuseRenderedEdgeCurvePoints) {
+ curvePoints.push(
+ {
+ x: endPt.x - (opNode.globalX || 0),
+ y: endPt.y - (opNode.globalY || 0),
+ },
+ {
+ x:
+ renderedEdgeCurvePoints[0].x -
+ (opNode.globalX || 0) +
+ (renderedEdgeFromNode.globalX || 0),
+ y:
+ renderedEdgeCurvePoints[0].y -
+ (opNode.globalY || 0) +
+ (renderedEdgeFromNode.globalY || 0),
+ },
+ );
+ }
+ }
+ }
+
+ // Add the points in rendered edge.
+ let targetPoints: Point[] = points;
+ let sourcePoints: Point[] = renderedEdge.points;
+ if (reuseRenderedEdgeCurvePoints) {
+ targetPoints = curvePoints;
+ sourcePoints = renderedEdgeCurvePoints;
+ }
+ targetPoints.push(
+ ...sourcePoints.map((pt) => {
+ return {
x:
- renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1]
- .x -
- (selectedNode.globalX || 0) +
+ pt.x -
+ (opNode.globalX || 0) +
(renderedEdgeFromNode.globalX || 0),
y:
- renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1]
- .y -
- (selectedNode.globalY || 0) +
+ pt.y -
+ (opNode.globalY || 0) +
(renderedEdgeFromNode.globalY || 0),
- },
- {
- x: startPt.x - (selectedNode.globalX || 0),
- y: startPt.y - (selectedNode.globalY || 0),
- },
+ };
+ }),
+ );
+
+ // Add a point from the highlighted node that connects to the first
+ // point of the rendered edge above.
+ if (renderedEdge.toNodeId !== highlightedNode.id) {
+ const renderedEdgeLastX =
+ renderedEdge.points[renderedEdge.points.length - 1].x +
+ (renderedEdgeFromNode.globalX || 0);
+ const renderedEdgeLastY =
+ renderedEdge.points[renderedEdge.points.length - 1].y +
+ (renderedEdgeFromNode.globalY || 0);
+ const startPt = this.getBestAnchorPointOnNode(
+ renderedEdgeLastX,
+ renderedEdgeLastY,
+ highlightedNode,
);
+ if (!reuseRenderedEdgeCurvePoints) {
+ points.push({
+ x: startPt.x - (opNode.globalX || 0),
+ y: startPt.y - (opNode.globalY || 0),
+ });
+ } else {
+ curvePoints.push(
+ {
+ x:
+ renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1]
+ .x -
+ (opNode.globalX || 0) +
+ (renderedEdgeFromNode.globalX || 0),
+ y:
+ renderedEdgeCurvePoints[renderedEdgeCurvePoints.length - 1]
+ .y -
+ (opNode.globalY || 0) +
+ (renderedEdgeFromNode.globalY || 0),
+ },
+ {
+ x: startPt.x - (opNode.globalX || 0),
+ y: startPt.y - (opNode.globalY || 0),
+ },
+ );
+ }
}
+ } else if (
+ isGroupNode(highlightedNode) ||
+ (isOpNode(highlightedNode) && !highlightedNode.hideInLayout)
+ ) {
+ (reuseRenderedEdgeCurvePoints ? curvePoints : points).push(
+ ...this.getDirectEdgeBetweenNodes(opNode, highlightedNode),
+ );
}
- } else if (
- isGroupNode(highlightedNode) ||
- (isOpNode(highlightedNode) && !highlightedNode.hideInLayout)
- ) {
- (reuseRenderedEdgeCurvePoints ? curvePoints : points).push(
- ...this.getDirectEdgeBetweenNodes(selectedNode, highlightedNode),
- );
- }
- // Use these points to form an edge and add it as an overlay edge.
- if (!reuseRenderedEdgeCurvePoints) {
- if (points.length > 0) {
- overlayEdges.push({
- id: `overlay_${selectedNode.id}___${highlightedNode.id}`,
- fromNodeId: selectedNode.id,
- toNodeId: highlightedNode.id,
- points,
- type: 'outgoing',
- });
- }
- } else {
- if (curvePoints.length > 0) {
- overlayEdges.push({
- id: `overlay_${selectedNode.id}___${highlightedNode.id}`,
- fromNodeId: selectedNode.id,
- toNodeId: highlightedNode.id,
- points: [],
- curvePoints,
- type: 'outgoing',
- });
+ // Use these points to form an edge and add it as an overlay edge.
+ if (!reuseRenderedEdgeCurvePoints) {
+ if (points.length > 0) {
+ overlayEdges.push({
+ id: `overlay_${opNode.id}___${highlightedNode.id}`,
+ fromNodeId: opNode.id,
+ toNodeId: highlightedNode.id,
+ points,
+ type: 'outgoing',
+ });
+ }
+ } else {
+ if (curvePoints.length > 0) {
+ overlayEdges.push({
+ id: `overlay_${opNode.id}___${highlightedNode.id}`,
+ fromNodeId: opNode.id,
+ toNodeId: highlightedNode.id,
+ points: [],
+ curvePoints,
+ type: 'outgoing',
+ });
+ }
}
}
}
@@ -830,12 +885,16 @@ export class WebglRendererIoHighlightService {
return false;
}
- // Ignore when clicking on a group node.
+ // Ignore when clicking on a group node and the corresponding config option
+ // is not enabled.
const selectedNode =
this.webglRenderer.curModelGraph.nodesById[
this.webglRenderer.selectedNodeId
];
- if (isGroupNode(selectedNode)) {
+ if (
+ isGroupNode(selectedNode) &&
+ !this.webglRenderer.appService.config()?.highlightLayerNodeInputsOutputs
+ ) {
return false;
}
diff --git a/src/ui/src/services/settings_service.ts b/src/ui/src/services/settings_service.ts
index 348aa01c..0206eb0f 100644
--- a/src/ui/src/services/settings_service.ts
+++ b/src/ui/src/services/settings_service.ts
@@ -35,6 +35,7 @@ export enum SettingKey {
DISALLOW_VERTICAL_EDGE_LABELS = 'disallow_vertical_edge_labels',
KEEP_LAYERS_WITH_A_SINGLE_CHILD = 'keep_layers_with_a_single_child',
SHOW_OP_NODE_OUT_OF_LAYER_EDGES_WITHOUT_SELECTING = 'show_op_node_out_of_layer_edges_without_selecting',
+ HIGHLIGHT_LAYER_NODE_INPUTS_OUTPUTS = 'highlight_layer_node_inputs_outputs',
}
/** Setting types. */
@@ -156,6 +157,19 @@ export const SETTING_SHOW_OP_NODE_OUT_OF_LAYER_EDGES_WITHOUT_SELECTING: Setting
'especially for larger models.',
};
+/** Setting for highlighting layer node inputs and outputs. */
+export const SETTING_HIGHLIGHT_LAYER_NODE_INPUTS_OUTPUTS: Setting = {
+ label: 'Highlight inputs and outputs of the selected layer node',
+ key: SettingKey.HIGHLIGHT_LAYER_NODE_INPUTS_OUTPUTS,
+ type: SettingType.BOOLEAN,
+ defaultValue: false,
+ help:
+ 'By default, inputs and outputs are highlighted only when an op node ' +
+ 'is selected. Enable this setting to see inputs and outputs for a ' +
+ 'selected layer node, including all its descendant op nodes within ' +
+ 'that layer.',
+};
+
const SETTINGS_LOCAL_STORAGE_KEY = 'model_explorer_settings';
/** All settings. */
@@ -168,7 +182,8 @@ export const ALL_SETTINGS = [
SETTING_KEEP_LAYERS_WITH_A_SINGLE_CHILD,
SETTING_SHOW_WELCOME_CARD,
SETTING_DISALLOW_VERTICAL_EDGE_LABELS,
- SETTING_SHOW_OP_NODE_OUT_OF_LAYER_EDGES_WITHOUT_SELECTING,
+ SETTING_MAX_CONST_ELEMENT_COUNT_LIMIT,
+ SETTING_HIGHLIGHT_LAYER_NODE_INPUTS_OUTPUTS,
];
/**