Skip to content

Commit

Permalink
TYP: towards matplotlib 3.8 (#55253)
Browse files Browse the repository at this point in the history
* TYP: towards matplotlib 3.8

* test 3.8

* ignore pyright errors

* merging error

* Conditional on import

* Disable parallel build to see docbuild error

* Some unnecessary ignores

* Add typing in test_sql

* type ignores

* Multiple ignores

* Uncomment

---------

Co-authored-by: Matthew Roeschke <[email protected]>
  • Loading branch information
twoertwein and mroeschke authored Nov 15, 2023
1 parent 800ae25 commit d999aac
Show file tree
Hide file tree
Showing 16 changed files with 149 additions and 62 deletions.
2 changes: 1 addition & 1 deletion ci/deps/actions-310.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- gcsfs>=2022.11.0
- jinja2>=3.1.2
- lxml>=4.9.2
- matplotlib>=3.6.3, <3.8
- matplotlib>=3.6.3
- numba>=0.56.4
- numexpr>=2.8.4
- odfpy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion ci/deps/actions-311-downstream_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies:
- gcsfs>=2022.11.0
- jinja2>=3.1.2
- lxml>=4.9.2
- matplotlib>=3.6.3, <3.8
- matplotlib>=3.6.3
- numba>=0.56.4
- numexpr>=2.8.4
- odfpy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion ci/deps/actions-311.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- gcsfs>=2022.11.0
- jinja2>=3.1.2
- lxml>=4.9.2
- matplotlib>=3.6.3, <3.8
- matplotlib>=3.6.3
- numba>=0.56.4
- numexpr>=2.8.4
- odfpy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion ci/deps/actions-39.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- gcsfs>=2022.11.0
- jinja2>=3.1.2
- lxml>=4.9.2
- matplotlib>=3.6.3, <3.8
- matplotlib>=3.6.3
- numba>=0.56.4
- numexpr>=2.8.4
- odfpy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion ci/deps/circle-310-arm64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies:
- gcsfs>=2022.11.0
- jinja2>=3.1.2
- lxml>=4.9.2
- matplotlib>=3.6.3, <3.8
- matplotlib>=3.6.3
- numba>=0.56.4
- numexpr>=2.8.4
- odfpy>=1.4.1
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies:
- ipython
- jinja2>=3.1.2
- lxml>=4.9.2
- matplotlib>=3.6.3, <3.8
- matplotlib>=3.6.3
- numba>=0.56.4
- numexpr>=2.8.4
- openpyxl>=3.1.0
Expand Down
28 changes: 22 additions & 6 deletions pandas/plotting/_matplotlib/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
if TYPE_CHECKING:
from collections.abc import Generator

from matplotlib.axis import Axis

from pandas._libs.tslibs.offsets import BaseOffset


Expand Down Expand Up @@ -187,7 +189,7 @@ class TimeFormatter(Formatter):
def __init__(self, locs) -> None:
self.locs = locs

def __call__(self, x, pos: int = 0) -> str:
def __call__(self, x, pos: int | None = 0) -> str:
"""
Return the time of day as a formatted string.
Expand Down Expand Up @@ -364,8 +366,14 @@ def get_locator(self, dmin, dmax):
locator = MilliSecondLocator(self.tz)
locator.set_axis(self.axis)

locator.axis.set_view_interval(*self.axis.get_view_interval())
locator.axis.set_data_interval(*self.axis.get_data_interval())
# error: Item "None" of "Axis | _DummyAxis | _AxisWrapper | None"
# has no attribute "get_data_interval"
locator.axis.set_view_interval( # type: ignore[union-attr]
*self.axis.get_view_interval() # type: ignore[union-attr]
)
locator.axis.set_data_interval( # type: ignore[union-attr]
*self.axis.get_data_interval() # type: ignore[union-attr]
)
return locator

return mdates.AutoDateLocator.get_locator(self, dmin, dmax)
Expand Down Expand Up @@ -950,6 +958,8 @@ class TimeSeries_DateLocator(Locator):
day : {int}, optional
"""

axis: Axis

def __init__(
self,
freq: BaseOffset,
Expand Down Expand Up @@ -999,7 +1009,9 @@ def __call__(self):
base = self.base
(d, m) = divmod(vmin, base)
vmin = (d + 1) * base
locs = list(range(vmin, vmax + 1, base))
# error: No overload variant of "range" matches argument types "float",
# "float", "int"
locs = list(range(vmin, vmax + 1, base)) # type: ignore[call-overload]
return locs

def autoscale(self):
Expand Down Expand Up @@ -1038,6 +1050,8 @@ class TimeSeries_DateFormatter(Formatter):
Whether the formatter works in dynamic mode or not.
"""

axis: Axis

def __init__(
self,
freq: BaseOffset,
Expand Down Expand Up @@ -1084,7 +1098,7 @@ def set_locs(self, locs) -> None:
(vmin, vmax) = (vmax, vmin)
self._set_default_format(vmin, vmax)

def __call__(self, x, pos: int = 0) -> str:
def __call__(self, x, pos: int | None = 0) -> str:
if self.formatdict is None:
return ""
else:
Expand All @@ -1107,6 +1121,8 @@ class TimeSeries_TimedeltaFormatter(Formatter):
Formats the ticks along an axis controlled by a :class:`TimedeltaIndex`.
"""

axis: Axis

@staticmethod
def format_timedelta_ticks(x, pos, n_decimals: int) -> str:
"""
Expand All @@ -1124,7 +1140,7 @@ def format_timedelta_ticks(x, pos, n_decimals: int) -> str:
s = f"{int(d):d} days {s}"
return s

def __call__(self, x, pos: int = 0) -> str:
def __call__(self, x, pos: int | None = 0) -> str:
(vmin, vmax) = tuple(self.axis.get_view_interval())
n_decimals = min(int(np.ceil(np.log10(100 * 10**9 / abs(vmax - vmin)))), 9)
return self.format_timedelta_ticks(x, pos, n_decimals)
82 changes: 57 additions & 25 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,9 +529,16 @@ def _maybe_right_yaxis(self, ax: Axes, axes_num: int) -> Axes:
# otherwise, create twin axes
orig_ax, new_ax = ax, ax.twinx()
# TODO: use Matplotlib public API when available
new_ax._get_lines = orig_ax._get_lines
new_ax._get_patches_for_fill = orig_ax._get_patches_for_fill
orig_ax.right_ax, new_ax.left_ax = new_ax, orig_ax
new_ax._get_lines = orig_ax._get_lines # type: ignore[attr-defined]
# TODO #54485
new_ax._get_patches_for_fill = ( # type: ignore[attr-defined]
orig_ax._get_patches_for_fill # type: ignore[attr-defined]
)
# TODO #54485
orig_ax.right_ax, new_ax.left_ax = ( # type: ignore[attr-defined]
new_ax,
orig_ax,
)

if not self._has_plotted_object(orig_ax): # no data on left y
orig_ax.get_yaxis().set_visible(False)
Expand All @@ -540,7 +547,7 @@ def _maybe_right_yaxis(self, ax: Axes, axes_num: int) -> Axes:
new_ax.set_yscale("log")
elif self.logy == "sym" or self.loglog == "sym":
new_ax.set_yscale("symlog")
return new_ax
return new_ax # type: ignore[return-value]

@final
@cache_readonly
Expand Down Expand Up @@ -1206,12 +1213,15 @@ def _get_errorbars(

@final
def _get_subplots(self, fig: Figure):
from matplotlib.axes import Subplot
if Version(mpl.__version__) < Version("3.8"):
from matplotlib.axes import Subplot as Klass
else:
from matplotlib.axes import Axes as Klass

return [
ax
for ax in fig.get_axes()
if (isinstance(ax, Subplot) and ax.get_subplotspec() is not None)
if (isinstance(ax, Klass) and ax.get_subplotspec() is not None)
]

@final
Expand Down Expand Up @@ -1255,8 +1265,10 @@ def _post_plot_logic(self, ax: Axes, data) -> None:
x, y = self.x, self.y
xlabel = self.xlabel if self.xlabel is not None else pprint_thing(x)
ylabel = self.ylabel if self.ylabel is not None else pprint_thing(y)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
# error: Argument 1 to "set_xlabel" of "_AxesBase" has incompatible
# type "Hashable"; expected "str"
ax.set_xlabel(xlabel) # type: ignore[arg-type]
ax.set_ylabel(ylabel) # type: ignore[arg-type]

@final
def _plot_colorbar(self, ax: Axes, *, fig: Figure, **kwds):
Expand Down Expand Up @@ -1393,7 +1405,7 @@ def _get_norm_and_cmap(self, c_values, color_by_categorical: bool):
else:
cmap = None

if color_by_categorical:
if color_by_categorical and cmap is not None:
from matplotlib import colors

n_cats = len(self.data[c].cat.categories)
Expand Down Expand Up @@ -1584,13 +1596,13 @@ def _ts_plot(self, ax: Axes, x, data: Series, style=None, **kwds):
decorate_axes(ax.left_ax, freq, kwds)
if hasattr(ax, "right_ax"):
decorate_axes(ax.right_ax, freq, kwds)
ax._plot_data.append((data, self._kind, kwds))
# TODO #54485
ax._plot_data.append((data, self._kind, kwds)) # type: ignore[attr-defined]

lines = self._plot(ax, data.index, np.asarray(data.values), style=style, **kwds)
# set date formatter, locators and rescale limits
# error: Argument 3 to "format_dateaxis" has incompatible type "Index";
# expected "DatetimeIndex | PeriodIndex"
format_dateaxis(ax, ax.freq, data.index) # type: ignore[arg-type]
# TODO #54485
format_dateaxis(ax, ax.freq, data.index) # type: ignore[arg-type, attr-defined]
return lines

@final
Expand All @@ -1606,11 +1618,15 @@ def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None:
if stacking_id is None:
return
if not hasattr(ax, "_stacker_pos_prior"):
ax._stacker_pos_prior = {}
# TODO #54485
ax._stacker_pos_prior = {} # type: ignore[attr-defined]
if not hasattr(ax, "_stacker_neg_prior"):
ax._stacker_neg_prior = {}
ax._stacker_pos_prior[stacking_id] = np.zeros(n)
ax._stacker_neg_prior[stacking_id] = np.zeros(n)
# TODO #54485
ax._stacker_neg_prior = {} # type: ignore[attr-defined]
# TODO #54485
ax._stacker_pos_prior[stacking_id] = np.zeros(n) # type: ignore[attr-defined]
# TODO #54485
ax._stacker_neg_prior[stacking_id] = np.zeros(n) # type: ignore[attr-defined]

@final
@classmethod
Expand All @@ -1624,9 +1640,17 @@ def _get_stacked_values(
cls._initialize_stacker(ax, stacking_id, len(values))

if (values >= 0).all():
return ax._stacker_pos_prior[stacking_id] + values
# TODO #54485
return (
ax._stacker_pos_prior[stacking_id] # type: ignore[attr-defined]
+ values
)
elif (values <= 0).all():
return ax._stacker_neg_prior[stacking_id] + values
# TODO #54485
return (
ax._stacker_neg_prior[stacking_id] # type: ignore[attr-defined]
+ values
)

raise ValueError(
"When stacked is True, each column must be either "
Expand All @@ -1640,9 +1664,11 @@ def _update_stacker(cls, ax: Axes, stacking_id: int | None, values) -> None:
if stacking_id is None:
return
if (values >= 0).all():
ax._stacker_pos_prior[stacking_id] += values
# TODO #54485
ax._stacker_pos_prior[stacking_id] += values # type: ignore[attr-defined]
elif (values <= 0).all():
ax._stacker_neg_prior[stacking_id] += values
# TODO #54485
ax._stacker_neg_prior[stacking_id] += values # type: ignore[attr-defined]

def _post_plot_logic(self, ax: Axes, data) -> None:
from matplotlib.ticker import FixedLocator
Expand All @@ -1658,7 +1684,9 @@ def get_label(i):
if self._need_to_set_index:
xticks = ax.get_xticks()
xticklabels = [get_label(x) for x in xticks]
ax.xaxis.set_major_locator(FixedLocator(xticks))
# error: Argument 1 to "FixedLocator" has incompatible type "ndarray[Any,
# Any]"; expected "Sequence[float]"
ax.xaxis.set_major_locator(FixedLocator(xticks)) # type: ignore[arg-type]
ax.set_xticklabels(xticklabels)

# If the index is an irregular time series, then by default
Expand Down Expand Up @@ -1737,9 +1765,11 @@ def _plot( # type: ignore[override]
if stacking_id is None:
start = np.zeros(len(y))
elif (y >= 0).all():
start = ax._stacker_pos_prior[stacking_id]
# TODO #54485
start = ax._stacker_pos_prior[stacking_id] # type: ignore[attr-defined]
elif (y <= 0).all():
start = ax._stacker_neg_prior[stacking_id]
# TODO #54485
start = ax._stacker_neg_prior[stacking_id] # type: ignore[attr-defined]
else:
start = np.zeros(len(y))

Expand Down Expand Up @@ -2005,7 +2035,9 @@ def _decorate_ticks(
ax.set_yticklabels(ticklabels)
if name is not None and self.use_index:
ax.set_ylabel(name)
ax.set_xlabel(self.xlabel)
# error: Argument 1 to "set_xlabel" of "_AxesBase" has incompatible type
# "Hashable | None"; expected "str"
ax.set_xlabel(self.xlabel) # type: ignore[arg-type]


class PiePlot(MPLPlot):
Expand Down
26 changes: 21 additions & 5 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,21 @@ def _get_column_weights(weights, i: int, y):

def _post_plot_logic(self, ax: Axes, data) -> None:
if self.orientation == "horizontal":
ax.set_xlabel("Frequency" if self.xlabel is None else self.xlabel)
ax.set_ylabel(self.ylabel)
# error: Argument 1 to "set_xlabel" of "_AxesBase" has incompatible
# type "Hashable"; expected "str"
ax.set_xlabel(
"Frequency"
if self.xlabel is None
else self.xlabel # type: ignore[arg-type]
)
ax.set_ylabel(self.ylabel) # type: ignore[arg-type]
else:
ax.set_xlabel(self.xlabel)
ax.set_ylabel("Frequency" if self.ylabel is None else self.ylabel)
ax.set_xlabel(self.xlabel) # type: ignore[arg-type]
ax.set_ylabel(
"Frequency"
if self.ylabel is None
else self.ylabel # type: ignore[arg-type]
)

@property
def orientation(self) -> PlottingOrientation:
Expand Down Expand Up @@ -447,8 +457,14 @@ def hist_series(
ax.grid(grid)
axes = np.array([ax])

# error: Argument 1 to "set_ticks_props" has incompatible type "ndarray[Any,
# dtype[Any]]"; expected "Axes | Sequence[Axes]"
set_ticks_props(
axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
axes, # type: ignore[arg-type]
xlabelsize=xlabelsize,
xrot=xrot,
ylabelsize=ylabelsize,
yrot=yrot,
)

else:
Expand Down
4 changes: 3 additions & 1 deletion pandas/plotting/_matplotlib/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,9 @@ def _is_single_string_color(color: Color) -> bool:
"""
conv = matplotlib.colors.ColorConverter()
try:
conv.to_rgba(color)
# error: Argument 1 to "to_rgba" of "ColorConverter" has incompatible type
# "str | Sequence[float]"; expected "tuple[float, float, float] | ..."
conv.to_rgba(color) # type: ignore[arg-type]
except ValueError:
return False
else:
Expand Down
Loading

0 comments on commit d999aac

Please sign in to comment.