Skip to content

Commit

Permalink
Use T_DataArray in Weighted (#8630)
Browse files Browse the repository at this point in the history
* Use `T_DataArray` in `Weighted`

Allows subtypes.

(I had this in my git stash, so commiting it...)

* Apply suggestions from code review
  • Loading branch information
max-sixty authored Jan 22, 2024
1 parent 5bd3d8b commit e571d1c
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions xarray/core/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from xarray.core.alignment import align, broadcast
from xarray.core.computation import apply_ufunc, dot
from xarray.core.pycompat import is_duck_dask_array
from xarray.core.types import Dims, T_Xarray
from xarray.core.types import Dims, T_DataArray, T_Xarray
from xarray.util.deprecation_helpers import _deprecate_positional_args

# Weighted quantile methods are a subset of the numpy supported quantile methods.
Expand Down Expand Up @@ -145,7 +145,7 @@ class Weighted(Generic[T_Xarray]):

__slots__ = ("obj", "weights")

def __init__(self, obj: T_Xarray, weights: DataArray) -> None:
def __init__(self, obj: T_Xarray, weights: T_DataArray) -> None:
"""
Create a Weighted object
Expand Down Expand Up @@ -189,7 +189,7 @@ def _weight_check(w):
_weight_check(weights.data)

self.obj: T_Xarray = obj
self.weights: DataArray = weights
self.weights: T_DataArray = weights

def _check_dim(self, dim: Dims):
"""raise an error if any dimension is missing"""
Expand All @@ -208,11 +208,11 @@ def _check_dim(self, dim: Dims):

@staticmethod
def _reduce(
da: DataArray,
weights: DataArray,
da: T_DataArray,
weights: T_DataArray,
dim: Dims = None,
skipna: bool | None = None,
) -> DataArray:
) -> T_DataArray:
"""reduce using dot; equivalent to (da * weights).sum(dim, skipna)
for internal use only
Expand All @@ -230,7 +230,7 @@ def _reduce(
# DataArray (if `weights` has additional dimensions)
return dot(da, weights, dim=dim)

def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray:
def _sum_of_weights(self, da: T_DataArray, dim: Dims = None) -> T_DataArray:
"""Calculate the sum of weights, accounting for missing values"""

# we need to mask data values that are nan; else the weights are wrong
Expand All @@ -255,10 +255,10 @@ def _sum_of_weights(self, da: DataArray, dim: Dims = None) -> DataArray:

def _sum_of_squares(
self,
da: DataArray,
da: T_DataArray,
dim: Dims = None,
skipna: bool | None = None,
) -> DataArray:
) -> T_DataArray:
"""Reduce a DataArray by a weighted ``sum_of_squares`` along some dimension(s)."""

demeaned = da - da.weighted(self.weights).mean(dim=dim)
Expand All @@ -267,20 +267,20 @@ def _sum_of_squares(

def _weighted_sum(
self,
da: DataArray,
da: T_DataArray,
dim: Dims = None,
skipna: bool | None = None,
) -> DataArray:
) -> T_DataArray:
"""Reduce a DataArray by a weighted ``sum`` along some dimension(s)."""

return self._reduce(da, self.weights, dim=dim, skipna=skipna)

def _weighted_mean(
self,
da: DataArray,
da: T_DataArray,
dim: Dims = None,
skipna: bool | None = None,
) -> DataArray:
) -> T_DataArray:
"""Reduce a DataArray by a weighted ``mean`` along some dimension(s)."""

weighted_sum = self._weighted_sum(da, dim=dim, skipna=skipna)
Expand All @@ -291,10 +291,10 @@ def _weighted_mean(

def _weighted_var(
self,
da: DataArray,
da: T_DataArray,
dim: Dims = None,
skipna: bool | None = None,
) -> DataArray:
) -> T_DataArray:
"""Reduce a DataArray by a weighted ``var`` along some dimension(s)."""

sum_of_squares = self._sum_of_squares(da, dim=dim, skipna=skipna)
Expand All @@ -305,21 +305,21 @@ def _weighted_var(

def _weighted_std(
self,
da: DataArray,
da: T_DataArray,
dim: Dims = None,
skipna: bool | None = None,
) -> DataArray:
) -> T_DataArray:
"""Reduce a DataArray by a weighted ``std`` along some dimension(s)."""

return cast("DataArray", np.sqrt(self._weighted_var(da, dim, skipna)))
return cast("T_DataArray", np.sqrt(self._weighted_var(da, dim, skipna)))

def _weighted_quantile(
self,
da: DataArray,
da: T_DataArray,
q: ArrayLike,
dim: Dims = None,
skipna: bool | None = None,
) -> DataArray:
) -> T_DataArray:
"""Apply a weighted ``quantile`` to a DataArray along some dimension(s)."""

def _get_h(n: float, q: np.ndarray, method: QUANTILE_METHODS) -> np.ndarray:
Expand Down

0 comments on commit e571d1c

Please sign in to comment.