diff --git a/returnn/datasets/basic.py b/returnn/datasets/basic.py index b2884ffef6..857da2b33f 100644 --- a/returnn/datasets/basic.py +++ b/returnn/datasets/basic.py @@ -121,6 +121,7 @@ def __init__(self, name=None, self.random_seed_offset = random_seed_offset self.partition_epoch = partition_epoch or 1 self.repeat_epoch = repeat_epoch or 1 + self.disable_horovod_partition = False # can be set by meta-dataset to handle multi-gpu partitioning on meta-level self.seq_tags_filter = set(self._load_seq_list_file(seq_list_filter_file)) if seq_list_filter_file else None self.unique_seq_tags = unique_seq_tags self._seq_order_seq_lens_file = seq_order_seq_lens_file @@ -483,6 +484,8 @@ def get_seq_order_for_epoch(self, epoch, num_seqs, get_seq_len=None): seq_index = self._apply_partition_epoch(seq_index, partition_epoch, epoch) if repeat_epoch > 1: seq_index = seq_index * repeat_epoch + if not self.disable_horovod_partition: + seq_index = self._apply_multi_gpu_partition(seq_index) if self.seq_tags_filter is not None: # Note: This is as generic as possible, but requires that get_all_tags is implemented. assert seq_index @@ -517,6 +520,31 @@ def _apply_partition_epoch(cls, seq_index, partition_epoch, epoch): return seq_index + @classmethod + def _apply_multi_gpu_partition(cls, seq_index): + """ + Via horovod_dataset_distribution = "partition", does nothing if not set. + + :param list[int] seq_index: + :return: partition of seq_index for the current processes, i.e. we split onto the different GPUs + :rtype: list[int] + """ + from returnn.config import get_global_config + config = get_global_config(raise_exception=False) + if not config or not config.is_true("use_horovod"): + return seq_index + + import returnn.tf.horovod + if not returnn.tf.horovod.get_ctx().get_dataset_distribution_type() == "partition": + return seq_index + + rank = returnn.tf.horovod.get_ctx().rank() + 1 # one-based to make work as "epoch" + num_gpus = returnn.tf.horovod.get_ctx().size() + + # Reuse the partition epoch logic to split current sub-epoch between different GPUs. + seq_index = cls._apply_partition_epoch(seq_index, partition_epoch=num_gpus, epoch=rank) + return seq_index + def _get_random_seed_for_epoch(self, epoch): """ :param int|None epoch: diff --git a/returnn/datasets/map.py b/returnn/datasets/map.py index f31a2e9d0c..d73a26efe8 100644 --- a/returnn/datasets/map.py +++ b/returnn/datasets/map.py @@ -102,8 +102,11 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): """ super(MapDatasetWrapper, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) - if seq_list or seq_order: + if seq_list: raise NotImplementedError + if seq_order: + self._seq_order = seq_order + return True try: self._seq_order = self._dataset.get_seq_order(epoch=epoch) diff --git a/returnn/datasets/meta.py b/returnn/datasets/meta.py index a825fcecf8..9cfae6e49f 100644 --- a/returnn/datasets/meta.py +++ b/returnn/datasets/meta.py @@ -863,8 +863,9 @@ def __init__(self, # This will only initialize datasets needed for features occurring in data_map self.datasets = {key: init_dataset(datasets[key]) for key in self.dataset_keys} - self._estimated_num_seqs = sum([self.datasets[k].estimated_num_seqs for k in sorted(self.datasets.keys())]) self.estimated_num_seq_per_subset = [self.datasets[k].estimated_num_seqs for k in sorted(self.datasets.keys())] + if all(num_seq is not None for num_seq in self.estimated_num_seq_per_subset): + self._estimated_num_seqs = sum(self.estimated_num_seq_per_subset) if data_dims: data_dims = convert_data_dims(data_dims) @@ -913,6 +914,9 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): # partition epoch of the individual sub-datasets is still supported. Later we will call init_seq_order again with a # sequence list to e.g. apply joint sorting or partition epoch of all sequences. for dataset in self.datasets.values(): + if self.sampling_sizes: + # Partitioning does not make sense if we sample a fixed number of sequences anyway. + dataset.disable_horovod_partition = True dataset.init_seq_order(epoch=epoch) # noinspection PyBroadException @@ -1076,6 +1080,8 @@ def _get_sampling_seq_order(self): # We want to additionally sort the sequences in the current sample. For this, create a sequence order on a # range of length of the number of sequences in the sample. Note that we have to map the indices to make use # of self._get_seq_length here. + # This get_seq_order_for_epoch call now also handles horovod_dataset_distribution = 'partition', which we + # disabled on sub-dataset level via 'disable_horovod_partition' above. seq_order_remapping = self.get_seq_order_for_epoch( epoch=epoch, num_seqs=len(seq_order), get_seq_len=lambda i: self._get_seq_length(seq_order[i])) diff --git a/returnn/tf/horovod.py b/returnn/tf/horovod.py index 6abf05534b..29aa5e85f4 100644 --- a/returnn/tf/horovod.py +++ b/returnn/tf/horovod.py @@ -99,7 +99,7 @@ def get_dataset_distribution_type(self): :rtype: str """ dataset_distribution = self._config.value("horovod_dataset_distribution", "shard") - assert dataset_distribution in {"shard", "random_seed_offset"} + assert dataset_distribution in {"shard", "random_seed_offset", "partition"} return dataset_distribution def is_dataset_distribution_shard(self): diff --git a/tests/test_Dataset.py b/tests/test_Dataset.py index 51fbe604a1..604108ddcb 100644 --- a/tests/test_Dataset.py +++ b/tests/test_Dataset.py @@ -6,8 +6,10 @@ import sys import _setup_test_env # noqa import unittest +import numpy from nose.tools import assert_equal, assert_is_instance, assert_in, assert_not_in, assert_true, assert_false from returnn.datasets.generating import GeneratingDataset, DummyDataset, DummyDatasetMultipleSequenceLength +from returnn.datasets.map import FromListDataset, MapDatasetWrapper from returnn.engine.batch import Batch from returnn.datasets.basic import DatasetSeq from returnn.util.basic import NumbersDict @@ -320,6 +322,84 @@ def test_task12ax_window(): assert_equal(list(data2a[-1, 2]), [0] * input_dim) # zero-padded right +def test_horovod_partition(): + num_seqs = 10 + dummy_data = [{"data": numpy.array([i])} for i in range(num_seqs)] + # FromListDataset because DummyDataset does not support sequence ordering and thus no partitioning. + dataset = MapDatasetWrapper( + FromListDataset(data_list=dummy_data, data_types=None), seq_ordering="random") + from returnn.config import get_global_config + global_config = get_global_config(auto_create=True) + global_config.set("use_horovod", True) + global_config.set("horovod_dataset_distribution", "partition") + from returnn.tf import horovod + + horovod_size = 3 + data_out = [] + for rank in range(horovod_size): + # Simulating a multi-gpu setup. + def get_dummy_ctx(config=None): + class DummyHorovodContext(horovod.HorovodContext): + def __init__(self, config): + self._rank = rank + self._size = horovod_size + self._config = config + return DummyHorovodContext(config or global_config) + horovod.get_ctx = get_dummy_ctx + dataset.init_seq_order(epoch=1) + seq_idx = 0 + while dataset.is_less_than_num_seqs(seq_idx): + dataset.load_seqs(seq_idx, seq_idx + 1) + data = dataset.get_data(seq_idx, "data") + data_out.extend(data.tolist()) + seq_idx += 1 + assert len(data_out) == num_seqs + assert set(data_out) == set(range(num_seqs)) + + +def test_horovod_partition_combined_dataset_sampling(): + num_seqs = 10 + sampling_size = 12 + dummy_data = [{"data": numpy.array([i])} for i in range(num_seqs)] + from returnn.datasets.meta import CombinedDataset + dataset = MapDatasetWrapper(FromListDataset(data_list=dummy_data)) + combined_dataset = CombinedDataset( + datasets={"dataset": dataset}, data_map={("dataset", "data"): "data"}, sampling_sizes={"dataset": sampling_size}, + data_dims={"data": (1, 1)}, seq_ordering="random") + from returnn.config import get_global_config + global_config = get_global_config(auto_create=True) + global_config.set("use_horovod", True) + global_config.set("horovod_dataset_distribution", "partition") + from returnn.tf import horovod + + horovod_size = 3 + data_out = [] + for rank in range(horovod_size): + # Simulating a multi-gpu setup. + def get_dummy_ctx(config=None): + class DummyHorovodContext(horovod.HorovodContext): + def __init__(self, config): + self._rank = rank + self._size = horovod_size + self._config = config + return DummyHorovodContext(config or global_config) + horovod.get_ctx = get_dummy_ctx + combined_dataset.init_seq_order(epoch=None) + seq_idx = 0 + while combined_dataset.is_less_than_num_seqs(seq_idx): + combined_dataset.load_seqs(seq_idx, seq_idx + 1) + data = combined_dataset.get_data(seq_idx, "data") + data_out.extend(data.tolist()) + seq_idx += 1 + # We sample 12 values from range(10) "in order", so 0 and 1 should appear twice, all other values once. This e.g. + # would not be the case if the sub-dataset is partitioned before sampling, + # see Dataset.disable_horovod_partition. + assert len(data_out) == sampling_size + assert set(data_out) == set(range(num_seqs)) + assert data_out.count(0) == 2 + assert data_out.count(1) == 2 + + if __name__ == "__main__": better_exchook.install() if len(sys.argv) <= 1: