Skip to content

Commit

Permalink
More type fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mdbenito committed Dec 23, 2024
1 parent 129317b commit d462e8c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 17 deletions.
34 changes: 19 additions & 15 deletions src/pydvl/reporting/plots.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Any, List, Literal, Optional, OrderedDict, Sequence
from typing import Any, List, Literal, Optional, OrderedDict, Sequence, cast

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -62,9 +62,9 @@ def shaded_mean_std(
ax.fill_between(abscissa, mean - std, mean + std, alpha=0.3, color=shade_color)
ax.plot(abscissa, mean, color=mean_color, **kwargs)

ax.set_title(title)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title or "")
ax.set_xlabel(xlabel or "")
ax.set_ylabel(ylabel or "")

return ax

Expand Down Expand Up @@ -110,9 +110,11 @@ def plot_ci_array(
variances=variances,
counts=np.ones_like(means, dtype=np.int_) * m,
indices=np.arange(n),
data_names=np.array(abscissa, dtype=str)
if abscissa is not None
else np.arange(n, dtype=str),
data_names=(
np.array(abscissa, dtype=str)
if abscissa is not None
else np.arange(n, dtype=str)
),
)

return plot_ci_values(
Expand All @@ -135,7 +137,7 @@ def plot_ci_values(
shade_color: Optional[str] = "lightblue",
ax: Optional[plt.Axes] = None,
**kwargs,
):
) -> plt.Axes:
"""Plot values and a confidence interval.
Uses `values.data_names` for the x-axis.
Expand Down Expand Up @@ -163,9 +165,11 @@ def plot_ci_values(
ppfs = {
"normal": norm.ppf,
"t": partial(t.ppf, df=values.counts - 1),
"auto": norm.ppf
if np.min(values.counts) > 30
else partial(t.ppf, df=values.counts - 1),
"auto": (
norm.ppf
if np.min(values.counts) > 30
else partial(t.ppf, df=values.counts - 1)
),
}

try:
Expand Down Expand Up @@ -264,9 +268,9 @@ def plot_shapley(
yerr = norm.ppf(1 - level / 2) * df[f"{prefix}_stderr"]

ax.errorbar(x=df.index, y=df[prefix], yerr=yerr, fmt="o", capsize=6)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.set_xlabel(xlabel or "")
ax.set_ylabel(ylabel or "")
ax.set_title(title or "")
plt.xticks(rotation=60)
return ax

Expand All @@ -288,7 +292,7 @@ def plot_influence_distribution(
ax.set_xlabel("Influence values")
ax.set_ylabel("Number of samples")
ax.set_title(f"Distribution of influences {title_extra}")
return ax
return cast(plt.Axes, ax)


def plot_influence_distribution_by_label(
Expand Down
4 changes: 2 additions & 2 deletions src/pydvl/valuation/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,13 @@ def __getitem__(self, idx: int | slice | Sequence[int]) -> Dataset:
y=self._y[idx],
feature_names=self.feature_names,
target_names=self.target_names,
data_names=self._data_names[idx], # type: ignore
data_names=self._data_names[idx],
description="(SLICED): " + self.description,
)

def feature(self, name: str) -> tuple[slice, int]:
try:
return np.index_exp[:, self.feature_names.index(name)]
return np.index_exp[:, self.feature_names.index(name)] # type: ignore
except ValueError:
raise ValueError(f"Feature {name} is not in {self.feature_names}")

Expand Down

0 comments on commit d462e8c

Please sign in to comment.