forked from harvardnlp/seq2seq-attn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.lua
executable file
·242 lines (226 loc) · 7.88 KB
/
models.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
dofile 'util.lua'
function nn.Module:reuseMem()
self.reuse = true
return self
end
function nn.Module:setReuse()
if self.reuse then
self.gradInput = self.output
end
end
function make_lstm(data, opt, model, use_chars)
assert(model == 'enc' or model == 'dec')
local name = '_' .. model
local dropout = opt.dropout or 0
local n = opt.num_layers
local rnn_size = opt.rnn_size
local input_size
if use_chars == 0 then
input_size = opt.word_vec_size
else
input_size = opt.num_kernels
end
local offset = 0
-- there will be 2*n+3 inputs
local inputs = {}
table.insert(inputs, nn.Identity()()) -- x (batch_size x max_word_l)
if model == 'dec' then
table.insert(inputs, nn.Identity()()) -- all context (batch_size x source_l x rnn_size)
offset = offset + 1
if opt.input_feed == 1 then
table.insert(inputs, nn.Identity()()) -- prev context_attn (batch_size x rnn_size)
offset = offset + 1
end
end
for L = 1,n do
table.insert(inputs, nn.Identity()()) -- prev_c[L]
table.insert(inputs, nn.Identity()()) -- prev_h[L]
end
local x, input_size_L
local outputs = {}
for L = 1,n do
-- c,h from previous timesteps
local prev_c = inputs[L*2+offset]
local prev_h = inputs[L*2+1+offset]
-- the input to this layer
if L == 1 then
if use_chars == 0 then
local word_vecs
if model == 'enc' then
word_vecs = nn.LookupTable(data.source_size, input_size)
else
word_vecs = nn.LookupTable(data.target_size, input_size)
end
word_vecs.name = 'word_vecs' .. name
x = word_vecs(inputs[1]) -- batch_size x word_vec_size
else
local char_vecs = nn.LookupTable(data.char_size, opt.char_vec_size)
char_vecs.name = 'word_vecs' .. name
local charcnn = make_cnn(opt.char_vec_size, opt.kernel_width, opt.num_kernels)
charcnn.name = 'charcnn' .. name
x = charcnn(char_vecs(inputs[1]))
if opt.num_highway_layers > 0 then
local mlp = make_highway(input_size, opt.num_highway_layers)
mlp.name = 'mlp' .. name
x = mlp(x)
end
end
input_size_L = input_size
if model == 'dec' then
if opt.input_feed == 1 then
x = nn.JoinTable(2)({x, inputs[1+offset]}) -- batch_size x (word_vec_size + rnn_size)
input_size_L = input_size + rnn_size
end
end
else
x = outputs[(L-1)*2]
if opt.res_net == 1 and L > 2 then
x = nn.CAddTable()({x, outputs[(L-2)*2]})
end
input_size_L = rnn_size
if opt.multi_attn == L and model == 'dec' then
local multi_attn = make_decoder_attn(data, opt, 1)
multi_attn.name = 'multi_attn' .. L
x = multi_attn({x, inputs[2]})
end
if dropout > 0 then
x = nn.Dropout(dropout, nil, false)(x)
end
end
-- evaluate the input sums at once for efficiency
local i2h = nn.Linear(input_size_L, 4 * rnn_size):reuseMem()(x)
local h2h = nn.LinearNoBias(rnn_size, 4 * rnn_size):reuseMem()(prev_h)
local all_input_sums = nn.CAddTable()({i2h, h2h})
local reshaped = nn.Reshape(4, rnn_size)(all_input_sums)
local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4)
-- decode the gates
local in_gate = nn.Sigmoid():reuseMem()(n1)
local forget_gate = nn.Sigmoid():reuseMem()(n2)
local out_gate = nn.Sigmoid():reuseMem()(n3)
-- decode the write inputs
local in_transform = nn.Tanh():reuseMem()(n4)
-- perform the LSTM update
local next_c = nn.CAddTable()({
nn.CMulTable()({forget_gate, prev_c}),
nn.CMulTable()({in_gate, in_transform})
})
-- gated cells form the output
local next_h = nn.CMulTable()({out_gate, nn.Tanh():reuseMem()(next_c)})
table.insert(outputs, next_c)
table.insert(outputs, next_h)
end
if model == 'dec' then
local top_h = outputs[#outputs]
local decoder_out
if opt.attn == 1 then
local decoder_attn = make_decoder_attn(data, opt)
decoder_attn.name = 'decoder_attn'
decoder_out = decoder_attn({top_h, inputs[2]})
else
decoder_out = nn.JoinTable(2)({top_h, inputs[2]})
decoder_out = nn.Tanh()(nn.LinearNoBias(opt.rnn_size*2, opt.rnn_size)(decoder_out))
end
if dropout > 0 then
decoder_out = nn.Dropout(dropout, nil, false)(decoder_out)
end
table.insert(outputs, decoder_out)
end
return nn.gModule(inputs, outputs)
end
function make_decoder_attn(data, opt, simple)
-- 2D tensor target_t (batch_l x rnn_size) and
-- 3D tensor for context (batch_l x source_l x rnn_size)
local inputs = {}
table.insert(inputs, nn.Identity()())
table.insert(inputs, nn.Identity()())
local target_t = nn.LinearNoBias(opt.rnn_size, opt.rnn_size)(inputs[1])
local context = inputs[2]
simple = simple or 0
-- get attention
local attn = nn.MM()({context, nn.Replicate(1,3)(target_t)}) -- batch_l x source_l x 1
attn = nn.Sum(3)(attn)
local softmax_attn = nn.SoftMax()
softmax_attn.name = 'softmax_attn'
attn = softmax_attn(attn)
attn = nn.Replicate(1,2)(attn) -- batch_l x 1 x source_l
-- apply attention to context
local context_combined = nn.MM()({attn, context}) -- batch_l x 1 x rnn_size
context_combined = nn.Sum(2)(context_combined) -- batch_l x rnn_size
local context_output
if simple == 0 then
context_combined = nn.JoinTable(2)({context_combined, inputs[1]}) -- batch_l x rnn_size*2
context_output = nn.Tanh()(nn.LinearNoBias(opt.rnn_size*2,
opt.rnn_size)(context_combined))
else
context_output = nn.CAddTable()({context_combined,inputs[1]})
end
return nn.gModule(inputs, {context_output})
end
function make_generator(data, opt)
local model = nn.Sequential()
model:add(nn.Linear(opt.rnn_size, data.target_size))
model:add(nn.LogSoftMax())
local w = torch.ones(data.target_size)
w[1] = 0
local criterion = nn.ClassNLLCriterion(w)
criterion.sizeAverage = false
return model, criterion
end
function keyword_generator(vecs, opt)
local model = nn.Sequential()
model:add(nn.Linear(opt.rnn_size, vecs:size(1)))
model:add(nn.Sigmoid())
local criterion = nn.BCECriterion()
criterion.sizeAverage = false
return model, criterion
end
-- cnn Unit
function make_cnn(input_size, kernel_width, num_kernels)
local output
local input = nn.Identity()()
if opt.cudnn == 1 then
local conv = cudnn.SpatialConvolution(1, num_kernels, input_size,
kernel_width, 1, 1, 0)
local conv_layer = conv(nn.View(1, -1, input_size):setNumInputDims(2)(input))
output = nn.Sum(3)(nn.Max(3)(nn.Tanh()(conv_layer)))
else
local conv = nn.TemporalConvolution(input_size, num_kernels, kernel_width)
local conv_layer = conv(input)
output = nn.Max(2)(nn.Tanh()(conv_layer))
end
return nn.gModule({input}, {output})
end
function make_highway(input_size, num_layers, output_size, bias, f)
-- size = dimensionality of inputs
-- num_layers = number of hidden layers (default = 1)
-- bias = bias for transform gate (default = -2)
-- f = non-linearity (default = ReLU)
local num_layers = num_layers or 1
local input_size = input_size
local output_size = output_size or input_size
local bias = bias or -2
local f = f or nn.ReLU()
local start = nn.Identity()()
local transform_gate, carry_gate, input, output
for i = 1, num_layers do
if i > 1 then
input_size = output_size
else
input = start
end
output = f(nn.Linear(input_size, output_size)(input))
transform_gate = nn.Sigmoid()(nn.AddConstant(bias, true)(
nn.Linear(input_size, output_size)(input)))
carry_gate = nn.AddConstant(1, true)(nn.MulConstant(-1)(transform_gate))
local proj
if input_size==output_size then
proj = nn.Identity()
else
proj = nn.LinearNoBias(input_size, output_size)
end
input = nn.CAddTable()({
nn.CMulTable()({transform_gate, output}),
nn.CMulTable()({carry_gate, proj(input)})})
end
return nn.gModule({start},{input})
end