Skip to content

Commit

Permalink
Dataset.reduce pass through non-numeric scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
mathause committed Oct 10, 2024
1 parent c057d13 commit 7310715
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 17 deletions.
22 changes: 11 additions & 11 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7009,17 +7009,17 @@ def reduce(
if not reduce_dims:
variables[name] = var
else:
if (
# Some reduction functions (e.g. std, var) need to run on variables
# that don't have the reduce dims: PR5393
not is_extension_array_dtype(var.dtype)
and (
not reduce_dims
or not numeric_only
or np.issubdtype(var.dtype, np.number)
or (var.dtype == np.bool_)
)
):

is_numeric = (not is_extension_array_dtype(var.dtype)) and (
np.issubdtype(var.dtype, np.number) or var.dtype == np.bool_
)

# pass through non-numeric scalar
if numeric_only and not is_numeric and var.ndim == 0:
variables[name] = var

elif not reduce_dims or not numeric_only or is_numeric:

# prefer to aggregate over axis=None rather than
# axis=(0, 1) if they will be equivalent, because
# the former is often more efficient
Expand Down
33 changes: 27 additions & 6 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5661,14 +5661,35 @@ def test_reduce_non_numeric(self) -> None:
data = np.random.randint(0, 100, size=size).astype(np.str_)
data1[v] = (dims, data, {"foo": "variable"})
# var4 is extension array categorical and should be dropped
assert (
"var4" not in data1.mean()
and "var5" not in data1.mean()
and "var6" not in data1.mean()
)

assert "var4" not in data1.mean()
assert "var5" not in data1.mean()
assert "var6" not in data1.mean()

assert_equal(data1.mean(), data2.mean())
assert_equal(data1.mean(dim="dim1"), data2.mean(dim="dim1"))
assert "var5" not in data1.mean(dim="dim2") and "var6" in data1.mean(dim="dim2")

assert "var5" not in data1.mean(dim="dim2")
assert "var6" in data1.mean(dim="dim2")

@pytest.mark.parametrize("op", ("sum", "prod", "mean", "std"))
def test_reduce_non_numeric_scalar(self, op) -> None:
# enusure non-numeric scalar is passed through

data_orig = create_test_data(seed=44, dim_sizes=(1, 2, 3))

# add a scalar
data = data_orig.assign(var4="string")

result = getattr(data, op)()
expected = getattr(data_orig, op)().assign(var4="string")

assert_equal(result, expected)

result = getattr(data, op)("dim1")
expected = getattr(data_orig, op)("dim1").assign(var4="string")

assert_equal(result, expected)

@pytest.mark.filterwarnings(
"ignore:Once the behaviour of DataArray:DeprecationWarning"
Expand Down

0 comments on commit 7310715

Please sign in to comment.