From 9cffea51a9a5065342d6ea80f965d5fad32729e1 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 13 Apr 2016 17:20:46 +0200 Subject: [PATCH] Remove unnecessary function override in unpooling modules (#749) Doing so breaks deserialization on systems with different architectures. These operations are cheap, so it's not a problem to perform them even when there's no unpooling associated. --- SpatialMaxPooling.lua | 5 +++++ SpatialMaxUnpooling.lua | 9 +-------- VolumetricMaxPooling.lua | 5 +++++ VolumetricMaxUnpooling.lua | 8 -------- 4 files changed, 11 insertions(+), 16 deletions(-) diff --git a/SpatialMaxPooling.lua b/SpatialMaxPooling.lua index 29274b216..8475b13f5 100644 --- a/SpatialMaxPooling.lua +++ b/SpatialMaxPooling.lua @@ -30,6 +30,11 @@ end function SpatialMaxPooling:updateOutput(input) self.indices = self.indices or input.new() + + local dims = input:dim() + self.iheight = input:size(dims-1) + self.iwidth = input:size(dims) + -- backward compatibility self.ceil_mode = self.ceil_mode or false self.padW = self.padW or 0 diff --git a/SpatialMaxUnpooling.lua b/SpatialMaxUnpooling.lua index d88b1dd4c..408bcc052 100644 --- a/SpatialMaxUnpooling.lua +++ b/SpatialMaxUnpooling.lua @@ -2,16 +2,9 @@ local SpatialMaxUnpooling, parent = torch.class('nn.SpatialMaxUnpooling', 'nn.Mo function SpatialMaxUnpooling:__init(poolingModule) parent.__init(self) - assert(torch.type(poolingModule)=='nn.SpatialMaxPooling', 'Argument must be a nn.SPatialMaxPooling module') + assert(torch.type(poolingModule)=='nn.SpatialMaxPooling', 'Argument must be a nn.SpatialMaxPooling module') assert(poolingModule.kH==poolingModule.dH and poolingModule.kW==poolingModule.dW, "The size of pooling module's kernel must be equal to its stride") self.pooling = poolingModule - - poolingModule.updateOutput = function(pool, input) - local dims = input:dim() - pool.iheight = input:size(dims-1) - pool.iwidth = input:size(dims) - return nn.SpatialMaxPooling.updateOutput(pool, input) - end end function SpatialMaxUnpooling:setParams() diff --git a/VolumetricMaxPooling.lua b/VolumetricMaxPooling.lua index 036f2c860..fd652310e 100644 --- a/VolumetricMaxPooling.lua +++ b/VolumetricMaxPooling.lua @@ -36,6 +36,11 @@ function VolumetricMaxPooling:floor() end function VolumetricMaxPooling:updateOutput(input) + local dims = input:dim() + self.itime = input:size(dims-2) + self.iheight = input:size(dims-1) + self.iwidth = input:size(dims) + self.indices = self.indices or input.new() input.THNN.VolumetricMaxPooling_updateOutput( input:cdata(), diff --git a/VolumetricMaxUnpooling.lua b/VolumetricMaxUnpooling.lua index 1bb04ed18..6291f5b85 100644 --- a/VolumetricMaxUnpooling.lua +++ b/VolumetricMaxUnpooling.lua @@ -5,14 +5,6 @@ function VolumetricMaxUnpooling:__init(poolingModule) assert(torch.type(poolingModule)=='nn.VolumetricMaxPooling', 'Argument must be a nn.VolumetricMaxPooling module') assert(poolingModule.kT==poolingModule.dT and poolingModule.kH==poolingModule.dH and poolingModule.kW==poolingModule.dW, "The size of pooling module's kernel must be equal to its stride") self.pooling = poolingModule - - poolingModule.updateOutput = function(pool, input) - local dims = input:dim() - pool.itime = input:size(dims-2) - pool.iheight = input:size(dims-1) - pool.iwidth = input:size(dims) - return nn.VolumetricMaxPooling.updateOutput(pool, input) - end end function VolumetricMaxUnpooling:setParams()