From 784ce579c3c84a7ec5ea3faafdab603382ae5f89 Mon Sep 17 00:00:00 2001 From: "Paulo S. Costa" Date: Wed, 6 Dec 2023 22:29:34 -0800 Subject: [PATCH] Type args and kwargs in frame pipe --- pandas-stubs/_typing.pyi | 6 +++++- pandas-stubs/core/frame.pyi | 12 ++++++++---- tests/test_frame.py | 10 ++++++++++ 3 files changed, 23 insertions(+), 5 deletions(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 5ab88069..272b9eec 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -28,7 +28,10 @@ from pandas.core.generic import NDFrame from pandas.core.groupby.grouper import Grouper from pandas.core.indexes.base import Index from pandas.core.series import Series -from typing_extensions import TypeAlias +from typing_extensions import ( + ParamSpec, + TypeAlias, +) from pandas._libs.interval import Interval from pandas._libs.tslibs import ( @@ -446,6 +449,7 @@ JSONSerializable: TypeAlias = PythonScalar | list | dict Axes: TypeAlias = AnyArrayLike | list | dict | range | tuple Renamer: TypeAlias = Mapping[Any, Label] | Callable[[Any], Label] T = TypeVar("T") +P = ParamSpec("P") FuncType: TypeAlias = Callable[..., Any] F = TypeVar("F", bound=FuncType) HashableT = TypeVar("HashableT", bound=Hashable) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 6c9a11b6..feedf99b 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -52,7 +52,10 @@ from pandas.core.window.rolling import ( Rolling, Window, ) -from typing_extensions import Self +from typing_extensions import ( + Concatenate, + Self, +) import xarray as xr from pandas._libs.missing import NAType @@ -100,6 +103,7 @@ from pandas._typing import ( MergeHow, NaPosition, NDFrameT, + P, ParquetEngine, QuantileInterpolation, RandomState, @@ -1832,9 +1836,9 @@ class DataFrame(NDFrame, OpsMixin): ) -> DataFrame: ... def pipe( self, - func: Callable[..., TType] | tuple[Callable[..., TType], _str], - *args, - **kwargs, + func: Callable[Concatenate[Self, P], TType] | tuple[Callable[..., TType], _str], + *args: P.args, + **kwargs: P.kwargs, ) -> TType: ... def pop(self, item: _str) -> Series: ... def pow( diff --git a/tests/test_frame.py b/tests/test_frame.py index 9bc32ddb..7441129a 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1386,6 +1386,16 @@ def baz(val: Styler) -> str: check(assert_type(pd.DataFrame({"a": [1], "b": [1]}).style.pipe(baz), str), str) + def qux(df: pd.DataFrame, a: int, b: str) -> pd.DataFrame: + return pd.DataFrame(df) + + check( + assert_type( + pd.DataFrame({"a": [1], "b": [1]}).pipe(qux, 1, y="a"), pd.DataFrame + ), + pd.DataFrame, + ) + # set_flags() method added in 1.2.0 https://pandas.pydata.org/docs/whatsnew/v1.2.0.html def test_types_set_flags() -> None: