Skip to content

Commit

Permalink
Type args and kwargs in frame pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
paw-lu committed Dec 21, 2023
1 parent c0ca527 commit 4fb7876
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 10 deletions.
6 changes: 5 additions & 1 deletion pandas-stubs/_typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ from pandas.core.generic import NDFrame
from pandas.core.groupby.grouper import Grouper
from pandas.core.indexes.base import Index
from pandas.core.series import Series
from typing_extensions import TypeAlias
from typing_extensions import (
ParamSpec,
TypeAlias,
)

from pandas._libs.interval import Interval
from pandas._libs.tslibs import (
Expand Down Expand Up @@ -446,6 +449,7 @@ JSONSerializable: TypeAlias = PythonScalar | list | dict
Axes: TypeAlias = AnyArrayLike | list | dict | range | tuple
Renamer: TypeAlias = Mapping[Any, Label] | Callable[[Any], Label]
T = TypeVar("T")
P = ParamSpec("P")
FuncType: TypeAlias = Callable[..., Any]
F = TypeVar("F", bound=FuncType)
HashableT = TypeVar("HashableT", bound=Hashable)
Expand Down
12 changes: 8 additions & 4 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ from pandas.core.window.rolling import (
Rolling,
Window,
)
from typing_extensions import Self
from typing_extensions import (
Concatenate,
Self,
)
import xarray as xr

from pandas._libs.missing import NAType
Expand Down Expand Up @@ -100,6 +103,7 @@ from pandas._typing import (
MergeHow,
NaPosition,
NDFrameT,
P,
ParquetEngine,
QuantileInterpolation,
RandomState,
Expand Down Expand Up @@ -1832,9 +1836,9 @@ class DataFrame(NDFrame, OpsMixin):
) -> DataFrame: ...
def pipe(
self,
func: Callable[..., TType] | tuple[Callable[..., TType], _str],
*args,
**kwargs,
func: Callable[Concatenate[Self, P], TType] | tuple[Callable[..., TType], _str],
*args: P.args,
**kwargs: P.kwargs,
) -> TType: ...
def pop(self, item: _str) -> Series: ...
def pow(
Expand Down
79 changes: 74 additions & 5 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,21 +1398,90 @@ def foo(df: pd.DataFrame) -> pd.DataFrame:
.pipe(foo)
)

df = pd.DataFrame({"a": [1], "b": [2]})
check(assert_type(val, pd.DataFrame), pd.DataFrame)

check(assert_type(pd.DataFrame({"a": [1]}).pipe(foo), pd.DataFrame), pd.DataFrame)
check(assert_type(df.pipe(foo), pd.DataFrame), pd.DataFrame)

def bar(val: Styler) -> Styler:
return val

check(
assert_type(pd.DataFrame({"a": [1], "b": [1]}).style.pipe(bar), Styler), Styler
)
check(assert_type(df.style.pipe(bar), Styler), Styler)

def baz(val: Styler) -> str:
return val.to_latex()

check(assert_type(pd.DataFrame({"a": [1], "b": [1]}).style.pipe(baz), str), str)
check(assert_type(df.style.pipe(baz), str), str)

def qux(
df: pd.DataFrame,
positional_only: int,
/,
argument_1: list[float],
argument_2: str,
*,
keyword_only: tuple[int, int],
) -> pd.DataFrame:
return pd.DataFrame(df)

check(
assert_type(
df.pipe(qux, 1, [1.0, 2.0], argument_2="hi", keyword_only=(1, 2)),
pd.DataFrame,
),
pd.DataFrame,
)

if TYPE_CHECKING_INVALID_USAGE:
df.pipe(
qux, # type: ignore[arg-type]
"a", # pyright: ignore[reportGeneralTypeIssues]
[1.0, 2.0],
argument_2="hi",
keyword_only=(1, 2),
)
df.pipe(
qux, # type: ignore[arg-type]
1,
[1.0, "b"], # pyright: ignore[reportGeneralTypeIssues]
argument_2="hi",
keyword_only=(1, 2),
)
df.pipe(
qux, # type: ignore[arg-type]
1,
[1.0, 2.0],
argument_2=11, # pyright: ignore[reportGeneralTypeIssues]
keyword_only=(1, 2),
)
df.pipe(
qux, # type: ignore[arg-type]
1,
[1.0, 2.0],
argument_2="hi",
keyword_only=(1,), # pyright: ignore[reportGeneralTypeIssues]
)
df.pipe(
qux, # type: ignore[arg-type]
1,
[1.0, 2.0],
argument_3="hi", # pyright: ignore[reportGeneralTypeIssues]
keyword_only=(1, 2),
)
df.pipe(
qux, # type: ignore[arg-type]
1,
[1.0, 2.0],
11,
(1, 2), # pyright: ignore[reportGeneralTypeIssues]
)
df.pipe(
qux, # type: ignore[arg-type]
positional_only=1, # pyright: ignore[reportGeneralTypeIssues]
argument_1=[1.0, 2.0],
argument_2=11,
keyword_only=(1, 2),
)


# set_flags() method added in 1.2.0 https://pandas.pydata.org/docs/whatsnew/v1.2.0.html
Expand Down

0 comments on commit 4fb7876

Please sign in to comment.