Skip to content

Commit

Permalink
LookupRNN, LinearRNN
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Leonard committed Apr 18, 2017
1 parent 2238f86 commit 31b0b89
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 9 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ SET(luasrc
StepLSTM.lua
RecLSTM.lua
utils.lua
LinearRNN.lua
LookupRNN.lua
)

ADD_TORCH_PACKAGE(rnn "${src}" "${luasrc}" "An RNN library for Torch")
Expand Down
17 changes: 17 additions & 0 deletions LinearRNN.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

local LinearRNN, parent = torch.class("nn.LinearRNN", "nn.Recurrence")

function LinearRNN:__init(inputsize, outputsize, transfer)
transfer = transfer or nn.Sigmoid()
local stepmodule = nn.Sequential()
:add(nn.JoinTable(1,1))
:add(nn.Linear(inputsize+outputsize, outputsize))
:add(transfer)
parent.__init(self, stepmodule, outputsize, 1)
self.inputsize = inputsize
self.outputsize = outputsize
end

function LinearRNN:__tostring__()
return torch.type(self) .. "(" .. self.inputsize .. ", " .. self.outputsize ..")"
end
23 changes: 23 additions & 0 deletions LookupRNN.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
local LookupRNN, parent = torch.class("nn.LookupRNN", "nn.Recurrence")

function LookupRNN:__init(nindex, outputsize, transfer, merge)
transfer = transfer or nn.Sigmoid()
merge = merge or nn.CAddTable()
local stepmodule = nn.Sequential() -- input is {x[t], h[t-1]}
:add(nn.ParallelTable()
:add(nn.LookupTable(nindex, outputsize)) -- input layer
:add(nn.Linear(outputsize, outputsize))) -- recurrent layer
:add(merge)
:add(transfer)
parent.__init(self, stepmodule, outputsize, 0)
self.nindex = nindex
self.outputsize = outputsize
end

function LookupRNN:maskZero()
error"Not Implemented"
end

function LookupRNN:__tostring__()
return torch.type(self) .. "(" .. self.nindex .. ", " .. self.outputsize ..")"
end
91 changes: 87 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# rnn: recurrent neural networks #
# Torch recurrent neural networks #

This is a Recurrent Neural Network library that extends Torch's nn.
This is a recurrent neural network (RNN) library that extends Torch's nn.
You can use it to build RNNs, LSTMs, GRUs, BRNNs, BLSTMs, and so forth and so on.
This library includes documentation for the following objects:

