From 72091ff840b4396b4a89936ceb392570eefe1c87 Mon Sep 17 00:00:00 2001 From: Nicholas Leonard Date: Wed, 17 May 2017 11:27:53 -0400 Subject: [PATCH] ReverseTable+SeqReverseSequence -> ReverseSequnece --- CMakeLists.txt | 4 +- ReverseSequence.lua | 75 +++++++++++++++++++ ReverseTable.lua | 39 ---------- .../SeqReverseSequence.lua | 0 init.lua | 4 +- test/test.lua | 33 ++++++-- 6 files changed, 106 insertions(+), 49 deletions(-) create mode 100644 ReverseSequence.lua delete mode 100644 ReverseTable.lua rename SeqReverseSequence.lua => deprecated/SeqReverseSequence.lua (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0a31f15..8b326f1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,7 +42,7 @@ SET(luasrc SeqGRU.lua SeqLSTM.lua deprecated/SeqLSTMP.lua - SeqReverseSequence.lua + deprecated/SeqReverseSequence.lua Sequencer.lua SequencerCriterion.lua ZeroGrad.lua @@ -85,7 +85,7 @@ SET(luasrc ReinforceCategorical.lua ReinforceGamma.lua ReinforceNormal.lua - ReverseTable.lua + ReverseSequence.lua Sequential.lua Serial.lua SimpleColorTransform.lua diff --git a/ReverseSequence.lua b/ReverseSequence.lua new file mode 100644 index 0000000..b97fcdc --- /dev/null +++ b/ReverseSequence.lua @@ -0,0 +1,75 @@ +local ReverseSequence, parent = torch.class("nn.ReverseSequence", "nn.Module") + +function ReverseSequence:updateOutput(input) + local seqlen + if torch.isTensor(input) then + seqlen = input:size(1) + self.output = torch.isTensor(self.output) and self.output or input.new() + self.output:resizeAs(input) + + self._range = self._range or torch.isCudaTensor(input) and torch.CudaLongTensor() or torch.LongTensor() + if self._range:nElement() ~= seqlen then + self._range:range(seqlen,1,-1) + end + self.output:index(input, 1, self._range) + else + seqlen = #input + self.output = torch.type(self.output) == 'table' and self.output or {} + assert(torch.type(input) == 'table', "Expecting table or tensor at arg 1") + + -- empty output table + for k,v in ipairs(self.output) do + self.output[k] = nil + end + + -- reverse input + local k = 1 + for i=seqlen,1,-1 do + self.output[k] = input[i] + k = k + 1 + end + end + + return self.output +end + +function ReverseSequence:updateGradInput(input, gradOutput) + local seqlen + if torch.isTensor(input) then + seqlen = input:size(1) + self.gradInput = torch.isTensor(self.gradInput) and self.gradInput or input.new() + self.gradInput:resizeAs(input) + + self.gradInput:index(gradOutput, 1, self._range) + else + seqlen = #input + self.gradInput = torch.type(self.gradInput) == 'table' and self.gradInput or {} + assert(torch.type(gradOutput) == 'table', "Expecting table or tensor at arg 2") + + -- empty gradInput table + for k,v in ipairs(self.gradInput) do + self.gradInput[k] = nil + end + + -- reverse gradOutput + local k = 1 + for i=seqlen,1,-1 do + self.gradInput[k] = gradOutput[i] + k = k + 1 + end + end + + return self.gradInput +end + +function ReverseSequence:clearState() + self.gradInput = torch.Tensor() + self.output = torch.Tensor() + self._range = nil +end + +function ReverseSequence:type(...) + self:clearState() + return parent.type(self, ...) +end + diff --git a/ReverseTable.lua b/ReverseTable.lua deleted file mode 100644 index 69660a0..0000000 --- a/ReverseTable.lua +++ /dev/null @@ -1,39 +0,0 @@ -local ReverseTable, parent = torch.class("nn.ReverseTable", "nn.Module") - -function ReverseTable:__init() - parent.__init(self) - self.output = {} - self.gradInput = {} -end - -function ReverseTable:updateOutput(inputTable) - assert(torch.type(inputTable) == 'table', "Expecting table at arg 1") - - -- empty output table - for k,v in ipairs(self.output) do - self.output[k] = nil - end - - -- reverse input - local k = 1 - for i=#inputTable,1,-1 do - self.output[k] = inputTable[i] - k = k + 1 - end - return self.output -end - -function ReverseTable:updateGradInput(inputTable, gradOutputTable) - -- empty gradInput table - for k,v in ipairs(self.gradInput) do - self.gradInput[k] = nil - end - - -- reverse gradOutput - local k = 1 - for i=#gradOutputTable,1,-1 do - self.gradInput[k] = gradOutputTable[i] - k = k + 1 - end - return self.gradInput -end diff --git a/SeqReverseSequence.lua b/deprecated/SeqReverseSequence.lua similarity index 100% rename from SeqReverseSequence.lua rename to deprecated/SeqReverseSequence.lua diff --git a/init.lua b/init.lua index 727373b..551f664 100644 --- a/init.lua +++ b/init.lua @@ -58,7 +58,7 @@ require('rnn.Collapse') require('rnn.ZipTable') require('rnn.ZipTableOneToMany') require('rnn.CAddTensorTable') -require('rnn.ReverseTable') +require('rnn.ReverseSequence') require('rnn.Dictionary') require('rnn.Inception') require('rnn.Clip') @@ -131,7 +131,6 @@ require('rnn.RecurrentAttention') -- sequencer + recurrent modules require('rnn.SeqLSTM') require('rnn.SeqGRU') -require('rnn.SeqReverseSequence') require('rnn.SeqBRNN') -- recurrent criterions: @@ -144,6 +143,7 @@ require('rnn.MaskZeroCriterion') require('rnn.LSTM') require('rnn.FastLSTM') require('rnn.SeqLSTMP') +require('rnn.SeqReverseSequence') -- prevent likely name conflicts nn.rnn = rnn diff --git a/test/test.lua b/test/test.lua index 7275df1..278e1fb 100644 --- a/test/test.lua +++ b/test/test.lua @@ -4991,24 +4991,45 @@ function rnntest.CAddTensorTable() mytester:assertTensorEq(output[1]+output[2]+output[3], gradInput[1], 0.000001, "CAddTensorTable gradInput1") end -function rnntest.ReverseTable() +function rnntest.ReverseSequence() + -- test table + -- input : { a, b, c, d } -- output : { c, b, a, d } - local r = nn.ReverseTable() + local r = nn.ReverseSequence() local input = {torch.randn(3,4), torch.randn(3,4), torch.randn(3,4), torch.randn(3,4)} local output = r:forward(input) - mytester:assert(#output == 4, "ReverseTable #output") + mytester:assert(#output == 4, "ReverseSequence #output") local k = 1 for i=#input,1,-1 do - mytester:assertTensorEq(input[i], output[k], 0.00001, "ReverseTable output err "..k) + mytester:assertTensorEq(input[i], output[k], 0.00001, "ReverseSequence output err "..k) k = k + 1 end local gradInput = r:backward(input, output) - mytester:assert(#gradInput == 4, "ReverseTable #gradInput") + mytester:assert(#gradInput == 4, "ReverseSequence #gradInput") for i=1,#input do - mytester:assertTensorEq(gradInput[i], input[i], 0.00001, "ReverseTable gradInput err "..i) + mytester:assertTensorEq(gradInput[i], input[i], 0.00001, "ReverseSequence gradInput err "..i) + end + + -- test tensor + + local r = nn.ReverseSequence() + local input = torch.randn(5,4,3) + local output = r:forward(input) + + mytester:assert(output:isSameSizeAs(input), "ReverseSequence #output") + local k = 1 + for i=5,1,-1 do + mytester:assertTensorEq(input[i], output[k], 0.00001, "ReverseSequence output err "..k) + k = k + 1 + end + + local gradInput = r:backward(input, output) + mytester:assert(gradInput:isSameSizeAs(input), "ReverseSequence #gradInput") + for i=1,5 do + mytester:assertTensorEq(gradInput[i], input[i], 0.00001, "ReverseSequence gradInput err "..i) end end