Skip to content

Commit

Permalink
raise for not implemented params for cudf, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
raisadz committed Nov 26, 2024
1 parent d51b44e commit e645bb2
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 13 deletions.
12 changes: 9 additions & 3 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,15 @@ def ewm_mean(
) -> PandasLikeSeries:
ser = self._native_series
mask_na = ser.isna()
if self._implementation is Implementation.CUDF and sum(mask_na) > 0:
msg = "`ewm_mean` with null values is not yet implemented for cuDF"
raise NotImplementedError(msg)
if self._implementation is Implementation.CUDF:
if min_periods == 0 and not ignore_nulls:
result = ser.ewm(com, span, half_life, alpha, adjust).mean()
elif min_periods != 0:
msg = "`min_periods != 0` is not yet implemented for cuDF"
raise NotImplementedError(msg)
else:
msg = "`ignore_nulls=True` is not yet implemented for cuDF"
raise NotImplementedError(msg)
result = ser.ewm(
com, span, half_life, alpha, min_periods, adjust, ignore_na=ignore_nulls
).mean()
Expand Down
59 changes: 49 additions & 10 deletions tests/expr_and_series/ewm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
"ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature."
)
def test_ewm_mean_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or (
"polars" in str(constructor) and POLARS_VERSION < (1,)
):
if any(
x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")
) or ("polars" in str(constructor) and POLARS_VERSION < (1,)):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand All @@ -36,7 +36,7 @@ def test_ewm_mean_expr(request: pytest.FixtureRequest, constructor: Constructor)
def test_ewm_mean_series(
request: pytest.FixtureRequest, constructor_eager: ConstructorEager
) -> None:
if any(x in str(constructor_eager) for x in ("pyarrow_table_", "modin")) or (
if any(x in str(constructor_eager) for x in ("pyarrow_table_", "modin", "cudf")) or (
"polars" in str(constructor_eager) and POLARS_VERSION < (1,)
):
request.applymarker(pytest.mark.xfail)
Expand Down Expand Up @@ -75,9 +75,9 @@ def test_ewm_mean_expr_adjust(
adjust: bool, # noqa: FBT001
expected: dict[str, list[float]],
) -> None:
if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or (
"polars" in str(constructor) and POLARS_VERSION < (1,)
):
if any(
x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")
) or ("polars" in str(constructor) and POLARS_VERSION < (1,)):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand Down Expand Up @@ -154,9 +154,9 @@ def test_ewm_mean_params(
request: pytest.FixtureRequest,
constructor: Constructor,
) -> None:
if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or (
"polars" in str(constructor) and POLARS_VERSION < (1,)
):
if any(
x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")
) or ("polars" in str(constructor) and POLARS_VERSION < (1,)):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor({"a": [2, 5, 3]}))
Expand Down Expand Up @@ -187,3 +187,42 @@ def test_ewm_mean_params(

with pytest.raises(ValueError, match="mutually exclusive"):
df.select(nw.col("a").ewm_mean(span=1.5, half_life=0.75, ignore_nulls=False))


@pytest.mark.filterwarnings(
"ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature."
)
def test_ewm_mean_cudf_default_params(
constructor: Constructor,
request: pytest.FixtureRequest,
) -> None:
if any(
x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")
) or ("polars" in str(constructor) and POLARS_VERSION < (1,)):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
expected: dict[str, list[float | None]] = {"a": [1.0, 1.0, 1.5714285714285714]}
assert_equal_data(
df.select(nw.col("a").ewm_mean(com=1, min_periods=0)),
expected,
)


@pytest.mark.filterwarnings(
"ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature."
)
def test_ewm_mean_cudf_raise() -> None:
pytest.importorskip("cudf")
import cudf

df = nw.from_native(cudf.DataFrame(data))
with pytest.raises(
NotImplementedError,
match="`min_periods != 0` is not yet implemented for cuDF",
):
df.select(nw.col("a").ewm_mean(com=1))
with pytest.raises(
NotImplementedError,
match="`ignore_nulls=True` is not yet implemented for cuDF",
):
df.select(nw.col("a").ewm_mean(com=1, min_periods=0, ignore_nulls=True))

0 comments on commit e645bb2

Please sign in to comment.