Skip to content

Commit

Permalink
Merge pull request optuna#792 from knshnb/edf-backend
Browse files Browse the repository at this point in the history
Render EDF in backend
  • Loading branch information
c-bata authored Feb 6, 2024
2 parents e5c5ae2 + b5c1260 commit cfadf70
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 0 deletions.
15 changes: 15 additions & 0 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,21 @@ def get_plot(study_id: int, plot_type: str) -> dict[str, Any]:
return {"reason": f"plot_type={plot_type} is not supported."}
return fig.to_json()

@app.get("/api/compare-studies/plot/<plot_type>")
@json_api_view
def get_compare_studies_plot(plot_type: str) -> dict[str, Any]:
study_ids = map(int, request.query.getall("study_ids[]"))
studies = [
optuna.load_study(study_name=storage.get_study_name_from_id(study_id), storage=storage)
for study_id in study_ids
]
if plot_type == "edf":
fig = optuna.visualization.plot_edf(studies)
else:
response.status = 404 # Not found
return {"reason": f"plot_type={plot_type} is not supported."}
return fig.to_json()

@app.put("/api/studies/<study_id:int>/note")
@json_api_view
def save_study_note(study_id: int) -> dict[str, Any]:
Expand Down
14 changes: 14 additions & 0 deletions optuna_dashboard/ts/apiClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,17 @@ export const getPlotAPI = (
.get<PlotResponse>(`/api/studies/${studyId}/plot/${plotType}`)
.then<PlotResponse>((res) => res.data)
}

export enum CompareStudiesPlotType {
EDF = "edf",
}
export const getCompareStudiesPlotAPI = (
studyIds: number[],
plotType: CompareStudiesPlotType
): Promise<PlotResponse> => {
return axiosInstance
.get<PlotResponse>(`/api/compare-studies/plot/${plotType}`, {
params: { study_ids: studyIds },
})
.then<PlotResponse>((res) => res.data)
}
38 changes: 38 additions & 0 deletions optuna_dashboard/ts/components/GraphEdf.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import React, { FC, useEffect, useMemo } from "react"
import { Typography, useTheme, Box } from "@mui/material"
import { plotlyDarkTemplate } from "./PlotlyDarkMode"
import { Target, useFilteredTrialsFromStudies } from "../trialFilter"
import { getCompareStudiesPlotAPI, CompareStudiesPlotType } from "../apiClient"
import { useBackendRender } from "../state"

const getPlotDomId = (objectiveId: number) => `graph-edf-${objectiveId}`

Expand All @@ -14,6 +16,42 @@ interface EdfPlotInfo {
export const GraphEdf: FC<{
studies: StudyDetail[]
objectiveId: number
}> = ({ studies, objectiveId }) => {
if (useBackendRender()) {
return <GraphEdfBackend studies={studies} />
} else {
return <GraphEdfFrontend studies={studies} objectiveId={objectiveId} />
}
}

const GraphEdfBackend: FC<{
studies: StudyDetail[]
}> = ({ studies }) => {
const studyIds = studies.map((s) => s.id)
const domId = getPlotDomId(-1)
const numCompletedTrials = studies.reduce(
(acc, study) =>
acc + study?.trials.filter((t) => t.state === "Complete").length,
0
)
useEffect(() => {
if (studyIds.length === 0) {
return
}
getCompareStudiesPlotAPI(studyIds, CompareStudiesPlotType.EDF)
.then(({ data, layout }) => {
plotly.react(domId, data, layout)
})
.catch((err) => {
console.error(err)
})
}, [studyIds, numCompletedTrials])
return <Box id={domId} sx={{ height: "450px" }} />
}

const GraphEdfFrontend: FC<{
studies: StudyDetail[]
objectiveId: number
}> = ({ studies, objectiveId }) => {
const theme = useTheme()
const domId = getPlotDomId(objectiveId)
Expand Down

0 comments on commit cfadf70

Please sign in to comment.