diff --git a/SpatialMaxPooling.lua b/SpatialMaxPooling.lua index 29274b216..8475b13f5 100644 --- a/SpatialMaxPooling.lua +++ b/SpatialMaxPooling.lua @@ -30,6 +30,11 @@ end function SpatialMaxPooling:updateOutput(input) self.indices = self.indices or input.new() + + local dims = input:dim() + self.iheight = input:size(dims-1) + self.iwidth = input:size(dims) + -- backward compatibility self.ceil_mode = self.ceil_mode or false self.padW = self.padW or 0 diff --git a/SpatialMaxUnpooling.lua b/SpatialMaxUnpooling.lua index d88b1dd4c..408bcc052 100644 --- a/SpatialMaxUnpooling.lua +++ b/SpatialMaxUnpooling.lua @@ -2,16 +2,9 @@ local SpatialMaxUnpooling, parent = torch.class('nn.SpatialMaxUnpooling', 'nn.Mo function SpatialMaxUnpooling:__init(poolingModule) parent.__init(self) - assert(torch.type(poolingModule)=='nn.SpatialMaxPooling', 'Argument must be a nn.SPatialMaxPooling module') + assert(torch.type(poolingModule)=='nn.SpatialMaxPooling', 'Argument must be a nn.SpatialMaxPooling module') assert(poolingModule.kH==poolingModule.dH and poolingModule.kW==poolingModule.dW, "The size of pooling module's kernel must be equal to its stride") self.pooling = poolingModule - - poolingModule.updateOutput = function(pool, input) - local dims = input:dim() - pool.iheight = input:size(dims-1) - pool.iwidth = input:size(dims) - return nn.SpatialMaxPooling.updateOutput(pool, input) - end end function SpatialMaxUnpooling:setParams() diff --git a/VolumetricMaxPooling.lua b/VolumetricMaxPooling.lua index 036f2c860..fd652310e 100644 --- a/VolumetricMaxPooling.lua +++ b/VolumetricMaxPooling.lua @@ -36,6 +36,11 @@ function VolumetricMaxPooling:floor() end function VolumetricMaxPooling:updateOutput(input) + local dims = input:dim() + self.itime = input:size(dims-2) + self.iheight = input:size(dims-1) + self.iwidth = input:size(dims) + self.indices = self.indices or input.new() input.THNN.VolumetricMaxPooling_updateOutput( input:cdata(), diff --git a/VolumetricMaxUnpooling.lua b/VolumetricMaxUnpooling.lua index 1bb04ed18..6291f5b85 100644 --- a/VolumetricMaxUnpooling.lua +++ b/VolumetricMaxUnpooling.lua @@ -5,14 +5,6 @@ function VolumetricMaxUnpooling:__init(poolingModule) assert(torch.type(poolingModule)=='nn.VolumetricMaxPooling', 'Argument must be a nn.VolumetricMaxPooling module') assert(poolingModule.kT==poolingModule.dT and poolingModule.kH==poolingModule.dH and poolingModule.kW==poolingModule.dW, "The size of pooling module's kernel must be equal to its stride") self.pooling = poolingModule - - poolingModule.updateOutput = function(pool, input) - local dims = input:dim() - pool.itime = input:size(dims-2) - pool.iheight = input:size(dims-1) - pool.iwidth = input:size(dims) - return nn.VolumetricMaxPooling.updateOutput(pool, input) - end end function VolumetricMaxUnpooling:setParams()