diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 7210afae..2855d822 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -28,7 +28,10 @@ from pandas.core.generic import NDFrame from pandas.core.groupby.grouper import Grouper from pandas.core.indexes.base import Index from pandas.core.series import Series -from typing_extensions import TypeAlias +from typing_extensions import ( + ParamSpec, + TypeAlias, +) from pandas._libs.interval import Interval from pandas._libs.tslibs import ( @@ -447,6 +450,7 @@ JSONSerializable: TypeAlias = PythonScalar | list | dict Axes: TypeAlias = AnyArrayLike | list | dict | range | tuple Renamer: TypeAlias = Mapping[Any, Label] | Callable[[Any], Label] T = TypeVar("T") +P = ParamSpec("P") FuncType: TypeAlias = Callable[..., Any] F = TypeVar("F", bound=FuncType) HashableT = TypeVar("HashableT", bound=Hashable) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 2de05988..5231b976 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -113,7 +113,6 @@ from pandas._typing import ( StorageOptions, StrLike, Suffixes, - T as TType, TimestampConvention, ValidationOptions, WriteBuffer, @@ -1829,12 +1828,6 @@ class DataFrame(NDFrame, OpsMixin): freq=..., **kwargs, ) -> DataFrame: ... - def pipe( - self, - func: Callable[..., TType] | tuple[Callable[..., TType], _str], - *args, - **kwargs, - ) -> TType: ... def pop(self, item: _str) -> Series: ... def pow( self, diff --git a/pandas-stubs/core/generic.pyi b/pandas-stubs/core/generic.pyi index 1d8b3aa4..e491b918 100644 --- a/pandas-stubs/core/generic.pyi +++ b/pandas-stubs/core/generic.pyi @@ -19,7 +19,10 @@ from pandas import Index import pandas.core.indexing as indexing from pandas.core.series import Series import sqlalchemy.engine -from typing_extensions import Self +from typing_extensions import ( + Concatenate, + Self, +) from pandas._typing import ( S1, @@ -40,6 +43,7 @@ from pandas._typing import ( IgnoreRaise, IndexLabel, Level, + P, ReplaceMethod, SortKind, StorageOptions, @@ -352,8 +356,19 @@ class NDFrame(indexing.IndexingMixin): ) -> Self: ... def head(self, n: int = ...) -> Self: ... def tail(self, n: int = ...) -> Self: ... + @overload + def pipe( + self, + func: Callable[Concatenate[Self, P], T], + *args: P.args, + **kwargs: P.kwargs, + ) -> T: ... + @overload def pipe( - self, func: Callable[..., T] | tuple[Callable[..., T], str], *args, **kwargs + self, + func: tuple[Callable[..., T], str], + *args: Any, + **kwargs: Any, ) -> T: ... def __finalize__(self, other, method=..., **kwargs) -> Self: ... def __setattr__(self, name: _str, value) -> None: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index 46f844c5..8056e283 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -1436,21 +1436,123 @@ def foo(df: pd.DataFrame) -> pd.DataFrame: .pipe(foo) ) + df = pd.DataFrame({"a": [1], "b": [2]}) check(assert_type(val, pd.DataFrame), pd.DataFrame) - check(assert_type(pd.DataFrame({"a": [1]}).pipe(foo), pd.DataFrame), pd.DataFrame) + check(assert_type(df.pipe(foo), pd.DataFrame), pd.DataFrame) def bar(val: Styler) -> Styler: return val - check( - assert_type(pd.DataFrame({"a": [1], "b": [1]}).style.pipe(bar), Styler), Styler - ) + check(assert_type(df.style.pipe(bar), Styler), Styler) def baz(val: Styler) -> str: return val.to_latex() - check(assert_type(pd.DataFrame({"a": [1], "b": [1]}).style.pipe(baz), str), str) + check(assert_type(df.style.pipe(baz), str), str) + + def qux( + df: pd.DataFrame, + positional_only: int, + /, + argument_1: list[float], + argument_2: str, + *, + keyword_only: tuple[int, int], + ) -> pd.DataFrame: + return pd.DataFrame(df) + + check( + assert_type( + df.pipe(qux, 1, [1.0, 2.0], argument_2="hi", keyword_only=(1, 2)), + pd.DataFrame, + ), + pd.DataFrame, + ) + + if TYPE_CHECKING_INVALID_USAGE: + df.pipe( + qux, + "a", # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues] + [1.0, 2.0], + argument_2="hi", + keyword_only=(1, 2), + ) + df.pipe( + qux, + 1, + [1.0, "b"], # type: ignore[list-item] # pyright: ignore[reportGeneralTypeIssues] + argument_2="hi", + keyword_only=(1, 2), + ) + df.pipe( + qux, + 1, + [1.0, 2.0], + argument_2=11, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues] + keyword_only=(1, 2), + ) + df.pipe( + qux, + 1, + [1.0, 2.0], + argument_2="hi", + keyword_only=(1,), # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues] + ) + df.pipe( # type: ignore[call-arg] + qux, + 1, + [1.0, 2.0], + argument_3="hi", # pyright: ignore[reportGeneralTypeIssues] + keyword_only=(1, 2), + ) + df.pipe( # type: ignore[misc] + qux, + 1, + [1.0, 2.0], + 11, # type: ignore[arg-type] + (1, 2), # pyright: ignore[reportGeneralTypeIssues] + ) + df.pipe( # type: ignore[call-arg] + qux, + positional_only=1, # pyright: ignore[reportGeneralTypeIssues] + argument_1=[1.0, 2.0], + argument_2=11, # type: ignore[arg-type] + keyword_only=(1, 2), + ) + + def dataframe_not_first_arg(x: int, df: pd.DataFrame) -> pd.DataFrame: + return df + + check( + assert_type( + df.pipe( + ( + dataframe_not_first_arg, + "df", + ), + 1, + ), + pd.DataFrame, + ), + pd.DataFrame, + ) + + if TYPE_CHECKING_INVALID_USAGE: + df.pipe( + ( + dataframe_not_first_arg, # type: ignore[arg-type] + 1, # pyright: ignore[reportGeneralTypeIssues] + ), + 1, + ) + df.pipe( + ( # pyright: ignore[reportGeneralTypeIssues] + 1, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues] + "df", + ), + 1, + ) # set_flags() method added in 1.2.0 https://pandas.pydata.org/docs/whatsnew/v1.2.0.html diff --git a/tests/test_series.py b/tests/test_series.py index f93a4a8f..673e4623 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -2914,3 +2914,113 @@ def test_timedeltaseries_operators() -> None: pd.Series, pd.Timedelta, ) + + +def test_pipe() -> None: + ser = pd.Series(range(10)) + + def first_arg_series( + ser: pd.Series, + positional_only: int, + /, + argument_1: list[float], + argument_2: str, + *, + keyword_only: tuple[int, int], + ) -> pd.Series: + return ser + + check( + assert_type( + ser.pipe( + first_arg_series, + 1, + [1.0, 2.0], + argument_2="hi", + keyword_only=(1, 2), + ), + pd.Series, + ), + pd.Series, + ) + + if TYPE_CHECKING_INVALID_USAGE: + ser.pipe( + first_arg_series, + "a", # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues] + [1.0, 2.0], + argument_2="hi", + keyword_only=(1, 2), + ) + ser.pipe( + first_arg_series, + 1, + [1.0, "b"], # type: ignore[list-item] # pyright: ignore[reportGeneralTypeIssues] + argument_2="hi", + keyword_only=(1, 2), + ) + ser.pipe( + first_arg_series, + 1, + [1.0, 2.0], + argument_2=11, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues] + keyword_only=(1, 2), + ) + ser.pipe( + first_arg_series, + 1, + [1.0, 2.0], + argument_2="hi", + keyword_only=(1,), # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues] + ) + ser.pipe( # type: ignore[call-arg] + first_arg_series, + 1, + [1.0, 2.0], + argument_3="hi", # pyright: ignore[reportGeneralTypeIssues] + keyword_only=(1, 2), + ) + ser.pipe( # type: ignore[misc] + first_arg_series, + 1, + [1.0, 2.0], + 11, # type: ignore[arg-type] + (1, 2), # pyright: ignore[reportGeneralTypeIssues] + ) + ser.pipe( # type: ignore[call-arg] + first_arg_series, + positional_only=1, # pyright: ignore[reportGeneralTypeIssues] + argument_1=[1.0, 2.0], + argument_2=11, # type: ignore[arg-type] + keyword_only=(1, 2), + ) + + def first_arg_not_series(argument_1: int, ser: pd.Series) -> pd.Series: + return ser + + check( + assert_type( + ser.pipe( + (first_arg_not_series, "ser"), + 1, + ), + pd.Series, + ), + pd.Series, + ) + + if TYPE_CHECKING_INVALID_USAGE: + ser.pipe( + ( + first_arg_not_series, # type: ignore[arg-type] + 1, # pyright: ignore[reportGeneralTypeIssues] + ), + 1, + ) + ser.pipe( + ( + 1, # type: ignore[arg-type] # pyright: ignore[reportGeneralTypeIssues] + "df", + ), + 1, + )