Skip to content

Commit

Permalink
Error when calling backward more than rho+1 times
Browse files Browse the repository at this point in the history
  • Loading branch information
achalddave committed Mar 5, 2017
1 parent ef98a97 commit 018cd71
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
4 changes: 4 additions & 0 deletions AbstractRecurrent.lua
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ end
function AbstractRecurrent:updateGradInput(input, gradOutput)
-- updateGradInput should be called in reverse order of time
self.updateGradInputStep = self.updateGradInputStep or self.step
assert(self.updateGradInputStep >= self.step - self.rho,
string.format('Called backward more than rho=%d times', self.rho))

-- BPTT for one time-step
self.gradInput = self:_updateGradInput(input, gradOutput)
Expand All @@ -68,6 +70,8 @@ function AbstractRecurrent:accGradParameters(input, gradOutput, scale)
-- accGradParameters should be called in reverse order of time
assert(self.updateGradInputStep < self.step, "Missing updateGradInput")
self.accGradParametersStep = self.accGradParametersStep or self.step
assert(self.accGradParametersStep >= self.step - self.rho,
string.format('Called backward more than rho=%d times', self.rho))

-- BPTT for one time-step
self:_accGradParameters(input, gradOutput, scale)
Expand Down
11 changes: 11 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,17 @@ function rnntest.Recurrent_old()
mytester:assert(err < 0.0001, "Recurrent optim.checkgrad error")
end

function rnntest.RecurrentErrorOnExtraBackward()
local model = nn.Recurrent(
nn.Identity(), nn.Identity(), nil, nil, 1 --[[rho]])
local input = torch.rand(1)
model:training()
for i = 1, 3 do model:forward(input) end
for j = 1, 2 do model:backward(input, input) end
mytester:assertErrorPattern(function() model:backward(input, input) end,
'Called backward more than rho%+1=2 times')
end

function rnntest.Recurrent()
local batchSize = 4
local dictSize = 100
Expand Down

0 comments on commit 018cd71

Please sign in to comment.