forked from facebookarchive/fb.resnet.torch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimagenet.lua
105 lines (90 loc) · 2.73 KB
/
imagenet.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- ImageNet dataset loader
--
local image = require 'image'
local paths = require 'paths'
local t = require 'datasets/transforms'
local ffi = require 'ffi'
local M = {}
local ImagenetDataset = torch.class('resnet.ImagenetDataset', M)
function ImagenetDataset:__init(imageInfo, opt, split)
self.imageInfo = imageInfo[split]
self.opt = opt
self.split = split
self.dir = paths.concat(opt.data, split)
assert(paths.dirp(self.dir), 'directory does not exist: ' .. self.dir)
end
function ImagenetDataset:get(i)
local path = ffi.string(self.imageInfo.imagePath[i]:data())
local image = self:_loadImage(paths.concat(self.dir, path))
local class = self.imageInfo.imageClass[i]
return {
input = image,
target = class,
}
end
function ImagenetDataset:_loadImage(path)
local ok, input = pcall(function()
return image.load(path, 3, 'float')
end)
-- Sometimes image.load fails because the file extension does not match the
-- image format. In that case, use image.decompress on a ByteTensor.
if not ok then
local f = io.open(path, 'r')
assert(f, 'Error reading: ' .. tostring(path))
local data = f:read('*a')
f:close()
local b = torch.ByteTensor(string.len(data))
ffi.copy(b:data(), data, b:size(1))
input = image.decompress(b, 3, 'float')
end
return input
end
function ImagenetDataset:size()
return self.imageInfo.imageClass:size(1)
end
-- Computed from random subset of ImageNet training images
local meanstd = {
mean = { 0.485, 0.456, 0.406 },
std = { 0.229, 0.224, 0.225 },
}
local pca = {
eigval = torch.Tensor{ 0.2175, 0.0188, 0.0045 },
eigvec = torch.Tensor{
{ -0.5675, 0.7192, 0.4009 },
{ -0.5808, -0.0045, -0.8140 },
{ -0.5836, -0.6948, 0.4203 },
},
}
function ImagenetDataset:preprocess()
if self.split == 'train' then
return t.Compose{
t.RandomSizedCrop(224),
t.ColorJitter({
brightness = 0.4,
contrast = 0.4,
saturation = 0.4,
}),
t.Lighting(0.1, pca.eigval, pca.eigvec),
t.ColorNormalize(meanstd),
t.HorizontalFlip(0.5),
}
elseif self.split == 'val' then
local Crop = self.opt.tenCrop and t.TenCrop or t.CenterCrop
return t.Compose{
t.Scale(256),
t.ColorNormalize(meanstd),
Crop(224),
}
else
error('invalid split: ' .. self.split)
end
end
return M.ImagenetDataset