Skip to content

Commit

Permalink
Merge pull request #1258 from nicholas-leonard/container-fix
Browse files Browse the repository at this point in the history
parameters() uses torch.type instead of type
  • Loading branch information
nicholas-leonard authored Jul 11, 2017
2 parents 4bd94cb + 0aeb67b commit 649f392
Show file tree
Hide file tree
Showing 11 changed files with 36 additions and 36 deletions.
2 changes: 1 addition & 1 deletion Bilinear.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ local Bilinear, parent = torch.class('nn.Bilinear', 'nn.Module')

local function isint(x) return type(x) == 'number' and x == math.floor(x) end
function Bilinear:__assertInput(input)
assert(input and type(input) == 'table' and #input == 2,
assert(input and torch.type(input) == 'table' and #input == 2,
'input should be a table containing two data Tensors')
assert(input[1]:nDimension() == 2 and input[2]:nDimension() == 2,
'input Tensors should be two-dimensional')
Expand Down
4 changes: 2 additions & 2 deletions Container.lua
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ end

function Container:parameters()
local function tinsert(to, from)
if type(from) == 'table' then
if torch.type(from) == 'table' then
for i=1,#from do
tinsert(to,from[i])
end
Expand All @@ -131,7 +131,7 @@ function Container:clearState()
if self[f] then
if torch.isTensor(self[f]) then
self[f] = self[f].new()
elseif type(self[f]) == 'table' then
elseif torch.type(self[f]) == 'table' then
self[f] = {}
else
self[f] = nil
Expand Down
2 changes: 1 addition & 1 deletion DontCast.lua
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ local function recursiveTypeCopy(dst, src, type_str)
end

local function tableTensorType(src)
if type(src) == 'table' then
if type(src) == 'table' then -- Note: don't use torch.type here
local type_str, found
for k,v in pairs(src) do
type_str, found = tableTensorType(v)
Expand Down
16 changes: 8 additions & 8 deletions FlattenTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ end
local function flatten(output, input)
local input_map -- has the same structure as input, but stores the
-- indices to the corresponding output
if type(input) == 'table' then
if torch.type(input) == 'table' then
input_map = {}
-- forward DFS order
for i = 1, #input do
Expand All @@ -30,8 +30,8 @@ local function checkMapping(output, input, input_map)
if input_map == nil or output == nil or input == nil then
return false
end
if type(input) == 'table' then
if type(input_map) ~= 'table' then
if torch.type(input) == 'table' then
if torch.type(input_map) ~= 'table' then
return false
end
if #input ~= #input_map then
Expand All @@ -46,7 +46,7 @@ local function checkMapping(output, input, input_map)
end
return true
else
if type(input_map) ~= 'number' then
if torch.type(input_map) ~= 'number' then
return false
end
return output[input_map] == input
Expand All @@ -56,7 +56,7 @@ end
-- During BPROP we have to build a gradInput with the same shape as the
-- input. This is a recursive function to build up a gradInput
local function inverseFlatten(gradOutput, input_map)
if type(input_map) == 'table' then
if torch.type(input_map) == 'table' then
local gradInput = {}
for i = 1, #input_map do
gradInput[#gradInput + 1] = inverseFlatten(gradOutput, input_map[i])
Expand All @@ -68,7 +68,7 @@ local function inverseFlatten(gradOutput, input_map)
end

function FlattenTable:updateOutput(input)
assert(type(input) == 'table', 'input must be a table')
assert(torch.type(input) == 'table', 'input must be a table')
-- to avoid updating rebuilding the flattened table every updateOutput call
-- we will do a DFS pass over the existing output table and the inputs to
-- see if it needs to be rebuilt.
Expand All @@ -80,8 +80,8 @@ function FlattenTable:updateOutput(input)
end

function FlattenTable:updateGradInput(input, gradOutput)
assert(type(input) == 'table', 'input must be a table')
assert(type(input) == 'table', 'gradOutput must be a table')
assert(torch.type(input) == 'table', 'input must be a table')
assert(torch.type(input) == 'table', 'gradOutput must be a table')
-- If the input changes between the updateOutput and updateGradInput call,
-- then we may have to rebuild the input_map! However, let's assume that
-- the input_map is valid and that forward has already been called.
Expand Down
2 changes: 1 addition & 1 deletion Identity.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function Identity:clearState()
if self[f] then
if torch.isTensor(self[f]) then
self[f] = self[f].new()
elseif type(self[f]) == 'table' then
elseif torch.type(self[f]) == 'table' then
self[f] = {}
else
self[f] = nil
Expand Down
4 changes: 2 additions & 2 deletions IndexLinear.lua
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function IndexLinear:reset(stdv)
end

function IndexLinear:reshapeInput(input)
assert(type(input) == 'table')
assert(torch.type(input) == 'table')

local ninputs = 0
for _, v in ipairs(input) do
Expand Down Expand Up @@ -108,7 +108,7 @@ function IndexLinear:reshapeInput(input)
-- { torch.LongTensor(size1), torch.LongTensor(size2), ..., torch.LongTensor(sizeN) }, -- batch of keys
-- { torch.Tensor(size1), torch.Tensor(size2), ..., torch.Tensor(sizeN) }, -- batch of values,
-- }
if type(keys) == 'table' and type(values) == 'table' then
if torch.type(keys) == 'table' and torch.type(values) == 'table' then
lkeys, lvalues = keys, values
self.isFlat = false
self.noBatch = false
Expand Down
12 changes: 6 additions & 6 deletions SparseLinear.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function SparseLinear:__init(inputSize, outputSize, doGradInput)
self.gradWeight = torch.Tensor(outputSize, inputSize):zero()
self.gradBias = torch.Tensor(outputSize):zero()

assert(type(self.doGradInput) == type(true))
assert(type(self.doGradInput) == 'boolean')

self.lastInput = nil
self.sparseUpdate = NO_LAST_INPUT
Expand All @@ -39,7 +39,7 @@ function SparseLinear:reset(stdv)
end

function SparseLinear:reshapeInput(input)
if type(input) == 'table' then
if torch.type(input) == 'table' then
return input, true, false
else
if input:dim() == 2 then
Expand All @@ -57,7 +57,7 @@ function SparseLinear:updateOutput(input)
local input, batchMode, legacyMode = self:reshapeInput(input)
self.legacyMode = legacyMode

if legacyMode then
if legacyMode then
input.THNN.SparseLinear_legacyUpdateOutput(
input:cdata(),
self.output:cdata(),
Expand Down Expand Up @@ -149,8 +149,8 @@ function SparseLinear:accGradParameters(input, gradOutput, scale)
end

function SparseLinear:updateGradInput(input, gradOutput)
if self.legacyMode then
if type(self.gradInput) ~= type(gradOutput) then self.gradInput = gradOutput.new() end
if self.legacyMode then
if torch.type(self.gradInput) ~= torch.type(gradOutput) then self.gradInput = gradOutput.new() end
self.gradInput:resizeAs(input)
else
self.gradInput = {}
Expand Down Expand Up @@ -185,7 +185,7 @@ function SparseLinear:updateGradInput(input, gradOutput)
return self.gradInput
end

-- These functions do sparse updates / zeros. However, if we accumulated
-- These functions do sparse updates / zeros. However, if we accumulated
-- gradients multiple times, we can't depend on the last input to do sparse
-- updates.
function SparseLinear:updateParameters(learningRate)
Expand Down
10 changes: 5 additions & 5 deletions SpatialFullConvolution.lua
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function SpatialFullConvolution:updateOutput(input)

-- The input can be a table where the second element indicates the target
-- output size, in which case the adj factors are computed automatically
if type(inputTensor) == 'table' then
if torch.type(inputTensor) == 'table' then
inputTensor = input[1]
local targetTensor = input[2]
local tDims = targetTensor:dim()
Expand Down Expand Up @@ -113,7 +113,7 @@ function SpatialFullConvolution:updateGradInput(input, gradOutput)

-- The input can be a table where the second element indicates the target
-- output size, in which case the adj factors are computed automatically
if type(inputTensor) == 'table' then
if torch.type(inputTensor) == 'table' then
inputTensor = input[1]
local targetTensor = input[2]
local tDims = targetTensor:dim()
Expand All @@ -122,7 +122,7 @@ function SpatialFullConvolution:updateGradInput(input, gradOutput)
adjW = calculateAdj(tW, self.kW, self.padW, self.dW)
adjH = calculateAdj(tH, self.kH, self.padH, self.dH)
-- Momentarily extract the gradInput tensor
if type(self.gradInput) == 'table' then
if torch.type(self.gradInput) == 'table' then
self.gradInput = self.gradInput[1] or inputTensor.new()
end
end
Expand All @@ -139,7 +139,7 @@ function SpatialFullConvolution:updateGradInput(input, gradOutput)
adjW, adjH
)

if type(input) == 'table' then
if torch.type(input) == 'table' then
-- Create a zero tensor to be expanded and used as gradInput[2].
self.zeroScalar = self.zeroScalar or input[2].new(1):zero()
self.ones:resize(input[2]:dim()):fill(1)
Expand All @@ -162,7 +162,7 @@ function SpatialFullConvolution:accGradParameters(input, gradOutput, scale)

-- The input can be a table where the second element indicates the target
-- output size, in which case the adj factors are computed automatically
if type(inputTensor) == 'table' then
if torch.type(inputTensor) == 'table' then
inputTensor = input[1]
local targetTensor = input[2]
local tDims = targetTensor:dim()
Expand Down
10 changes: 5 additions & 5 deletions VolumetricFullConvolution.lua
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function VolumetricFullConvolution:updateOutput(input)

-- The input can be a table where the second element indicates the target
-- output size, in which case the adj factors are computed automatically
if type(inputTensor) == 'table' then
if torch.type(inputTensor) == 'table' then
inputTensor = input[1]
local targetTensor = input[2]
local tDims = targetTensor:dim()
Expand Down Expand Up @@ -128,7 +128,7 @@ function VolumetricFullConvolution:updateGradInput(input, gradOutput)

-- The input can be a table where the second element indicates the target
-- output size, in which case the adj factors are computed automatically
if type(inputTensor) == 'table' then
if torch.type(inputTensor) == 'table' then
inputTensor = input[1]
local targetTensor = input[2]
local tDims = targetTensor:dim()
Expand All @@ -139,7 +139,7 @@ function VolumetricFullConvolution:updateGradInput(input, gradOutput)
adjW = calculateAdj(tW, self.kW, self.padW, self.dW)
adjH = calculateAdj(tH, self.kH, self.padH, self.dH)
-- Momentarily extract the gradInput tensor
if type(self.gradInput) == 'table' then
if torch.type(self.gradInput) == 'table' then
self.gradInput = self.gradInput[1]
end
end
Expand All @@ -156,7 +156,7 @@ function VolumetricFullConvolution:updateGradInput(input, gradOutput)
adjT, adjW, adjH
)

