From 46fd4d8317cf4157c89eb5f7c3e3a58420e6be29 Mon Sep 17 00:00:00 2001 From: Achal Dave Date: Wed, 29 Nov 2017 14:28:29 -0500 Subject: [PATCH] Small abstraction in data_source.lua --- data_source.lua | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/data_source.lua b/data_source.lua index 41bad2c..2791ad7 100644 --- a/data_source.lua +++ b/data_source.lua @@ -123,6 +123,27 @@ function DiskFramesHdf5LabelsDataSource:key_label_map(return_label_map) return key_labels end +function DiskFramesHdf5LabelsDataSource:load_image(key) + --[[ + Load image and label for a given key. + ]]-- + local video_name, frame_number = self:frame_video_offset(key) + local frame_path = string.format('%s/%s/frame%04d.png', + self.frames_root, + video_name, + frame_number) + local frame = image.load( + frame_path, 3 --[[depth]], 'byte' --[[type]]) + -- For backwards compatibility, use BGR images. + frame = frame:index(1, torch.LongTensor{3, 2, 1}) + return frame +end + +function DiskFramesHdf5LabelsDataSource:load_labels(key) + local video_name, frame_number = self:frame_video_offset(key) + return self.labels[video_name][frame_number] +end + function DiskFramesHdf5LabelsDataSource:load_data(keys, load_images) --[[ Load images and labels for a set of keys. @@ -155,22 +176,12 @@ function DiskFramesHdf5LabelsDataSource:load_data(keys, load_images) VideoDataSource.END_OF_SEQUENCE) batch_labels[{step, sequence}]:zero() else - local video_name, frame_number = self:frame_video_offset(key) if load_images then - local frame_path = string.format('%s/%s/frame%04d.png', - self.frames_root, - video_name, - frame_number) - local frame = image.load( - frame_path, 3 --[[depth]], 'byte' --[[type]]) - -- For backwards compatibility, use BGR images. - frame = frame:index(1, torch.LongTensor{3, 2, 1}) - batch_images[step][sequence] = frame + batch_images[step][sequence] = self:load_image(key) else batch_images[step][sequence] = torch.ByteTensor() end - batch_labels[{step, sequence, {}}] = - self.labels[video_name][frame_number] + batch_labels[{step, sequence, {}}] = self:load_labels(key) end end end