Skip to content

Commit

Permalink
fixup for data with no nulls
Browse files Browse the repository at this point in the history
  • Loading branch information
raisadz committed Nov 26, 2024
1 parent ad20205 commit 9442ec9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 39 deletions.
17 changes: 9 additions & 8 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,18 @@ def ewm_mean(
ser = self._native_series
mask_na = ser.isna()
if self._implementation is Implementation.CUDF:
if min_periods == 0 and not ignore_nulls:
if (min_periods == 0 and not ignore_nulls) or (not mask_na.any()):
result = ser.ewm(com, span, half_life, alpha, adjust).mean()
elif min_periods != 0:
msg = "`min_periods != 0` is not yet implemented for cuDF"
raise NotImplementedError(msg)
else:
msg = "`ignore_nulls=True` is not yet implemented for cuDF"
msg = (
"cuDF only supports `ewm_mean` when there are no missing values "
"or when both `min_period=0` and `ignore_nulls=False`"
)
raise NotImplementedError(msg)
result = ser.ewm(
com, span, half_life, alpha, min_periods, adjust, ignore_na=ignore_nulls
).mean()
else:
result = ser.ewm(
com, span, half_life, alpha, min_periods, adjust, ignore_na=ignore_nulls
).mean()
result[mask_na] = None
return self._from_native_series(result)

Expand Down
46 changes: 15 additions & 31 deletions tests/expr_and_series/ewm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
"ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature."
)
def test_ewm_mean_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if any(
x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")
) or ("polars" in str(constructor) and POLARS_VERSION < (1,)):
if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or (
"polars" in str(constructor) and POLARS_VERSION < (1,)
):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand All @@ -36,7 +36,7 @@ def test_ewm_mean_expr(request: pytest.FixtureRequest, constructor: Constructor)
def test_ewm_mean_series(
request: pytest.FixtureRequest, constructor_eager: ConstructorEager
) -> None:
if any(x in str(constructor_eager) for x in ("pyarrow_table_", "modin", "cudf")) or (
if any(x in str(constructor_eager) for x in ("pyarrow_table_", "modin")) or (
"polars" in str(constructor_eager) and POLARS_VERSION < (1,)
):
request.applymarker(pytest.mark.xfail)
Expand Down Expand Up @@ -75,9 +75,9 @@ def test_ewm_mean_expr_adjust(
adjust: bool, # noqa: FBT001
expected: dict[str, list[float]],
) -> None:
if any(
x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")
) or ("polars" in str(constructor) and POLARS_VERSION < (1,)):
if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or (
"polars" in str(constructor) and POLARS_VERSION < (1,)
):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand Down Expand Up @@ -189,40 +189,24 @@ def test_ewm_mean_params(
df.select(nw.col("a").ewm_mean(span=1.5, half_life=0.75, ignore_nulls=False))


@pytest.mark.filterwarnings(
"ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature."
)
def test_ewm_mean_cudf_default_params(
constructor: Constructor,
request: pytest.FixtureRequest,
) -> None:
if any(
x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")
) or ("polars" in str(constructor) and POLARS_VERSION < (1,)):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
expected: dict[str, list[float | None]] = {"a": [1.0, 1.0, 1.5714285714285714]}
assert_equal_data(
df.select(nw.col("a").ewm_mean(com=1, min_periods=0)),
expected,
)


@pytest.mark.filterwarnings(
"ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature."
)
def test_ewm_mean_cudf_raise() -> None: # pragma: no cover
pytest.importorskip("cudf")
import cudf

df = nw.from_native(cudf.DataFrame(data))
df_with_nulls = nw.from_native(cudf.DataFrame({"a": [2.0, 4.0, None, 3.0]}))
df_ignore_nulls = nw.from_native(cudf.DataFrame(data))
with pytest.raises(
NotImplementedError,
match="`min_periods != 0` is not yet implemented for cuDF",
match="cuDF only supports `ewm_mean` when there are no missing values",
):
df.select(nw.col("a").ewm_mean(com=1))
df_with_nulls.select(nw.col("a").ewm_mean(com=1))
with pytest.raises(
NotImplementedError,
match="`ignore_nulls=True` is not yet implemented for cuDF",
match="cuDF only supports `ewm_mean` when there are no missing values",
):
df.select(nw.col("a").ewm_mean(com=1, min_periods=0, ignore_nulls=True))
df_ignore_nulls.select(
nw.col("a").ewm_mean(com=1, min_periods=0, ignore_nulls=True)
)

0 comments on commit 9442ec9

Please sign in to comment.