From e0bce245266ae281e9ac1465b62917cae758b17d Mon Sep 17 00:00:00 2001 From: CyrilJl Date: Sun, 9 Jun 2024 12:02:58 +0200 Subject: [PATCH] 0.4.1 --- batchstats/__init__.py | 2 +- batchstats/_misc.py | 21 +++++++++++++++++ batchstats/core.py | 11 ++++++++- batchstats/stats.py | 52 +++++++++++++++++++++++++++++++++++++++++- docs/source/future.rst | 8 +++++++ setup.py | 21 ++++++++--------- tests/test_stats.py | 4 ++-- 7 files changed, 103 insertions(+), 16 deletions(-) diff --git a/batchstats/__init__.py b/batchstats/__init__.py index 3986f46..bcba35e 100644 --- a/batchstats/__init__.py +++ b/batchstats/__init__.py @@ -4,4 +4,4 @@ __all__ = ['BatchCov', 'BatchMax', 'BatchMean', 'BatchMin', 'BatchPeakToPeak', 'BatchStat', 'BatchStd', 'BatchSum', 'BatchVar', 'BatchNanMean', 'BatchNanStat', 'BatchNanSum'] -__version__ = '0.4' +__version__ = '0.4.1' diff --git a/batchstats/_misc.py b/batchstats/_misc.py index 61eebe5..01237d5 100644 --- a/batchstats/_misc.py +++ b/batchstats/_misc.py @@ -20,6 +20,27 @@ class UnequalSamplesNumber(ValueError): pass +class DifferentAxisError(ValueError): + """ + Error raised when two BatchStats objects are merged but have different `axis`. + """ + pass + + +class DifferentShapesError(ValueError): + """ + Error raised when two BatchStats objects are merged but have different shapes. + """ + pass + + +class DifferentStatsError(ValueError): + """ + Error raised when two BatchStats objects are merged but hare not of the same type. + """ + pass + + def any_nan(x, axis=None): """ Check if there are any NaN values in the input array. diff --git a/batchstats/core.py b/batchstats/core.py index c180058..b4bfece 100644 --- a/batchstats/core.py +++ b/batchstats/core.py @@ -1,6 +1,6 @@ import numpy as np -from ._misc import any_nan, check_params +from ._misc import DifferentAxisError, DifferentShapesError, DifferentStatsError, any_nan, check_params class BatchStat: @@ -51,6 +51,15 @@ def _process_batch(self, batch, assume_valid=False): def __repr__(self): return f"{self.__class__.__name__}()" + def merge_test(self, other, field: str): + if type(self) != type(other): + raise DifferentStatsError() + if self.axis != other.axis: + raise DifferentAxisError() + if hasattr(self, field) and hasattr(other, field): + if getattr(self, field).shape != getattr(other, field).shape: + raise DifferentShapesError() + class BatchNanStat: """ diff --git a/batchstats/stats.py b/batchstats/stats.py index 8456f52..98cd927 100644 --- a/batchstats/stats.py +++ b/batchstats/stats.py @@ -53,6 +53,18 @@ def __call__(self) -> np.ndarray: else: return self.sum.copy() + def __add__(self, other): + self.merge_test(other, field='sum') + if self.n_samples == 0: + return other + elif other.n_samples == 0: + return self + else: + ret = BatchSum(axis=self.axis) + ret.n_samples = self.n_samples + other.n_samples + ret.sum = self.sum + other.sum + return ret + class BatchMax(BatchStat): """ @@ -101,6 +113,18 @@ def __call__(self) -> np.ndarray: else: return self.max.copy() + def __add__(self, other): + self.merge_test(other, field='max') + if self.n_samples == 0: + return other + elif other.n_samples == 0: + return self + else: + ret = BatchMax(axis=self.axis) + ret.n_samples = self.n_samples + other.n_samples + ret.max = np.maximum(self.max, other.max) + return ret + class BatchMin(BatchStat): """ @@ -149,6 +173,18 @@ def __call__(self) -> np.ndarray: else: return self.min.copy() + def __add__(self, other): + self.merge_test(other, field='min') + if self.n_samples == 0: + return other + elif other.n_samples == 0: + return self + else: + ret = BatchMin(axis=self.axis) + ret.n_samples = self.n_samples + other.n_samples + ret.max = np.minimum(self.max, other.max) + return ret + class BatchMean(BatchStat): """ @@ -178,7 +214,8 @@ def update_batch(self, batch, assume_valid=False): if self.mean is None: self.mean = np.mean(valid_batch, axis=self.axis) else: - self.mean = ((self.n_samples - n) * self.mean + np.sum(valid_batch, axis=self.axis)) / self.n_samples + mean_batch = np.mean(valid_batch, axis=self.axis) + self.mean = ((self.n_samples - n) * self.mean + np.sum(valid_batch-mean_batch, axis=self.axis) + n*mean_batch) / self.n_samples return self def __call__(self) -> np.ndarray: @@ -197,6 +234,19 @@ def __call__(self) -> np.ndarray: else: return self.mean.copy() + def __add__(self, other): + self.merge_test(other, field='mean') + if self.n_samples == 0: + return other + elif other.n_samples == 0: + return self + else: + ret = BatchMean(axis=self.axis) + ret.n_samples = self.n_samples + other.n_samples + ret.mean = self.n_samples*self.mean + other.n_samples*other.mean + ret.mean /= ret.n_samples + return ret + class BatchPeakToPeak(BatchStat): """ diff --git a/docs/source/future.rst b/docs/source/future.rst index 3162a15..5c4acd4 100644 --- a/docs/source/future.rst +++ b/docs/source/future.rst @@ -1,5 +1,13 @@ .. Future Development +What's New ? +============ + +Version 0.4.1 (June 9, 2024) +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +- Improved numerical stability of the ``BatchMean``'s algorithm, based on `Numerically Stable Parallel Computation of (Co-)Variance by Schubert and Gertz`_ + + Future Plans ============ diff --git a/setup.py b/setup.py index ba5f99d..fd19f4c 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +import numpy as np from setuptools import find_packages, setup # Read version from the __init__.py file @@ -7,14 +8,12 @@ version = line.strip().split()[-1][1:-1] break -setup( - name='batchstats', - version=version, - author='Cyril Joly', - description='Efficient batch statistics computation library for Python.', - long_description=open('README.md').read(), - long_description_content_type='text/markdown', - url='https://github.com/CyrilJl/BatchStats', - packages=find_packages(), - install_requires=['numpy'], -) +setup(name='batchstats', + version=version, + author='Cyril Joly', + description='Efficient batch statistics computation library for Python.', + long_description=open('README.md').read(), + long_description_content_type='text/markdown', + url='https://github.com/CyrilJl/BatchStats', + packages=find_packages(), + install_requires=['numpy']) diff --git a/tests/test_stats.py b/tests/test_stats.py index d32e1dc..fc74f0c 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -7,13 +7,13 @@ @pytest.fixture def data(): m, n = 1_000_000, 50 - return np.random.randn(m, n) + return 1e1*np.random.randn(m, n) + 1e3 @pytest.fixture def data_2d_features(): m, n, o = 100_000, 50, 60 - return np.random.randn(m, n, o) + return 1e1*np.random.randn(m, n, o) + 1e3 @pytest.fixture