forked from Element-Research/dpnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Module.lua
559 lines (502 loc) · 16.8 KB
/
Module.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
local Module = nn.Module
function Module:sparseParameters()
return self:parameters()
end
function Module:updateParameters(learningRate)
-- sparse params can have different learningRate scales per param
local params, gradParams, scales = self:sparseParameters()
if params then
for i,param in pairs(params) do -- pairs for sparse params
local scale = scales and scales[i] or 1
param:add(-learningRate*scale, gradParams[i])
end
end
end
function Module:zeroGradParameters()
local _,gradParams = self:sparseParameters()
if gradParams then
for i,gradParam in pairs(gradParams) do -- pairs for sparse params
gradParam:zero()
end
end
end
------------------------ clone and type --------------------------------
Module.dpnn_parameters = {'weight', 'bias'}
Module.dpnn_gradParameters = {'gradWeight', 'gradBias'}
function Module:sharedClone(shareParams, shareGradParams)
shareParams = (shareParams == nil) and true or shareParams
shareGradParams = (shareGradParams == nil) and true or shareGradParams
local moduleClones, modules
if self.modules then
moduleClones = {}
for i,module in ipairs(self.modules) do
moduleClones[i] = module:sharedClone(shareParams, shareGradParams)
end
modules = self.modules
self.modules = nil -- to prevent recloning
end
local params, pointers = {}, {}
if shareParams then
for i,paramName in ipairs(self.dpnn_parameters) do
local param = self[paramName]
if param then
params[paramName] = param
self[paramName] = nil
if param:storage() then
pointers[torch.pointer(param:storage():data())] = true
end
end
end
end
if shareGradParams then
for i,paramName in ipairs(self.dpnn_gradParameters) do
local gradParam = self[paramName]
if gradParam then
params[paramName] = gradParam
self[paramName] = nil
if gradParam:storage() then
pointers[torch.pointer(gradParam:storage():data())] = true
end
end
end
end
-- find all the tensors that share storage with the shared params
for paramName, param in pairs(self) do
if torch.isTensor(param) and param:storage() then
if pointers[torch.pointer(param:storage():data())] then
params[paramName] = param
self[paramName] = nil
end
end
end
-- clone everything but parameters and/or gradients
local clone = self:clone()
for paramName, param in pairs(params) do
assert(self[paramName] == nil)
self[paramName] = param
clone[paramName] = param.new():set(param)
end
if moduleClones then
assert(self.modules == nil)
self.modules = modules
clone.modules = moduleClones
end
return clone
end
-- by default, Module:type() will preserve shared Tensors.
-- Its more sensible this way, necessary for RNNs and fits
-- in with existing overriden methods.
-- for preserving shared params created with sharedClones
function Module:type(type)
assert(type, 'Module:type() must provide a type to convert to')
-- key: pointer to old storage ; value : new storage
local castmap = dpnn.castmap
local root
if not castmap then
-- Contains torch.Storage instances used in Modules.
-- The use of a global variable is ugly. But It is the only way
-- to fit in with existing overriden Module:type() methods.
root = true
dpnn.castmap = {}
castmap = dpnn.castmap
end
local function recursiveType(param, type_str)
if torch.type(param) == 'table' then
for k,v in pairs(param) do
param[k] = recursiveType(v, type_str)
end
elseif torch.isTypeOf(param, 'nn.Module') or torch.isTypeOf(param, 'nn.Criterion') then
param:type(type_str)
else
if torch.isTensor(param) then
if param:storage() then
local pointer = torch.pointer(param:storage():data())
local storage = castmap[pointer]
-- empty storages (cuda) have zero pointers.
-- we assume that these aren't shared.
-- https://github.com/torch/cutorch/issues/147
if pointer > 0 then
if not storage then
local _param = param
-- cast entire storage
param = param.new(param:storage()):type(type_str)
if param:storage() then -- to handle cuda tensors ...
param:set(param:storage(), _param:storageOffset(), _param:size(), _param:stride())
castmap[pointer] = param:storage()
-- in case the module gets cast more than once:
castmap[torch.pointer(param:storage():data())] = param:storage()
end
else
-- set to point to existing storage
local _param = param
param = torch.getmetatable(type_str).new()
param:set(storage, _param:storageOffset(), _param:size(), _param:stride())
end
else
assert(not storage)
param = param:type(type_str)
end
else
param = param:type(type_str)
end
end
end
return param
end
-- find all tensors and convert them
for key,param in pairs(self) do
self[key] = recursiveType(param, type)
end
if root then
-- reset the cast map
dpnn.castmap = nil
end
return self
end
----------------- serialization (see nn.Serial) -------------------
Module.dpnn_mediumEmpty = {'output', 'gradInput', 'momGradParams', 'dpnn_input'}
Module.dpnn_lightEmpty = Module.dpnn_gradParameters
-- defaults to heavy serialization
Module.dpnn_serialEmpty = {}
Module.dpnn_serialType = false
-- sets the serialization behavior of the entire module structure
function Module:serialMode(empty, type)
assert(torch.type(empty) == 'table', "Expecting table at arg 1")
self.dpnn_serialEmpty = empty
self.dpnn_serialType = type
-- set the serial of all encapsulated modules
local function recursiveSerial(tbl)
for k,v in pairs(tbl) do
if torch.isTypeOf(v, 'nn.Module') then
v:serialMode(empty, type)
elseif torch.type(v) == 'table' then
recursiveSerial(v)
end
end
end
recursiveSerial(self)
return self
end
-- serialMode : serialize everything
function Module:heavySerial(type)
return self:serialMode({}, type)
end
-- serialMode : serialize everything except dpnn_mediumEmpty attributes
function Module:mediumSerial(type)
self.dpnn_serialEmpty = self.dpnn_mediumEmpty
self.dpnn_serialType = (type == nil) and 'float' or type
-- set the serial of all encapsulated modules
local function recursiveSerial(tbl)
for k,v in pairs(tbl) do
if torch.isTypeOf(v, 'nn.Module') then
v:mediumSerial(type)
elseif torch.type(v) == 'table' then
recursiveSerial(v)
end
end
end
recursiveSerial(self)
return self
end
-- serialMode : serialize everything except dpnn_mediumEmpty and dpnn_lightEmpty attributes
function Module:lightSerial(type)
self.dpnn_serialEmpty = _.clone(self.dpnn_mediumEmpty)
for k,v in ipairs(self.dpnn_lightEmpty) do
table.insert(self.dpnn_serialEmpty, v)
end
self.dpnn_serialType = (type == nil) and 'float' or type
-- set the serial of all encapsulated modules
local function recursiveSerial(tbl)
for k,v in pairs(tbl) do
if torch.isTypeOf(v, 'nn.Module') then
v:lightSerial(type)
elseif torch.type(v) == 'table' then
recursiveSerial(v)
end
end
end
recursiveSerial(self)
return self
end
function Module:getSerialState(states)
states = states or {}
-- dont get the serial state of the same module twice (reuse existing)
if states[self] then
return states[self]
end
-- returns the object structure as tables (i.e. without metatables)
local function recursiveState(tbl)
local state = _.map(tbl,
function(k,v)
if torch.isTypeOf(tbl, 'nn.Module') and _.contains(tbl.dpnn_serialEmpty, k) then
-- "empties" module attributes found in empty
if torch.type(v) == 'table' then
-- empty table
return {}
elseif torch.isTensor(v) then
-- empty tensor
return v.new()
else
-- not table nor tensor? then serialize as is
return v
end
elseif torch.isTypeOf(v, 'nn.Module') then
-- recursive, yet can be overwritten
return v:getSerialState(states)
elseif torch.type(v) == 'table' then
-- in case it is a table of modules
if not states[v] then
states[v] = recursiveState(v)
end
return states[v]
else
return v
end
end
)
return state
end
local state = recursiveState(self)
-- include typename so that module can be reconstructed from the state
state.dpnn_typename = torch.type(self)
states[self] = state
return state
end
-- decorates self with nn.Serial
function Module:Serial()
return nn.Serial(self)
end
----------------------- for training -----------------------------
-- useful to get the output size
-- I chose this method name because it is less likely to be overriden.
function Module:outside(insize)
local input
if torch.type(insize) == 'table' then
input = torch.randn(unpack(insize))
else
input = torch.randn(insize)
end
local output = self:updateOutput(input)
return output:size()
end
-- for those interested in implementing the visitor design pattern
function Module:accept(visitor)
visitor:visit(self)
end
-- Can be used as a regularizer instead of weight decay
-- Assumes that parameters are arranged (output dim x ... x input dim)
function Module:maxParamNorm(maxOutNorm, maxInNorm)
-- this allows each module to set its own max[Out,In]Norm
maxOutNorm = self.maxOutNorm or maxOutNorm
maxInNorm = self.maxInNorm or maxInNorm
if not (maxOutNorm or maxInNorm) then
return
end
if self.modules then
for i,module in ipairs(self.modules) do
module:maxParamNorm(maxOutNorm, maxInNorm)
end
else
local params = self:parameters()
if not params or gradParams then
return
end
for k,param in pairs(params) do -- pairs for sparse params
-- By default, only affects non-1D params.
if param:dim() > 1 then
if maxOutNorm and maxOutNorm > 0 then
-- rows feed into output neurons
param:renorm(2, 1, maxOutNorm)
end
if maxInNorm and maxInNorm > 0 then
-- cols feed out from input neurons
param:renorm(2, param:dim(), maxInNorm)
end
end
end
end
end
-- Similar to maxParamNorm, but norm is global to Module for which
-- this is called. Unless moduleLocal is true, in which case, the
-- norm constraint is applied to the norm of all parameters in each
-- component (non-container) module.
function Module:gradParamClip(cutoffNorm, moduleLocal)
-- this allows each module to set its own cutoffNorm
cutoffNorm = self.cutoffNorm or cutoffNorm
if cutoffNorm <= 0 then
return
end
if self.moduleLocal ~= nil then
moduleLocal = self.moduleLocal
end
local norm = 0
if moduleLocal and self.modules then
for i,module in ipairs(self.modules) do
norm = norm + math.pow(module:gradParamClip(maxOutNorm, maxInNorm), 2)
end
norm = math.sqrt(norm)
else
local params, gradParams = self:parameters()
if not (params and gradParams) then
return norm
end
for k,gradParam in pairs(gradParams) do -- pairs for sparse params
norm = norm + math.pow(gradParam:norm(),2)
end
norm = math.sqrt(norm)
if norm > cutoffNorm then
-- rescale gradParams to obtain desired cutoffNorm
for k,gradParam in pairs(gradParams) do
gradParam:mul(cutoffNorm/norm)
end
end
end
return norm
end
-- Adds weight decay constraint on params with dims > 2 (default).
-- TODO : allow inplace weightDecay (before calling accUpdateGradParameters)
function Module:weightDecay(wdFactor, wdMinDim)
-- this allows each module to set its own hyper-parameters
wdFactor = self.wdFactor or wdFactor
if wdFactor <= 0 then
return
end
wdMinDim = self.wdMinDim or wdMinDim or 2
if self.modules then
for i,module in ipairs(self.modules) do
module:weightDecay(wdFactor, wdMinDim)
end
else
local params, gradParams = self:parameters()
if not (params and gradParams) then
return
end
for i,param in pairs(params) do -- pairs for sparse params
if param:dim() >= wdMinDim then
gradParams[i]:add(wdFactor, param)
end
end
end
end
function Module:momentumGradParameters()
if (not self.momGradParams) or _.isEmpty(self.momGradParams) then
local params, gradParams = self:parameters()
if not gradParams or _.isEmpty(gradParams) then
return
end
self.momGradParams = {}
for i,gradParam in pairs(gradParams) do
self.momGradParams[i] = gradParam.new():resizeAs(gradParam):copy(gradParam)
end
end
return self.momGradParams
end
-- uses momentum learning to update gradParams
function Module:updateGradParameters(momFactor, momDamp, momNesterov)
-- this allows each module to set its own hyper-parameters
momFactor = self.momFactor or momFactor
if momFactor <= 0 then
return
end
momDamp = self.momDamp or momDamp or momFactor
if self.momNesterov ~= nil then
momNesterov = self.momNesterov
end
if self.modules then
for i,module in ipairs(self.modules) do
module:updateGradParameters(momFactor, momDamp, momNesterov)
end
else
local params, gradParams = self:parameters()
if (not params) or _.isEmpty(params) then
return
end
local momGradParams = self:momentumGradParameters()
for i,gradParam in pairs(gradParams) do
momGradParams[i]:mul(momFactor):add(1-momDamp, gradParam)
end
if momNesterov then
for i,gradParam in pairs(gradParams) do
gradParam:add(momFactor, momGradParams[i])
end
else
for i,gradParam in pairs(gradParams) do
gradParam:copy(momGradParams[i])
end
end
end
end
function Module:checkParameters()
local params = self:parameters() or {}
for k,param in pairs(params) do
if _.isNaN(param:sum()) then
error("NaN Error for param at index" ..k)
end
end
end
function Module:dontBackward()
self.updateGradInput = function() end
self.accGradParameters = function() end
self.accUpdateGradParameters = function() end
return self
end
function Module:contiguousInput(input, backward)
if backward then
return self.dpnn_cinput or input
end
if not input:isContiguous() then
self.dpnn_cinput = self.dpnn_cinput or input.new()
self.dpnn_cinput:resizeAs(input):copy(input)
input = self.dpnn_cinput
end
return input
end
function Module:toBatch(tensor, nDim, batchDim)
local batchDim = batchDim or 1
if tensor:dim() == nDim then
self.dpnn_online = true
local size = tensor:size():totable()
table.insert(size, batchDim, 1)
tensor = tensor:view(unpack(size))
else
self.dpnn_online = false
end
return tensor
end
function Module:fromBatch(tensor, batchDim)
if self.dpnn_online then
local size = tensor:size():totable()
assert(table.remove(size, batchDim) == 1)
tensor = tensor:view(unpack(size))
end
return tensor
end
function Module:extrapolateType()
local params = module:parameters()
if params then
-- extrapolate the tensor type of the module
local types = {}
for i, param in ipairs(params) do
local tensorType = torch.type(param)
types[tensorType] = (types[tensorType] or 0) + 1
end
local maxCount = 0
local maxType
for tensorType, count in pairs(types) do
if count > maxCount then
maxtype = tensorType
maxCount = count
end
end
return maxType
end
return nil --unknown otherwise
end
function Module:profile()
if self.modules then
for i, module in ipairs(self.modules) do
module:profile()
end
end
self.dpnn_profile = true
end