Skip to content

Commit

Permalink
Merge pull request optuna#615 from moririn2528/settting-api
Browse files Browse the repository at this point in the history
add preference feedback component api
  • Loading branch information
c-bata authored Sep 14, 2023
2 parents cbfff61 + 9f17603 commit 76356eb
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ Preferential Optimization
optuna_dashboard.preferential.create_study
optuna_dashboard.preferential.load_study
optuna_dashboard.preferential.PreferentialStudy
optuna_dashboard.register_preference_feedback_component

Streamlit
-----------------
Expand Down
1 change: 1 addition & 0 deletions optuna_dashboard/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ._named_objectives import set_objective_names # noqa
from ._note import get_note # noqa
from ._note import save_note # noqa
from ._preference_setting import register_preference_feedback_component # noqa


__version__ = "0.13.0b1"
23 changes: 23 additions & 0 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ._custom_plot_data import get_plotly_graph_objects
from ._importance import get_param_importance_from_trials_cache
from ._pareto_front import get_pareto_front_trials
from ._preference_setting import _register_preference_feedback_component
from ._preferential_history import NewHistory
from ._preferential_history import PreferenceHistoryNotFound
from ._preferential_history import remove_history
Expand Down Expand Up @@ -315,6 +316,28 @@ def post_preference(study_id: int) -> dict[str, Any]:
response.status = 204
return {}

@app.put("/api/studies/<study_id:int>/preference_feedback_component_type")
@json_api_view
def put_preference_feedback_component_type(study_id: int) -> dict[str, Any]:
try:
component_type = request.json.get("type", "")
artifact_key = request.json.get("artifact_key", None)
except ValueError:
response.status = 400
return {"reason": "invalid request."}
if component_type not in ["note", "artifact"]:
response.status = 400
return {"reason": "component_type must be either 'Note' or 'Artifact'."}

_register_preference_feedback_component(
study_id=study_id,
storage=storage,
component_type=component_type,
artifact_key=artifact_key,
)
response.status = 204
return {}

@app.delete("/api/studies/<study_id:int>/preference/<history_id>")
@json_api_view
def remove_preference(study_id: int, history_id: str) -> dict[str, Any]:
Expand Down
67 changes: 67 additions & 0 deletions optuna_dashboard/_preference_setting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from typing import Any
from typing import TYPE_CHECKING

from optuna.storages import BaseStorage

from .preferential._study import PreferentialStudy


if TYPE_CHECKING:
from typing import Literal

OUTPUT_COMPONENT_TYPE = Literal["note", "artifact"]

_SYSTEM_ATTR_FEEDBACK_COMPONENT = "preference:component"


def _register_preference_feedback_component(
study_id: int,
storage: BaseStorage,
component_type: OUTPUT_COMPONENT_TYPE,
artifact_key: str | None = None,
) -> None:
value: dict[str, Any] = {"type": component_type}
if artifact_key is not None:
value["artifact_key"] = artifact_key
storage.set_study_system_attr(
study_id=study_id,
key=_SYSTEM_ATTR_FEEDBACK_COMPONENT,
value=value,
)


def register_preference_feedback_component(
study: PreferentialStudy,
component_type: OUTPUT_COMPONENT_TYPE,
artifact_key: str | None = None,
) -> None:
"""Register a preference feedback component to the study.
With this feature, you can change the component, displayed on the
human feedback pages. By default, the Markdown note (``component_type="note"``)
is displayed. If you specify ``component_type="artifact"``, the viewer for the
specified artifact file will be displayed.
Args:
study:
The study to register the preference feedback component.
component_type:
The component type, displayed on the human feedback pages
(default: ``"note"``).
user_attr_artifact_key:
This option is required when the ``component_type`` is ``"artifact"``.
The user attribute, which is specified this field, must contain the
``artifact``id you want to display on the human feedback page.
"""
if component_type == "artifact":
assert (
artifact_key is not None
), "artifact_key must be specified when component_type is Artifact"

_register_preference_feedback_component(
study_id=study._study._study_id,
storage=study._study._storage,
component_type=component_type,
artifact_key=artifact_key,
)
4 changes: 4 additions & 0 deletions optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from . import _note as note
from ._form_widget import get_form_widgets_json
from ._named_objectives import get_objective_names
from ._preference_setting import _SYSTEM_ATTR_FEEDBACK_COMPONENT
from ._preferential_history import _SYSTEM_ATTR_PREFIX_HISTORY
from .artifact._backend import list_trial_artifacts
from .preferential._study import _SYSTEM_ATTR_PREFERENTIAL_STUDY
Expand Down Expand Up @@ -165,6 +166,9 @@ def serialize_study_detail(
if form_widgets:
serialized["form_widgets"] = form_widgets
if serialized["is_preferential"]:
serialized["feedback_component_type"] = system_attrs.get(
_SYSTEM_ATTR_FEEDBACK_COMPONENT, {}
)
serialized["preference_history"] = serialize_preference_history(system_attrs)
serialized["preferences"] = get_preferences(system_attrs)
serialized["skipped_trials"] = skipped_trials
Expand Down
31 changes: 31 additions & 0 deletions python_tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from optuna.study import StudyDirection
from optuna_dashboard._app import create_app
from optuna_dashboard._app import create_new_study
from optuna_dashboard._preference_setting import register_preference_feedback_component
from optuna_dashboard._preferential_history import NewHistory
from optuna_dashboard._preferential_history import remove_history
from optuna_dashboard._preferential_history import report_history
Expand Down Expand Up @@ -183,6 +184,36 @@ def test_report_preference_when_typo_mode(self) -> None:
)
self.assertEqual(status, 400)

def test_change_component(self) -> None:
storage = optuna.storages.InMemoryStorage()
study = create_study(storage=storage, n_generate=3)
register_preference_feedback_component(study, "note")
for _ in range(3):
study.ask()

app = create_app(storage)
study_id = study._study._study_id
status, _, _ = send_request(
app,
f"/api/studies/{study_id}/preference_feedback_component_type",
"PUT",
body=json.dumps({"type": "artifact", "artifact_key": "image"}),
content_type="application/json",
)
self.assertEqual(status, 204)

status, _, body = send_request(
app,
f"/api/studies/{study_id}",
"GET",
content_type="application/json",
)
self.assertEqual(status, 200)

study_detail = json.loads(body)
assert study_detail["feedback_component_type"]["type"] == "artifact"
assert study_detail["feedback_component_type"]["artifact_key"] == "image"

def test_skip_trial(self) -> None:
storage = optuna.storages.InMemoryStorage()
study = create_study(n_generate=4, storage=storage)
Expand Down
20 changes: 20 additions & 0 deletions python_tests/test_preference_setting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

from unittest import TestCase

import optuna
from optuna_dashboard._preference_setting import _SYSTEM_ATTR_FEEDBACK_COMPONENT
from optuna_dashboard._preference_setting import register_preference_feedback_component
from optuna_dashboard.preferential._study import PreferentialStudy


class FeedbackSettingTestCase(TestCase):
def test_widget_to_dict_from_dict(self) -> None:
study = PreferentialStudy(optuna.create_study())
register_preference_feedback_component(study, "artifact", "image_key")
system_attrs = study._study.system_attrs
feedback_type = system_attrs.get(_SYSTEM_ATTR_FEEDBACK_COMPONENT, {})
assert "type" in feedback_type
assert feedback_type["type"] == "artifact"
assert "artifact_key" in feedback_type
assert feedback_type["artifact_key"] == "image_key"

0 comments on commit 76356eb

Please sign in to comment.