Skip to content

Commit

Permalink
axis arg
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrilJl committed May 28, 2024
1 parent 3547ffd commit d5195ed
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 111 deletions.
87 changes: 87 additions & 0 deletions batchstats/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import numpy as np

from ._misc import any_nan, check_params


class BatchStat:
"""
Base class for calculating statistics over batches of data.
Attributes:
n_samples (int): Total number of samples processed.
"""

def __init__(self, axis=0):
self.axis = check_params(param=axis, types=(int, tuple))
self.n_samples = 0

def _complementary_axis(self, ndim):
if isinstance(self.axis, int):
return tuple(set(range(ndim)) - set((self.axis,)))
else:
return tuple(set(range(ndim)) - set(self.axis))

def _process_batch(self, batch, assume_valid=False):
"""
Process the input batch, handling NaN values if necessary.
Args:
batch (numpy.ndarray): Input batch.
assume_valid (bool, optional): If True, assumes all elements in the batch are valid. Default is False.
Returns:
numpy.ndarray: Processed batch.
"""
batch = np.atleast_2d(np.asarray(batch))
if assume_valid:
self.n_samples += len(batch)
return batch
else:
axis = self._complementary_axis(ndim=batch.ndim)
nan_mask = any_nan(batch, axis=axis)
if nan_mask.any():
valid_batch = batch[~nan_mask]
else:
valid_batch = batch
self.n_samples += len(valid_batch)
return valid_batch

def __repr__(self):
return f"{self.__class__.__name__}()"


class BatchNanStat:
"""
Base class for calculating statistics over batches of data that can contain NaN values.
Attributes:
n_samples (numpy.ndarray): Total number of samples processed, accounting for NaN values.
"""

def __init__(self, axis=0):
"""
Initialize the BatchNanStat object.
"""
self.n_samples = None
self.axis = axis

def _process_batch(self, batch):
"""
Process the input batch, counting NaN values.
Args:
batch (numpy.ndarray): Input batch.
Returns:
numpy.ndarray: Processed batch.
"""
batch = np.atleast_2d(np.asarray(batch))
if self.n_samples is None:
self.n_samples = np.isfinite(batch).sum(axis=self.axis)
else:
self.n_samples += np.isfinite(batch).sum(axis=self.axis)
return batch
50 changes: 8 additions & 42 deletions batchstats/nanstats.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,7 @@
import numpy as np

from ._misc import NoValidSamplesError


class BatchNanStat:
"""
Base class for calculating statistics over batches of data that can contain NaN values.
Attributes:
n_samples (numpy.ndarray): Total number of samples processed, accounting for NaN values.
"""

def __init__(self):
"""
Initialize the BatchNanStat object.
"""
self.n_samples = None

def _process_batch(self, batch):
"""
Process the input batch, counting NaN values.
Args:
batch (numpy.ndarray): Input batch.
Returns:
numpy.ndarray: Processed batch.
"""
batch = np.atleast_2d(np.asarray(batch))
axis = tuple(range(1, batch.ndim))
if self.n_samples is None:
self.n_samples = np.isfinite(batch).sum(axis=0)
else:
self.n_samples += np.isfinite(batch).sum(axis=0)
return batch
from .core import BatchNanStat


class BatchNanSum(BatchNanStat):
Expand All @@ -44,11 +10,11 @@ class BatchNanSum(BatchNanStat):
"""

def __init__(self):
def __init__(self, axis=0):
"""
Initialize the BatchNanSum object.
"""
super().__init__()
super().__init__(axis=axis)
self.sum = None

def update_batch(self, batch):
Expand All @@ -65,9 +31,9 @@ def update_batch(self, batch):
batch = self._process_batch(batch)
axis = tuple(range(1, batch.ndim))
if self.sum is None:
self.sum = np.nansum(batch, axis=0)
self.sum = np.nansum(batch, axis=self.axis)
else:
self.sum += np.nansum(batch, axis=0)
self.sum += np.nansum(batch, axis=self.axis)
return self

def __call__(self):
Expand All @@ -93,12 +59,12 @@ class BatchNanMean(BatchNanStat):
"""

def __init__(self):
def __init__(self, axis=0):
"""
Initialize the BatchNanMean object.
"""
super().__init__()
self.sum = BatchNanSum()
super().__init__(axis=axis)
self.sum = BatchNanSum(axis=axis)

def update_batch(self, batch):
"""
Expand Down
96 changes: 27 additions & 69 deletions batchstats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,7 @@
import numpy as np

from ._misc import NoValidSamplesError, UnequalSamplesNumber, any_nan, check_params


class BatchStat:
"""
Base class for calculating statistics over batches of data.
Attributes:
n_samples (int): Total number of samples processed.
"""

def __init__(self):
self.n_samples = 0

def _process_batch(self, batch, assume_valid=False):
"""
Process the input batch, handling NaN values if necessary.
Args:
batch (numpy.ndarray): Input batch.
assume_valid (bool, optional): If True, assumes all elements in the batch are valid. Default is False.
Returns:
numpy.ndarray: Processed batch.
"""
batch = np.atleast_2d(np.asarray(batch))
if assume_valid:
self.n_samples += len(batch)
return batch
else:
axis = tuple(range(1, batch.ndim))
nan_mask = any_nan(batch, axis=axis)
if nan_mask.any():
valid_batch = batch[~nan_mask]
else:
valid_batch = batch
self.n_samples += len(valid_batch)
return valid_batch

def __repr__(self):
return f"{self.__class__.__name__}()"
from .core import BatchStat


class BatchSum(BatchStat):
Expand All @@ -53,8 +12,8 @@ class BatchSum(BatchStat):
"""

