From 212c2e46ff7d49905d1e5e67c4f1ee4bb396df29 Mon Sep 17 00:00:00 2001 From: CyrilJl Date: Mon, 13 May 2024 23:28:58 +0200 Subject: [PATCH] add nanstats --- batchstats/__init__.py | 8 +++-- batchstats/_misc.py | 31 ++++++++++++++++ batchstats/nanstats.py | 51 +++++++++++++++++++++++++++ batchstats/{core.py => stats.py} | 31 +--------------- tests/test_nanstats.py | 40 +++++++++++++++++++++ tests/{test_core.py => test_stats.py} | 0 6 files changed, 128 insertions(+), 33 deletions(-) create mode 100644 batchstats/nanstats.py rename batchstats/{core.py => stats.py} (95%) create mode 100644 tests/test_nanstats.py rename tests/{test_core.py => test_stats.py} (100%) diff --git a/batchstats/__init__.py b/batchstats/__init__.py index 968fac9..f4f2591 100644 --- a/batchstats/__init__.py +++ b/batchstats/__init__.py @@ -1,5 +1,7 @@ -from .core import BatchCov, BatchMax, BatchMean, BatchMin, BatchStat, BatchSum, BatchVar +from .nanstats import BatchNanMean, BatchNanStat, BatchNanSum +from .stats import BatchCov, BatchMax, BatchMean, BatchMin, BatchStat, BatchSum, BatchVar -__all__ = ['BatchCov', 'BatchMax', 'BatchMean', 'BatchMin', 'BatchStat', 'BatchSum', 'BatchVar'] +__all__ = ['BatchCov', 'BatchMax', 'BatchMean', 'BatchMin', 'BatchStat', 'BatchSum', 'BatchVar', + 'BatchNanMean', 'BatchNanStat', 'BatchNanSum'] -__version__ = '0.2' \ No newline at end of file +__version__ = '0.2' diff --git a/batchstats/_misc.py b/batchstats/_misc.py index 50a5ed3..61eebe5 100644 --- a/batchstats/_misc.py +++ b/batchstats/_misc.py @@ -1,9 +1,40 @@ import warnings +import numpy as np + # Customize the warning format warnings.formatwarning = lambda msg, *args, **kwargs: str(msg) + '\n' +class NoValidSamplesError(ValueError): + """ + Error raised when there are no valid samples for calculation. + """ + pass + + +class UnequalSamplesNumber(ValueError): + """ + Error raised when two batches have unequal lengths. + """ + pass + + +def any_nan(x, axis=None): + """ + Check if there are any NaN values in the input array. + + Args: + x (numpy.ndarray): Input array. + axis (int or tuple of ints, optional): Axis or axes along which to operate. Default is None. + + Returns: + numpy.ndarray: Boolean array indicating NaN presence. + + """ + return np.isnan(np.add.reduce(array=x, axis=axis)) + + def check_params(param, params=None, types=None): # Check if the parameter's type matches the accepted types if (types is not None) and (not isinstance(param, types)): diff --git a/batchstats/nanstats.py b/batchstats/nanstats.py new file mode 100644 index 0000000..6e0ca6c --- /dev/null +++ b/batchstats/nanstats.py @@ -0,0 +1,51 @@ +import numpy as np + +from ._misc import NoValidSamplesError + + +class BatchNanStat: + def __init__(self): + self.n_samples = None + + def _process_batch(self, batch): + batch = np.atleast_2d(np.asarray(batch)) + axis = tuple(range(1, batch.ndim)) + if self.n_samples is None: + self.n_samples = np.isfinite(batch).sum(axis=0) + else: + self.n_samples += np.isfinite(batch).sum(axis=0) + return batch + + +class BatchNanSum(BatchNanStat): + def __init__(self): + super().__init__() + self.sum = None + + def update_batch(self, batch): + batch = self._process_batch(batch) + axis = tuple(range(1, batch.ndim)) + if self.sum is None: + self.sum = np.nansum(batch, axis=0) + else: + self.sum += np.nansum(batch, axis=0) + return self + + def __call__(self): + if self.sum is None: + raise NoValidSamplesError() + else: + return np.where(self.n_samples > 0, self.sum, np.nan) + + +class BatchNanMean(BatchNanStat): + def __init__(self): + super().__init__() + self.sum = BatchNanSum() + + def update_batch(self, batch): + self.sum.update_batch(batch) + return self + + def __call__(self): + return self.sum()/self.sum.n_samples diff --git a/batchstats/core.py b/batchstats/stats.py similarity index 95% rename from batchstats/core.py rename to batchstats/stats.py index 79adbd1..db1b131 100644 --- a/batchstats/core.py +++ b/batchstats/stats.py @@ -2,36 +2,7 @@ import numpy as np -from ._misc import check_params - - -class NoValidSamplesError(ValueError): - """ - Error raised when there are no valid samples for calculation. - """ - pass - - -class UnequalSamplesNumber(ValueError): - """ - Error raised when two batches have unequal lengths. - """ - pass - - -def any_nan(x, axis=None): - """ - Check if there are any NaN values in the input array. - - Args: - x (numpy.ndarray): Input array. - axis (int or tuple of ints, optional): Axis or axes along which to operate. Default is None. - - Returns: - numpy.ndarray: Boolean array indicating NaN presence. - - """ - return np.isnan(np.add.reduce(array=x, axis=axis)) +from ._misc import NoValidSamplesError, UnequalSamplesNumber, any_nan, check_params class BatchStat: diff --git a/tests/test_nanstats.py b/tests/test_nanstats.py new file mode 100644 index 0000000..f4228b7 --- /dev/null +++ b/tests/test_nanstats.py @@ -0,0 +1,40 @@ +import numpy as np +import pytest + +from batchstats import BatchNanMean, BatchNanSum + + +@pytest.fixture +def data(): + m, n = 1_000_000, 50 + nan_ratio = 0.05 + data = np.random.randn(m, n) + num_nans = int(m * n * nan_ratio) + nan_indices = np.random.choice(range(m * n), num_nans, replace=False) + data.ravel()[nan_indices] = np.nan + return data + + +@pytest.fixture +def n_batches(): + return 31 + + +def test_nansum(data, n_batches): + true_stat = np.nansum(data, axis=0) + + batchsum = BatchNanSum() + for batch_data in np.array_split(data, n_batches): + batchsum.update_batch(batch=batch_data) + batch_stat = batchsum() + assert np.allclose(true_stat, batch_stat) + + +def test_nanmean(data, n_batches): + true_stat = np.nanmean(data, axis=0) + + batchmean = BatchNanMean() + for batch_data in np.array_split(data, n_batches): + batchmean.update_batch(batch=batch_data) + batch_stat = batchmean() + assert np.allclose(true_stat, batch_stat) diff --git a/tests/test_core.py b/tests/test_stats.py similarity index 100% rename from tests/test_core.py rename to tests/test_stats.py