Skip to content

Commit

Permalink
Do not call clearState on model while training
Browse files Browse the repository at this point in the history
Calling clearState() seems to cause issues that, after 4-5 days of
debugging, I haven't been able to fix. See, for example:

torch/nn#1141
torch/cunn#441

Further, it's unclear to me if `getParameters` and memory management in
general works well when a call to `clearState` can destroy modules (and
therefore weight tensors). The easiest solution to all of this is simply
to never call clearState on the model while it is training.

When saving the model, we create a copy of it on the CPU, and call
clearState on this CPU copy, which we then save to disk.
  • Loading branch information
achalddave committed Feb 27, 2017
1 parent 0944aaa commit c6b93fe
Showing 1 changed file with 20 additions and 5 deletions.
25 changes: 20 additions & 5 deletions trainer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,35 @@ function Trainer:evaluate_epoch(epoch, num_batches)
self:_train_or_evaluate_epoch(epoch, num_batches, false --[[train_mode]])
end

function Trainer.static.deep_copy(tbl)
-- Copied from fb.resnet.torch repo.
-- Creates a copy of a network with new modules and the same tensors.
local copy = {}
for k, v in pairs(tbl) do
if type(v) == 'table' then
copy[k] = Trainer.deep_copy(v)
else
copy[k] = v
end
end
if torch.typename(tbl) then
torch.setmetatable(copy, torch.typename(tbl))
end
return copy
end

function Trainer:save(directory, epoch)
--[[
Save model, optimization config, and optimization config to a directory.
]]--
-- Clear intermediate states in the model before saving to disk to minimize
-- disk space usage.
self.model:clearState()
local model = self.model
if torch.isTypeOf(self.model, 'nn.DataParallelTable') then
if torch.isTypeOf(model, 'nn.DataParallelTable') then
model = model:get(1)
end
torch.save(paths.concat(directory, 'model_' .. epoch .. '.t7'), model)
local cpu_model = Trainer.deep_copy(model):float():clearState()
torch.save(paths.concat(directory, 'model_' .. epoch .. '.t7'), cpu_model)
torch.save(paths.concat(directory, 'optim_config_' .. epoch .. '.t7'),
self.optimization_config)
torch.save(paths.concat(directory, 'optim_state_' .. epoch .. '.t7'),
Expand Down Expand Up @@ -207,7 +224,6 @@ end

function Trainer:_train_or_evaluate_epoch(epoch, num_batches, train_mode)
if train_mode then
self.model:clearState()
self.model:training()
self:update_optim_config(epoch)
else
Expand Down Expand Up @@ -502,7 +518,6 @@ end
function SequentialTrainer:_train_or_evaluate_epoch(
epoch, num_sequences, train_mode)
if train_mode then
self.model:clearState()
self.model:training()
self:update_optim_config(epoch)
else
Expand Down

0 comments on commit c6b93fe

Please sign in to comment.