From e7c456be016c891488254687d1053959fdf7c6ef Mon Sep 17 00:00:00 2001 From: Nicholas Leonard Date: Wed, 3 May 2017 00:44:25 -0400 Subject: [PATCH] MaskZero v2 --- AbstractRecurrent.lua | 75 +- AbstractSequencerCriterion.lua | 5 +- CMakeLists.txt | 1 - LinearRNN.lua | 2 +- LookupRNN.lua | 8 +- LookupTableMaskZero.lua | 6 +- MaskZero.lua | 128 ++- MaskZeroCriterion.lua | 5 +- Module.lua | 24 +- README.md | 3 +- RecLSTM.lua | 15 +- SeqLSTM.lua | 90 +- Sequencer.lua | 2 +- StepLSTM.lua | 59 +- VariableLength.lua | 40 +- examples/encoder-decoder-coupling.lua | 12 + examples/multigpu-nce-rnnlm.lua | 9 +- examples/noise-contrastive-estimate.lua | 21 +- .../simple-bisequencer-network-variable.lua | 10 +- examples/twitter_sentiment_rnn.lua | 13 +- init.lua | 4 +- test/test.lua | 884 ++---------------- utils.lua | 29 +- 23 files changed, 368 insertions(+), 1077 deletions(-) diff --git a/AbstractRecurrent.lua b/AbstractRecurrent.lua index f175f42..cb87ff4 100644 --- a/AbstractRecurrent.lua +++ b/AbstractRecurrent.lua @@ -32,24 +32,18 @@ function AbstractRecurrent:getStepModule(step) return stepmodule end -function AbstractRecurrent:maskZero(nInputDim) - local stepmodule = nn.MaskZero(self.modules[1], nInputDim, true) - self.sharedClones = {stepmodule} - self.modules[1] = stepmodule - return self -end - -function AbstractRecurrent:trimZero(nInputDim) - if torch.typename(self)=='nn.GRU' and self.p ~= 0 then - assert(self.mono, "TrimZero for BGRU needs `mono` option.") +function AbstractRecurrent:updateOutput(input) + if self.zeroMask then + -- where zeroMask = 1, the past is forgotten, + -- that is, the output/gradOutput is zero'd + local stepmodule = self:getStepModule(self.step) + self.zeroMaskStep = self.zeroMaskStep + 1 + if self.zeroMaskStep > self.zeroMask:size(1) then + error"AbstractRecurrent.updateOutput called more times than self.zeroMask:size(1)" + end + stepmodule:setZeroMask(self.zeroMask[self.zeroMaskStep]) end - local stepmodule = nn.TrimZero(self.modules[1], nInputDim, true) - self.sharedClones = {stepmodule} - self.modules[1] = stepmodule - return self -end -function AbstractRecurrent:updateOutput(input) -- feed-forward for one time-step self.output = self:_updateOutput(input) @@ -64,6 +58,10 @@ function AbstractRecurrent:updateOutput(input) end function AbstractRecurrent:updateGradInput(input, gradOutput) + if self.zeroMask and self.zeroMask:size(1) ~= self.zeroMaskStep then + error"AbstractRecurrent.updateOutput called less times than self.zeroMask:size(1)" + end + -- updateGradInput should be called in reverse order of time self.updateGradInputStep = self.updateGradInputStep or self.step @@ -86,7 +84,7 @@ function AbstractRecurrent:accGradParameters(input, gradOutput, scale) self.accGradParametersStep = self.accGradParametersStep - 1 end --- goes hand in hand with the next method : forget() +-- goes hand in hand with forget() -- this methods brings the oldest memory to the current step function AbstractRecurrent:recycle() self.nSharedClone = self.nSharedClone or _.size(self.sharedClones) @@ -113,6 +111,7 @@ function nn.AbstractRecurrent:clearState() clone:clearState() end self.modules[1]:clearState() + self.zeroMask = nil return parent.clearState(self) end @@ -189,6 +188,32 @@ function AbstractRecurrent:type(type, tensorcache) end) end +function AbstractRecurrent:maskZero(v1) + if not self.maskzero then + self.maskzero = true + local stepmodule = nn.MaskZero(self.modules[1], v1) + self.sharedClones = {stepmodule} + self.modules[1] = stepmodule + end + return self +end + +function AbstractRecurrent:setZeroMask(zeroMask) + if zeroMask == false then + self.zeroMask = false + for k,stepmodule in pairs(self.sharedClones) do + stepmodule:setZeroMask(zeroMask) + end + elseif torch.isTypeOf(self.modules[1], 'nn.AbstractRecurrent') then + self.modules[1]:setZeroMask(zeroMask) + else + assert(zeroMask:dim() >= 2, "Expecting dim >= 2 for zeroMask. For example, seqlen x batchsize") + -- reserve for later. Each step will be masked in updateOutput + self.zeroMask = zeroMask + self.zeroMaskStep = 0 + end +end + function AbstractRecurrent:training() return self:includingSharedClones(function() return parent.training(self) @@ -261,18 +286,12 @@ function AbstractRecurrent:setGradHiddenState(step, hiddenState) error"Not Implemented" end --- backwards compatibility -AbstractRecurrent.recursiveResizeAs = rnn.recursiveResizeAs -AbstractRecurrent.recursiveSet = rnn.recursiveSet -AbstractRecurrent.recursiveCopy = rnn.recursiveCopy -AbstractRecurrent.recursiveAdd = rnn.recursiveAdd -AbstractRecurrent.recursiveTensorEq = rnn.recursiveTensorEq -AbstractRecurrent.recursiveNormal = rnn.recursiveNormal - function AbstractRecurrent:__tostring__() - if self.inputSize and self.outputSize then - return self.__typename .. string.format("(%d -> %d)", self.inputSize, self.outputSize) + local inputsize = self.inputsize or self.inputSize + local outputsize = self.outputsize or self.outputSize + if inputsize and outputsize then + return self.__typename .. string.format("(%d -> %d)", inputsize, outputsize) else - return parent.__tostring__(self) + return self.__typename end end diff --git a/AbstractSequencerCriterion.lua b/AbstractSequencerCriterion.lua index 3e3fe23..ebac701 100644 --- a/AbstractSequencerCriterion.lua +++ b/AbstractSequencerCriterion.lua @@ -8,7 +8,6 @@ local AbstractSequencerCriterion, parent = torch.class('nn.AbstractSequencerCrit function AbstractSequencerCriterion:__init(criterion, sizeAverage) parent.__init(self) - self.criterion = criterion if torch.isTypeOf(criterion, 'nn.ModuleCriterion') then error(torch.type(self).." shouldn't decorate a ModuleCriterion. ".. "Instead, try the other way around : ".. @@ -20,14 +19,14 @@ function AbstractSequencerCriterion:__init(criterion, sizeAverage) else self.sizeAverage = false end - self.clones = {} + self.clones = {criterion} end function AbstractSequencerCriterion:getStepCriterion(step) assert(step, "expecting step at arg 1") local criterion = self.clones[step] if not criterion then - criterion = self.criterion:clone() + criterion = self.clones[1]:clone() self.clones[step] = criterion end return criterion diff --git a/CMakeLists.txt b/CMakeLists.txt index 2bd74a2..8f5adf9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -45,7 +45,6 @@ SET(luasrc SeqReverseSequence.lua Sequencer.lua SequencerCriterion.lua - TrimZero.lua ZeroGrad.lua test/bigtest.lua test/test.lua diff --git a/LinearRNN.lua b/LinearRNN.lua index 5fbf6ca..25f98d8 100644 --- a/LinearRNN.lua +++ b/LinearRNN.lua @@ -13,5 +13,5 @@ function LinearRNN:__init(inputsize, outputsize, transfer) end function LinearRNN:__tostring__() - return torch.type(self) .. "(" .. self.inputsize .. ", " .. self.outputsize ..")" + return torch.type(self) .. "(" .. self.inputsize .. " -> " .. self.outputsize ..")" end \ No newline at end of file diff --git a/LookupRNN.lua b/LookupRNN.lua index 018d78a..691b1cb 100644 --- a/LookupRNN.lua +++ b/LookupRNN.lua @@ -5,7 +5,7 @@ function LookupRNN:__init(nindex, outputsize, transfer, merge) merge = merge or nn.CAddTable() local stepmodule = nn.Sequential() -- input is {x[t], h[t-1]} :add(nn.ParallelTable() - :add(nn.LookupTable(nindex, outputsize)) -- input layer + :add(nn.LookupTableMaskZero(nindex, outputsize)) -- input layer :add(nn.Linear(outputsize, outputsize))) -- recurrent layer :add(merge) :add(transfer) @@ -14,10 +14,6 @@ function LookupRNN:__init(nindex, outputsize, transfer, merge) self.outputsize = outputsize end -function LookupRNN:maskZero() - error"Not Implemented" -end - function LookupRNN:__tostring__() - return torch.type(self) .. "(" .. self.nindex .. ", " .. self.outputsize ..")" + return torch.type(self) .. "(" .. self.nindex .. " -> " .. self.outputsize ..")" end \ No newline at end of file diff --git a/LookupTableMaskZero.lua b/LookupTableMaskZero.lua index cdafc40..9721175 100644 --- a/LookupTableMaskZero.lua +++ b/LookupTableMaskZero.lua @@ -5,17 +5,17 @@ function LookupTableMaskZero:__init(nIndex, nOutput) end function LookupTableMaskZero:updateOutput(input) - self.weight[1]:zero() + self.weight[1]:zero() if self.__input and (torch.type(self.__input) ~= torch.type(input)) then self.__input = nil -- fixes old casting bug end self.__input = self.__input or input.new() self.__input:resizeAs(input):add(input, 1) - return parent.updateOutput(self, self.__input) + return parent.updateOutput(self, self.__input) end function LookupTableMaskZero:accGradParameters(input, gradOutput, scale) - parent.accGradParameters(self, self.__input, gradOutput, scale) + parent.accGradParameters(self, self.__input, gradOutput, scale) end function LookupTableMaskZero:type(type, cache) diff --git a/MaskZero.lua b/MaskZero.lua index bc533ea..7fe11a6 100644 --- a/MaskZero.lua +++ b/MaskZero.lua @@ -1,99 +1,83 @@ ------------------------------------------------------------------------ --[[ MaskZero ]]-- --- Decorator that zeroes the output rows of the encapsulated module --- for commensurate input rows which are tensors of zeros +-- Zeroes the elements of the state tensors +-- (output/gradOutput/input/gradInput) of the encapsulated module +-- for commensurate elements that are 1 in self.zeroMask. +-- By default only output/gradOutput are zeroMasked. +-- self.zeroMask is set with setZeroMask(zeroMask). +-- Only works in batch-mode. +-- Note that when input/gradInput are zeroMasked, it is in-place ------------------------------------------------------------------------ local MaskZero, parent = torch.class("nn.MaskZero", "nn.Decorator") -function MaskZero:__init(module, nInputDim, silent) +function MaskZero:__init(module, v1, maskinput, maskoutput) parent.__init(self, module) assert(torch.isTypeOf(module, 'nn.Module')) - if torch.isTypeOf(module, 'nn.AbstractRecurrent') and not silent then - print("Warning : you are most likely using MaskZero the wrong way. " - .."You should probably use AbstractRecurrent:maskZero() so that " - .."it wraps the internal AbstractRecurrent.recurrentModule instead of " - .."wrapping the AbstractRecurrent module itself.") - end - assert(torch.type(nInputDim) == 'number', 'Expecting nInputDim number at arg 1') - self.nInputDim = nInputDim -end - -function MaskZero:recursiveGetFirst(input) - if torch.type(input) == 'table' then - return self:recursiveGetFirst(input[1]) - else - assert(torch.isTensor(input)) - return input - end -end - -function MaskZero:recursiveMask(output, input, mask) - if torch.type(input) == 'table' then - output = torch.type(output) == 'table' and output or {} - for k,v in ipairs(input) do - output[k] = self:recursiveMask(output[k], v, mask) - end - else - assert(torch.isTensor(input)) - output = torch.isTensor(output) and output or input.new() - - -- make sure mask has the same dimension as the input tensor - local inputSize = input:size():fill(1) - if self.batchmode then - inputSize[1] = input:size(1) - end - mask:resize(inputSize) - -- build mask - local zeroMask = mask:expandAs(input) - output:resizeAs(input):copy(input) - output:maskedFill(zeroMask, 0) - end - return output + self.maskinput = maskinput -- defaults to false + self.maskoutput = maskoutput == nil and true or maskoutput -- defaults to true + self.v2 = not v1 end function MaskZero:updateOutput(input) - -- recurrent module input is always the first one - local rmi = self:recursiveGetFirst(input):contiguous() - if rmi:dim() == self.nInputDim then - self.batchmode = false - rmi = rmi:view(-1) -- collapse dims - elseif rmi:dim() - 1 == self.nInputDim then - self.batchmode = true - rmi = rmi:view(rmi:size(1), -1) -- collapse non-batch dims - else - error("nInputDim error: "..rmi:dim()..", "..self.nInputDim) + if self.v2 then + assert(self.zeroMask ~= nil, "MaskZero expecting zeroMask tensor or false") + else -- backwards compat + self.zeroMask = nn.utils.getZeroMaskBatch(input, self.zeroMask) end - -- build mask - local vectorDim = rmi:dim() - self._zeroMask = self._zeroMask or rmi.new() - self._zeroMask:norm(rmi, 2, vectorDim) - self.zeroMask = self.zeroMask or ( - (torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor() - or (torch.type(rmi) == 'torch.ClTensor') and torch.ClTensor() - or torch.ByteTensor() - ) - self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) + if self.maskinput and self.zeroMask then + nn.utils.recursiveZeroMask(input, self.zeroMask) + end -- forward through decorated module local output = self.modules[1]:updateOutput(input) - self.output = self:recursiveMask(self.output, output, self.zeroMask) + if self.maskoutput and self.zeroMask then + self.output = nn.utils.recursiveCopy(self.output, output) + nn.utils.recursiveZeroMask(self.output, self.zeroMask) + else + self.output = output + end + return self.output end function MaskZero:updateGradInput(input, gradOutput) - -- zero gradOutputs before backpropagating through decorated module - self.gradOutput = self:recursiveMask(self.gradOutput, gradOutput, self.zeroMask) + assert(self.zeroMask ~= nil, "MaskZero expecting zeroMask tensor or false") + + if self.maskoutput and self.zeroMask then + self.gradOutput = nn.utils.recursiveCopy(self.gradOutput, gradOutput) + nn.utils.recursiveZeroMask(self.gradOutput, self.zeroMask) + gradOutput = self.gradOutput + end + + self.gradInput = self.modules[1]:updateGradInput(input, gradOutput) + + if self.maskinput and self.zeroMask then + nn.utils.recursiveZeroMask(self.gradInput, self.zeroMask) + end - self.gradInput = self.modules[1]:updateGradInput(input, self.gradOutput) return self.gradInput end -function MaskZero:type(type, ...) +function MaskZero:clearState() + self.output = nil + self.gradInput = nil self.zeroMask = nil - self._zeroMask = nil - self._maskbyte = nil - self._maskindices = nil + return self +end + +function MaskZero:type(type, ...) + self:clearState() return parent.type(self, type, ...) end + +function MaskZero:setZeroMask(zeroMask) + if zeroMask == false then + self.zeroMask = false + else + assert(torch.isByteTensor(zeroMask)) + assert(zeroMask:isContiguous()) + self.zeroMask = zeroMask + end +end diff --git a/MaskZeroCriterion.lua b/MaskZeroCriterion.lua index d866966..9ecf328 100644 --- a/MaskZeroCriterion.lua +++ b/MaskZeroCriterion.lua @@ -5,11 +5,11 @@ ------------------------------------------------------------------------ local MaskZeroCriterion, parent = torch.class("nn.MaskZeroCriterion", "nn.Criterion") -function MaskZeroCriterion:__init(criterion) +function MaskZeroCriterion:__init(criterion, v1) parent.__init(self) self.criterion = criterion assert(torch.isTypeOf(criterion, 'nn.Criterion')) - self.v2 = true + self.v2 = not v1 end function MaskZeroCriterion:updateOutput(input, target) @@ -39,7 +39,6 @@ function MaskZeroCriterion:updateOutput(input, target) -- indexSelect the input self.input = nn.utils.recursiveIndex(self.input, input, 1, self._indices) self.target = nn.utils.recursiveIndex(self.target, target, 1, self._indices) - -- forward through decorated criterion self.output = self.criterion:updateOutput(self.input, self.target) end diff --git a/Module.lua b/Module.lua index d25f215..58d43ff 100644 --- a/Module.lua +++ b/Module.lua @@ -1,3 +1,4 @@ +local _ = require 'moses' local Module = nn.Module -- You can use this to manually forget past memories in AbstractRecurrent instances @@ -20,12 +21,25 @@ function Module:remember(remember) return self end -function Module:stepClone(shareParams, shareGradParams) - return self:sharedClone(shareParams, shareGradParams, true) +function Module:maskZero(v1) + if self.modules then + for i, module in ipairs(self.modules) do + module:maskZero(v1) + end + end + return self end -function Module:backwardOnline() - print("Deprecated Jan 6, 2016. By default rnn now uses backwardOnline, so no need to call this method") +function Module:setZeroMask(zeroMask) + if self.modules then + for i, module in ipairs(self.modules) do + module:setZeroMask(zeroMask) + end + end +end + +function Module:stepClone(shareParams, shareGradParams) + return self:sharedClone(shareParams, shareGradParams, true) end -- calls setOutputStep on all component AbstractRecurrent modules @@ -78,7 +92,7 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone) shareGradParams = (shareGradParams == nil) and true or shareGradParams if stepClone and self.dpnn_stepclone then - -- this is for AbstractRecurrent modules (in rnn) + -- this is for AbstractRecurrent modules return self end diff --git a/README.md b/README.md index 9b103a6..971e584 100644 --- a/README.md +++ b/README.md @@ -28,9 +28,8 @@ Modules that `forward` entire sequences through a decorated `AbstractRecurrent` Miscellaneous modules and criterions : * [MaskZero](#rnn.MaskZero) : zeroes the `output` and `gradOutput` rows of the decorated module for commensurate `input` rows which are tensors of zeros; - * [TrimZero](#rnn.TrimZero) : same behavior as `MaskZero`, but more efficient when `input` contains lots zero-masked rows; * [LookupTableMaskZero](#rnn.LookupTableMaskZero) : extends `nn.LookupTable` to support zero indexes for padding. Zero indexes are forwarded as tensors of zeros; - * [MaskZeroCriterion](#rnn.MaskZeroCriterion) : zeros the `gradInput` and `err` rows of the decorated criterion for commensurate `input` rows which are tensors of zeros; + * [MaskZeroCriterion](#rnn.MaskZeroCriterion) : zeros the `gradInput` and `loss` rows of the decorated criterion for commensurate `zeroMask` elements which are 1; * [SeqReverseSequence](#rnn.SeqReverseSequence) : reverses an input sequence on a specific dimension; * [VariableLength](#rnn.VariableLength): decorates a `Sequencer` to accept and produce a table of variable length inputs and outputs; diff --git a/RecLSTM.lua b/RecLSTM.lua index f2f05ca..cf1303c 100644 --- a/RecLSTM.lua +++ b/RecLSTM.lua @@ -16,17 +16,18 @@ function RecLSTM:__init(inputsize, hiddensize, outputsize) self.zeroCell = torch.Tensor() end -function RecLSTM:maskZero() +function RecLSTM:maskZero(v1) assert(torch.isTypeOf(self.modules[1], 'nn.StepLSTM')) for i,stepmodule in pairs(self.sharedClones) do - stepmodule:maskZero() + stepmodule:maskZero(v1) end - self.modules[1]:maskZero() + self.modules[1]:maskZero(v1) return self end ------------------------- forward backward ----------------------------- function RecLSTM:_updateOutput(input) + assert(input:dim() == 2, "RecLSTM expecting batchsize x inputsize tensor (Only supports batchmode)") local prevOutput, prevCell = unpack(self:getHiddenState(self.step-1, input)) -- output(t), cell(t) = lstm{input(t), output(t-1), cell(t-1)} @@ -185,4 +186,12 @@ function RecLSTM:setGradHiddenState(step, gradHiddenState) self.gradOutputs[step] = gradHiddenState[1] self.gradCells[step] = gradHiddenState[2] +end + +function RecLSTM:__tostring__() + if self.weightO then + return self.__typename .. string.format("(%d -> %d -> %d)", self.inputsize, self.hiddensize, self.outputsize) + else + return self.__typename .. string.format("(%d -> %d)", self.inputsize, self.outputsize) + end end \ No newline at end of file diff --git a/SeqLSTM.lua b/SeqLSTM.lua index ca91c14..f2f279b 100644 --- a/SeqLSTM.lua +++ b/SeqLSTM.lua @@ -76,6 +76,7 @@ function SeqLSTM:__init(inputsize, hiddensize, outputsize) -- set this to true for variable length sequences that seperate -- independent sequences with a step of zeros (a tensor of size D) self.maskzero = false + self.v2 = true end function SeqLSTM:reset(std) @@ -88,22 +89,15 @@ function SeqLSTM:reset(std) return self end --- unlike MaskZero, the mask is applied in-place -function SeqLSTM:recursiveMask(output, mask) - if torch.type(output) == 'table' then - for k,v in ipairs(output) do - self:recursiveMask(output[k], mask) +function SeqLSTM:zeroMaskState(state, step, cur_x) + if self.maskzero and self.zeroMask ~= false then + if self.v2 then + assert(self.zeroMask ~= nil, torch.type(self).." expecting zeroMask tensor or false") + nn.utils.recursiveZeroMask(state, self.zeroMask[step]) + else -- backwards compat + self.zeroMask = nn.utils.getZeroMaskBatch(cur_x, self.zeroMask) + nn.utils.recursiveZeroMask(state, self.zeroMask) end - else - assert(torch.isTensor(output)) - - -- make sure mask has the same dimension as the output tensor - local outputSize = output:size():fill(1) - outputSize[1] = output:size(1) - mask:resize(outputSize) - -- build mask - local zeroMask = mask:expandAs(output) - output:maskedFill(zeroMask, 0) end end @@ -123,11 +117,21 @@ function SeqLSTM:updateOutput(input) local seqlen, batchsize = input:size(1), input:size(2) local inputsize, hiddensize, outputsize = self.inputsize, self.hiddensize, self.outputsize + if self.maskzero and self.v2 and self.zeroMask ~= false then + if not torch.isTensor(self.zeroMask) then + error(torch.type(self).." expecting previous call to setZeroMask(zeroMask) with maskzero=true") + end + if (self.zeroMask:size(1) ~= seqlen) or (self.zeroMask:size(2) ~= batchsize) then + error(torch.type(self).." expecting zeroMask of size seqlen x batchsize, got " + ..self.zeroMask:size(1).." x "..self.zeroMask:size(2).." instead of "..seqlen.." x "..batchsize ) + end + end + -- remember previous state? local remember = self:hasMemory() local c0 = self.c0 - if c0:nElement() == 0 or not remember then + if (c0:nElement() ~= batchsize * hiddensize) or not remember then c0:resize(batchsize, hiddensize):zero() elseif remember then assert(self.cell:size(2) == batchsize, 'batch sizes must be constant to remember states') @@ -135,7 +139,7 @@ function SeqLSTM:updateOutput(input) end local h0 = self.h0 - if h0:nElement() == 0 or not remember then + if (h0:nElement() ~= batchsize * outputsize) or not remember then h0:resize(batchsize, outputsize):zero() elseif remember then assert(self.output:size(2) == batchsize, 'batch sizes must be the same to remember states') @@ -166,16 +170,7 @@ function SeqLSTM:updateOutput(input) next_h, next_c) end - if self.maskzero then - -- build mask from input - local vectorDim = cur_x:dim() - self._zeroMask = self._zeroMask or cur_x.new() - self._zeroMask:norm(cur_x, 2, vectorDim) - self.zeroMask = self.zeroMask or ((torch.type(cur_x) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) - self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) - -- zero masked output - self:recursiveMask({next_h, next_c, cur_gates}, self.zeroMask) - end + self:zeroMaskState({next_h, next_c, gates}, t, cur_x) prev_h, prev_c = next_h, next_c end @@ -208,16 +203,7 @@ function SeqLSTM:updateOutput(input) next_h:mm(self._hidden[t], self.weightO) end - if self.maskzero then - -- build mask from input - local vectorDim = cur_x:dim() - self._zeroMask = self._zeroMask or cur_x.new() - self._zeroMask:norm(cur_x, 2, vectorDim) - self.zeroMask = self.zeroMask or ((torch.type(cur_x) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) - self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) - -- zero masked output - self:recursiveMask({next_h, next_c, cur_gates}, self.zeroMask) - end + self:zeroMaskState({next_h, next_c, cur_gates}, t, cur_x) prev_h, prev_c = next_h, next_c end @@ -253,17 +239,7 @@ function SeqLSTM:backward(input, gradOutput, scale) end grad_next_h:add(gradOutput[t]) - if self.maskzero then --and not self.weightO then - -- we only do this for sub-classes (LSTM doesn't need it) - -- build mask from input - local vectorDim = cur_x:dim() - self._zeroMask = self._zeroMask or cur_x.new() - self._zeroMask:norm(cur_x, 2, vectorDim) - self.zeroMask = self.zeroMask or ((torch.type(cur_x) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) - self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) - -- zero masked gradOutput - self:recursiveMask({grad_next_h, grad_next_c}, self.zeroMask) - end + self:zeroMaskState({grad_next_h, grad_next_c}, t, cur_x) if self.weightO then self.grad_hidden = self.grad_hidden or cur_x.new() @@ -302,17 +278,7 @@ function SeqLSTM:backward(input, gradOutput, scale) local cur_x = input[t] - if self.maskzero and torch.type(self) ~= 'nn.SeqLSTM' then - -- we only do this for sub-classes (LSTM doesn't need it) - -- build mask from input - local vectorDim = cur_x:dim() - self._zeroMask = self._zeroMask or cur_x.new() - self._zeroMask:norm(cur_x, 2, vectorDim) - self.zeroMask = self.zeroMask or ((torch.type(cur_x) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) - self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) - -- zero masked gradOutput - self:recursiveMask(grad_next_h, self.zeroMask) - end + self:zeroMaskState({grad_next_h, grad_next_c}, t, cur_x) if self.weightO then -- LSTMP self.buffer3:resizeAs(grad_next_h):copy(grad_next_h) @@ -442,10 +408,8 @@ function SeqLSTM:evaluate() assert(self.train == false) end -function SeqLSTM:maskZero() - self.maskzero = true - return self -end +SeqLSTM.maskZero = nn.StepLSTM.maskZero +SeqLSTM.setZeroMask = nn.MaskZero.setZeroMask function SeqLSTM:parameters() return {self.weight, self.bias, self.weightO}, {self.gradWeight, self.gradBias, self.gradWeightO} diff --git a/Sequencer.lua b/Sequencer.lua index 3a267a9..da483f5 100644 --- a/Sequencer.lua +++ b/Sequencer.lua @@ -86,7 +86,7 @@ function Sequencer:updateOutput(input) end else for step=1,nStep do - self.tableoutput[step] = nn.rnn.recursiveCopy( + self.tableoutput[step] = nn.utils.recursiveCopy( self.tableoutput[step] or table.remove(self._output, 1), self.module:updateOutput(input[step]) ) diff --git a/StepLSTM.lua b/StepLSTM.lua index 4e4906e..7a3edc2 100644 --- a/StepLSTM.lua +++ b/StepLSTM.lua @@ -34,6 +34,7 @@ function StepLSTM:__init(inputsize, hiddensize, outputsize) -- set this to true for variable length sequences that seperate -- independent sequences with a step of zeros (a tensor of size D) self.maskzero = false + self.v2 = true end function StepLSTM:reset(std) @@ -46,26 +47,6 @@ function StepLSTM:reset(std) return self end --- unlike MaskZero, the mask is applied in-place -function StepLSTM:recursiveMask(output, mask) - if torch.type(output) == 'table' then - for k,v in ipairs(output) do - self:recursiveMask(output[k], mask) - end - else - assert(torch.isTensor(output)) - - -- make sure mask has the same dimension as the output tensor - local outputSize = output:size():fill(1) - outputSize[1] = output:size(1) - mask:resize(outputSize) - -- build mask - local zeroMask = mask:expandAs(output) - output:maskedFill(zeroMask, 0) - end -end - - function StepLSTM:updateOutput(input) self.recompute_backward = true local cur_x, prev_h, prev_c = input[1], input[2], input[3] @@ -123,14 +104,14 @@ function StepLSTM:updateOutput(input) end end - if self.maskzero then - -- build mask from input - local zero_mask = torch.getBuffer('StepLSTM', 'zero_mask', cur_x) - zero_mask:norm(cur_x, 2, 2) - self.zeroMask = self.zeroMask or ((torch.type(cur_x) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor()) - zero_mask.eq(self.zeroMask, zero_mask, 0) - -- zero masked output - self:recursiveMask({next_h, next_c, self.gates}, self.zeroMask) + if self.maskzero and self.zeroMask ~= false then + if self.v2 then + assert(self.zeroMask ~= nil, torch.type(self).." expecting zeroMask tensor or false") + else -- backwards compat + self.zeroMask = nn.utils.getZeroMaskBatch(cur_x, self.zeroMask) + end + -- zero masked outputs and gates + nn.utils.recursiveZeroMask({next_h, next_c, self.gates}, self.zeroMask) end return self.output @@ -148,9 +129,9 @@ function StepLSTM:backward(input, gradOutput, scale) local grad_gates = torch.getBuffer('StepLSTM', 'grad_gates', self.gates) -- batchsize x 4*outputsize local grad_gates_sum = torch.getBuffer('StepLSTM', 'grad_gates_sum', self.gates) -- 1 x 4*outputsize - if self.maskzero then + if self.maskzero and self.zeroMask ~= false then -- zero masked gradOutput - self:recursiveMask({grad_next_h, grad_next_c}, self.zeroMask) + nn.utils.recursiveZeroMask({grad_next_h, grad_next_c}, self.zeroMask) end if cur_x.nn.StepLSTM_backward and not self.forceLua then @@ -255,11 +236,6 @@ function StepLSTM:clearState() self.output[1]:set(); self.output[2]:set() self.gradInput[1]:set(); self.gradInput[2]:set(); self.gradInput[3]:set() - - self.zeroMask = nil - self._zeroMask = nil - self._maskbyte = nil - self._maskindices = nil end function StepLSTM:type(type, ...) @@ -271,11 +247,22 @@ function StepLSTM:parameters() return {self.weight, self.bias, self.weightO}, {self.gradWeight, self.gradBias, self.gradWeightO} end -function StepLSTM:maskZero() +function StepLSTM:maskZero(v1) self.maskzero = true + self.v2 = not v1 return self end +StepLSTM.setZeroMask = nn.MaskZero.setZeroMask + +function StepLSTM:__tostring__() + if self.weightO then + return self.__typename .. string.format("(%d -> %d -> %d)", self.inputsize, self.hiddensize, self.outputsize) + else + return self.__typename .. string.format("(%d -> %d)", self.inputsize, self.outputsize) + end +end + -- for sharedClone local _ = require 'moses' local params = _.clone(parent.dpnn_parameters) diff --git a/VariableLength.lua b/VariableLength.lua index 9e00ba8..0261b9b 100644 --- a/VariableLength.lua +++ b/VariableLength.lua @@ -1,41 +1,12 @@ local VariableLength, parent = torch.class("nn.VariableLength", "nn.Decorator") --- make sure your module has been set-up for zero-masking (that is, module:maskZero()) function VariableLength:__init(module, lastOnly) - parent.__init(self, module) + parent.__init(self, assert(module:maskZero())) -- only extract the last element of each sequence self.lastOnly = lastOnly -- defaults to false self.gradInput = {} end --- recursively masks input (inplace) -function VariableLength.recursiveMask(input, mask) - if torch.type(input) == 'table' then - for k,v in ipairs(input) do - self.recursiveMask(v, mask) - end - else - assert(torch.isTensor(input)) - - -- make sure mask has the same dimension as the input tensor - assert(mask:dim() == 2, "Expecting batchsize x seqlen mask tensor") - -- expand mask to input (if necessary) - local zeroMask - if input:dim() == 2 then - zeroMask = mask - elseif input:dim() > 2 then - local inputSize = input:size():fill(1) - inputSize[1] = input:size(1) - inputSize[2] = input:size(2) - zeroMask = mask:view(inputSize):expandAs(input) - else - error"Expecting batchsize x seqlen [ x ...] input tensor" - end - -- zero-mask input in between sequences - input:maskedFill(zeroMask, 0) - end -end - function VariableLength:updateOutput(input) -- input is a table of batchSize tensors assert(torch.type(input) == 'table') @@ -51,7 +22,8 @@ function VariableLength:updateOutput(input) self.indexes, self.mappedLengths = self._input.nn.VariableLength_FromSamples(input, self._input, self._mask) -- zero-mask the _input where mask is 1 - self.recursiveMask(self._input, self._mask) + nn.utils.recursiveZeroMask(self._input, self._mask) + self.modules[1]:setZeroMask(self._mask) -- feedforward the zero-mask format through the decorated module local output = self.modules[1]:updateOutput(self._input) @@ -82,7 +54,7 @@ function VariableLength:updateGradInput(input, gradOutput) end -- zero-mask the _gradOutput where mask is 1 - self.recursiveMask(self._gradOutput, self._mask) + nn.utils.recursiveZeroMask(self._gradOutput, self._mask) -- updateGradInput decorated module local gradInput = self.modules[1]:updateGradInput(self._input, self._gradOutput) @@ -107,4 +79,8 @@ function VariableLength:clearState() self._gradOutput = nil self._input = nil return parent.clearState(self) +end + +function VariableLength:setZeroMask() + error"Not Supported" end \ No newline at end of file diff --git a/examples/encoder-decoder-coupling.lua b/examples/encoder-decoder-coupling.lua index 71bc4e7..18b2517 100644 --- a/examples/encoder-decoder-coupling.lua +++ b/examples/encoder-decoder-coupling.lua @@ -65,10 +65,22 @@ local decInSeq = torch.Tensor({{6,1,2,3,4,0,0,0},{6,5,4,3,2,1,0,0}}):t() -- Label '7' represents the end of sentence (EOS). local decOutSeq = torch.Tensor({{1,2,3,4,7,0,0,0},{5,4,3,2,1,7,0,0}}):t() +-- the zeroMasks are used for zeroing intermediate RNN states where the zeroMask = 1 +-- randomly set the zeroMasks from the input sequence or explicitly +local encZeroMask = math.random() < 0.5 and nn.utils.getZeroMaskSequence(encInSeq) -- auto zeroMask from input sequence + or torch.ByteTensor({{1,1,1,1,0,0,0},{1,1,1,0,0,0,0}}):t():contiguous() -- explicit zeroMask +local decZeroMask = math.random() < 0.5 and nn.utils.getZeroMaskSequence(decInSeq) + or torch.ByteTensor({{0,0,0,0,0,1,1,1},{0,0,0,0,0,0,1,1}}):t():contiguous() + for i=1,opt.niter do enc:zeroGradParameters() dec:zeroGradParameters() + -- zero-masking + enc:setZeroMask(encZeroMask) + dec:setZeroMask(decZeroMask) + criterion:setZeroMask(decZeroMask) + -- Forward pass local encOut = enc:forward(encInSeq) forwardConnect(enc, dec) diff --git a/examples/multigpu-nce-rnnlm.lua b/examples/multigpu-nce-rnnlm.lua index 4dc4106..e18f399 100644 --- a/examples/multigpu-nce-rnnlm.lua +++ b/examples/multigpu-nce-rnnlm.lua @@ -112,10 +112,9 @@ if not lm then for i,hiddensize in ipairs(opt.hiddensize) do -- this is a faster version of nn.Sequencer(nn.RecLSTM(inpusize, hiddensize)) local rnn = opt.projsize < 1 and nn.SeqLSTM(inputsize, hiddensize) - or nn.SeqLSTMP(inputsize, opt.projsize, hiddensize) -- LSTM with a projection layer - rnn.maskzero = true + or nn.SeqLSTM(inputsize, opt.projsize, hiddensize) -- LSTM with a projection layer local device = i <= #opt.hiddensize/2 and 1 or 2 - lm:add(nn.GPU(rnn, device):cuda()) + lm:add(nn.GPU(rnn:maskZero(true), device):cuda()) if opt.dropout > 0 then lm:add(nn.GPU(nn.Dropout(opt.dropout), device):cuda()) end @@ -146,7 +145,7 @@ if not lm then :add(nn.ZipTable()) -- {{x1,x2,...}, {t1,t2,...}} -> {{x1,t1},{x2,t2},...} -- encapsulate stepmodule into a Sequencer - local masked = nn.MaskZero(ncemodule, 1):cuda() + local masked = nn.MaskZero(ncemodule, true):cuda() lm:add(nn.GPU(nn.Sequencer(masked), 3, opt.device):cuda()) -- remember previous state between batches @@ -165,7 +164,7 @@ end if not (criterion and targetmodule) then --[[ loss function ]]-- - local crit = nn.MaskZeroCriterion(nn.NCECriterion(), 0) + local crit = nn.MaskZeroCriterion(nn.NCECriterion(), true) -- target is also seqlen x batchsize. targetmodule = nn.Sequential() diff --git a/examples/noise-contrastive-estimate.lua b/examples/noise-contrastive-estimate.lua index 2319594..1066388 100644 --- a/examples/noise-contrastive-estimate.lua +++ b/examples/noise-contrastive-estimate.lua @@ -105,11 +105,10 @@ if not lm then -- rnn layers local inputsize = opt.inputsize for i,hiddensize in ipairs(opt.hiddensize) do - -- this is a faster version of nn.Sequencer(nn.RecSTM(inpusize, hiddensize)) - local rnn = opt.projsize < 1 and nn.SeqLSTM(inputsize, hiddensize) - or nn.SeqLSTMP(inputsize, opt.projsize, hiddensize) -- LSTM with a projection layer - rnn.maskzero = true - lm:add(rnn) + -- this is a faster version of nn.Sequencer(nn.RecLSTM(inpusize, hiddensize)) + local rnn = opt.projsize < 1 and nn.SeqLSTM(inputsize, hiddensize) + or nn.SeqLSTM(inputsize, opt.projsize, hiddensize) -- LSTM with a projection layer + lm:add(rnn:maskZero()) if opt.dropout > 0 then lm:add(nn.Dropout(opt.dropout)) end @@ -130,7 +129,7 @@ if not lm then :add(nn.ZipTable()) -- {{x1,x2,...}, {t1,t2,...}} -> {{x1,t1},{x2,t2},...} -- encapsulate stepmodule into a Sequencer - lm:add(nn.Sequencer(nn.MaskZero(ncemodule, 1))) + lm:add(nn.Sequencer(nn.MaskZero(ncemodule))) -- remember previous state between batches lm:remember() @@ -155,7 +154,7 @@ end if not (criterion and targetmodule) then --[[ loss function ]]-- - local crit = nn.MaskZeroCriterion(nn.NCECriterion(), 0) + local crit = nn.MaskZeroCriterion(nn.NCECriterion()) -- target is also seqlen x batchsize. targetmodule = nn.SplitTable(1) @@ -199,6 +198,7 @@ if not xplog then end local ntrial = 0 +local zeroMask local epoch = xplog.epoch+1 opt.lr = opt.lr or opt.startlr opt.trainsize = opt.trainsize == -1 and trainset:size() or opt.trainsize @@ -215,6 +215,9 @@ while opt.maxepoch <= 0 or epoch <= opt.maxepoch do for i, inputs, targets in trainset:subiter(opt.seqlen, opt.trainsize) do targets = targetmodule:forward(targets) inputs = {inputs, targets} + -- zero-mask + zeroMask = nn.utils.getZeroMaskSequence(inputs[1], zeroMask) + nn.utils.setZeroMask({lm, criterion}, zeroMask, opt.cuda) -- forward local outputs = lm:forward(inputs) local err = criterion:forward(outputs, targets) @@ -273,6 +276,10 @@ while opt.maxepoch <= 0 or epoch <= opt.maxepoch do local sumErr = 0 for i, inputs, targets in validset:subiter(opt.seqlen, opt.validsize) do targets = targetmodule:forward(targets) + -- zero-mask + zeroMask = nn.utils.getZeroMaskSequence(inputs, zeroMask) + nn.utils.setZeroMask({lm, criterion}, zeroMask, opt.cuda) + -- forward local outputs = lm:forward{inputs, targets} local err = criterion:forward(outputs, targets) sumErr = sumErr + err diff --git a/examples/simple-bisequencer-network-variable.lua b/examples/simple-bisequencer-network-variable.lua index 389ad27..d0adb28 100644 --- a/examples/simple-bisequencer-network-variable.lua +++ b/examples/simple-bisequencer-network-variable.lua @@ -17,7 +17,7 @@ local sharedLookupTable = nn.LookupTableMaskZero(nIndex, hiddenSize) -- forward rnn local fwd = nn.Sequential() :add(sharedLookupTable) - :add(nn.RecLSTM(hiddenSize, hiddenSize):maskZero(1)) + :add(nn.RecLSTM(hiddenSize, hiddenSize):maskZero(true)) -- internally, rnn will be wrapped into a Recursor to make it an AbstractRecurrent instance. fwdSeq = nn.Sequencer(fwd) @@ -25,7 +25,7 @@ fwdSeq = nn.Sequencer(fwd) -- backward rnn (will be applied in reverse order of input sequence) local bwd = nn.Sequential() :add(sharedLookupTable:sharedClone()) - :add(nn.RecLSTM(hiddenSize, hiddenSize):maskZero(1)) + :add(nn.RecLSTM(hiddenSize, hiddenSize):maskZero(true)) bwdSeq = nn.Sequencer(bwd) -- merges the output of one time-step of fwd and bwd rnns. @@ -44,14 +44,14 @@ local brnn = nn.Sequential() local rnn = nn.Sequential() :add(brnn) - :add(nn.Sequencer(nn.MaskZero(nn.Linear(hiddenSize*2, nIndex), 1))) -- times two due to JoinTable - :add(nn.Sequencer(nn.MaskZero(nn.LogSoftMax(), 1))) + :add(nn.Sequencer(nn.MaskZero(nn.Linear(hiddenSize*2, nIndex), true))) -- times two due to JoinTable + :add(nn.Sequencer(nn.MaskZero(nn.LogSoftMax(), true))) print(rnn) -- build criterion -criterion = nn.SequencerCriterion(nn.MaskZeroCriterion(nn.ClassNLLCriterion(), 1)) +criterion = nn.SequencerCriterion(nn.MaskZeroCriterion(nn.ClassNLLCriterion(), true)) -- build dummy dataset (task is to predict next item, given previous) sequence_ = torch.LongTensor():range(1,10) -- 1,2,3,4,5,6,7,8,9,10 diff --git a/examples/twitter_sentiment_rnn.lua b/examples/twitter_sentiment_rnn.lua index b06bbe7..7ea5a16 100644 --- a/examples/twitter_sentiment_rnn.lua +++ b/examples/twitter_sentiment_rnn.lua @@ -6,7 +6,6 @@ require 'paths' require 'optim' require 'rnn' -require 'nngraph' require 'cutorch' require 'cunn' local dl = require 'dataload' @@ -78,7 +77,7 @@ trainSet, validSet, testSet = dl.loadSentiment140(datapath, minFreq, -- Model if not opt.loadModel then print("Building model") - modelPath = paths.concat(savepath, + modelPath = paths.concat(savepath, "Sentiment140_model_" .. dl.uniqueid() .. ".net") lookupDim = tonumber(opt.lookupDim) lookupDropout = tonumber(opt.lookupDropout) @@ -98,13 +97,11 @@ if not opt.loadModel then -- Recurrent layers local inputSize = lookupDim for i, hiddenSize in ipairs(hiddenSizes) do - local rnn = nn.SeqLSTM(inputSize, hiddenSize) - rnn.maskzero = true - model:add(rnn) + model:add(nn.SeqLSTM(inputSize, hiddenSize):maskZero(true)) if dropouts[i] ~= 0 and dropouts[i] ~= nil then model:add(nn.Dropout(dropouts[i])) end - inputSize = hiddenSize + inputSize = hiddenSize end model:add(nn.Select(1, -1)) @@ -119,7 +116,7 @@ end print("Model path: " .. modelPath) collectgarbage() --- Criterion +-- Criterion criterion = nn.ClassNLLCriterion() -- Training @@ -275,7 +272,7 @@ for epoch=1, epochs do else earlyStopCount = earlyStopCount + 1 end - + if earlyStopCount >= earlyStopThresh then print("Early stopping at epoch: " .. tostring(epoch)) break diff --git a/init.lua b/init.lua index 0fa12d7..989ddbd 100644 --- a/init.lua +++ b/init.lua @@ -53,6 +53,8 @@ require('rnn.BatchNormalization') -- modules +require('rnn.LookupTableMaskZero') +require('rnn.MaskZero') require('rnn.PrintSize') require('rnn.Convert') require('rnn.Constant') @@ -81,8 +83,6 @@ require('rnn.SAdd') require('rnn.CopyGrad') require('rnn.VariableLength') require('rnn.StepLSTM') -require('rnn.LookupTableMaskZero') -require('rnn.MaskZero') require('rnn.SpatialBinaryConvolution') require('rnn.SimpleColorTransform') require('rnn.PCAColorTransform') diff --git a/test/test.lua b/test/test.lua index b36ecab..a526ebb 100644 --- a/test/test.lua +++ b/test/test.lua @@ -83,12 +83,12 @@ function rnntest.RecLSTM_main() "LSTM gradParam "..i.." error "..tostring(gradParam).." "..tostring(gradParam2)) end - gradParams = lstm.recursiveCopy(nil, gradParams) + gradParams = nn.utils.recursiveCopy(nil, gradParams) gradInput = gradInput:clone() mytester:assert(lstm.zeroOutput:sum() == 0, "zeroOutput error") mytester:assert(lstm.zeroCell:sum() == 0, "zeroCell error") lstm:forget() - output = lstm.recursiveCopy(nil, output) + output = nn.utils.recursiveCopy(nil, output) local output3 = {} lstm:zeroGradParameters() for step=1,nStep do @@ -193,11 +193,11 @@ function rnntest.GRU() "gru gradParam "..i.." error "..tostring(gradParam).." "..tostring(gradParam2)) end - gradParams = gru.recursiveCopy(nil, gradParams) + gradParams = nn.utils.recursiveCopy(nil, gradParams) gradInput = gradInput:clone() mytester:assert(gru.zeroTensor:sum() == 0, "zeroTensor error") gru:forget() - output = gru.recursiveCopy(nil, output) + output = nn.utils.recursiveCopy(nil, output) local output3 = {} gru:zeroGradParameters() for step=1,nStep do @@ -1670,347 +1670,6 @@ function rnntest.RepeaterCriterion() end -function rnntest.LSTM_nn_vs_nngraph() - local model = {} - -- match the successful https://github.com/wojzaremba/lstm - -- We want to make sure our LSTM matches theirs. - -- Also, the ugliest unit test you have every seen. - -- Resolved 2-3 annoying bugs with it. - local success = pcall(function() require 'nngraph' end) - if not success then - return - end - - local vocabSize = 100 - local inputSize = 30 - local batchSize = 4 - local nLayer = 2 - local dropout = 0 - local nStep = 10 - local lr = 1 - - -- build nn equivalent of nngraph model - local model2 = nn.Sequential() - local container2 = nn.Container() - container2:add(nn.LookupTable(vocabSize, inputSize)) - model2:add(container2:get(1)) - local dropout2 = nn.Dropout(dropout) - model2:add(dropout2) - local seq21 = nn.SplitTable(1,2) - model2:add(seq21) - container2:add(nn.FastLSTM(inputSize, inputSize)) - local seq22 = nn.Sequencer(container2:get(2)) - model2:add(seq22) - local seq24 = nn.Sequencer(nn.Dropout(0)) - model2:add(seq24) - container2:add(nn.FastLSTM(inputSize, inputSize)) - local seq23 = nn.Sequencer(container2:get(3)) - model2:add(seq23) - local seq25 = nn.Sequencer(nn.Dropout(0)) - model2:add(seq25) - container2:add(nn.Linear(inputSize, vocabSize)) - local mlp = nn.Sequential():add(container2:get(4)):add(nn.LogSoftMax()) -- test double encapsulation - model2:add(nn.Sequencer(mlp)) - - local criterion2 = nn.ModuleCriterion(nn.SequencerCriterion(nn.ClassNLLCriterion()), nil, nn.SplitTable(1,1)) - - - -- nngraph model - local container = nn.Container() - local lstmId = 1 - local function lstm(x, prev_c, prev_h) - -- Calculate all four gates in one go - local i2h = nn.Linear(inputSize, 4*inputSize) - local dummy = nn.Container() - dummy:add(i2h) - i2h = i2h(x) - local h2h = nn.LinearNoBias(inputSize, 4*inputSize) - dummy:add(h2h) - h2h = h2h(prev_h) - container:add(dummy) - local gates = nn.CAddTable()({i2h, h2h}) - - -- Reshape to (batch_size, n_gates, hid_size) - -- Then slize the n_gates dimension, i.e dimension 2 - local reshaped_gates = nn.Reshape(4,inputSize)(gates) - local sliced_gates = nn.SplitTable(2)(reshaped_gates) - - -- Use select gate to fetch each gate and apply nonlinearity - local in_gate = nn.Sigmoid()(nn.SelectTable(1)(sliced_gates)) - local in_transform = nn.Tanh()(nn.SelectTable(2)(sliced_gates)) - local forget_gate = nn.Sigmoid()(nn.SelectTable(3)(sliced_gates)) - local out_gate = nn.Sigmoid()(nn.SelectTable(4)(sliced_gates)) - - local next_c = nn.CAddTable()({ - nn.CMulTable()({forget_gate, prev_c}), - nn.CMulTable()({in_gate, in_transform}) - }) - local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) - lstmId = lstmId + 1 - return next_c, next_h - end - local function create_network() - local x = nn.Identity()() - local y = nn.Identity()() - local prev_s = nn.Identity()() - local lookup = nn.LookupTable(vocabSize, inputSize) - container:add(lookup) - local identity = nn.Identity() - lookup = identity(lookup(x)) - local i = {[0] = lookup} - local next_s = {} - local split = {prev_s:split(2 * nLayer)} - for layer_idx = 1, nLayer do - local prev_c = split[2 * layer_idx - 1] - local prev_h = split[2 * layer_idx] - local dropped = nn.Dropout(dropout)(i[layer_idx - 1]) - local next_c, next_h = lstm(dropped, prev_c, prev_h) - table.insert(next_s, next_c) - table.insert(next_s, next_h) - i[layer_idx] = next_h - end - - local h2y = nn.Linear(inputSize, vocabSize) - container:add(h2y) - local dropped = nn.Dropout(dropout)(i[nLayer]) - local pred = nn.LogSoftMax()(h2y(dropped)) - local err = nn.ClassNLLCriterion()({pred, y}) - local module = nn.gModule({x, y, prev_s}, {err, nn.Identity()(next_s)}) - module:getParameters():uniform(-0.1, 0.1) - module._lookup = identity - return module - end - - local function g_cloneManyTimes(net, T) - local clones = {} - local params, gradParams = net:parameters() - local mem = torch.MemoryFile("w"):binary() - assert(net._lookup) - mem:writeObject(net) - for t = 1, T do - local reader = torch.MemoryFile(mem:storage(), "r"):binary() - local clone = reader:readObject() - reader:close() - local cloneParams, cloneGradParams = clone:parameters() - for i = 1, #params do - cloneParams[i]:set(params[i]) - cloneGradParams[i]:set(gradParams[i]) - end - clones[t] = clone - collectgarbage() - end - mem:close() - return clones - end - - local model = {} - local paramx, paramdx - local core_network = create_network() - - -- sync nn with nngraph model - local params, gradParams = container:getParameters() - local params2, gradParams2 = container2:getParameters() - params2:copy(params) - container:zeroGradParameters() - container2:zeroGradParameters() - paramx, paramdx = core_network:getParameters() - - model.s = {} - model.ds = {} - model.start_s = {} - for j = 0, nStep do - model.s[j] = {} - for d = 1, 2 * nLayer do - model.s[j][d] = torch.zeros(batchSize, inputSize) - end - end - for d = 1, 2 * nLayer do - model.start_s[d] = torch.zeros(batchSize, inputSize) - model.ds[d] = torch.zeros(batchSize, inputSize) - end - model.core_network = core_network - model.rnns = g_cloneManyTimes(core_network, nStep) - model.norm_dw = 0 - model.err = torch.zeros(nStep) - - -- more functions for nngraph baseline - local function g_replace_table(to, from) - assert(#to == #from) - for i = 1, #to do - to[i]:copy(from[i]) - end - end - - local function reset_ds() - for d = 1, #model.ds do - model.ds[d]:zero() - end - end - - local function reset_state(state) - state.pos = 1 - if model ~= nil and model.start_s ~= nil then - for d = 1, 2 * nLayer do - model.start_s[d]:zero() - end - end - end - - local function fp(state) - g_replace_table(model.s[0], model.start_s) - if state.pos + nStep > state.data:size(1) then - error"Not Supposed to happen in this unit test" - end - for i = 1, nStep do - local x = state.data[state.pos] - local y = state.data[state.pos + 1] - local s = model.s[i - 1] - model.err[i], model.s[i] = unpack(model.rnns[i]:forward({x, y, s})) - state.pos = state.pos + 1 - end - g_replace_table(model.start_s, model.s[nStep]) - return model.err:mean() - end - - model.dss = {} - local function bp(state) - paramdx:zero() - local __, gradParams = core_network:parameters() - for i=1,#gradParams do - mytester:assert(gradParams[i]:sum() == 0) - end - reset_ds() -- backward of last step in each sequence is zero - for i = nStep, 1, -1 do - state.pos = state.pos - 1 - local x = state.data[state.pos] - local y = state.data[state.pos + 1] - local s = model.s[i - 1] - local derr = torch.ones(1) - local tmp = model.rnns[i]:backward({x, y, s}, {derr, model.ds,})[3] - model.dss[i-1] = tmp - g_replace_table(model.ds, tmp) - end - state.pos = state.pos + nStep - paramx:add(-lr, paramdx) - end - - -- inputs and targets (for nngraph implementation) - local inputs = torch.Tensor(nStep*10, batchSize):random(1,vocabSize) - - -- is everything aligned between models? - local params_, gradParams_ = container:parameters() - local params2_, gradParams2_ = container2:parameters() - - for i=1,#params_ do - mytester:assertTensorEq(params_[i], params2_[i], 0.00001, "nn vs nngraph unaligned params err "..i) - mytester:assertTensorEq(gradParams_[i], gradParams2_[i], 0.00001, "nn vs nngraph unaligned gradParams err "..i) - end - - -- forward - local state = {pos=1,data=inputs} - local err = fp(state) - - local inputs2 = inputs:narrow(1,1,nStep):transpose(1,2) - local targets2 = inputs:narrow(1,2,nStep):transpose(1,2) - local outputs2 = model2:forward(inputs2) - local err2 = criterion2:forward(outputs2, targets2) - mytester:assert(math.abs(err - err2/nStep) < 0.0001, "nn vs nngraph err error") - - -- backward/update - bp(state) - - local gradOutputs2 = criterion2:backward(outputs2, targets2) - model2:backward(inputs2, gradOutputs2) - model2:updateParameters(lr) - model2:zeroGradParameters() - - for i=1,#gradParams2_ do - mytester:assert(gradParams2_[i]:sum() == 0) - end - - for i=1,#params_ do - mytester:assertTensorEq(params_[i], params2_[i], 0.00001, "nn vs nngraph params err "..i) - end - - for i=1,nStep do - mytester:assertTensorEq(model.rnns[i]._lookup.output, dropout2.output:select(2,i), 0.0000001) - mytester:assertTensorEq(model.rnns[i]._lookup.gradInput, dropout2.gradInput:select(2,i), 0.0000001) - end - - -- next_c, next_h, next_c... - for i=nStep-1,2,-1 do - mytester:assertTensorEq(model.dss[i][1], container2:get(2).gradCells[i], 0.0000001, "gradCells1 err "..i) - mytester:assertTensorEq(model.dss[i][2], container2:get(2)._gradOutputs[i] - seq24.gradInput[i], 0.0000001, "gradOutputs1 err "..i) - mytester:assertTensorEq(model.dss[i][3], container2:get(3).gradCells[i], 0.0000001, "gradCells2 err "..i) - mytester:assertTensorEq(model.dss[i][4], container2:get(3)._gradOutputs[i] - seq25.gradInput[i], 0.0000001, "gradOutputs2 err "..i) - end - - for i=1,#params2_ do - params2_[i]:copy(params_[i]) - gradParams_[i]:copy(gradParams2_[i]) - end - - local gradInputClone = dropout2.gradInput:select(2,1):clone() - - local start_s = _.map(model.start_s, function(k,v) return v:clone() end) - mytester:assertTensorEq(start_s[1], container2:get(2).cells[nStep], 0.0000001) - mytester:assertTensorEq(start_s[2], container2:get(2).outputs[nStep], 0.0000001) - mytester:assertTensorEq(start_s[3], container2:get(3).cells[nStep], 0.0000001) - mytester:assertTensorEq(start_s[4], container2:get(3).outputs[nStep], 0.0000001) - - -- and do it again - -- forward - -- reset_state(state) - - local inputs2 = inputs:narrow(1,nStep+1,nStep):transpose(1,2) - local targets2 = inputs:narrow(1,nStep+2,nStep):transpose(1,2) - model2:remember() - local outputs2 = model2:forward(inputs2) - - local inputsClone = seq21.output[nStep]:clone() - local outputsClone = container2:get(2).outputs[nStep]:clone() - local cellsClone = container2:get(2).cells[nStep]:clone() - local err2 = criterion2:forward(outputs2, targets2) - local state = {pos=nStep+1,data=inputs} - local err = fp(state) - mytester:assert(math.abs(err2/nStep - err) < 0.00001, "nn vs nngraph err error") - -- backward/update - bp(state) - - local gradOutputs2 = criterion2:backward(outputs2, targets2) - model2:backward(inputs2, gradOutputs2) - - mytester:assertTensorEq(start_s[1], container2:get(2).cells[nStep], 0.0000001) - mytester:assertTensorEq(start_s[2], container2:get(2).outputs[nStep], 0.0000001) - mytester:assertTensorEq(start_s[3], container2:get(3).cells[nStep], 0.0000001) - mytester:assertTensorEq(start_s[4], container2:get(3).outputs[nStep], 0.0000001) - - model2:updateParameters(lr) - - mytester:assertTensorEq(inputsClone, seq21.output[nStep], 0.000001) - mytester:assertTensorEq(outputsClone, container2:get(2).outputs[nStep], 0.000001) - mytester:assertTensorEq(cellsClone, container2:get(2).cells[nStep], 0.000001) - - -- next_c, next_h, next_c... - for i=nStep-1,2,-1 do - mytester:assertTensorEq(model.dss[i][1], container2:get(2).gradCells[i+nStep], 0.0000001, "gradCells1 err "..i) - mytester:assertTensorEq(model.dss[i][2], container2:get(2)._gradOutputs[i+nStep] - seq24.gradInput[i], 0.0000001, "gradOutputs1 err "..i) - mytester:assertTensorEq(model.dss[i][3], container2:get(3).gradCells[i+nStep], 0.0000001, "gradCells2 err "..i) - mytester:assertTensorEq(model.dss[i][4], container2:get(3)._gradOutputs[i+nStep] - seq25.gradInput[i], 0.0000001, "gradOutputs2 err "..i) - end - - mytester:assertTensorNe(gradInputClone, dropout2.gradInput:select(2,1), 0.0000001, "lookup table gradInput1 err") - - for i=1,nStep do - mytester:assertTensorEq(model.rnns[i]._lookup.output, dropout2.output:select(2,i), 0.0000001, "lookup table output err "..i) - mytester:assertTensorEq(model.rnns[i]._lookup.gradInput, dropout2.gradInput:select(2,i), 0.0000001, "lookup table gradInput err "..i) - end - - for i=1,#params_ do - mytester:assertTensorEq(params_[i], params2_[i], 0.00001, "nn vs nngraph second update params err "..i) - end -end - function rnntest.LSTM_char_rnn() -- benchmark our LSTM against char-rnn's LSTM if not benchmark then @@ -2018,7 +1677,6 @@ function rnntest.LSTM_char_rnn() end local success = pcall(function() - require 'nngraph' require 'cunn' end) if not success then @@ -2243,7 +1901,7 @@ function rnntest.LSTM_char_rnn() local inputSize = input_size for L=1,n do - seq:add(nn.FastLSTM(inputSize, rnn_size)) + seq:add(nn.RecLSTM(inputSize, rnn_size)) inputSize = rnn_size end @@ -2257,9 +1915,7 @@ function rnntest.LSTM_char_rnn() return lstm end - nn.FastLSTM.usenngraph = true local lstm2 = makeRnnLSTM(input_size, rnn_size, n_layer, gpu) - nn.FastLSTM.usenngraph = false local function trainRnn(x, y, fwdOnly) local outputs = lstm2:forward(x) @@ -2302,21 +1958,17 @@ function rnntest.LSTM_char_rnn() print("runtime: char, rnn, char/rnn", chartime, rnntime, chartime/rnntime) -- on NVIDIA Titan Black : - -- with FastLSTM.usenngraph = false : - -- setuptime : char, rnn, char/rnn 1.5070691108704 1.1547832489014 1.3050666541138 - -- runtime: char, rnn, char/rnn 1.0558769702911 1.7060630321503 0.61889681119246 - -- with FastLSTM.usenngraph = true : -- setuptime : char, rnn, char/rnn 1.5920469760895 2.4352579116821 0.65374881586558 -- runtime: char, rnn, char/rnn 1.0614919662476 1.124755859375 0.94375322199913 end -function rnntest.LSTM_checkgrad() +function rnntest.RecLSTM_checkgrad() if not pcall(function() require 'optim' end) then return end local hiddenSize = 2 local nIndex = 2 - local r = nn.LSTM(hiddenSize, hiddenSize) + local r = nn.RecLSTM(hiddenSize, hiddenSize) local rnn = nn.Sequential() rnn:add(r) @@ -2325,8 +1977,8 @@ function rnntest.LSTM_checkgrad() rnn = nn.Recursor(rnn) local criterion = nn.ClassNLLCriterion() - local inputs = torch.randn(4, 2) - local targets = torch.Tensor{1, 2, 1, 2}:resize(4, 1) + local inputs = torch.randn(3, 4, 2) + local targets = torch.Tensor(3,4):random(1,2) local parameters, grads = rnn:getParameters() function f(x) @@ -2349,7 +2001,7 @@ function rnntest.LSTM_checkgrad() end local err = optim.checkgrad(f, parameters:clone()) - mytester:assert(err < 0.0001, "LSTM optim.checkgrad error") + mytester:assert(err < 0.0001, "RecLSTM optim.checkgrad error") end function rnntest.Recursor() @@ -2609,11 +2261,6 @@ function rnntest.MaskZero_main() -- Note we use lstmModule input signature and firstElement to prevent duplicate code for name, recurrent in pairs(recurrents) do -- test encapsulated module first - -- non batch - local i = torch.rand(10) - local e = nn.Sigmoid():forward(i) - local o = firstElement(recurrent:forward({i, torch.zeros(10), torch.zeros(10)})) - mytester:assertlt(torch.norm(o - e), precision, 'mock ' .. name .. ' failed for non batch') -- batch local i = torch.rand(5, 10) local e = nn.Sigmoid():forward(i) @@ -2621,16 +2268,8 @@ function rnntest.MaskZero_main() mytester:assertlt(torch.norm(o - e), precision, 'mock ' .. name .. ' module failed for batch') -- test mask zero module now - local module = nn.MaskZero(recurrent, 1) - -- non batch forward - local i = torch.rand(10) - local e = firstElement(recurrent:forward({i, torch.rand(10), torch.rand(10)})) - local o = firstElement(module:forward({i, torch.rand(10), torch.rand(10)})) - mytester:assertgt(torch.norm(i - o), precision, 'error on non batch forward for ' .. name) - mytester:assertlt(torch.norm(e - o), precision, 'error on non batch forward for ' .. name) - local i = torch.zeros(10) - local o = firstElement(module:forward({i, torch.rand(10), torch.rand(10)})) - mytester:assertlt(torch.norm(i - o), precision, 'error on non batch forward for ' .. name) + local module = nn.MaskZero(recurrent) + module:setZeroMask(false) -- batch forward local i = torch.rand(5, 10) local e = firstElement(recurrent:forward({i, torch.rand(5, 10), torch.rand(5, 10)})) @@ -2638,9 +2277,11 @@ function rnntest.MaskZero_main() mytester:assertgt(torch.norm(i - o), precision, 'error on batch forward for ' .. name) mytester:assertlt(torch.norm(e - o), precision, 'error on batch forward for ' .. name) local i = torch.zeros(5, 10) + module:setZeroMask(torch.ByteTensor(5):fill(1)) local o = firstElement(module:forward({i, torch.rand(5, 10), torch.rand(5, 10)})) mytester:assertlt(torch.norm(i - o), precision, 'error on batch forward for ' .. name) local i = torch.Tensor({{0, 0, 0}, {1, 2, 5}}) + module:setZeroMask(torch.ByteTensor({1,0})) -- clone r because it will be update by module:forward call local r = firstElement(recurrent:forward({i, torch.rand(2, 3), torch.rand(2, 3)})):clone() local o = firstElement(module:forward({i, torch.rand(2, 3), torch.rand(2, 3)})) @@ -2655,22 +2296,25 @@ function rnntest.MaskZero_main() -- Use a SplitTable and SelectTable to adapt module local module = nn.Sequential() module:add(nn.SplitTable(1)) - module:add(nn.MaskZero(recurrent, 1)) + module:add(nn.MaskZero(recurrent)) if name == 'lstm' then module:add(nn.SelectTable(1)) end local input = torch.rand(name == 'lstm' and 3 or 2, 10) + module:setZeroMask(false) local err = jac.testJacobian(module, input) mytester:assertlt(err, precision, 'error on state for ' .. name) -- IO - local ferr,berr = jac.testIO(module,input) + function module.clearState(self) return self end + local ferr,berr = jac.testIO(module, input) mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err for ' .. name) mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err for ' .. name) -- batch -- rebuild module to avoid correlated tests local module = nn.Sequential() module:add(nn.SplitTable(1)) - module:add(nn.MaskZero(recurrent, 1)) + module:add(nn.MaskZero(recurrent)) if name == 'lstm' then module:add(nn.SelectTable(1)) end + module:setZeroMask(false) local input = torch.rand(name == 'lstm' and 3 or 2, 5, 10) local err = jac.testJacobian(module,input) @@ -2727,62 +2371,12 @@ function rnntest.AbstractRecurrent_maskZero() input:select(2,4):copy(sequence) - for i=1,4 do - table.insert(inputs, input[i]) - end - - - local function testmask(rnn) - local seq = nn.Sequencer(rnn:maskZero(1)) - - local outputs = seq:forward(inputs) - - mytester:assert(math.abs(outputs[1]:narrow(1,1,3):sum()) < 0.0000001, torch.type(rnn).." mask zero 1 err") - mytester:assert(math.abs(outputs[2]:narrow(1,1,2):sum()) < 0.0000001, torch.type(rnn).." mask zero 2 err") - mytester:assert(math.abs(outputs[3]:narrow(1,1,1):sum()) < 0.0000001, torch.type(rnn).." mask zero 3 err") - - mytester:assertTensorEq(outputs[1][4], outputs[2][3], 0.0000001, torch.type(rnn).." mask zero err") - mytester:assertTensorEq(outputs[1][4], outputs[3][2], 0.0000001, torch.type(rnn).." mask zero err") - mytester:assertTensorEq(outputs[1][4], outputs[4][1], 0.0000001, torch.type(rnn).." mask zero err") - - mytester:assertTensorEq(outputs[2][4], outputs[3][3], 0.0000001, torch.type(rnn).." mask zero err") - mytester:assertTensorEq(outputs[2][4], outputs[4][2], 0.0000001, torch.type(rnn).." mask zero err") - - mytester:assertTensorEq(outputs[3][4], outputs[4][3], 0.0000001, torch.type(rnn).." mask zero err") - end - - local rm = nn.Sequential() - :add(nn.ParallelTable() - :add(nn.Linear(10,10)) - :add(nn.Linear(10,10))) - :add(nn.CAddTable()) - :add(nn.Sigmoid()) - - testmask(nn.Recurrence(rm, 10, 1)) - testmask(nn.LSTM(10,10)) - testmask(nn.GRU(10,10)) -end - -function rnntest.AbstractRecurrent_trimZero() - local inputs = {} - - local input = torch.zeros(4,4,10) - local sequence = torch.randn(4,10) - input:select(2,1):select(1,4):copy(sequence[1]) - input:select(2,2):narrow(1,3,2):copy(sequence:narrow(1,1,2)) - input:select(2,3):narrow(1,2,3):copy(sequence:narrow(1,1,3)) - input:select(2,4):copy(sequence) - - - for i=1,4 do - table.insert(inputs, input[i]) - end - - local function testmask(rnn) - local seq = nn.Sequencer(rnn:trimZero(1)) + local seq = nn.Sequencer(rnn:maskZero()) - local outputs = seq:forward(inputs) + local zeroMask = nn.utils.getZeroMaskSequence(input) + seq:setZeroMask(zeroMask) + local outputs = seq:forward(input) mytester:assert(math.abs(outputs[1]:narrow(1,1,3):sum()) < 0.0000001, torch.type(rnn).." mask zero 1 err") mytester:assert(math.abs(outputs[2]:narrow(1,1,2):sum()) < 0.0000001, torch.type(rnn).." mask zero 2 err") @@ -2798,16 +2392,9 @@ function rnntest.AbstractRecurrent_trimZero() mytester:assertTensorEq(outputs[3][4], outputs[4][3], 0.0000001, torch.type(rnn).." mask zero err") end - local rm = nn.Sequential() - :add(nn.ParallelTable() - :add(nn.Linear(10,10)) - :add(nn.Linear(10,10))) - :add(nn.CAddTable()) - :add(nn.Sigmoid()) - - testmask(nn.Recurrence(rm, 10, 1)) - testmask(nn.LSTM(10,10)) - testmask(nn.GRU(10,10)) + testmask(nn.LinearRNN(10, 10)) + testmask(nn.RecLSTM(10, 10)) + testmask(nn.GRU(10, 10)) end local function forwardbackward(module, criterion, input, expected) @@ -2969,8 +2556,7 @@ function rnntest.MaskZero_where() local batchsize = 4 local seqlen = 7 - local rnn = nn.FastLSTM(hiddensize, hiddensize) - rnn:maskZero(1) + local rnn = nn.LinearRNN(hiddensize, hiddensize):maskZero() rnn = nn.Sequencer(rnn) -- is there any difference between start and end padding? @@ -2978,14 +2564,13 @@ function rnntest.MaskZero_where() local inputs, gradOutputs = {}, {} for i=1,seqlen do - if i==1 then - inputs[i] = torch.zeros(batchsize, hiddensize) - else - inputs[i] = torch.randn(batchsize, hiddensize) - end + inputs[i] = torch.randn(batchsize, hiddensize) gradOutputs[i] = torch.randn(batchsize, hiddensize) end + local zeroMask = torch.ByteTensor(seqlen, batchsize):zero() + zeroMask[1]:fill(1) + rnn:setZeroMask(zeroMask) local outputs = rnn:forward(inputs) rnn:zeroGradParameters() local gradInputs = rnn:backward(inputs, gradOutputs) @@ -3005,6 +2590,8 @@ function rnntest.MaskZero_where() inputs[seqlen] = table.remove(inputs, 1) gradOutputs[seqlen] = table.remove(gradOutputs, 1) + zeroMask:zero()[seqlen]:fill(1) + rnn:setZeroMask(zeroMask) rnn:forget() local outputs = rnn:forward(inputs) rnn:zeroGradParameters() @@ -3024,14 +2611,13 @@ function rnntest.MaskZero_where() local inputs, gradOutputs = {}, {} for i=1,seqlen do - if i==4 then - inputs[i] = torch.zeros(batchsize, hiddensize) - else - inputs[i] = torch.randn(batchsize, hiddensize) - end + inputs[i] = torch.randn(batchsize, hiddensize) gradOutputs[i] = torch.randn(batchsize, hiddensize) end + local zeroMask = torch.ByteTensor(seqlen, batchsize):zero() + zeroMask[4]:fill(1) + rnn:setZeroMask(zeroMask) rnn:forget() local rnn2 = rnn:clone() @@ -3043,6 +2629,8 @@ function rnntest.MaskZero_where() local inputs1 = _.first(inputs, 3) local gradOutputs1 = _.first(gradOutputs, 3) + local zeroMask = torch.ByteTensor(3, batchsize):zero() + rnn2:setZeroMask(zeroMask) local outputs1 = rnn2:forward(inputs1) rnn2:zeroGradParameters() local gradInputs1 = rnn2:backward(inputs1, gradOutputs1) @@ -3057,6 +2645,7 @@ function rnntest.MaskZero_where() local inputs2 = _.last(inputs, 3) local gradOutputs2 = _.last(gradOutputs, 3) + rnn2:setZeroMask(zeroMask) local outputs2 = rnn2:forward(inputs2) local gradInputs2 = rnn2:backward(inputs2, gradOutputs2) @@ -3118,52 +2707,6 @@ function rnntest.issue129() mytester:assertTensorEq(output, output2, 0.0002, "issue 129 err") end -function rnntest.issue170() - torch.manualSeed(123) - - local rnn_size = 8 - local vocabSize = 7 - local word_embedding_size = 10 - local rnn_dropout = .00000001 -- dropout ignores manualSeed() - local mono = true - local x = torch.Tensor{{1,2,3},{0,4,5},{0,0,7}} - local t = torch.ceil(torch.rand(x:size(2))) - local rnns = {'GRU'} - local methods = {'maskZero', 'trimZero'} - local loss = torch.Tensor(#rnns, #methods,1) - - for ir,arch in pairs(rnns) do - local rnn = nn[arch](word_embedding_size, rnn_size, nil, rnn_dropout, true) - local model = nn.Sequential() - :add(nn.LookupTableMaskZero(vocabSize, word_embedding_size)) - :add(nn.SplitTable(2)) - :add(nn.Sequencer(rnn)) - :add(nn.SelectTable(-1)) - :add(nn.Linear(rnn_size, 10)) - model:getParameters():uniform(-0.1, 0.1) - local criterion = nn.CrossEntropyCriterion() - local models = {} - for j=1,#methods do - table.insert(models, model:clone()) - end - for im,method in pairs(methods) do - model = models[im] - local rnn = model:get(3).module - rnn[method](rnn, 1) - for i=1,loss:size(3) do - model:zeroGradParameters() - local y = model:forward(x) - loss[ir][im][i] = criterion:forward(y,t) - local dy = criterion:backward(y,t) - model:backward(x, dy) - local w,dw = model:parameters() - model:updateParameters(.5) - end - end - end - mytester:assertTensorEq(loss:select(2,1), loss:select(2,2), 0.0000001, "loss check") -end - function rnntest.encoderdecoder() torch.manualSeed(123) @@ -3512,26 +3055,32 @@ function rnntest.SeqLSTM_main() local inputsize = 2 local outputsize = 3 - -- compare SeqLSTM to FastLSTM (forward, backward, update) + -- compare SeqLSTM to RecLSTM (forward, backward, update) local function testmodule(seqlstm, seqlen, batchsize, lstm2, remember, eval, seqlstm2, maskzero) lstm2 = lstm2 or seqlstm:toRecLSTM() remember = remember or 'neither' + seqlstm2 = seqlstm2 or nn.Sequencer(lstm2) local input, gradOutput input = torch.randn(seqlen, batchsize, inputsize) if maskzero then lstm2:maskZero() + local zeroMask = torch.ByteTensor(seqlen, batchsize):zero() for i=1,seqlen do for j=1,batchsize do if math.random() < 0.2 then - input[{i,j,{}}]:zero() + zeroMask[{i,j}] = 1 end end end + seqlstm:setZeroMask(zeroMask) + seqlstm2:setZeroMask(zeroMask) + else + seqlstm:setZeroMask(false) + seqlstm2:setZeroMask(false) end gradOutput = torch.randn(seqlen, batchsize, outputsize) - seqlstm2 = seqlstm2 or nn.Sequencer(lstm2) seqlstm2:remember(remember) mytester:assert(seqlstm2._remember == remember, tostring(seqlstm2._remember) ..'~='.. tostring(remember)) @@ -3577,8 +3126,7 @@ function rnntest.SeqLSTM_main() local seqlen = 4 local batchsize = 5 - local seqlstm = nn.SeqLSTM(inputsize, outputsize) - seqlstm.maskzero = true + local seqlstm = nn.SeqLSTM(inputsize, outputsize):maskZero() seqlstm:reset(0.1) local lstm2 = testmodule(seqlstm, seqlen, batchsize) @@ -3673,9 +3221,9 @@ function rnntest.SeqLSTM_maskzero() -- Note that more maskzero = true tests with masked inputs are in SeqLSTM unit test. local T, N, D, H = 3, 2, 4, 5 local seqlstm = nn.SeqLSTM(D,H) - seqlstm.maskzero = false - local seqlstm2 = seqlstm:clone() - seqlstm2.maskzero = true + local seqlstm2 = seqlstm:clone():maskZero() + local zeroMask = torch.ByteTensor(T, N):zero() + seqlstm2:setZeroMask(zeroMask) local input = torch.randn(T, N, D) local gradOutput = torch.randn(T, N, H) @@ -3710,9 +3258,9 @@ function rnntest.SeqLSTM_maskzero() input = input:cuda() gradOutput = gradOutput:cuda() seqlstm:cuda() + zeroMask = zeroMask:type('torch.CudaByteTensor') end - seqlstm.maskzero = false seqlstm:forward(input) seqlstm:backward(input, gradOutput) @@ -3733,7 +3281,8 @@ function rnntest.SeqLSTM_maskzero() end end - seqlstm.maskzero = true + seqlstm:maskZero() + seqlstm:setZeroMask(zeroMask) seqlstm:forward(input) seqlstm:backward(input, gradOutput) @@ -3758,7 +3307,7 @@ function rnntest.SeqLSTMP_main() local batchsize = 5 local lstm = nn.SeqLSTM(inputsize, outputsize) - local lstmp = nn.SeqLSTMP(inputsize, hiddensize, outputsize) + local lstmp = nn.SeqLSTM(inputsize, hiddensize, outputsize) local params, gradParams = lstm:parameters() local paramsp, gradParamsp = lstmp:parameters() @@ -3799,16 +3348,21 @@ function rnntest.SeqLSTMP_main() -- test with maskzero + lstmp:maskZero() + lstm:maskZero() + + local zeroMask = torch.ByteTensor(seqlen, batchsize):zero() + for i=1,seqlen do for j=1,batchsize do if math.random() < 0.2 then - input[{i,j,{}}]:zero() + zeroMask[{i,j}] = 1 end end end - lstmp.maskzero = true - lstm.maskzero = true + lstmp:setZeroMask(zeroMask) + lstm:setZeroMask(zeroMask) local output = lstm:forward(input) local outputp = lstmp:forward(input) @@ -3834,8 +3388,8 @@ function rnntest.SeqLSTMP_main() local hiddensize = 4 - local lstmp = nn.SeqLSTMP(inputsize, hiddensize, outputsize) - local lstmp2 = nn.SeqLSTMP(inputsize, hiddensize, outputsize) + local lstmp = nn.SeqLSTM(inputsize, hiddensize, outputsize) + local lstmp2 = nn.SeqLSTM(inputsize, hiddensize, outputsize) local params, gradParams = lstmp:parameters() local params2, gradParams2 = lstmp2:parameters() @@ -3848,9 +3402,11 @@ function rnntest.SeqLSTMP_main() lstmp2:zeroGradParameters() local input = torch.randn(seqlen, batchsize, inputsize) - input[3] = 0 -- zero the 3 time-step + zeroMask:zero() + zeroMask[3] = 1 -- zero the 3rd time-step - lstmp.maskzero = true + lstmp:maskZero() + lstmp:setZeroMask(zeroMask) local output = lstmp:forward(input) local gradInput = lstmp:backward(input, gradOutput) @@ -3872,94 +3428,6 @@ function rnntest.SeqLSTMP_main() end end -function rnntest.FastLSTM_issue203() - torch.manualSeed(123) - local nActions = 3 - local wordEmbDim = 4 - local lstmHidDim = 7 - - local input = {torch.randn(2), torch.randn(2)} - local target = {torch.IntTensor{1, 3}, torch.IntTensor{2, 3}} - - local seq = nn.Sequencer( - nn.Sequential() - :add(nn.Linear(2, wordEmbDim)) - :add(nn.Copy(nil,nil,true)) - :add(nn.FastLSTM(wordEmbDim, lstmHidDim)) - :add(nn.Linear(lstmHidDim, nActions)) - :add(nn.LogSoftMax()) - ) - - local seq2 = nn.Sequencer( - nn.Sequential() - :add(nn.Linear(2, wordEmbDim)) - :add(nn.FastLSTM(wordEmbDim, lstmHidDim)) - :add(nn.Linear(lstmHidDim, nActions)) - :add(nn.LogSoftMax()) - ) - - local parameters, grads = seq:getParameters() - local parameters2, grads2 = seq2:getParameters() - - parameters:copy(parameters2) - - local criterion = nn.SequencerCriterion(nn.ClassNLLCriterion()) - local criterion2 = nn.SequencerCriterion(nn.ClassNLLCriterion()) - - local output = seq:forward(input) - local loss = criterion:forward(output, target) - local gradOutput = criterion:backward(output, target) - seq:zeroGradParameters() - local gradInput = seq:backward(input, gradOutput) - - local output2 = seq2:forward(input) - local loss2 = criterion2:forward(output2, target) - local gradOutput2 = criterion2:backward(output2, target) - seq2:zeroGradParameters() - local gradInput2 = seq2:backward(input, gradOutput2) - - local t1 = seq.modules[1].sharedClones[2]:get(3).sharedClones[1].gradInput[1] - local t2 = seq2.modules[1].sharedClones[1]:get(2).sharedClones[1].gradInput[1] - mytester:assertTensorEq(t1, t2, 0.0000001, "LSTM gradInput1") - - local t1 = seq.modules[1].sharedClones[2]:get(3).sharedClones[2].gradInput[1] - local t2 = seq2.modules[1].sharedClones[1]:get(2).sharedClones[2].gradInput[1] - mytester:assertTensorEq(t1, t2, 0.0000001, "LSTM gradInput2") - - for i=1,2 do - mytester:assertTensorEq(output2[i], output[i], 0.0000001, "output "..i) - mytester:assertTensorEq(gradOutput2[i], gradOutput[i], 0.0000001, "gradOutput "..i) - mytester:assertTensorEq(gradInput2[i], gradInput[i], 0.0000001, "gradInput "..i) - end - - local params, gradParams = seq:parameters() - local params2, gradParams2 = seq2:parameters() - - for i=1,#params do - mytester:assertTensorEq(gradParams[i], gradParams2[i], 0.000001, "gradParams "..tostring(gradParams[i])) - end - - if not pcall(function() require 'optim' end) then - return - end - - local seq_ = seq2 - local parameters_ = parameters2 - local grads_ = grads2 - local function f(x) - parameters_:copy(x) - -- seq:forget() - seq_:zeroGradParameters() - seq_:forward(input) - criterion:forward(seq_.output, target) - seq_:backward(input, criterion:backward(seq_.output, target)) - return criterion.output, grads_ - end - - local err = optim.checkgrad(f, parameters_:clone()) - mytester:assert(err < 0.000001, "error "..err) -end - function rnntest.SeqBRNNTest() local brnn = nn.SeqBRNN(5, 5) @@ -4074,16 +3542,7 @@ function rnntest.NormStabilizer() local SequencerCriterion, parent = torch.class('nn.SequencerCriterionNormStab', 'nn.SequencerCriterion') function SequencerCriterion:__init(criterion, beta) - parent.__init(self) - self.criterion = criterion - if torch.isTypeOf(criterion, 'nn.ModuleCriterion') then - error("SequencerCriterion shouldn't decorate a ModuleCriterion. ".. - "Instead, try the other way around : ".. - "ModuleCriterion decorates a SequencerCriterion. ".. - "Its modules can also be similarly decorated with a Sequencer.") - end - self.clones = {} - self.gradInput = {} + parent.__init(self, criterion) self.beta = beta end @@ -4255,7 +3714,6 @@ function rnntest.NCE_MaskZero() uniform = 0.1, hiddensize = {100}, vocabsize = 100, - dropout = 0, k = 25 } @@ -4265,20 +3723,12 @@ function rnntest.NCE_MaskZero() local lookup = nn.LookupTableMaskZero(opt.vocabsize, opt.hiddensize[1]) lookup.maxnormout = -1 -- prevent weird maxnormout behaviour lm:add(lookup) -- input is seqlen x batchsize - if opt.dropout > 0 then - lm:add(nn.Dropout(opt.dropout)) - end -- rnn layers local inputsize = opt.hiddensize[1] for i,hiddensize in ipairs(opt.hiddensize) do - -- this is a faster version of nnSequencer(nn.FastLSTM(inpusize, hiddensize)) - local rnn = nn.SeqLSTM(inputsize, hiddensize) - rnn.maskzero = true - lm:add(rnn) - if opt.dropout > 0 then - lm:add(nn.Dropout(opt.dropout)) - end + lm:add(nn.SeqLSTM(inputsize, hiddensize):maskZero()) + lm:add(nn.Dropout(opt.dropout)) inputsize = hiddensize end @@ -4296,7 +3746,7 @@ function rnntest.NCE_MaskZero() :add(nn.ZipTable()) -- {{x1,x2,...}, {t1,t2,...}} -> {{x1,t1},{x2,t2},...} -- encapsulate stepmodule into a Sequencer - lm:add(nn.Sequencer(nn.MaskZero(ncemodule, 1))) + lm:add(nn.Sequencer(nn.MaskZero(ncemodule))) -- remember previous state between batches lm:remember() @@ -4305,11 +3755,12 @@ function rnntest.NCE_MaskZero() for k,param in ipairs(lm:parameters()) do param:uniform(-opt.uniform, opt.uniform) end + ncemodule:reset() end --[[ loss function ]]-- - local crit = nn.MaskZeroCriterion(nn.NCECriterion(), 0) + local crit = nn.MaskZeroCriterion(nn.NCECriterion()) local targetmodule = nn.SplitTable(1) local criterion = nn.SequencerCriterion(crit) @@ -4319,6 +3770,7 @@ function rnntest.NCE_MaskZero() targets = torch.LongTensor(opt.datasize, opt.seqlen, opt.batchsize):random(1,opt.vocabsize) } + local _ = require 'moses' local starterr local err local found = false @@ -4327,6 +3779,10 @@ function rnntest.NCE_MaskZero() for i=1,opt.datasize do local input, target = data.inputs[i], data.targets[i] local target = targetmodule:forward(target) + local zeroMask = nn.utils.getZeroMaskSequence(input) + --print("INPUT ZM", input, zeroMask) + lm:setZeroMask(zeroMask) + criterion:setZeroMask(zeroMask) local output = lm:forward({input, target}) err = err + criterion:forward(output, target) local gradOutput = criterion:backward(output, target) @@ -4789,23 +4245,18 @@ function checkgrad(opfunc, x, eps) end function rnntest.MufuruGradients() - local batchSize = torch.random(1,2) - local inputSize = torch.random(1,3) - local outputSize = torch.random(1,4) - local seqlen = torch.random(1,5) + local batchSize = 2 + local inputSize = 3 + local outputSize = 4 + local seqlen = 5 local rnn = nn.MuFuRu(inputSize, outputSize) local module = nn.Sequencer(rnn) local w,dw = module:getParameters() - local crit = nn.CrossEntropyCriterion() + local crit = nn.SequencerCriterion(nn.CrossEntropyCriterion()) local input = torch.randn(seqlen, batchSize, inputSize) - local target = torch.LongTensor(seqlen, batchSize) - for i=1,seqlen do - for j=1,batchSize do - target[i][j] = torch.random(1, outputSize) - end - end + local target = torch.LongTensor(seqlen, batchSize):random(1,outputSize) local function feval(x) if w ~= x then w:copy(x) end module:zeroGradParameters() @@ -4819,154 +4270,6 @@ function rnntest.MufuruGradients() mytester:assertlt(err, precision, "error in computing grad parameters") end -function rnntest.inplaceBackward() - -- not implemented (work was started, but never finished, sorry) - if true then return end - - local lr = 0.1 - local seqlen, batchsize, hiddensize = 3, 4, 5 - local input = torch.randn(seqlen, batchsize, hiddensize) - local gradOutput = torch.randn(seqlen, batchsize, hiddensize) - - -- test sequencer(linear) - - local seq = nn.Sequencer(nn.Linear(hiddensize, hiddensize)) - local seq2 = seq:clone() - seq2:inplaceBackward() - - local output = seq:forward(input) - local output2 = seq2:forward(input) - - mytester:assertTensorEq(output, output2, 0.000001) - - seq:zeroGradParameters() - local gradInput = seq:backward(input, gradOutput) - seq:updateParameters(lr) - - local gradInput2 = seq2:backward(input, gradOutput, -lr) - - mytester:assertTensorEq(gradInput, gradInput2, 0.000001) - - local params = seq:parameters() - local params2 = seq2:parameters() - - for i=1,#params do - mytester:assertTensorEq(params[i], params2[i], 0.000001) - end - - -- test seqlstm - - local seq = nn.SeqLSTM(hiddensize, hiddensize) - local seq2 = seq:clone() - seq2:inplaceBackward() - - local output = seq:forward(input) - local output2 = seq2:forward(input) - - mytester:assertTensorEq(output, output2, 0.000001) - - seq:zeroGradParameters() - local gradInput = seq:backward(input, gradOutput) - seq:updateParameters(lr) - - local gradInput2 = seq2:backward(input, gradOutput, -lr) - - mytester:assertTensorEq(gradInput, gradInput2, 0.000001) - - local params = seq:parameters() - local params2 = seq2:parameters() - - for i=1,#params do - mytester:assertTensorEq(params[i], params2[i], 0.000001) - end - - - if true then return end - -- test language model - - local vocabsize = 100 - local input = torch.LongTensor(seqlen, batchsize):random(1,vocabsize) - local target = torch.LongTensor(seqlen, batchsize):random(1,vocabsize) - - local lm = nn.Sequential() - local lookup = nn.LookupTableMaskZero(vocabsize, hiddensize) - lm:add(lookup) - - for i=1,2 do - local rnn = nn.SeqLSTM(hiddensize, hiddensize) - rnn.maskzero = true - lm:add(rnn) - end - - lm:add(nn.SplitTable(1)) - - local unigram = torch.FloatTensor(vocabsize):uniform(1,10) - local ncemodule = nn.NCEModule(hiddensize, vocabsize, 10, unigram, -1) - local _sampleidx = torch.Tensor(1,10):random(1,vocabsize) - - function ncemodule.noiseSample(self, sampleidx, batchsize, k) - assert(batchsize == 1) - assert(k == 10) - sampleidx:resize(1, k):copy(_sampleidx) - return sampleidx - end - - lm = nn.Sequential() - :add(nn.ParallelTable() - :add(lm):add(nn.Identity())) - :add(nn.ZipTable()) - - lm:add(nn.Sequencer(nn.MaskZero(ncemodule, 1))) - lm:remember() - - local crit = nn.MaskZeroCriterion(nn.NCECriterion(), 0) - local targetmodule = nn.SplitTable(1) - local criterion = nn.SequencerCriterion(crit) - - local lm2 = lm:clone() - lm2:inplaceBackward() - - local criterion2 = criterion:clone() - - local target = targetmodule:forward(target) - - local inputTable = {input, target} - - local output = lm:forward(inputTable) - local output2 = lm2:forward(inputTable) - - for i=1,seqlen do - mytester:assertTensorEq(output[i][1], output2[i][1], 0.000001) - mytester:assertTensorEq(output[i][2], output2[i][2], 0.000001) - mytester:assertTensorEq(output[i][3], output2[i][3], 0.000001) - mytester:assertTensorEq(output[i][4], output2[i][4], 0.000001) - end - - local loss = criterion:forward(output, target) - local loss2 = criterion2:forward(output, target) - - local gradOutput = criterion:backward(output, target) - local gradOutput2 = criterion2:backward(output, target) - - for i=1,seqlen do - mytester:assertTensorEq(gradOutput[i][1], gradOutput2[i][1], 0.000001) - mytester:assertTensorEq(gradOutput[i][2], gradOutput2[i][2], 0.000001) - end - - lm:zeroGradParameters() - lm:backward(inputTable, gradOutput) - lm:updateParameters(lr) - - lm2:backward(inputTable, gradOutput2, -lr) - - local params = lm:parameters() - local params2 = lm2:parameters() - - for i=1,#params do - mytester:assertTensorEq(params[i], params2[i], 0.000001, "error in params "..i..": "..tostring(params[i]:size())) - end -end - function rnntest.getHiddenState() local seqlen, batchsize = 7, 3 local inputsize, outputsize = 4, 5 @@ -5220,6 +4523,7 @@ function rnntest.VariableLength_lstm() input2:select(2,i):narrow(1,maxLength-seqlen+1,seqlen):copy(input[i]) end + lstm:setZeroMask(nn.utils.getZeroMaskSequence(input2)) local output2 = lstm:forward(input2) if not lastOnly then @@ -5510,16 +4814,20 @@ function rnntest.RecLSTM_maskzero() local T, N, D, H = 3, 2, 4, 5 local reclstm = nn.RecLSTM(D,H):maskZero() local seqlstm = nn.Sequencer(reclstm) - local seqlstm2 = nn.SeqLSTM(D,H) + local seqlstm2 = nn.SeqLSTM(D,H):maskZero() seqlstm2.weight:copy(reclstm.modules[1].weight) seqlstm2.bias:copy(reclstm.modules[1].bias) - seqlstm2.maskzero = true local input = torch.randn(T, N, D) input[{2,1}]:fill(0) input[{3,2}]:fill(0) local gradOutput = torch.randn(T, N, H) + local zeroMask = torch.ByteTensor(T, N):zero() + zeroMask[{2,1}] = 1 + zeroMask[{3,2}] = 1 + seqlstm:setZeroMask(zeroMask) + seqlstm2:setZeroMask(zeroMask) local output = seqlstm:forward(input) local output2 = seqlstm2:forward(input) diff --git a/utils.lua b/utils.lua index 1f4bbba..b61aa84 100644 --- a/utils.lua +++ b/utils.lua @@ -57,8 +57,17 @@ function nn.utils.getZeroMaskSequence(sequence, zeroMask) sequence = sequence:contiguous():view(sequence:size(1), sequence:size(2), -1) -- build mask (1 where norm is 0 in first) - local _zeroMask = torch.getBuffer('getZeroMaskSequence', '_zeroMask', sequence) + local _zeroMask + if sequence.norm then + _zeroMask = torch.getBuffer('getZeroMaskSequence', '_zeroMask', sequence) + else + _zeroMask = torch.getBuffer('getZeroMaskSequence', '_zeroMask', 'torch.FloatTensor') + local _sequence = torch.getBuffer('getZeroMaskSequence', '_sequence', 'torch.FloatTensor') + _sequence:resize(sequence:size()):copy(sequence) + sequence = _sequence + end _zeroMask:norm(sequence, 2, 3) + zeroMask = zeroMask or ( (torch.type(sequence) == 'torch.CudaTensor') and torch.CudaByteTensor() or (torch.type(sequence) == 'torch.ClTensor') and torch.ClTensor() @@ -161,16 +170,19 @@ function nn.utils.recursiveGetFirst(input) end -- in-place set tensor to zero where zeroMask is 1 -function nn.utils.recursiveZeroMask(tensor, mask) +function nn.utils.recursiveZeroMask(tensor, zeroMask) if torch.type(tensor) == 'table' then for k,tensor_k in ipairs(tensor) do - nn.utils.recursiveMask(tensor_k) + nn.utils.recursiveZeroMask(tensor_k, zeroMask) end else assert(torch.isTensor(tensor)) local tensorSize = tensor:size():fill(1) tensorSize[1] = tensor:size(1) + if zeroMask:dim() == 2 then + tensorSize[2] = tensor:size(2) + end assert(zeroMask:dim() <= tensor:dim()) zeroMask = zeroMask:view(tensorSize):expandAs(tensor) -- set tensor to zero where zeroMask is 1 @@ -263,4 +275,15 @@ function nn.utils.recursiveMaskedCopy(dst, mask, src) torch.type(dst).." and "..torch.type(src).." instead") end return dst +end + +function nn.utils.setZeroMask(modules, zeroMask, cuda) + if cuda then + cuZeroMask = torch.getBuffer('setZeroMask', 'cuZeroMask', 'torch.CudaByteTensor') + cuZeroMask:resize(zeroMask:size()):copy(zeroMask) + zeroMask = cuZeroMask + end + for i,module in ipairs(torch.type(modules) == 'table' and modules or {modules}) do + module:setZeroMask(zeroMask) + end end \ No newline at end of file