-
Notifications
You must be signed in to change notification settings - Fork 958
/
Copy pathView.lua
96 lines (81 loc) · 2.41 KB
/
View.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
local View, parent = torch.class('nn.View', 'nn.Module')
function View:resetSize(...)
if select('#', ...) == 1 and torch.typename(select(1, ...)) == 'torch.LongStorage' then
self.size = select(1, ...)
else
self.size = torch.LongStorage({...})
end
self.numElements = 1
local inferdim = false
for i = 1,#self.size do
local szi = self.size[i]
if szi >= 0 then
self.numElements = self.numElements * self.size[i]
else
assert(szi == -1, 'size should be positive or -1')
assert(not inferdim, 'only one dimension can be at -1')
inferdim = true
end
end
return self
end
function View:__init(...)
parent.__init(self)
self:resetSize(...)
self.numInputDims = nil
end
function View:setNumInputDims(numInputDims)
self.numInputDims = numInputDims
return self
end
local function batchsize(input, size, numInputDims, numElements)
local ind = input:nDimension()
local isz = input:size()
local maxdim = numInputDims and numInputDims or ind
local ine = 1
for i=ind,ind-maxdim+1,-1 do
ine = ine * isz[i]
end
if ine % numElements ~= 0 then
error(string.format(
'input view (%s) and desired view (%s) do not match',
table.concat(input:size():totable(), 'x'),
table.concat(size:totable(), 'x')))
end
-- the remainder is either the batch...
local bsz = ine / numElements
-- ... or the missing size dim
for i=1,size:size() do
if size[i] == -1 then
bsz = 1
break
end
end
-- for dim over maxdim, it is definitively the batch
for i=ind-maxdim,1,-1 do
bsz = bsz * isz[i]
end
-- special card
if bsz == 1 and (not numInputDims or input:nDimension() <= numInputDims) then
return
end
return bsz
end
function View:updateOutput(input)
self.output = self.output or input.new()
local bsz = batchsize(input, self.size, self.numInputDims, self.numElements)
if bsz then
self.output:view(input, bsz, table.unpack(self.size:totable()))
else
self.output:view(input, self.size)
end
return self.output
end
function View:updateGradInput(input, gradOutput)
self.gradInput = self.gradInput or gradOutput.new()
self.gradInput:view(gradOutput, input:size())
return self.gradInput
end
function View:__tostring__()
return torch.type(self)..'('..table.concat(self.size:totable(), ', ')..')'
end