From defb04fec935a4535beae4907dcffea474021dad Mon Sep 17 00:00:00 2001 From: Google AI Edge Date: Fri, 11 Oct 2024 12:43:42 -0700 Subject: [PATCH] Add support for edge overlays PiperOrigin-RevId: 684925525 --- .../visualizer/common/edge_overlays.ts | 82 ++++++++ .../visualizer/common/model_graph.ts | 5 + .../visualizer/common/sync_navigation.ts | 2 +- .../src/components/visualizer/common/task.ts | 1 + .../src/components/visualizer/common/utils.ts | 79 +++++++ .../visualizer/common/visualizer_config.ts | 7 + .../visualizer/edge_overlays_dropdown.ng.html | 97 +++++++++ .../visualizer/edge_overlays_dropdown.scss | 193 ++++++++++++++++++ .../visualizer/edge_overlays_dropdown.ts | 165 +++++++++++++++ .../visualizer/edge_overlays_service.ts | 155 ++++++++++++++ .../node_data_provider_dropdown.scss | 2 +- .../visualizer/renderer_wrapper.ng.html | 8 + .../visualizer/renderer_wrapper.scss | 4 + .../components/visualizer/renderer_wrapper.ts | 6 + .../src/components/visualizer/split_pane.ts | 30 ++- .../visualizer/sync_navigation_service.ts | 1 + .../visualizer/view_on_node.ng.html | 1 + .../src/components/visualizer/view_on_node.ts | 1 - .../src/components/visualizer/webgl_edges.ts | 23 ++- .../components/visualizer/webgl_renderer.ts | 59 ++++++ .../webgl_renderer_edge_overlays_service.ts | 193 ++++++++++++++++++ .../webgl_renderer_edge_texts_service.ts | 112 ++++++---- 22 files changed, 1171 insertions(+), 55 deletions(-) create mode 100644 src/ui/src/components/visualizer/common/edge_overlays.ts create mode 100644 src/ui/src/components/visualizer/edge_overlays_dropdown.ng.html create mode 100644 src/ui/src/components/visualizer/edge_overlays_dropdown.scss create mode 100644 src/ui/src/components/visualizer/edge_overlays_dropdown.ts create mode 100644 src/ui/src/components/visualizer/edge_overlays_service.ts create mode 100644 src/ui/src/components/visualizer/webgl_renderer_edge_overlays_service.ts diff --git a/src/ui/src/components/visualizer/common/edge_overlays.ts b/src/ui/src/components/visualizer/common/edge_overlays.ts new file mode 100644 index 00000000..06d66659 --- /dev/null +++ b/src/ui/src/components/visualizer/common/edge_overlays.ts @@ -0,0 +1,82 @@ +/** + * @license + * Copyright 2024 The Model Explorer Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================== + */ + +import {TaskData, TaskType} from './task'; + +/** The data for edge overlays. */ +export declare interface EdgeOverlaysData extends TaskData { + type: TaskType.EDGE_OVERLAYS; + + /** The name of this set of overlays, for UI display purposes. */ + name: string; + + /** A list of edge overlays. */ + overlays: EdgeOverlay[]; +} + +/** An edge overlay. */ +export declare interface EdgeOverlay { + /** The name displayed in the UI to identify this overlay. */ + name: string; + + /** The edges that define the overlay. */ + edges: Edge[]; + + /** + * The color of the overlay edges. + * + * They are rendered in this color when any of the nodes in this overlay is + * selected. + */ + edgeColor: string; + + /** The width of the overlay edges. Default to 2. */ + edgeWidth?: number; + + /** The font size of the edge labels. Default to 7.5. */ + edgeLabelFontSize?: number; +} + +/** An edge in the overlay. */ +export declare interface Edge { + /** The id of the source node. Op node only. */ + sourceNodeId: string; + + /** The id of the target node. Op node only. */ + targetNodeId: string; + + /** Label shown on the edge. */ + label?: string; +} + +/** The processed edge overlays data. */ +export declare interface ProcessedEdgeOverlaysData extends EdgeOverlaysData { + /** A random id. */ + id: string; + + processedOverlays: ProcessedEdgeOverlay[]; +} + +/** The processed edge overlay. */ +export declare interface ProcessedEdgeOverlay extends EdgeOverlay { + /** A random id. */ + id: string; + + /** The set of node ids that are in this overlay. */ + nodeIds: Set; +} diff --git a/src/ui/src/components/visualizer/common/model_graph.ts b/src/ui/src/components/visualizer/common/model_graph.ts index b79f9b05..e9c1748d 100644 --- a/src/ui/src/components/visualizer/common/model_graph.ts +++ b/src/ui/src/components/visualizer/common/model_graph.ts @@ -271,4 +271,9 @@ export declare interface ModelEdge { // The following are for webgl rendering. curvePoints?: Point[]; + + // The label of the edge. + // + // If set, it will be rendered on edge instead of tensor shape. + label?: string; } diff --git a/src/ui/src/components/visualizer/common/sync_navigation.ts b/src/ui/src/components/visualizer/common/sync_navigation.ts index d9ac981f..b82a08e5 100644 --- a/src/ui/src/components/visualizer/common/sync_navigation.ts +++ b/src/ui/src/components/visualizer/common/sync_navigation.ts @@ -19,7 +19,7 @@ import {TaskData, TaskType} from './task'; /** The data for navigation syncing. */ -export interface SyncNavigationData extends TaskData { +export declare interface SyncNavigationData extends TaskData { type: TaskType.SYNC_NAVIGATION; mapping: SyncNavigationMapping; diff --git a/src/ui/src/components/visualizer/common/task.ts b/src/ui/src/components/visualizer/common/task.ts index ec2bcbd1..97e03fcd 100644 --- a/src/ui/src/components/visualizer/common/task.ts +++ b/src/ui/src/components/visualizer/common/task.ts @@ -24,4 +24,5 @@ export declare interface TaskData { /** The type of a task. */ export enum TaskType { SYNC_NAVIGATION = 'sync_navigation', + EDGE_OVERLAYS = 'edge_overlays', } diff --git a/src/ui/src/components/visualizer/common/utils.ts b/src/ui/src/components/visualizer/common/utils.ts index 1f1291b6..6c2a9963 100644 --- a/src/ui/src/components/visualizer/common/utils.ts +++ b/src/ui/src/components/visualizer/common/utils.ts @@ -48,6 +48,7 @@ import { ProcessedNodeQuery, ProcessedNodeRegexQuery, ProcessedNodeStylerRule, + Rect, SearchMatch, SearchMatchType, SearchNodeType, @@ -974,3 +975,81 @@ export function splitLabel(label: string): string[] { export function getMultiLineLabelExtraHeight(label: string): number { return (splitLabel(label).length - 1) * NODE_LABEL_LINE_HEIGHT; } + +/** + * Calculates the closest intersection points of a line (L) connecting + * the centers of two rectangles (rect1 and rect2) with the sides of these + * rectangles. + */ +export function getIntersectionPoints(rect1: Rect, rect2: Rect) { + // Function to calculate the center of a rectangle + function getCenter(rect: Rect) { + return { + x: rect.x + rect.width / 2, + y: rect.y + rect.height / 2, + }; + } + + // Function to calculate intersection between a line and a rectangle + function getIntersection(rect: Rect, center1: Point, center2: Point) { + // Line parameters + const dx = center2.x - center1.x; + const dy = center2.y - center1.y; + + // Check for intersection with each of the four sides of the rectangle + let tMin = Number.MAX_VALUE; + let intersection: Point = {x: 0, y: 0}; + + // Left side (x = rect.x) + if (dx !== 0) { + const t = (rect.x - center1.x) / dx; + const y = center1.y + t * dy; + if (t >= 0 && y >= rect.y && y <= rect.y + rect.height && t < tMin) { + tMin = t; + intersection = {x: rect.x, y}; + } + } + + // Right side (x = rect.x + rect.width) + if (dx !== 0) { + const t = (rect.x + rect.width - center1.x) / dx; + const y = center1.y + t * dy; + if (t >= 0 && y >= rect.y && y <= rect.y + rect.height && t < tMin) { + tMin = t; + intersection = {x: rect.x + rect.width, y}; + } + } + + // Top side (y = rect.y) + if (dy !== 0) { + const t = (rect.y - center1.y) / dy; + const x = center1.x + t * dx; + if (t >= 0 && x >= rect.x && x <= rect.x + rect.width && t < tMin) { + tMin = t; + intersection = {x, y: rect.y}; + } + } + + // Bottom side (y = rect.y + rect.height) + if (dy !== 0) { + const t = (rect.y + rect.height - center1.y) / dy; + const x = center1.x + t * dx; + if (t >= 0 && x >= rect.x && x <= rect.x + rect.width && t < tMin) { + tMin = t; + intersection = {x, y: rect.y + rect.height}; + } + } + + return intersection; + } + + // Get the centers of the rectangles + const center1 = getCenter(rect1); + const center2 = getCenter(rect2); + + // Find the closest intersection point of the line with rect1 and rect2 + const intersection1 = getIntersection(rect1, center1, center2); + const intersection2 = getIntersection(rect2, center2, center1); + + return {intersection1, intersection2}; +} diff --git a/src/ui/src/components/visualizer/common/visualizer_config.ts b/src/ui/src/components/visualizer/common/visualizer_config.ts index db47abd7..b0885a8d 100644 --- a/src/ui/src/components/visualizer/common/visualizer_config.ts +++ b/src/ui/src/components/visualizer/common/visualizer_config.ts @@ -16,6 +16,7 @@ * ============================================================================== */ +import {EdgeOverlaysData} from './edge_overlays'; import {SyncNavigationData} from './sync_navigation'; import {NodeStylerRule, RendererType} from './types'; @@ -63,6 +64,12 @@ export declare interface VisualizerConfig { /** The data for navigation syncing. */ syncNavigationData?: SyncNavigationData; + /** List of data for edge overlays that will be applied to the left pane. */ + edgeOverlaysDataListLeftPane?: EdgeOverlaysData[]; + + /** List of data for edge overlays that will be applied to the right pane. */ + edgeOverlaysDataListRightPane?: EdgeOverlaysData[]; + /** * Default graph renderer. * diff --git a/src/ui/src/components/visualizer/edge_overlays_dropdown.ng.html b/src/ui/src/components/visualizer/edge_overlays_dropdown.ng.html new file mode 100644 index 00000000..4c3c7eb7 --- /dev/null +++ b/src/ui/src/components/visualizer/edge_overlays_dropdown.ng.html @@ -0,0 +1,97 @@ + + +
+
+ polyline +
+
+ + +
+ Show custom edge overlays on graph +
+
+ + +
+
+
Edge overlays
+
+ close +
+
+ + +
+ @if (overlaysSets().length === 0) { +
+ No loaded edge overlays +
+ } @else { + @for (overlaySet of overlaysSets(); track overlaySet.id) { +
+
+ {{overlaySet.name}} +
+ delete +
+
+ @for (overlay of overlaySet.overlays; track overlay.id) { +
+ + @if (overlay.selected) { +
+ View +
+ } +
+ } +
+ } + } +
+ + +
+
Load from computer
+ +
+ +
+
\ No newline at end of file diff --git a/src/ui/src/components/visualizer/edge_overlays_dropdown.scss b/src/ui/src/components/visualizer/edge_overlays_dropdown.scss new file mode 100644 index 00000000..997bd096 --- /dev/null +++ b/src/ui/src/components/visualizer/edge_overlays_dropdown.scss @@ -0,0 +1,193 @@ +/** + * @license + * Copyright 2024 The Model Explorer Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================== + */ + +.container { + .mat-icon-container { + width: 20px; + height: 20px; + display: flex; + align-items: center; + justify-content: center; + cursor: pointer; + opacity: 0.6; + + &:hover { + opacity: 0.9; + } + + mat-icon { + font-size: 20px; + width: 20px; + height: 20px; + } + } +} + +::ng-deep bubble-container:has(.model-explorer-edge-overlays-popup) { + border-top-left-radius: 0; + border-top-right-radius: 0; +} + +::ng-deep .model-explorer-edge-overlays-popup { + padding: 12px; + padding-top: 10px; + font-size: 12px; + background-color: white; + display: flex; + flex-direction: column; + + .icon-container { + cursor: pointer; + opacity: 0.8; + display: flex; + align-items: center; + justify-content: center; + + &:hover { + opacity: 1; + } + + mat-icon { + font-size: 16px; + width: 16px; + height: 16px; + color: #777; + } + } + + .label { + font-weight: 500; + font-size: 11px; + text-transform: uppercase; + letter-spacing: 0.0727em; + margin-bottom: 6px; + display: flex; + align-items: center; + justify-content: space-between; + + &:not(:first-child) { + margin-top: 12px; + } + } + + .loaded-overlays-container { + display: flex; + flex-direction: column; + padding-bottom: 8px; + border-bottom: 1px solid #ccc; + gap: 8px; + + .no-overlays-label { + color: #999; + } + + .overlay-set-label { + display: flex; + align-items: center; + justify-content: space-between; + font-weight: 700; + line-height: 15px; + word-break: break-all; + margin-bottom: 4px; + } + + .overlay-item { + display: flex; + align-items: center; + justify-content: space-between; + + label { + display: flex; + align-items: center; + cursor: pointer; + line-height: 15px; + word-break: break-all; + gap: 4px; + user-select: none; + + input { + cursor: pointer; + } + } + + .view-label { + cursor: pointer; + color: #00639b; + opacity: .8; + user-select: none; + line-height: 15px; + + &:hover { + opacity: 1; + } + } + } + } + + .upload-container { + display: flex; + flex-direction: column; + align-items: flex-start; + padding: 0 16px 0 0; + margin-top: 12px; + } + + .upload-json-file-button { + margin: 4px 0; + width: 90px; + height: 30px; + /* stylelint-disable-next-line declaration-no-important -- override MDC */ + font-size: 12px !important; + /* stylelint-disable-next-line declaration-no-important -- override MDC */ + letter-spacing: normal !important; + + &.upload { + margin-top: 2px; + } + + ::ng-deep .mat-mdc-button-touch-target { + display: none; + } + } + + .or-divider { + height: 1px; + border-top: 1px solid #eee; + position: relative; + margin-top: 12px; + + .or-label { + font-size: 10px; + top: -12px; + color: #aaa; + position: absolute; + padding: 2px; + background-color: white; + display: flex; + align-items: center; + justify-content: center; + width: 16px; + left: calc(50% - 8px); + } + } + + .upload-json-file-input { + display: none; + } + +} \ No newline at end of file diff --git a/src/ui/src/components/visualizer/edge_overlays_dropdown.ts b/src/ui/src/components/visualizer/edge_overlays_dropdown.ts new file mode 100644 index 00000000..3e882182 --- /dev/null +++ b/src/ui/src/components/visualizer/edge_overlays_dropdown.ts @@ -0,0 +1,165 @@ +/** + * @license + * Copyright 2024 The Model Explorer Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================== + */ + +import {OverlaySizeConfig} from '@angular/cdk/overlay'; +import {CommonModule} from '@angular/common'; +import { + ChangeDetectionStrategy, + ChangeDetectorRef, + Component, + computed, + inject, + Input, + Signal, + ViewChild, +} from '@angular/core'; +import {MatButtonModule} from '@angular/material/button'; +import {MatIconModule} from '@angular/material/icon'; +import {MatSnackBar} from '@angular/material/snack-bar'; +import {MatTooltipModule} from '@angular/material/tooltip'; +import {Bubble} from '../bubble/bubble'; +import {BubbleClick} from '../bubble/bubble_click'; +import {AppService} from './app_service'; +import {ProcessedEdgeOverlay} from './common/edge_overlays'; +import {EdgeOverlaysService} from './edge_overlays_service'; +import {LocalStorageService} from './local_storage_service'; + +interface OverlaysSet { + id: string; + name: string; + overlays: OverlayItem[]; +} + +interface OverlayItem { + id: string; + name: string; + selected: boolean; + processedOverlay: ProcessedEdgeOverlay; +} + +/** The edge overlays dropdown panel with the trigger button. */ +@Component({ + standalone: true, + selector: 'edge-overlays-dropdown', + imports: [ + Bubble, + BubbleClick, + CommonModule, + MatButtonModule, + MatIconModule, + MatTooltipModule, + ], + templateUrl: './edge_overlays_dropdown.ng.html', + styleUrls: ['./edge_overlays_dropdown.scss'], + changeDetection: ChangeDetectionStrategy.OnPush, +}) +export class EdgeOverlaysDropdown { + @Input({required: true}) paneId!: string; + @Input({required: true}) rendererId!: string; + @ViewChild(BubbleClick) popup!: BubbleClick; + + private readonly appService = inject(AppService); + private readonly localStorageService = inject(LocalStorageService); + private readonly changeDetectorRef = inject(ChangeDetectorRef); + private readonly edgeOverlaysService = inject(EdgeOverlaysService); + private readonly snackBar = inject(MatSnackBar); + + readonly overlaysSets: Signal = computed(() => { + const overlays = this.edgeOverlaysService.loadedEdgeOverlays(); + return overlays.map((overlay) => ({ + id: overlay.id, + name: overlay.name, + overlays: overlay.processedOverlays.map((overlay) => ({ + id: overlay.id, + name: overlay.name, + selected: this.edgeOverlaysService + .selectedOverlayIds() + .includes(overlay.id), + processedOverlay: overlay, + })), + })); + }); + + readonly helpPopupSize: OverlaySizeConfig = { + minWidth: 0, + minHeight: 0, + }; + + readonly edgeOverlaysPopupSize: OverlaySizeConfig = { + minWidth: 280, + minHeight: 0, + }; + + readonly remoteSourceLoading = this.edgeOverlaysService.remoteSourceLoading; + opened = false; + + constructor() { + } + + handleClickOnEdgeOverlaysButton() { + if (this.opened) { + this.popup.closeDialog(); + } + } + + handleClickUpload(input: HTMLInputElement) { + const files = input.files; + if (!files || files.length === 0) { + return; + } + const file = files[0]; + const fileReader = new FileReader(); + fileReader.onload = (event) => { + const error = this.edgeOverlaysService.addEdgeOverlayDataFromJsonData( + event.target?.result as string, + ); + if (error) { + this.showError(error); + } + }; + fileReader.readAsText(file); + input.value = ''; + } + + handleDeleteOverlaySet(overlaySet: OverlaysSet) { + this.edgeOverlaysService.deleteOverlayData(overlaySet.id); + } + + toggleOverlaySelection(overlay: OverlayItem) { + this.edgeOverlaysService.toggleOverlaySelection(overlay.id); + } + + handleClickViewOverlay(overlay: OverlayItem) { + // Get the first node of the overlay. + const edges = overlay.processedOverlay.edges; + if (edges.length === 0) { + return; + } + const firstNodeId = edges[0].sourceNodeId; + + // Reveal it. + this.appService.setNodeToReveal(this.paneId, firstNodeId); + } + + private showError(message: string) { + console.error(message); + this.snackBar.open(message, 'Dismiss', { + duration: 5000, + }); + } +} diff --git a/src/ui/src/components/visualizer/edge_overlays_service.ts b/src/ui/src/components/visualizer/edge_overlays_service.ts new file mode 100644 index 00000000..ed5d1721 --- /dev/null +++ b/src/ui/src/components/visualizer/edge_overlays_service.ts @@ -0,0 +1,155 @@ +/** + * @license + * Copyright 2024 The Model Explorer Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================== + */ + +import {Injectable, computed, signal} from '@angular/core'; +import { + EdgeOverlaysData, + ProcessedEdgeOverlay, + ProcessedEdgeOverlaysData, +} from './common/edge_overlays'; +import {ReadFileResp} from './common/types'; +import {genUid} from './common/utils'; + +/** A service for managing edge overlays. */ +@Injectable() +export class EdgeOverlaysService { + readonly remoteSourceLoading = signal(false); + + readonly loadedEdgeOverlays = signal([]); + + readonly selectedOverlayIds = signal([]); + + readonly selectedOverlays = computed(() => { + const overlays: ProcessedEdgeOverlay[] = []; + for (const overlayData of this.loadedEdgeOverlays()) { + for (const overlay of overlayData.processedOverlays) { + if (this.selectedOverlayIds().includes(overlay.id)) { + overlays.push(overlay); + } + } + } + return overlays; + }); + + addOverlay(overlay: EdgeOverlaysData) { + this.loadedEdgeOverlays.update((loadedOverlays) => { + return [...loadedOverlays, processOverlay(overlay)]; + }); + } + + deleteOverlayData(id: string) { + const overlaysDataToDelete = this.loadedEdgeOverlays().find( + (overlaysData) => overlaysData.id === id, + ); + this.loadedEdgeOverlays.update((overlayDataList) => { + return overlayDataList.filter((overlayData) => overlayData.id !== id); + }); + + // Update selected overlays. + if (overlaysDataToDelete) { + const overlayIdsToDelete = new Set( + overlaysDataToDelete.processedOverlays.map((overlay) => overlay.id), + ); + this.selectedOverlayIds.update((selectedOverlayIds) => { + return selectedOverlayIds.filter((id) => !overlayIdsToDelete.has(id)); + }); + } + } + + toggleOverlaySelection(idToToggle: string) { + this.selectedOverlayIds.update((selectedOverlayIds) => { + let ids = [...selectedOverlayIds]; + if (selectedOverlayIds.includes(idToToggle)) { + ids = ids.filter((id) => id !== idToToggle); + } else { + ids.push(idToToggle); + } + return ids; + }); + } + + addEdgeOverlayData(data: EdgeOverlaysData) { + this.addOverlay(data); + + // Select all newly-added overlays. + this.selectedOverlayIds.update((selectedOverlayIds) => { + const loadedOverlaysDataList = this.loadedEdgeOverlays(); + const newOverlayData = + loadedOverlaysDataList[loadedOverlaysDataList.length - 1]; + const newIds = newOverlayData.processedOverlays.map( + (overlay) => overlay.id, + ); + return [...selectedOverlayIds, ...newIds]; + }); + } + + addEdgeOverlayDataFromJsonData(str: string): string { + try { + const data = JSON.parse(str) as EdgeOverlaysData; + this.addEdgeOverlayData(data); + } catch (e) { + return `Failed to parse JSON file. ${e}`; + } + return ''; + } + + async loadFromCns(path: string): Promise { + // Call API to read file content. + this.remoteSourceLoading.set(true); + const url = `/read_file?path=${path}`; + const resp = await fetch(url); + if (!resp.ok) { + this.remoteSourceLoading.set(false); + return `Failed to load JSON file "${path}"`; + } + + // Parse response. + const json = JSON.parse( + (await resp.text()).replace(")]}'\n", ''), + ) as ReadFileResp; + + const error = this.addEdgeOverlayDataFromJsonData(json.content); + + this.remoteSourceLoading.set(false); + + return error; + } +} + +function processOverlay( + overlayData: EdgeOverlaysData, +): ProcessedEdgeOverlaysData { + const processedOverlayData: ProcessedEdgeOverlaysData = { + id: genUid(), + processedOverlays: [], + ...overlayData, + }; + for (const overlay of overlayData.overlays) { + const processedOverlay: ProcessedEdgeOverlay = { + id: genUid(), + nodeIds: new Set(), + ...overlay, + }; + processedOverlayData.processedOverlays.push(processedOverlay); + for (const edge of overlay.edges) { + processedOverlay.nodeIds.add(edge.sourceNodeId); + processedOverlay.nodeIds.add(edge.targetNodeId); + } + } + return processedOverlayData; +} diff --git a/src/ui/src/components/visualizer/node_data_provider_dropdown.scss b/src/ui/src/components/visualizer/node_data_provider_dropdown.scss index 03c92774..466d4fec 100644 --- a/src/ui/src/components/visualizer/node_data_provider_dropdown.scss +++ b/src/ui/src/components/visualizer/node_data_provider_dropdown.scss @@ -157,7 +157,7 @@ .or-label { font-size: 10px; - top: -9px; + top: -12px; color: #aaa; position: absolute; padding: 2px; diff --git a/src/ui/src/components/visualizer/renderer_wrapper.ng.html b/src/ui/src/components/visualizer/renderer_wrapper.ng.html index 6d20e14f..4f9793db 100644 --- a/src/ui/src/components/visualizer/renderer_wrapper.ng.html +++ b/src/ui/src/components/visualizer/renderer_wrapper.ng.html @@ -142,6 +142,14 @@ + + @if (showEdgeOverlaysDropdown) { + + + } + @if (showDownloadPng) {
diff --git a/src/ui/src/components/visualizer/renderer_wrapper.scss b/src/ui/src/components/visualizer/renderer_wrapper.scss index 46c7ff2a..0a08ba35 100644 --- a/src/ui/src/components/visualizer/renderer_wrapper.scss +++ b/src/ui/src/components/visualizer/renderer_wrapper.scss @@ -113,6 +113,10 @@ margin: 2px 5px; height : 20px; } + + edge-overlays-dropdown { + margin-left: 4px; + } } subgraph-breadcrumbs { diff --git a/src/ui/src/components/visualizer/renderer_wrapper.ts b/src/ui/src/components/visualizer/renderer_wrapper.ts index 216545b0..870a481e 100644 --- a/src/ui/src/components/visualizer/renderer_wrapper.ts +++ b/src/ui/src/components/visualizer/renderer_wrapper.ts @@ -46,6 +46,7 @@ import { SubgraphBreadcrumbItem, } from './common/types'; import {isGroupNode} from './common/utils'; +import {EdgeOverlaysDropdown} from './edge_overlays_dropdown'; import {SearchBar} from './search_bar'; import {SnapshotManager} from './snapshot_manager'; import {SubgraphBreadcrumbs} from './subgraph_breadcrumbs'; @@ -60,6 +61,7 @@ import {WebglRenderer} from './webgl_renderer'; Bubble, BubbleClick, CommonModule, + EdgeOverlaysDropdown, MatButtonModule, MatIconModule, MatMenuModule, @@ -204,6 +206,10 @@ export class RendererWrapper { return !this.inPopup && this.curSubgraphBreadcrumbs.length > 1; } + get showEdgeOverlaysDropdown(): boolean { + return !this.inPopup; + } + get disableExpandCollapseAllButton(): boolean { return this.appService.getFlattenLayers(this.paneId); } diff --git a/src/ui/src/components/visualizer/split_pane.ts b/src/ui/src/components/visualizer/split_pane.ts index cc9cc730..bb17673e 100644 --- a/src/ui/src/components/visualizer/split_pane.ts +++ b/src/ui/src/components/visualizer/split_pane.ts @@ -23,9 +23,11 @@ import { ChangeDetectorRef, Component, Input, + OnInit, } from '@angular/core'; import {AppService} from './app_service'; import type {Pane} from './common/types'; +import {EdgeOverlaysService} from './edge_overlays_service'; import {GraphPanel} from './graph_panel'; import {InfoPanel} from './info_panel'; import {SplitPaneService} from './split_pane_service'; @@ -36,7 +38,7 @@ import {SubgraphSelectionService} from './subgraph_selection_service'; standalone: true, selector: 'split-pane', imports: [CommonModule, GraphPanel, InfoPanel], - providers: [SubgraphSelectionService, SplitPaneService], + providers: [EdgeOverlaysService, SubgraphSelectionService, SplitPaneService], templateUrl: './split_pane.ng.html', styleUrls: ['./split_pane.scss'], animations: [ @@ -59,14 +61,38 @@ import {SubgraphSelectionService} from './subgraph_selection_service'; ], changeDetection: ChangeDetectionStrategy.OnPush, }) -export class SplitPane { +export class SplitPane implements OnInit { @Input({required: true}) pane!: Pane; constructor( private readonly appService: AppService, private readonly changeDetectorRef: ChangeDetectorRef, + private readonly edgeOverlaysService: EdgeOverlaysService, ) {} + ngOnInit() { + // Load edge overlays stored in config. + const config = this.appService.config(); + const panes = this.appService.panes(); + if ( + panes.length > 0 && + panes[0].id === this.pane.id && + config?.edgeOverlaysDataListLeftPane + ) { + for (const data of config.edgeOverlaysDataListLeftPane) { + this.edgeOverlaysService.addEdgeOverlayData(data); + } + } else if ( + panes.length > 1 && + panes[1].id === this.pane.id && + config?.edgeOverlaysDataListRightPane + ) { + for (const data of config.edgeOverlaysDataListRightPane) { + this.edgeOverlaysService.addEdgeOverlayData(data); + } + } + } + refresh() { this.changeDetectorRef.markForCheck(); } diff --git a/src/ui/src/components/visualizer/sync_navigation_service.ts b/src/ui/src/components/visualizer/sync_navigation_service.ts index b3d66fa2..a9d5dcf8 100644 --- a/src/ui/src/components/visualizer/sync_navigation_service.ts +++ b/src/ui/src/components/visualizer/sync_navigation_service.ts @@ -96,6 +96,7 @@ export class SyncNavigationService { const url = `/read_file?path=${path}`; const resp = await fetch(url); if (!resp.ok) { + this.loadingFromCns.set(false); return `Failed to load JSON file "${path}"`; } diff --git a/src/ui/src/components/visualizer/view_on_node.ng.html b/src/ui/src/components/visualizer/view_on_node.ng.html index c6755051..b02055fa 100644 --- a/src/ui/src/components/visualizer/view_on_node.ng.html +++ b/src/ui/src/components/visualizer/view_on_node.ng.html @@ -69,6 +69,7 @@ } } +
View on edges
diff --git a/src/ui/src/components/visualizer/view_on_node.ts b/src/ui/src/components/visualizer/view_on_node.ts index 39614cda..b58cf3e2 100644 --- a/src/ui/src/components/visualizer/view_on_node.ts +++ b/src/ui/src/components/visualizer/view_on_node.ts @@ -33,7 +33,6 @@ import {MatTooltipModule} from '@angular/material/tooltip'; import {Bubble} from '../bubble/bubble'; import {BubbleClick} from '../bubble/bubble_click'; - import {AppService} from './app_service'; import { LOCAL_STORAGE_KEY_SHOW_ON_EDGE_ITEM_TYPES, diff --git a/src/ui/src/components/visualizer/webgl_edges.ts b/src/ui/src/components/visualizer/webgl_edges.ts index 94c81214..161d4349 100644 --- a/src/ui/src/components/visualizer/webgl_edges.ts +++ b/src/ui/src/components/visualizer/webgl_edges.ts @@ -201,6 +201,7 @@ export class WebglEdges { constructor( private readonly color: WebglColor, private readonly edgeWidth: number, + private readonly arrowScale = 1, ) { this.planeGeo = new THREE.PlaneGeometry(1, 1); this.planeGeo.rotateX(-Math.PI / 2); @@ -217,12 +218,15 @@ export class WebglEdges { // Create arrow head geo. const triangle = new THREE.Shape(); + const arrowBaseSize = ARROW_BASE_SIZE * arrowScale; + const arrowHeight = ARROW_HEIGHT * arrowScale; + const arrowThickness = ARROW_THICKNESS * arrowScale; triangle - .moveTo(-ARROW_BASE_SIZE / 2, -ARROW_HEIGHT) - .lineTo(0, -ARROW_THICKNESS) - .lineTo(ARROW_BASE_SIZE / 2, -ARROW_HEIGHT) + .moveTo(-arrowBaseSize / 2, -arrowHeight) + .lineTo(0, -arrowThickness) + .lineTo(arrowBaseSize / 2, -arrowHeight) .lineTo(0, 0) - .lineTo(-ARROW_BASE_SIZE / 2, -ARROW_HEIGHT); + .lineTo(-arrowBaseSize / 2, -arrowHeight); this.arrowHeadGeometry = new THREE.ShapeGeometry(triangle); this.arrowHeadGeometry.rotateX(-Math.PI / 2); @@ -280,6 +284,15 @@ export class WebglEdges { endPt.x + nodeGlobalX, endPt.y + nodeGlobalY, ]; + const savedCurEndpoints = [...curEndpoints]; + + // Move the last segment inward a little bit so that it doesn't go out + // of the arrowhead. + if (i === points.length - 2 && points.length >= 2) { + const f = Math.atan2(endPt.y - startPt.y, endPt.x - startPt.x); + curEndpoints[2] -= (Math.cos(f) * ARROW_HEIGHT * this.arrowScale) / 2; + curEndpoints[3] -= (Math.sin(f) * ARROW_HEIGHT * this.arrowScale) / 2; + } const savedSegment = this.savedEdgeSegments[segmentId]; if (forceNoAnimation) { @@ -308,7 +321,7 @@ export class WebglEdges { // Arrowheads. if (i === points.length - 2) { const arrowHeadId = edge.id; - const curLastSegmentEndpoints = curEndpoints; + const curLastSegmentEndpoints = savedCurEndpoints; const savedArrowHead = this.savedArrowHeads[arrowHeadId]; if (forceNoAnimation) { lastSegmentEndPoints.push(...curLastSegmentEndpoints); diff --git a/src/ui/src/components/visualizer/webgl_renderer.ts b/src/ui/src/components/visualizer/webgl_renderer.ts index 91e66c60..689840ef 100644 --- a/src/ui/src/components/visualizer/webgl_renderer.ts +++ b/src/ui/src/components/visualizer/webgl_renderer.ts @@ -110,6 +110,7 @@ import {ThreejsService} from './threejs_service'; import {UiStateService} from './ui_state_service'; import {WebglEdges} from './webgl_edges'; import {WebglRendererAttrsTableService} from './webgl_renderer_attrs_table_service'; +import {WebglRendererEdgeOverlaysService} from './webgl_renderer_edge_overlays_service'; import {WebglRendererEdgeTextsService} from './webgl_renderer_edge_texts_service'; import {WebglRendererIdenticalLayerService} from './webgl_renderer_identical_layer_service'; import { @@ -198,6 +199,7 @@ type RenderElement = RenderElementNode | RenderElementEdge; providers: [ WebglRendererAttrsTableService, WebglRendererEdgeTextsService, + WebglRendererEdgeOverlaysService, WebglRendererIdenticalLayerService, WebglRendererIoHighlightService, WebglRendererIoTracingService, @@ -447,6 +449,7 @@ export class WebglRenderer implements OnInit, OnDestroy { private readonly viewContainerRef: ViewContainerRef, private readonly webglRendererAttrsTableService: WebglRendererAttrsTableService, readonly webglRendererEdgeTextsService: WebglRendererEdgeTextsService, + private readonly webglRendererEdgeOverlaysService: WebglRendererEdgeOverlaysService, private readonly webglRendererIdenticalLayerService: WebglRendererIdenticalLayerService, private readonly webglRendererIoHighlightService: WebglRendererIoHighlightService, private readonly webglRendererIoTracingService: WebglRendererIoTracingService, @@ -459,6 +462,7 @@ export class WebglRenderer implements OnInit, OnDestroy { ) { this.webglRendererAttrsTableService.init(this); this.webglRendererEdgeTextsService.init(this); + this.webglRendererEdgeOverlaysService.init(this); this.webglRendererIdenticalLayerService.init(this); this.webglRendererIoHighlightService.init(this); this.webglRendererIoTracingService.init(this); @@ -601,6 +605,7 @@ export class WebglRenderer implements OnInit, OnDestroy { // data needed to update nodes styles correctly. this.webglRendererIoHighlightService.updateIncomingAndOutgoingHighlights(); this.webglRendererIdenticalLayerService.updateIdenticalLayerIndicators(); + this.webglRendererEdgeOverlaysService.updateOverlaysData(); this.updateNodesStyles(); this.webglRendererThreejsService.render(); @@ -611,6 +616,50 @@ export class WebglRenderer implements OnInit, OnDestroy { nodeId: this.selectedNodeId, }); } + + // Automatically reveal all nodes in the edge overlays (if existed). + if (this.webglRendererEdgeOverlaysService.curOverlays.length > 0) { + const deepestExpandedGroupNodeIds = + this.webglRendererEdgeOverlaysService.getDeepestExpandedGroupNodeIds(); + if (deepestExpandedGroupNodeIds.length > 0) { + this.sendRelayoutGraphRequest( + this.selectedNodeId, + deepestExpandedGroupNodeIds, + ); + } else { + this.webglRendererEdgeOverlaysService.updateOverlaysEdges(); + this.webglRendererThreejsService.render(); + } + } else { + this.webglRendererEdgeOverlaysService.clearOverlaysEdges(); + this.webglRendererThreejsService.render(); + } + }); + + // Handle selected edge overlays changes. + effect(() => { + this.webglRendererEdgeOverlaysService.edgeOverlaysService.selectedOverlayIds(); + this.webglRendererEdgeOverlaysService.updateOverlaysData(); + + // Automatically reveal all nodes in the edge overlays (if existed). + if (this.selectedNodeId !== '') { + if (this.webglRendererEdgeOverlaysService.curOverlays.length > 0) { + const deepestExpandedGroupNodeIds = + this.webglRendererEdgeOverlaysService.getDeepestExpandedGroupNodeIds(); + if (deepestExpandedGroupNodeIds.length > 0) { + this.sendRelayoutGraphRequest( + this.selectedNodeId, + deepestExpandedGroupNodeIds, + ); + } else { + this.webglRendererEdgeOverlaysService.updateOverlaysEdges(); + this.webglRendererThreejsService.render(); + } + } else { + this.webglRendererEdgeOverlaysService.clearOverlaysEdges(); + this.webglRendererThreejsService.render(); + } + } }); // Handle "download as png". @@ -1431,6 +1480,15 @@ export class WebglRenderer implements OnInit, OnDestroy { return node.height || 0; } + getNodeRect(node: ModelNode): Rect { + return { + x: this.getNodeX(node), + y: this.getNodeY(node), + width: this.getNodeWidth(node), + height: this.getNodeHeight(node), + }; + } + getNodeLabelRelativeY(node: ModelNode): number { return 14; } @@ -1673,6 +1731,7 @@ export class WebglRenderer implements OnInit, OnDestroy { this.renderGraph(); this.webglRendererIoHighlightService.updateIncomingAndOutgoingHighlights(); this.webglRendererIdenticalLayerService.updateIdenticalLayerIndicators(); + this.webglRendererEdgeOverlaysService.updateOverlaysEdges(); this.updateNodesStyles(); if (rectToZoomFit) { const zoomFitFn = () => { diff --git a/src/ui/src/components/visualizer/webgl_renderer_edge_overlays_service.ts b/src/ui/src/components/visualizer/webgl_renderer_edge_overlays_service.ts new file mode 100644 index 00000000..cb002239 --- /dev/null +++ b/src/ui/src/components/visualizer/webgl_renderer_edge_overlays_service.ts @@ -0,0 +1,193 @@ +/** + * @license + * Copyright 2024 The Model Explorer Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================== + */ + +import {Injectable, inject} from '@angular/core'; +import * as three from 'three'; +import {WEBGL_ELEMENT_Y_FACTOR} from './common/consts'; +import {EdgeOverlay} from './common/edge_overlays'; +import {GroupNode, ModelEdge, OpNode} from './common/model_graph'; +import {getIntersectionPoints} from './common/utils'; +import {EdgeOverlaysService} from './edge_overlays_service'; +import {ThreejsService} from './threejs_service'; +import {WebglEdges} from './webgl_edges'; +import {WebglRenderer} from './webgl_renderer'; +import {WebglRendererThreejsService} from './webgl_renderer_threejs_service'; +import {WebglTexts} from './webgl_texts'; + +const THREE = three; + +const DEFAULT_EDGE_WIDTH = 1.5; + +/** + * Service for managing edge overlays related tasks in webgl renderer. + */ +@Injectable() +export class WebglRendererEdgeOverlaysService { + private readonly threejsService: ThreejsService = inject(ThreejsService); + + private webglRenderer!: WebglRenderer; + private webglRendererThreejsService!: WebglRendererThreejsService; + private overlaysEdgesList: WebglEdges[] = []; + private overlaysEdgeTextsList: WebglTexts[] = []; + + readonly edgeOverlaysService = inject(EdgeOverlaysService); + curOverlays: EdgeOverlay[] = []; + + init(webglRenderer: WebglRenderer) { + this.webglRenderer = webglRenderer; + this.webglRendererThreejsService = + webglRenderer.webglRendererThreejsService; + } + + updateOverlaysData() { + this.clearOverlaysData(); + + const selectedNodeId = this.webglRenderer.selectedNodeId; + if (!selectedNodeId) { + return; + } + + // Find overlays that contain the node from the selected overlays. + const selectedOverlays = this.edgeOverlaysService.selectedOverlays(); + for (const selectedOverlay of selectedOverlays) { + if (selectedOverlay.nodeIds.has(selectedNodeId)) { + this.curOverlays.push(selectedOverlay); + } + } + } + + clearOverlaysData() { + this.curOverlays = []; + } + + updateOverlaysEdges() { + this.clearOverlaysEdges(); + + if (this.curOverlays.length === 0) { + return; + } + + for (let i = 0; i < this.curOverlays.length; i++) { + const subgraph = this.curOverlays[i]; + const edgeWidth = subgraph.edgeWidth ?? DEFAULT_EDGE_WIDTH; + const edges: Array<{edge: ModelEdge; index: number}> = []; + const curWebglEdges = new WebglEdges( + new THREE.Color(subgraph.edgeColor), + edgeWidth, + edgeWidth / DEFAULT_EDGE_WIDTH, + ); + for (const {sourceNodeId, targetNodeId, label} of subgraph.edges) { + const sourceNode = this.webglRenderer.curModelGraph.nodesById[ + sourceNodeId + ] as OpNode; + const targetNode = this.webglRenderer.curModelGraph.nodesById[ + targetNodeId + ] as OpNode; + const {intersection1, intersection2} = getIntersectionPoints( + this.webglRenderer.getNodeRect(sourceNode), + this.webglRenderer.getNodeRect(targetNode), + ); + // Edge. + edges.push({ + edge: { + id: `overlay_edge_${i}_${sourceNodeId}_${targetNodeId}`, + fromNodeId: sourceNodeId, + toNodeId: targetNodeId, + label: label ?? '', + points: [], + curvePoints: [ + { + x: intersection1.x - (sourceNode?.globalX || 0), + y: intersection1.y - (sourceNode?.globalY || 0), + }, + { + x: intersection2.x - (sourceNode.globalX || 0), + y: intersection2.y - (sourceNode.globalY || 0), + }, + ], + }, + // Use anything > 95 which is used for rendering io highlight edges. + index: 96 / WEBGL_ELEMENT_Y_FACTOR, + }); + } + curWebglEdges.generateMesh(edges, this.webglRenderer.curModelGraph); + this.webglRendererThreejsService.addToScene(curWebglEdges.edgesMesh); + this.webglRendererThreejsService.addToScene(curWebglEdges.arrowHeadsMesh); + this.overlaysEdgesList.push(curWebglEdges); + + // Edge labels. + const labels = + this.webglRenderer.webglRendererEdgeTextsService.genLabelsOnEdges( + edges, + new THREE.Color(subgraph.edgeColor), + edgeWidth / 2, + 96.5, + subgraph.edgeLabelFontSize ?? 7.5, + ); + const curWebglTexts = new WebglTexts(this.threejsService); + curWebglTexts.generateMesh(labels, true, false, true); + this.webglRendererThreejsService.addToScene(curWebglTexts.mesh); + this.overlaysEdgeTextsList.push(curWebglTexts); + } + } + + clearOverlaysEdges() { + for (const webglEdges of this.overlaysEdgesList) { + webglEdges.clear(); + } + for (const webglTexts of this.overlaysEdgeTextsList) { + if (webglTexts.mesh && webglTexts.mesh.geometry) { + webglTexts.mesh.geometry.dispose(); + this.webglRendererThreejsService.removeFromScene(webglTexts.mesh); + } + } + + this.overlaysEdgesList = []; + this.overlaysEdgeTextsList = []; + } + + getDeepestExpandedGroupNodeIds(): string[] { + if (this.curOverlays.length === 0) { + return []; + } + + const ids = new Set(); + + const addNsParentId = (nodeId: string) => { + const node = this.webglRenderer.curModelGraph.nodesById[nodeId]; + if (node.nsParentId) { + const parentNode = this.webglRenderer.curModelGraph.nodesById[ + node.nsParentId + ] as GroupNode; + if ( + !parentNode.expanded || + !this.webglRenderer.isNodeRendered(parentNode.id) + ) { + ids.add(node.nsParentId); + } + } + }; + for (const subgraph of this.curOverlays) { + for (const {sourceNodeId, targetNodeId} of subgraph.edges) { + addNsParentId(sourceNodeId); + addNsParentId(targetNodeId); + } + } + return [...ids]; + } +} diff --git a/src/ui/src/components/visualizer/webgl_renderer_edge_texts_service.ts b/src/ui/src/components/visualizer/webgl_renderer_edge_texts_service.ts index fb5b955e..67508bf5 100644 --- a/src/ui/src/components/visualizer/webgl_renderer_edge_texts_service.ts +++ b/src/ui/src/components/visualizer/webgl_renderer_edge_texts_service.ts @@ -62,9 +62,13 @@ export class WebglRendererEdgeTextsService { genLabelsOnEdges( edges: Array<{index: number; edge: ModelEdge}>, color: three.Color, + extraOffsetToEdge = 0, + y = 95, + fontSize?: number, ): LabelData[] { const edgeLabelFontSize = - this.appService.config()?.edgeLabelFontSize || + fontSize ?? + this.appService.config()?.edgeLabelFontSize ?? DEFAULT_EDGE_LABEL_FONT_SIZE; const disallowVerticalEdgeLabels = this.appService.config()?.disallowVerticalEdgeLabels || false; @@ -80,32 +84,39 @@ export class WebglRendererEdgeTextsService { } // Find the tensor shape. - let tensorShape = '?'; - const outputsMetadata = fromNode.outputsMetadata || {}; - for (const outputId of Object.keys(outputsMetadata)) { - const outgoingEdge = (fromNode.outgoingEdges || []).find( - (curEdge) => - curEdge.sourceNodeOutputId === outputId && - curEdge.targetNodeId === edge.toNodeId, - ); - if (outgoingEdge != null) { - tensorShape = outputsMetadata[outputId]['shape'] || '?'; - tensorShape = tensorShape - .split('') - .map((char) => { - if (char === 'x') { - char = 'x'; - } - if (char === '∗') { - char = '*'; - } - if (char === '') { - char = ''; - } - return charsInfo[char] == null ? '?' : char; - }) - .join(''); - break; + let edgeLabel = '?'; + if (edge.label != null) { + edgeLabel = edge.label; + if (edgeLabel === '') { + continue; + } + } else { + const outputsMetadata = fromNode.outputsMetadata || {}; + for (const outputId of Object.keys(outputsMetadata)) { + const outgoingEdge = (fromNode.outgoingEdges || []).find( + (curEdge) => + curEdge.sourceNodeOutputId === outputId && + curEdge.targetNodeId === edge.toNodeId, + ); + if (outgoingEdge != null) { + edgeLabel = outputsMetadata[outputId]['shape'] || '?'; + edgeLabel = edgeLabel + .split('') + .map((char) => { + if (char === 'x') { + char = 'x'; + } + if (char === '∗') { + char = '*'; + } + if (char === '') { + char = ''; + } + return charsInfo[char] == null ? '?' : char; + }) + .join(''); + break; + } } } @@ -137,20 +148,25 @@ export class WebglRendererEdgeTextsService { // Use '3' to take some padding into account when calculating text length. const curveLength = curvePath.getLength(); const space = edgeLabelFontSize / 2 / curveLength; - const textLongerThanCurve = space * (tensorShape.length + 3) > 1; + const textLongerThanCurve = space * (edgeLabel.length + 3) > 1; const renderWholeTextFn = () => { const pos = curvePath.getPointAt(0.5) as three.Vector2; + const posX = pos.x; + const posY = + curvePoints[0].y === curvePoints[curvePoints.length - 1].y + ? pos.y - 10 - extraOffsetToEdge + : pos.y; labels.push({ - id: `${edge.id}_${tensorShape}`, + id: `${edge.id}_${edgeLabel}`, nodeId: edge.toNodeId, - label: tensorShape, + label: edgeLabel, height: edgeLabelFontSize, hAlign: 'center', vAlign: 'center', weight: FontWeight.MEDIUM, - x: pos.x, - y: 95, - z: pos.y, + x: posX, + y, + z: posY, color, borderColor: {r: 1, g: 1, b: 1}, }); @@ -176,12 +192,12 @@ export class WebglRendererEdgeTextsService { const startPosition = Math.max( 0, // 5 is the estimated height of the arrow head. - Math.min(0.25, 1 - tensorShape.length * space - 5 / curveLength), + Math.min(0.25, 1 - edgeLabel.length * space - 5 / curveLength), ); const maxOffset = Math.max( 0.05, // 5 is the estimated height of the arrow head. - 1 - 5 / curveLength - startPosition - space * tensorShape.length, + 1 - 5 / curveLength - startPosition - space * edgeLabel.length, ); // const step = 10 / curveLength; const step = 0.05; @@ -193,8 +209,8 @@ export class WebglRendererEdgeTextsService { let prevAngle: number | undefined = undefined; charInfoList = []; let curPosition = curStartPosition; - for (let i = 0; i < tensorShape.length; i++) { - const char = tensorShape[i]; + for (let i = 0; i < edgeLabel.length; i++) { + const char = edgeLabel[i]; const pos = curvePath.getPointAt( Math.min(curPosition, 1), ) as three.Vector2; @@ -237,8 +253,8 @@ export class WebglRendererEdgeTextsService { const charInfo = charsInfo[char]; let nextCharXadvance = 0; - if (i !== tensorShape.length - 1) { - const nextChar = tensorShape[i + 1]; + if (i !== edgeLabel.length - 1) { + const nextChar = edgeLabel[i + 1]; nextCharXadvance = charsInfo[nextChar].xadvance; } const delta = @@ -271,8 +287,8 @@ export class WebglRendererEdgeTextsService { char: string; }> = []; let curPosition = charInfoList[0].position; - for (let i = tensorShape.length - 1; i >= 0; i--) { - const char = tensorShape[i]; + for (let i = edgeLabel.length - 1; i >= 0; i--) { + const char = edgeLabel[i]; const pos = curvePath.getPointAt( Math.min(1, curPosition), ) as three.Vector2; @@ -294,7 +310,7 @@ export class WebglRendererEdgeTextsService { const charInfo = charsInfo[char]; let nextCharXadvance = 0; if (i >= 1) { - const nextCharInfo = charsInfo[tensorShape[i - 1]]; + const nextCharInfo = charsInfo[edgeLabel[i - 1]]; nextCharXadvance = nextCharInfo.xadvance; } const delta = @@ -323,9 +339,15 @@ export class WebglRendererEdgeTextsService { hAlign: '', vAlign: '', weight: FontWeight.MEDIUM, - x: pos.x + Math.sin(angle) * (-edgeLabelFontSize * 1.5), - y: 95, - z: pos.y + Math.cos(angle) * (-edgeLabelFontSize * 1.5), + x: + pos.x + + Math.sin(angle) * + (-edgeLabelFontSize * 1.5 - extraOffsetToEdge), + y, + z: + pos.y + + Math.cos(angle) * + (-edgeLabelFontSize * 1.5 - extraOffsetToEdge), color, angle, edgeTextMode: true,