Skip to content

Commit

Permalink
add nanstats
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrilJl committed May 13, 2024
1 parent f95afd9 commit 212c2e4
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 33 deletions.
8 changes: 5 additions & 3 deletions batchstats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .core import BatchCov, BatchMax, BatchMean, BatchMin, BatchStat, BatchSum, BatchVar
from .nanstats import BatchNanMean, BatchNanStat, BatchNanSum
from .stats import BatchCov, BatchMax, BatchMean, BatchMin, BatchStat, BatchSum, BatchVar

__all__ = ['BatchCov', 'BatchMax', 'BatchMean', 'BatchMin', 'BatchStat', 'BatchSum', 'BatchVar']
__all__ = ['BatchCov', 'BatchMax', 'BatchMean', 'BatchMin', 'BatchStat', 'BatchSum', 'BatchVar',
'BatchNanMean', 'BatchNanStat', 'BatchNanSum']

__version__ = '0.2'
__version__ = '0.2'
31 changes: 31 additions & 0 deletions batchstats/_misc.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,40 @@
import warnings

import numpy as np

# Customize the warning format
warnings.formatwarning = lambda msg, *args, **kwargs: str(msg) + '\n'


class NoValidSamplesError(ValueError):
"""
Error raised when there are no valid samples for calculation.
"""
pass


class UnequalSamplesNumber(ValueError):
"""
Error raised when two batches have unequal lengths.
"""
pass


def any_nan(x, axis=None):
"""
Check if there are any NaN values in the input array.
Args:
x (numpy.ndarray): Input array.
axis (int or tuple of ints, optional): Axis or axes along which to operate. Default is None.
Returns:
numpy.ndarray: Boolean array indicating NaN presence.
"""
return np.isnan(np.add.reduce(array=x, axis=axis))


def check_params(param, params=None, types=None):
# Check if the parameter's type matches the accepted types
if (types is not None) and (not isinstance(param, types)):
Expand Down
51 changes: 51 additions & 0 deletions batchstats/nanstats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np

from ._misc import NoValidSamplesError


class BatchNanStat:
def __init__(self):
self.n_samples = None

def _process_batch(self, batch):
batch = np.atleast_2d(np.asarray(batch))
axis = tuple(range(1, batch.ndim))
if self.n_samples is None:
self.n_samples = np.isfinite(batch).sum(axis=0)
else:
self.n_samples += np.isfinite(batch).sum(axis=0)
return batch


class BatchNanSum(BatchNanStat):
def __init__(self):
super().__init__()
self.sum = None

def update_batch(self, batch):
batch = self._process_batch(batch)
axis = tuple(range(1, batch.ndim))
if self.sum is None:
self.sum = np.nansum(batch, axis=0)
else:
self.sum += np.nansum(batch, axis=0)
return self

def __call__(self):
if self.sum is None:
raise NoValidSamplesError()
else:
return np.where(self.n_samples > 0, self.sum, np.nan)


class BatchNanMean(BatchNanStat):
def __init__(self):
super().__init__()
self.sum = BatchNanSum()

def update_batch(self, batch):
self.sum.update_batch(batch)
return self

def __call__(self):
return self.sum()/self.sum.n_samples
31 changes: 1 addition & 30 deletions batchstats/core.py → batchstats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,7 @@

import numpy as np

from ._misc import check_params


class NoValidSamplesError(ValueError):
"""
Error raised when there are no valid samples for calculation.
"""
pass


class UnequalSamplesNumber(ValueError):
"""
Error raised when two batches have unequal lengths.
"""
pass


def any_nan(x, axis=None):
"""
Check if there are any NaN values in the input array.
Args:
x (numpy.ndarray): Input array.
axis (int or tuple of ints, optional): Axis or axes along which to operate. Default is None.
Returns:
numpy.ndarray: Boolean array indicating NaN presence.
"""
return np.isnan(np.add.reduce(array=x, axis=axis))
from ._misc import NoValidSamplesError, UnequalSamplesNumber, any_nan, check_params


class BatchStat:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_nanstats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import numpy as np
import pytest

from batchstats import BatchNanMean, BatchNanSum


@pytest.fixture
def data():
m, n = 1_000_000, 50
nan_ratio = 0.05
data = np.random.randn(m, n)
num_nans = int(m * n * nan_ratio)
nan_indices = np.random.choice(range(m * n), num_nans, replace=False)
data.ravel()[nan_indices] = np.nan
return data


@pytest.fixture
def n_batches():
return 31


def test_nansum(data, n_batches):
true_stat = np.nansum(data, axis=0)

batchsum = BatchNanSum()
for batch_data in np.array_split(data, n_batches):
batchsum.update_batch(batch=batch_data)
batch_stat = batchsum()
assert np.allclose(true_stat, batch_stat)


def test_nanmean(data, n_batches):
true_stat = np.nanmean(data, axis=0)

batchmean = BatchNanMean()
for batch_data in np.array_split(data, n_batches):
batchmean.update_batch(batch=batch_data)
batch_stat = batchmean()
assert np.allclose(true_stat, batch_stat)
File renamed without changes.

0 comments on commit 212c2e4

Please sign in to comment.