-
Notifications
You must be signed in to change notification settings - Fork 958
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1228 from nicholas-leonard/Convert
nn.Convert
- Loading branch information
Showing
4 changed files
with
358 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
------------------------------------------------------------------------ | ||
--[ nn.Convert ]-- | ||
-- Module to convert between different data formats | ||
-- nn.Convert('bchw', 'bf') or nn.Convert('chw', 'f') | ||
-- Automatically converts input to same type as self.output | ||
-- Simplest use is for automatic input type converions : nn.Convert() | ||
------------------------------------------------------------------------ | ||
local _ = require 'moses' | ||
local Convert, parent = torch.class("nn.Convert", "nn.Container") | ||
|
||
function Convert:__init(inputShape, outputShape) | ||
if outputShape and not inputShape then | ||
error"Expecting non-nil arg 1 when arg 2 is provided" | ||
end | ||
inputShape = inputShape or 'b*' | ||
outputShape = outputShape or inputShape | ||
self.inputShape = inputShape:find('b') and inputShape or ('b'..inputShape) | ||
self.outputShape = outputShape:find('b') and outputShape or ('b'..outputShape) | ||
self.inputBatchDim = self.inputShape:find('b') | ||
self.outputBatchDim = self.outputShape:find('b') | ||
if self.inputShape == 'b*' or self.outputShape == 'b*' then | ||
assert(self.inputShape == 'b*' and self.outputShape == 'b*', 'Both or neither shapes must be b*') | ||
self.nInputDim = -1 | ||
self.nOutputDim = -1 | ||
self.transposition = true | ||
else | ||
-- number of dims in batch mode | ||
self.nInputDim = #self.inputShape | ||
self.nOutputDim = #self.outputShape | ||
-- is the outputShape just a transposition of the inputShape? | ||
if self.nInputDim == self.nOutputDim then | ||
self.transposition = true | ||
for i=1,self.nInputDim do | ||
if not self.outputShape:find(self.inputShape:sub(i,i)) then | ||
self.transposition = false | ||
break | ||
end | ||
end | ||
end | ||
end | ||
parent.__init(self) | ||
end | ||
|
||
-- post-initialization | ||
function Convert:buildConverter(input) | ||
if self.transposition then | ||
self.converter = self:transpose(self.outputShape) | ||
else | ||
if (torch.type(self[self.outputShape]) ~= 'function') then | ||
error(string.format("Unrecognized conversion of shape %s to %s", self.inputShape, self.outputShape)) | ||
end | ||
self.converter = self[self.outputShape](self, input) | ||
end | ||
assert(torch.isTensor(self.output), "Expecting Tensor output") | ||
|
||
self.converter:type(torch.type(self.output)) | ||
|
||
self.modules[1] = self.converter | ||
end | ||
|
||
function Convert:updateOutput(input) | ||
assert(torch.isTensor(input), "expecting Tensor") | ||
if not torch.isTypeOf(input, torch.type(self.output)) then | ||
-- handle different input type | ||
self._input = self._input or self.output.new() | ||
self._input:resize(input:size()):copy(input) | ||
input = self._input | ||
end | ||
self.batchMode = true | ||
if input:dim() < self.nInputDim then | ||
-- handle non-batch mode | ||
local inputSize = input:size():totable() | ||
table.insert(inputSize, self.inputBatchDim, 1) | ||
self.__input = self.__input or input.new() | ||
self.__input:set(input):resize(table.unpack(inputSize)) | ||
input = self.__input | ||
self.batchMode = false | ||
end | ||
if not self.converter then | ||
self:buildConverter(input) | ||
end | ||
|
||
self.output = self.converter:updateOutput(input) | ||
|
||
if not self.batchMode then | ||
local outputSize = self.output:size():totable() | ||
table.remove(outputSize, self.outputBatchDim) | ||
self.__output = self.__output or self.output.new() | ||
self.__output:set(self.output):resize(table.unpack(outputSize)) | ||
self.output = self.__output | ||
end | ||
return self.output | ||
end | ||
|
||
function Convert:updateGradInput(input, gradOutput) | ||
local input_ = input | ||
input = self._input or input | ||
if not self.batchMode then | ||
input = self.__input | ||
self.__gradOutput = self.__gradOutput or gradOutput.new() | ||
self.__gradOutput:set(gradOutput):resize(self.converter.output:size()) | ||
gradOutput = self.__gradOutput | ||
end | ||
|
||
local gradInput = self.converter:updateGradInput(input, gradOutput) | ||
|
||
if not self.batchMode then | ||
self.__gradInput = self.__gradInput or gradInput.new() | ||
self.__gradInput:set(gradInput):resize(input_:size()) | ||
gradInput = self.__gradInput | ||
end | ||
if self._input then | ||
self._gradInput = self._gradInput or input.new() | ||
self._gradInput:resize(input:size()):copy(gradInput) | ||
self.gradInput = self._gradInput | ||
else | ||
self.gradInput = gradInput | ||
end | ||
|
||
return self.gradInput | ||
end | ||
|
||
function Convert:accGradParameters(input, gradOutput, scale) | ||
input = self.batchMode and self.__input or self._input or input | ||
gradOutput = self.batchMode and self.__gradOutput or gradOutput | ||
self.converter:accGradParameters(input, gradOutput, scale) | ||
end | ||
|
||
function Convert:accUpdateGradParameters(input, gradOutput, lr) | ||
input = self.batchMode and self.__input or self._input or input | ||
gradOutput = self.batchMode and self.__gradOutput or gradOutput | ||
self.converter:accUpdateGradParameters(input, gradOutput, lr) | ||
end | ||
|
||
-- batch feature | ||
function Convert:bf(input) | ||
local b_pos = self:findAxis('b', self.inputShape) | ||
local dim = #self.inputShape | ||
if self.inputShape == 'bt' then | ||
error"Conversion of shape bt to bf not supported: open an issue on github" | ||
end | ||
-- was b | ||
if dim == 1 then | ||
return nn.Reshape(1) | ||
end | ||
-- was b... | ||
local modula | ||
if b_pos ~= 1 then | ||
modula = nn.Transpose({1, b_pos}) | ||
end | ||
if dim > 2 then | ||
local transpose = modula | ||
local sampleSize = input:select(self:findAxis('b'),1):nElement() | ||
local reshape = nn.Reshape(sampleSize) | ||
if transpose then | ||
modula = nn.Sequential() | ||
modula:add(transpose) | ||
modula:add(reshape) | ||
else | ||
modula = reshape | ||
end | ||
end | ||
return modula or nn.Identity() | ||
end | ||
|
||
-- each example is a scalar; batch is a vector | ||
function Convert:b(input) | ||
local b_pos = self:findAxis('b') | ||
if self.inputShape == 'bt' or self.inputShape == 'tb' then | ||
local t_pos = self:findAxis('t') | ||
-- select first set of classes | ||
return nn.Select(t_pos, 1) | ||
elseif self.inputShape == 'bf' or self.inputShape == 'fb' then | ||
-- this wont work as expected with size(f) > 1 | ||
local f_pos = self:findAxis('f') | ||
if input:size(f_pos) > 1 then | ||
error("Cannot convert shape "..self.inputShape.." to b when feature > 1") | ||
end | ||
return nn.Select(f_pos, 1) | ||
else | ||
error("Cannot convert shape "..self.inputShape.." to shape b") | ||
end | ||
end | ||
|
||
-- returns the current shape of the data | ||
function Convert:default() | ||
return nn.Identity() | ||
end | ||
|
||
-- multi-class (batch target) | ||
function Convert:bt() | ||
local b_pos = self:findAxis('b') | ||
local modula | ||
if self.inputShape == 'b' then | ||
modula = nn.Reshape(1) | ||
else | ||
error("cannot convert shape '"..self.inputShape.."' to bt") | ||
end | ||
return modula | ||
end | ||
|
||
-- a generic function for transposing shape axes | ||
function Convert:transpose(newShape) | ||
if newShape == self.inputShape then | ||
return nn.Identity() | ||
end | ||
local inputShape = {} | ||
for i=1,#self.inputShape do | ||
table.insert(inputShape, self.inputShape:sub(i,i)) | ||
end | ||
local transpositions = {} | ||
for i=1,#newShape do | ||
local j = _.indexOf(inputShape, newShape:sub(i,i)) | ||
if i ~= j then | ||
local char = inputShape[i] | ||
inputShape[i] = inputShape[j] | ||
inputShape[j] = char | ||
table.insert(transpositions, {j, i}) | ||
end | ||
end | ||
return nn.Transpose(table.unpack(transpositions)) | ||
end | ||
|
||
function Convert:findAxis(axis_char, shape, silent) | ||
shape = shape or self.inputShape | ||
local axis_pos = shape:find(axis_char) | ||
if (not silent) and (not axis_pos) then | ||
error("Provided shape '"..shape.."' has no axis '"..axis_char.."'", 2) | ||
end | ||
return axis_pos | ||
end | ||
|
||
function Convert:clearState() | ||
self._input = nil | ||
self._gradInput = nil | ||
self.__input = nil | ||
self.__output = nil | ||
self.__gradInput = nil | ||
self.__gradOutput = nil | ||
end | ||
|
||
function Convert:type(type) | ||
self:clearState() | ||
return parent.type(self, type) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters