-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluate.lua
164 lines (136 loc) · 4.7 KB
/
evaluate.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
require 'dp'
require 'rnn'
require 'optim'
-- References :
-- A. http://papers.nips.cc/paper/5542-recurrent-models-of-visual-attention.pdf
-- B. http://incompleteideas.net/sutton/williams-92.pdf
--[[command line arguments]]--
cmd = torch.CmdLine()
cmd:text()
cmd:text('Evaluate a Recurrent Model for Visual Attention')
cmd:text('Options:')
cmd:option('--xpPath', '', 'path to a previously saved model')
cmd:option('--cuda', false, 'model was saved with cuda')
cmd:option('--evalTest', false, 'model was saved with cuda')
cmd:option('--stochastic', false, 'evaluate the model stochatically. Generate glimpses stochastically')
cmd:option('--dataset', 'Mnist', 'which dataset to use : Mnist | TranslattedMnist | etc')
cmd:option('--overwrite', false, 'overwrite checkpoint')
cmd:text()
local opt = cmd:parse(arg or {})
-- check that saved model exists
assert(paths.filep(opt.xpPath), opt.xpPath..' does not exist')
if opt.cuda then
require 'cunn'
end
xp = torch.load(opt.xpPath)
model = xp:model().module
tester = xp:tester() or xp:validator() -- dp.Evaluator
tester:sampler()._epoch_size = nil
conf = tester:feedback() -- dp.Confusion
cm = conf._cm -- optim.ConfusionMatrix
print("Last evaluation of "..(xp:tester() and 'test' or 'valid').." set :")
print(cm)
if opt.dataset == 'TranslatedMnist' then
ds = torch.checkpoint(
paths.concat(dp.DATA_DIR, 'checkpoint/dp.TranslatedMnist_test.t7'),
function()
local ds = dp[opt.dataset]{load_all=false}
ds:loadTest()
return ds
end,
opt.overwrite
)
else
ds = dp[opt.dataset]()
end
ra = model:findModules('nn.RecurrentAttention')[1]
sg = model:findModules('nn.SpatialGlimpse')[1]
-- stochastic or deterministic
for i=1,#ra.actions do
local rn = ra.action:getStepModule(i):findModules('nn.ReinforceNormal')[1]
rn.stochastic = opt.stochastic
end
if opt.evalTest then
conf:reset()
tester:propagateEpoch(ds:testSet())
print((opt.stochastic and "Stochastic" or "Deterministic") .. "evaluation of test set :")
print(cm)
end
inputs = ds:get('test','inputs')
targets = ds:get('test','targets', 'b')
input = inputs:narrow(1,1,10)
model:training() -- otherwise the rnn doesn't save intermediate time-step states
if not opt.stochastic then
for i=1,#ra.actions do
local rn = ra.action:getStepModule(i):findModules('nn.ReinforceNormal')[1]
rn.stdev = 0 -- deterministic
end
end
output = model:forward(input)
function drawBox(img, bbox, channel)
channel = channel or 1
local x1, y1 = torch.round(bbox[1]), torch.round(bbox[2])
local x2, y2 = torch.round(bbox[1] + bbox[3]), torch.round(bbox[2] + bbox[4])
x1, y1 = math.max(1, x1), math.max(1, y1)
x2, y2 = math.min(img:size(3), x2), math.min(img:size(2), y2)
local max = img:max()
for i=x1,x2 do
img[channel][y1][i] = max
img[channel][y2][i] = max
end
for i=y1,y2 do
img[channel][i][x1] = max
img[channel][i][x2] = max
end
return img
end
locations = ra.actions
input = nn.Convert(ds:ioShapes(),'bchw'):forward(input)
glimpses = {}
patches = {}
params = nil
print((input:size(1)).. "input size ")
for i=1,input:size(1) do
local img = input[i]
print("I am here")
---print(img)
for j,location in ipairs(locations) do
local glimpse = glimpses[j] or {}
glimpses[j] = glimpse
local patch = patches[j] or {}
patches[j] = patch
local xy = location[i]
-- (-1,-1) top left corner, (1,1) bottom right corner of image
local x, y = xy:select(1,1), xy:select(1,2)
-- (0,0), (1,1)
x, y = (x+1)/2, (y+1)/2
-- (1,1), (input:size(3), input:size(4))
x, y = x*(input:size(3)-1)+1, y*(input:size(4)-1)+1
local gimg = img:clone()
print((sg.depth) .. " sg depth")
for d=1,sg.depth do
local size = sg.size*(sg.scale^(d-1))
local bbox = {y-size/2, x-size/2, size, size}
drawBox(gimg, bbox, 1)
end
glimpse[i] = gimg
local sg_, ps
if j == 1 then
sg_ = ra.rnn.initialModule:findModules('nn.SpatialGlimpse')[1]
else
sg_ = ra.rnn.sharedClones[j]:findModules('nn.SpatialGlimpse')[1]
end
--print(sg_.output[i]:narrow(1,1,1):size())
--print(img:clone():size())
--patch[i] = image.scale(img:clone():float(), sg_.output[i]:narrow(1,1,1):float())
patch[i] = image.scale(img:clone():float(), sg_.output[i]:float())
collectgarbage()
end
end
paths.mkdir('glimpse')
for j,glimpse in ipairs(glimpses) do
local g = image.toDisplayTensor{input=glimpse,nrow=10,padding=3}
local p = image.toDisplayTensor{input=patches[j],nrow=10,padding=3}
image.save("glimpse/glimpse"..j..".png", g)
image.save("glimpse/patch"..j..".png", p)
end