diff --git a/returnn/datasets/generating.py b/returnn/datasets/generating.py index 9b20d1c181..312c970db3 100644 --- a/returnn/datasets/generating.py +++ b/returnn/datasets/generating.py @@ -729,6 +729,7 @@ def __init__(self, input_dim, output_dim, num_seqs, seq_len=None, """ if seq_len is None: seq_len = {'data': 10, 'classes': 20} + self._seq_order = None # type: typing.Optional[typing.List[int]] super(DummyDatasetMultipleSequenceLength, self).__init__( input_dim=input_dim, output_dim=output_dim, @@ -756,6 +757,34 @@ def generate_seq(self, seq_idx): for i in range(i1, i2)]) return DatasetSeq(seq_idx=seq_idx, features=features, targets=targets) + def get_all_tags(self): + """ + :rtype: list[str] + """ + return ["seq-%i" % i for i in range(self._num_seqs)] + + def init_seq_order(self, epoch=None, seq_list=None, seq_order=None): + """ + If random_shuffle_epoch1, for epoch 1 with "random" ordering, we leave the given order as is. + Otherwise, this is mostly the default behavior. + + :param int|None epoch: + :param list[str]|None seq_list: List of sequence tags, to set a predefined order. + :param list[int]|None seq_order: List of corpus sequence indices, to set a predefined order. + :rtype: bool + :returns whether the order changed (True is always safe to return) + """ + super(DummyDatasetMultipleSequenceLength, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order) + + def get_seq_len(i): + return self.seq_len['data'] + + self._seq_order = self.get_seq_order_for_epoch(epoch, self._num_seqs, get_seq_len=get_seq_len) + self._num_seqs = len(self._seq_order) + + def get_current_seq_order(self): + return self._seq_order + class DummyDatasetMultipleDataKeys(DummyDataset): """