From d462e8cd876e95ca005257f5ef2ad17cfdabedb8 Mon Sep 17 00:00:00 2001 From: Miguel de Benito Delgado Date: Mon, 23 Dec 2024 15:59:05 +0100 Subject: [PATCH] More type fixes --- src/pydvl/reporting/plots.py | 34 +++++++++++++++++++--------------- src/pydvl/valuation/dataset.py | 4 ++-- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/src/pydvl/reporting/plots.py b/src/pydvl/reporting/plots.py index 147ae1d7a..abe2b0a7f 100644 --- a/src/pydvl/reporting/plots.py +++ b/src/pydvl/reporting/plots.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, List, Literal, Optional, OrderedDict, Sequence +from typing import Any, List, Literal, Optional, OrderedDict, Sequence, cast import matplotlib.pyplot as plt import numpy as np @@ -62,9 +62,9 @@ def shaded_mean_std( ax.fill_between(abscissa, mean - std, mean + std, alpha=0.3, color=shade_color) ax.plot(abscissa, mean, color=mean_color, **kwargs) - ax.set_title(title) - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) + ax.set_title(title or "") + ax.set_xlabel(xlabel or "") + ax.set_ylabel(ylabel or "") return ax @@ -110,9 +110,11 @@ def plot_ci_array( variances=variances, counts=np.ones_like(means, dtype=np.int_) * m, indices=np.arange(n), - data_names=np.array(abscissa, dtype=str) - if abscissa is not None - else np.arange(n, dtype=str), + data_names=( + np.array(abscissa, dtype=str) + if abscissa is not None + else np.arange(n, dtype=str) + ), ) return plot_ci_values( @@ -135,7 +137,7 @@ def plot_ci_values( shade_color: Optional[str] = "lightblue", ax: Optional[plt.Axes] = None, **kwargs, -): +) -> plt.Axes: """Plot values and a confidence interval. Uses `values.data_names` for the x-axis. @@ -163,9 +165,11 @@ def plot_ci_values( ppfs = { "normal": norm.ppf, "t": partial(t.ppf, df=values.counts - 1), - "auto": norm.ppf - if np.min(values.counts) > 30 - else partial(t.ppf, df=values.counts - 1), + "auto": ( + norm.ppf + if np.min(values.counts) > 30 + else partial(t.ppf, df=values.counts - 1) + ), } try: @@ -264,9 +268,9 @@ def plot_shapley( yerr = norm.ppf(1 - level / 2) * df[f"{prefix}_stderr"] ax.errorbar(x=df.index, y=df[prefix], yerr=yerr, fmt="o", capsize=6) - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - ax.set_title(title) + ax.set_xlabel(xlabel or "") + ax.set_ylabel(ylabel or "") + ax.set_title(title or "") plt.xticks(rotation=60) return ax @@ -288,7 +292,7 @@ def plot_influence_distribution( ax.set_xlabel("Influence values") ax.set_ylabel("Number of samples") ax.set_title(f"Distribution of influences {title_extra}") - return ax + return cast(plt.Axes, ax) def plot_influence_distribution_by_label( diff --git a/src/pydvl/valuation/dataset.py b/src/pydvl/valuation/dataset.py index 2f474baa2..ae821aaf6 100644 --- a/src/pydvl/valuation/dataset.py +++ b/src/pydvl/valuation/dataset.py @@ -152,13 +152,13 @@ def __getitem__(self, idx: int | slice | Sequence[int]) -> Dataset: y=self._y[idx], feature_names=self.feature_names, target_names=self.target_names, - data_names=self._data_names[idx], # type: ignore + data_names=self._data_names[idx], description="(SLICED): " + self.description, ) def feature(self, name: str) -> tuple[slice, int]: try: - return np.index_exp[:, self.feature_names.index(name)] + return np.index_exp[:, self.feature_names.index(name)] # type: ignore except ValueError: raise ValueError(f"Feature {name} is not in {self.feature_names}")