diff --git a/optuna_dashboard/_objective_form_widget.py b/optuna_dashboard/_objective_form_widget.py index 0b89e6186..e4e100752 100644 --- a/optuna_dashboard/_objective_form_widget.py +++ b/optuna_dashboard/_objective_form_widget.py @@ -23,8 +23,8 @@ "values": list[float], }, ) - SliderWidgetLabels = TypedDict( - "SliderWidgetLabels", + SliderWidgetLabel = TypedDict( + "SliderWidgetLabel", {"value": float, "label": str}, ) SliderWidgetJSON = TypedDict( @@ -35,7 +35,7 @@ "min": float, "max": float, "step": Optional[float], - "labels": Optional[list[SliderWidgetLabels]], + "labels": Optional[list[SliderWidgetLabel]], }, ) TextInputWidgetJSON = TypedDict( @@ -67,18 +67,21 @@ def to_dict(self) -> ChoiceWidgetJSON: class ObjectiveSliderWidget: min: float max: float - step: Optional[float] - labels: Optional[list[tuple[float, str]]] + step: Optional[float] = None + labels: Optional[list[tuple[float, str]]] = None description: Optional[str] = None def to_dict(self) -> SliderWidgetJSON: + labels: Optional[list[SliderWidgetLabel]] = None + if self.labels is not None: + labels = [{"value": value, "label": label} for value, label in self.labels] return { "type": "slider", "description": self.description, "min": self.min, "max": self.max, "step": self.step, - "labels": [{"value": value, "label": label} for value, label in self.labels], + "labels": labels, } @@ -110,7 +113,9 @@ def to_dict(self) -> UserAttrRefJSON: SYSTEM_ATTR_KEY = "dashboard:objective_form_widgets" -def register_objective_form_widgets(study: optuna.Study, widgets: list[ObjectiveFormWidget]): +def register_objective_form_widgets( + study: optuna.Study, widgets: list[ObjectiveFormWidget] +) -> None: if len(study.directions) != len(widgets): raise ValueError("The length of actions must be the same with the number of objectives.") widget_dicts = [w.to_dict() for w in widgets] diff --git a/optuna_dashboard/ts/components/ObjectiveForm.tsx b/optuna_dashboard/ts/components/ObjectiveForm.tsx new file mode 100644 index 000000000..71c62e523 --- /dev/null +++ b/optuna_dashboard/ts/components/ObjectiveForm.tsx @@ -0,0 +1,261 @@ +import React, { FC, useMemo, useState } from "react" +import { + Typography, + Box, + useTheme, + Card, + FormControlLabel, + FormControl, + FormLabel, + Button, + RadioGroup, + Radio, + Slider, + TextField, +} from "@mui/material" +import { DebouncedInputTextField } from "./Debounce" +import { actionCreator } from "../action" + +export const ObjectiveForm: FC<{ + trial: Trial + directions: StudyDirection[] + names: string[] + widgets: ObjectiveFormWidget[] +}> = ({ trial, directions, names, widgets }) => { + const theme = useTheme() + const action = actionCreator() + const [values, setValues] = useState<(number | null)[]>( + directions.map((d, i) => { + const widget = widgets.at(i) + if (widget === undefined) { + return null + } else if (widget.type === "text") { + return null + } else if (widget.type === "choice") { + return widget.values.at(0) || null + } else if (widget.type === "slider") { + return widget.min + } else if (widget.type === "user_attr") { + const attr = trial.user_attrs.find((attr) => attr.key == widget.key) + if (attr === undefined) { + return null + } else { + const n = Number(attr.value) + return isNaN(n) ? null : n + } + } else { + return null + } + }) + ) + + const setValue = (objectiveId: number, value: number | null) => { + const newValues = [...values] + if (newValues.length <= objectiveId) { + return + } + newValues[objectiveId] = value + setValues(newValues) + } + + const disableSubmit = useMemo( + () => values.findIndex((v) => v === null) >= 0, + [values] + ) + + const handleSubmit = (e: React.MouseEvent): void => { + e.preventDefault() + const filtered = values.filter((v): v is number => v !== null) + if (filtered.length !== directions.length) { + return + } + action.tellTrial(trial.study_id, trial.trial_id, "Complete", filtered) + } + + const getObjectiveName = (i: number): string => { + const n = names.at(i) + if (n !== undefined) { + return n + } + if (directions.length == 1) { + return `Objective` + } else { + return `Objective ${i}` + } + } + + return ( + <> + + {directions.length > 1 ? "Set Objective Values" : "Set Objective Value"} + + + + {directions.map((d, i) => { + const widget = widgets.at(i) + const value = values.at(i) + const key = `objective-${i}` + if (widget === undefined) { + return ( + + {getObjectiveName(i)} + { + const n = Number(s) + if (s.length > 0 && valid && !isNaN(n)) { + setValue(i, n) + return + } else if (values.at(i) !== null) { + setValue(i, null) + } + }} + delay={500} + textFieldProps={{ + required: true, + autoFocus: true, + fullWidth: true, + helperText: + value === null || value === undefined + ? `Please input the float number.` + : "", + label: getObjectiveName(i), + type: "text", + }} + /> + + ) + } else if (widget.type === "text") { + return ( + + + {getObjectiveName(i)} - {widget.description} + + { + const n = Number(s) + if (s.length > 0 && valid && !isNaN(n)) { + setValue(i, n) + return + } else if (values.at(i) !== null) { + setValue(i, null) + } + }} + delay={500} + textFieldProps={{ + required: true, + autoFocus: true, + fullWidth: true, + helperText: + value === null || value === undefined + ? `Please input the float number.` + : "", + type: "text", + inputProps: { + pattern: "[-+]?[0-9]*.?[0-9]+([eE][-+]?[0-9]+)?", + }, + }} + /> + + ) + } else if (widget.type === "choice") { + return ( + + + {getObjectiveName(i)} - {widget.description} + + + {widget.choices.map((c, i) => ( + } + label={c} + /> + ))} + + + ) + } else if (widget.type === "slider") { + return ( + + + {getObjectiveName(i)} - {widget.description} + + + + + + ) + } else if (widget.type === "user_attr") { + return ( + + {getObjectiveName(i)} + + + ) + } + return null + })} + + + + + + + + + ) +} diff --git a/optuna_dashboard/ts/components/TrialList.tsx b/optuna_dashboard/ts/components/TrialList.tsx index b59ad462a..070fe2dfa 100644 --- a/optuna_dashboard/ts/components/TrialList.tsx +++ b/optuna_dashboard/ts/components/TrialList.tsx @@ -19,13 +19,6 @@ import { CardContent, CardMedia, CardActionArea, - TextField, - Button, - Slider, - FormControl, - FormLabel, - RadioGroup, - Radio, } from "@mui/material" import Chip from "@mui/material/Chip" import Divider from "@mui/material/Divider" @@ -49,8 +42,7 @@ import { useRecoilValue } from "recoil" import { artifactIsAvailable } from "../state" import { actionCreator } from "../action" import { useDeleteArtifactDialog } from "./DeleteArtifactDialog" -import FormControlLabel from "@mui/material/FormControlLabel" -import { DebouncedInputTextField } from "./Debounce" +import { ObjectiveForm } from "./ObjectiveForm" const states: TrialState[] = [ "Complete", @@ -303,7 +295,7 @@ const TrialListDetail: FC<{ cardSx={{ marginBottom: theme.spacing(2) }} /> {trial.state === "Running" && directions.length > 0 && ( - = ({ trial, directions, names, widgets }) => { - const theme = useTheme() - const action = actionCreator() - const [values, setValues] = useState<(number | null)[]>( - directions.map((d, i) => { - const widget = widgets.at(i) - if (widget === undefined) { - return null - } else if (widget.type === "text") { - return null - } else if (widget.type === "choice") { - return widget.values.at(0) || null - } else if (widget.type === "slider") { - return widget.min - } else if (widget.type === "user_attr") { - const attr = trial.user_attrs.find((attr) => attr.key == widget.key) - if (attr === undefined) { - return null - } else { - const n = Number(attr.value) - return isNaN(n) ? null : n - } - } else { - return null - } - }) - ) - - const setValue = (objectiveId: number, value: number | null) => { - const newValues = [...values] - if (newValues.length <= objectiveId) { - return - } - newValues[objectiveId] = value - setValues(newValues) - } - - const disableSubmit = useMemo( - () => values.findIndex((v) => v === null) >= 0, - [values] - ) - - const handleSubmit = (e: React.MouseEvent): void => { - e.preventDefault() - const filtered = values.filter((v): v is number => v !== null) - if (filtered.length !== directions.length) { - return - } - action.tellTrial(trial.study_id, trial.trial_id, "Complete", filtered) - } - - const getObjectiveName = (i: number): string => { - const n = names.at(i) - if (n !== undefined) { - return n - } - if (directions.length == 1) { - return `Objective` - } else { - return `Objective ${i}` - } - } - - return ( - <> - - {directions.length > 1 ? "Set Objective Values" : "Set Objective Value"} - - - - {directions.map((d, i) => { - const widget = widgets.at(i) - const value = values.at(i) - const key = `objective-${i}` - if (widget === undefined) { - return ( - - {getObjectiveName(i)} - { - const n = Number(s) - if (s.length > 0 && valid && !isNaN(n)) { - setValue(i, n) - return - } else if (values.at(i) !== null) { - setValue(i, null) - } - }} - delay={500} - textFieldProps={{ - required: true, - autoFocus: true, - fullWidth: true, - helperText: - value === null || value === undefined - ? `Please input the float number.` - : "", - label: getObjectiveName(i), - type: "text", - }} - /> - - ) - } else if (widget.type === "text") { - return ( - - - {getObjectiveName(i)} - {widget.description} - - { - const n = Number(s) - if (s.length > 0 && valid && !isNaN(n)) { - setValue(i, n) - return - } else if (values.at(i) !== null) { - setValue(i, null) - } - }} - delay={500} - textFieldProps={{ - required: true, - autoFocus: true, - fullWidth: true, - helperText: - value === null || value === undefined - ? `Please input the float number.` - : "", - type: "text", - inputProps: { - pattern: "[-+]?[0-9]*.?[0-9]+([eE][-+]?[0-9]+)?", - }, - }} - /> - - ) - } else if (widget.type === "choice") { - return ( - - - {getObjectiveName(i)} - {widget.description} - - - {widget.choices.map((c, i) => ( - } - label={c} - /> - ))} - - - ) - } else if (widget.type === "slider") { - return ( - - - {getObjectiveName(i)} - {widget.description} - - - 0 ? widget.labels : true} - valueLabelDisplay="auto" - /> - - - ) - } else if (widget.type === "user_attr") { - return ( - - {getObjectiveName(i)} - - - ) - } - return null - })} - - - - - - - - - ) -} - const TrialArtifact: FC<{ trial: Trial }> = ({ trial }) => { const theme = useTheme() const action = actionCreator() diff --git a/optuna_dashboard/ts/types/index.d.ts b/optuna_dashboard/ts/types/index.d.ts index ee12f4fd3..1d30bc574 100644 --- a/optuna_dashboard/ts/types/index.d.ts +++ b/optuna_dashboard/ts/types/index.d.ts @@ -138,10 +138,12 @@ type ObjectiveSliderWidget = { min: number max: number step: number - labels: { - value: number - label: string - }[] + labels: + | { + value: number + label: string + }[] + | null } type ObjectiveTextInputWidget = {