diff --git a/README.md b/README.md index f46519a..9a33e67 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,25 @@ print("Batch Mean:", batchmean()) print("Batch Variance:", batchvar()) ``` +It is also possible to compute the covariance between two datasets: + +```python +import numpy as np +from batchstats import BatchCov + +n_samples, m, n = 10_000, 100, 50 +data1 = np.random.randn(n_samples, m) +data2 = np.random.randn(n_samples, n) +n_batches = 7 + +batchcov = BatchCov() +for batch_index in np.array_split(np.arange(n_samples), n_batches): + batchcov.update_batch(batch1=data1[batch_index], batch2=data2[batch_index]) +true_cov = (data1 - data1.mean(axis=0)).T@(data2 - data2.mean(axis=0))/n_samples +np.allclose(true_cov, batchcov()), batchcov().shape +>>> (True, (100, 50)) +``` + `batchstats` is also flexible in terms of input shapes, with the first dimension always representing the samples and the remaining dimensions representing the features: ```python