Skip to content

Commit

Permalink
REF: make plotting less stateful (4) (#55872)
Browse files Browse the repository at this point in the history
* REF: make plotting less stateful (4)

* REF: make plotting less stateful (4)

* TYP: plotting

* REF: make plotting less stateful (4)
  • Loading branch information
jbrockmendel authored Nov 7, 2023
1 parent 21fa354 commit 97c61e8
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 22 deletions.
25 changes: 20 additions & 5 deletions pandas/plotting/_matplotlib/boxplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from matplotlib.artist import setp
import numpy as np

from pandas.util._decorators import cache_readonly
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.common import is_dict_like
Expand Down Expand Up @@ -132,15 +133,29 @@ def _validate_color_args(self):
else:
self.color = None

@cache_readonly
def _color_attrs(self):
# get standard colors for default
colors = get_standard_colors(num_colors=3, colormap=self.colormap, color=None)
# use 2 colors by default, for box/whisker and median
# flier colors isn't needed here
# because it can be specified by ``sym`` kw
self._boxes_c = colors[0]
self._whiskers_c = colors[0]
self._medians_c = colors[2]
self._caps_c = colors[0]
return get_standard_colors(num_colors=3, colormap=self.colormap, color=None)

@cache_readonly
def _boxes_c(self):
return self._color_attrs[0]

@cache_readonly
def _whiskers_c(self):
return self._color_attrs[0]

@cache_readonly
def _medians_c(self):
return self._color_attrs[2]

@cache_readonly
def _caps_c(self):
return self._color_attrs[0]

def _get_colors(
self,
Expand Down
20 changes: 14 additions & 6 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@
npt,
)

from pandas import Series


def _color_in_style(style: str) -> bool:
"""
Expand Down Expand Up @@ -471,7 +473,8 @@ def generate(self) -> None:
self._post_plot_logic(ax, self.data)

@final
def _has_plotted_object(self, ax: Axes) -> bool:
@staticmethod
def _has_plotted_object(ax: Axes) -> bool:
"""check whether ax has data"""
return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0

Expand Down Expand Up @@ -576,7 +579,8 @@ def result(self):
return self.axes[0]

@final
def _convert_to_ndarray(self, data):
@staticmethod
def _convert_to_ndarray(data):
# GH31357: categorical columns are processed separately
if isinstance(data.dtype, CategoricalDtype):
return data
Expand Down Expand Up @@ -767,6 +771,7 @@ def _apply_axis_properties(
if fontsize is not None:
label.set_fontsize(fontsize)

@final
@property
def legend_title(self) -> str | None:
if not isinstance(self.data.columns, ABCMultiIndex):
Expand Down Expand Up @@ -836,7 +841,8 @@ def _make_legend(self) -> None:
ax.legend(loc="best")

@final
def _get_ax_legend(self, ax: Axes):
@staticmethod
def _get_ax_legend(ax: Axes):
"""
Take in axes and return ax and legend under different scenarios
"""
Expand Down Expand Up @@ -1454,7 +1460,7 @@ def _plot( # type: ignore[override]
return lines

@final
def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):
def _ts_plot(self, ax: Axes, x, data: Series, style=None, **kwds):
# accept x to be consistent with normal plot func,
# x is not passed to tsplot as it uses data.index as x coordinate
# column_num must be in kwds for stacking purpose
Expand All @@ -1471,11 +1477,13 @@ def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):

lines = self._plot(ax, data.index, data.values, style=style, **kwds)
# set date formatter, locators and rescale limits
format_dateaxis(ax, ax.freq, data.index)
# error: Argument 3 to "format_dateaxis" has incompatible type "Index";
# expected "DatetimeIndex | PeriodIndex"
format_dateaxis(ax, ax.freq, data.index) # type: ignore[arg-type]
return lines

@final
def _get_stacking_id(self):
def _get_stacking_id(self) -> int | None:
if self.stacked:
return id(self.data)
else:
Expand Down
19 changes: 11 additions & 8 deletions pandas/plotting/_matplotlib/hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,16 @@ def __init__(
data,
bins: int | np.ndarray | list[np.ndarray] = 10,
bottom: int | np.ndarray = 0,
*,
range=None,
**kwargs,
) -> None:
if is_list_like(bottom):
bottom = np.array(bottom)
self.bottom = bottom

self._bin_range = range

self.xlabel = kwargs.get("xlabel")
self.ylabel = kwargs.get("ylabel")
# Do not call LinePlot.__init__ which may fill nan
Expand All @@ -85,7 +89,7 @@ def _calculate_bins(self, data: DataFrame, bins) -> np.ndarray:
values = np.ravel(nd_values)
values = values[~isna(values)]

hist, bins = np.histogram(values, bins=bins, range=self.kwds.get("range", None))
hist, bins = np.histogram(values, bins=bins, range=self._bin_range)
return bins

# error: Signature of "_plot" incompatible with supertype "LinePlot"
Expand Down Expand Up @@ -209,24 +213,23 @@ def __init__(self, data, bw_method=None, ind=None, **kwargs) -> None:
self.bw_method = bw_method
self.ind = ind

def _get_ind(self, y):
if self.ind is None:
@staticmethod
def _get_ind(y, ind):
if ind is None:
# np.nanmax() and np.nanmin() ignores the missing values
sample_range = np.nanmax(y) - np.nanmin(y)
ind = np.linspace(
np.nanmin(y) - 0.5 * sample_range,
np.nanmax(y) + 0.5 * sample_range,
1000,
)
elif is_integer(self.ind):
elif is_integer(ind):
sample_range = np.nanmax(y) - np.nanmin(y)
ind = np.linspace(
np.nanmin(y) - 0.5 * sample_range,
np.nanmax(y) + 0.5 * sample_range,
self.ind,
ind,
)
else:
ind = self.ind
return ind

@classmethod
Expand All @@ -252,7 +255,7 @@ def _plot(

def _make_plot_keywords(self, kwds, y):
kwds["bw_method"] = self.bw_method
kwds["ind"] = self._get_ind(y)
kwds["ind"] = self._get_ind(y, ind=self.ind)
return kwds

def _post_plot_logic(self, ax, data) -> None:
Expand Down
8 changes: 5 additions & 3 deletions pandas/plotting/_matplotlib/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
DataFrame,
DatetimeIndex,
Index,
PeriodIndex,
Series,
)

Expand Down Expand Up @@ -300,16 +301,17 @@ def maybe_convert_index(ax: Axes, data):
return data


# Patch methods for subplot. Only format_dateaxis is currently used.
# Do we need the rest for convenience?
# Patch methods for subplot.


def _format_coord(freq, t, y) -> str:
time_period = Period(ordinal=int(t), freq=freq)
return f"t = {time_period} y = {y:8f}"


def format_dateaxis(subplot, freq, index) -> None:
def format_dateaxis(
subplot, freq: BaseOffset, index: DatetimeIndex | PeriodIndex
) -> None:
"""
Pretty-formats the date axis (x-axis).
Expand Down

0 comments on commit 97c61e8

Please sign in to comment.