Skip to content

Commit

Permalink
Merge pull request #370 from c-bata/custom-widget
Browse files Browse the repository at this point in the history
Introducing Custom User Widget
  • Loading branch information
c-bata authored Jan 27, 2023
2 parents 7b1045d + 935dd9d commit 5bf6c38
Show file tree
Hide file tree
Showing 9 changed files with 482 additions and 11 deletions.
5 changes: 5 additions & 0 deletions optuna_dashboard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from ._app import wsgi # noqa
from ._named_objectives import set_objective_names # noqa
from ._note import save_note # noqa
from ._objective_form_widget import ObjectiveChoiceWidget # noqa
from ._objective_form_widget import ObjectiveSliderWidget # noqa
from ._objective_form_widget import ObjectiveTextInputWidget # noqa
from ._objective_form_widget import ObjectiveUserAttrRef # noqa
from ._objective_form_widget import register_objective_form_widgets # noqa


__version__ = "0.9.0b2"
133 changes: 133 additions & 0 deletions optuna_dashboard/_objective_form_widget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from __future__ import annotations

from dataclasses import dataclass
import json
from typing import TYPE_CHECKING
from typing import Union

import optuna


if TYPE_CHECKING:
from typing import Any
from typing import Literal
from typing import Optional
from typing import TypedDict

ChoiceWidgetJSON = TypedDict(
"ChoiceWidgetJSON",
{
"type": Literal["choice"],
"description": Optional[str],
"choices": list[str],
"values": list[float],
},
)
SliderWidgetLabel = TypedDict(
"SliderWidgetLabel",
{"value": float, "label": str},
)
SliderWidgetJSON = TypedDict(
"SliderWidgetJSON",
{
"type": Literal["slider"],
"description": Optional[str],
"min": float,
"max": float,
"step": Optional[float],
"labels": Optional[list[SliderWidgetLabel]],
},
)
TextInputWidgetJSON = TypedDict(
"TextInputWidgetJSON",
{"type": Literal["text"], "description": Optional[str]},
)
UserAttrRefJSON = TypedDict("UserAttrRefJSON", {"type": Literal["user_attr"], "key": str})
ObjectiveFormWidgetJSON = Union[
ChoiceWidgetJSON, SliderWidgetJSON, TextInputWidgetJSON, UserAttrRefJSON
]


@dataclass
class ObjectiveChoiceWidget:
choices: list[str]
values: list[float]
description: Optional[str] = None

def to_dict(self) -> ChoiceWidgetJSON:
return {
"type": "choice",
"description": self.description,
"choices": self.choices,
"values": self.values,
}


@dataclass
class ObjectiveSliderWidget:
min: float
max: float
step: Optional[float] = None
labels: Optional[list[tuple[float, str]]] = None
description: Optional[str] = None

def to_dict(self) -> SliderWidgetJSON:
labels: Optional[list[SliderWidgetLabel]] = None
if self.labels is not None:
labels = [{"value": value, "label": label} for value, label in self.labels]
return {
"type": "slider",
"description": self.description,
"min": self.min,
"max": self.max,
"step": self.step,
"labels": labels,
}


@dataclass
class ObjectiveTextInputWidget:
description: Optional[str] = None

def to_dict(self) -> TextInputWidgetJSON:
return {
"type": "text",
"description": self.description,
}


@dataclass
class ObjectiveUserAttrRef:
key: str

def to_dict(self) -> UserAttrRefJSON:
return {
"type": "user_attr",
"key": self.key,
}


ObjectiveFormWidget = Union[
ObjectiveChoiceWidget, ObjectiveSliderWidget, ObjectiveTextInputWidget, ObjectiveUserAttrRef
]
SYSTEM_ATTR_KEY = "dashboard:objective_form_widgets"


def register_objective_form_widgets(
study: optuna.Study, widgets: list[ObjectiveFormWidget]
) -> None:
if len(study.directions) != len(widgets):
raise ValueError("The length of actions must be the same with the number of objectives.")
widget_dicts = [w.to_dict() for w in widgets]
study._storage.set_study_system_attr(
study._study_id, SYSTEM_ATTR_KEY, json.dumps(widget_dicts)
)


def get_objective_form_widgets_json(
study_system_attr: dict[str, Any]
) -> Optional[list[ObjectiveFormWidgetJSON]]:
widgets_json = study_system_attr.get(SYSTEM_ATTR_KEY)
if widgets_json is None:
return None
return json.loads(widgets_json)
4 changes: 4 additions & 0 deletions optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from . import _note as note
from ._named_objectives import get_objective_names
from ._objective_form_widget import get_objective_form_widgets_json
from .artifact._backend import list_trial_artifacts


Expand Down Expand Up @@ -143,6 +144,9 @@ def serialize_study_detail(
objective_names = get_objective_names(system_attrs)
if objective_names:
serialized["objective_names"] = objective_names
objective_form_widgets = get_objective_form_widgets_json(system_attrs)
if objective_form_widgets:
serialized["objective_form_widgets"] = objective_form_widgets
return serialized


Expand Down
4 changes: 2 additions & 2 deletions optuna_dashboard/ts/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -394,12 +394,12 @@ export const actionCreator = () => {
trialId: number,
state: TrialStateFinished,
values?: number[]
): Promise<void> => {
): void => {
const message =
values === undefined
? `id=${trialId}, state=${state}`
: `id=${trialId}, state=${state}, values=${values}`
return tellTrialAPI(trialId, state, values)
tellTrialAPI(trialId, state, values)
.then(() => {
const index = studyDetails[studyId].trials.findIndex(
(t) => t.trial_id === trialId
Expand Down
2 changes: 2 additions & 0 deletions optuna_dashboard/ts/apiClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ interface StudyDetailResponse {
has_intermediate_values: boolean
note: Note
objective_names?: string[]
objective_form_widgets?: ObjectiveFormWidget[]
}

export const getStudyDetailAPI = (
Expand Down Expand Up @@ -99,6 +100,7 @@ export const getStudyDetailAPI = (
has_intermediate_values: res.data.has_intermediate_values,
note: res.data.note,
objective_names: res.data.objective_names,
objective_form_widgets: res.data.objective_form_widgets,
}
})
}
Expand Down
6 changes: 4 additions & 2 deletions optuna_dashboard/ts/components/Debounce.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@ import React, { FC, useEffect } from "react"
import { TextField, TextFieldProps } from "@mui/material"

export const DebouncedInputTextField: FC<{
onChange: (s: string) => void
onChange: (s: string, valid: boolean) => void
delay: number
textFieldProps: TextFieldProps
}> = ({ onChange, delay, textFieldProps }) => {
const [text, setText] = React.useState<string>("")
const [valid, setValidity] = React.useState<boolean>(true)
useEffect(() => {
const timer = setTimeout(() => {
onChange(text)
onChange(text, valid)
}, delay)
return () => {
clearTimeout(timer)
Expand All @@ -19,6 +20,7 @@ export const DebouncedInputTextField: FC<{
<TextField
onChange={(e) => {
setText(e.target.value)
setValidity(e.target.validity.valid)
}}
{...textFieldProps}
/>
Expand Down
Loading

0 comments on commit 5bf6c38

Please sign in to comment.