-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathtest_fast_products.py
134 lines (100 loc) · 2.85 KB
/
test_fast_products.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
Test the basic multilinear algebra operations between torchtt.TT objects.
"""
import pytest
import torchtt as tntt
import torch as tn
import numpy as np
def err_rel(t, ref): return (tn.linalg.norm(t-ref).numpy() / tn.linalg.norm(ref).numpy()
if tn.linalg.norm(ref).numpy() > 0 else tn.linalg.norm(t-ref).numpy()) if ref.shape == t.shape else np.inf
parameters = [tn.float64, tn.complex128]
@pytest.mark.parametrize("dtype", parameters)
def test_hadamard(dtype):
'''
Test the hadamard fast multiplication betwenn TTs
'''
N = [2, 3, 4, 2, 3]
x = tntt.random(N, [1, 3, 2, 3, 4, 1], dtype=dtype)
y = tntt.random(N, [1, 2, 2, 5, 4, 1], dtype=dtype)
X = x.clone()
X = X + X
X = X + X
# X += 1e2*x
Y = y.clone()
Y = Y + Y
Y = Y + Y
# Y += 1e2*x
z_ref = 16*x*y
z = tntt.fast_hadammard(X, Y, 1e-9)
assert z.N == z_ref.N
assert err_rel(z.full(), z_ref.full()) < 1e-9
@pytest.mark.parametrize("dtype", parameters)
def test_hadamard_ttm(dtype):
'''
Test the hadamard fast multiplication betwenn TTMs
'''
M = [3, 2, 2, 4]
N = [2, 3, 4, 2]
MN = [(M[i], N[i]) for i in range(4)]
x = tntt.random(MN, [1, 3, 2, 3, 1], dtype=dtype)
y = tntt.random(MN, [1, 2, 2, 5, 1], dtype=dtype)
X = x.clone()
X = X + X
X = X + X
# X += 1e2*x
Y = y.clone()
Y = Y + Y
Y = Y + Y
# Y += 1e2*x
z_ref = 16*x*y
z = tntt.fast_hadammard(X, Y, 1e-9)
assert z.N == z_ref.N
assert z.M == z_ref.M
assert err_rel(z.full(), z_ref.full()) < 1e-9
@pytest.mark.parametrize("dtype", parameters)
def test_mv(dtype):
'''
Test the fast multiplication betwenn TTM and TT
'''
M = [3, 2, 2, 4]
N = [2, 3, 4, 2]
MN = [(M[i], N[i]) for i in range(4)]
x = tntt.random(MN, [1, 3, 2, 3, 1], dtype=dtype)
y = tntt.random(N, [1, 2, 2, 5, 1], dtype=dtype)
X = x.clone()
X = X + X
X = X + X
# X += 1e2*x
Y = y.clone()
Y = Y + Y
Y = Y + Y
# Y += 1e2*x
z_ref = 16*x@y
z = tntt.fast_mv(X, Y, 1e-9)
assert z.N == z_ref.N
assert err_rel(z.full(), z_ref.full()) < 1e-9
@pytest.mark.parametrize("dtype", parameters)
def test_mm(dtype):
'''
Test the fast multiplication betwenn TTM and TTM
'''
M = [3, 2, 2, 4]
N = [2, 3, 4, 2]
K = [4, 2, 3, 3]
MN = [(M[i], N[i]) for i in range(4)]
NK = [(N[i], K[i]) for i in range(4)]
x = tntt.random(MN, [1, 3, 2, 3, 1], dtype=dtype)
y = tntt.random(NK, [1, 2, 2, 5, 1], dtype=dtype)
X = x.clone()
X = X + X
X = X + X
# X += 1e2*x
Y = y.clone()
Y = Y + Y
Y = Y + Y
# Y += 1e2*x
z_ref = 16*x@y
z = tntt.fast_mm(X, Y, 1e-9)
assert z.N == z_ref.N
assert z.M == z_ref.M
assert err_rel(z.full(), z_ref.full()) < 1e-9