Skip to content

Commit

Permalink
chore: rename validate_column_comparand utility functions, add pep740…
Browse files Browse the repository at this point in the history
… badge (#1477)

* docs: Add trusted publishing badge to README

* chore: rename validate_column_comparand utility functions, add pep740 badge

* simplify
  • Loading branch information
MarcoGorelli authored Dec 1, 2024
1 parent dcf2533 commit cc76526
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 87 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@

[![PyPI version](https://badge.fury.io/py/narwhals.svg)](https://badge.fury.io/py/narwhals)
[![Downloads](https://static.pepy.tech/badge/narwhals/month)](https://pepy.tech/project/narwhals)
[![Trusted publishing](https://img.shields.io/badge/Trusted_publishing-Provides_attestations-bright_green)](https://peps.python.org/pep-0740/)

Extremely lightweight and extensible compatibility layer between dataframe libraries!

- **Full API support**: cuDF, Modin, pandas, Polars, PyArrow
- **Lazy-only support**: Dask
- **Interchange-level support**: Ibis, Vaex, anything else which implements the DataFrame Interchange Protocol
- **Interchange-level support**: DuckDB, Ibis, Vaex, anything which implements the DataFrame Interchange Protocol

Seamlessly support all, without depending on any!

Expand Down
6 changes: 5 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

![](assets/image.png)

[![PyPI version](https://badge.fury.io/py/narwhals.svg)](https://badge.fury.io/py/narwhals)
[![Downloads](https://static.pepy.tech/badge/narwhals/month)](https://pepy.tech/project/narwhals)
[![Trusted publishing](https://img.shields.io/badge/Trusted_publishing-Provides_attestations-bright_green)](https://peps.python.org/pep-0740/)

Extremely lightweight and extensible compatibility layer between dataframe libraries!

- **Full API support**: cuDF, Modin, pandas, Polars, PyArrow
- **Lazy-only support**: Dask
- **Interchange-level support**: Ibis, DuckDB, Vaex, anything else which implements the DataFrame Interchange Protocol
- **Interchange-level support**: DuckDB, Ibis, Vaex, anything which implements the DataFrame Interchange Protocol

Seamlessly support all, without depending on any!

Expand Down
50 changes: 26 additions & 24 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from typing import Sequence
from typing import overload

from narwhals._arrow.utils import broadcast_and_extract_native
from narwhals._arrow.utils import cast_for_truediv
from narwhals._arrow.utils import floordiv_compat
from narwhals._arrow.utils import narwhals_to_native_dtype
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._arrow.utils import parse_datetime_format
from narwhals._arrow.utils import validate_column_comparand
from narwhals.utils import Implementation
from narwhals.utils import generate_temporary_column_name

Expand Down Expand Up @@ -101,67 +101,67 @@ def __len__(self: Self) -> int:
def __eq__(self: Self, other: object) -> Self: # type: ignore[override]
import pyarrow.compute as pc

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.equal(ser, other))

def __ne__(self: Self, other: object) -> Self: # type: ignore[override]
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.not_equal(ser, other))

def __ge__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.greater_equal(ser, other))

def __gt__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.greater(ser, other))

def __le__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.less_equal(ser, other))

def __lt__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.less(ser, other))

def __and__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.and_kleene(ser, other))

def __rand__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.and_kleene(other, ser))

def __or__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.or_kleene(ser, other))

def __ror__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.or_kleene(other, ser))

def __add__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.add(ser, other))

def __radd__(self: Self, other: Any) -> Self:
Expand All @@ -170,7 +170,7 @@ def __radd__(self: Self, other: Any) -> Self:
def __sub__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.subtract(ser, other))

def __rsub__(self: Self, other: Any) -> Self:
Expand All @@ -179,7 +179,7 @@ def __rsub__(self: Self, other: Any) -> Self:
def __mul__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.multiply(ser, other))

def __rmul__(self: Self, other: Any) -> Self:
Expand All @@ -188,28 +188,28 @@ def __rmul__(self: Self, other: Any) -> Self:
def __pow__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.power(ser, other))

def __rpow__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(pc.power(other, ser))

def __floordiv__(self: Self, other: Any) -> Self:
ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(floordiv_compat(ser, other))

def __rfloordiv__(self: Self, other: Any) -> Self:
ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
return self._from_native_series(floordiv_compat(other, ser))

def __truediv__(self: Self, other: Any) -> Self:
import pyarrow as pa # ignore-banned-import()
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
if not isinstance(other, (pa.Array, pa.ChunkedArray)):
# scalar
other = pa.scalar(other)
Expand All @@ -219,7 +219,7 @@ def __rtruediv__(self: Self, other: Any) -> Self:
import pyarrow as pa # ignore-banned-import()
import pyarrow.compute as pc # ignore-banned-import()

ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
if not isinstance(other, (pa.Array, pa.ChunkedArray)):
# scalar
other = pa.scalar(other)
Expand All @@ -229,15 +229,15 @@ def __mod__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

floor_div = (self // other)._native_series
ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
res = pc.subtract(ser, pc.multiply(floor_div, other))
return self._from_native_series(res)

def __rmod__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

floor_div = (other // self)._native_series
ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
res = pc.subtract(other, pc.multiply(floor_div, ser))
return self._from_native_series(res)

Expand All @@ -251,7 +251,7 @@ def len(self: Self, *, _return_py_scalar: bool = True) -> int:

def filter(self: Self, other: Any) -> Self:
if not (isinstance(other, list) and all(isinstance(x, bool) for x in other)):
ser, other = validate_column_comparand(self, other, self._backend_version)
ser, other = broadcast_and_extract_native(self, other, self._backend_version)
else:
ser = self._native_series
return self._from_native_series(ser.filter(other))
Expand Down Expand Up @@ -382,7 +382,9 @@ def scatter(self: Self, indices: int | Sequence[int], values: Any) -> Self:
mask = np.zeros(self.len(), dtype=bool)
mask[indices] = True
if isinstance(values, self.__class__):
ser, values = validate_column_comparand(self, values, self._backend_version)
ser, values = broadcast_and_extract_native(
self, values, self._backend_version
)
else:
ser = self._native_series
if isinstance(values, pa.ChunkedArray):
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> pa.D
raise AssertionError(msg)


def validate_column_comparand(
def broadcast_and_extract_native(
lhs: ArrowSeries, rhs: Any, backend_version: tuple[int, ...]
) -> tuple[pa.ChunkedArray, Any]:
"""Validate RHS of binary operation.
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def __init__(
def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
from narwhals._expression_parsing import parse_into_expr
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals._pandas_like.utils import validate_column_comparand
from narwhals._pandas_like.utils import broadcast_align_and_extract_native

plx = PandasLikeNamespace(
implementation=self._implementation,
Expand All @@ -501,7 +501,7 @@ def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
)
value_series = cast(PandasLikeSeries, value_series)

value_series_native, condition_native = validate_column_comparand(
value_series_native, condition_native = broadcast_align_and_extract_native(
value_series, condition
)

Expand Down
Loading

0 comments on commit cc76526

Please sign in to comment.