From 7a1522fb2dff8956a533ad510089d64f66b36440 Mon Sep 17 00:00:00 2001 From: xadrianzetx Date: Sun, 17 Sep 2023 10:55:20 +0200 Subject: [PATCH 1/2] Add `metric_names` getter to complement `get_metric_names` --- optuna/study/study.py | 13 +++++++++++++ tests/study_tests/test_study.py | 15 +++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/optuna/study/study.py b/optuna/study/study.py index 59aab1cd1e..07140c14ce 100644 --- a/optuna/study/study.py +++ b/optuna/study/study.py @@ -333,6 +333,19 @@ def system_attrs(self) -> dict[str, Any]: return copy.deepcopy(self._storage.get_study_system_attrs(self._study_id)) + @property + @experimental_func("3.4.0") + def metric_names(self) -> list[str] | None: + """Return metric names. + + .. note:: + Use :meth:`~optuna.study.Study.set_metric_names` to set the metric names first. + + Returns: + A list with names for each dimension of the returned values of the objective function. + """ + return self._storage.get_study_system_attrs(self._study_id).get(_SYSTEM_ATTR_METRIC_NAMES) + def optimize( self, func: ObjectiveFuncType, diff --git a/tests/study_tests/test_study.py b/tests/study_tests/test_study.py index 0aba2e13ef..51411e54d8 100644 --- a/tests/study_tests/test_study.py +++ b/tests/study_tests/test_study.py @@ -1601,3 +1601,18 @@ def test_set_invalid_metric_names() -> None: study = create_study(directions=["minimize", "minimize"]) with pytest.raises(ValueError): study.set_metric_names(metric_names) + + +def test_get_metric_names() -> None: + study = create_study() + assert study.metric_names is None + study.set_metric_names(["v0"]) + assert study.metric_names == ["v0"] + study.set_metric_names(["v1"]) + assert study.metric_names == ["v1"] + + +def test_get_metric_names_experimental_warning() -> None: + study = create_study() + with pytest.warns(ExperimentalWarning): + study.metric_names From 60a3552f05a884083d52869a013f5975eb12a115 Mon Sep 17 00:00:00 2001 From: xadrianzetx Date: Sun, 17 Sep 2023 10:56:17 +0200 Subject: [PATCH 2/2] Use `metric_names` attribute instead of querying the storage directly --- optuna/study/_dataframe.py | 5 +---- optuna/study/study.py | 4 +--- optuna/visualization/_pareto_front.py | 5 +---- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/optuna/study/_dataframe.py b/optuna/study/_dataframe.py index b311443055..84131526e9 100644 --- a/optuna/study/_dataframe.py +++ b/optuna/study/_dataframe.py @@ -8,7 +8,6 @@ import optuna from optuna._imports import try_import -from optuna.study.study import _SYSTEM_ATTR_METRIC_NAMES from optuna.trial._state import TrialState @@ -41,9 +40,7 @@ def _create_records_and_aggregate_column( column_agg: DefaultDict[str, Set] = collections.defaultdict(set) non_nested_attr = "" - metric_names = study._storage.get_study_system_attrs(study._study_id).get( - _SYSTEM_ATTR_METRIC_NAMES - ) + metric_names = study.metric_names records = [] for trial in study.get_trials(deepcopy=False): diff --git a/optuna/study/study.py b/optuna/study/study.py index 07140c14ce..eae8fd0993 100644 --- a/optuna/study/study.py +++ b/optuna/study/study.py @@ -1088,9 +1088,7 @@ def _log_completed_trial(self, trial: trial_module.FrozenTrial) -> None: if not _logger.isEnabledFor(logging.INFO): return - metric_names = self._storage.get_study_system_attrs(self._study_id).get( - _SYSTEM_ATTR_METRIC_NAMES - ) + metric_names = self.metric_names if len(trial.values) > 1: trial_values: list[float] | dict[str, float] diff --git a/optuna/visualization/_pareto_front.py b/optuna/visualization/_pareto_front.py index 05365a5583..c7f8a0f809 100644 --- a/optuna/visualization/_pareto_front.py +++ b/optuna/visualization/_pareto_front.py @@ -11,7 +11,6 @@ from optuna.exceptions import ExperimentalWarning from optuna.study import Study from optuna.study._multi_objective import _get_pareto_front_trials_by_trials -from optuna.study.study import _SYSTEM_ATTR_METRIC_NAMES from optuna.trial import FrozenTrial from optuna.trial import TrialState from optuna.visualization._plotly_imports import _imports @@ -343,9 +342,7 @@ def _infer_n_targets( ) if target_names is None: - metric_names = study._storage.get_study_system_attrs(study._study_id).get( - _SYSTEM_ATTR_METRIC_NAMES - ) + metric_names = study.metric_names if metric_names is None: target_names = [f"Objective {i}" for i in range(n_targets)] else: