Skip to content

Commit

Permalink
rho -> seqlen; refactored AbstractRecurrent:updateOutput
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Leonard committed Apr 27, 2017
1 parent 1987c4e commit d7dd75c
Show file tree
Hide file tree
Showing 12 changed files with 96 additions and 144 deletions.
56 changes: 24 additions & 32 deletions AbstractRecurrent.lua
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function AbstractRecurrent:__init(stepmodule)
parent.__init(self)

assert(torch.isTypeOf(stepmodule, 'nn.Module'), torch.type(self).." expecting nn.Module instance at arg 1")
self.rho = 99999 --the maximum number of time steps to BPTT
self.seqlen = 99999 --the maximum number of time steps to BPTT

self.outputs = {}
self.gradInputs = {}
Expand Down Expand Up @@ -52,6 +52,20 @@ function AbstractRecurrent:trimZero(nInputDim)
return self
end

function AbstractRecurrent:updateOutput(input)
-- feed-forward for one time-step
self.output = self:_updateOutput(input)

self.outputs[self.step] = self.output

self.step = self.step + 1
self.gradPrevOutput = nil
self.updateGradInputStep = nil
self.accGradParametersStep = nil

return self.output
end

function AbstractRecurrent:updateGradInput(input, gradOutput)
-- updateGradInput should be called in reverse order of time
self.updateGradInputStep = self.updateGradInputStep or self.step
Expand Down Expand Up @@ -84,16 +98,16 @@ function AbstractRecurrent:recycle(offset)
local _ = require 'moses'
self.nSharedClone = self.nSharedClone or _.size(self.sharedClones)

local rho = math.max(self.rho + 1, self.nSharedClone)
local seqlen = math.max(self.seqlen + 1, self.nSharedClone)
if self.sharedClones[self.step] == nil then
self.sharedClones[self.step] = self.sharedClones[self.step-rho]
self.sharedClones[self.step-rho] = nil
self._gradOutputs[self.step] = self._gradOutputs[self.step-rho]
self._gradOutputs[self.step-rho] = nil
self.sharedClones[self.step] = self.sharedClones[self.step-seqlen]
self.sharedClones[self.step-seqlen] = nil
self._gradOutputs[self.step] = self._gradOutputs[self.step-seqlen]
self._gradOutputs[self.step-seqlen] = nil
end

self.outputs[self.step-rho-1] = nil
self.gradInputs[self.step-rho-1] = nil
self.outputs[self.step-seqlen-1] = nil
self.gradInputs[self.step-seqlen-1] = nil

return self
end
Expand Down Expand Up @@ -226,8 +240,8 @@ function AbstractRecurrent:setOutputStep(step)
self.gradInput = self.gradInputs[step]
end

function AbstractRecurrent:maxBPTTstep(rho)
self.rho = rho
function AbstractRecurrent:maxBPTTstep(seqlen)
self.seqlen = seqlen
end

-- get stored hidden state: h[t] where h[t] = f(x[t], h[t-1])
Expand Down Expand Up @@ -258,28 +272,6 @@ AbstractRecurrent.recursiveAdd = rnn.recursiveAdd
AbstractRecurrent.recursiveTensorEq = rnn.recursiveTensorEq
AbstractRecurrent.recursiveNormal = rnn.recursiveNormal



function AbstractRecurrent:backwardThroughTime(step, rho)
error"DEPRECATED Jan 8, 2016"
end

function AbstractRecurrent:updateGradInputThroughTime(step, rho)
error"DEPRECATED Jan 8, 2016"
end

function AbstractRecurrent:accGradParametersThroughTime(step, rho)
error"DEPRECATED Jan 8, 2016"
end

function AbstractRecurrent:accUpdateGradParametersThroughTime(lr, step, rho)
error"DEPRECATED Jan 8, 2016"
end

function AbstractRecurrent:backwardUpdateThroughTime(learningRate)
error"DEPRECATED Jan 8, 2016"
end

function AbstractRecurrent:__tostring__()
if self.inputSize and self.outputSize then
return self.__typename .. string.format("(%d -> %d)", self.inputSize, self.outputSize)
Expand Down
13 changes: 2 additions & 11 deletions GRU.lua
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ function GRU:setHiddenState(step, hiddenState)
end

------------------------- forward backward -----------------------------
function GRU:updateOutput(input)
function GRU:_updateOutput(input)
local prevOutput = self:getHiddenState(self.step-1, input)

-- output(t) = gru{input(t), output(t-1)}
Expand All @@ -164,16 +164,7 @@ function GRU:updateOutput(input)
output = self.modules[1]:updateOutput{input, prevOutput}
end

self.outputs[self.step] = output

self.output = output

self.step = self.step + 1
self.gradPrevOutput = nil
self.updateGradInputStep = nil
self.accGradParametersStep = nil
-- note that we don't return the cell, just the output
return self.output
return output
end


Expand Down
12 changes: 2 additions & 10 deletions NormStabilizer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function NS:_accGradParameters(input, gradOutput, scale)
-- No parameters to update
end

function NS:updateOutput(input)
function NS:_updateOutput(input)
assert(input:dim() == 2)
local output
if self.train ~= false then
Expand All @@ -31,15 +31,7 @@ function NS:updateOutput(input)
output = self.modules[1]:updateOutput(input)
end

self.outputs[self.step] = output

self.output = output
self.step = self.step + 1
self.gradPrevOutput = nil
self.updateGradInputStep = nil
self.accGradParametersStep = nil

return self.output
return output
end

-- returns norm-stabilizer loss as defined in ref. A
Expand Down
12 changes: 2 additions & 10 deletions RecLSTM.lua
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function RecLSTM:setHiddenState(step, hiddenState)
end

------------------------- forward backward -----------------------------
function RecLSTM:updateOutput(input)
function RecLSTM:_updateOutput(input)
local prevOutput, prevCell = unpack(self:getHiddenState(self.step-1, input))

-- output(t), cell(t) = lstm{input(t), output(t-1), cell(t-1)}
Expand All @@ -73,18 +73,10 @@ function RecLSTM:updateOutput(input)
output, cell = unpack(self.modules[1]:updateOutput{input, prevOutput, prevCell})
end

self.outputs[self.step] = output
self.cells[self.step] = cell

self.output = output
self.cell = cell

self.step = self.step + 1
self.gradPrevOutput = nil
self.updateGradInputStep = nil
self.accGradParametersStep = nil
-- note that we don't return the cell, just the output
return self.output
return output
end

function RecLSTM:getGradHiddenState(step)
Expand Down
13 changes: 2 additions & 11 deletions Recurrence.lua
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ function Recurrence:setHiddenState(step, hiddenState)
end
end

function Recurrence:updateOutput(input)
function Recurrence:_updateOutput(input)
-- output(t-1)
local prevOutput = self:getHiddenState(self.step-1, input)[1]

Expand All @@ -117,16 +117,7 @@ function Recurrence:updateOutput(input)
output = self.modules[1]:updateOutput{input, prevOutput}
end

self.outputs[self.step] = output

self.output = output

self.step = self.step + 1
self.gradPrevOutput = nil
self.updateGradInputStep = nil
self.accGradParametersStep = nil

return self.output
return output
end

function Recurrence:getGradHiddenState(step)
Expand Down
9 changes: 2 additions & 7 deletions Recursor.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
------------------------------------------------------------------------
local Recursor, parent = torch.class('nn.Recursor', 'nn.AbstractRecurrent')

function Recursor:updateOutput(input)
function Recursor:_updateOutput(input)
local output
if self.train ~= false then -- if self.train or self.train == nil then
-- set/save the output states
Expand All @@ -17,12 +17,7 @@ function Recursor:updateOutput(input)
output = self.modules[1]:updateOutput(input)
end

self.outputs[self.step] = output
self.output = output
self.step = self.step + 1
self.updateGradInputStep = nil
self.accGradParametersStep = nil
return self.output
return output
end

function Recursor:_updateGradInput(input, gradOutput)
Expand Down
56 changes: 28 additions & 28 deletions Repeater.lua
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
------------------------------------------------------------------------
--[[ Repeater ]]--
-- Encapsulates an AbstractRecurrent instance (rnn) which is repeatedly
-- presented with the same input for rho time steps.
-- The output is a table of rho outputs of the rnn.
-- Encapsulates an AbstractRecurrent instance (rnn) which is repeatedly
-- presented with the same input for seqlen time steps.
-- The output is a table of seqlen outputs of the rnn.
------------------------------------------------------------------------
assert(not nn.Repeater, "update nnx package : luarocks install nnx")
local Repeater, parent = torch.class('nn.Repeater', 'nn.AbstractSequencer')

