Skip to content

Commit

Permalink
0.4.1
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrilJl committed Jun 9, 2024
1 parent 838c0a9 commit e0bce24
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 16 deletions.
2 changes: 1 addition & 1 deletion batchstats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
__all__ = ['BatchCov', 'BatchMax', 'BatchMean', 'BatchMin', 'BatchPeakToPeak', 'BatchStat', 'BatchStd', 'BatchSum',
'BatchVar', 'BatchNanMean', 'BatchNanStat', 'BatchNanSum']

__version__ = '0.4'
__version__ = '0.4.1'
21 changes: 21 additions & 0 deletions batchstats/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,27 @@ class UnequalSamplesNumber(ValueError):
pass


class DifferentAxisError(ValueError):
"""
Error raised when two BatchStats objects are merged but have different `axis`.
"""
pass


class DifferentShapesError(ValueError):
"""
Error raised when two BatchStats objects are merged but have different shapes.
"""
pass


class DifferentStatsError(ValueError):
"""
Error raised when two BatchStats objects are merged but hare not of the same type.
"""
pass


def any_nan(x, axis=None):
"""
Check if there are any NaN values in the input array.
Expand Down
11 changes: 10 additions & 1 deletion batchstats/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from ._misc import any_nan, check_params
from ._misc import DifferentAxisError, DifferentShapesError, DifferentStatsError, any_nan, check_params


class BatchStat:
Expand Down Expand Up @@ -51,6 +51,15 @@ def _process_batch(self, batch, assume_valid=False):
def __repr__(self):
return f"{self.__class__.__name__}()"

def merge_test(self, other, field: str):
if type(self) != type(other):
raise DifferentStatsError()
if self.axis != other.axis:
raise DifferentAxisError()
if hasattr(self, field) and hasattr(other, field):
if getattr(self, field).shape != getattr(other, field).shape:
raise DifferentShapesError()


class BatchNanStat:
"""
Expand Down
52 changes: 51 additions & 1 deletion batchstats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def __call__(self) -> np.ndarray:
else:
return self.sum.copy()

def __add__(self, other):
self.merge_test(other, field='sum')
if self.n_samples == 0:
return other
elif other.n_samples == 0:
return self
else:
ret = BatchSum(axis=self.axis)
ret.n_samples = self.n_samples + other.n_samples
ret.sum = self.sum + other.sum
return ret


class BatchMax(BatchStat):
"""
Expand Down Expand Up @@ -101,6 +113,18 @@ def __call__(self) -> np.ndarray:
else:
return self.max.copy()

def __add__(self, other):
self.merge_test(other, field='max')
if self.n_samples == 0:
return other
elif other.n_samples == 0:
return self
else:
ret = BatchMax(axis=self.axis)
ret.n_samples = self.n_samples + other.n_samples
ret.max = np.maximum(self.max, other.max)
return ret


class BatchMin(BatchStat):
"""
Expand Down Expand Up @@ -149,6 +173,18 @@ def __call__(self) -> np.ndarray:
else:
return self.min.copy()

def __add__(self, other):
self.merge_test(other, field='min')
if self.n_samples == 0:
return other
elif other.n_samples == 0:
return self
else:
ret = BatchMin(axis=self.axis)
ret.n_samples = self.n_samples + other.n_samples
ret.max = np.minimum(self.max, other.max)
return ret


class BatchMean(BatchStat):
"""
Expand Down Expand Up @@ -178,7 +214,8 @@ def update_batch(self, batch, assume_valid=False):
if self.mean is None:
self.mean = np.mean(valid_batch, axis=self.axis)
else:
self.mean = ((self.n_samples - n) * self.mean + np.sum(valid_batch, axis=self.axis)) / self.n_samples
mean_batch = np.mean(valid_batch, axis=self.axis)
self.mean = ((self.n_samples - n) * self.mean + np.sum(valid_batch-mean_batch, axis=self.axis) + n*mean_batch) / self.n_samples
return self

def __call__(self) -> np.ndarray:
Expand All @@ -197,6 +234,19 @@ def __call__(self) -> np.ndarray:
else:
return self.mean.copy()

def __add__(self, other):
self.merge_test(other, field='mean')
if self.n_samples == 0:
return other
elif other.n_samples == 0:
return self
else:
ret = BatchMean(axis=self.axis)
ret.n_samples = self.n_samples + other.n_samples
ret.mean = self.n_samples*self.mean + other.n_samples*other.mean
ret.mean /= ret.n_samples
return ret


class BatchPeakToPeak(BatchStat):
"""
Expand Down
8 changes: 8 additions & 0 deletions docs/source/future.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
.. Future Development
What's New ?
============

Version 0.4.1 (June 9, 2024)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- Improved numerical stability of the ``BatchMean``'s algorithm, based on `Numerically Stable Parallel Computation of (Co-)Variance by Schubert and Gertz<https://ds.ifi.uni-heidelberg.de/files/Team/eschubert/publications/SSDBM18-covariance-authorcopy.pdf>`_


Future Plans
============

Expand Down
21 changes: 10 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from setuptools import find_packages, setup

# Read version from the __init__.py file
Expand All @@ -7,14 +8,12 @@
version = line.strip().split()[-1][1:-1]
break

setup(
name='batchstats',
version=version,
author='Cyril Joly',
description='Efficient batch statistics computation library for Python.',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
url='https://github.com/CyrilJl/BatchStats',
packages=find_packages(),
install_requires=['numpy'],
)
setup(name='batchstats',
version=version,
author='Cyril Joly',
description='Efficient batch statistics computation library for Python.',
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
url='https://github.com/CyrilJl/BatchStats',
packages=find_packages(),
install_requires=['numpy'])
4 changes: 2 additions & 2 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
@pytest.fixture
def data():
m, n = 1_000_000, 50
return np.random.randn(m, n)
return 1e1*np.random.randn(m, n) + 1e3


@pytest.fixture
def data_2d_features():
m, n, o = 100_000, 50, 60
return np.random.randn(m, n, o)
return 1e1*np.random.randn(m, n, o) + 1e3


@pytest.fixture
Expand Down

0 comments on commit e0bce24

Please sign in to comment.