From f40d17829a35fa76583f599406973dc948590060 Mon Sep 17 00:00:00 2001 From: patrick-wilken Date: Mon, 11 Oct 2021 11:25:41 -0400 Subject: [PATCH] Added horovod partition test for CombinedDataset with sampling_sizes --- tests/test_Dataset.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/tests/test_Dataset.py b/tests/test_Dataset.py index 384bdea1bd..604108ddcb 100644 --- a/tests/test_Dataset.py +++ b/tests/test_Dataset.py @@ -357,6 +357,49 @@ def __init__(self, config): assert set(data_out) == set(range(num_seqs)) +def test_horovod_partition_combined_dataset_sampling(): + num_seqs = 10 + sampling_size = 12 + dummy_data = [{"data": numpy.array([i])} for i in range(num_seqs)] + from returnn.datasets.meta import CombinedDataset + dataset = MapDatasetWrapper(FromListDataset(data_list=dummy_data)) + combined_dataset = CombinedDataset( + datasets={"dataset": dataset}, data_map={("dataset", "data"): "data"}, sampling_sizes={"dataset": sampling_size}, + data_dims={"data": (1, 1)}, seq_ordering="random") + from returnn.config import get_global_config + global_config = get_global_config(auto_create=True) + global_config.set("use_horovod", True) + global_config.set("horovod_dataset_distribution", "partition") + from returnn.tf import horovod + + horovod_size = 3 + data_out = [] + for rank in range(horovod_size): + # Simulating a multi-gpu setup. + def get_dummy_ctx(config=None): + class DummyHorovodContext(horovod.HorovodContext): + def __init__(self, config): + self._rank = rank + self._size = horovod_size + self._config = config + return DummyHorovodContext(config or global_config) + horovod.get_ctx = get_dummy_ctx + combined_dataset.init_seq_order(epoch=None) + seq_idx = 0 + while combined_dataset.is_less_than_num_seqs(seq_idx): + combined_dataset.load_seqs(seq_idx, seq_idx + 1) + data = combined_dataset.get_data(seq_idx, "data") + data_out.extend(data.tolist()) + seq_idx += 1 + # We sample 12 values from range(10) "in order", so 0 and 1 should appear twice, all other values once. This e.g. + # would not be the case if the sub-dataset is partitioned before sampling, + # see Dataset.disable_horovod_partition. + assert len(data_out) == sampling_size + assert set(data_out) == set(range(num_seqs)) + assert data_out.count(0) == 2 + assert data_out.count(1) == 2 + + if __name__ == "__main__": better_exchook.install() if len(sys.argv) <= 1: