From 4fb7876d2c55a48ccd3b47e2ad2cefd191113f21 Mon Sep 17 00:00:00 2001 From: "Paulo S. Costa" Date: Wed, 6 Dec 2023 22:29:34 -0800 Subject: [PATCH] Type args and kwargs in frame pipe --- pandas-stubs/_typing.pyi | 6 ++- pandas-stubs/core/frame.pyi | 12 ++++-- tests/test_frame.py | 79 ++++++++++++++++++++++++++++++++++--- 3 files changed, 87 insertions(+), 10 deletions(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 5ab88069..272b9eec 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -28,7 +28,10 @@ from pandas.core.generic import NDFrame from pandas.core.groupby.grouper import Grouper from pandas.core.indexes.base import Index from pandas.core.series import Series -from typing_extensions import TypeAlias +from typing_extensions import ( + ParamSpec, + TypeAlias, +) from pandas._libs.interval import Interval from pandas._libs.tslibs import ( @@ -446,6 +449,7 @@ JSONSerializable: TypeAlias = PythonScalar | list | dict Axes: TypeAlias = AnyArrayLike | list | dict | range | tuple Renamer: TypeAlias = Mapping[Any, Label] | Callable[[Any], Label] T = TypeVar("T") +P = ParamSpec("P") FuncType: TypeAlias = Callable[..., Any] F = TypeVar("F", bound=FuncType) HashableT = TypeVar("HashableT", bound=Hashable) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 6cd3236b..168bd352 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -52,7 +52,10 @@ from pandas.core.window.rolling import ( Rolling, Window, ) -from typing_extensions import Self +from typing_extensions import ( + Concatenate, + Self, +) import xarray as xr from pandas._libs.missing import NAType @@ -100,6 +103,7 @@ from pandas._typing import ( MergeHow, NaPosition, NDFrameT, + P, ParquetEngine, QuantileInterpolation, RandomState, @@ -1832,9 +1836,9 @@ class DataFrame(NDFrame, OpsMixin): ) -> DataFrame: ... def pipe( self, - func: Callable[..., TType] | tuple[Callable[..., TType], _str], - *args, - **kwargs, + func: Callable[Concatenate[Self, P], TType] | tuple[Callable[..., TType], _str], + *args: P.args, + **kwargs: P.kwargs, ) -> TType: ... def pop(self, item: _str) -> Series: ... def pow( diff --git a/tests/test_frame.py b/tests/test_frame.py index a25e9b68..15ae6694 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1398,21 +1398,90 @@ def foo(df: pd.DataFrame) -> pd.DataFrame: .pipe(foo) ) + df = pd.DataFrame({"a": [1], "b": [2]}) check(assert_type(val, pd.DataFrame), pd.DataFrame) - check(assert_type(pd.DataFrame({"a": [1]}).pipe(foo), pd.DataFrame), pd.DataFrame) + check(assert_type(df.pipe(foo), pd.DataFrame), pd.DataFrame) def bar(val: Styler) -> Styler: return val - check( - assert_type(pd.DataFrame({"a": [1], "b": [1]}).style.pipe(bar), Styler), Styler - ) + check(assert_type(df.style.pipe(bar), Styler), Styler) def baz(val: Styler) -> str: return val.to_latex() - check(assert_type(pd.DataFrame({"a": [1], "b": [1]}).style.pipe(baz), str), str) + check(assert_type(df.style.pipe(baz), str), str) + + def qux( + df: pd.DataFrame, + positional_only: int, + /, + argument_1: list[float], + argument_2: str, + *, + keyword_only: tuple[int, int], + ) -> pd.DataFrame: + return pd.DataFrame(df) + + check( + assert_type( + df.pipe(qux, 1, [1.0, 2.0], argument_2="hi", keyword_only=(1, 2)), + pd.DataFrame, + ), + pd.DataFrame, + ) + + if TYPE_CHECKING_INVALID_USAGE: + df.pipe( + qux, # type: ignore[arg-type] + "a", # pyright: ignore[reportGeneralTypeIssues] + [1.0, 2.0], + argument_2="hi", + keyword_only=(1, 2), + ) + df.pipe( + qux, # type: ignore[arg-type] + 1, + [1.0, "b"], # pyright: ignore[reportGeneralTypeIssues] + argument_2="hi", + keyword_only=(1, 2), + ) + df.pipe( + qux, # type: ignore[arg-type] + 1, + [1.0, 2.0], + argument_2=11, # pyright: ignore[reportGeneralTypeIssues] + keyword_only=(1, 2), + ) + df.pipe( + qux, # type: ignore[arg-type] + 1, + [1.0, 2.0], + argument_2="hi", + keyword_only=(1,), # pyright: ignore[reportGeneralTypeIssues] + ) + df.pipe( + qux, # type: ignore[arg-type] + 1, + [1.0, 2.0], + argument_3="hi", # pyright: ignore[reportGeneralTypeIssues] + keyword_only=(1, 2), + ) + df.pipe( + qux, # type: ignore[arg-type] + 1, + [1.0, 2.0], + 11, + (1, 2), # pyright: ignore[reportGeneralTypeIssues] + ) + df.pipe( + qux, # type: ignore[arg-type] + positional_only=1, # pyright: ignore[reportGeneralTypeIssues] + argument_1=[1.0, 2.0], + argument_2=11, + keyword_only=(1, 2), + ) # set_flags() method added in 1.2.0 https://pandas.pydata.org/docs/whatsnew/v1.2.0.html