if type(input) == 'table' then
if torch.type(input) == 'table' then
-- Create a zero tensor to be expanded and used as gradInput[2].
self.zeroScalar = self.zeroScalar or input[2].new(1):zero()
self.ones:resize(input[2]:dim()):fill(1)
Expand All @@ -177,7 +177,7 @@ function VolumetricFullConvolution:accGradParameters(input, gradOutput, scale)

-- The input can be a table where the second element indicates the target
-- output size, in which case the adj factors are computed automatically
if type(inputTensor) == 'table' then
if torch.type(inputTensor) == 'table' then
inputTensor = input[1]
local targetTensor = input[2]
local tDims = targetTensor:dim()
Expand Down
4 changes: 2 additions & 2 deletions hessian.lua
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ function nn.hessian.enable()
function nn.SpatialConvolution.initDiagHessianParameters(self)
initDiagHessianParameters(self,{'gradWeight','gradBias'},{'diagHessianWeight','diagHessianBias'})
end

----------------------------------------------------------------------
-- SpatialConvolutionLocal
----------------------------------------------------------------------
Expand Down Expand Up @@ -361,7 +361,7 @@ function nn.hessian.enable()

function nn.Sequential.parameters(self)
local function tinsert(to, from)
if type(from) == 'table' then
if torch.type(from) == 'table' then
for i=1,#from do
tinsert(to,from[i])
end
Expand Down
6 changes: 3 additions & 3 deletions utils.lua
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ function nn.utils.addSingletonDimension(...)
else
view, t, dim = select(1,...)
assert(torch.isTensor(view),
"output tensor expected, got " .. type(view))
"output tensor expected, got " .. torch.type(view))
end

assert(torch.isTensor(t), "input tensor expected")
Expand Down Expand Up @@ -202,14 +202,14 @@ end
-- nn.utils.clearState(self, '_buffer', '_buffer2')
function nn.utils.clear(self, ...)
local arg = {...}
if #arg > 0 and type(arg[1]) == 'table' then
if #arg > 0 and torch.type(arg[1]) == 'table' then
arg = arg[1]
end
local function clear(f)
if self[f] then
if torch.isTensor(self[f]) then
self[f]:set()
elseif type(self[f]) == 'table' then
elseif torch.type(self[f]) == 'table' then
self[f] = {}
else
self[f] = nil
Expand Down

0 comments on commit 649f392

Please sign in to comment.