-
Notifications
You must be signed in to change notification settings - Fork 0
/
LookupTable.lua
85 lines (69 loc) · 2.3 KB
/
LookupTable.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
local LookupTable, parent = torch.class('nn.LookupTable', 'nn.Module')
LookupTable.__version = 4
function LookupTable:__init(nIndex, nOutput)
parent.__init(self)
self.weight = torch.Tensor(nIndex, nOutput)
self.gradWeight = torch.Tensor(nIndex, nOutput):zero()
self._count = torch.IntTensor()
self._input = torch.LongTensor()
self.shouldScaleGradByFreq = false
self:reset()
end
function LookupTable:accUpdateOnly()
self.gradWeight = nil
return self
end
function LookupTable:scaleGradByFreq()
self.shouldScaleGradByFreq = true
return self
end
function LookupTable:reset(stdv)
stdv = stdv or 1
self.weight:normal(0, stdv)
end
function LookupTable:makeInputContiguous(input)
-- make sure input is a contiguous torch.LongTensor
if (not input:isContiguous()) or torch.type(input) ~= torch.type(self._input) then
self.copiedInput = true
self._input:resize(input:size()):copy(input)
return self._input
end
self.copiedInput = false
return input
end
function LookupTable:updateOutput(input)
input = self:makeInputContiguous(input)
if input:dim() == 1 then
self.output:index(self.weight, 1, input)
elseif input:dim() == 2 then
self.output:index(self.weight, 1, input:view(-1))
self.output = self.output:view(input:size(1), input:size(2), self.weight:size(2))
else
error("input must be a vector or matrix")
end
return self.output
end
function LookupTable:accGradParameters(input, gradOutput, scale)
input = self.copiedInput and self._input or input
if input:dim() == 2 then
input = input:view(-1)
elseif input:dim() ~= 1 then
error("input must be a vector or matrix")
end
self.gradWeight.nn.LookupTable_accGradParameters(self, input, gradOutput, scale)
end
function LookupTable:type(type)
parent.type(self, type)
if type == 'torch.CudaTensor' then
-- CUDA uses _sorted and _indices temporary tensors
self._sorted = self.weight.new()
self._indices = self.weight.new()
else
-- self._count and self._input should only be converted if using Cuda
self._count = torch.IntTensor()
self._input = torch.LongTensor()
end
return self
end
-- we do not need to accumulate parameters when sharing
LookupTable.sharedAccUpdateGradParameters = LookupTable.accUpdateGradParameters