Skip to content

Commit

Permalink
Added horovod partition test for CombinedDataset with sampling_sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-wilken committed Oct 11, 2021
1 parent e06e39e commit f40d178
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions tests/test_Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,49 @@ def __init__(self, config):
assert set(data_out) == set(range(num_seqs))


def test_horovod_partition_combined_dataset_sampling():
num_seqs = 10
sampling_size = 12
dummy_data = [{"data": numpy.array([i])} for i in range(num_seqs)]
from returnn.datasets.meta import CombinedDataset
dataset = MapDatasetWrapper(FromListDataset(data_list=dummy_data))
combined_dataset = CombinedDataset(
datasets={"dataset": dataset}, data_map={("dataset", "data"): "data"}, sampling_sizes={"dataset": sampling_size},
data_dims={"data": (1, 1)}, seq_ordering="random")
from returnn.config import get_global_config
global_config = get_global_config(auto_create=True)
global_config.set("use_horovod", True)
global_config.set("horovod_dataset_distribution", "partition")
from returnn.tf import horovod

horovod_size = 3
data_out = []
for rank in range(horovod_size):
# Simulating a multi-gpu setup.
def get_dummy_ctx(config=None):
class DummyHorovodContext(horovod.HorovodContext):
def __init__(self, config):
self._rank = rank
self._size = horovod_size
self._config = config
return DummyHorovodContext(config or global_config)
horovod.get_ctx = get_dummy_ctx
combined_dataset.init_seq_order(epoch=None)
seq_idx = 0
while combined_dataset.is_less_than_num_seqs(seq_idx):
combined_dataset.load_seqs(seq_idx, seq_idx + 1)
data = combined_dataset.get_data(seq_idx, "data")
data_out.extend(data.tolist())
seq_idx += 1
# We sample 12 values from range(10) "in order", so 0 and 1 should appear twice, all other values once. This e.g.
# would not be the case if the sub-dataset is partitioned before sampling,
# see Dataset.disable_horovod_partition.
assert len(data_out) == sampling_size
assert set(data_out) == set(range(num_seqs))
assert data_out.count(0) == 2
assert data_out.count(1) == 2


if __name__ == "__main__":
better_exchook.install()
if len(sys.argv) <= 1:
Expand Down

0 comments on commit f40d178

Please sign in to comment.