From 093bdabe42848a5a4a45b28d7e32d99e7b89b920 Mon Sep 17 00:00:00 2001 From: Achal Dave Date: Wed, 29 Nov 2017 14:55:11 -0500 Subject: [PATCH] Add options for padding batch when video ends --- samplers.lua | 36 +++++++++++++++++++++++++++++++++--- trainer.lua | 4 +++- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/samplers.lua b/samplers.lua index 9987023..78dcae6 100644 --- a/samplers.lua +++ b/samplers.lua @@ -440,6 +440,10 @@ end local SequentialBatchSampler, SequentialBatchSamplerSuper = classic.class( 'SequentialBatchSampler', VideoSampler) +SequentialBatchSampler.ON_VIDEO_END = { + NEW = 'new', + PAD = 'pad' +} function SequentialBatchSampler:_init( data_source_obj, sequence_length, step_size, use_boundary_frames, options) @@ -463,12 +467,29 @@ function SequentialBatchSampler:_init( batch index 1: [frame 1, frame 2] batch index 2: [frame 2, frame 3] Default: sequence_length + on_video_end (string): One of "new" or "pad". This + determines how to fill a batch if the current video ends before + the batch is full. + "new" (default): Fill batch with frames from a new video + "pad": Fill batch with copies of the last frame of the current + video. ]]-- SequentialBatchSamplerSuper._init(self, data_source_obj, sequence_length, step_size, use_boundary_frames, options) self.stride = self.options.stride == nil and self.sequence_length or self.options.stride + if self.options.on_video_end == self.ON_VIDEO_END.NEW or + self.options.on_video_end == self.ON_VIDEO_END.PAD then + log.debug('on_video_end set to', self.options.on_video_end) + self.on_video_end = self.options.on_video_end + elseif self.options.on_video_end == nil then + log.debug('on_video_end was nil, setting to "new"') + self.on_video_end = self.ON_VIDEO_END.NEW + else + error(string.format('Unrecognized "on_video_end" option: %s', + self.options.on_video_end)) + end self.videos = Sampler.permute(__.keys(self.video_keys)) self.video_index = 1 @@ -490,11 +511,20 @@ function SequentialBatchSampler:sample_keys(batch_size) for _ = 1, self.sequence_length do table.insert(batch_keys, {}) end - for _ = 1, batch_size do + for batch_index = 1, batch_size do if not self:_is_valid_start() then - self:advance_video() - assert(self:_is_valid_start()) + -- If the batch just started, or the sampler is configured to start + -- a new video on end of old video. + if batch_index == 1 or self.on_video_end == self.ON_VIDEO_END.NEW + then + self:advance_video() + assert(self:_is_valid_start()) + elseif self.on_video_end == self.ON_VIDEO_END.PAD then + self.frame_index = self.frame_index - self.stride + assert(self:_is_valid_start()) + end end + local sequence = self:get_sequence(self.videos[self.video_index], self.frame_index) for step = 1, #sequence do diff --git a/trainer.lua b/trainer.lua index af53b3a..99a3c52 100644 --- a/trainer.lua +++ b/trainer.lua @@ -200,11 +200,13 @@ function Trainer:_train_or_evaluate_batch(train_mode) local images, labels = self:_load_batch(data, train_mode) local loss = 0 local outputs + local function forward_backward() if train_mode then self.model:zeroGradParameters() end - for i = 1, math.ceil(self.batch_size / self.computational_batch_size) do + local num_images = images:size(2) + for i = 1, math.ceil(num_images / self.computational_batch_size) do local start_index = (i - 1) * self.computational_batch_size + 1 local end_index = math.min( i * self.computational_batch_size, self.batch_size)