Skip to content

Commit

Permalink
feat: raise NotImplementedError for not supported parameters in `ew…
Browse files Browse the repository at this point in the history
…m_mean` for cuDF (#1449)
  • Loading branch information
raisadz authored Nov 26, 2024
1 parent 73e67c7 commit 8ea0f09
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
18 changes: 15 additions & 3 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,21 @@ def ewm_mean(
) -> PandasLikeSeries:
ser = self._native_series
mask_na = ser.isna()
result = ser.ewm(
com, span, half_life, alpha, min_periods, adjust, ignore_na=ignore_nulls
).mean()
if self._implementation is Implementation.CUDF:
if (min_periods == 0 and not ignore_nulls) or (not mask_na.any()):
result = ser.ewm(
com=com, span=span, halflife=half_life, alpha=alpha, adjust=adjust
).mean()
else:
msg = (
"cuDF only supports `ewm_mean` when there are no missing values "
"or when both `min_period=0` and `ignore_nulls=False`"
)
raise NotImplementedError(msg)
else:
result = ser.ewm(
com, span, half_life, alpha, min_periods, adjust, ignore_na=ignore_nulls
).mean()
result[mask_na] = None
return self._from_native_series(result)

Expand Down
29 changes: 22 additions & 7 deletions tests/expr_and_series/ewm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
"ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature."
)
def test_ewm_mean_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if any(
x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")
) or ("polars" in str(constructor) and POLARS_VERSION < (1,)):
if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or (
"polars" in str(constructor) and POLARS_VERSION < (1,)
):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand All @@ -36,7 +36,7 @@ def test_ewm_mean_expr(request: pytest.FixtureRequest, constructor: Constructor)
def test_ewm_mean_series(
request: pytest.FixtureRequest, constructor_eager: ConstructorEager
) -> None:
if any(x in str(constructor_eager) for x in ("pyarrow_table_", "modin", "cudf")) or (
if any(x in str(constructor_eager) for x in ("pyarrow_table_", "modin")) or (
"polars" in str(constructor_eager) and POLARS_VERSION < (1,)
):
request.applymarker(pytest.mark.xfail)
Expand Down Expand Up @@ -75,9 +75,9 @@ def test_ewm_mean_expr_adjust(
adjust: bool, # noqa: FBT001
expected: dict[str, list[float]],
) -> None:
if any(
x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")
) or ("polars" in str(constructor) and POLARS_VERSION < (1,)):
if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or (
"polars" in str(constructor) and POLARS_VERSION < (1,)
):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand Down Expand Up @@ -187,3 +187,18 @@ def test_ewm_mean_params(

with pytest.raises(ValueError, match="mutually exclusive"):
df.select(nw.col("a").ewm_mean(span=1.5, half_life=0.75, ignore_nulls=False))


@pytest.mark.filterwarnings(
"ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature."
)
def test_ewm_mean_cudf_raise() -> None: # pragma: no cover
pytest.importorskip("cudf")
import cudf

df = nw.from_native(cudf.DataFrame({"a": [2.0, 4.0, None, 3.0]}))
with pytest.raises(
NotImplementedError,
match="cuDF only supports `ewm_mean` when there are no missing values",
):
df.select(nw.col("a").ewm_mean(com=1))

0 comments on commit 8ea0f09

Please sign in to comment.