Skip to content

Commit

Permalink
REF: make plotting less stateful (2) (#55850)
Browse files Browse the repository at this point in the history
* REF: validate_sharex

* REF: make plotting less stateful (2)

* pylint ignore

* mypy fixup
  • Loading branch information
jbrockmendel authored Nov 7, 2023
1 parent 8c52003 commit 1ce10d6
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 57 deletions.
1 change: 0 additions & 1 deletion pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def __init__(self, data, return_type: str = "axes", **kwargs) -> None:
# Do not call LinePlot.__init__ which may fill nan
MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called

def _args_adjust(self) -> None:
if self.subplots:
# Disable label ax sharing. Otherwise, all subplots shows last
# column label
Expand Down
64 changes: 25 additions & 39 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pandas.core.dtypes.common import (
is_any_real_numeric_dtype,
is_bool,
is_float,
is_float_dtype,
is_hashable,
Expand Down Expand Up @@ -130,7 +131,7 @@ def __init__(
kind=None,
by: IndexLabel | None = None,
subplots: bool | Sequence[Sequence[str]] = False,
sharex=None,
sharex: bool | None = None,
sharey: bool = False,
use_index: bool = True,
figsize: tuple[float, float] | None = None,
Expand Down Expand Up @@ -191,17 +192,7 @@ def __init__(

self.subplots = self._validate_subplots_kwarg(subplots)

if sharex is None:
# if by is defined, subplots are used and sharex should be False
if ax is None and by is None:
self.sharex = True
else:
# if we get an axis, the users should do the visibility
# setting...
self.sharex = False
else:
self.sharex = sharex

self.sharex = self._validate_sharex(sharex, ax, by)
self.sharey = sharey
self.figsize = figsize
self.layout = layout
Expand Down Expand Up @@ -275,6 +266,20 @@ def __init__(

self._validate_color_args()

@final
def _validate_sharex(self, sharex: bool | None, ax, by) -> bool:
if sharex is None:
# if by is defined, subplots are used and sharex should be False
if ax is None and by is None: # pylint: disable=simplifiable-if-statement
sharex = True
else:
# if we get an axis, the users should do the visibility
# setting...
sharex = False
elif not is_bool(sharex):
raise TypeError("sharex must be a bool or None")
return bool(sharex)

@final
def _validate_subplots_kwarg(
self, subplots: bool | Sequence[Sequence[str]]
Expand Down Expand Up @@ -454,7 +459,6 @@ def draw(self) -> None:

@final
def generate(self) -> None:
self._args_adjust()
self._compute_plot_data()
fig = self._setup_subplots()
self._make_plot(fig)
Expand All @@ -466,10 +470,6 @@ def generate(self) -> None:
self._post_plot_logic_common(ax, self.data)
self._post_plot_logic(ax, self.data)

@abstractmethod
def _args_adjust(self) -> None:
pass

@final
def _has_plotted_object(self, ax: Axes) -> bool:
"""check whether ax has data"""
Expand Down Expand Up @@ -1326,9 +1326,6 @@ def _make_plot(self, fig: Figure):
err_kwds["ecolor"] = scatter.get_facecolor()[0]
ax.errorbar(data[x].values, data[y].values, linestyle="none", **err_kwds)

def _args_adjust(self) -> None:
pass


class HexBinPlot(PlanePlot):
@property
Expand Down Expand Up @@ -1361,9 +1358,6 @@ def _make_plot(self, fig: Figure) -> None:
def _make_legend(self) -> None:
pass

def _args_adjust(self) -> None:
pass


class LinePlot(MPLPlot):
_default_rot = 0
Expand Down Expand Up @@ -1529,9 +1523,6 @@ def _update_stacker(cls, ax: Axes, stacking_id, values) -> None:
elif (values <= 0).all():
ax._stacker_neg_prior[stacking_id] += values

def _args_adjust(self) -> None:
pass

def _post_plot_logic(self, ax: Axes, data) -> None:
from matplotlib.ticker import FixedLocator

Expand Down Expand Up @@ -1641,9 +1632,6 @@ def _plot( # type: ignore[override]
res = [rect]
return res

def _args_adjust(self) -> None:
pass

def _post_plot_logic(self, ax: Axes, data) -> None:
LinePlot._post_plot_logic(self, ax, data)

Expand Down Expand Up @@ -1676,8 +1664,14 @@ def __init__(self, data, **kwargs) -> None:
kwargs.setdefault("align", "center")
self.tick_pos = np.arange(len(data))

self.bottom = kwargs.pop("bottom", 0)
self.left = kwargs.pop("left", 0)
bottom = kwargs.pop("bottom", 0)
left = kwargs.pop("left", 0)
if is_list_like(bottom):
bottom = np.array(bottom)
if is_list_like(left):
left = np.array(left)
self.bottom = bottom
self.left = left

self.log = kwargs.pop("log", False)
MPLPlot.__init__(self, data, **kwargs)
Expand All @@ -1698,12 +1692,6 @@ def __init__(self, data, **kwargs) -> None:

self.ax_pos = self.tick_pos - self.tickoffset

def _args_adjust(self) -> None:
if is_list_like(self.bottom):
self.bottom = np.array(self.bottom)
if is_list_like(self.left):
self.left = np.array(self.left)

# error: Signature of "_plot" incompatible with supertype "MPLPlot"
@classmethod
def _plot( # type: ignore[override]
Expand Down Expand Up @@ -1874,8 +1862,6 @@ def __init__(self, data, kind=None, **kwargs) -> None:
if (data < 0).any().any():
raise ValueError(f"{self._kind} plot doesn't allow negative values")
MPLPlot.__init__(self, data, kind=kind, **kwargs)

def _args_adjust(self) -> None:
self.grid = False
self.logy = False
self.logx = False
Expand Down
29 changes: 12 additions & 17 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,36 +58,34 @@ def __init__(
bottom: int | np.ndarray = 0,
**kwargs,
) -> None:
self.bins = bins # use mpl default
if is_list_like(bottom):
bottom = np.array(bottom)
self.bottom = bottom

self.xlabel = kwargs.get("xlabel")
self.ylabel = kwargs.get("ylabel")
# Do not call LinePlot.__init__ which may fill nan
MPLPlot.__init__(self, data, **kwargs) # pylint: disable=non-parent-init-called

def _args_adjust(self) -> None:
# calculate bin number separately in different subplots
# where subplots are created based on by argument
if is_integer(self.bins):
self.bins = self._adjust_bins(bins)

def _adjust_bins(self, bins: int | np.ndarray | list[np.ndarray]):
if is_integer(bins):
if self.by is not None:
by_modified = unpack_single_str_list(self.by)
grouped = self.data.groupby(by_modified)[self.columns]
self.bins = [self._calculate_bins(group) for key, group in grouped]
bins = [self._calculate_bins(group, bins) for key, group in grouped]
else:
self.bins = self._calculate_bins(self.data)

if is_list_like(self.bottom):
self.bottom = np.array(self.bottom)
bins = self._calculate_bins(self.data, bins)
return bins

def _calculate_bins(self, data: DataFrame) -> np.ndarray:
def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray:
"""Calculate bins given data"""
nd_values = data.infer_objects(copy=False)._get_numeric_data()
values = np.ravel(nd_values)
values = values[~isna(values)]

hist, bins = np.histogram(
values, bins=self.bins, range=self.kwds.get("range", None)
)
hist, bins = np.histogram(values, bins=bins, range=self.kwds.get("range", None))
return bins

# error: Signature of "_plot" incompatible with supertype "LinePlot"
Expand Down Expand Up @@ -211,9 +209,6 @@ def __init__(self, data, bw_method=None, ind=None, **kwargs) -> None:
self.bw_method = bw_method
self.ind = ind

def _args_adjust(self) -> None:
pass

def _get_ind(self, y):
if self.ind is None:
# np.nanmax() and np.nanmin() ignores the missing values
Expand Down

0 comments on commit 1ce10d6

Please sign in to comment.