forked from harvardnlp/seq2seq-attn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.lua
166 lines (153 loc) · 5.5 KB
/
data.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
--
-- Manages encoder/decoder data matrices.
--
local data = torch.class("data")
function data:__init(opt, data_file)
local f = hdf5.open(data_file, 'r')
self.source = f:read('source'):all()
self.target = f:read('target'):all()
self.target_output = f:read('target_output'):all()
self.target_l = f:read('target_l'):all() --max target length each batch
self.target_l_all = f:read('target_l_all'):all()
self.target_l_all:add(-1)
self.batch_l = f:read('batch_l'):all()
self.source_l = f:read('batch_w'):all() --max source length each batch
if opt.load_key_vecs == 1 then
self.keyword_vecs = f:read('vecs'):all()
self.batch_keyword_l = f:read('batch_keyword_l'):all()
self.keyword_size = f:read('keyword_size'):all()[1]
end
if opt.start_symbol == 0 then
self.source_l:add(-2)
self.source = self.source[{{},{2, self.source:size(2)-1}}]
end
self.batch_idx = f:read('batch_idx'):all()
self.target_size = f:read('target_size'):all()[1]
self.source_size = f:read('source_size'):all()[1]
self.target_nonzeros = f:read('target_nonzeros'):all()
if opt.use_chars_enc == 1 then
self.source_char = f:read('source_char'):all()
self.char_size = f:read('char_size'):all()[1]
self.char_length = self.source_char:size(3)
if opt.start_symbol == 0 then
self.source_char = self.source_char[{{}, {2, self.source_char:size(2)-1}}]
end
end
if opt.use_chars_dec == 1 then
self.target_char = f:read('target_char'):all()
self.char_size = f:read('char_size'):all()[1]
self.char_length = self.target_char:size(3)
end
self.length = self.batch_l:size(1)
self.seq_length = self.target:size(2)
self.batches = {}
local max_source_l = self.source_l:max()
local source_l_rev = torch.ones(max_source_l):long()
for i = 1, max_source_l do
source_l_rev[i] = max_source_l - i + 1
end
for i = 1, self.length do
local source_i, target_i,keyword_vec_i
local target_output_i = self.target_output:sub(self.batch_idx[i],self.batch_idx[i]
+self.batch_l[i]-1, 1, self.target_l[i])
local target_l_i = self.target_l_all:sub(self.batch_idx[i],
self.batch_idx[i]+self.batch_l[i]-1)
if opt.use_chars_enc == 1 then
source_i = self.source_char:sub(self.batch_idx[i],
self.batch_idx[i] + self.batch_l[i]-1, 1,
self.source_l[i]):transpose(1,2):contiguous()
else
source_i = self.source:sub(self.batch_idx[i], self.batch_idx[i]+self.batch_l[i]-1,
1, self.source_l[i]):transpose(1,2)
end
if opt.load_key_vecs == 1 then
local keyword_len = self.batch_keyword_l[i]
keyword_vec_i = self.keyword_vecs:sub(self.batch_idx[i], self.batch_idx[i]+self.batch_l[i]-1, 1, keyword_len):transpose(1,2)
end
if opt.reverse_src == 1 then
source_i = source_i:index(1, source_l_rev[{{max_source_l-self.source_l[i]+1,
max_source_l}}])
end
if opt.use_chars_dec == 1 then
target_i = self.target_char:sub(self.batch_idx[i],
self.batch_idx[i] + self.batch_l[i]-1, 1,
self.target_l[i]):transpose(1,2):contiguous()
else
target_i = self.target:sub(self.batch_idx[i], self.batch_idx[i]+self.batch_l[i]-1,
1, self.target_l[i]):transpose(1,2)
end
if opt.load_key_vecs == 1 then
table.insert(self.batches, {target_i,
target_output_i:transpose(1,2),
self.target_nonzeros[i],
source_i,
self.batch_l[i],
self.target_l[i],
self.source_l[i],
target_l_i,
keyword_vec_i,
self.batch_keyword_l[i]})
else
table.insert(self.batches, {target_i,
target_output_i:transpose(1,2),
self.target_nonzeros[i],
source_i,
self.batch_l[i],
self.target_l[i],
self.source_l[i],
target_l_i})
end
end
end
function data:size()
return self.length
end
function data.__index(self, idx)
if type(idx) == "string" then
return data[idx]
else
local target_input = self.batches[idx][1]
local target_output = self.batches[idx][2]
local nonzeros = self.batches[idx][3]
local source_input = self.batches[idx][4]
local batch_l = self.batches[idx][5]
local target_l = self.batches[idx][6]
local source_l = self.batches[idx][7]
local target_l_all = self.batches[idx][8]
local keyword_idx_vec = self.batches[idx][9]:transpose(1,2)
local keyword_len = self.batches[idx][10]
local keyword_vec
if opt.load_key_vecs == 1 then
keyword_vec = torch.zeros(batch_l, self.keyword_size)
for sent_id = 1, batch_l do
for word_num = 1, keyword_len do
if keyword_idx_vec[sent_id][word_num] ~= 0 then
keyword_vec[sent_id][keyword_idx_vec[sent_id][word_num]] = 1
end
end
end
end
keyword_vec = keyword_vec:transpose(1,2)
if opt.gpuid >= 0 then --if multi-gpu, source lives in gpuid1, rest on gpuid2
cutorch.setDevice(opt.gpuid)
source_input = source_input:cuda()
if opt.load_key_vecs == 1 then
keyword_vec = keyword_vec:cuda()
end
if opt.gpuid2 >= 0 then
cutorch.setDevice(opt.gpuid2)
end
target_input = target_input:cuda()
target_output = target_output:cuda()
target_l_all = target_l_all:cuda()
end
if opt.load_key_vecs == 1 then
return {target_input, target_output, nonzeros, source_input,
batch_l, target_l, source_l, target_l_all,keyword_vec}
else
return {target_input, target_output, nonzeros, source_input,
batch_l, target_l, source_l, target_l_all}
end
end
end
return data