forked from rosejn/lua-util
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathnan.lua
114 lines (92 loc) · 2.03 KB
/
nan.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
-- Replace NaN-generating functions with wrapped safe version
require 'torch'
require 'util'
local function negative(x)
return torch.lt(x,0):sum() > 0
end
local function zero(x)
return torch.eq(x, 0):sum() > 0
end
do
local orig = torch.DoubleTensor.log
torch.DoubleTensor.log = function(x)
assert(not negative(x), 'log called with negative argument')
return orig(x)
end
end
do
local orig = torch.log
torch.log = function(x)
assert(not negative(x), 'log called with negative argument')
return orig(x)
end
end
do
local orig = torch.DoubleTensor.sqrt
torch.DoubleTensor.sqrt = function(x)
assert(not negative(x), 'sqrt called with negative argument')
return orig(x)
end
end
do
local orig = torch.sqrt
torch.sqrt = function(x)
assert(not negative(x), 'sqrt called with negative argument')
return orig(x)
end
end
do
local orig = torch.DoubleTensor.div
torch.DoubleTensor.div = function(x, y, z)
if z == nil then
assert(not (zero(x) and y == 0), '0 / 0')
return orig(x, y)
else
assert(not (util.is_tensor(y) and zero(y) and z == 0), '0 / 0')
return orig(x, y, z)
end
end
end
do
local orig = torch.div
torch.div = function(x, y, z)
if z == nil then
assert(not (zero(x) and y == 0), '0 / 0')
return orig(x, y)
else
assert(not (util.is_tensor(y) and zero(y) and z == 0), '0 / 0')
return orig(x, y, z)
end
end
end
do
local orig = torch.DoubleTensor.__div
torch.DoubleTensor.__div = function(x, y)
assert(not (zero(x) and y == 0), '0 / 0')
return orig(x, y)
end
end
do
local orig = torch.cdiv
torch.cdiv = function(x, y, z)
if z == nil then
assert(not (zero(x) and zero(y)), '0 / 0')
return orig(x, y)
else
assert(not (zero(y) and zero(z)), '0 / 0')
return orig(x, y, z)
end
end
end
do
local orig = torch.DoubleTensor.cdiv
torch.DoubleTensor.cdiv = function(x, y, z)
if z == nil then
assert(not (zero(x) and zero(y)), '0 / 0')
return orig(x, y)
else
assert(not (zero(y) and zero(z)), '0 / 0')
return orig(x, y, z)
end
end
end