From 8ea0f09fd1771ec6ac20344eb10e8756b12279cc Mon Sep 17 00:00:00 2001 From: raisadz <34237447+raisadz@users.noreply.github.com> Date: Tue, 26 Nov 2024 14:53:29 +0000 Subject: [PATCH] feat: raise `NotImplementedError` for not supported parameters in `ewm_mean` for cuDF (#1449) --- narwhals/_pandas_like/series.py | 18 +++++++++++++++--- tests/expr_and_series/ewm_test.py | 29 ++++++++++++++++++++++------- 2 files changed, 37 insertions(+), 10 deletions(-) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index c8520529a..6ae76c7b5 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -188,9 +188,21 @@ def ewm_mean( ) -> PandasLikeSeries: ser = self._native_series mask_na = ser.isna() - result = ser.ewm( - com, span, half_life, alpha, min_periods, adjust, ignore_na=ignore_nulls - ).mean() + if self._implementation is Implementation.CUDF: + if (min_periods == 0 and not ignore_nulls) or (not mask_na.any()): + result = ser.ewm( + com=com, span=span, halflife=half_life, alpha=alpha, adjust=adjust + ).mean() + else: + msg = ( + "cuDF only supports `ewm_mean` when there are no missing values " + "or when both `min_period=0` and `ignore_nulls=False`" + ) + raise NotImplementedError(msg) + else: + result = ser.ewm( + com, span, half_life, alpha, min_periods, adjust, ignore_na=ignore_nulls + ).mean() result[mask_na] = None return self._from_native_series(result) diff --git a/tests/expr_and_series/ewm_test.py b/tests/expr_and_series/ewm_test.py index 641e8961d..5277576ce 100644 --- a/tests/expr_and_series/ewm_test.py +++ b/tests/expr_and_series/ewm_test.py @@ -16,9 +16,9 @@ "ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature." ) def test_ewm_mean_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if any( - x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf") - ) or ("polars" in str(constructor) and POLARS_VERSION < (1,)): + if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or ( + "polars" in str(constructor) and POLARS_VERSION < (1,) + ): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -36,7 +36,7 @@ def test_ewm_mean_expr(request: pytest.FixtureRequest, constructor: Constructor) def test_ewm_mean_series( request: pytest.FixtureRequest, constructor_eager: ConstructorEager ) -> None: - if any(x in str(constructor_eager) for x in ("pyarrow_table_", "modin", "cudf")) or ( + if any(x in str(constructor_eager) for x in ("pyarrow_table_", "modin")) or ( "polars" in str(constructor_eager) and POLARS_VERSION < (1,) ): request.applymarker(pytest.mark.xfail) @@ -75,9 +75,9 @@ def test_ewm_mean_expr_adjust( adjust: bool, # noqa: FBT001 expected: dict[str, list[float]], ) -> None: - if any( - x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf") - ) or ("polars" in str(constructor) and POLARS_VERSION < (1,)): + if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or ( + "polars" in str(constructor) and POLARS_VERSION < (1,) + ): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -187,3 +187,18 @@ def test_ewm_mean_params( with pytest.raises(ValueError, match="mutually exclusive"): df.select(nw.col("a").ewm_mean(span=1.5, half_life=0.75, ignore_nulls=False)) + + +@pytest.mark.filterwarnings( + "ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature." +) +def test_ewm_mean_cudf_raise() -> None: # pragma: no cover + pytest.importorskip("cudf") + import cudf + + df = nw.from_native(cudf.DataFrame({"a": [2.0, 4.0, None, 3.0]})) + with pytest.raises( + NotImplementedError, + match="cuDF only supports `ewm_mean` when there are no missing values", + ): + df.select(nw.col("a").ewm_mean(com=1))