-
Notifications
You must be signed in to change notification settings - Fork 15
/
decode.py
231 lines (182 loc) · 9.17 KB
/
decode.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import numpy as np
class Decoder(object):
class TracebackNode(object):
def __init__(self, label, predecessor, boundary = False):
self.label = label
self.predecessor = predecessor
self.boundary = boundary
class HypDict(dict):
class Hypothesis(object):
def __init__(self, score, cur_len, traceback):
self.score = score
self.cur_len = cur_len
self.traceback = traceback
def update(self, key, score, cur_len, traceback):
if (not key in self) or (self[key].score <= score):
self[key] = self.Hypothesis(score, cur_len, traceback)
def __init__(self, grammar, frame_sampling=1, max_hypotheses=np.inf, thresh=0.1):
self.grammar = grammar
self.frame_sampling = frame_sampling
self.max_hypotheses = max_hypotheses
self.thresh = thresh
def decode(self, log_frame_probs):
frame_scores = log_frame_probs
# create initial hypotheses
hyps = self.init_decoding(frame_scores)
# decode each following time step
for t in range(2 * self.frame_sampling - 1, frame_scores.shape[0], self.frame_sampling):
hyps = self.decode_frame(t, hyps, frame_scores)
self.prune(hyps, self.thresh)
# transition to end symbol
final_hyp = self.finalize_decoding(hyps, frame_scores)
labels, segments = self.traceback(final_hyp, frame_scores.shape[0])
return labels, final_hyp.score
### helper functions ###
def frame_score(self, frame_scores, t, label):
return frame_scores[t, label]
def prune(self, hyps, thresh=None):
tmp = sorted( [ (hyps[key].score, key) for key in hyps ] )
half_thresh = thresh
count = 0
for i in range(len(tmp) - 1):
if tmp[i][0] < half_thresh:
count += 1
else:
break
del_keys = [ x[1] for x in tmp[0 : count] ]
for key in del_keys:
del hyps[key]
if len(hyps) > self.max_hypotheses:
tmp = sorted( [ (hyps[key].score, key) for key in hyps ] )
end = len(tmp) - self.max_hypotheses
base_score = tmp[end][0]
i = 0
for i in range(end-1, -1, -1):
if base_score == tmp[i][0]:
end -= 1
else:
break
del_keys = [ x[1] for x in tmp[0 : end] ]
for key in del_keys:
del hyps[key]
def init_decoding(self, frame_scores):
hyps = self.HypDict()
context = (self.grammar.start_symbol(),)
for label in self.grammar.possible_successors(context):
key = context + (label, self.frame_sampling, '0')
score = np.inf
hyps.update(key, score, 1, self.TracebackNode(label, None, boundary = True))
return hyps
def decode_frame(self, t, old_hyp, frame_scores):
new_hyp = self.HypDict()
for key, hyp in old_hyp.items():
context, label, length, hist = key[0:-3], key[-3], key[-2], key[-1]
full_context = tuple(list(context) + [label])
cur_idx_1, cur_idx_2 = self.grammar.next_idx_range[context]
next_idx_1, next_idx_2 = self.grammar.next_idx_range[full_context]
if hyp.cur_len == next_idx_1:
new_label = list(self.grammar.possible_successors(full_context))[0]
if new_label == self.grammar.end_symbol():
continue
hist_sum = hist.split(',')
hist_sum = [int(item) for item in hist_sum]
hist_sum = sum(hist_sum)
new_hist = hist + ',' + str(hyp.cur_len - hist_sum)
new_key = full_context + (new_label, self.frame_sampling, new_hist)
score = self.cal_hist_score(new_hist, hyp.cur_len + 1, frame_scores, hyp.score)
new_hyp.update(new_key, score, hyp.cur_len + 1, self.TracebackNode(new_label, hyp.traceback, boundary = True))
continue
if hyp.cur_len <= cur_idx_2:
new_key = context + (label, length + self.frame_sampling, hist)
score = hyp.score
new_hyp.update(new_key, score, hyp.cur_len + 1, self.TracebackNode(label, hyp.traceback, boundary = False))
continue
# stay in the same label...
new_key = context + (label, length + self.frame_sampling, hist)
score = hyp.score
new_hyp.update(new_key, score, hyp.cur_len + 1, self.TracebackNode(label, hyp.traceback, boundary = False))
# ... or go to the next label
new_label = list(self.grammar.possible_successors(full_context))[0]
if new_label == self.grammar.end_symbol():
continue
hist_sum = hist.split(',')
hist_sum = [int(item) for item in hist_sum]
hist_sum = sum(hist_sum)
new_hist = hist + ',' + str(hyp.cur_len - hist_sum)
new_key = full_context + (new_label, self.frame_sampling, new_hist)
score = self.cal_hist_score(new_hist, hyp.cur_len + 1, frame_scores, hyp.score)
new_hyp.update(new_key, score, hyp.cur_len + 1, self.TracebackNode(new_label, hyp.traceback, boundary = True))
return new_hyp
def finalize_decoding(self, old_hyp, frame_scores):
final_hyp = self.HypDict.Hypothesis(-np.inf, 0, None)
for key, hyp in old_hyp.items():
context, label, length, hist = key[0:-3], key[-3], key[-2], key[-1]
context = context + (label,)
labels, segments = self.traceback(hyp, frame_scores.shape[0])
labels = np.array(labels)
seg_list = np.where(labels > 0)
seg_list_bkg = np.where(labels < 1)
score_act = self.oic_score(seg_list, frame_scores[:, 1])
score_bkg = self.oic_score(seg_list_bkg, frame_scores[:, 0])
score = sum(score_act + score_bkg) / (len(score_act) + len(score_bkg))
if score >= final_hyp.score:
final_hyp.score, final_hyp.traceback = score, hyp.traceback
return final_hyp
def traceback(self, hyp, n_frames):
class Segment(object):
def __init__(self, label):
self.label, self.length = label, 0
traceback = hyp.traceback
if traceback is None:
print("", end="")
labels = []
segments = [Segment(traceback.label)]
while not traceback == None:
segments[-1].length += self.frame_sampling
labels += [traceback.label] * self.frame_sampling
if traceback.boundary and not traceback.predecessor == None:
segments.append( Segment(traceback.predecessor.label) )
traceback = traceback.predecessor
segments[0].length += n_frames - len(labels) # append length of missing frames
labels += [hyp.traceback.label] * (n_frames - len(labels)) # append labels for missing frames
return list(reversed(labels)), list(reversed(segments))
def cal_hist_score(self, hist, cur_len, frame_scores, cur_score, _lambda=0.25):
durations = hist.split(',')
segments = []
len_seg = int(durations[-1])
start = cur_len - len_seg - 1
end = cur_len - 2
label = self.grammar.transcript[len(durations) - 2]
inner_score = np.mean(frame_scores[start:end + 1, label])
outer_s = max(0, int(start - _lambda * len_seg))
outer_e = min(int(frame_scores.shape[0] - 1), int(end + _lambda * len_seg))
outer_seg = list(range(outer_s, start)) + list(range(end + 1, outer_e + 1))
if len(outer_seg) == 0:
outer_score = 0
else:
outer_score = np.mean(frame_scores[outer_seg, label])
if len(durations) == 2:
score = inner_score - outer_score
else:
score = ((cur_score * (len(durations) - 2)) + inner_score - outer_score) / (len(durations) - 1)
return score
def oic_score(self, seg_list, frame_scores, _lambda=0.25):
final_score = []
temp_list = np.array(seg_list)[0]
if temp_list.shape[0] > 0:
grouped_temp_list = grouping(temp_list)
for j in range(len(grouped_temp_list)):
inner_score = np.mean(frame_scores[grouped_temp_list[j]])
len_proposal = len(grouped_temp_list[j])
outer_s = max(0, int(grouped_temp_list[j][0] - _lambda * len_proposal))
outer_e = min(int(frame_scores.shape[0] - 1), int(grouped_temp_list[j][-1] + _lambda * len_proposal))
outer_temp_list = list(range(outer_s, int(grouped_temp_list[j][0]))) + list(range(int(grouped_temp_list[j][-1] + 1), outer_e + 1))
if len(outer_temp_list) == 0:
outer_score = 0
else:
outer_score = np.mean(frame_scores[outer_temp_list])
c_score = inner_score - outer_score
final_score.append(c_score)
return final_score
def grouping(arr):
return np.split(arr, np.where(np.diff(arr) != 1)[0] + 1)