forked from Element-Research/dpnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DontCast.lua
82 lines (70 loc) · 2.42 KB
/
DontCast.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
local DontCast, parent = torch.class("nn.DontCast", "nn.Decorator")
function DontCast:__init(module, castin, castout, moduleType)
parent.__init(self, module)
self.castin = castin
self.castout = (castout == nil) and castin or castout
self.moduleType = moduleType
if not self.moduleType then
assert(torch.isTensor(module.output), "cannot extrapolate module type")
self.moduleType = torch.typename(module.output)
end
end
function DontCast:updateOutput(input)
if self.castin and torch.type(input) ~= self.moduleType then
self._input = self._input or torch.getmetatable(self.moduleType).new()
self._input:resize(input:size()):copy(input)
input = self._input
end
local output = self.module:updateOutput(input)
if self.castout then
self.output:resize(output:size()):copy(output)
else
self.output = output
end
return self.output
end
function DontCast:updateGradInput(input, gradOutput)
if self.castin and torch.type(input) ~= self.moduleType then
input = self._input
end
if self.castout and torch.type(gradOutput) ~= self.moduleType then
self._gradOutput = self._gradOutput or torch.getmetatable(self.moduleType).new()
self._gradOutput:resize(gradOutput:size()):copy(gradOutput)
gradOutput = self._gradOutput
end
local gradInput = self.module:updateGradInput(input, gradOutput)
if self.castin then
self.gradInput:resize(gradInput:size()):copy(gradInput)
else
self.gradInput = gradInput
end
return self.gradInput
end
function DontCast:accGradParameters(input, gradOutput, scale)
if self.castin and torch.type(input) ~= self.moduleType then
input = self._input
end
if self.castout and torch.type(gradOutput) ~= self.moduleType then
gradOutput = self._gradOutput
end
self.module:accGradParameters(input, gradOutput, scale)
end
function DontCast:accUpdateGradParameters(input, gradOutput, lr)
if self.castin and torch.type(input) ~= self.moduleType then
input = self._input
end
if self.castout and torch.type(gradOutput) ~= self.moduleType then
gradOutput = self._gradOutput
end
self.module:accUpdateGradParameters(input, gradOutput, lr)
end
-- dont cast
function DontCast:type(type)
if self.castout then
self.output = self.output:type(type)
end
if self.castin then
self.gradInput = self.gradInput:type(type)
end
return self
end