Skip to content

Commit

Permalink
Merge pull request #1225 from nicholas-leonard/ZeroGrad
Browse files Browse the repository at this point in the history
nn.ZeroGrad
  • Loading branch information
nicholas-leonard authored May 24, 2017
2 parents 5913e31 + 1535bd3 commit e40e281
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 1 deletion.
14 changes: 14 additions & 0 deletions ZeroGrad.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
local ZeroGrad, parent = torch.class('nn.ZeroGrad', 'nn.Module')

function ZeroGrad:updateOutput(input)
self.output:set(input)
return self.output
end

-- the gradient is simply zeroed.
-- useful when you don't want to backpropgate through certain paths.
function ZeroGrad:updateGradInput(input, gradOutput)
self.gradInput = nn.utils.recursiveResizeAs(self.gradInput, input)
self.gradInput = nn.utils.recursiveFill(self.gradInput, 0)
return self.gradInput
end
23 changes: 22 additions & 1 deletion doc/simple.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
* [WhiteNoise](#nn.WhiteNoise) : adds isotropic Gaussian noise to the signal when in training mode;
* [OneHot](#nn.OneHot) : transforms a tensor of indices into [one-hot](https://en.wikipedia.org/wiki/One-hot) encoding;
* [PrintSize](#nn.PrintSize) : prints the size of `input` and `gradOutput` (useful for debugging);
* [ZeroGrad](#nn.ZeroGrad) : forwards the `input` as-is, yet zeros the `gradInput`.

<a name="nn.Linear"></a>
## Linear ##
Expand Down Expand Up @@ -1762,4 +1763,24 @@ module = nn.PrintSize(name)
This module is useful for debugging complicated module composites.
It prints the size of the `input` and `gradOutput` during `forward`
and `backward` propagation respectively.
The `name` is a string used to identify the module along side the printed size.
The `name` is a string used to identify the module along side the printed size.

<a name='nn.ZeroGrad'></a>
## ZeroGrad ##

```lua
module = nn.ZeroGrad()
input = torch.Tensor{1,2}
gradOutput = torch.Tensor{3,4}
print(module:forward(input))
1
2
[torch.DoubleTensor of size 2]

print(module:backward(input, gradOutput))
0
0
[torch.DoubleTensor of size 2]
```

The module zeros the `gradInput` but forwards the `input` as-is.
1 change: 1 addition & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ require('nn.VolumetricDropout')
require('nn.WhiteNoise')
require('nn.OneHot')
require('nn.PrintSize')
require('nn.ZeroGrad')

require('nn.CAddTable')
require('nn.CDivTable')
Expand Down
10 changes: 10 additions & 0 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8538,6 +8538,16 @@ function nntest.OneHot()
end
end

function nntest.ZeroGrad()
local input = torch.randn(3,4)
local zg = nn.ZeroGrad()
local output = zg:forward(input)
mytester:assertTensorEq(input, output, 0.00000001)
local gradInput = zg:backward(input, input)
local gradInput2 = gradInput:clone():zero()
mytester:assertTensorEq(gradInput, gradInput2, 0.0000001)
end


mytester:add(nntest)

Expand Down

0 comments on commit e40e281

Please sign in to comment.