function Repeater:__init(module, rho)
function Repeater:__init(module, seqlen)
parent.__init(self)
assert(torch.type(rho) == 'number', "expecting number value for arg 2")
self.rho = rho
assert(torch.type(seqlen) == 'number', "expecting number value for arg 2")
self.seqlen = seqlen
self.module = (not torch.isTypeOf(module, 'nn.AbstractRecurrent')) and nn.Recursor(module) or module
self.module:maxBPTTstep(rho) -- hijack rho (max number of time-steps for backprop)

self.module:maxBPTTstep(seqlen) -- hijack seqlen (max number of time-steps for backprop)

self.modules[1] = self.module
self.output = {}
end
Expand All @@ -24,21 +24,21 @@ function Repeater:updateOutput(input)

self.module:forget()
-- TODO make copy outputs optional
for step=1,self.rho do
for step=1,self.seqlen do
self.output[step] = nn.rnn.recursiveCopy(self.output[step], self.module:updateOutput(input))
end
return self.output
end

function Repeater:updateGradInput(input, gradOutput)
assert(self.module.step - 1 == self.rho, "inconsistent rnn steps")
assert(self.module.step - 1 == self.seqlen, "inconsistent rnn steps")
assert(torch.type(gradOutput) == 'table', "expecting gradOutput table")
assert(#gradOutput == self.rho, "gradOutput should have rho elements")
assert(#gradOutput == self.seqlen, "gradOutput should have seqlen elements")

-- back-propagate through time (BPTT)
for step=self.rho,1,-1 do
for step=self.seqlen,1,-1 do
local gradInput = self.module:updateGradInput(input, gradOutput[step])
if step == self.rho then
if step == self.seqlen then
self.gradInput = nn.rnn.recursiveCopy(self.gradInput, gradInput)
else
nn.rnn.recursiveAdd(self.gradInput, gradInput)
Expand All @@ -49,29 +49,29 @@ function Repeater:updateGradInput(input, gradOutput)
end

function Repeater:accGradParameters(input, gradOutput, scale)
assert(self.module.step - 1 == self.rho, "inconsistent rnn steps")
assert(self.module.step - 1 == self.seqlen, "inconsistent rnn steps")
assert(torch.type(gradOutput) == 'table', "expecting gradOutput table")
assert(#gradOutput == self.rho, "gradOutput should have rho elements")
assert(#gradOutput == self.seqlen, "gradOutput should have seqlen elements")

-- back-propagate through time (BPTT)
for step=self.rho,1,-1 do
for step=self.seqlen,1,-1 do
self.module:accGradParameters(input, gradOutput[step], scale)
end

end

function Repeater:maxBPTTstep(rho)
self.rho = rho
self.module:maxBPTTstep(rho)
function Repeater:maxBPTTstep(seqlen)
self.seqlen = seqlen
self.module:maxBPTTstep(seqlen)
end

function Repeater:accUpdateGradParameters(input, gradOutput, lr)
assert(self.module.step - 1 == self.rho, "inconsistent rnn steps")
assert(self.module.step - 1 == self.seqlen, "inconsistent rnn steps")
assert(torch.type(gradOutput) == 'table', "expecting gradOutput table")
assert(#gradOutput == self.rho, "gradOutput should have rho elements")
assert(#gradOutput == self.seqlen, "gradOutput should have seqlen elements")

-- back-propagate through time (BPTT)
for step=self.rho,1,-1 do
for step=self.seqlen,1,-1 do
self.module:accUpdateGradParameters(input, gradOutput[step], lr)
end
end
Expand All @@ -84,7 +84,7 @@ function Repeater:__tostring__()
str = str .. tab .. ' V V V '.. line
str = str .. tab .. tostring(self.modules[1]):gsub(line, line .. tab) .. line
str = str .. tab .. ' V V V '.. line
str = str .. tab .. '[output(1),output(2),...,output('..self.rho..')]' .. line
str = str .. tab .. '[output(1),output(2),...,output('..self.seqlen..')]' .. line
str = str .. '}'
return str
end
2 changes: 1 addition & 1 deletion Sequencer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function Sequencer:updateOutput(input)
nStep = #input
end

-- Note that the Sequencer hijacks the rho attribute of the rnn
-- Note that the Sequencer hijacks the seqlen attribute of the rnn
self.module:maxBPTTstep(nStep)
if self.train ~= false then
-- TRAINING
Expand Down
6 changes: 3 additions & 3 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
# Examples

This directory contains various training scripts.
This directory contains various training scripts.

Torch blog posts
* The torch.ch blog contains detailed posts about the *rnn* package.
1. [recurrent-visual-attention.lua](recurrent-visual-attention.lua): training script used in [Recurrent Model for Visual Attention](http://torch.ch/blog/2015/09/21/rmva.html). Implements the REINFORCE learning rule to learn an attention mechanism for classifying MNIST digits, sometimes translated.
2. [noise-contrastive-esimate.lua](noise-contrastive-estimate.lua): one of two training scripts used in [Language modeling a billion words](http://torch.ch/blog/2016/07/25/nce.html). Single-GPU script for training recurrent language models on the Google billion words dataset.
3. [multigpu-nce-rnnlm.lua](multigpu-nce-rnnlm.lua) : 4-GPU version of `noise-contrastive-estimate.lua` for training larger multi-GPU models. Two of two training scripts used in the [Language modeling a billion words](http://torch.ch/blog/2016/07/25/nce.html).

Simple training scripts.
Simple training scripts.
* Showcases the fundamental principles of the package. In chronological order of introduction date.
1. [simple-recurrent-network.lua](simple-recurrent-network.lua): uses the `nn.Recurrent` module to instantiate a Simple RNN. Illustrates the first AbstractRecurrent instance in action. It has since been surpassed by the more flexible `nn.Recursor` and `nn.Recurrence`. The `nn.Recursor` class decorates any module to make it conform to the nn.AbstractRecurrent interface. The `nn.Recurrence` implements the recursive `h[t] <- forward(h[t-1], x[t])`. Together, `nn.Recursor` and `nn.Recurrence` can be used to implement a wide range of experimental recurrent architectures.
1. [simple-recurrent-network.lua](simple-recurrent-network.lua): uses the `nn.LookupRNN` module to instantiate a Simple RNN. Illustrates the first AbstractRecurrent instance in action. It has since been surpassed by the more flexible `nn.Recursor` and `nn.Recurrence`. The `nn.Recursor` class decorates any module to make it conform to the nn.AbstractRecurrent interface. The `nn.Recurrence` implements the recursive `h[t] <- forward(h[t-1], x[t])`. Together, `nn.Recursor` and `nn.Recurrence` can be used to implement a wide range of experimental recurrent architectures.
2. [simple-sequencer-network.lua](simple-sequencer-network.lua): uses the `nn.Sequencer` module to accept a batch of sequences as `input` of size `seqlen x batchsize x ...`. Both tables and tensors are accepted as input and produce the same type of output (table->table, tensor->tensor). The `Sequencer` class abstract away the implementation of back-propagation through time. It also provides a `remember(['neither','both'])` method for triggering what the `Sequencer` remembers between iterations (forward,backward,update).
3. [simple-recurrence-network.lua](simple-recurrence-network.lua): uses the `nn.Recurrence` module to define the h[t] <- sigmoid(h[t-1], x[t]) Simple RNN. Decorates it using `nn.Sequencer` so that an entire batch of sequences (`input`) can forward and backward propagated per update.
6 changes: 3 additions & 3 deletions examples/nested-recurrence-lstm.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ require 'rnn'

-- hyper-parameters
batchSize = 8
rho = 5 -- sequence length
seqlen = 5 -- sequence length
hiddenSize = 7
nIndex = 10
lr = 0.1
Expand Down Expand Up @@ -49,10 +49,10 @@ offsets = torch.LongTensor(offsets)
-- training
local iteration = 1
while true do
-- 1. create a sequence of rho time-steps
-- 1. create a sequence of seqlen time-steps

local inputs, targets = {}, {}
for step=1,rho do
for step=1,seqlen do
-- a batch of inputs
inputs[step] = sequence:index(1, offsets)
-- incement indices
Expand Down
1 change: 0 additions & 1 deletion examples/simple-sequencer-network.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ hiddenSize = 7
nIndex = 10
lr = 0.1


local rnn = nn.Sequential()
:add(nn.LookupRNN(nIndex, hiddenSize))
:add(nn.Linear(hiddenSize, nIndex))
Expand Down
Loading

0 comments on commit d7dd75c

Please sign in to comment.