Skip to content

Commit

Permalink
Add options for padding batch when video ends
Browse files Browse the repository at this point in the history
  • Loading branch information
achalddave committed Nov 29, 2017
1 parent 1cedac9 commit 093bdab
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 4 deletions.
36 changes: 33 additions & 3 deletions samplers.lua
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,10 @@ end

local SequentialBatchSampler, SequentialBatchSamplerSuper = classic.class(
'SequentialBatchSampler', VideoSampler)
SequentialBatchSampler.ON_VIDEO_END = {
NEW = 'new',
PAD = 'pad'
}
function SequentialBatchSampler:_init(
data_source_obj, sequence_length, step_size, use_boundary_frames,
options)
Expand All @@ -463,12 +467,29 @@ function SequentialBatchSampler:_init(
batch index 1: [frame 1, frame 2]
batch index 2: [frame 2, frame 3]
Default: sequence_length
on_video_end (string): One of "new" or "pad". This
determines how to fill a batch if the current video ends before
the batch is full.
"new" (default): Fill batch with frames from a new video
"pad": Fill batch with copies of the last frame of the current
video.
]]--
SequentialBatchSamplerSuper._init(self, data_source_obj, sequence_length,
step_size, use_boundary_frames, options)

self.stride = self.options.stride == nil and
self.sequence_length or self.options.stride
if self.options.on_video_end == self.ON_VIDEO_END.NEW or
self.options.on_video_end == self.ON_VIDEO_END.PAD then
log.debug('on_video_end set to', self.options.on_video_end)
self.on_video_end = self.options.on_video_end
elseif self.options.on_video_end == nil then
log.debug('on_video_end was nil, setting to "new"')
self.on_video_end = self.ON_VIDEO_END.NEW
else
error(string.format('Unrecognized "on_video_end" option: %s',
self.options.on_video_end))
end

self.videos = Sampler.permute(__.keys(self.video_keys))
self.video_index = 1
Expand All @@ -490,11 +511,20 @@ function SequentialBatchSampler:sample_keys(batch_size)
for _ = 1, self.sequence_length do
table.insert(batch_keys, {})
end
for _ = 1, batch_size do
for batch_index = 1, batch_size do
if not self:_is_valid_start() then
self:advance_video()
assert(self:_is_valid_start())
-- If the batch just started, or the sampler is configured to start
-- a new video on end of old video.
if batch_index == 1 or self.on_video_end == self.ON_VIDEO_END.NEW
then
self:advance_video()
assert(self:_is_valid_start())
elseif self.on_video_end == self.ON_VIDEO_END.PAD then
self.frame_index = self.frame_index - self.stride
assert(self:_is_valid_start())
end
end

local sequence = self:get_sequence(self.videos[self.video_index],
self.frame_index)
for step = 1, #sequence do
Expand Down
4 changes: 3 additions & 1 deletion trainer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,13 @@ function Trainer:_train_or_evaluate_batch(train_mode)
local images, labels = self:_load_batch(data, train_mode)
local loss = 0
local outputs

local function forward_backward()
if train_mode then
self.model:zeroGradParameters()
end
for i = 1, math.ceil(self.batch_size / self.computational_batch_size) do
local num_images = images:size(2)
for i = 1, math.ceil(num_images / self.computational_batch_size) do
local start_index = (i - 1) * self.computational_batch_size + 1
local end_index = math.min(
i * self.computational_batch_size, self.batch_size)
Expand Down

0 comments on commit 093bdab

Please sign in to comment.