forked from Element-Research/rnn
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathRecursor.lua
102 lines (84 loc) · 2.85 KB
/
Recursor.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
------------------------------------------------------------------------
--[[ Recursor ]]--
-- Decorates module to be used within an AbstractSequencer.
-- It does this by making the decorated module conform to the
-- AbstractRecurrent interface (which is inherited by LSTM/Recurrent)
------------------------------------------------------------------------
local Recursor, parent = torch.class('nn.Recursor', 'nn.AbstractRecurrent')
function Recursor:__init(module, rho)
parent.__init(self, rho or 9999999)
self.recurrentModule = module
self.module = module
self.modules = {module}
self.sharedClones[1] = self.recurrentModule
end
function Recursor:updateOutput(input)
local output
if self.train ~= false then -- if self.train or self.train == nil then
-- set/save the output states
self:recycle()
local recurrentModule = self:getStepModule(self.step)
output = recurrentModule:updateOutput(input)
else
output = self.recurrentModule:updateOutput(input)
end
self.outputs[self.step] = output
self.output = output
self.step = self.step + 1
self.updateGradInputStep = nil
self.accGradParametersStep = nil
return self.output
end
function Recursor:_updateGradInput(input, gradOutput)
assert(self.step > 1, "expecting at least one updateOutput")
local step = self.updateGradInputStep - 1
assert(step >= 1)
local recurrentModule = self:getStepModule(step)
recurrentModule:setOutputStep(step)
local gradInput = recurrentModule:updateGradInput(input, gradOutput)
return gradInput
end
function Recursor:_accGradParameters(input, gradOutput, scale)
local step = self.accGradParametersStep - 1
assert(step >= 1)
local recurrentModule = self:getStepModule(step)
recurrentModule:setOutputStep(step)
recurrentModule:accGradParameters(input, gradOutput, scale)
end
function Recursor:includingSharedClones(f)
local modules = self.modules
self.modules = {}
local sharedClones = self.sharedClones
self.sharedClones = nil
for i,modules in ipairs{modules, sharedClones} do
for j, module in pairs(modules) do
table.insert(self.modules, module)
end
end
local r = {f()}
self.modules = modules
self.sharedClones = sharedClones
return unpack(r)
end
function Recursor:forget(offset)
parent.forget(self, offset)
nn.Module.forget(self)
return self
end
function Recursor:maxBPTTstep(rho)
self.rho = rho
nn.Module.maxBPTTstep(self, rho)
end
function Recursor:getHiddenState(...)
return self.modules[1]:getHiddenState(...)
end
function Recursor:setHiddenState(...)
return self.modules[1]:setHiddenState(...)
end
function Recursor:getGradHiddenState(...)
return self.modules[1]:getGradHiddenState(...)
end
function Recursor:setGradHiddenState(...)
return self.modules[1]:setGradHiddenState(...)
end
Recursor.__tostring__ = nn.Decorator.__tostring__