-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest-sort.cpp
94 lines (74 loc) · 2.03 KB
/
test-sort.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include <algorithm>
#include "simd_helpers/simd_int32.hpp"
#include "simd_helpers/simd_int64.hpp"
#include "simd_helpers/simd_float32.hpp"
#include "simd_helpers/simd_float64.hpp"
#include "simd_helpers/simd_ntuple.hpp"
#include "simd_helpers/sort.hpp"
#include "simd_helpers/simd_debug.hpp"
using namespace std;
using namespace simd_helpers;
// ------------------------------------------------------------------------------------------------
template<typename T, int S, int N>
static void test_sort1(std::mt19937 &rng)
{
vector<T> x(S*N, 0);
vector<T> y(S*N, 0);
vector<T> z(S*N, 0);
vector<T> t(N, 0);
for (int iter = 0; iter < 10000; iter++) {
for (int i = 0; i < S*N; i++)
x[i] = uniform_rand<T>(rng, -10, 10);
simd_ntuple<T,S,N> a;
a.loadu(&x[0]);
simd_sort(a);
a.storeu(&y[0]);
for (int i = 0; i < S; i++) {
for (int j = 0; j < N; j++)
t[j] = x[j*S+i];
std::sort(t.begin(), t.end());
for (int j = 0; j < N; j++)
z[j*S+i] = t[j];
}
for (int i = 0; i < S*N; i++) {
if (y[i] != z[i])
throw runtime_error("sort failed");
}
}
}
template<typename T, int S>
static void test_sort(std::mt19937 &rng)
{
test_sort1<T,S,2> (rng);
test_sort1<T,S,3> (rng);
test_sort1<T,S,4> (rng);
test_sort1<T,S,5> (rng);
test_sort1<T,S,6> (rng);
test_sort1<T,S,7> (rng);
test_sort1<T,S,8> (rng);
test_sort1<T,S,9> (rng);
test_sort1<T,S,10> (rng);
test_sort1<T,S,11> (rng);
test_sort1<T,S,12> (rng);
test_sort1<T,S,13> (rng);
test_sort1<T,S,14> (rng);
test_sort1<T,S,15> (rng);
test_sort1<T,S,16> (rng);
cout << "test_sort<" << type_name<T>() << "," << S << ">: pass" << endl;
}
int main(int argc, char **argv)
{
std::random_device rd;
std::mt19937 rng(rd());
test_sort<int,4> (rng);
test_sort<int64_t,2> (rng);
test_sort<float,4> (rng);
test_sort<double,2> (rng);
#ifdef __AVX__
test_sort<int,8> (rng);
test_sort<int64_t,4> (rng);
test_sort<float,8> (rng);
test_sort<double,4> (rng);
#endif
return 0;
}