forked from rwth-i6/returnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
NumpyDumpDataset.py
128 lines (105 loc) · 4.03 KB
/
NumpyDumpDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from Dataset import Dataset, DatasetSeq
from Log import log
import os
import numpy
class NumpyDumpDataset(Dataset):
file_format_data = "%i.data"
file_format_targets = "%i.targets"
def __init__(self, prefix, postfix=".txt.gz",
start_seq=0, end_seq=None,
num_inputs=None, num_outputs=None, **kwargs):
super(NumpyDumpDataset, self).__init__(**kwargs)
self.file_format_data = prefix + self.file_format_data + postfix
self.file_format_targets = prefix + self.file_format_targets + postfix
self.start_seq = start_seq
self._init_num_seqs(end_seq)
self._seq_index = None
self.cached_seqs = []; " :type: list[DatasetSeq] "
self.num_inputs = num_inputs
self.num_outputs = num_outputs
assert num_inputs and num_outputs
def _init_num_seqs(self, end_seq=None):
last_seq = None
i = self.start_seq
while True:
if end_seq is not None and i >= end_seq:
break
if not os.path.exists(self.file_format_data % i):
break
if not os.path.exists(self.file_format_targets % i):
break
last_seq = i
i += 1
if end_seq is None:
assert last_seq is not None, "None found. Check %s." % (self.file_format_data % self.start_seq)
end_seq = last_seq
else:
assert last_seq == end_seq - 1, "Check %s." % (self.file_format_data % end_seq)
assert end_seq > self.start_seq
self._num_seqs = end_seq - self.start_seq
def _load_numpy_seq(self, seq_idx):
real_idx = self._seq_index[seq_idx]
features = numpy.loadtxt(self.file_format_data % real_idx)
targets = numpy.loadtxt(self.file_format_targets % real_idx)
assert features.ndim == 2
assert features.shape[1] == self.num_inputs
assert targets.ndim == 1
self._add_cache_seq(seq_idx, features, targets)
# ------------ Dataset API --------------
def init_seq_order(self, epoch=None, seq_list=None):
super(NumpyDumpDataset, self).init_seq_order(epoch=epoch, seq_list=seq_list)
if seq_list: raise NotImplementedError
if self.seq_ordering == "sorted": # not supported atm
self.seq_ordering = "default"
self._seq_index = [i + self.start_seq for i in self.get_seq_order_for_epoch(epoch, self.num_seqs)]
self.cached_seqs[:] = []
return True
def _load_seqs(self, start, end):
self._cleanup_old_seq_cache(start)
for i in range(start, end):
if not self._have_cache_seq(i):
self._load_numpy_seq(i)
def get_input_data(self, seq_idx):
return self._get_cache_seq(seq_idx).features
def get_targets(self, target, seq_idx):
return self._get_cache_seq(seq_idx).targets.get(target, None)
def get_ctc_targets(self, seq_idx):
assert False, "No CTC targets."
def get_seq_length(self, seq_idx):
# This is different from the other get_* functions.
# load_seqs() might not have been called before.
if not self._have_cache_seq(seq_idx):
self._load_numpy_seq(seq_idx)
return self._get_cache_seq(seq_idx).num_frames
@property
def num_seqs(self):
return self._num_seqs
def len_info(self):
return "%s, %i seqs" % (self.__class__.__name__, self.num_seqs)
# ------------ Seq cache management -----------
def _cleanup_old_seq_cache(self, seq_end):
i = 0
while i < len(self.cached_seqs):
if self.cached_seqs[i].seq_idx >= seq_end:
break
i += 1
del self.cached_seqs[:i]
def _get_cache_seq(self, seq_idx, error_not_found=True):
for data in self.cached_seqs:
if data.seq_idx == seq_idx:
return data
if error_not_found:
raise Exception("seq %i not loaded" % seq_idx)
else:
return None
def _have_cache_seq(self, seq_idx):
return self._get_cache_seq(seq_idx, error_not_found=False) is not None
def _get_cache_last_seq_idx(self):
if self.cached_seqs:
return self.cached_seqs[-1].seq_idx
else:
return -1
def _add_cache_seq(self, seq_idx, features, targets):
last_seq_idx = self._get_cache_last_seq_idx()
assert seq_idx == last_seq_idx + 1
self.cached_seqs += [DatasetSeq(seq_idx, features, targets)]