Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: make plotting less stateful (2) #55850

Merged
merged 5 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def __init__(self, data, return_type: str = "axes", **kwargs) -> None:
# Do not call LinePlot.__init__ which may fill nan
MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called

def _args_adjust(self) -> None:
if self.subplots:
# Disable label ax sharing. Otherwise, all subplots shows last
# column label
Expand Down
64 changes: 25 additions & 39 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pandas.core.dtypes.common import (
is_any_real_numeric_dtype,
is_bool,
is_float,
is_float_dtype,
is_hashable,
Expand Down Expand Up @@ -130,7 +131,7 @@ def __init__(
kind=None,
by: IndexLabel | None = None,
subplots: bool | Sequence[Sequence[str]] = False,
sharex=None,
sharex: bool | None = None,
sharey: bool = False,
use_index: bool = True,
figsize: tuple[float, float] | None = None,
Expand Down Expand Up @@ -191,17 +192,7 @@ def __init__(

self.subplots = self._validate_subplots_kwarg(subplots)

if sharex is None:
# if by is defined, subplots are used and sharex should be False
if ax is None and by is None:
self.sharex = True
else:
# if we get an axis, the users should do the visibility
# setting...
self.sharex = False
else:
self.sharex = sharex

self.sharex = self._validate_sharex(sharex, ax, by)
self.sharey = sharey
self.figsize = figsize
self.layout = layout
Expand Down Expand Up @@ -275,6 +266,20 @@ def __init__(

self._validate_color_args()

@final
def _validate_sharex(self, sharex: bool | None, ax, by) -> bool:
if sharex is None:
# if by is defined, subplots are used and sharex should be False
if ax is None and by is None: # pylint: disable=simplifiable-if-statement
sharex = True
else:
# if we get an axis, the users should do the visibility
# setting...
sharex = False
elif not is_bool(sharex):
raise TypeError("sharex must be a bool or None")
return bool(sharex)

@final
def _validate_subplots_kwarg(
self, subplots: bool | Sequence[Sequence[str]]
Expand Down Expand Up @@ -454,7 +459,6 @@ def draw(self) -> None:

@final
def generate(self) -> None:
self._args_adjust()
self._compute_plot_data()
fig = self._setup_subplots()
self._make_plot(fig)
Expand All @@ -466,10 +470,6 @@ def generate(self) -> None:
self._post_plot_logic_common(ax, self.data)
self._post_plot_logic(ax, self.data)

@abstractmethod
def _args_adjust(self) -> None:
pass

@final
def _has_plotted_object(self, ax: Axes) -> bool:
"""check whether ax has data"""
Expand Down Expand Up @@ -1326,9 +1326,6 @@ def _make_plot(self, fig: Figure):
err_kwds["ecolor"] = scatter.get_facecolor()[0]
ax.errorbar(data[x].values, data[y].values, linestyle="none", **err_kwds)

def _args_adjust(self) -> None:
pass


class HexBinPlot(PlanePlot):
@property
Expand Down Expand Up @@ -1361,9 +1358,6 @@ def _make_plot(self, fig: Figure) -> None:
def _make_legend(self) -> None:
pass

def _args_adjust(self) -> None:
pass


class LinePlot(MPLPlot):
_default_rot = 0
Expand Down Expand Up @@ -1529,9 +1523,6 @@ def _update_stacker(cls, ax: Axes, stacking_id, values) -> None:
elif (values <= 0).all():
ax._stacker_neg_prior[stacking_id] += values

def _args_adjust(self) -> None:
pass

def _post_plot_logic(self, ax: Axes, data) -> None:
from matplotlib.ticker import FixedLocator

Expand Down Expand Up @@ -1641,9 +1632,6 @@ def _plot( # type: ignore[override]
res = [rect]
return res

def _args_adjust(self) -> None:
pass

def _post_plot_logic(self, ax: Axes, data) -> None:
LinePlot._post_plot_logic(self, ax, data)

Expand Down Expand Up @@ -1676,8 +1664,14 @@ def __init__(self, data, **kwargs) -> None:
kwargs.setdefault("align", "center")
self.tick_pos = np.arange(len(data))

self.bottom = kwargs.pop("bottom", 0)
self.left = kwargs.pop("left", 0)
bottom = kwargs.pop("bottom", 0)
left = kwargs.pop("left", 0)
if is_list_like(bottom):
bottom = np.array(bottom)
if is_list_like(left):
left = np.array(left)
self.bottom = bottom
self.left = left

self.log = kwargs.pop("log", False)
MPLPlot.__init__(self, data, **kwargs)
Expand All @@ -1698,12 +1692,6 @@ def __init__(self, data, **kwargs) -> None:

self.ax_pos = self.tick_pos - self.tickoffset

def _args_adjust(self) -> None:
if is_list_like(self.bottom):
self.bottom = np.array(self.bottom)
if is_list_like(self.left):
self.left = np.array(self.left)

# error: Signature of "_plot" incompatible with supertype "MPLPlot"
@classmethod
def _plot( # type: ignore[override]
Expand Down Expand Up @@ -1874,8 +1862,6 @@ def __init__(self, data, kind=None, **kwargs) -> None:
if (data < 0).any().any():
raise ValueError(f"{self._kind} plot doesn't allow negative values")
MPLPlot.__init__(self, data, kind=kind, **kwargs)

def _args_adjust(self) -> None:
self.grid = False
self.logy = False
self.logx = False
Expand Down
29 changes: 12 additions & 17 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,36 +58,34 @@ def __init__(
bottom: int | np.ndarray = 0,
**kwargs,
) -> None:
self.bins = bins # use mpl default
if is_list_like(bottom):
bottom = np.array(bottom)
self.bottom = bottom

self.xlabel = kwargs.get("xlabel")
self.ylabel = kwargs.get("ylabel")
# Do not call LinePlot.__init__ which may fill nan
MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called

def _args_adjust(self) -> None:
# calculate bin number separately in different subplots
# where subplots are created based on by argument
if is_integer(self.bins):
self.bins = self._adjust_bins(bins)

def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]):
if is_integer(bins):
if self.by is not None:
by_modified = unpack_single_str_list(self.by)
grouped = self.data.groupby(by_modified)[self.columns]
self.bins = [self._calculate_bins(group) for key, group in grouped]
bins = [self._calculate_bins(group, bins) for key, group in grouped]
else:
self.bins = self._calculate_bins(self.data)

if is_list_like(self.bottom):
self.bottom = np.array(self.bottom)
bins = self._calculate_bins(self.data, bins)
return bins

def _calculate_bins(self, data: DataFrame) -> np.ndarray:
def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray:
"""Calculate bins given data"""
nd_values = data.infer_objects(copy=False)._get_numeric_data()
values = np.ravel(nd_values)
values = values[~isna(values)]

hist, bins = np.histogram(
values, bins=self.bins, range=self.kwds.get("range", None)
)
hist, bins = np.histogram(values, bins=bins, range=self.kwds.get("range", None))
return bins

# error: Signature of "_plot" incompatible with supertype "LinePlot"
Expand Down Expand Up @@ -211,9 +209,6 @@ def __init__(self, data, bw_method=None, ind=None, **kwargs) -> None:
self.bw_method = bw_method
self.ind = ind

def _args_adjust(self) -> None:
pass

def _get_ind(self, y):
if self.ind is None:
# np.nanmax() and np.nanmin() ignores the missing values
Expand Down
Loading