diff --git a/optuna_dashboard/__init__.py b/optuna_dashboard/__init__.py index 25bcaf783..45d519a31 100644 --- a/optuna_dashboard/__init__.py +++ b/optuna_dashboard/__init__.py @@ -2,6 +2,11 @@ from ._app import wsgi # noqa from ._named_objectives import set_objective_names # noqa from ._note import save_note # noqa +from ._objective_form_widget import ObjectiveChoiceWidget # noqa +from ._objective_form_widget import ObjectiveSliderWidget # noqa +from ._objective_form_widget import ObjectiveTextInputWidget # noqa +from ._objective_form_widget import ObjectiveUserAttrRef # noqa +from ._objective_form_widget import register_objective_form_widgets # noqa __version__ = "0.9.0b2" diff --git a/optuna_dashboard/_objective_form_widget.py b/optuna_dashboard/_objective_form_widget.py new file mode 100644 index 000000000..e4e100752 --- /dev/null +++ b/optuna_dashboard/_objective_form_widget.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from dataclasses import dataclass +import json +from typing import TYPE_CHECKING +from typing import Union + +import optuna + + +if TYPE_CHECKING: + from typing import Any + from typing import Literal + from typing import Optional + from typing import TypedDict + + ChoiceWidgetJSON = TypedDict( + "ChoiceWidgetJSON", + { + "type": Literal["choice"], + "description": Optional[str], + "choices": list[str], + "values": list[float], + }, + ) + SliderWidgetLabel = TypedDict( + "SliderWidgetLabel", + {"value": float, "label": str}, + ) + SliderWidgetJSON = TypedDict( + "SliderWidgetJSON", + { + "type": Literal["slider"], + "description": Optional[str], + "min": float, + "max": float, + "step": Optional[float], + "labels": Optional[list[SliderWidgetLabel]], + }, + ) + TextInputWidgetJSON = TypedDict( + "TextInputWidgetJSON", + {"type": Literal["text"], "description": Optional[str]}, + ) + UserAttrRefJSON = TypedDict("UserAttrRefJSON", {"type": Literal["user_attr"], "key": str}) + ObjectiveFormWidgetJSON = Union[ + ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON + ] + + +@dataclass +class ObjectiveChoiceWidget: + choices: list[str] + values: list[float] + description: Optional[str] = None + + def to_dict(self) -> ChoiceWidgetJSON: + return { + "type": "choice", + "description": self.description, + "choices": self.choices, + "values": self.values, + } + + +@dataclass +class ObjectiveSliderWidget: + min: float + max: float + step: Optional[float] = None + labels: Optional[list[tuple[float, str]]] = None + description: Optional[str] = None + + def to_dict(self) -> SliderWidgetJSON: + labels: Optional[list[SliderWidgetLabel]] = None + if self.labels is not None: + labels = [{"value": value, "label": label} for value, label in self.labels] + return { + "type": "slider", + "description": self.description, + "min": self.min, + "max": self.max, + "step": self.step, + "labels": labels, + } + + +@dataclass +class ObjectiveTextInputWidget: + description: Optional[str] = None + + def to_dict(self) -> TextInputWidgetJSON: + return { + "type": "text", + "description": self.description, + } + + +@dataclass +class ObjectiveUserAttrRef: + key: str + + def to_dict(self) -> UserAttrRefJSON: + return { + "type": "user_attr", + "key": self.key, + } + + +ObjectiveFormWidget = Union[ + ObjectiveChoiceWidget, ObjectiveSliderWidget, ObjectiveTextInputWidget, ObjectiveUserAttrRef +] +SYSTEM_ATTR_KEY = "dashboard:objective_form_widgets" + + +def register_objective_form_widgets( + study: optuna.Study, widgets: list[ObjectiveFormWidget] +) -> None: + if len(study.directions) != len(widgets): + raise ValueError("The length of actions must be the same with the number of objectives.") + widget_dicts = [w.to_dict() for w in widgets] + study._storage.set_study_system_attr( + study._study_id, SYSTEM_ATTR_KEY, json.dumps(widget_dicts) + ) + + +def get_objective_form_widgets_json( + study_system_attr: dict[str, Any] +) -> Optional[list[ObjectiveFormWidgetJSON]]: + widgets_json = study_system_attr.get(SYSTEM_ATTR_KEY) + if widgets_json is None: + return None + return json.loads(widgets_json) diff --git a/optuna_dashboard/_serializer.py b/optuna_dashboard/_serializer.py index 11d3883f9..021705703 100644 --- a/optuna_dashboard/_serializer.py +++ b/optuna_dashboard/_serializer.py @@ -13,6 +13,7 @@ from . import _note as note from ._named_objectives import get_objective_names +from ._objective_form_widget import get_objective_form_widgets_json from .artifact._backend import list_trial_artifacts @@ -143,6 +144,9 @@ def serialize_study_detail( objective_names = get_objective_names(system_attrs) if objective_names: serialized["objective_names"] = objective_names + objective_form_widgets = get_objective_form_widgets_json(system_attrs) + if objective_form_widgets: + serialized["objective_form_widgets"] = objective_form_widgets return serialized diff --git a/optuna_dashboard/ts/action.ts b/optuna_dashboard/ts/action.ts index 41ebff59b..f7d182071 100644 --- a/optuna_dashboard/ts/action.ts +++ b/optuna_dashboard/ts/action.ts @@ -394,12 +394,12 @@ export const actionCreator = () => { trialId: number, state: TrialStateFinished, values?: number[] - ): Promise => { + ): void => { const message = values === undefined ? `id=${trialId}, state=${state}` : `id=${trialId}, state=${state}, values=${values}` - return tellTrialAPI(trialId, state, values) + tellTrialAPI(trialId, state, values) .then(() => { const index = studyDetails[studyId].trials.findIndex( (t) => t.trial_id === trialId diff --git a/optuna_dashboard/ts/apiClient.ts b/optuna_dashboard/ts/apiClient.ts index cb24f1560..e263117c7 100644 --- a/optuna_dashboard/ts/apiClient.ts +++ b/optuna_dashboard/ts/apiClient.ts @@ -67,6 +67,7 @@ interface StudyDetailResponse { has_intermediate_values: boolean note: Note objective_names?: string[] + objective_form_widgets?: ObjectiveFormWidget[] } export const getStudyDetailAPI = ( @@ -99,6 +100,7 @@ export const getStudyDetailAPI = ( has_intermediate_values: res.data.has_intermediate_values, note: res.data.note, objective_names: res.data.objective_names, + objective_form_widgets: res.data.objective_form_widgets, } }) } diff --git a/optuna_dashboard/ts/components/Debounce.tsx b/optuna_dashboard/ts/components/Debounce.tsx index e08d251e5..31e34739a 100644 --- a/optuna_dashboard/ts/components/Debounce.tsx +++ b/optuna_dashboard/ts/components/Debounce.tsx @@ -2,14 +2,15 @@ import React, { FC, useEffect } from "react" import { TextField, TextFieldProps } from "@mui/material" export const DebouncedInputTextField: FC<{ - onChange: (s: string) => void + onChange: (s: string, valid: boolean) => void delay: number textFieldProps: TextFieldProps }> = ({ onChange, delay, textFieldProps }) => { const [text, setText] = React.useState("") + const [valid, setValidity] = React.useState(true) useEffect(() => { const timer = setTimeout(() => { - onChange(text) + onChange(text, valid) }, delay) return () => { clearTimeout(timer) @@ -19,6 +20,7 @@ export const DebouncedInputTextField: FC<{ { setText(e.target.value) + setValidity(e.target.validity.valid) }} {...textFieldProps} /> diff --git a/optuna_dashboard/ts/components/ObjectiveForm.tsx b/optuna_dashboard/ts/components/ObjectiveForm.tsx new file mode 100644 index 000000000..71c62e523 --- /dev/null +++ b/optuna_dashboard/ts/components/ObjectiveForm.tsx @@ -0,0 +1,261 @@ +import React, { FC, useMemo, useState } from "react" +import { + Typography, + Box, + useTheme, + Card, + FormControlLabel, + FormControl, + FormLabel, + Button, + RadioGroup, + Radio, + Slider, + TextField, +} from "@mui/material" +import { DebouncedInputTextField } from "./Debounce" +import { actionCreator } from "../action" + +export const ObjectiveForm: FC<{ + trial: Trial + directions: StudyDirection[] + names: string[] + widgets: ObjectiveFormWidget[] +}> = ({ trial, directions, names, widgets }) => { + const theme = useTheme() + const action = actionCreator() + const [values, setValues] = useState<(number | null)[]>( + directions.map((d, i) => { + const widget = widgets.at(i) + if (widget === undefined) { + return null + } else if (widget.type === "text") { + return null + } else if (widget.type === "choice") { + return widget.values.at(0) || null + } else if (widget.type === "slider") { + return widget.min + } else if (widget.type === "user_attr") { + const attr = trial.user_attrs.find((attr) => attr.key == widget.key) + if (attr === undefined) { + return null + } else { + const n = Number(attr.value) + return isNaN(n) ? null : n + } + } else { + return null + } + }) + ) + + const setValue = (objectiveId: number, value: number | null) => { + const newValues = [...values] + if (newValues.length <= objectiveId) { + return + } + newValues[objectiveId] = value + setValues(newValues) + } + + const disableSubmit = useMemo( + () => values.findIndex((v) => v === null) >= 0, + [values] + ) + + const handleSubmit = (e: React.MouseEvent): void => { + e.preventDefault() + const filtered = values.filter((v): v is number => v !== null) + if (filtered.length !== directions.length) { + return + } + action.tellTrial(trial.study_id, trial.trial_id, "Complete", filtered) + } + + const getObjectiveName = (i: number): string => { + const n = names.at(i) + if (n !== undefined) { + return n + } + if (directions.length == 1) { + return `Objective` + } else { + return `Objective ${i}` + } + } + + return ( + <> + + {directions.length > 1 ? "Set Objective Values" : "Set Objective Value"} + + + + {directions.map((d, i) => { + const widget = widgets.at(i) + const value = values.at(i) + const key = `objective-${i}` + if (widget === undefined) { + return ( + + {getObjectiveName(i)} + { + const n = Number(s) + if (s.length > 0 && valid && !isNaN(n)) { + setValue(i, n) + return + } else if (values.at(i) !== null) { + setValue(i, null) + } + }} + delay={500} + textFieldProps={{ + required: true, + autoFocus: true, + fullWidth: true, + helperText: + value === null || value === undefined + ? `Please input the float number.` + : "", + label: getObjectiveName(i), + type: "text", + }} + /> + + ) + } else if (widget.type === "text") { + return ( + + + {getObjectiveName(i)} - {widget.description} + + { + const n = Number(s) + if (s.length > 0 && valid && !isNaN(n)) { + setValue(i, n) + return + } else if (values.at(i) !== null) { + setValue(i, null) + } + }} + delay={500} + textFieldProps={{ + required: true, + autoFocus: true, + fullWidth: true, + helperText: + value === null || value === undefined + ? `Please input the float number.` + : "", + type: "text", + inputProps: { + pattern: "[-+]?[0-9]*.?[0-9]+([eE][-+]?[0-9]+)?", + }, + }} + /> + + ) + } else if (widget.type === "choice") { + return ( + + + {getObjectiveName(i)} - {widget.description} + + + {widget.choices.map((c, i) => ( + } + label={c} + /> + ))} + + + ) + } else if (widget.type === "slider") { + return ( + + + {getObjectiveName(i)} - {widget.description} + + + + + + ) + } else if (widget.type === "user_attr") { + return ( + + {getObjectiveName(i)} + + + ) + } + return null + })} + + + + + + + + + ) +} diff --git a/optuna_dashboard/ts/components/TrialList.tsx b/optuna_dashboard/ts/components/TrialList.tsx index 6e3dec57d..070fe2dfa 100644 --- a/optuna_dashboard/ts/components/TrialList.tsx +++ b/optuna_dashboard/ts/components/TrialList.tsx @@ -42,6 +42,7 @@ import { useRecoilValue } from "recoil" import { artifactIsAvailable } from "../state" import { actionCreator } from "../action" import { useDeleteArtifactDialog } from "./DeleteArtifactDialog" +import { ObjectiveForm } from "./ObjectiveForm" const states: TrialState[] = [ "Complete", @@ -143,15 +144,20 @@ const useIsBestTrial = ( const TrialListDetail: FC<{ trial: Trial isBestTrial: (trialId: number) => boolean -}> = ({ trial, isBestTrial }) => { + directions: StudyDirection[] + objectiveNames: string[] + objectiveFormWidgets: ObjectiveFormWidget[] +}> = ({ + trial, + isBestTrial, + directions, + objectiveNames, + objectiveFormWidgets, +}) => { const theme = useTheme() const artifactEnabled = useRecoilValue(artifactIsAvailable) const startMs = trial.datetime_start?.getTime() const completeMs = trial.datetime_complete?.getTime() - let duration = "" - if (startMs !== undefined && completeMs !== undefined) { - duration = (completeMs - startMs).toString() - } const params = trial.state === "Waiting" ? trial.fixed_params : trial.params const info: [string, string | null | ReactNode][] = [ @@ -184,7 +190,12 @@ const TrialListDetail: FC<{ "Completed At", trial?.datetime_complete ? trial?.datetime_complete.toString() : null, ], - ["Duration", `${duration} ms`], + [ + "Duration (ms)", + startMs !== undefined && completeMs !== undefined + ? (completeMs - startMs).toString() + : null, + ], [ "User Attributes", @@ -264,7 +275,9 @@ const TrialListDetail: FC<{ flexDirection: "column", }} > - {info.map(([key, value]) => renderInfo(key, value))} + {info.map(([key, value]) => + value !== null ? renderInfo(key, value) : null + )} + {trial.state === "Running" && directions.length > 0 && ( + + )} {artifactEnabled && } ) @@ -716,6 +737,11 @@ export const TrialList: FC<{ studyDetail: StudyDetail | null }> = ({ key={t.trial_id} trial={t} isBestTrial={isBestTrial} + directions={studyDetail?.directions || []} + objectiveNames={studyDetail?.objective_names || []} + objectiveFormWidgets={ + studyDetail?.objective_form_widgets || [] + } /> ))} diff --git a/optuna_dashboard/ts/types/index.d.ts b/optuna_dashboard/ts/types/index.d.ts index 20c9fd486..1d30bc574 100644 --- a/optuna_dashboard/ts/types/index.d.ts +++ b/optuna_dashboard/ts/types/index.d.ts @@ -125,6 +125,43 @@ type StudySummary = { datetime_start?: Date } +type ObjectiveChoiceWidget = { + type: "choice" + description: string + choices: string[] + values: number[] +} + +type ObjectiveSliderWidget = { + type: "slider" + description: string + min: number + max: number + step: number + labels: + | { + value: number + label: string + }[] + | null +} + +type ObjectiveTextInputWidget = { + type: "text" + description: string +} + +type ObjectiveUserAttrRef = { + type: "user_attr" + key: string +} + +type ObjectiveFormWidget = + | ObjectiveChoiceWidget + | ObjectiveSliderWidget + | ObjectiveTextInputWidget + | ObjectiveUserAttrRef + type StudyDetail = { id: number name: string @@ -138,6 +175,7 @@ type StudyDetail = { has_intermediate_values: boolean note: Note objective_names?: string[] + objective_form_widgets?: ObjectiveFormWidget[] } type StudyDetails = {