diff --git a/README.md b/README.md index 6f8b7f3..1dc41f3 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ np.allclose(true_cov, batchcov()), batchcov().shape >>> (True, (100, 50)) ``` -`batchstats` is also flexible in terms of input shapes, with the first dimension always representing the samples and the remaining dimensions representing the features: +`batchstats` is also flexible in terms of input shapes. By default, statistics are applied on the first axis: the first dimension representing the samples and the remaining dimensions representing the features: ```python import numpy as np @@ -80,6 +80,27 @@ np.allclose(true_sum, batchsum()), batchsum().shape >>> (True, (80, 90)) ``` +But as in the ``numpy`` associated functions, the user can specify the reduction axis or axes: + +```python +import numpy as np +from batchstats import BatchMean + +data = [np.random.randn(24, 7, 128) for _ in range(100)] + +batchmean = BatchMean(axis=(0, 2)) +for batch in data: + batchmean.update_batch(batch) +batchmean().shape +>>> (7,) + +batchmean = BatchMean(axis=2) +for batch in data: + batchmean.update_batch(batch) +batchmean().shape +>>> (24, 7) +``` + ## Available Classes/Stats - `BatchCov`: Compute the covariance matrix of two datasets (not necessarily square)