From e11c9ec8aca95d5065306c73eff660e25e9002ca Mon Sep 17 00:00:00 2001 From: Achal Dave Date: Sat, 4 Mar 2017 19:13:06 -0500 Subject: [PATCH] Error when calling backward more than rho+1 times --- AbstractRecurrent.lua | 4 ++++ test/test.lua | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/AbstractRecurrent.lua b/AbstractRecurrent.lua index 276ee73..7607808 100644 --- a/AbstractRecurrent.lua +++ b/AbstractRecurrent.lua @@ -55,6 +55,8 @@ end function AbstractRecurrent:updateGradInput(input, gradOutput) -- updateGradInput should be called in reverse order of time self.updateGradInputStep = self.updateGradInputStep or self.step + assert(self.updateGradInputStep >= self.step - self.rho, + string.format('Called backward more than rho+1=%d times', self.rho+1)) -- BPTT for one time-step self.gradInput = self:_updateGradInput(input, gradOutput) @@ -68,6 +70,8 @@ function AbstractRecurrent:accGradParameters(input, gradOutput, scale) -- accGradParameters should be called in reverse order of time assert(self.updateGradInputStep < self.step, "Missing updateGradInput") self.accGradParametersStep = self.accGradParametersStep or self.step + assert(self.accGradParametersStep >= self.step - self.rho, + string.format('Called backward more than rho+1=%d times', self.rho+1)) -- BPTT for one time-step self:_accGradParameters(input, gradOutput, scale) diff --git a/test/test.lua b/test/test.lua index 62d33b0..374fc0a 100644 --- a/test/test.lua +++ b/test/test.lua @@ -640,6 +640,17 @@ function rnntest.Recurrent_old() mytester:assert(err < 0.0001, "Recurrent optim.checkgrad error") end +function rnntest.RecurrentErrorOnExtraBackward() + local model = nn.Recurrent( + nn.Identity(), nn.Identity(), nil, nil, 1 --[[rho]]) + local input = torch.rand(1) + model:training() + for i = 1, 3 do model:forward(input) end + for j = 1, 2 do model:backward(input, input) end + mytester:assertErrorPattern(function() model:backward(input, input) end, + 'Called backward more than rho%+1=2 times') +end + function rnntest.Recurrent() local batchSize = 4 local dictSize = 100