Skip to content

Commit

Permalink
Merge pull request #1227 from nicholas-leonard/Collapse
Browse files Browse the repository at this point in the history
nn.Collapse
  • Loading branch information
nicholas-leonard authored May 25, 2017
2 parents d1f66cb + eb6548a commit c6f1da5
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 2 deletions.
30 changes: 30 additions & 0 deletions Collapse.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
local Collapse, parent = torch.class('nn.Collapse', 'nn.Module')

-- collapses non-batch dims
function Collapse:__init(nInputDim)
parent.__init(self)
self.nInputDim = nInputDim
end

function Collapse:updateOutput(input)
if not input:isContiguous() then
self._input = self._input or input.new()
self._input:resize(input:size()):copy(input)
input = self._input
end
if input:dim() > self.nInputDim then
self.output:view(input,input:size(1),-1)
else
self.output:view(input,-1)
end
return self.output
end

function Collapse:updateGradInput(input, gradOutput)
self.gradInput:view(gradOutput, input:size())
return self.gradInput
end

function Collapse:clearState()
self._input = nil
end
24 changes: 22 additions & 2 deletions doc/simple.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
* [WhiteNoise](#nn.WhiteNoise) : adds isotropic Gaussian noise to the signal when in training mode;
* [OneHot](#nn.OneHot) : transforms a tensor of indices into [one-hot](https://en.wikipedia.org/wiki/One-hot) encoding;
* [PrintSize](#nn.PrintSize) : prints the size of `input` and `gradOutput` (useful for debugging);
* [ZeroGrad](#nn.ZeroGrad) : forwards the `input` as-is, yet zeros the `gradInput`.
* [ZeroGrad](#nn.ZeroGrad) : forwards the `input` as-is, yet zeros the `gradInput`;
* [Collapse](#nn.Collapse) : just like `nn.View(-1)`.

<a name="nn.Linear"></a>
## Linear ##
Expand Down Expand Up @@ -1029,6 +1030,8 @@ Example 2:
[torch.LongStorage of size 2]
```

For collapsing non-batch dims, check out [nn.Collapse](#nn.Collapse).

<a name="nn.Contiguous"></a>
## Contiguous ##

Expand Down Expand Up @@ -1783,4 +1786,21 @@ print(module:backward(input, gradOutput))
[torch.DoubleTensor of size 2]
```

The module zeros the `gradInput` but forwards the `input` as-is.
The module zeros the `gradInput` but forwards the `input` as-is.

<a name='nn.Collapse'></a>
## Collapse ##

```lua
module = nn.Collapse(nInputDim)
```

This module is the equivalent of:
```
view = nn.View(-1)
view:setNumInputDim(nInputDim)
```

It collapses all non-batch dimensions. This is useful for converting
a spatial feature map to the single dimension required by a dense
hidden layer like Linear.
1 change: 1 addition & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ require('nn.NarrowTable')
require('nn.MapTable')
require('nn.ZipTable')
require('nn.ZipTableOneToMany')
require('nn.Collapse')

require('nn.Criterion')
require('nn.MSECriterion')
Expand Down
16 changes: 16 additions & 0 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8596,6 +8596,22 @@ function nntest.ZipTableOneToMany()
mytester:assertTensorEq(torch.mul(input[1], 3), gradInput[1], 0.000001, "ZipTableOneToMany gradInput21")
end

function nntest.Collapse()
local c = nn.Collapse(3)
local input = torch.randn(8,3,4,5)
local output = c:forward(input)
mytester:assertTensorEq(input:view(8,-1), output, 0.000001, "Collapse:forward")
local gradInput = c:backward(input, output)
mytester:assertTensorEq(gradInput, input, 0.000001, "Collapse:backward")
mytester:assertTableEq(gradInput:size():totable(), input:size():totable(), 0.000001, "Collapse:backward size")
local input2 = input:transpose(1,4)
local output2 = c:forward(input2)
mytester:assertTensorEq(input2:contiguous():view(5,-1), output2, 0.000001, "Collapse:forward non-contiguous")
local gradInput2 = c:backward(input2, output2)
mytester:assertTensorEq(gradInput2, input2, 0.000001, "Collapse:backward non-contiguous")
mytester:assertTableEq(gradInput2:size():totable(), input2:size():totable(), 0.000001, "Collapse:backward size non-contiguous")
end


mytester:add(nntest)

Expand Down

0 comments on commit c6f1da5

Please sign in to comment.