Skip to content

Commit

Permalink
Merge pull request #1226 from nicholas-leonard/ZipTable
Browse files Browse the repository at this point in the history
nn.ZipTable
  • Loading branch information
nicholas-leonard authored May 25, 2017
2 parents e40e281 + 78f9a49 commit d1f66cb
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 4 deletions.
34 changes: 34 additions & 0 deletions ZipTable.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
local ZipTable, parent = torch.class('nn.ZipTable', 'nn.Module')

-- input : { {a1,a2}, {b1,b2}, {c1,c2} }
-- output : { {a1,b1,c1}, {a2,b2,c2} }
function ZipTable:__init()
parent.__init(self)
self.output = {}
self.gradInput = {}
end

function ZipTable:updateOutput(inputTable)
self.output = {}
for i,inTable in ipairs(inputTable) do
for j,input in ipairs(inTable) do
local output = self.output[j] or {}
output[i] = input
self.output[j] = output
end
end
return self.output
end

function ZipTable:updateGradInput(inputTable, gradOutputTable)
self.gradInput = {}
for i,gradOutTable in ipairs(gradOutputTable) do
for j,gradOutput in ipairs(gradOutTable) do
local gradInput = self.gradInput[j] or {}
gradInput[i] = gradOutput
self.gradInput[j] = gradInput
end
end
return self.gradInput
end

37 changes: 37 additions & 0 deletions ZipTableOneToMany.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
local ZipTableOneToMany, parent = torch.class('nn.ZipTableOneToMany', 'nn.Module')

-- based on ZipTable in dpnn

-- input : { v, {a, b, c} }
-- output : { {v,a}, {v,b}, {v,c} }
function ZipTableOneToMany:__init()
parent.__init(self)
self.output = {}
self.gradInput = {}
-- make buffer to update during forward/backward
self.gradInputEl = torch.Tensor()
end

