-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata_source.lua
215 lines (191 loc) · 7.85 KB
/
data_source.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
local classic = require 'classic'
local hdf5 = require 'hdf5'
local image = require 'image'
local torch = require 'torch'
local __ = require 'moses'
require 'classic.torch' -- Necessary for serializing classic classes.
local DataSource = classic.class('DataSource')
DataSource:mustHave('num_samples')
-- Given an array of keys, load the associated images and labels.
DataSource:mustHave('load_data')
-- Return a table mapping data sample keys to labels for the sample.
DataSource:mustHave('key_label_map')
DataSource:mustHave('num_labels')
local VideoDataSource = classic.class('VideoDataSource', 'DataSource')
-- Return a table mapping video name to an ordered array of keys for frames in
-- the video.
VideoDataSource:mustHave('video_keys')
VideoDataSource:mustHave('frame_video_offset')
VideoDataSource.END_OF_SEQUENCE = -1
local DiskFramesHdf5LabelsDataSource = classic.class(
'DiskFramesHdf5LabelsDataSource', 'VideoDataSource')
function DiskFramesHdf5LabelsDataSource:_init(options)
--[[
Data source for loading frames from disk and labels from HDF5.
Args (in options):
frames_root (str): Contains subdirectories titled <video_name> for each
video, which in turn contain frames of the form frame%04d.png
labels_hdf5 (str or num): If str, specifies path to HDF5 file
containing <video_name> keys, with (num_frames, num_labels) binary
label matrices as values. If num, specifies number of labels, and
we will set the labels matrix to be a matrix of all 1s. This is
useful for running on images without labels.
]]--
self.frames_root = options.frames_root
local labels_hdf5 = options.labels_hdf5
if type(labels_hdf5) == "number" then
self.num_labels_ = labels_hdf5
self.video_keys_ =
DiskFramesHdf5LabelsDataSource.static.collect_video_frames(
self.frames_root)
self.num_samples_ = 0
self.labels = {}
for video, video_keys in pairs(self.video_keys_) do
self.num_samples_ = self.num_samples_ + #video_keys
self.labels[video] = torch.ones(#video_keys, self.num_labels_)
end
else
local hdf5_labels_file = hdf5.open(labels_hdf5, 'r')
self.labels = hdf5_labels_file:all()
self.num_labels_ = self.labels[__.keys(self.labels)[1]]:size(2)
self.video_keys_ = {}
self.num_samples_ = 0
for video_name, video_labels in pairs(self.labels) do
local num_frames = video_labels:size(1)
self.video_keys_[video_name] = {}
for i = 1, num_frames do
table.insert(self.video_keys_[video_name], video_name .. '-' .. i)
end
self.num_samples_ = self.num_samples_ + num_frames
end
end
end
function DiskFramesHdf5LabelsDataSource.static.collect_video_frames(path)
local paths = require 'paths'
local video_keys = {}
for video in paths.iterdirs(path) do
video_keys[video] = {}
for frame in paths.iterfiles(path .. '/' .. video) do
if string.match(frame, 'frame[0-9]+') ~= nil then
local index = string.match(frame, '[0-9]+')
video_keys[video][tonumber(index)] = video .. '-' .. index
end
end
end
return video_keys
end
function DiskFramesHdf5LabelsDataSource:num_labels() return self.num_labels_ end
function DiskFramesHdf5LabelsDataSource:video_keys()
return self.video_keys_
end
function DiskFramesHdf5LabelsDataSource:key_label_map(return_label_map)
--[[
Load mapping from frame keys to labels array.
Note: This is a giant array, and should be destroyed as soon as it is no
longer needed. If this array is stored permanently (e.g. globally or as an
object attribute), it will slow down *all* future calls to collectgarbage().
Args:
return_label_map (bool): If true, return a map from label names to label
id.
Returns:
key_labels: Table mapping frame keys to array of label indices.
(optional) label_map: See doc for return_label_map arg.
]]--
assert(not return_label_map)
local key_labels = {}
for video_name, video_labels in pairs(self.labels) do
for i = 1, video_labels:size(1) do
local key = video_name .. '-' .. i
local squeezed = video_labels[{i, {}}]:nonzero():squeeze()
if torch.isTensor(squeezed) then
squeezed = squeezed:totable()
else
squeezed = {squeezed}
end
local labels = squeezed
key_labels[key] = labels
end
end
return key_labels
end
function DiskFramesHdf5LabelsDataSource:load_image(key)
--[[
Load image and label for a given key.
]]--
local video_name, frame_number = self:frame_video_offset(key)
local frame_path = string.format('%s/%s/frame%04d.png',
self.frames_root,
video_name,
frame_number)
local frame = image.load(
frame_path, 3 --[[depth]], 'byte' --[[type]])
-- For backwards compatibility, use BGR images.
frame = frame:index(1, torch.LongTensor{3, 2, 1})
return frame
end
function DiskFramesHdf5LabelsDataSource:load_labels(key)
local video_name, frame_number = self:frame_video_offset(key)
return self.labels[video_name][frame_number]
end
function DiskFramesHdf5LabelsDataSource:load_data(keys, load_images)
--[[
Load images and labels for a set of keys.
Args:
keys (array): Array of array of string keys. Each element must be
an array of the same length as every element, and contains keys for
one step of the image sequence.
load_images (bool): Defaults to true. If false, load only labels,
not images. The ByteTensors in batch_images will simply be empty.
Returns:
batch_images: Array of array of ByteTensors
batch_labels: ByteTensor of shape (num_steps, batch_size, num_labels)
]]--
load_images = load_images == nil and true or load_images
local num_steps = #keys
local batch_size = #keys[1]
local batch_labels = torch.ByteTensor(
num_steps, batch_size, self.num_labels_)
local batch_images = {}
for step = 1, num_steps do
batch_images[step] = {}
end
for step, step_keys in ipairs(keys) do
for sequence, key in ipairs(step_keys) do
if key == VideoDataSource.END_OF_SEQUENCE then
table.insert(batch_images[step],
VideoDataSource.END_OF_SEQUENCE)
batch_labels[{step, sequence}]:zero()
else
if load_images then
batch_images[step][sequence] = self:load_image(key)
else
batch_images[step][sequence] = torch.ByteTensor()
end
batch_labels[{step, sequence, {}}] = self:load_labels(key)
end
end
end
return batch_images, batch_labels
end
-- luacheck: push no unused args
function DiskFramesHdf5LabelsDataSource:frame_video_offset(key)
return DiskFramesHdf5LabelsDataSource.static.parse_frame_key(key)
end
-- luacheck: pop
function DiskFramesHdf5LabelsDataSource:num_samples()
return self.num_samples_
end
function DiskFramesHdf5LabelsDataSource.static.parse_frame_key(frame_key)
-- Keys are of the form '<filename>-<frame_number>'.
-- Find the index of the '-'
local _, split_index = string.find(frame_key, '[^-]*-')
local filename = string.sub(frame_key, 1, split_index - 1)
local frame_number = tonumber(string.sub(frame_key, split_index + 1, -1))
return filename, frame_number
end
return {
DataSource = DataSource,
VideoDataSource = VideoDataSource,
DiskFramesHdf5LabelsDataSource = DiskFramesHdf5LabelsDataSource,
END_OF_SEQUENCE = VideoDataSource.END_OF_SEQUENCE
}