diff --git a/pandas/plotting/_matplotlib/boxplot.py b/pandas/plotting/_matplotlib/boxplot.py index 5fcea796a9c6e..8027cc76a9a40 100644 --- a/pandas/plotting/_matplotlib/boxplot.py +++ b/pandas/plotting/_matplotlib/boxplot.py @@ -35,6 +35,7 @@ from collections.abc import Collection from matplotlib.axes import Axes + from matplotlib.figure import Figure from matplotlib.lines import Line2D from pandas._typing import MatplotlibColor @@ -177,7 +178,7 @@ def maybe_color_bp(self, bp) -> None: if not self.kwds.get("capprops"): setp(bp["caps"], color=caps, alpha=1) - def _make_plot(self) -> None: + def _make_plot(self, fig: Figure) -> None: if self.subplots: self._return_obj = pd.Series(dtype=object) diff --git a/pandas/plotting/_matplotlib/core.py b/pandas/plotting/_matplotlib/core.py index d88605db60720..fa5b7c906355c 100644 --- a/pandas/plotting/_matplotlib/core.py +++ b/pandas/plotting/_matplotlib/core.py @@ -80,6 +80,7 @@ from matplotlib.artist import Artist from matplotlib.axes import Axes from matplotlib.axis import Axis + from matplotlib.figure import Figure from pandas._typing import ( IndexLabel, @@ -241,7 +242,8 @@ def __init__( self.stacked = kwds.pop("stacked", False) self.ax = ax - self.fig = fig + # TODO: deprecate fig keyword as it is ignored, not passed in tests + # as of 2023-11-05 self.axes = np.array([], dtype=object) # "real" version get set in `generate` # parse errorbar input if given @@ -449,11 +451,11 @@ def draw(self) -> None: def generate(self) -> None: self._args_adjust() self._compute_plot_data() - self._setup_subplots() - self._make_plot() + fig = self._setup_subplots() + self._make_plot(fig) self._add_table() self._make_legend() - self._adorn_subplots() + self._adorn_subplots(fig) for ax in self.axes: self._post_plot_logic_common(ax, self.data) @@ -495,7 +497,7 @@ def _maybe_right_yaxis(self, ax: Axes, axes_num): new_ax.set_yscale("symlog") return new_ax - def _setup_subplots(self): + def _setup_subplots(self) -> Figure: if self.subplots: naxes = ( self.nseries if isinstance(self.subplots, bool) else len(self.subplots) @@ -538,8 +540,8 @@ def _setup_subplots(self): elif self.logy == "sym" or self.loglog == "sym": [a.set_yscale("symlog") for a in axes] - self.fig = fig self.axes = axes + return fig @property def result(self): @@ -637,7 +639,7 @@ def _compute_plot_data(self): self.data = numeric_data.apply(self._convert_to_ndarray) - def _make_plot(self): + def _make_plot(self, fig: Figure): raise AbstractMethodError(self) def _add_table(self) -> None: @@ -672,11 +674,11 @@ def _post_plot_logic_common(self, ax, data): def _post_plot_logic(self, ax, data) -> None: """Post process for each axes. Overridden in child classes""" - def _adorn_subplots(self): + def _adorn_subplots(self, fig: Figure): """Common post process unrelated to data""" if len(self.axes) > 0: - all_axes = self._get_subplots() - nrows, ncols = self._get_axes_layout() + all_axes = self._get_subplots(fig) + nrows, ncols = self._get_axes_layout(fig) handle_shared_axes( axarr=all_axes, nplots=len(all_axes), @@ -723,7 +725,7 @@ def _adorn_subplots(self): for ax, title in zip(self.axes, self.title): ax.set_title(title) else: - self.fig.suptitle(self.title) + fig.suptitle(self.title) else: if is_list_like(self.title): msg = ( @@ -1114,17 +1116,17 @@ def _get_errorbars( errors[kw] = err return errors - def _get_subplots(self): + def _get_subplots(self, fig: Figure): from matplotlib.axes import Subplot return [ ax - for ax in self.fig.get_axes() + for ax in fig.get_axes() if (isinstance(ax, Subplot) and ax.get_subplotspec() is not None) ] - def _get_axes_layout(self) -> tuple[int, int]: - axes = self._get_subplots() + def _get_axes_layout(self, fig: Figure) -> tuple[int, int]: + axes = self._get_subplots(fig) x_set = set() y_set = set() for ax in axes: @@ -1172,7 +1174,7 @@ def _post_plot_logic(self, ax: Axes, data) -> None: ax.set_xlabel(xlabel) ax.set_ylabel(ylabel) - def _plot_colorbar(self, ax: Axes, **kwds): + def _plot_colorbar(self, ax: Axes, *, fig: Figure, **kwds): # Addresses issues #10611 and #10678: # When plotting scatterplots and hexbinplots in IPython # inline backend the colorbar axis height tends not to @@ -1189,7 +1191,7 @@ def _plot_colorbar(self, ax: Axes, **kwds): # use the last one which contains the latest information # about the ax img = ax.collections[-1] - return self.fig.colorbar(img, ax=ax, **kwds) + return fig.colorbar(img, ax=ax, **kwds) class ScatterPlot(PlanePlot): @@ -1209,7 +1211,7 @@ def __init__(self, data, x, y, s=None, c=None, **kwargs) -> None: c = self.data.columns[c] self.c = c - def _make_plot(self): + def _make_plot(self, fig: Figure): x, y, c, data = self.x, self.y, self.c, self.data ax = self.axes[0] @@ -1274,7 +1276,7 @@ def _make_plot(self): ) if cb: cbar_label = c if c_is_column else "" - cbar = self._plot_colorbar(ax, label=cbar_label) + cbar = self._plot_colorbar(ax, fig=fig, label=cbar_label) if color_by_categorical: cbar.set_ticks(np.linspace(0.5, n_cats - 0.5, n_cats)) cbar.ax.set_yticklabels(self.data[c].cat.categories) @@ -1306,7 +1308,7 @@ def __init__(self, data, x, y, C=None, **kwargs) -> None: C = self.data.columns[C] self.C = C - def _make_plot(self) -> None: + def _make_plot(self, fig: Figure) -> None: x, y, data, C = self.x, self.y, self.data, self.C ax = self.axes[0] # pandas uses colormap, matplotlib uses cmap. @@ -1321,7 +1323,7 @@ def _make_plot(self) -> None: ax.hexbin(data[x].values, data[y].values, C=c_values, cmap=cmap, **self.kwds) if cb: - self._plot_colorbar(ax) + self._plot_colorbar(ax, fig=fig) def _make_legend(self) -> None: pass @@ -1358,7 +1360,7 @@ def _is_ts_plot(self) -> bool: def _use_dynamic_x(self): return use_dynamic_x(self._get_ax(0), self.data) - def _make_plot(self) -> None: + def _make_plot(self, fig: Figure) -> None: if self._is_ts_plot(): data = maybe_convert_index(self._get_ax(0), self.data) @@ -1680,7 +1682,7 @@ def _plot( # type: ignore[override] def _start_base(self): return self.bottom - def _make_plot(self) -> None: + def _make_plot(self, fig: Figure) -> None: colors = self._get_colors() ncolors = len(colors) @@ -1842,7 +1844,7 @@ def _args_adjust(self) -> None: def _validate_color_args(self) -> None: pass - def _make_plot(self) -> None: + def _make_plot(self, fig: Figure) -> None: colors = self._get_colors(num_colors=len(self.data), color_kwds="colors") self.kwds.setdefault("colors", colors) diff --git a/pandas/plotting/_matplotlib/hist.py b/pandas/plotting/_matplotlib/hist.py index 076b95a885d5e..0a2d158d5f3e8 100644 --- a/pandas/plotting/_matplotlib/hist.py +++ b/pandas/plotting/_matplotlib/hist.py @@ -39,6 +39,7 @@ if TYPE_CHECKING: from matplotlib.axes import Axes + from matplotlib.figure import Figure from pandas._typing import PlottingOrientation @@ -113,7 +114,7 @@ def _plot( # type: ignore[override] cls._update_stacker(ax, stacking_id, n) return patches - def _make_plot(self) -> None: + def _make_plot(self, fig: Figure) -> None: colors = self._get_colors() stacking_id = self._get_stacking_id()