Skip to content

Commit

Permalink
TYP: plotting (#55829)
Browse files Browse the repository at this point in the history
* TYP: plotting

* mypy fixup
  • Loading branch information
jbrockmendel authored Nov 6, 2023
1 parent ef52fea commit 6d662b8
Showing 1 changed file with 42 additions and 2 deletions.
44 changes: 42 additions & 2 deletions pandas/plotting/_matplotlib/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import (
TYPE_CHECKING,
Literal,
final,
)
import warnings

Expand Down Expand Up @@ -274,6 +275,7 @@ def __init__(

self._validate_color_args()

@final
def _validate_subplots_kwarg(
self, subplots: bool | Sequence[Sequence[str]]
) -> bool | list[tuple[int, ...]]:
Expand Down Expand Up @@ -420,6 +422,7 @@ def _validate_color_args(self):
"other or pass 'style' without a color symbol"
)

@final
def _iter_data(self, data=None, keep_index: bool = False, fillna=None):
if data is None:
data = self.data
Expand All @@ -445,9 +448,11 @@ def nseries(self) -> int:
else:
return self.data.shape[1]

@final
def draw(self) -> None:
self.plt.draw_if_interactive()

@final
def generate(self) -> None:
self._args_adjust()
self._compute_plot_data()
Expand All @@ -465,11 +470,13 @@ def generate(self) -> None:
def _args_adjust(self) -> None:
pass

@final
def _has_plotted_object(self, ax: Axes) -> bool:
"""check whether ax has data"""
return len(ax.lines) != 0 or len(ax.artists) != 0 or len(ax.containers) != 0

def _maybe_right_yaxis(self, ax: Axes, axes_num):
@final
def _maybe_right_yaxis(self, ax: Axes, axes_num: int):
if not self.on_right(axes_num):
# secondary axes may be passed via ax kw
return self._get_ax_layer(ax)
Expand Down Expand Up @@ -497,6 +504,7 @@ def _maybe_right_yaxis(self, ax: Axes, axes_num):
new_ax.set_yscale("symlog")
return new_ax

@final
def _setup_subplots(self) -> Figure:
if self.subplots:
naxes = (
Expand Down Expand Up @@ -567,6 +575,7 @@ def result(self):
else:
return self.axes[0]

@final
def _convert_to_ndarray(self, data):
# GH31357: categorical columns are processed separately
if isinstance(data.dtype, CategoricalDtype):
Expand All @@ -585,6 +594,7 @@ def _convert_to_ndarray(self, data):

return data

@final
def _compute_plot_data(self):
data = self.data

Expand Down Expand Up @@ -642,6 +652,7 @@ def _compute_plot_data(self):
def _make_plot(self, fig: Figure):
raise AbstractMethodError(self)

@final
def _add_table(self) -> None:
if self.table is False:
return
Expand All @@ -652,6 +663,7 @@ def _add_table(self) -> None:
ax = self._get_ax(0)
tools.table(ax, data)

@final
def _post_plot_logic_common(self, ax, data):
"""Common post process for each axes"""
if self.orientation == "vertical" or self.orientation is None:
Expand All @@ -674,6 +686,7 @@ def _post_plot_logic_common(self, ax, data):
def _post_plot_logic(self, ax, data) -> None:
"""Post process for each axes. Overridden in child classes"""

@final
def _adorn_subplots(self, fig: Figure):
"""Common post process unrelated to data"""
if len(self.axes) > 0:
Expand Down Expand Up @@ -735,6 +748,7 @@ def _adorn_subplots(self, fig: Figure):
raise ValueError(msg)
self.axes[0].set_title(self.title)

@final
def _apply_axis_properties(
self, axis: Axis, rot=None, fontsize: int | None = None
) -> None:
Expand Down Expand Up @@ -764,6 +778,7 @@ def legend_title(self) -> str | None:
stringified = map(pprint_thing, self.data.columns.names)
return ",".join(stringified)

@final
def _mark_right_label(self, label: str, index: int) -> str:
"""
Append ``(right)`` to the label of a line if it's plotted on the right axis.
Expand All @@ -774,6 +789,7 @@ def _mark_right_label(self, label: str, index: int) -> str:
label += " (right)"
return label

@final
def _append_legend_handles_labels(self, handle: Artist, label: str) -> None:
"""
Append current handle and label to ``legend_handles`` and ``legend_labels``.
Expand Down Expand Up @@ -819,6 +835,7 @@ def _make_legend(self) -> None:
if ax.get_visible():
ax.legend(loc="best")

@final
def _get_ax_legend(self, ax: Axes):
"""
Take in axes and return ax and legend under different scenarios
Expand All @@ -834,6 +851,7 @@ def _get_ax_legend(self, ax: Axes):
ax = other_ax
return ax, leg

@final
@cache_readonly
def plt(self):
import matplotlib.pyplot as plt
Expand All @@ -842,6 +860,7 @@ def plt(self):

_need_to_set_index = False

@final
def _get_xticks(self, convert_period: bool = False):
index = self.data.index
is_datetype = index.inferred_type in ("datetime", "date", "datetime64", "time")
Expand Down Expand Up @@ -896,6 +915,7 @@ def _get_custom_index_name(self):
"""Specify whether xlabel/ylabel should be used to override index name"""
return self.xlabel

@final
def _get_index_name(self) -> str | None:
if isinstance(self.data.index, ABCMultiIndex):
name = self.data.index.names
Expand All @@ -915,6 +935,7 @@ def _get_index_name(self) -> str | None:

return name

@final
@classmethod
def _get_ax_layer(cls, ax, primary: bool = True):
"""get left (primary) or right (secondary) axes"""
Expand All @@ -923,6 +944,7 @@ def _get_ax_layer(cls, ax, primary: bool = True):
else:
return getattr(ax, "right_ax", ax)

@final
def _col_idx_to_axis_idx(self, col_idx: int) -> int:
"""Return the index of the axis where the column at col_idx should be plotted"""
if isinstance(self.subplots, list):
Expand All @@ -936,6 +958,7 @@ def _col_idx_to_axis_idx(self, col_idx: int) -> int:
# subplots is True: one ax per column
return col_idx

@final
def _get_ax(self, i: int):
# get the twinx ax if appropriate
if self.subplots:
Expand All @@ -950,6 +973,7 @@ def _get_ax(self, i: int):
ax.get_yaxis().set_visible(True)
return ax

@final
@classmethod
def get_default_ax(cls, ax) -> None:
import matplotlib.pyplot as plt
Expand All @@ -959,13 +983,15 @@ def get_default_ax(cls, ax) -> None:
ax = plt.gca()
ax = cls._get_ax_layer(ax)

def on_right(self, i):
@final
def on_right(self, i: int):
if isinstance(self.secondary_y, bool):
return self.secondary_y

if isinstance(self.secondary_y, (tuple, list, np.ndarray, ABCIndex)):
return self.data.columns[i] in self.secondary_y

@final
def _apply_style_colors(self, colors, kwds, col_num, label: str):
"""
Manage style and color based on column number and its label.
Expand Down Expand Up @@ -1006,6 +1032,7 @@ def _get_colors(
color=self.kwds.get(color_kwds),
)

@final
def _parse_errorbars(self, label, err):
"""
Look for error keyword arguments and return the actual errorbar data
Expand Down Expand Up @@ -1095,6 +1122,7 @@ def match_labels(data, e):

return err

@final
def _get_errorbars(
self, label=None, index=None, xerr: bool = True, yerr: bool = True
):
Expand All @@ -1116,6 +1144,7 @@ def _get_errorbars(
errors[kw] = err
return errors

@final
def _get_subplots(self, fig: Figure):
from matplotlib.axes import Subplot

Expand All @@ -1125,6 +1154,7 @@ def _get_subplots(self, fig: Figure):
if (isinstance(ax, Subplot) and ax.get_subplotspec() is not None)
]

@final
def _get_axes_layout(self, fig: Figure) -> tuple[int, int]:
axes = self._get_subplots(fig)
x_set = set()
Expand Down Expand Up @@ -1163,17 +1193,20 @@ def __init__(self, data, x, y, **kwargs) -> None:
self.x = x
self.y = y

@final
@property
def nseries(self) -> int:
return 1

@final
def _post_plot_logic(self, ax: Axes, data) -> None:
x, y = self.x, self.y
xlabel = self.xlabel if self.xlabel is not None else pprint_thing(x)
ylabel = self.ylabel if self.ylabel is not None else pprint_thing(y)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)

@final
def _plot_colorbar(self, ax: Axes, *, fig: Figure, **kwds):
# Addresses issues #10611 and #10678:
# When plotting scatterplots and hexbinplots in IPython
Expand Down Expand Up @@ -1353,10 +1386,12 @@ def __init__(self, data, **kwargs) -> None:
if "x_compat" in self.kwds:
self.x_compat = bool(self.kwds.pop("x_compat"))

@final
def _is_ts_plot(self) -> bool:
# this is slightly deceptive
return not self.x_compat and self.use_index and self._use_dynamic_x()

@final
def _use_dynamic_x(self):
return use_dynamic_x(self._get_ax(0), self.data)

Expand Down Expand Up @@ -1424,6 +1459,7 @@ def _plot( # type: ignore[override]
cls._update_stacker(ax, stacking_id, y)
return lines

@final
def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):
# accept x to be consistent with normal plot func,
# x is not passed to tsplot as it uses data.index as x coordinate
Expand All @@ -1444,12 +1480,14 @@ def _ts_plot(self, ax: Axes, x, data, style=None, **kwds):
format_dateaxis(ax, ax.freq, data.index)
return lines

@final
def _get_stacking_id(self):
if self.stacked:
return id(self.data)
else:
return None

@final
@classmethod
def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None:
if stacking_id is None:
Expand All @@ -1461,6 +1499,7 @@ def _initialize_stacker(cls, ax: Axes, stacking_id, n: int) -> None:
ax._stacker_pos_prior[stacking_id] = np.zeros(n)
ax._stacker_neg_prior[stacking_id] = np.zeros(n)

@final
@classmethod
def _get_stacked_values(cls, ax: Axes, stacking_id, values, label):
if stacking_id is None:
Expand All @@ -1480,6 +1519,7 @@ def _get_stacked_values(cls, ax: Axes, stacking_id, values, label):
f"Column '{label}' contains both positive and negative values"
)

@final
@classmethod
def _update_stacker(cls, ax: Axes, stacking_id, values) -> None:
if stacking_id is None:
Expand Down

0 comments on commit 6d662b8

Please sign in to comment.