Skip to content

Commit

Permalink
Support constraints plot_rank (optuna#4899)
Browse files Browse the repository at this point in the history
* refactor: Passing color as normalized rgb ndarray.

* feat: Support optimization with constraints in plot_rank.

* fix: Make color ditribution same between tick and points.

* refactor: Passing color with 0~255 RGB not 0~1.

* test: Test color converter by assert_array_equal.

* docs: Add constraints to rank_plot example.

* fix: Delete debug message in test.

Co-authored-by: Shinichi Hemmi <[email protected]>

---------

Co-authored-by: Shinichi Hemmi <[email protected]>
  • Loading branch information
ryota717 and Alnusjaponica authored Sep 26, 2023
1 parent a1b43b2 commit 52465d9
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 38 deletions.
51 changes: 37 additions & 14 deletions optuna/visualization/_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from optuna._experimental import experimental_func
from optuna.logging import get_logger
from optuna.samplers._base import _CONSTRAINTS_KEY
from optuna.study import Study
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
Expand Down Expand Up @@ -45,15 +46,15 @@ class _RankSubplotInfo(NamedTuple):
ys: list[Any]
trials: list[FrozenTrial]
zs: np.ndarray
color_idxs: np.ndarray
colors: np.ndarray


class _RankPlotInfo(NamedTuple):
params: list[str]
sub_plot_infos: list[list[_RankSubplotInfo]]
target_name: str
zs: np.ndarray
color_idxs: np.ndarray
colors: np.ndarray
has_custom_target: bool


Expand Down Expand Up @@ -81,10 +82,18 @@ def plot_rank(
def objective(trial):
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
c0 = 400 - (x + y)**2
trial.set_user_attr("constraint", [c0])
return x ** 2 + y
sampler = optuna.samplers.TPESampler(seed=10)
def constraints(trial):
return trial.user_attrs["constraint"]
sampler = optuna.samplers.TPESampler(seed=10, constraints_func=constraints)
study = optuna.create_study(sampler=sampler)
study.optimize(objective, n_trials=30)
Expand Down Expand Up @@ -160,16 +169,18 @@ def target(trial: FrozenTrial) -> float:
target_values = np.array([target(trial) for trial in trials])
raw_ranks = _get_order_with_same_order_averaging(target_values)
color_idxs = raw_ranks / (len(trials) - 1) if len(trials) >= 2 else np.array([0.5])
colors = _convert_color_idxs_to_scaled_rgb_colors(color_idxs)

sub_plot_infos: list[list[_RankSubplotInfo]]
if len(params) == 2:
x_param = params[0]
y_param = params[1]
sub_plot_info = _get_rank_subplot_info(trials, target_values, color_idxs, x_param, y_param)
sub_plot_info = _get_rank_subplot_info(trials, target_values, colors, x_param, y_param)
sub_plot_infos = [[sub_plot_info]]
else:
sub_plot_infos = [
[
_get_rank_subplot_info(trials, target_values, color_idxs, x_param, y_param)
_get_rank_subplot_info(trials, target_values, colors, x_param, y_param)
for x_param in params
]
for y_param in params
Expand All @@ -180,21 +191,29 @@ def target(trial: FrozenTrial) -> float:
sub_plot_infos=sub_plot_infos,
target_name=target_name,
zs=target_values,
color_idxs=color_idxs,
colors=colors,
has_custom_target=has_custom_target,
)


def _get_rank_subplot_info(
trials: list[FrozenTrial],
target_values: np.ndarray,
color_idxs: np.ndarray,
colors: np.ndarray,
x_param: str,
y_param: str,
) -> _RankSubplotInfo:
xaxis = _get_axis_info(trials, x_param)
yaxis = _get_axis_info(trials, y_param)

infeasible_trial_ids = []
for i in range(len(trials)):
constraints = trials[i].system_attrs.get(_CONSTRAINTS_KEY)
if constraints is not None and any([x > 0.0 for x in constraints]):
infeasible_trial_ids.append(i)

colors[infeasible_trial_ids] = plotly.colors.hex_to_rgb("#cccccc")

filtered_ids = [
i
for i in range(len(trials))
Expand All @@ -204,15 +223,15 @@ def _get_rank_subplot_info(
xs = [trial.params[x_param] for trial in filtered_trials]
ys = [trial.params[y_param] for trial in filtered_trials]
zs = target_values[filtered_ids]
color_idxs = color_idxs[filtered_ids]
colors = colors[filtered_ids]
return _RankSubplotInfo(
xaxis=xaxis,
yaxis=yaxis,
xs=xs,
ys=ys,
trials=filtered_trials,
zs=np.array(zs),
color_idxs=np.array(color_idxs),
colors=colors,
)


Expand Down Expand Up @@ -269,10 +288,6 @@ def _get_axis_info(trials: list[FrozenTrial], param_name: str) -> _AxisInfo:
def _get_rank_subplot(
info: _RankSubplotInfo, target_name: str, print_raw_objectives: bool
) -> "Scatter":
colormap = "RdYlBu_r"
# sample_colorscale requires plotly >= 5.0.0.
colors = plotly.colors.sample_colorscale(colormap, info.color_idxs)

def get_hover_text(trial: FrozenTrial, target_value: float) -> str:
lines = [f"Trial #{trial.number}"]
lines += [f"{k}: {v}" for k, v in trial.params.items()]
Expand All @@ -285,7 +300,7 @@ def get_hover_text(trial: FrozenTrial, target_value: float) -> str:
x=info.xs,
y=info.ys,
marker={
"color": colors,
"color": list(map(plotly.colors.label_rgb, info.colors)),
"line": {"width": 0.5, "color": "Grey"},
},
mode="markers",
Expand Down Expand Up @@ -403,3 +418,11 @@ def _get_rank_plot(
)
figure.add_trace(colorbar_trace)
return figure


def _convert_color_idxs_to_scaled_rgb_colors(color_idxs: np.ndarray) -> np.ndarray:
colormap = "RdYlBu_r"
# sample_colorscale requires plotly >= 5.0.0.
labeled_colors = plotly.colors.sample_colorscale(colormap, color_idxs)
scaled_rgb_colors = np.array([plotly.colors.unlabel_rgb(cl) for cl in labeled_colors])
return scaled_rgb_colors
17 changes: 12 additions & 5 deletions optuna/visualization/matplotlib/_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,18 @@ def plot_rank(
def objective(trial):
x = trial.suggest_float("x", -100, 100)
y = trial.suggest_categorical("y", [-1, 0, 1])
c0 = 400 - (x + y)**2
trial.set_user_attr("constraint", [c0])
return x ** 2 + y
sampler = optuna.samplers.TPESampler(seed=10)
def constraints(trial):
return trial.user_attrs["constraint"]
sampler = optuna.samplers.TPESampler(seed=10, constraints_func=constraints)
study = optuna.create_study(sampler=sampler)
study.optimize(objective, n_trials=30)
Expand Down Expand Up @@ -126,7 +134,8 @@ def _get_rank_plot(

tick_info = _get_tick_info(info.zs)

cbar = fig.colorbar(pc, ax=axs, ticks=tick_info.coloridxs, cmap=plt.get_cmap("RdYlBu_r"))
pc.set_cmap(plt.get_cmap("RdYlBu_r"))
cbar = fig.colorbar(pc, ax=axs, ticks=tick_info.coloridxs)
cbar.ax.set_yticklabels(tick_info.text)
cbar.outline.set_edgecolor("gray")
return axs
Expand All @@ -151,6 +160,4 @@ def _add_rank_subplot(
if info.yaxis.is_log:
ax.set_yscale("log")

return ax.scatter(
x=info.xs, y=info.ys, c=info.color_idxs, cmap=plt.get_cmap("RdYlBu_r"), edgecolors="grey"
)
return ax.scatter(x=info.xs, y=info.ys, c=info.colors / 255, edgecolors="grey")
Loading

0 comments on commit 52465d9

Please sign in to comment.