Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.ModuleCriterion #1232

Merged
merged 1 commit into from
May 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions ModuleCriterion.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
local ModuleCriterion, parent = torch.class("nn.ModuleCriterion", "nn.Criterion")

function ModuleCriterion:__init(criterion, inputModule, targetModule, castTarget)
self.inputModule = inputModule
self.targetModule = targetModule
self.castTarget = (castTarget == nil) and true or castTarget
if self.inputModule then
local params = self.inputModule:parameters()
if params and #params > 0 then
print"Warning: nn.ModuleCriterion doesn't support parameter updates"
end
end
self.criterion = criterion
end

function ModuleCriterion:updateOutput(input, target)
if self.inputModule then
self.input = self.inputModule:forward(input)
end
if self.targetModule then
self.target = self.targetModule:forward(target)
end
self.output = self.criterion:forward(self.input or input, self.target or target)
return self.output
end

function ModuleCriterion:updateGradInput(input, target)
self.gradInput = self.criterion:backward(self.input or input, self.target or target)
if self.inputModule then
self.gradInput = self.inputModule:backward(input, self.gradInput)
end
return self.gradInput
end

function ModuleCriterion:type(type, typecache)
if self.inputModule then
self.inputModule:type(type, typecache)
end
if self.castTarget and self.targetModule then
self.targetModule:type(type, typecache)
end
self.criterion:type(type, typecache)
return parent.type(self, type, typecache)
end
15 changes: 15 additions & 0 deletions doc/criterion.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ target, they compute a gradient according to a given loss function.
* [`MultiCriterion`](#nn.MultiCriterion) : a weighted sum of other criterions each applied to the same input and target;
* [`ParallelCriterion`](#nn.ParallelCriterion) : a weighted sum of other criterions each applied to a different input and target;
* [`MarginRankingCriterion`](#nn.MarginRankingCriterion): ranks two inputs;
* [`ModuleCriterion`](#nn.ModuleCriterion) : adds an optional `inputModule` and `targetModule` before a decorated criterion;

<a name="nn.Criterion"></a>
## Criterion ##
Expand Down Expand Up @@ -877,3 +878,17 @@ for i = 1, 100 do
end
end
```

<a name='nn.ModuleCriterion'></a>
## ModuleCriterion ##

```lua
criterion = nn.ModuleCriterion(criterion [, inputModule, targetModule, castTarget])
```

This criterion decorates a `criterion` by allowing the `input` and `target` to be
fed through an optional `inputModule` and `targetModule` before being passed to the
`criterion`. The `inputModule` must not contain parameters as these would not be updated.

When `castTarget = true` (the default), the `targetModule` is cast along with the `inputModule` and
`criterion`. Otherwise, the `targetModule` isn't.
1 change: 1 addition & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ require('nn.BCECriterion')
require('nn.CrossEntropyCriterion')
require('nn.ParallelCriterion')
require('nn.DistanceRatioCriterion')
require('nn.ModuleCriterion')

require('nn.PixelShuffle')

Expand Down
18 changes: 18 additions & 0 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2175,7 +2175,25 @@ function nntest.MarginRankingCriterion()
local v = torch.rand(2, batch_size)
local t = torch.Tensor(batch_size):random(0,1):mul(2):add(-1)
criterionJacobianTest1DTable(crit,v,t)
end

function nntest.ModuleCriterion()
local input = torch.randn(8,4)
local target = torch.randn(8,4)
local inputModule = nn.Tanh()
local criterion = nn.MSECriterion()
local mc = nn.ModuleCriterion(criterion, inputModule)

local err = mc:forward(input, target)
local gradInput = mc:backward(input, target)

local output = inputModule:forward(input)
local err2 = criterion:forward(output, target)
local gradOutput = criterion:backward(output, target)
local gradInput2 = inputModule:backward(input, gradOutput)

mytester:assert(err == err2, "ModuleCriterion backward err")
mytester:assertTensorEq(gradInput, gradInput2, 0.000001, "ModuleCriterion backward err")
end

function nntest.MaskedSelect()
Expand Down