Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partly fix #1493: error message changed #1515

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals._pandas_like.utils import set_axis
from narwhals._pandas_like.utils import to_datetime
from narwhals.dependencies import is_numpy_scalar
from narwhals.utils import Implementation
from narwhals.utils import import_dtypes_module

Expand Down Expand Up @@ -112,7 +113,7 @@ def __getitem__(self, idx: int) -> Any: ...
def __getitem__(self, idx: slice | Sequence[int]) -> Self: ...

def __getitem__(self, idx: int | slice | Sequence[int]) -> Any | Self:
if isinstance(idx, int):
if isinstance(idx, int) or is_numpy_scalar(idx):
return self._native_series.iloc[idx]
return self._from_native_series(self._native_series.iloc[idx])

Expand Down
5 changes: 5 additions & 0 deletions narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,11 @@ def is_numpy_array(arr: Any) -> TypeGuard[np.ndarray]:
return (np := get_numpy()) is not None and isinstance(arr, np.ndarray)


def is_numpy_scalar(scalar: Any) -> TypeGuard[np.generic]:
"""Check whether `scalar` is a NumPy Scalar without importing NumPy."""
return (np := get_numpy()) is not None and np.isscalar(scalar)


def is_pandas_like_dataframe(df: Any) -> bool:
"""Check whether `df` is a pandas-like DataFrame without doing any imports.

Expand Down
3 changes: 2 additions & 1 deletion narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import TypeVar
from typing import overload

from narwhals.dependencies import is_numpy_scalar
from narwhals.dtypes import _validate_dtype
from narwhals.typing import IntoSeriesT
from narwhals.utils import _validate_rolling_arguments
Expand Down Expand Up @@ -67,7 +68,7 @@ def __getitem__(self: Self, idx: int) -> Any: ...
def __getitem__(self: Self, idx: slice | Sequence[int]) -> Self: ...

def __getitem__(self: Self, idx: int | slice | Sequence[int]) -> Any | Self:
if isinstance(idx, int):
if isinstance(idx, int) or is_numpy_scalar(idx):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we check (is_numpy_scalar(idx) and idx.dtype.kind in ('i', 'u'))?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand it, we would ideally capture any sort of numpy scalar that can be used to index a pandas series. Can't strings, floats and datetime objects also be used?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, we only allow positional indexing here

return self._compliant_series[idx]
return self._from_compliant_series(self._compliant_series[idx])

Expand Down
12 changes: 12 additions & 0 deletions tests/series_only/scalar_index_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from __future__ import annotations

import narwhals.stable.v1 as nw

np = nw.dependencies.get_numpy()
pd = nw.dependencies.get_pandas()
Comment on lines +5 to +6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can just import these directly here (though if you use constructor_eager they're probably not necessary)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I can confirm that constructor_eager works without them.



def test_index() -> None:
s = pd.Series([0, 1, 2])
snw = nw.from_native(s, series_only=True)
assert snw[snw[0]] == np.int64(0)
Comment on lines +9 to +12
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use constructor_eager here? like

s = nw.from_native(constructor_eager({'a': [0,1,2]}), eager_only=True)['a']
assert s[s[0]] == 0
```

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't understand constructor_eager before, but having tried it out, it seems like it runs the test for different libraries. Correct?
The problem here is that, while numpy.int64(0) == 0, for some reason <pyarrow.Int64Scalar: 0> != 0. Not sure about other libraries...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @thevro ! yeah Series.__getitem__ should return a Python scalar for PyArrow. i'll just make a PR, then if you fetch and merge then the comparison should work

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done #1554

if you fetch upstream and then merge upstream/main then the == 0 comparison should work

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And doing

import pyarrow as pa
from pyarrow import compute as pc
pazero = pc.cast(0, pa.int64())
...
idx = snw[0]  # test indexing using builtin int
assert idx == 0 or idx == pazero # idx should have type numpy.int64 or pa.int64 
assert snw[idx] == 0 or idx == pazero # test indexing using third-party int

produces a test error:

>       assert snw[idx] == 0 or idx == pazero # test indexing using numpy.int64

tests/series_only/scalar_index_test.py:19: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
narwhals/series.py:73: in __getitem__
    return self._from_compliant_series(self._compliant_series[idx])
narwhals/_arrow/series.py:378: in __getitem__
    return self._from_native_series(self._native_series[idx])
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

>   ???
E   TypeError: 'pyarrow.lib.Int64Scalar' object cannot be interpreted as an integer

pyarrow/table.pxi:312: TypeError

I suppose this comes of handling only numpy scalars.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, fixed now, if you fetch and merge upstream/main then == 0 should work

Loading