From d7dd75c89e7bb8167c3d7c95945e383f01245533 Mon Sep 17 00:00:00 2001 From: Nicholas Leonard Date: Thu, 27 Apr 2017 14:20:41 -0400 Subject: [PATCH] rho -> seqlen; refactored AbstractRecurrent:updateOutput --- AbstractRecurrent.lua | 56 ++++++++++++--------------- GRU.lua | 13 +------ NormStabilizer.lua | 12 +----- RecLSTM.lua | 12 +----- Recurrence.lua | 13 +------ Recursor.lua | 9 +---- Repeater.lua | 56 +++++++++++++-------------- Sequencer.lua | 2 +- examples/README.md | 6 +-- examples/nested-recurrence-lstm.lua | 6 +-- examples/simple-sequencer-network.lua | 1 - test/test.lua | 54 +++++++++++++------------- 12 files changed, 96 insertions(+), 144 deletions(-) diff --git a/AbstractRecurrent.lua b/AbstractRecurrent.lua index d516e09..589c80a 100644 --- a/AbstractRecurrent.lua +++ b/AbstractRecurrent.lua @@ -9,7 +9,7 @@ function AbstractRecurrent:__init(stepmodule) parent.__init(self) assert(torch.isTypeOf(stepmodule, 'nn.Module'), torch.type(self).." expecting nn.Module instance at arg 1") - self.rho = 99999 --the maximum number of time steps to BPTT + self.seqlen = 99999 --the maximum number of time steps to BPTT self.outputs = {} self.gradInputs = {} @@ -52,6 +52,20 @@ function AbstractRecurrent:trimZero(nInputDim) return self end +function AbstractRecurrent:updateOutput(input) + -- feed-forward for one time-step + self.output = self:_updateOutput(input) + + self.outputs[self.step] = self.output + + self.step = self.step + 1 + self.gradPrevOutput = nil + self.updateGradInputStep = nil + self.accGradParametersStep = nil + + return self.output +end + function AbstractRecurrent:updateGradInput(input, gradOutput) -- updateGradInput should be called in reverse order of time self.updateGradInputStep = self.updateGradInputStep or self.step @@ -84,16 +98,16 @@ function AbstractRecurrent:recycle(offset) local _ = require 'moses' self.nSharedClone = self.nSharedClone or _.size(self.sharedClones) - local rho = math.max(self.rho + 1, self.nSharedClone) + local seqlen = math.max(self.seqlen + 1, self.nSharedClone) if self.sharedClones[self.step] == nil then - self.sharedClones[self.step] = self.sharedClones[self.step-rho] - self.sharedClones[self.step-rho] = nil - self._gradOutputs[self.step] = self._gradOutputs[self.step-rho] - self._gradOutputs[self.step-rho] = nil + self.sharedClones[self.step] = self.sharedClones[self.step-seqlen] + self.sharedClones[self.step-seqlen] = nil + self._gradOutputs[self.step] = self._gradOutputs[self.step-seqlen] + self._gradOutputs[self.step-seqlen] = nil end - self.outputs[self.step-rho-1] = nil - self.gradInputs[self.step-rho-1] = nil + self.outputs[self.step-seqlen-1] = nil + self.gradInputs[self.step-seqlen-1] = nil return self end @@ -226,8 +240,8 @@ function AbstractRecurrent:setOutputStep(step) self.gradInput = self.gradInputs[step] end -function AbstractRecurrent:maxBPTTstep(rho) - self.rho = rho +function AbstractRecurrent:maxBPTTstep(seqlen) + self.seqlen = seqlen end -- get stored hidden state: h[t] where h[t] = f(x[t], h[t-1]) @@ -258,28 +272,6 @@ AbstractRecurrent.recursiveAdd = rnn.recursiveAdd AbstractRecurrent.recursiveTensorEq = rnn.recursiveTensorEq AbstractRecurrent.recursiveNormal = rnn.recursiveNormal - - -function AbstractRecurrent:backwardThroughTime(step, rho) - error"DEPRECATED Jan 8, 2016" -end - -function AbstractRecurrent:updateGradInputThroughTime(step, rho) - error"DEPRECATED Jan 8, 2016" -end - -function AbstractRecurrent:accGradParametersThroughTime(step, rho) - error"DEPRECATED Jan 8, 2016" -end - -function AbstractRecurrent:accUpdateGradParametersThroughTime(lr, step, rho) - error"DEPRECATED Jan 8, 2016" -end - -function AbstractRecurrent:backwardUpdateThroughTime(learningRate) - error"DEPRECATED Jan 8, 2016" -end - function AbstractRecurrent:__tostring__() if self.inputSize and self.outputSize then return self.__typename .. string.format("(%d -> %d)", self.inputSize, self.outputSize) diff --git a/GRU.lua b/GRU.lua index 743ce78..b1328fd 100644 --- a/GRU.lua +++ b/GRU.lua @@ -150,7 +150,7 @@ function GRU:setHiddenState(step, hiddenState) end ------------------------- forward backward ----------------------------- -function GRU:updateOutput(input) +function GRU:_updateOutput(input) local prevOutput = self:getHiddenState(self.step-1, input) -- output(t) = gru{input(t), output(t-1)} @@ -164,16 +164,7 @@ function GRU:updateOutput(input) output = self.modules[1]:updateOutput{input, prevOutput} end - self.outputs[self.step] = output - - self.output = output - - self.step = self.step + 1 - self.gradPrevOutput = nil - self.updateGradInputStep = nil - self.accGradParametersStep = nil - -- note that we don't return the cell, just the output - return self.output + return output end diff --git a/NormStabilizer.lua b/NormStabilizer.lua index b11fbfa..3e11f8f 100644 --- a/NormStabilizer.lua +++ b/NormStabilizer.lua @@ -17,7 +17,7 @@ function NS:_accGradParameters(input, gradOutput, scale) -- No parameters to update end -function NS:updateOutput(input) +function NS:_updateOutput(input) assert(input:dim() == 2) local output if self.train ~= false then @@ -31,15 +31,7 @@ function NS:updateOutput(input) output = self.modules[1]:updateOutput(input) end - self.outputs[self.step] = output - - self.output = output - self.step = self.step + 1 - self.gradPrevOutput = nil - self.updateGradInputStep = nil - self.accGradParametersStep = nil - - return self.output + return output end -- returns norm-stabilizer loss as defined in ref. A diff --git a/RecLSTM.lua b/RecLSTM.lua index de1c362..942f96a 100644 --- a/RecLSTM.lua +++ b/RecLSTM.lua @@ -59,7 +59,7 @@ function RecLSTM:setHiddenState(step, hiddenState) end ------------------------- forward backward ----------------------------- -function RecLSTM:updateOutput(input) +function RecLSTM:_updateOutput(input) local prevOutput, prevCell = unpack(self:getHiddenState(self.step-1, input)) -- output(t), cell(t) = lstm{input(t), output(t-1), cell(t-1)} @@ -73,18 +73,10 @@ function RecLSTM:updateOutput(input) output, cell = unpack(self.modules[1]:updateOutput{input, prevOutput, prevCell}) end - self.outputs[self.step] = output self.cells[self.step] = cell - - self.output = output self.cell = cell - - self.step = self.step + 1 - self.gradPrevOutput = nil - self.updateGradInputStep = nil - self.accGradParametersStep = nil -- note that we don't return the cell, just the output - return self.output + return output end function RecLSTM:getGradHiddenState(step) diff --git a/Recurrence.lua b/Recurrence.lua index 7b0e7b0..6469575 100644 --- a/Recurrence.lua +++ b/Recurrence.lua @@ -102,7 +102,7 @@ function Recurrence:setHiddenState(step, hiddenState) end end -function Recurrence:updateOutput(input) +function Recurrence:_updateOutput(input) -- output(t-1) local prevOutput = self:getHiddenState(self.step-1, input)[1] @@ -117,16 +117,7 @@ function Recurrence:updateOutput(input) output = self.modules[1]:updateOutput{input, prevOutput} end - self.outputs[self.step] = output - - self.output = output - - self.step = self.step + 1 - self.gradPrevOutput = nil - self.updateGradInputStep = nil - self.accGradParametersStep = nil - - return self.output + return output end function Recurrence:getGradHiddenState(step) diff --git a/Recursor.lua b/Recursor.lua index 32bae3f..e8c7b3a 100644 --- a/Recursor.lua +++ b/Recursor.lua @@ -6,7 +6,7 @@ ------------------------------------------------------------------------ local Recursor, parent = torch.class('nn.Recursor', 'nn.AbstractRecurrent') -function Recursor:updateOutput(input) +function Recursor:_updateOutput(input) local output if self.train ~= false then -- if self.train or self.train == nil then -- set/save the output states @@ -17,12 +17,7 @@ function Recursor:updateOutput(input) output = self.modules[1]:updateOutput(input) end - self.outputs[self.step] = output - self.output = output - self.step = self.step + 1 - self.updateGradInputStep = nil - self.accGradParametersStep = nil - return self.output + return output end function Recursor:_updateGradInput(input, gradOutput) diff --git a/Repeater.lua b/Repeater.lua index eaedcc6..c98b06e 100644 --- a/Repeater.lua +++ b/Repeater.lua @@ -1,20 +1,20 @@ ------------------------------------------------------------------------ --[[ Repeater ]]-- --- Encapsulates an AbstractRecurrent instance (rnn) which is repeatedly --- presented with the same input for rho time steps. --- The output is a table of rho outputs of the rnn. +-- Encapsulates an AbstractRecurrent instance (rnn) which is repeatedly +-- presented with the same input for seqlen time steps. +-- The output is a table of seqlen outputs of the rnn. ------------------------------------------------------------------------ assert(not nn.Repeater, "update nnx package : luarocks install nnx") local Repeater, parent = torch.class('nn.Repeater', 'nn.AbstractSequencer') -function Repeater:__init(module, rho) +function Repeater:__init(module, seqlen) parent.__init(self) - assert(torch.type(rho) == 'number', "expecting number value for arg 2") - self.rho = rho + assert(torch.type(seqlen) == 'number', "expecting number value for arg 2") + self.seqlen = seqlen self.module = (not torch.isTypeOf(module, 'nn.AbstractRecurrent')) and nn.Recursor(module) or module - - self.module:maxBPTTstep(rho) -- hijack rho (max number of time-steps for backprop) - + + self.module:maxBPTTstep(seqlen) -- hijack seqlen (max number of time-steps for backprop) + self.modules[1] = self.module self.output = {} end @@ -24,21 +24,21 @@ function Repeater:updateOutput(input) self.module:forget() -- TODO make copy outputs optional - for step=1,self.rho do + for step=1,self.seqlen do self.output[step] = nn.rnn.recursiveCopy(self.output[step], self.module:updateOutput(input)) end return self.output end function Repeater:updateGradInput(input, gradOutput) - assert(self.module.step - 1 == self.rho, "inconsistent rnn steps") + assert(self.module.step - 1 == self.seqlen, "inconsistent rnn steps") assert(torch.type(gradOutput) == 'table', "expecting gradOutput table") - assert(#gradOutput == self.rho, "gradOutput should have rho elements") - + assert(#gradOutput == self.seqlen, "gradOutput should have seqlen elements") + -- back-propagate through time (BPTT) - for step=self.rho,1,-1 do + for step=self.seqlen,1,-1 do local gradInput = self.module:updateGradInput(input, gradOutput[step]) - if step == self.rho then + if step == self.seqlen then self.gradInput = nn.rnn.recursiveCopy(self.gradInput, gradInput) else nn.rnn.recursiveAdd(self.gradInput, gradInput) @@ -49,29 +49,29 @@ function Repeater:updateGradInput(input, gradOutput) end function Repeater:accGradParameters(input, gradOutput, scale) - assert(self.module.step - 1 == self.rho, "inconsistent rnn steps") + assert(self.module.step - 1 == self.seqlen, "inconsistent rnn steps") assert(torch.type(gradOutput) == 'table', "expecting gradOutput table") - assert(#gradOutput == self.rho, "gradOutput should have rho elements") - + assert(#gradOutput == self.seqlen, "gradOutput should have seqlen elements") + -- back-propagate through time (BPTT) - for step=self.rho,1,-1 do + for step=self.seqlen,1,-1 do self.module:accGradParameters(input, gradOutput[step], scale) end - + end -function Repeater:maxBPTTstep(rho) - self.rho = rho - self.module:maxBPTTstep(rho) +function Repeater:maxBPTTstep(seqlen) + self.seqlen = seqlen + self.module:maxBPTTstep(seqlen) end function Repeater:accUpdateGradParameters(input, gradOutput, lr) - assert(self.module.step - 1 == self.rho, "inconsistent rnn steps") + assert(self.module.step - 1 == self.seqlen, "inconsistent rnn steps") assert(torch.type(gradOutput) == 'table', "expecting gradOutput table") - assert(#gradOutput == self.rho, "gradOutput should have rho elements") - + assert(#gradOutput == self.seqlen, "gradOutput should have seqlen elements") + -- back-propagate through time (BPTT) - for step=self.rho,1,-1 do + for step=self.seqlen,1,-1 do self.module:accUpdateGradParameters(input, gradOutput[step], lr) end end @@ -84,7 +84,7 @@ function Repeater:__tostring__() str = str .. tab .. ' V V V '.. line str = str .. tab .. tostring(self.modules[1]):gsub(line, line .. tab) .. line str = str .. tab .. ' V V V '.. line - str = str .. tab .. '[output(1),output(2),...,output('..self.rho..')]' .. line + str = str .. tab .. '[output(1),output(2),...,output('..self.seqlen..')]' .. line str = str .. '}' return str end diff --git a/Sequencer.lua b/Sequencer.lua index 46d0eee..3a267a9 100644 --- a/Sequencer.lua +++ b/Sequencer.lua @@ -46,7 +46,7 @@ function Sequencer:updateOutput(input) nStep = #input end - -- Note that the Sequencer hijacks the rho attribute of the rnn + -- Note that the Sequencer hijacks the seqlen attribute of the rnn self.module:maxBPTTstep(nStep) if self.train ~= false then -- TRAINING diff --git a/examples/README.md b/examples/README.md index 91e79e0..c9ca2f0 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,6 +1,6 @@ # Examples -This directory contains various training scripts. +This directory contains various training scripts. Torch blog posts * The torch.ch blog contains detailed posts about the *rnn* package. @@ -8,8 +8,8 @@ Torch blog posts 2. [noise-contrastive-esimate.lua](noise-contrastive-estimate.lua): one of two training scripts used in [Language modeling a billion words](http://torch.ch/blog/2016/07/25/nce.html). Single-GPU script for training recurrent language models on the Google billion words dataset. 3. [multigpu-nce-rnnlm.lua](multigpu-nce-rnnlm.lua) : 4-GPU version of `noise-contrastive-estimate.lua` for training larger multi-GPU models. Two of two training scripts used in the [Language modeling a billion words](http://torch.ch/blog/2016/07/25/nce.html). -Simple training scripts. +Simple training scripts. * Showcases the fundamental principles of the package. In chronological order of introduction date. - 1. [simple-recurrent-network.lua](simple-recurrent-network.lua): uses the `nn.Recurrent` module to instantiate a Simple RNN. Illustrates the first AbstractRecurrent instance in action. It has since been surpassed by the more flexible `nn.Recursor` and `nn.Recurrence`. The `nn.Recursor` class decorates any module to make it conform to the nn.AbstractRecurrent interface. The `nn.Recurrence` implements the recursive `h[t] <- forward(h[t-1], x[t])`. Together, `nn.Recursor` and `nn.Recurrence` can be used to implement a wide range of experimental recurrent architectures. + 1. [simple-recurrent-network.lua](simple-recurrent-network.lua): uses the `nn.LookupRNN` module to instantiate a Simple RNN. Illustrates the first AbstractRecurrent instance in action. It has since been surpassed by the more flexible `nn.Recursor` and `nn.Recurrence`. The `nn.Recursor` class decorates any module to make it conform to the nn.AbstractRecurrent interface. The `nn.Recurrence` implements the recursive `h[t] <- forward(h[t-1], x[t])`. Together, `nn.Recursor` and `nn.Recurrence` can be used to implement a wide range of experimental recurrent architectures. 2. [simple-sequencer-network.lua](simple-sequencer-network.lua): uses the `nn.Sequencer` module to accept a batch of sequences as `input` of size `seqlen x batchsize x ...`. Both tables and tensors are accepted as input and produce the same type of output (table->table, tensor->tensor). The `Sequencer` class abstract away the implementation of back-propagation through time. It also provides a `remember(['neither','both'])` method for triggering what the `Sequencer` remembers between iterations (forward,backward,update). 3. [simple-recurrence-network.lua](simple-recurrence-network.lua): uses the `nn.Recurrence` module to define the h[t] <- sigmoid(h[t-1], x[t]) Simple RNN. Decorates it using `nn.Sequencer` so that an entire batch of sequences (`input`) can forward and backward propagated per update. diff --git a/examples/nested-recurrence-lstm.lua b/examples/nested-recurrence-lstm.lua index 3cf12e2..69afd29 100644 --- a/examples/nested-recurrence-lstm.lua +++ b/examples/nested-recurrence-lstm.lua @@ -4,7 +4,7 @@ require 'rnn' -- hyper-parameters batchSize = 8 -rho = 5 -- sequence length +seqlen = 5 -- sequence length hiddenSize = 7 nIndex = 10 lr = 0.1 @@ -49,10 +49,10 @@ offsets = torch.LongTensor(offsets) -- training local iteration = 1 while true do - -- 1. create a sequence of rho time-steps + -- 1. create a sequence of seqlen time-steps local inputs, targets = {}, {} - for step=1,rho do + for step=1,seqlen do -- a batch of inputs inputs[step] = sequence:index(1, offsets) -- incement indices diff --git a/examples/simple-sequencer-network.lua b/examples/simple-sequencer-network.lua index 9f2c9c7..d0c3a98 100644 --- a/examples/simple-sequencer-network.lua +++ b/examples/simple-sequencer-network.lua @@ -7,7 +7,6 @@ hiddenSize = 7 nIndex = 10 lr = 0.1 - local rnn = nn.Sequential() :add(nn.LookupRNN(nIndex, hiddenSize)) :add(nn.Linear(hiddenSize, nIndex)) diff --git a/test/test.lua b/test/test.lua index 9833e93..4c1fbb0 100644 --- a/test/test.lua +++ b/test/test.lua @@ -41,7 +41,7 @@ function rnntest.RecLSTM_main() local gradInput = gradInputs[1] - local mlp2 -- this one will simulate rho = nStep + local mlp2 -- this one will simulate seqlen = nStep local inputs for step=1,nStep do -- iteratively build an LSTM out of non-recurrent components @@ -152,7 +152,7 @@ function rnntest.GRU() gradInput = gru:backward(input[step], gradOutput[step], 1) end - local mlp2 -- this one will simulate rho = nStep + local mlp2 -- this one will simulate seqlen = nStep local inputs for step=1,nStep do -- iteratively build an GRU out of non-recurrent components @@ -236,7 +236,7 @@ function rnntest.RecurrentAttention() locatorHiddenSize = 20, imageHiddenSize = 20, hiddenSize = 20, - rho = 5, + seqlen = 5, locatorStd = 0.1, inputSize = 28, nClass = 10, @@ -276,11 +276,11 @@ function rnntest.RecurrentAttention() locator:add(nn.HardTanh()) -- bounds sample between -1 and 1 -- model is a reinforcement learning agent - local rva = nn.RecurrentAttention(rnn:clone(), locator:clone(), opt.rho, {opt.hiddenSize}) + local rva = nn.RecurrentAttention(rnn:clone(), locator:clone(), opt.seqlen, {opt.hiddenSize}) local input = torch.randn(opt.batchSize,1,opt.inputSize,opt.inputSize) local gradOutput = {} - for step=1,opt.rho do + for step=1,opt.seqlen do table.insert(gradOutput, torch.randn(opt.batchSize, opt.hiddenSize)) end @@ -288,7 +288,7 @@ function rnntest.RecurrentAttention() local output = rva:forward(input) - mytester:assert(#output == opt.rho, "RecurrentAttention #output err") + mytester:assert(#output == opt.seqlen, "RecurrentAttention #output err") local reward = torch.randn(opt.batchSize) rva:reinforce(reward) @@ -483,7 +483,7 @@ function rnntest.Sequencer_main() mytester:assertTensorEq(params7[i], params8[i], 0.0000001, "Sequencer "..torch.type(rnn7).." remember params err "..i) end - -- test in evaluation mode with remember and variable rho + -- test in evaluation mode with remember and variable seqlen local rnn7 = rnn:clone() -- a fresh copy (no hidden states) local params7 = rnn7:parameters() local params9 = rnn9:parameters() -- not a fresh copy @@ -976,7 +976,7 @@ function rnntest.Sequencer_tensor() mytester:assertTensorEq(params7[i], params8[i], 0.0000001, "Sequencer "..torch.type(rnn7).." remember params err "..i) end - -- test in evaluation mode with remember and variable rho + -- test in evaluation mode with remember and variable seqlen local rnn7 = rnn:clone() -- a fresh copy (no hidden states) local params7 = rnn7:parameters() local params9 = rnn9:parameters() -- not a fresh copy @@ -2356,10 +2356,10 @@ function rnntest.Recursor() local inputSize = 3 local hiddenSize = 12 local outputSize = 7 - local rho = 5 + local seqlen = 5 - local inputs = torch.randn(rho, batchSize, inputSize) - local gradOutputs = torch.randn(rho, batchSize, outputSize) + local inputs = torch.randn(seqlen, batchSize, inputSize) + local gradOutputs = torch.randn(seqlen, batchSize, outputSize) -- USE CASE 1. Recursor(LSTM) @@ -2369,18 +2369,18 @@ function rnntest.Recursor() re:zeroGradParameters() re2:zeroGradParameters() - local outputs = torch.Tensor(rho, batchSize, outputSize) + local outputs = torch.Tensor(seqlen, batchSize, outputSize) local outputs2 = outputs:clone() - local gradInputs = torch.Tensor(rho, batchSize, inputSize) + local gradInputs = torch.Tensor(seqlen, batchSize, inputSize) - for i=1,rho do + for i=1,seqlen do -- forward outputs[i] = re:forward(inputs[i]) outputs2[i] = re2:forward(inputs[i]) end - local gradInputs_2 = torch.Tensor(rho, batchSize, inputSize) - for i=rho,1,-1 do + local gradInputs_2 = torch.Tensor(seqlen, batchSize, inputSize) + for i=seqlen,1,-1 do -- backward gradInputs_2[i] = re2:backward(inputs[i], gradOutputs[i]) gradInputs[i] = re:backward(inputs[i], gradOutputs[i]) @@ -2418,11 +2418,11 @@ function rnntest.Recursor() local outputs = seq:forward(inputs) local gradInputs = seq:backward(inputs, gradOutputs) - for i=1,rho do + for i=1,seqlen do outputs2[i] = re2:forward(inputs[i]) end - for i=rho,1,-1 do + for i=seqlen,1,-1 do gradInputs_2[i] = re2:backward(inputs[i], gradOutputs[i]) end @@ -2450,13 +2450,13 @@ function rnntest.Recursor() re:zeroGradParameters() re2:zeroGradParameters() - for i=1,rho do + for i=1,seqlen do -- forward outputs = re:forward(inputs[i]) outputs2 = re2:forward(inputs[i]) end - for i=rho,1,-1 do + for i=seqlen,1,-1 do -- backward gradInputs_2[i] = re2:backward(inputs[i], gradOutputs[i]) end @@ -2464,7 +2464,7 @@ function rnntest.Recursor() re2:updateParameters(0.1) -- recursor requires reverse-time-step order during backward - for i=rho,1,-1 do + for i=seqlen,1,-1 do gradInputs[i] = re:backward(inputs[i], gradOutputs[i]) end @@ -2543,7 +2543,7 @@ function rnntest.Recurrence_nested() local batchSize = 4 local hiddenSize = 2 - local rho = 3 + local seqlen = 3 local lstm = nn.RecLSTM(hiddenSize,hiddenSize) @@ -2558,7 +2558,7 @@ function rnntest.Recurrence_nested() local seq = nn.Sequencer(rnn) local inputs, gradOutputs = {}, {} - for i=1,rho do + for i=1,seqlen do inputs[i] = torch.randn(batchSize, hiddenSize) gradOutputs[i] = torch.randn(batchSize, hiddenSize) end @@ -4280,7 +4280,7 @@ function rnntest.NormStabilizer() -- Make a simple RNN and training set to test gradients -- hyper-parameters local batchSize = 3 - local rho = 2 + local seqlen = 2 local hiddenSize = 3 local inputSize = 4 local lr = 0.1 @@ -4302,7 +4302,7 @@ function rnntest.NormStabilizer() while iteration < 5 do -- generate a random data point local inputs, targets = {}, {} - for step=1,rho do + for step=1,seqlen do inputs[step] = torch.randn(batchSize, inputSize) targets[step] = torch.randn(batchSize, hiddenSize) end @@ -4403,7 +4403,7 @@ function rnntest.NormStabilizer() local seq2 = nn.Sequencer(ns2) local inputs, gradOutputs = {}, {} - for step=1,rho do + for step=1,seqlen do inputs[step] = torch.randn(batchSize, inputSize) gradOutputs[step] = torch.randn(batchSize, inputSize) end @@ -4413,7 +4413,7 @@ function rnntest.NormStabilizer() local gradInputs = seq:backward(inputs, gradOutputs) local gradInputs2 = seq2:backward(inputs, gradOutputs) - for step=1,rho do + for step=1,seqlen do mytester:assertTensorEq(outputs[step], outputs2[step], 0.0000001) mytester:assertTensorEq(gradInputs[step], gradInputs2[step], 0.0000001) end