-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata_loader.lua
108 lines (93 loc) · 3.13 KB
/
data_loader.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
--[[
Helper class to load data and labels from an LMDB containing LabeledVideoFrames
as values.
]]--
local classic = require 'classic'
local threads = require 'threads'
require 'torch'
require 'classic.torch'
local log = require 'util/log' -- luacheck: no unused
local DataLoader = classic.class('DataLoader')
function DataLoader:_init(data_source_obj, sampler)
--[[
Args:
data_source: Data source object.
sampler (Sampler): Sampler used for batches
]]--
self.data_source = data_source_obj
self.sampler = sampler
self._prefetched_data = {
batch_images = nil,
batch_labels = nil
}
self._prefetching_thread = threads.Threads(1)
end
function DataLoader:num_labels()
return self.sampler:num_labels()
end
function DataLoader:num_samples()
return self.sampler:num_samples()
end
function DataLoader:load_batch(batch_size, return_keys)
--[[
Load a batch of images and labels.
Args:
batch_size (num)
return_keys (bool): Whether to return the keys from the batch.
Default false.
Returns:
images (Array of array of ByteTensors): Contains image sequences for
the batch. Each element is a step in the sequence, so that
#images = sequence_length, #images[1] = batch_size.
labels (ByteTensor): Contains label ids. Size is
(sequence_length, batch_size, num_labels)
keys (Array of array of strings): Only returned if return_keys is
True.
]]--
return_keys = return_keys == nil and false or return_keys
if not self:_data_fetched() then
self:fetch_batch_async(batch_size)
self._prefetching_thread:synchronize()
end
local batch_images = self._prefetched_data.batch_images
local batch_labels = self._prefetched_data.batch_labels
local batch_keys = self._prefetched_data.batch_keys
self._prefetched_data.batch_images = nil
self._prefetched_data.batch_labels = nil
self._prefetched_data.batch_keys = nil
if return_keys then
return batch_images, batch_labels, batch_keys
else
return batch_images, batch_labels
end
end
function DataLoader:fetch_batch_async(batch_size)
--[[ Load a batch, store it for returning in next call to load_batch. ]]--
if self:_data_fetched() then
return
end
local batch_keys = self.sampler:sample_keys(batch_size)
local data_source_obj = self.data_source
self._prefetching_thread:addjob(
function()
require 'torch'
require 'classic'
require 'classic.torch'
require 'data_source'
end,
function()
local batch_images, batch_labels = data_source_obj:load_data(
batch_keys)
self._prefetched_data = {
batch_images = batch_images,
batch_labels = batch_labels,
batch_keys = batch_keys
}
end)
end
function DataLoader:_data_fetched()
-- Wait for possible fetching thread to finish.
self._prefetching_thread:synchronize()
return self._prefetched_data.batch_images ~= nil
end
return { DataLoader = DataLoader }