Skip to content

Commit

Permalink
perf: Prefer getitem instead of loc for selecting multiple columns in…
Browse files Browse the repository at this point in the history
… pandas-like (#1355)
  • Loading branch information
MarcoGorelli authored Nov 12, 2024
1 parent 8b1c054 commit b9d5fe5
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 11 deletions.
37 changes: 33 additions & 4 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from narwhals._dask.utils import add_row_index
from narwhals._dask.utils import parse_exprs_and_named_exprs
from narwhals._pandas_like.utils import native_to_narwhals_dtype
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals.utils import Implementation
from narwhals.utils import flatten
from narwhals.utils import generate_temporary_column_name
Expand Down Expand Up @@ -116,7 +117,14 @@ def select(

if exprs and all(isinstance(x, str) for x in exprs) and not named_exprs:
# This is a simple slice => fastpath!
return self._from_native_frame(self._native_frame.loc[:, exprs])
return self._from_native_frame(
select_columns_by_name(
self._native_frame,
list(exprs), # type: ignore[arg-type]
self._backend_version,
self._implementation,
)
)

new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)

Expand All @@ -136,7 +144,12 @@ def select(
)
return self._from_native_frame(df)

df = self._native_frame.assign(**new_series).loc[:, list(new_series.keys())]
df = select_columns_by_name(
self._native_frame.assign(**new_series),
list(new_series.keys()),
self._backend_version,
self._implementation,
)
return self._from_native_frame(df)

def drop_nulls(self: Self, subset: str | list[str] | None) -> Self:
Expand Down Expand Up @@ -257,8 +270,16 @@ def join(
n_bytes=8, columns=[*self.columns, *other.columns]
)

if right_on is None: # pragma: no cover
msg = "`right_on` cannot be `None` in anti-join"
raise TypeError(msg)
other_native = (
other._native_frame.loc[:, right_on]
select_columns_by_name(
other._native_frame,
right_on,
self._backend_version,
self._implementation,
)
.rename( # rename to avoid creating extra columns in join
columns=dict(zip(right_on, left_on)) # type: ignore[arg-type]
)
Expand All @@ -276,8 +297,16 @@ def join(
)

if how == "semi":
if right_on is None: # pragma: no cover
msg = "`right_on` cannot be `None` in semi-join"
raise TypeError(msg)
other_native = (
other._native_frame.loc[:, right_on]
select_columns_by_name(
other._native_frame,
right_on,
self._backend_version,
self._implementation,
)
.rename( # rename to avoid creating extra columns in join
columns=dict(zip(right_on, left_on)) # type: ignore[arg-type]
)
Expand Down
40 changes: 36 additions & 4 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from narwhals._pandas_like.utils import create_compliant_series
from narwhals._pandas_like.utils import horizontal_concat
from narwhals._pandas_like.utils import native_to_narwhals_dtype
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals._pandas_like.utils import validate_dataframe_comparand
from narwhals.dependencies import is_numpy_array
from narwhals.utils import Implementation
Expand Down Expand Up @@ -232,7 +233,14 @@ def __getitem__(

elif is_sequence_but_not_str(item) or (is_numpy_array(item) and item.ndim == 1):
if all(isinstance(x, str) for x in item) and len(item) > 0:
return self._from_native_frame(self._native_frame.loc[:, item])
return self._from_native_frame(
select_columns_by_name(
self._native_frame,
item,
self._backend_version,
self._implementation,
)
)
return self._from_native_frame(self._native_frame.iloc[item])

elif isinstance(item, slice):
Expand Down Expand Up @@ -328,7 +336,15 @@ def select(
) -> Self:
if exprs and all(isinstance(x, str) for x in exprs) and not named_exprs:
# This is a simple slice => fastpath!
return self._from_native_frame(self._native_frame.loc[:, list(exprs)])
column_names = list(exprs)
return self._from_native_frame(
select_columns_by_name(
self._native_frame,
column_names, # type: ignore[arg-type]
self._backend_version,
self._implementation,
)
)
new_series = evaluate_into_exprs(self, *exprs, **named_exprs)
if not new_series:
# return empty dataframe, like Polars does
Expand Down Expand Up @@ -545,9 +561,17 @@ def join(
indicator_token = generate_temporary_column_name(
n_bytes=8, columns=[*self.columns, *other.columns]
)
if right_on is None: # pragma: no cover
msg = "`right_on` cannot be `None` in anti-join"
raise TypeError(msg)

other_native = (
other._native_frame.loc[:, right_on]
select_columns_by_name(
other._native_frame,
right_on,
self._backend_version,
self._implementation,
)
.rename( # rename to avoid creating extra columns in join
columns=dict(zip(right_on, left_on)), # type: ignore[arg-type]
copy=False,
Expand All @@ -567,8 +591,16 @@ def join(
)

if how == "semi":
if right_on is None: # pragma: no cover
msg = "`right_on` cannot be `None` in semi-join"
raise TypeError(msg)
other_native = (
other._native_frame.loc[:, right_on]
select_columns_by_name(
other._native_frame,
right_on,
self._backend_version,
self._implementation,
)
.rename( # rename to avoid creating extra columns in join
columns=dict(zip(right_on, left_on)), # type: ignore[arg-type]
copy=False,
Expand Down
23 changes: 20 additions & 3 deletions narwhals/_pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from narwhals._expression_parsing import is_simple_aggregation
from narwhals._expression_parsing import parse_into_exprs
from narwhals._pandas_like.utils import native_series_from_iterable
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals.utils import Implementation
from narwhals.utils import remove_prefix
from narwhals.utils import tupleify
Expand Down Expand Up @@ -38,7 +39,15 @@ def __init__(
): # pragma: no cover
if (
not drop_null_keys
and self._df._native_frame.loc[:, self._keys].isna().any().any()
and select_columns_by_name(
self._df._native_frame,
self._keys,
self._df._backend_version,
self._df._implementation,
)
.isna()
.any()
.any()
):
msg = "Grouping by null values is not supported in pandas < 1.0.0"
raise NotImplementedError(msg)
Expand Down Expand Up @@ -227,7 +236,11 @@ def agg_pandas( # noqa: PLR0915
result_aggs = native_namespace.DataFrame(
list(grouped.groups.keys()), columns=keys
)
return from_dataframe(result_aggs.loc[:, output_names])
return from_dataframe(
select_columns_by_name(
result_aggs, output_names, backend_version, implementation
)
)

if dataframe_is_empty:
# Don't even attempt this, it's way too inconsistent across pandas versions.
Expand Down Expand Up @@ -275,4 +288,8 @@ def func(df: Any) -> Any:
# This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files
result_complex.reset_index(inplace=True) # noqa: PD002

return from_dataframe(result_complex.loc[:, output_names])
return from_dataframe(
select_columns_by_name(
result_complex, output_names, backend_version, implementation
)
)
18 changes: 18 additions & 0 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from typing import Iterable
from typing import Literal
from typing import Sequence
from typing import TypeVar

from narwhals._arrow.utils import (
Expand Down Expand Up @@ -635,3 +636,20 @@ def calculate_timestamp_date(s: pd.Series, time_unit: str) -> pd.Series:
else:
result = s * 1_000
return result


def select_columns_by_name(
df: T,
column_names: Sequence[str],
backend_version: tuple[int, ...],
implementation: Implementation,
) -> T:
"""Select columns by name. Prefer this over `df.loc[:, column_names]` as it's
generally more performant."""
if (df.columns.dtype.kind == "b") or ( # type: ignore[attr-defined]
implementation is Implementation.PANDAS and backend_version < (1, 5)
):
# See https://github.com/narwhals-dev/narwhals/issues/1349#issuecomment-2470118122
# for why we need this
return df.loc[:, column_names] # type: ignore[no-any-return, attr-defined]
return df[column_names] # type: ignore[no-any-return, index]
12 changes: 12 additions & 0 deletions tests/frame/select_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import narwhals.stable.v1 as nw
from tests.utils import PANDAS_VERSION
from tests.utils import Constructor
from tests.utils import assert_equal_data

Expand Down Expand Up @@ -32,3 +33,14 @@ def test_non_string_select_invalid() -> None:
df = nw.from_native(pd.DataFrame({0: [1, 2], "b": [3, 4]}))
with pytest.raises(TypeError, match="\n\nHint: if you were trying to select"):
nw.to_native(df.select(0)) # type: ignore[arg-type]


def test_select_boolean_cols(request: pytest.FixtureRequest) -> None:
if PANDAS_VERSION < (1, 1):
# bug in old pandas
request.applymarker(pytest.mark.xfail)
df = nw.from_native(pd.DataFrame({True: [1, 2], False: [3, 4]}), eager_only=True)
result = df.group_by(True).agg(nw.col(False).max()) # type: ignore[arg-type]# noqa: FBT003
assert_equal_data(result.to_dict(as_series=False), {True: [1, 2]}) # type: ignore[dict-item]
result = df.select(nw.col([False, True])) # type: ignore[list-item]
assert_equal_data(result.to_dict(as_series=False), {True: [1, 2], False: [3, 4]}) # type: ignore[dict-item]

0 comments on commit b9d5fe5

Please sign in to comment.