function ZipTableOneToMany:updateOutput(input)
assert(#input == 2, "input must be table of element and table")
local inputEl, inputTable = input[1], input[2]
self.output = {}
for i,v in ipairs(inputTable) do
self.output[i] = {inputEl, v}
end
return self.output
end

function ZipTableOneToMany:updateGradInput(input, gradOutput)
assert(#input == 2, "input must be table of element and table")
local inputEl, inputTable = input[1], input[2]
self.gradInputEl:resizeAs(inputEl):zero()
local gradInputTable = {}
for i,gradV in ipairs(gradOutput) do
self.gradInputEl:add(gradV[1])
gradInputTable[i] = gradV[2]
end
self.gradInput = {self.gradInputEl, gradInputTable}
return self.gradInput
end

40 changes: 36 additions & 4 deletions doc/table.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ This allows one to build very rich architectures:
* [`SelectTable`](#nn.SelectTable): select one element from a `table`;
* [`NarrowTable`](#nn.NarrowTable): select a slice of elements from a `table`;
* [`FlattenTable`](#nn.FlattenTable): flattens a nested `table` hierarchy;
* [`ZipTable`](#nn.ZipTable) : zip a table of tables into a table of tables;
* [`ZipTableOneToMany`](#nn.ZipTableOneToMany) : zip a table to a single tensor;
* Pair Modules compute a measure like distance or similarity from a pair (`table`) of input `Tensor`s:
* [`PairwiseDistance`](#nn.PairwiseDistance): outputs the `p`-norm. distance between inputs;
* [`DotProduct`](#nn.DotProduct): outputs the dot product (similarity) between inputs;
Expand Down Expand Up @@ -692,7 +694,7 @@ Forwarding a batch of 2 examples gives us something like this:

`module` = `SelectTable(index)`

Creates a module that takes a (nested) `table` as input and outputs the element at index `index`. `index` can be strings or integers (positive or negative).
Creates a module that takes a (nested) `table` as input and outputs the element at index `index`. `index` can be strings or integers (positive or negative).
This can be either a `table` or a [`Tensor`](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor).

The gradients of the non-`index` elements are zeroed `Tensor`s of the same size. This is true regardless of the
Expand Down Expand Up @@ -731,7 +733,7 @@ Exmaple 2:

> gradInput = nn.SelectTable("A"):backward(input, torch.randn(2, 3))

> gradInput
> gradInput
{
A : DoubleTensor - size: 2x3
B : DoubleTensor - size: 2x1
Expand Down Expand Up @@ -811,11 +813,11 @@ Example 3:

`module` = `NarrowTable(offset [, length])`

Creates a module that takes a `table` as input and outputs the subtable
Creates a module that takes a `table` as input and outputs the subtable
starting at index `offset` having `length` elements (defaults to 1 element).
The elements can be either a `table` or a [`Tensor`](https://github.com/torch/torch7/blob/master/doc/tensor.md#tensor).

The gradients of the elements not included in the subtable are zeroed `Tensor`s of the same size.
The gradients of the elements not included in the subtable are zeroed `Tensor`s of the same size.
This is true regardless of the depth of the encapsulated `Tensor` as the function used internally to do so is recursive.

Example:
Expand Down Expand Up @@ -883,6 +885,36 @@ gives the output:
}
```

<a name='nn.ZipTable'></a>
## ZipTable ##

```lua
module = nn.ZipTable()
```

Zips a table of tables into a table of tables.

Example:
```lua
print(module:forward{ {'a1','a2'}, {'b1','b2'}, {'c1','c2'} })
{ {'a1','b1','c1'}, {'a2','b2','c2'} }
```

<a name='nn.ZipTableOneToMany'></a>
## ZipTableOneToMany ##

```lua
module = nn.ZipTableOneToMany()
```

Zips a table of element `el` and table of elements `tab` into a table of tables, where the i-th table contains the element `el` and the i-th element in table `tab`

Example:
```lua
print(module:forward{ 'el', {'a','b','c'} })
{ {'el','a'}, {'el','b'}, {'el','c'} }
```

<a name="nn.PairwiseDistance"></a>
## PairwiseDistance ##

Expand Down
2 changes: 2 additions & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ require('nn.CriterionTable')
require('nn.FlattenTable')
require('nn.NarrowTable')
require('nn.MapTable')
require('nn.ZipTable')
require('nn.ZipTableOneToMany')

require('nn.Criterion')
require('nn.MSECriterion')
Expand Down
48 changes: 48 additions & 0 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8548,6 +8548,54 @@ function nntest.ZeroGrad()
mytester:assertTensorEq(gradInput, gradInput2, 0.0000001)
end

function nntest.ZipTable()
-- input : { {a1,a2}, {b1,b2}, {c1,c2} }
-- output : { {a1,b1,c1}, {a2,b2,c2} }
local z = nn.ZipTable()
local input = {
{torch.randn(3,4), torch.randn(3,4)},
{torch.randn(3,4), torch.randn(3,4)},
{torch.randn(3,4), torch.randn(3,4)}
}
local output = z:forward(input)
mytester:assert(#output == 2, "ZipTable #output")
mytester:assert(#(output[1]) == 3, "ZipTable #output[1]")
mytester:assertTensorEq(input[1][1], output[1][1], 0.000001, "ZipTable input11")
mytester:assertTensorEq(input[1][2], output[2][1], 0.000001, "ZipTable input12")
mytester:assertTensorEq(input[3][2], output[2][3], 0.000001, "ZipTable input32")
local gradInput = z:backward(input, output)
mytester:assert(#gradInput == 3, "ZipTable #gradInput")
mytester:assert(#(gradInput[1]) == 2, "ZipTable #gradInput[1]")
mytester:assertTensorEq(input[1][1], gradInput[1][1], 0.000001, "ZipTable gradInput11")
mytester:assertTensorEq(input[1][2], gradInput[1][2], 0.000001, "ZipTable gradInput12")
mytester:assertTensorEq(input[3][2], gradInput[3][2], 0.000001, "ZipTable gradInput32")
end

function nntest.ZipTableOneToMany()
-- input : { v, {a,b,c} }
-- output : { {v,a}, {v,b}, {v,c} }
local z = nn.ZipTableOneToMany()
local input = { torch.randn(3), { torch.randn(4), torch.rand(4), torch.rand(4) } }
local output = z:forward(input)
mytester:assert(#output == 3, "ZipTableOneToMany #output")
mytester:assert(#(output[1]) == 2, "ZipTableOneToMany #output[1]")
mytester:assert(#(output[2]) == 2, "ZipTableOneToMany #output[2]")
mytester:assert(#(output[3]) == 2, "ZipTableOneToMany #output[3]")
mytester:assertTensorEq(input[1], output[1][1], 0.000001, "ZipTableOneToMany input1 output11")
mytester:assertTensorEq(input[1], output[2][1], 0.000001, "ZipTableOneToMany input1 output21")
mytester:assertTensorEq(input[1], output[3][1], 0.000001, "ZipTableOneToMany input1 output31")
mytester:assertTensorEq(input[2][1], output[1][2], 0.000001, "ZipTableOneToMany input21")
mytester:assertTensorEq(input[2][2], output[2][2], 0.000001, "ZipTableOneToMany input22")
mytester:assertTensorEq(input[2][3], output[3][2], 0.000001, "ZipTableOneToMany input23")
local gradInput = z:backward(input, output)
mytester:assert(#gradInput == 2, "ZipTableOneToMany #gradInput")
mytester:assert(#(gradInput[2]) == 3, "ZipTableOneToMany #gradInput[2]")
mytester:assertTensorEq(input[2][1], gradInput[2][1], 0.000001, "ZipTableOneToMany gradInput21")
mytester:assertTensorEq(input[2][2], gradInput[2][2], 0.000001, "ZipTableOneToMany gradInput22")
mytester:assertTensorEq(input[2][3], gradInput[2][3], 0.000001, "ZipTableOneToMany gradInput32")
mytester:assertTensorEq(torch.mul(input[1], 3), gradInput[1], 0.000001, "ZipTableOneToMany gradInput21")
end


mytester:add(nntest)

Expand Down

0 comments on commit d1f66cb

Please sign in to comment.