Modules that consider successive calls to `forward` as different time-steps in a sequence :
* [AbstractRecurrent](#rnn.AbstractRecurrent) : an abstract class inherited by Recurrent and LSTM;
* [LSTM](#rnn.LSTM) : a vanilla Long-Short Term Memory module;
* [RecLSTM](#rnn.RecLSTM) : a faster LSTM (based on SeqLSTM) that doesn't use peephole connections;
* [LookupRNN](#rnn.LookupRNN): implements a simple RNN where the input layer is a `LookupTable`;
* [LinearRNN](#rnn.LinearRNN): implements a simple RNN where the input layer is a `Linear`;
* [LSTM](#rnn.LSTM) : a vanilla Long-Short Term Memory module (uses peephole connections);
* [RecLSTM](#rnn.RecLSTM) : a faster LSTM (based on `SeqLSTM`) that doesn't use peephole connections;
* [GRU](#rnn.GRU) : Gated Recurrent Units module;
* [Recursor](#rnn.Recursor) : decorates a module to make it conform to the [AbstractRecurrent](#rnn.AbstractRecurrent) interface;
* [Recurrence](#rnn.Recurrence) : decorates a module that outputs `output(t)` given `{input(t), output(t-1)}`;
Expand Down Expand Up @@ -219,6 +221,87 @@ only the previous step is remembered. This is very efficient memory-wise,
such that evaluation can be performed using potentially infinite-length
sequence.

<a name='rnn.LookupRNN'></a>
## LookupRNN

References :
* A. [Sutsekever Thesis Sec. 2.5 and 2.8](http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf)
* B. [Mikolov Thesis Sec. 3.2 and 3.3](http://www.fit.vutbr.cz/~imikolov/rnnlm/thesis.pdf)
* C. [RNN and Backpropagation Guide](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.3.9311&rep=rep1&type=pdf)

This module subclasses the [Recurrence](#rnn.Recurrence) module to implement a simple RNN where the input layer is a
`LookupTable` module and the recurrent layer is a `Linear` module.
Note that to fully implement a Simple RNN, you need to add the output `Linear [+ SoftMax]` module after the `LookupRNN`.

The `nn.LookupRNN(nindex, outputsize, [transfer, merge])` constructor takes up to 4 arguments:
* `nindex` : the number of embeddings in the `LookupTable(nindex, outputsize)` (that is, the *input layer*).
* `outputsize` : the number of output units. This defines the size of the *recurrent layer* which is a `Linear(outputsize, outputsize)`.
* `merge` : a [table Module](https://github.com/torch/nn/blob/master/doc/table.md#table-layers) that merges the outputs of the `LookupTable` and `Linear` module before being forwarded through the `transfer` Module. Defaults to `nn.CAddTable()`.
* `transfer` : a non-linear modules used to process the output of the `merge` module. Defaults to `nn.Sigmoid()`.

The `LookupRNN` is essentially the following:

```lua
nn.Recurrence(
nn.Sequential() -- input is {x[t], h[t-1]}
:add(nn.ParallelTable()
:add(nn.LookupTable(nindex, outputsize)) -- input layer
:add(nn.Linear(outputsize, outputsize))) -- recurrent layer
:add(merge)
:add(transfer)
, outputsize, 0)
```

An RNN is used to process a sequence of inputs.
As an `AbstractRecurrent` subclass, the `LookupRNN` propagates each step of a sequence by its own call to `forward` (and `backward`).
Each call to `LookupRNN.forward` keeps a log of the intermediate states (the `input` and many `Module.outputs`)
and increments the `step` attribute by 1.
Method `backward` must be called in reverse order of the sequence of calls to `forward` in
order to backpropgate through time (BPTT). This reverse order is necessary
to return a `gradInput` for each call to `forward`.

The `step` attribute is only reset to 1 when a call to the `forget` method is made.
In which case, the Module is ready to process the next sequence (or batch thereof).
Note that the longer the sequence, the more memory that will be required to store all the
`output` and `gradInput` states (one for each time step).

To use this module with batches, we suggest using different
sequences of the same size within a batch and calling `updateParameters`
every `seqlen` steps and `forget` at the end of the sequence.

Note that calling the `evaluate` method turns off long-term memory;
the RNN will only remember the previous output. This allows the RNN
to handle long sequences without allocating any additional memory.

For a simple concise example of how to make use of this module, please consult the
[simple-recurrent-network.lua](examples/simple-recurrent-network.lua)
training script.

<a name='rnn.LinearRNN'></a>
## LinearRNN

This module subclasses the [Recurrence](#rnn.Recurrence) module to implement a simple RNN where the input and the recurrent layer
are combined into a single `Linear` module.
Note that to fully implement the Simple RNN, you need to add the output `Linear [+ SoftMax]` module after the `LinearRNN`.

The `nn.LinearRNN(inputsize, outputsize, [transfer])` constructor takes up to 3 arguments:
* `inputsize` : the number of input units;
* `outputsize` : the number of output units.
* `transfer` : a non-linear modules for activating the RNN. Defaults to `nn.Sigmoid()`.

The `LinearRNN` is essentially the following:

```lua
nn.Recurrence(
nn.Sequential()
:add(nn.JoinTable(1,1))
:add(nn.Linear(inputsize+outputsize, outputsize))
:add(transfer)
, outputsize, 1)
```

Combining the input and recurrent layer into a single `Linear` module makes it quite efficient.

<a name='rnn.LSTM'></a>
## LSTM ##
References :
Expand Down
12 changes: 7 additions & 5 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,15 @@ torch.include('rnn', 'CopyGrad.lua')
torch.include('rnn', 'VariableLength.lua')

-- recurrent modules
torch.include('rnn', 'LookupTableMaskZero.lua')
torch.include('rnn', 'MaskZero.lua')
torch.include('rnn', 'TrimZero.lua')
torch.include('rnn', 'AbstractRecurrent.lua')
torch.include('rnn', 'Recursor.lua')
torch.include('rnn', 'Recurrence.lua')
torch.include('rnn', 'LinearRNN.lua')
torch.include('rnn', 'LookupRNN.lua')
torch.include('rnn', 'LSTM.lua')
torch.include('rnn', 'RecLSTM.lua')
torch.include('rnn', 'GRU.lua')
torch.include('rnn', 'Mufuru.lua')
torch.include('rnn', 'Recursor.lua')
torch.include('rnn', 'Recurrence.lua')
torch.include('rnn', 'NormStabilizer.lua')

-- sequencer modules
Expand All @@ -63,6 +62,9 @@ torch.include('rnn', 'SeqBRNN.lua')

-- step modules
torch.include('rnn', 'StepLSTM.lua')
torch.include('rnn', 'LookupTableMaskZero.lua')
torch.include('rnn', 'MaskZero.lua')
torch.include('rnn', 'TrimZero.lua')

-- recurrent criterions:
torch.include('rnn', 'SequencerCriterion.lua')
Expand Down
43 changes: 43 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -7159,6 +7159,49 @@ function rnntest.RecLSTM_maskzero()
end
end

function rnntest.LinearRNN()
local inputsize, outputsize = 3, 4
local seqlen, batchsize = 5, 2

local input = torch.randn(seqlen, batchsize, inputsize)
local gradOutput = torch.randn(seqlen, batchsize, outputsize)

local lrnn = nn.Sequencer(nn.LinearRNN(inputsize, outputsize))

local output = lrnn:forward(input)
lrnn:zeroGradParameters()
local gradInput = lrnn:backward(input, gradOutput)

mytester:assert(output:isSameSizeAs(gradOutput))
mytester:assert(gradInput:isSameSizeAs(input))

local params, gradParams = lrnn:parameters()
for i=1,2 do
mytester:assert(gradParams[i]:abs():mean() > 0.000001)
end
end

function rnntest.LookupRNN()
local nindex, outputsize = 3, 4
local seqlen, batchsize = 5, 2

local input = torch.LongTensor(seqlen, batchsize):random(1,nindex)
local gradOutput = torch.randn(seqlen, batchsize, outputsize)

local lrnn = nn.Sequencer(nn.LookupRNN(nindex, outputsize))

local output = lrnn:forward(input)
lrnn:zeroGradParameters()
lrnn:backward(input, gradOutput)

mytester:assert(output:isSameSizeAs(gradOutput))

local params, gradParams = lrnn:parameters()
for i=1,2 do
mytester:assert(gradParams[i]:abs():mean() > 0.000001)
end
end

function rnn.test(tests, benchmark_, exclude)
mytester = torch.Tester()
mytester:add(rnntest)
Expand Down

0 comments on commit 31b0b89

Please sign in to comment.