From 1c9dbbb7bc7a589bb3ebfd455eea1f4ee6c81bd4 Mon Sep 17 00:00:00 2001 From: keisuke-umezawa Date: Sun, 18 Feb 2024 14:49:27 +0900 Subject: [PATCH 1/4] Replace StudySummary to ForzenStudy --- optuna_dashboard/_app.py | 34 +++++++++++++------------- optuna_dashboard/_serializer.py | 33 +++++++++++-------------- optuna_dashboard/_storage.py | 36 ++++++--------------------- python_tests/test_serializers.py | 42 ++++++++++++++------------------ 4 files changed, 57 insertions(+), 88 deletions(-) diff --git a/optuna_dashboard/_app.py b/optuna_dashboard/_app.py index 1d4c9f7c0..91cf08d15 100644 --- a/optuna_dashboard/_app.py +++ b/optuna_dashboard/_app.py @@ -40,11 +40,11 @@ from ._preferential_history import report_history from ._preferential_history import restore_history from ._rdb_migration import register_rdb_migration_route +from ._serializer import serialize_frozen_study from ._serializer import serialize_study_detail -from ._serializer import serialize_study_summary from ._storage import create_new_study -from ._storage import get_study_summaries -from ._storage import get_study_summary +from ._storage import get_studies +from ._storage import get_study from ._storage import get_trials from ._storage_url import get_storage from .artifact._backend import delete_all_artifacts @@ -102,8 +102,8 @@ def api_meta() -> dict[str, Any]: @app.get("/api/studies") @json_api_view def list_study_summaries() -> dict[str, Any]: - summaries = get_study_summaries(storage) - serialized = [serialize_study_summary(summary) for summary in summaries] + studies = get_studies(storage) + serialized = [serialize_frozen_study(s) for s in studies] return { "study_summaries": serialized, } @@ -131,12 +131,12 @@ def create_study() -> dict[str, Any]: response.status = 400 # Bad request return {"reason": f"'{study_name}' already exists"} - summary = get_study_summary(storage, study_id) - if summary is None: + study = get_study(storage, study_id) + if study is None: response.status = 500 # Internal server error return {"reason": "Failed to create study"} response.status = 201 # Created - return {"study_summary": serialize_study_summary(summary)} + return {"study_summary": serialize_frozen_study(study)} @app.post("/api/studies//rename") @json_api_view @@ -167,14 +167,14 @@ def rename_study(study_id: int) -> dict[str, Any]: response.status = 500 storage.delete_study(dst_study._study_id) return {"reason": str(e)} - new_study_summary = get_study_summary(storage, dst_study._study_id) - if new_study_summary is None: + new_study = get_study(storage, dst_study._study_id) + if new_study is None: response.status = 500 return {"reason": "Failed to load the new study"} storage.delete_study(src_study._study_id) response.status = 201 - return serialize_study_summary(new_study_summary) + return serialize_frozen_study(new_study) @app.delete("/api/studies/") @json_api_view @@ -201,24 +201,24 @@ def get_study_detail(study_id: int) -> dict[str, Any]: return {"reason": "`after` should be larger or equal 0."} except KeyError: after = 0 - summary = get_study_summary(storage, study_id) - if summary is None: + study = get_study(storage, study_id) + if study is None: response.status = 404 # Not found return {"reason": f"study_id={study_id} is not found"} trials = get_trials(storage, study_id) - system_attrs = getattr(summary, "system_attrs", {}) + system_attrs = getattr(study, "system_attrs", {}) is_preferential = system_attrs.get(_SYSTEM_ATTR_PREFERENTIAL_STUDY, False) # TODO(c-bata): Cache best_trials if is_preferential: best_trials = get_best_preferential_trials(study_id, storage) - elif len(summary.directions) == 1: + elif len(study.directions) == 1: if len([t for t in trials if t.state == TrialState.COMPLETE]) == 0: best_trials = [] else: best_trials = [storage.get_best_trial(study_id)] else: - best_trials = get_pareto_front_trials(trials=trials, directions=summary.directions) + best_trials = get_pareto_front_trials(trials=trials, directions=study.directions) ( # TODO: intersection_search_space and union_search_space look more clear since now we # have union_user_attrs. @@ -232,7 +232,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]: skipped_trial_ids = get_skipped_trial_ids(system_attrs) skipped_trial_numbers = [t.number for t in trials if t._trial_id in skipped_trial_ids] return serialize_study_detail( - summary, + study, best_trials, trials[after:], intersection, diff --git a/optuna_dashboard/_serializer.py b/optuna_dashboard/_serializer.py index e3c4da659..bbb94d97c 100644 --- a/optuna_dashboard/_serializer.py +++ b/optuna_dashboard/_serializer.py @@ -12,7 +12,7 @@ from optuna.distributions import CategoricalDistribution from optuna.distributions import FloatDistribution from optuna.distributions import IntDistribution -from optuna.study import StudySummary +from optuna.study._frozen import FrozenStudy from optuna.trial import FrozenTrial from . import _note as note @@ -116,25 +116,22 @@ def serialize_attrs(attrs: dict[str, Any]) -> list[Attribute]: return serialized -def serialize_study_summary(summary: StudySummary) -> dict[str, Any]: +def serialize_frozen_study(study: FrozenStudy) -> dict[str, Any]: serialized = { - "study_id": summary._study_id, - "study_name": summary.study_name, - "directions": [d.name.lower() for d in summary.directions], - "user_attrs": serialize_attrs(summary.user_attrs), - "is_preferential": getattr(summary, "_system_attrs", {}).get( + "study_id": study._study_id, + "study_name": study.study_name, + "directions": [d.name.lower() for d in study.directions], + "user_attrs": serialize_attrs(study.user_attrs), + "is_preferential": getattr(study, "_system_attrs", {}).get( _SYSTEM_ATTR_PREFERENTIAL_STUDY, False ), } - if summary.datetime_start is not None: - serialized["datetime_start"] = summary.datetime_start.isoformat() - return serialized def serialize_study_detail( - summary: StudySummary, + study: FrozenStudy, best_trials: list[FrozenTrial], trials: list[FrozenTrial], intersection: list[tuple[str, BaseDistribution]], @@ -145,20 +142,18 @@ def serialize_study_detail( skipped_trial_numbers: list[int], ) -> dict[str, Any]: serialized: dict[str, Any] = { - "name": summary.study_name, - "directions": [d.name.lower() for d in summary.directions], - "user_attrs": serialize_attrs(summary.user_attrs), + "name": study.study_name, + "directions": [d.name.lower() for d in study.directions], + "user_attrs": serialize_attrs(study.user_attrs), } - system_attrs = getattr(summary, "system_attrs", {}) + system_attrs = getattr(study, "system_attrs", {}) serialized["artifacts"] = list_study_artifacts(system_attrs) - if summary.datetime_start is not None: - serialized["datetime_start"] = summary.datetime_start.isoformat() serialized["trials"] = [ - serialize_frozen_trial(summary._study_id, trial, system_attrs) for trial in trials + serialize_frozen_trial(study._study_id, trial, system_attrs) for trial in trials ] serialized["best_trials"] = [ - serialize_frozen_trial(summary._study_id, trial, system_attrs) for trial in best_trials + serialize_frozen_trial(study._study_id, trial, system_attrs) for trial in best_trials ] serialized["intersection_search_space"] = serialize_search_space(intersection) serialized["union_search_space"] = serialize_search_space(union) diff --git a/optuna_dashboard/_storage.py b/optuna_dashboard/_storage.py index 460d7cb71..fe4a36c00 100644 --- a/optuna_dashboard/_storage.py +++ b/optuna_dashboard/_storage.py @@ -3,19 +3,14 @@ from datetime import datetime from datetime import timedelta import threading -import typing from optuna.storages import BaseStorage from optuna.storages import RDBStorage from optuna.study import StudyDirection -from optuna.study import StudySummary +from optuna.study._frozen import FrozenStudy from optuna.trial import FrozenTrial -if typing.TYPE_CHECKING: - from optuna.study._frozen import FrozenStudy - - # In-memory trials cache trials_cache_lock = threading.Lock() trials_cache: dict[int, list[FrozenTrial]] = {} @@ -49,19 +44,19 @@ def get_trials(storage: BaseStorage, study_id: int) -> list[FrozenTrial]: return trials -def get_study_summaries(storage: BaseStorage) -> list[StudySummary]: +def get_studies(storage: BaseStorage) -> list[FrozenStudy]: frozen_studies = storage.get_all_studies() if isinstance(storage, RDBStorage): frozen_studies = sorted(frozen_studies, key=lambda s: s._study_id) - return [_frozen_study_to_study_summary(s) for s in frozen_studies] + return frozen_studies -def get_study_summary(storage: BaseStorage, study_id: int) -> StudySummary | None: - summaries = get_study_summaries(storage) - for summary in summaries: - if summary._study_id != study_id: +def get_study(storage: BaseStorage, study_id: int) -> FrozenStudy: + studies = get_studies(storage) + for s in studies: + if s._study_id != study_id: continue - return summary + return s return None @@ -70,18 +65,3 @@ def create_new_study( ) -> int: study_id = storage.create_new_study(directions, study_name=study_name) return study_id - - -def _frozen_study_to_study_summary(frozen_study: "FrozenStudy") -> StudySummary: - is_single = len(frozen_study.directions) == 1 - return StudySummary( - study_name=frozen_study.study_name, - study_id=frozen_study._study_id, - direction=frozen_study.direction if is_single else None, - directions=frozen_study.directions if not is_single else None, - user_attrs=frozen_study.user_attrs, - system_attrs=frozen_study.system_attrs, - best_trial=None, - n_trials=-1, # This field isn't used by Dashboard. - datetime_start=None, - ) diff --git a/python_tests/test_serializers.py b/python_tests/test_serializers.py index a991d5d39..fb43a9c83 100644 --- a/python_tests/test_serializers.py +++ b/python_tests/test_serializers.py @@ -5,9 +5,9 @@ import numpy as np import optuna from optuna_dashboard._serializer import serialize_attrs +from optuna_dashboard._serializer import serialize_frozen_study from optuna_dashboard._serializer import serialize_study_detail -from optuna_dashboard._serializer import serialize_study_summary -from optuna_dashboard._storage import get_study_summaries +from optuna_dashboard._storage import get_studies from optuna_dashboard.preferential import create_study from packaging import version import pytest @@ -60,26 +60,20 @@ def test_serialize_numpy_floating() -> None: def test_get_study_detail_is_preferential() -> None: storage = optuna.storages.InMemoryStorage() study = create_study(n_generate=4, storage=storage) - study_summaries = get_study_summaries(storage) - assert len(study_summaries) == 1 + studies = get_studies(storage) + assert len(studies) == 1 - study_summary = study_summaries[0] - study_detail = serialize_study_detail( - study_summary, [], study.trials, [], [], [], False, {}, [] - ) + study_detail = serialize_study_detail(studies[0], [], study.trials, [], [], [], False, {}, []) assert study_detail["is_preferential"] def test_get_study_detail_is_not_preferential() -> None: storage = optuna.storages.InMemoryStorage() study = optuna.create_study(storage=storage) - study_summaries = get_study_summaries(storage) - assert len(study_summaries) == 1 + studies = get_studies(storage) + assert len(studies) == 1 - study_summary = study_summaries[0] - study_detail = serialize_study_detail( - study_summary, [], study.trials, [], [], [], False, {}, [] - ) + study_detail = serialize_study_detail(studies[0], [], study.trials, [], [], [], False, {}, []) assert not study_detail["is_preferential"] @@ -87,20 +81,20 @@ def test_get_study_detail_is_not_preferential() -> None: @pytest.mark.skipif( version.parse(optuna.__version__) < version.parse("3.2.0"), reason="Needs optuna.search_space" ) -def test_get_study_summary_is_preferential() -> None: +def test_get_study_is_preferential() -> None: storage = optuna.storages.InMemoryStorage() create_study(n_generate=4, storage=storage) - study_summaries = get_study_summaries(storage) - assert len(study_summaries) == 1 + studies = get_studies(storage) + assert len(studies) == 1 - study_summary = serialize_study_summary(study_summaries[0]) - assert study_summary["is_preferential"] + serialized = serialize_frozen_study(studies[0]) + assert serialized["is_preferential"] -def test_get_study_summary_is_not_preferential() -> None: +def test_get_study_is_not_preferential() -> None: storage = optuna.storages.InMemoryStorage() optuna.create_study(storage=storage) - study_summaries = get_study_summaries(storage) - assert len(study_summaries) == 1 - study_summary = serialize_study_summary(study_summaries[0]) - assert not study_summary["is_preferential"] + studies = get_studies(storage) + assert len(studies) == 1 + serialized = serialize_frozen_study(studies[0]) + assert not serialized["is_preferential"] From ed6636955d936141b04e908909d8ea709f2e3e9a Mon Sep 17 00:00:00 2001 From: keisuke-umezawa Date: Sun, 18 Feb 2024 14:56:34 +0900 Subject: [PATCH 2/4] Fix mypy --- optuna_dashboard/_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optuna_dashboard/_storage.py b/optuna_dashboard/_storage.py index fe4a36c00..8fe10e0a0 100644 --- a/optuna_dashboard/_storage.py +++ b/optuna_dashboard/_storage.py @@ -51,7 +51,7 @@ def get_studies(storage: BaseStorage) -> list[FrozenStudy]: return frozen_studies -def get_study(storage: BaseStorage, study_id: int) -> FrozenStudy: +def get_study(storage: BaseStorage, study_id: int) -> FrozenStudy | None: studies = get_studies(storage) for s in studies: if s._study_id != study_id: From 9170b8824f69b08dd94d1dfd26e5c9910f8a93d8 Mon Sep 17 00:00:00 2001 From: keisuke-umezawa Date: Sun, 18 Feb 2024 15:11:24 +0900 Subject: [PATCH 3/4] Fix unittest --- optuna_dashboard/_serializer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/optuna_dashboard/_serializer.py b/optuna_dashboard/_serializer.py index bbb94d97c..ee527ca6b 100644 --- a/optuna_dashboard/_serializer.py +++ b/optuna_dashboard/_serializer.py @@ -122,9 +122,7 @@ def serialize_frozen_study(study: FrozenStudy) -> dict[str, Any]: "study_name": study.study_name, "directions": [d.name.lower() for d in study.directions], "user_attrs": serialize_attrs(study.user_attrs), - "is_preferential": getattr(study, "_system_attrs", {}).get( - _SYSTEM_ATTR_PREFERENTIAL_STUDY, False - ), + "is_preferential": study.system_attrs.get(_SYSTEM_ATTR_PREFERENTIAL_STUDY, False), } return serialized @@ -146,7 +144,7 @@ def serialize_study_detail( "directions": [d.name.lower() for d in study.directions], "user_attrs": serialize_attrs(study.user_attrs), } - system_attrs = getattr(study, "system_attrs", {}) + system_attrs = study.system_attrs serialized["artifacts"] = list_study_artifacts(system_attrs) serialized["trials"] = [ From 8f39413e9c2e00a58687cab3cdae1d19bde33ee0 Mon Sep 17 00:00:00 2001 From: keisuke-umezawa Date: Thu, 22 Feb 2024 10:15:10 +0900 Subject: [PATCH 4/4] Follow review comments --- optuna_dashboard/_app.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/optuna_dashboard/_app.py b/optuna_dashboard/_app.py index 91cf08d15..c2e5099af 100644 --- a/optuna_dashboard/_app.py +++ b/optuna_dashboard/_app.py @@ -101,9 +101,10 @@ def api_meta() -> dict[str, Any]: @app.get("/api/studies") @json_api_view - def list_study_summaries() -> dict[str, Any]: + def list_studies() -> dict[str, Any]: studies = get_studies(storage) serialized = [serialize_frozen_study(s) for s in studies] + # TODO(umezawa): Rename `study_summaries` to `studies`. return { "study_summaries": serialized, }