-
Notifications
You must be signed in to change notification settings - Fork 313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add grid lstm #345
Open
christopher5106
wants to merge
7
commits into
Element-Research:master
Choose a base branch
from
christopher5106:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
add grid lstm #345
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
344c155
add grid lstm
christopher5106 32bd83e
some corrections
christopher5106 a89a728
n
christopher5106 61ae656
initialization of hidden and cell states in the grid module
christopher5106 97da470
taking lookup table out of module to be more generic
christopher5106 05c188f
for printing
christopher5106 6065689
using userPrevCell to initialize cells after forget
christopher5106 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
local Grid2DLSTM, parent = torch.class("nn.Grid2DLSTM", 'nn.AbstractRecurrent') | ||
|
||
function lstm(h_t, h_d, prev_c, rnn_size) | ||
local all_input_sums = nn.CAddTable()({h_t, h_d}) | ||
local reshaped = nn.Reshape(4, rnn_size)(all_input_sums) | ||
local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) | ||
-- decode the gates | ||
local in_gate = nn.Sigmoid()(n1) | ||
local forget_gate = nn.Sigmoid()(n2) | ||
local out_gate = nn.Sigmoid()(n3) | ||
-- decode the write inputs | ||
local in_transform = nn.Tanh()(n4) | ||
-- perform the LSTM update | ||
local next_c = nn.CAddTable()({ | ||
nn.CMulTable()({forget_gate, prev_c}), | ||
nn.CMulTable()({in_gate, in_transform}) | ||
}) | ||
-- gated cells form the output | ||
local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) | ||
return next_c, next_h | ||
end | ||
|
||
|
||
function Grid2DLSTM:__init( outputSize, nb_layers, dropout, tie_weights, rho, cell2gate) | ||
parent.__init(self, rho or 9999) | ||
self.inputSize = outputSize -- for compatibility with tostring function in AbstractRecurrent | ||
self.outputSize = outputSize | ||
self.should_tie_weights = tie_weights or true | ||
self.dropout = dropout or 0 | ||
self.nb_layers = nb_layers | ||
-- build the model | ||
self.cell2gate = (cell2gate == nil) and true or cell2gate | ||
self.recurrentModule = self:buildModel() | ||
|
||
-- make it work with nn.Container | ||
self.modules[1] = self.recurrentModule | ||
self.sharedClones[1] = self.recurrentModule | ||
|
||
-- for output(0), cell(0) and gradCell(T) | ||
self.zeroTensor = torch.Tensor() | ||
|
||
self.cells = {} | ||
self.gradCells = {} | ||
|
||
-- initialization | ||
-- local net_params = self.recurrentModule:parameters() | ||
-- | ||
-- for _, p in pairs(net_params) do | ||
-- p:uniform(-0.08, 0.08) | ||
-- end | ||
-- | ||
-- initialize the LSTM forget gates with slightly higher biases to encourage remembering in the beginning | ||
for layer_idx = 1, self.nb_layers do | ||
for _,node in ipairs(self.recurrentModule.forwardnodes) do | ||
if node.data.annotations.name == "i2h_" .. layer_idx then | ||
print('setting forget gate biases to 1 in LSTM layer ' .. layer_idx) | ||
-- the gates are, in order, i,f,o,g, so f is the 2nd block of weights | ||
node.data.module.bias[{{self.outputSize+1, 2*self.outputSize}}]:fill(1.0) | ||
end | ||
end | ||
end | ||
|
||
end | ||
|
||
function Grid2DLSTM:buildModel() | ||
require 'nngraph' | ||
assert(nngraph, "Missing nngraph package") | ||
|
||
-- There will be 2*n+1 inputs | ||
local inputs = {} | ||
table.insert(inputs, nn.Identity()()) -- input c for depth dimension | ||
table.insert(inputs, nn.Identity()()) -- input h for depth dimension | ||
for L = 1,self.nb_layers do | ||
table.insert(inputs, nn.Identity()()) -- prev_c[L] for time dimension | ||
table.insert(inputs, nn.Identity()()) -- prev_h[L] for time dimension | ||
end | ||
|
||
local shared_weights | ||
if self.should_tie_weights == true then shared_weights = {nn.Linear(self.outputSize, 4 * self.outputSize), nn.Linear(self.outputSize, 4 * self.outputSize)} end | ||
|
||
local outputs_t = {} -- Outputs being handed to the next time step along the time dimension | ||
local outputs_d = {} -- Outputs being handed from one layer to the next along the depth dimension | ||
|
||
for L = 1,self.nb_layers do | ||
-- Take hidden and memory cell from previous time steps | ||
local prev_c_t = inputs[L*2+1] | ||
local prev_h_t = inputs[L*2+2] | ||
|
||
if L == 1 then | ||
-- We're in the first layer | ||
prev_c_d = inputs[1] -- input_c_d: the starting depth dimension memory cell, just a zero vec. | ||
prev_h_d = inputs[2] -- input_h_d: the starting depth dimension hidden state. | ||
else | ||
-- We're in the higher layers 2...N | ||
-- Take hidden and memory cell from layers below | ||
prev_c_d = outputs_d[((L-1)*2)-1] | ||
prev_h_d = outputs_d[((L-1)*2)] | ||
if self.dropout > 0 then prev_h_d = nn.Dropout(self.dropout)(prev_h_d):annotate{name='drop_' .. L} end -- apply dropout, if any | ||
end | ||
|
||
-- Evaluate the input sums at once for efficiency | ||
local t2h_t = nn.Linear(self.outputSize, 4 * self.outputSize)(prev_h_t):annotate{name='i2h_'..L} | ||
local d2h_t = nn.Linear(self.outputSize, 4 * self.outputSize)(prev_h_d):annotate{name='h2h_'..L} | ||
|
||
-- Get transformed memory and hidden states pointing in the time direction first | ||
local next_c_t, next_h_t = lstm(t2h_t, d2h_t, prev_c_t, self.outputSize) | ||
|
||
-- Pass memory cell and hidden state to next timestep | ||
table.insert(outputs_t, next_c_t) | ||
table.insert(outputs_t, next_h_t) | ||
|
||
-- Evaluate the input sums at once for efficiency | ||
local t2h_d = nn.Linear(self.outputSize, 4 * self.outputSize)(next_h_t):annotate{name='i2h_'..L} | ||
local d2h_d = nn.Linear(self.outputSize, 4 * self.outputSize)(prev_h_d):annotate{name='h2h_'..L} | ||
|
||
-- See section 3.5, "Weight Sharing" of http://arxiv.org/pdf/1507.01526.pdf | ||
-- The weights along the temporal dimension are already tied (cloned many times in train.lua) | ||
-- Here we can tie the weights along the depth dimension. Having invariance in computation | ||
-- along the depth appears to be critical to solving the 15 digit addition problem w/ high accy. | ||
-- See fig 4. to compare tied vs untied grid lstms on this task. | ||
if self.should_tie_weights == true then | ||
print("tying weights along the depth dimension") | ||
t2h_d.data.module:share(shared_weights[1], 'weight', 'bias', 'gradWeight', 'gradBias') | ||
d2h_d.data.module:share(shared_weights[2], 'weight', 'bias', 'gradWeight', 'gradBias') | ||
end | ||
|
||
-- Create the lstm gated update pointing in the depth direction. | ||
-- We 'prioritize' the depth dimension by using the updated temporal hidden state as input | ||
-- instead of the previous temporal hidden state. This implements Section 3.2, "Priority Dimensions" | ||
local next_c_d, next_h_d = lstm(t2h_d, d2h_d, prev_c_d, self.outputSize) | ||
|
||
-- Pass the depth dimension memory cell and hidden state to layer above | ||
table.insert(outputs_d, next_c_d) | ||
table.insert(outputs_d, next_h_d) | ||
end | ||
|
||
-- set up the decoder | ||
local top_h = outputs_d[#outputs_d] | ||
table.insert(outputs_t, top_h) | ||
|
||
-- outputs_h contains | ||
-- nb_layers x (next_c_t, next_h_t) | ||
-- next-h | ||
|
||
return nn.gModule(inputs, outputs_t) | ||
|
||
end | ||
|
||
function Grid2DLSTM:updateOutput(input) | ||
|
||
if self.step == 1 then | ||
-- the initial state of the cell/hidden states | ||
if self.userPrevCell then | ||
-- print("Initializing the cell/hidden states with user previous cell") | ||
self.cells = { [0] = self.userPrevCell } | ||
else | ||
-- print("Initializing the cell/hidden states with zero") | ||
self.cells = {[0] = {}} | ||
|
||
for L=1,self.nb_layers do | ||
local h_init = torch.zeros(input:size(1), self.outputSize):cuda() | ||
table.insert(self.cells[0], h_init:clone()) | ||
table.insert(self.cells[0], h_init:clone()) -- extra initial state for prev_c | ||
end | ||
end | ||
end | ||
|
||
local input_mem_cell = torch.zeros(input:size(1), self.outputSize):float():cuda() | ||
|
||
local rnn_inputs = {input_mem_cell, input, unpack(self.cells[self.step-1])} | ||
-- print(input:size()) | ||
local lst | ||
if self.train ~= false then | ||
self:recycle() | ||
local recurrentModule = self:getStepModule(self.step) | ||
-- the actual forward propagation | ||
lst = recurrentModule:updateOutput(rnn_inputs) | ||
else | ||
lst = self.recurrentModule:updateOutput(rnn_inputs) | ||
end | ||
|
||
self.cells[self.step] = {} | ||
for i=1,#(self.cells[0]) do table.insert(self.cells[self.step], lst[i]) end | ||
|
||
self.outputs[self.step] = lst[#lst] | ||
self.output = lst[#lst] | ||
|
||
self.step = self.step + 1 | ||
self.gradPrevOutput = nil | ||
self.updateGradInputStep = nil | ||
self.accGradParametersStep = nil | ||
-- note that we don't return the cell, just the output | ||
return self.output | ||
|
||
end | ||
|
||
function Grid2DLSTM:_updateGradInput(input, gradOutput) | ||
assert(self.step > 1, "expecting at least one updateOutput") | ||
local step = self.updateGradInputStep - 1 | ||
assert(step >= 1) | ||
|
||
-- set the output/gradOutput states of current Module | ||
local recurrentModule = self:getStepModule(step) | ||
|
||
-- backward propagate through this step | ||
if (step == self.step-1) then | ||
self.gradCells = {[step] = {}} | ||
for L=1,self.nb_layers do | ||
local h_init = torch.zeros(input:size(1), self.outputSize):cuda() | ||
table.insert(self.gradCells[step], h_init:clone()) | ||
table.insert(self.gradCells[step], h_init:clone()) -- extra initial state for prev_c | ||
end | ||
local input_mem_cell = torch.zeros(input:size(1), self.outputSize):float():cuda() | ||
self.rnn_inputs = {input_mem_cell, input, unpack(self.cells[step-1])} | ||
end | ||
|
||
table.insert(self.gradCells[step], gradOutput) | ||
|
||
|
||
local dlst = recurrentModule:updateGradInput(self.rnn_inputs, self.gradCells[step]) | ||
self.gradCells[step-1] = {} | ||
local gradInput = {} | ||
for k,v in pairs(dlst) do | ||
if k > 2 then -- k <= skip_index is gradient on inputs, which we dont need | ||
-- note we do k-1 because first item is dembeddings, and then follow the | ||
-- derivatives of the state, starting at index 2. I know... | ||
self.gradCells[step-1][k-2] = v | ||
else | ||
table.insert(gradInput, v) | ||
end | ||
end | ||
return gradInput[2] | ||
end | ||
|
||
function Grid2DLSTM:_accGradParameters(input, gradOutput, scale) | ||
local step = self.accGradParametersStep - 1 | ||
assert(step >= 1) | ||
|
||
-- set the output/gradOutput states of current Module | ||
local recurrentModule = self:getStepModule(step) | ||
|
||
-- backward propagate through this step | ||
local input_mem_cell = torch.zeros(input:size(1), self.outputSize):float():cuda() | ||
local rnn_inputs = {input_mem_cell, input, unpack(self.cells[step-1])} | ||
|
||
recurrentModule:accGradParameters( rnn_inputs, self.gradCells[step], scale) | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will have to fix this option so it can work without a gpu, i'm getting errors on cpu version
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried removing cuda() references but could not get it to work on the rnn-sin demo with 4x2 tensor/table