diff --git a/README.md b/README.md index 4429a0a..98efdb5 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,30 @@ np.allclose(a, b) >>> 306 ms ± 5.09 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) ``` +## NaN handling possibility + +While the previous `Batch*` classes exclude every sample containing at least one NaN from the computations, the `BatchNan*` classes adopt a more flexible approach to handling NaN values, similar to `np.nansum`, `np.nanmean`, etc. Consequently, the outputted statistics can be computed from various numbers of samples for each feature: + +```python +import numpy as np +from batchstats import BatchNanSum + +m, n = 1_000_000, 50 +nan_ratio = 0.05 +n_batches = 17 + +data = np.random.randn(m, n) +num_nans = int(m * n * nan_ratio) +nan_indices = np.random.choice(range(m * n), num_nans, replace=False) +data.ravel()[nan_indices] = np.nan + +batchsum = BatchNanSum() +for batch_data in np.array_split(data, n_batches): + batchsum.update_batch(batch=batch_data) +np.allclose(np.nansum(data, axis=0), batchsum()) +>>> True +``` + ## Documentation The documentation is available [here](https://batchstats.readthedocs.io/en/latest/). diff --git a/batchstats/__init__.py b/batchstats/__init__.py index f4f2591..ca89757 100644 --- a/batchstats/__init__.py +++ b/batchstats/__init__.py @@ -4,4 +4,4 @@ __all__ = ['BatchCov', 'BatchMax', 'BatchMean', 'BatchMin', 'BatchStat', 'BatchSum', 'BatchVar', 'BatchNanMean', 'BatchNanStat', 'BatchNanSum'] -__version__ = '0.2' +__version__ = '0.3' diff --git a/batchstats/nanstats.py b/batchstats/nanstats.py index 6e0ca6c..7ea864a 100644 --- a/batchstats/nanstats.py +++ b/batchstats/nanstats.py @@ -4,10 +4,31 @@ class BatchNanStat: + """ + Base class for calculating statistics over batches of data that can contain NaN values. + + Attributes: + n_samples (numpy.ndarray): Total number of samples processed, accounting for NaN values. + + """ + def __init__(self): + """ + Initialize the BatchNanStat object. + """ self.n_samples = None def _process_batch(self, batch): + """ + Process the input batch, counting NaN values. + + Args: + batch (numpy.ndarray): Input batch. + + Returns: + numpy.ndarray: Processed batch. + + """ batch = np.atleast_2d(np.asarray(batch)) axis = tuple(range(1, batch.ndim)) if self.n_samples is None: @@ -18,11 +39,29 @@ def _process_batch(self, batch): class BatchNanSum(BatchNanStat): + """ + Class for calculating the sum of batches of data that can contain NaN values. + + """ + def __init__(self): + """ + Initialize the BatchNanSum object. + """ super().__init__() self.sum = None def update_batch(self, batch): + """ + Update the sum with a new batch of data that can contain NaN values. + + Args: + batch (numpy.ndarray): Input batch. + + Returns: + BatchNanSum: Updated BatchNanSum object. + + """ batch = self._process_batch(batch) axis = tuple(range(1, batch.ndim)) if self.sum is None: @@ -32,6 +71,16 @@ def update_batch(self, batch): return self def __call__(self): + """ + Calculate the sum of the batches that can contain NaN values. + + Returns: + numpy.ndarray: Sum of the batches. + + Raises: + NoValidSamplesError: If no valid samples are available. + + """ if self.sum is None: raise NoValidSamplesError() else: @@ -39,13 +88,38 @@ def __call__(self): class BatchNanMean(BatchNanStat): + """ + Class for calculating the mean of batches of data that can contain NaN values. + + """ + def __init__(self): + """ + Initialize the BatchNanMean object. + """ super().__init__() self.sum = BatchNanSum() def update_batch(self, batch): + """ + Update the mean with a new batch of data that can contain NaN values. + + Args: + batch (numpy.ndarray): Input batch. + + Returns: + BatchNanMean: Updated BatchNanMean object. + + """ self.sum.update_batch(batch) return self def __call__(self): - return self.sum()/self.sum.n_samples + """ + Calculate the mean of the batches that can contain NaN values. + + Returns: + numpy.ndarray: Mean of the batches. + + """ + return self.sum() / self.sum.n_samples diff --git a/batchstats/stats.py b/batchstats/stats.py index db1b131..601b43e 100644 --- a/batchstats/stats.py +++ b/batchstats/stats.py @@ -243,6 +243,8 @@ class BatchVar(BatchMean): """ Class for calculating the variance of batches of data. + Args: + ddof (int, optional): Means Delta Degrees of Freedom. The divisor used in calculations is N - ddof, where N represents the number of elements. By default ddof is zero. """ def __init__(self, ddof=0): @@ -272,9 +274,9 @@ def init_var(cls, v, vm): def compute_incremental_variance(v, p, u): """ Compute incremental variance. - For v 2D and p/u 1D, equivalent to ((v-p).T@(v-u)).sum(axis=0) or - np.einsum('ji,ji->i', v - p, v - u). faster and less memory consumer because - no intermediate 2d array are created. + For v 2D and p/u 1D, equivalent to ``((v-p).T@(v-u)).sum(axis=0)`` or + ``np.einsum('ji,ji->i', v - p, v - u)``. Faster and less memory consumer because + no intermediate 2D array are created. Args: v (numpy.ndarray): Input data. @@ -341,6 +343,8 @@ class BatchCov(BatchStat): """ Class for calculating the covariance of batches of data. + Args: + ddof (int, optional): Means Delta Degrees of Freedom. The divisor used in calculations is N - ddof, where N represents the number of elements. By default ddof is zero. """ def __init__(self, ddof=0): diff --git a/docs/source/api_reference.rst b/docs/source/api_reference.rst index 45f0310..fc4f34a 100644 --- a/docs/source/api_reference.rst +++ b/docs/source/api_reference.rst @@ -10,7 +10,19 @@ The class ``BatchStat`` is the parent class from which other classes inherit. It The following classes inherit from ``BatchStat``, and enable the user to compute various statistics over batch-accessed data: .. automodule:: batchstats - :members: + :members: BatchCov, BatchMax, BatchMean, BatchMin, BatchSum, BatchVar :undoc-members: :show-inheritance: - :exclude-members: BatchStat \ No newline at end of file + + +The class ``BatchNanStat`` is the parent class from which other classes that can treat NaNs inherit. It allows for the factorization of the ``_process_batch`` method, which keeps track of the number of NaNs per feature. + +.. autoclass:: batchstats.BatchNanStat + +The following classes inherit from ``BatchNanStat``: + +.. automodule:: batchstats + :members: BatchNanMean, BatchNanSum + :undoc-members: + :show-inheritance: + :no-index: \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index af3757e..c58bffe 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -27,7 +27,6 @@ templates_path = ['_templates'] exclude_patterns = [] - # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output