def __init__(self):
super().__init__()
def __init__(self, axis=0):
super().__init__(axis=axis)
self.sum = None

def update_batch(self, batch, assume_valid=False):
Expand All @@ -73,9 +32,9 @@ def update_batch(self, batch, assume_valid=False):
n = len(valid_batch)
if n > 0:
if self.sum is None:
self.sum = np.sum(a=valid_batch, axis=0)
self.sum = np.sum(a=valid_batch, axis=self.axis)
else:
self.sum += np.sum(a=valid_batch, axis=0)
self.sum += np.sum(a=valid_batch, axis=self.axis)
return self

def __call__(self) -> np.ndarray:
Expand All @@ -101,8 +60,8 @@ class BatchMax(BatchStat):
"""

def __init__(self):
super().__init__()
def __init__(self, axis=0):
super().__init__(axis=axis)
self.max = None

def update_batch(self, batch, assume_valid=False):
Expand All @@ -121,9 +80,9 @@ def update_batch(self, batch, assume_valid=False):
n = len(valid_batch)
if n > 0:
if self.max is None:
self.max = np.max(valid_batch, axis=0)
self.max = np.max(valid_batch, axis=self.axis)
else:
np.maximum(self.max, np.max(valid_batch, axis=0), out=self.max)
np.maximum(self.max, np.max(valid_batch, axis=self.axis), out=self.max)
return self

def __call__(self) -> np.ndarray:
Expand All @@ -149,8 +108,8 @@ class BatchMin(BatchStat):
"""

def __init__(self):
super().__init__()
def __init__(self, axis=0):
super().__init__(axis=axis)
self.min = None

def update_batch(self, batch, assume_valid=False):
Expand All @@ -169,9 +128,9 @@ def update_batch(self, batch, assume_valid=False):
n = len(valid_batch)
if n > 0:
if self.min is None:
self.min = np.min(valid_batch, axis=0)
self.min = np.min(valid_batch, axis=self.axis)
else:
np.minimum(self.min, np.min(valid_batch, axis=0), out=self.min)
np.minimum(self.min, np.min(valid_batch, axis=self.axis), out=self.min)
return self

def __call__(self) -> np.ndarray:
Expand All @@ -197,8 +156,8 @@ class BatchMean(BatchStat):
"""

def __init__(self):
super().__init__()
def __init__(self, axis=0):
super().__init__(axis=axis)
self.mean = None

def update_batch(self, batch, assume_valid=False):
Expand All @@ -217,9 +176,9 @@ def update_batch(self, batch, assume_valid=False):
n = len(valid_batch)
if n > 0:
if self.mean is None:
self.mean = np.mean(valid_batch, axis=0)
self.mean = np.mean(valid_batch, axis=self.axis)
else:
self.mean = ((self.n_samples - n) * self.mean + np.sum(valid_batch, axis=0)) / self.n_samples
self.mean = ((self.n_samples - n) * self.mean + np.sum(valid_batch, axis=self.axis)) / self.n_samples
return self

def __call__(self) -> np.ndarray:
Expand All @@ -244,10 +203,10 @@ class BatchPeakToPeak(BatchStat):
Class for calculating the peak-to-peak (max - min) of batches of data.
"""

def __init__(self):
super().__init__()
self.batchmax = BatchMax()
self.batchmin = BatchMin()
def __init__(self, axis=0):
super().__init__(axis=axis)
self.batchmax = BatchMax(axis=axis)
self.batchmin = BatchMin(axis=axis)

def update_batch(self, batch, assume_valid=False):
"""
Expand Down Expand Up @@ -287,9 +246,9 @@ class BatchVar(BatchMean):
ddof (int, optional): Means Delta Degrees of Freedom. The divisor used in calculations is N - ddof, where N represents the number of elements. By default ddof is zero.
"""

def __init__(self, ddof=0):
super().__init__()
self.mean = BatchMean()
def __init__(self, axis=0, ddof=0):
super().__init__(axis=axis)
self.mean = BatchMean(axis=axis)
self.var = None
self.ddof = check_params(param=ddof, types=int)

Expand Down Expand Up @@ -386,9 +345,9 @@ class BatchStd(BatchStat):
ddof (int, optional): Means Delta Degrees of Freedom. The divisor used in calculations is N - ddof, where N represents the number of elements. By default ddof is zero.
"""

def __init__(self, ddof=0):
super().__init__()
self.var = BatchVar(ddof=ddof)
def __init__(self, axis=0, ddof=0):
super().__init__(axis=axis)
self.var = BatchVar(axis=axis, ddof=ddof)

def update_batch(self, batch, assume_valid=False):
"""Update the standard deviation with a new batch of data.
Expand Down Expand Up @@ -512,4 +471,3 @@ def __call__(self) -> np.ndarray:
if self.cov is None:
raise NoValidSamplesError("No valid samples for calculating covariance.")
return self.n_samples/(self.n_samples - self.ddof)*self.cov
return self.n_samples/(self.n_samples - self.ddof)*self.cov

0 comments on commit d5195ed

Please sign in to comment.