Skip to content

Commit

Permalink
Small abstraction in data_source.lua
Browse files Browse the repository at this point in the history
  • Loading branch information
achalddave committed Nov 29, 2017
1 parent b94b32d commit 46fd4d8
Showing 1 changed file with 23 additions and 12 deletions.
35 changes: 23 additions & 12 deletions data_source.lua
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,27 @@ function DiskFramesHdf5LabelsDataSource:key_label_map(return_label_map)
return key_labels
end

function DiskFramesHdf5LabelsDataSource:load_image(key)
--[[
Load image and label for a given key.
]]--
local video_name, frame_number = self:frame_video_offset(key)
local frame_path = string.format('%s/%s/frame%04d.png',
self.frames_root,
video_name,
frame_number)
local frame = image.load(
frame_path, 3 --[[depth]], 'byte' --[[type]])
-- For backwards compatibility, use BGR images.
frame = frame:index(1, torch.LongTensor{3, 2, 1})
return frame
end

function DiskFramesHdf5LabelsDataSource:load_labels(key)
local video_name, frame_number = self:frame_video_offset(key)
return self.labels[video_name][frame_number]
end

function DiskFramesHdf5LabelsDataSource:load_data(keys, load_images)
--[[
Load images and labels for a set of keys.
Expand Down Expand Up @@ -155,22 +176,12 @@ function DiskFramesHdf5LabelsDataSource:load_data(keys, load_images)
VideoDataSource.END_OF_SEQUENCE)
batch_labels[{step, sequence}]:zero()
else
local video_name, frame_number = self:frame_video_offset(key)
if load_images then
local frame_path = string.format('%s/%s/frame%04d.png',
self.frames_root,
video_name,
frame_number)
local frame = image.load(
frame_path, 3 --[[depth]], 'byte' --[[type]])
-- For backwards compatibility, use BGR images.
frame = frame:index(1, torch.LongTensor{3, 2, 1})
batch_images[step][sequence] = frame
batch_images[step][sequence] = self:load_image(key)
else
batch_images[step][sequence] = torch.ByteTensor()
end
batch_labels[{step, sequence, {}}] =
self.labels[video_name][frame_number]
batch_labels[{step, sequence, {}}] = self:load_labels(key)
end
end
end
Expand Down

0 comments on commit 46fd4d8

Please sign in to comment.