forked from torch/nngraph
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinit.lua
50 lines (40 loc) · 1.26 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
require 'nn'
require 'graph'
nngraph = {}
torch.include('nngraph','node.lua')
torch.include('nngraph','gmodule.lua')
torch.include('nngraph','graphinspecting.lua')
torch.include('nngraph','ModuleFromCriterion.lua')
-- handy functions
local utils = paths.dofile('utils.lua')
local istensor = torch.isTensor
local istable = utils.istable
local istorchclass = utils.istorchclass
-- simpler todot functions
nngraph.simple_print = paths.dofile('simple_print.lua')
-- Modify the __call function to hack into nn.Module
local Module = torch.getmetatable('nn.Module')
function Module:__call__(...)
local nArgs = select("#", ...)
assert(nArgs <= 1, 'Use {input1, input2} to pass multiple inputs.')
local input = ...
if nArgs == 1 and input == nil then
error('what is this in the input? nil')
end
if not istable(input) then
input = {input}
end
local mnode = nngraph.Node({module=self})
for i,dnode in ipairs(input) do
if torch.typename(dnode) ~= 'nngraph.Node' then
error('what is this in the input? ' .. tostring(dnode))
end
mnode:add(dnode,true)
end
return mnode
end
local Criterion = torch.getmetatable('nn.Criterion')
function Criterion:__call__(...)
return nn.ModuleFromCriterion(self)(...)
end
return nngraph