Skip to content

Commit

Permalink
Split ObjectiveForm component
Browse files Browse the repository at this point in the history
  • Loading branch information
c-bata committed Jan 27, 2023
1 parent c51ceef commit 935dd9d
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 261 deletions.
19 changes: 12 additions & 7 deletions optuna_dashboard/_objective_form_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
"values": list[float],
},
)
SliderWidgetLabels = TypedDict(
"SliderWidgetLabels",
SliderWidgetLabel = TypedDict(
"SliderWidgetLabel",
{"value": float, "label": str},
)
SliderWidgetJSON = TypedDict(
Expand All @@ -35,7 +35,7 @@
"min": float,
"max": float,
"step": Optional[float],
"labels": Optional[list[SliderWidgetLabels]],
"labels": Optional[list[SliderWidgetLabel]],
},
)
TextInputWidgetJSON = TypedDict(
Expand Down Expand Up @@ -67,18 +67,21 @@ def to_dict(self) -> ChoiceWidgetJSON:
class ObjectiveSliderWidget:
min: float
max: float
step: Optional[float]
labels: Optional[list[tuple[float, str]]]
step: Optional[float] = None
labels: Optional[list[tuple[float, str]]] = None
description: Optional[str] = None

def to_dict(self) -> SliderWidgetJSON:
labels: Optional[list[SliderWidgetLabel]] = None
if self.labels is not None:
labels = [{"value": value, "label": label} for value, label in self.labels]
return {
"type": "slider",
"description": self.description,
"min": self.min,
"max": self.max,
"step": self.step,
"labels": [{"value": value, "label": label} for value, label in self.labels],
"labels": labels,
}


Expand Down Expand Up @@ -110,7 +113,9 @@ def to_dict(self) -> UserAttrRefJSON:
SYSTEM_ATTR_KEY = "dashboard:objective_form_widgets"


def register_objective_form_widgets(study: optuna.Study, widgets: list[ObjectiveFormWidget]):
def register_objective_form_widgets(
study: optuna.Study, widgets: list[ObjectiveFormWidget]
) -> None:
if len(study.directions) != len(widgets):
raise ValueError("The length of actions must be the same with the number of objectives.")
widget_dicts = [w.to_dict() for w in widgets]
Expand Down
261 changes: 261 additions & 0 deletions optuna_dashboard/ts/components/ObjectiveForm.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
import React, { FC, useMemo, useState } from "react"
import {
Typography,
Box,
useTheme,
Card,
FormControlLabel,
FormControl,
FormLabel,
Button,
RadioGroup,
Radio,
Slider,
TextField,
} from "@mui/material"
import { DebouncedInputTextField } from "./Debounce"
import { actionCreator } from "../action"

export const ObjectiveForm: FC<{
trial: Trial
directions: StudyDirection[]
names: string[]
widgets: ObjectiveFormWidget[]
}> = ({ trial, directions, names, widgets }) => {
const theme = useTheme()
const action = actionCreator()
const [values, setValues] = useState<(number | null)[]>(
directions.map((d, i) => {
const widget = widgets.at(i)
if (widget === undefined) {
return null
} else if (widget.type === "text") {
return null
} else if (widget.type === "choice") {
return widget.values.at(0) || null
} else if (widget.type === "slider") {
return widget.min
} else if (widget.type === "user_attr") {
const attr = trial.user_attrs.find((attr) => attr.key == widget.key)
if (attr === undefined) {
return null
} else {
const n = Number(attr.value)
return isNaN(n) ? null : n
}
} else {
return null
}
})
)

const setValue = (objectiveId: number, value: number | null) => {
const newValues = [...values]
if (newValues.length <= objectiveId) {
return
}
newValues[objectiveId] = value
setValues(newValues)
}

const disableSubmit = useMemo<boolean>(
() => values.findIndex((v) => v === null) >= 0,
[values]
)

const handleSubmit = (e: React.MouseEvent<HTMLButtonElement>): void => {
e.preventDefault()
const filtered = values.filter<number>((v): v is number => v !== null)
if (filtered.length !== directions.length) {
return
}
action.tellTrial(trial.study_id, trial.trial_id, "Complete", filtered)
}

const getObjectiveName = (i: number): string => {
const n = names.at(i)
if (n !== undefined) {
return n
}
if (directions.length == 1) {
return `Objective`
} else {
return `Objective ${i}`
}
}

return (
<>
<Typography
variant="h5"
sx={{ fontWeight: theme.typography.fontWeightBold }}
>
{directions.length > 1 ? "Set Objective Values" : "Set Objective Value"}
</Typography>
<Box sx={{ p: theme.spacing(1, 0) }}>
<Card
sx={{
display: "flex",
flexDirection: "column",
marginBottom: theme.spacing(2),
margin: theme.spacing(0, 1, 1, 0),
p: theme.spacing(1),
}}
>
{directions.map((d, i) => {
const widget = widgets.at(i)
const value = values.at(i)
const key = `objective-${i}`
if (widget === undefined) {
return (
<FormControl key={key} sx={{ margin: theme.spacing(1, 2) }}>
<FormLabel>{getObjectiveName(i)}</FormLabel>
<DebouncedInputTextField
onChange={(s, valid) => {
const n = Number(s)
if (s.length > 0 && valid && !isNaN(n)) {
setValue(i, n)
return
} else if (values.at(i) !== null) {
setValue(i, null)
}
}}
delay={500}
textFieldProps={{
required: true,
autoFocus: true,
fullWidth: true,
helperText:
value === null || value === undefined
? `Please input the float number.`
: "",
label: getObjectiveName(i),
type: "text",
}}
/>
</FormControl>
)
} else if (widget.type === "text") {
return (
<FormControl key={key} sx={{ margin: theme.spacing(1, 2) }}>
<FormLabel>
{getObjectiveName(i)} - {widget.description}
</FormLabel>
<DebouncedInputTextField
onChange={(s, valid) => {
const n = Number(s)
if (s.length > 0 && valid && !isNaN(n)) {
setValue(i, n)
return
} else if (values.at(i) !== null) {
setValue(i, null)
}
}}
delay={500}
textFieldProps={{
required: true,
autoFocus: true,
fullWidth: true,
helperText:
value === null || value === undefined
? `Please input the float number.`
: "",
type: "text",
inputProps: {
pattern: "[-+]?[0-9]*.?[0-9]+([eE][-+]?[0-9]+)?",
},
}}
/>
</FormControl>
)
} else if (widget.type === "choice") {
return (
<FormControl key={key} sx={{ margin: theme.spacing(1, 2) }}>
<FormLabel>
{getObjectiveName(i)} - {widget.description}
</FormLabel>
<RadioGroup row defaultValue={widget.values.at(0)}>
{widget.choices.map((c, i) => (
<FormControlLabel
key={c}
value={widget.values.at(i)}
control={<Radio />}
label={c}
/>
))}
</RadioGroup>
</FormControl>
)
} else if (widget.type === "slider") {
return (
<FormControl key={key} sx={{ margin: theme.spacing(1, 2) }}>
<FormLabel>
{getObjectiveName(i)} - {widget.description}
</FormLabel>
<Box sx={{ padding: theme.spacing(0, 2) }}>
<Slider
defaultValue={widget.min}
min={widget.min}
max={widget.max}
step={widget.step}
marks={
widget.labels === null || widget.labels.length == 0
? true
: widget.labels
}
valueLabelDisplay="auto"
/>
</Box>
</FormControl>
)
} else if (widget.type === "user_attr") {
return (
<FormControl key={key} sx={{ margin: theme.spacing(1, 2) }}>
<FormLabel>{getObjectiveName(i)}</FormLabel>
<TextField
inputProps={{ readOnly: true }}
value={value || undefined}
error={value === null}
helperText={
value === null || value === undefined
? `This objective value is referred from trial.user_attrs[${widget.key}].`
: ""
}
/>
</FormControl>
)
}
return null
})}
<Box
sx={{
display: "flex",
flexDirection: "row",
margin: theme.spacing(1, 2),
}}
>
<Button
variant="contained"
type="submit"
sx={{ marginRight: theme.spacing(1) }}
disabled={disableSubmit}
onClick={handleSubmit}
>
Submit
</Button>
<Box sx={{ flexGrow: 1 }} />
<Button
variant="outlined"
color="error"
onClick={() => {
action.tellTrial(trial.study_id, trial.trial_id, "Fail")
}}
>
Fail Trial
</Button>
</Box>
</Card>
</Box>
</>
)
}
Loading

0 comments on commit 935dd9d

Please sign in to comment.