diff --git a/optuna_dashboard/_app.py b/optuna_dashboard/_app.py index dfa7338c5..ac0e1cf0e 100644 --- a/optuna_dashboard/_app.py +++ b/optuna_dashboard/_app.py @@ -288,6 +288,21 @@ def get_plot(study_id: int, plot_type: str) -> dict[str, Any]: return {"reason": f"plot_type={plot_type} is not supported."} return fig.to_json() + @app.get("/api/compare-studies/plot/") + @json_api_view + def get_compare_studies_plot(plot_type: str) -> dict[str, Any]: + study_ids = map(int, request.query.getall("study_ids[]")) + studies = [ + optuna.load_study(study_name=storage.get_study_name_from_id(study_id), storage=storage) + for study_id in study_ids + ] + if plot_type == "edf": + fig = optuna.visualization.plot_edf(studies) + else: + response.status = 404 # Not found + return {"reason": f"plot_type={plot_type} is not supported."} + return fig.to_json() + @app.put("/api/studies//note") @json_api_view def save_study_note(study_id: int) -> dict[str, Any]: diff --git a/optuna_dashboard/ts/apiClient.ts b/optuna_dashboard/ts/apiClient.ts index 46fe3d0e4..14566a07b 100644 --- a/optuna_dashboard/ts/apiClient.ts +++ b/optuna_dashboard/ts/apiClient.ts @@ -461,3 +461,17 @@ export const getPlotAPI = ( .get(`/api/studies/${studyId}/plot/${plotType}`) .then((res) => res.data) } + +export enum CompareStudiesPlotType { + EDF = "edf", +} +export const getCompareStudiesPlotAPI = ( + studyIds: number[], + plotType: CompareStudiesPlotType +): Promise => { + return axiosInstance + .get(`/api/compare-studies/plot/${plotType}`, { + params: { study_ids: studyIds }, + }) + .then((res) => res.data) +} diff --git a/optuna_dashboard/ts/components/GraphEdf.tsx b/optuna_dashboard/ts/components/GraphEdf.tsx index ca85ee54c..dffa72745 100644 --- a/optuna_dashboard/ts/components/GraphEdf.tsx +++ b/optuna_dashboard/ts/components/GraphEdf.tsx @@ -3,6 +3,8 @@ import React, { FC, useEffect, useMemo } from "react" import { Typography, useTheme, Box } from "@mui/material" import { plotlyDarkTemplate } from "./PlotlyDarkMode" import { Target, useFilteredTrialsFromStudies } from "../trialFilter" +import { getCompareStudiesPlotAPI, CompareStudiesPlotType } from "../apiClient" +import { useBackendRender } from "../state" const getPlotDomId = (objectiveId: number) => `graph-edf-${objectiveId}` @@ -14,6 +16,42 @@ interface EdfPlotInfo { export const GraphEdf: FC<{ studies: StudyDetail[] objectiveId: number +}> = ({ studies, objectiveId }) => { + if (useBackendRender()) { + return + } else { + return + } +} + +const GraphEdfBackend: FC<{ + studies: StudyDetail[] +}> = ({ studies }) => { + const studyIds = studies.map((s) => s.id) + const domId = getPlotDomId(-1) + const numCompletedTrials = studies.reduce( + (acc, study) => + acc + study?.trials.filter((t) => t.state === "Complete").length, + 0 + ) + useEffect(() => { + if (studyIds.length === 0) { + return + } + getCompareStudiesPlotAPI(studyIds, CompareStudiesPlotType.EDF) + .then(({ data, layout }) => { + plotly.react(domId, data, layout) + }) + .catch((err) => { + console.error(err) + }) + }, [studyIds, numCompletedTrials]) + return +} + +const GraphEdfFrontend: FC<{ + studies: StudyDetail[] + objectiveId: number }> = ({ studies, objectiveId }) => { const theme = useTheme() const domId = getPlotDomId(objectiveId)