Skip to content

Commit

Permalink
add BatchPeakToPeak
Browse files Browse the repository at this point in the history
  • Loading branch information
CyrilJl committed May 20, 2024
1 parent 7b26d7b commit face914
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
7 changes: 4 additions & 3 deletions batchstats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .nanstats import BatchNanMean, BatchNanStat, BatchNanSum
from .stats import BatchCov, BatchMax, BatchMean, BatchMin, BatchStat, BatchSum, BatchVar
from .stats import (BatchCov, BatchMax, BatchMean, BatchMin, BatchPeakToPeak,
BatchStat, BatchSum, BatchVar)

__all__ = ['BatchCov', 'BatchMax', 'BatchMean', 'BatchMin', 'BatchStat', 'BatchSum', 'BatchVar',
__all__ = ['BatchCov', 'BatchMax', 'BatchMean', 'BatchMin', 'BatchPeakToPeak', 'BatchStat', 'BatchSum', 'BatchVar',
'BatchNanMean', 'BatchNanStat', 'BatchNanSum']

__version__ = '0.3'
__version__ = '0.3.1'
43 changes: 42 additions & 1 deletion batchstats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import numpy as np

from ._misc import NoValidSamplesError, UnequalSamplesNumber, any_nan, check_params
from ._misc import (NoValidSamplesError, UnequalSamplesNumber, any_nan,
check_params)


class BatchStat:
Expand Down Expand Up @@ -239,6 +240,46 @@ def __call__(self) -> np.ndarray:
return self.mean.copy()


class BatchPeakToPeak(BatchStat):
"""
Class for calculating the peak-to-peak (max - min) of batches of data.
"""

def __init__(self):
super().__init__()
self.batchmax = BatchMax()
self.batchmin = BatchMin()

def update_batch(self, batch, assume_valid=False):
"""
Update the peak-to-peak with a new batch of data.
Args:
batch (numpy.ndarray): Input batch.
assume_valid (bool, optional): If True, assumes all elements in the batch are valid. Default is False.
Returns:
BatchPeakToPeak: Updated BatchPeakToPeak object.
"""
self.batchmax.update_batch(batch, assume_valid=assume_valid)
self.batchmin.update_batch(batch, assume_valid=assume_valid)
return self

def __call__(self) -> np.ndarray:
"""
Calculate the peak-to-peak.
Returns:
numpy.ndarray: Peak-to-peak of the batches.
Raises:
NoValidSamplesError: If no valid samples are available.
"""
return self.batchmax() - self.batchmin()


class BatchVar(BatchMean):
"""
Class for calculating the variance of batches of data.
Expand Down
13 changes: 12 additions & 1 deletion tests/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import pytest

from batchstats import BatchCov, BatchMax, BatchMean, BatchMin, BatchSum, BatchVar
from batchstats import (BatchCov, BatchMax, BatchMean, BatchMin,
BatchPeakToPeak, BatchSum, BatchVar)


@pytest.fixture
Expand Down Expand Up @@ -51,6 +52,16 @@ def test_mean(data, n_batches):
assert np.allclose(true_stat, batch_stat)


def test_ptp(data, n_batches):
true_stat = np.ptp(data, axis=0)

batchptp = BatchPeakToPeak()
for batch_data in np.array_split(data, n_batches):
batchptp.update_batch(batch=batch_data)
batch_stat = batchptp()
assert np.allclose(true_stat, batch_stat)


def test_sum(data, n_batches):
true_stat = np.sum(data, axis=0)

Expand Down

0 comments on commit face914

Please sign in to comment.