Skip to content

Commit

Permalink
update labels and label sizes in plots
Browse files Browse the repository at this point in the history
  • Loading branch information
mastoffel committed Nov 27, 2024
1 parent 8c3ddb8 commit 44d7f8b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
22 changes: 14 additions & 8 deletions autoemulate/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def _validate_inputs(cv_results, model_name):
)


def check_multioutput(y, output_index):
def _check_multioutput(y, output_index):
"""Checks if y is multi-output and if the output_index is valid."""
if y.ndim > 1:
if (output_index > y.shape[1] - 1) | (output_index < 0):
raise ValueError(
f"Output index {output_index} is out of range. The index should be between 0 and {y.shape[1] - 1}."
)
print(
f"""Multiple outputs detected. Plotting the output variable with index {output_index}.
f"""Plotting the output variable with index {output_index}.
To plot other outputs, set `output_index` argument to the desired index."""
)

Expand Down Expand Up @@ -148,6 +148,8 @@ def _plot_single_fold(
y_test_std,
ax,
title=f"{model_name} - {title_suffix}",
input_index=input_index,
output_index=output_index,
)
else:
display = PredictionErrorDisplay.from_predictions(
Expand Down Expand Up @@ -334,7 +336,7 @@ def _plot_cv(
"""

_validate_inputs(cv_results, model_name)
check_multioutput(y, output_index)
_check_multioutput(y, output_index)

if model_name:
figure = _plot_model_folds(
Expand Down Expand Up @@ -449,7 +451,9 @@ def _plot_model(
y_pred[:, out_idx],
y_std[:, out_idx] if y_std is not None else None,
ax=axs[plot_index],
title=f"X{in_idx} vs. y{out_idx}",
title=f"$X_{in_idx}$ vs. $y_{out_idx}$",
input_index=in_idx,
output_index=out_idx,
)
plot_index += 1
else:
Expand Down Expand Up @@ -479,7 +483,9 @@ def _plot_model(
return fig


def _plot_Xy(X, y, y_pred, y_std=None, ax=None, title="Xy"):
def _plot_Xy(
X, y, y_pred, y_std=None, ax=None, title="Xy", input_index=0, output_index=0
):
"""
Plots observed and predicted values vs. features, including 2σ error bands where available.
"""
Expand Down Expand Up @@ -533,9 +539,9 @@ def _plot_Xy(X, y, y_pred, y_std=None, ax=None, title="Xy"):
label="pred.",
)

ax.set_xlabel("X")
ax.set_ylabel("y")
ax.set_title(title)
ax.set_xlabel(f"$X_{input_index}$", fontsize=13)
ax.set_ylabel(f"$y_{output_index}$", fontsize=13)
ax.set_title(title, fontsize=13)
ax.grid(True, alpha=0.3)

# Get the handles and labels for the scatter plots
Expand Down
12 changes: 6 additions & 6 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

from autoemulate.compare import AutoEmulate
from autoemulate.emulators import RadialBasisFunctions
from autoemulate.plotting import _check_multioutput
from autoemulate.plotting import _plot_cv
from autoemulate.plotting import _plot_model
from autoemulate.plotting import _plot_single_fold
from autoemulate.plotting import _predict_with_optional_std
from autoemulate.plotting import _validate_inputs
from autoemulate.plotting import check_multioutput


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -72,7 +72,7 @@ def test_check_multioutput_with_single_output():
y = np.array([1, 2, 3, 4, 5])
output_index = 0
try:
check_multioutput(y, output_index)
_check_multioutput(y, output_index)
except ValueError as e:
assert False, f"Unexpected ValueError: {str(e)}"

Expand All @@ -81,7 +81,7 @@ def test_check_multioutput_with_multioutput():
y = np.array([[1, 2, 3], [4, 5, 6]])
output_index = 1
try:
check_multioutput(y, output_index)
_check_multioutput(y, output_index)
except ValueError as e:
assert False, f"Unexpected ValueError: {str(e)}"

Expand All @@ -90,7 +90,7 @@ def test_check_multioutput_with_invalid_output_index():
y = np.array([[1, 2, 3], [4, 5, 6]])
output_index = 3
try:
check_multioutput(y, output_index)
_check_multioutput(y, output_index)
assert False, "Expected ValueError to be raised"
except ValueError as e:
assert (
Expand Down Expand Up @@ -354,7 +354,7 @@ def test__plot_model_int(ae_single_output):
output_index=0,
)
assert isinstance(fig, plt.Figure)
assert fig.axes[0].get_title() == "X0 vs. y0"
assert all(term in fig.axes[0].get_title() for term in ["X", "y", "vs."])


def test__plot_model_list(ae_single_output):
Expand All @@ -367,7 +367,7 @@ def test__plot_model_list(ae_single_output):
output_index=[0],
)
assert isinstance(fig, plt.Figure)
assert fig.axes[1].get_title() == "X1 vs. y0"
assert all(term in fig.axes[1].get_title() for term in ["X", "y", "vs."])


def test__plot_model_int_out_of_range(ae_single_output):
Expand Down

0 comments on commit 44d7f8b

Please sign in to comment.