Skip to content

Commit

Permalink
Merge pull request optuna#4930 from xadrianzetx/metric-names-getter2
Browse files Browse the repository at this point in the history
Add `metric_names` getter to study
  • Loading branch information
HideakiImamura authored Sep 28, 2023
2 parents 52465d9 + 60a3552 commit e06046b
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 11 deletions.
5 changes: 1 addition & 4 deletions optuna/study/_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import optuna
from optuna._imports import try_import
from optuna.study.study import _SYSTEM_ATTR_METRIC_NAMES
from optuna.trial._state import TrialState


Expand Down Expand Up @@ -41,9 +40,7 @@ def _create_records_and_aggregate_column(
column_agg: DefaultDict[str, Set] = collections.defaultdict(set)
non_nested_attr = ""

metric_names = study._storage.get_study_system_attrs(study._study_id).get(
_SYSTEM_ATTR_METRIC_NAMES
)
metric_names = study.metric_names

records = []
for trial in study.get_trials(deepcopy=False):
Expand Down
17 changes: 14 additions & 3 deletions optuna/study/study.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,19 @@ def system_attrs(self) -> dict[str, Any]:

return copy.deepcopy(self._storage.get_study_system_attrs(self._study_id))

@property
@experimental_func("3.4.0")
def metric_names(self) -> list[str] | None:
"""Return metric names.
.. note::
Use :meth:`~optuna.study.Study.set_metric_names` to set the metric names first.
Returns:
A list with names for each dimension of the returned values of the objective function.
"""
return self._storage.get_study_system_attrs(self._study_id).get(_SYSTEM_ATTR_METRIC_NAMES)

def optimize(
self,
func: ObjectiveFuncType,
Expand Down Expand Up @@ -1075,9 +1088,7 @@ def _log_completed_trial(self, trial: trial_module.FrozenTrial) -> None:
if not _logger.isEnabledFor(logging.INFO):
return

metric_names = self._storage.get_study_system_attrs(self._study_id).get(
_SYSTEM_ATTR_METRIC_NAMES
)
metric_names = self.metric_names

if len(trial.values) > 1:
trial_values: list[float] | dict[str, float]
Expand Down
5 changes: 1 addition & 4 deletions optuna/visualization/_pareto_front.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from optuna.exceptions import ExperimentalWarning
from optuna.study import Study
from optuna.study._multi_objective import _get_pareto_front_trials_by_trials
from optuna.study.study import _SYSTEM_ATTR_METRIC_NAMES
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from optuna.visualization._plotly_imports import _imports
Expand Down Expand Up @@ -343,9 +342,7 @@ def _infer_n_targets(
)

if target_names is None:
metric_names = study._storage.get_study_system_attrs(study._study_id).get(
_SYSTEM_ATTR_METRIC_NAMES
)
metric_names = study.metric_names
if metric_names is None:
target_names = [f"Objective {i}" for i in range(n_targets)]
else:
Expand Down
15 changes: 15 additions & 0 deletions tests/study_tests/test_study.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,3 +1601,18 @@ def test_set_invalid_metric_names() -> None:
study = create_study(directions=["minimize", "minimize"])
with pytest.raises(ValueError):
study.set_metric_names(metric_names)


def test_get_metric_names() -> None:
study = create_study()
assert study.metric_names is None
study.set_metric_names(["v0"])
assert study.metric_names == ["v0"]
study.set_metric_names(["v1"])
assert study.metric_names == ["v1"]


def test_get_metric_names_experimental_warning() -> None:
study = create_study()
with pytest.warns(ExperimentalWarning):
study.metric_names

0 comments on commit e06046b

Please sign in to comment.