From 78f9a498a6e5444eedc04fc670a2ab108ef1511d Mon Sep 17 00:00:00 2001 From: Nicholas Leonard Date: Wed, 24 May 2017 18:43:29 -0400 Subject: [PATCH] nn.ZipTable --- ZipTable.lua | 34 ++++++++++++++++++++++++++++++ ZipTableOneToMany.lua | 37 +++++++++++++++++++++++++++++++++ doc/table.md | 40 ++++++++++++++++++++++++++++++++---- init.lua | 2 ++ test.lua | 48 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 157 insertions(+), 4 deletions(-) create mode 100644 ZipTable.lua create mode 100644 ZipTableOneToMany.lua diff --git a/ZipTable.lua b/ZipTable.lua new file mode 100644 index 000000000..7b18619eb --- /dev/null +++ b/ZipTable.lua @@ -0,0 +1,34 @@ +local ZipTable, parent = torch.class('nn.ZipTable', 'nn.Module') + +-- input : { {a1,a2}, {b1,b2}, {c1,c2} } +-- output : { {a1,b1,c1}, {a2,b2,c2} } +function ZipTable:__init() + parent.__init(self) + self.output = {} + self.gradInput = {} +end + +function ZipTable:updateOutput(inputTable) + self.output = {} + for i,inTable in ipairs(inputTable) do + for j,input in ipairs(inTable) do + local output = self.output[j] or {} + output[i] = input + self.output[j] = output + end + end + return self.output +end + +function ZipTable:updateGradInput(inputTable, gradOutputTable) + self.gradInput = {} + for i,gradOutTable in ipairs(gradOutputTable) do + for j,gradOutput in ipairs(gradOutTable) do + local gradInput = self.gradInput[j] or {} + gradInput[i] = gradOutput + self.gradInput[j] = gradInput + end + end + return self.gradInput +end + diff --git a/ZipTableOneToMany.lua b/ZipTableOneToMany.lua new file mode 100644 index 000000000..d4a80fe0d --- /dev/null +++ b/ZipTableOneToMany.lua @@ -0,0 +1,37 @@ +local ZipTableOneToMany, parent = torch.class('nn.ZipTableOneToMany', 'nn.Module') + +-- based on ZipTable in dpnn + +-- input : { v, {a, b, c} } +-- output : { {v,a}, {v,b}, {v,c} } +function ZipTableOneToMany:__init() + parent.__init(self) + self.output = {} + self.gradInput = {} + -- make buffer to update during forward/backward + self.gradInputEl = torch.Tensor() +end + +function ZipTableOneToMany:updateOutput(input) + assert(#input == 2, "input must be table of element and table") + local inputEl, inputTable = input[1], input[2] + self.output = {} + for i,v in ipairs(inputTable) do + self.output[i] = {inputEl, v} + end + return self.output +end + +function ZipTableOneToMany:updateGradInput(input, gradOutput) + assert(#input == 2, "input must be table of element and table") + local inputEl, inputTable = input[1], input[2] + self.gradInputEl:resizeAs(inputEl):zero() + local gradInputTable = {} + for i,gradV in ipairs(gradOutput) do + self.gradInputEl:add(gradV[1]) + gradInputTable[i] = gradV[2] + end + self.gradInput = {self.gradInputEl, gradInputTable} + return self.gradInput +end + diff --git a/doc/table.md b/doc/table.md index b3e2e5f86..1924eaddf 100644 --- a/doc/table.md +++ b/doc/table.md @@ -15,6 +15,8 @@ This allows one to build very rich architectures: * [`SelectTable`](#nn.SelectTable): select one element from a `table`; * [`NarrowTable`](#nn.NarrowTable): select a slice of elements from a `table`; * [`FlattenTable`](#nn.FlattenTable): flattens a nested `table` hierarchy; + * [`ZipTable`](#nn.ZipTable) : zip a table of tables into a table of tables; + * [`ZipTableOneToMany`](#nn.ZipTableOneToMany) : zip a table to a single tensor; * Pair Modules compute a measure like distance or similarity from a pair (`table`) of input `Tensor`s: * [`PairwiseDistance`](#nn.PairwiseDistance): outputs the `p`-norm. distance between inputs; * [`DotProduct`](#nn.DotProduct): outputs the dot product (similarity) between inputs; @@ -692,7 +694,7 @@ Forwarding a batch of 2 examples gives us something like this: `module` = `SelectTable(index)` -Creates a module that takes a (nested) `table` as input and outputs the element at index `index`. `index` can be strings or integers (positive or negative). +Creates a module that takes a (nested) `table` as input and outputs the element at index `index`. `index` can be strings or integers (positive or negative). This can be either a `table` or a [`Tensor`](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor). The gradients of the non-`index` elements are zeroed `Tensor`s of the same size. This is true regardless of the @@ -731,7 +733,7 @@ Exmaple 2: > gradInput = nn.SelectTable("A"):backward(input, torch.randn(2, 3)) -> gradInput +> gradInput { A : DoubleTensor - size: 2x3 B : DoubleTensor - size: 2x1 @@ -811,11 +813,11 @@ Example 3: `module` = `NarrowTable(offset [, length])` -Creates a module that takes a `table` as input and outputs the subtable +Creates a module that takes a `table` as input and outputs the subtable starting at index `offset` having `length` elements (defaults to 1 element). The elements can be either a `table` or a [`Tensor`](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor). -The gradients of the elements not included in the subtable are zeroed `Tensor`s of the same size. +The gradients of the elements not included in the subtable are zeroed `Tensor`s of the same size. This is true regardless of the depth of the encapsulated `Tensor` as the function used internally to do so is recursive. Example: @@ -883,6 +885,36 @@ gives the output: } ``` + +## ZipTable ## + +```lua +module = nn.ZipTable() +``` + +Zips a table of tables into a table of tables. + +Example: +```lua +print(module:forward{ {'a1','a2'}, {'b1','b2'}, {'c1','c2'} }) +{ {'a1','b1','c1'}, {'a2','b2','c2'} } +``` + + +## ZipTableOneToMany ## + +```lua +module = nn.ZipTableOneToMany() +``` + +Zips a table of element `el` and table of elements `tab` into a table of tables, where the i-th table contains the element `el` and the i-th element in table `tab` + +Example: +```lua +print(module:forward{ 'el', {'a','b','c'} }) +{ {'el','a'}, {'el','b'}, {'el','c'} } +``` + ## PairwiseDistance ## diff --git a/init.lua b/init.lua index b397d770d..447d357d8 100755 --- a/init.lua +++ b/init.lua @@ -170,6 +170,8 @@ require('nn.CriterionTable') require('nn.FlattenTable') require('nn.NarrowTable') require('nn.MapTable') +require('nn.ZipTable') +require('nn.ZipTableOneToMany') require('nn.Criterion') require('nn.MSECriterion') diff --git a/test.lua b/test.lua index 2dafb099c..16bae0954 100755 --- a/test.lua +++ b/test.lua @@ -8548,6 +8548,54 @@ function nntest.ZeroGrad() mytester:assertTensorEq(gradInput, gradInput2, 0.0000001) end +function nntest.ZipTable() + -- input : { {a1,a2}, {b1,b2}, {c1,c2} } + -- output : { {a1,b1,c1}, {a2,b2,c2} } + local z = nn.ZipTable() + local input = { + {torch.randn(3,4), torch.randn(3,4)}, + {torch.randn(3,4), torch.randn(3,4)}, + {torch.randn(3,4), torch.randn(3,4)} + } + local output = z:forward(input) + mytester:assert(#output == 2, "ZipTable #output") + mytester:assert(#(output[1]) == 3, "ZipTable #output[1]") + mytester:assertTensorEq(input[1][1], output[1][1], 0.000001, "ZipTable input11") + mytester:assertTensorEq(input[1][2], output[2][1], 0.000001, "ZipTable input12") + mytester:assertTensorEq(input[3][2], output[2][3], 0.000001, "ZipTable input32") + local gradInput = z:backward(input, output) + mytester:assert(#gradInput == 3, "ZipTable #gradInput") + mytester:assert(#(gradInput[1]) == 2, "ZipTable #gradInput[1]") + mytester:assertTensorEq(input[1][1], gradInput[1][1], 0.000001, "ZipTable gradInput11") + mytester:assertTensorEq(input[1][2], gradInput[1][2], 0.000001, "ZipTable gradInput12") + mytester:assertTensorEq(input[3][2], gradInput[3][2], 0.000001, "ZipTable gradInput32") +end + +function nntest.ZipTableOneToMany() + -- input : { v, {a,b,c} } + -- output : { {v,a}, {v,b}, {v,c} } + local z = nn.ZipTableOneToMany() + local input = { torch.randn(3), { torch.randn(4), torch.rand(4), torch.rand(4) } } + local output = z:forward(input) + mytester:assert(#output == 3, "ZipTableOneToMany #output") + mytester:assert(#(output[1]) == 2, "ZipTableOneToMany #output[1]") + mytester:assert(#(output[2]) == 2, "ZipTableOneToMany #output[2]") + mytester:assert(#(output[3]) == 2, "ZipTableOneToMany #output[3]") + mytester:assertTensorEq(input[1], output[1][1], 0.000001, "ZipTableOneToMany input1 output11") + mytester:assertTensorEq(input[1], output[2][1], 0.000001, "ZipTableOneToMany input1 output21") + mytester:assertTensorEq(input[1], output[3][1], 0.000001, "ZipTableOneToMany input1 output31") + mytester:assertTensorEq(input[2][1], output[1][2], 0.000001, "ZipTableOneToMany input21") + mytester:assertTensorEq(input[2][2], output[2][2], 0.000001, "ZipTableOneToMany input22") + mytester:assertTensorEq(input[2][3], output[3][2], 0.000001, "ZipTableOneToMany input23") + local gradInput = z:backward(input, output) + mytester:assert(#gradInput == 2, "ZipTableOneToMany #gradInput") + mytester:assert(#(gradInput[2]) == 3, "ZipTableOneToMany #gradInput[2]") + mytester:assertTensorEq(input[2][1], gradInput[2][1], 0.000001, "ZipTableOneToMany gradInput21") + mytester:assertTensorEq(input[2][2], gradInput[2][2], 0.000001, "ZipTableOneToMany gradInput22") + mytester:assertTensorEq(input[2][3], gradInput[2][3], 0.000001, "ZipTableOneToMany gradInput32") + mytester:assertTensorEq(torch.mul(input[1], 3), gradInput[1], 0.000001, "ZipTableOneToMany gradInput21") +end + mytester:add(nntest)