From 40bc84d52c898b071229cbfbc634ad1686cf0e77 Mon Sep 17 00:00:00 2001 From: "Paulo S. Costa" Date: Mon, 15 Jan 2024 10:01:43 -0600 Subject: [PATCH] TYP: Persist typing information for pipe args and kwargs (#56760) * Type generic pipe with function params * Type common pipe with function params * Type resample pipe with function params * Type groupby pipe with function params * Type style pipe function params and tuple func --- pandas/_typing.py | 13 +++++++++++- pandas/core/common.py | 36 +++++++++++++++++++++++++++++++--- pandas/core/generic.py | 26 +++++++++++++++++++++--- pandas/core/groupby/groupby.py | 34 +++++++++++++++++++++++++++----- pandas/core/resample.py | 29 ++++++++++++++++++++++++--- pandas/io/formats/style.py | 31 +++++++++++++++++++++++++++-- 6 files changed, 152 insertions(+), 17 deletions(-) diff --git a/pandas/_typing.py b/pandas/_typing.py index a80f9603493a7..fa9dc14bb4bd7 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -90,18 +90,29 @@ from typing import SupportsIndex if sys.version_info >= (3, 10): + from typing import Concatenate # pyright: ignore[reportUnusedImport] + from typing import ParamSpec from typing import TypeGuard # pyright: ignore[reportUnusedImport] else: - from typing_extensions import TypeGuard # pyright: ignore[reportUnusedImport] + from typing_extensions import ( # pyright: ignore[reportUnusedImport] + Concatenate, + ParamSpec, + TypeGuard, + ) + + P = ParamSpec("P") if sys.version_info >= (3, 11): from typing import Self # pyright: ignore[reportUnusedImport] else: from typing_extensions import Self # pyright: ignore[reportUnusedImport] + else: npt: Any = None + ParamSpec: Any = None Self: Any = None TypeGuard: Any = None + Concatenate: Any = None HashableT = TypeVar("HashableT", bound=Hashable) MutableMappingT = TypeVar("MutableMappingT", bound=MutableMapping) diff --git a/pandas/core/common.py b/pandas/core/common.py index 7d864e02be54e..69b602feee3ea 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -24,6 +24,7 @@ TYPE_CHECKING, Any, Callable, + TypeVar, cast, overload, ) @@ -51,7 +52,9 @@ from pandas._typing import ( AnyArrayLike, ArrayLike, + Concatenate, NpDtype, + P, RandomState, T, ) @@ -463,8 +466,34 @@ def random_state(state: RandomState | None = None): ) +_T = TypeVar("_T") # Secondary TypeVar for use in pipe's type hints + + +@overload +def pipe( + obj: _T, + func: Callable[Concatenate[_T, P], T], + *args: P.args, + **kwargs: P.kwargs, +) -> T: + ... + + +@overload +def pipe( + obj: Any, + func: tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, +) -> T: + ... + + def pipe( - obj, func: Callable[..., T] | tuple[Callable[..., T], str], *args, **kwargs + obj: _T, + func: Callable[Concatenate[_T, P], T] | tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, ) -> T: """ Apply a function ``func`` to object ``obj`` either by passing obj as the @@ -490,12 +519,13 @@ def pipe( object : the return type of ``func``. """ if isinstance(func, tuple): - func, target = func + # Assigning to func_ so pyright understands that it's a callable + func_, target = func if target in kwargs: msg = f"{target} is both the pipe target and a keyword argument" raise ValueError(msg) kwargs[target] = obj - return func(*args, **kwargs) + return func_(*args, **kwargs) else: return func(obj, *args, **kwargs) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index a61148a09be18..caac11b6ab4f6 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -50,6 +50,7 @@ Axis, AxisInt, CompressionOptions, + Concatenate, DtypeArg, DtypeBackend, DtypeObj, @@ -213,6 +214,7 @@ ) from pandas._libs.tslibs import BaseOffset + from pandas._typing import P from pandas import ( DataFrame, @@ -6118,13 +6120,31 @@ def sample( return result + @overload + def pipe( + self, + func: Callable[Concatenate[Self, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: + ... + + @overload + def pipe( + self, + func: tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, + ) -> T: + ... + @final @doc(klass=_shared_doc_kwargs["klass"]) def pipe( self, - func: Callable[..., T] | tuple[Callable[..., T], str], - *args, - **kwargs, + func: Callable[Concatenate[Self, P], T] | tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, ) -> T: r""" Apply chainable functions that expect Series or DataFrames. diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index f36297a59498d..ab22d4e3dc200 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -29,6 +29,7 @@ class providing the base-class of operations. Union, cast, final, + overload, ) import warnings @@ -55,7 +56,6 @@ class providing the base-class of operations. PositionalIndexer, RandomState, Scalar, - T, npt, ) from pandas.compat.numpy import function as nv @@ -147,7 +147,13 @@ class providing the base-class of operations. ) if TYPE_CHECKING: - from typing import Any + from pandas._typing import ( + Any, + Concatenate, + P, + Self, + T, + ) from pandas.core.resample import Resampler from pandas.core.window import ( @@ -989,6 +995,24 @@ def _selected_obj(self): def _dir_additions(self) -> set[str]: return self.obj._dir_additions() + @overload + def pipe( + self, + func: Callable[Concatenate[Self, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: + ... + + @overload + def pipe( + self, + func: tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, + ) -> T: + ... + @Substitution( klass="GroupBy", examples=dedent( @@ -1014,9 +1038,9 @@ def _dir_additions(self) -> set[str]: @Appender(_pipe_template) def pipe( self, - func: Callable[..., T] | tuple[Callable[..., T], str], - *args, - **kwargs, + func: Callable[Concatenate[Self, P], T] | tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, ) -> T: return com.pipe(self, func, *args, **kwargs) diff --git a/pandas/core/resample.py b/pandas/core/resample.py index 31309777c154d..924f9e6d49040 100644 --- a/pandas/core/resample.py +++ b/pandas/core/resample.py @@ -9,6 +9,7 @@ cast, final, no_type_check, + overload, ) import warnings @@ -97,12 +98,16 @@ from collections.abc import Hashable from pandas._typing import ( + Any, AnyArrayLike, Axis, AxisInt, + Concatenate, Frequency, IndexLabel, InterpolateOptions, + P, + Self, T, TimedeltaConvertibleTypes, TimeGrouperOrigin, @@ -254,6 +259,24 @@ def _get_binner(self): bin_grouper = BinGrouper(bins, binlabels, indexer=self._indexer) return binner, bin_grouper + @overload + def pipe( + self, + func: Callable[Concatenate[Self, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: + ... + + @overload + def pipe( + self, + func: tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, + ) -> T: + ... + @final @Substitution( klass="Resampler", @@ -278,9 +301,9 @@ def _get_binner(self): @Appender(_pipe_template) def pipe( self, - func: Callable[..., T] | tuple[Callable[..., T], str], - *args, - **kwargs, + func: Callable[Concatenate[Self, P], T] | tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, ) -> T: return super().pipe(func, *args, **kwargs) diff --git a/pandas/io/formats/style.py b/pandas/io/formats/style.py index 0fbfae22f4663..5289a21adfbb4 100644 --- a/pandas/io/formats/style.py +++ b/pandas/io/formats/style.py @@ -9,7 +9,6 @@ import operator from typing import ( TYPE_CHECKING, - Any, Callable, overload, ) @@ -66,15 +65,20 @@ from matplotlib.colors import Colormap from pandas._typing import ( + Any, Axis, AxisInt, + Concatenate, FilePath, IndexLabel, IntervalClosedType, Level, + P, QuantileInterpolation, Scalar, + Self, StorageOptions, + T, WriteBuffer, WriteExcelBuffer, ) @@ -3614,7 +3618,30 @@ class MyStyler(cls): # type: ignore[valid-type,misc] return MyStyler - def pipe(self, func: Callable, *args, **kwargs): + @overload + def pipe( + self, + func: Callable[Concatenate[Self, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: + ... + + @overload + def pipe( + self, + func: tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, + ) -> T: + ... + + def pipe( + self, + func: Callable[Concatenate[Self, P], T] | tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, + ) -> T: """ Apply ``func(self, *args, **kwargs)``, and return the result.