diff --git a/optuna_dashboard/_app.py b/optuna_dashboard/_app.py index ac0e1cf0e..1d4c9f7c0 100644 --- a/optuna_dashboard/_app.py +++ b/optuna_dashboard/_app.py @@ -283,6 +283,12 @@ def get_plot(study_id: int, plot_type: str) -> dict[str, Any]: fig = optuna.visualization.plot_rank(study) elif plot_type == "edf": fig = optuna.visualization.plot_edf(study) + elif plot_type == "timeline": + fig = optuna.visualization.plot_timeline(study) + elif plot_type == "param_importances": + fig = optuna.visualization.plot_param_importances(study) + elif plot_type == "pareto_front": + fig = optuna.visualization.plot_pareto_front(study) else: response.status = 404 # Not found return {"reason": f"plot_type={plot_type} is not supported."} diff --git a/optuna_dashboard/ts/apiClient.ts b/optuna_dashboard/ts/apiClient.ts index 14566a07b..cf775fd61 100644 --- a/optuna_dashboard/ts/apiClient.ts +++ b/optuna_dashboard/ts/apiClient.ts @@ -452,6 +452,9 @@ export enum PlotType { ParallelCoordinate = "parallel_coordinate", Rank = "rank", EDF = "edf", + Timeline = "timeline", + ParamImportances = "param_importances", + ParetoFront = "pareto_front", } export const getPlotAPI = ( studyId: number, diff --git a/optuna_dashboard/ts/components/GraphHyperparameterImportances.tsx b/optuna_dashboard/ts/components/GraphHyperparameterImportances.tsx index 88cce45f8..12ef8d2c9 100644 --- a/optuna_dashboard/ts/components/GraphHyperparameterImportances.tsx +++ b/optuna_dashboard/ts/components/GraphHyperparameterImportances.tsx @@ -3,7 +3,13 @@ import React, { FC, useEffect } from "react" import { Typography, useTheme, Box, Card, CardContent } from "@mui/material" import { useParamImportance } from "../hooks/useParamImportance" -import { useStudyDirections, usePlotlyColorTheme } from "../state" +import { + useStudyDirections, + usePlotlyColorTheme, + useBackendRender, +} from "../state" +import { PlotType } from "../apiClient" +import { usePlot } from "../hooks/usePlot" const plotDomId = "graph-hyperparameter-importances" @@ -11,6 +17,57 @@ export const GraphHyperparameterImportance: FC<{ studyId: number study: StudyDetail | null graphHeight: string +}> = ({ studyId, study = null, graphHeight }) => { + if (useBackendRender()) { + return ( + + ) + } else { + return ( + + ) + } +} + +const GraphHyperparameterImportanceBackend: FC<{ + studyId: number + study: StudyDetail | null + graphHeight: string +}> = ({ studyId, study = null, graphHeight }) => { + const numCompletedTrials = + study?.trials.filter((t) => t.state === "Complete").length || 0 + const { data, layout, error } = usePlot({ + numCompletedTrials, + studyId, + plotType: PlotType.ParamImportances, + }) + + useEffect(() => { + if (data && layout) { + plotly.react(plotDomId, data, layout) + } + }, [data, layout]) + useEffect(() => { + if (error) { + console.error(error) + } + }, [error]) + + return +} + +const GraphHyperparameterImportanceFrontend: FC<{ + studyId: number + study: StudyDetail | null + graphHeight: string }> = ({ studyId, study = null, graphHeight }) => { const theme = useTheme() const colorTheme = usePlotlyColorTheme(theme.palette.mode) diff --git a/optuna_dashboard/ts/components/GraphParetoFront.tsx b/optuna_dashboard/ts/components/GraphParetoFront.tsx index 61ad7b9e3..9ddec25be 100644 --- a/optuna_dashboard/ts/components/GraphParetoFront.tsx +++ b/optuna_dashboard/ts/components/GraphParetoFront.tsx @@ -14,11 +14,50 @@ import { import { makeHovertext } from "../graphUtil" import { usePlotlyColorTheme } from "../state" import { useNavigate } from "react-router-dom" +import { PlotType } from "../apiClient" +import { useBackendRender } from "../state" +import { usePlot } from "../hooks/usePlot" const plotDomId = "graph-pareto-front" export const GraphParetoFront: FC<{ study: StudyDetail | null +}> = ({ study = null }) => { + if (useBackendRender()) { + return + } else { + return + } +} + +const GraphParetoFrontBackend: FC<{ + study: StudyDetail | null +}> = ({ study = null }) => { + const studyId = study?.id + const numCompletedTrials = + study?.trials.filter((t) => t.state === "Complete").length || 0 + const { data, layout, error } = usePlot({ + numCompletedTrials, + studyId, + plotType: PlotType.ParetoFront, + }) + + useEffect(() => { + if (data && layout) { + plotly.react(plotDomId, data, layout) + } + }, [data, layout]) + useEffect(() => { + if (error) { + console.error(error) + } + }, [error]) + + return +} + +const GraphParetoFrontFrontend: FC<{ + study: StudyDetail | null }> = ({ study = null }) => { const theme = useTheme() const colorTheme = usePlotlyColorTheme(theme.palette.mode) diff --git a/optuna_dashboard/ts/components/GraphTimeline.tsx b/optuna_dashboard/ts/components/GraphTimeline.tsx index e9de2c714..519399bb2 100644 --- a/optuna_dashboard/ts/components/GraphTimeline.tsx +++ b/optuna_dashboard/ts/components/GraphTimeline.tsx @@ -3,12 +3,55 @@ import React, { FC, useEffect } from "react" import { Card, CardContent, Grid, Typography, useTheme } from "@mui/material" import { makeHovertext } from "../graphUtil" import { usePlotlyColorTheme } from "../state" +import { PlotType } from "../apiClient" +import { useBackendRender } from "../state" +import { usePlot } from "../hooks/usePlot" const plotDomId = "graph-timeline" const maxBars = 100 export const GraphTimeline: FC<{ study: StudyDetail | null +}> = ({ study }) => { + if (useBackendRender()) { + return + } else { + return + } +} + +const GraphTimelineBackend: FC<{ + study: StudyDetail | null +}> = ({ study }) => { + const studyId = study?.id + const numCompletedTrials = + study?.trials.filter((t) => t.state === "Complete").length || 0 + const { data, layout, error } = usePlot({ + numCompletedTrials, + studyId, + plotType: PlotType.Timeline, + }) + + useEffect(() => { + if (data && layout) { + plotly.react(plotDomId, data, layout) + } + }, [data, layout]) + useEffect(() => { + if (error) { + console.error(error) + } + }, [error]) + + return ( + +
+ + ) +} + +const GraphTimelineFrontend: FC<{ + study: StudyDetail | null }> = ({ study }) => { const theme = useTheme() const colorTheme = usePlotlyColorTheme(theme.palette.mode)