-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbenchmark_utils.py
135 lines (117 loc) · 4.47 KB
/
benchmark_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
Utility functions to perform benchmarks.
"""
from time import time
from typing import Dict, Tuple, Callable, Iterable, Any, Union
import numpy as np
from utils import check_sol, build_matrices, solve_bartels_stewart
def benchmark(solve_fun: Callable,
vary_param: Tuple[str, Iterable],
log_context: Dict[str, Any] = None,
check_solution: bool = True,
n_runs: Union[int, Tuple[int, int]] = 5,
bartel_stewart: bool = False,
**kwargs):
"""Interface function for benchmarking.
:param n_runs: fixed if int. Provide boundaries (nruns_min, nruns_max) to vary nruns in function of matrix size
"""
if log_context is None:
log_context = dict()
if isinstance(n_runs, int):
get_nruns = lambda _: n_runs
else:
if vary_param[0] != 'dim':
raise ValueError
nmin, nmax = min(e[0] for e in vary_param[1]), max(e[0] for e in vary_param[1])
get_nruns = lambda n: compute_nruns(nmin, max(n_runs), nmax, min(n_runs), n)
variable, values = vary_param
results = []
# If the variable that varies is for multiple_runs, it will automatically understand it,
# otherwise it passes the variable to solve_fun
for value in values:
r = get_nruns(value[0] if variable == 'dim' else kwargs['dim'][0])
print(f'{variable}={value} {r} runs')
run_config = kwargs.copy()
run_config[variable] = value
if bartel_stewart:
times = multiple_runs_bertel_stewart(
solve_fun=solve_fun,
n_runs=r,
check_solution=check_solution,
**run_config
)
else:
times = multiple_runs(
solve_fun=solve_fun,
n_runs=r,
check_solution=check_solution,
**run_config
)
# Log computational time
run_config['time'] = times
# Additional log
run_config.update(log_context)
results.append(run_config)
return results
def compute_nruns(nmin: int, nruns: int, nmax: int, nruns2: int, n: int):
"""Compute the number of runs given the matrix size n."""
nmin, nmax, n = np.log([nmin, nmax, n])
slope = (nruns2 - nruns) / (nmax - nmin)
r = (n - nmin) * slope + nruns
r = np.clip(r, nruns2, nruns)
return int(r)
def multiple_runs(solve_fun: Callable, n_runs: int, dim: Tuple[int, int], check_solution: bool = True, **kwargs):
"""Return computation times for multiple runs.
:param solve_fun: function taking matrices (np.ndarray) A mxm, B nxn, C mxn and optional keyword arguments
:param n_runs: number of replications
:param dim: dimension of matrices (m, n), m size of A, n size of B
:param check_solution: whether to check solution after calling solve_fun
:param kwargs: passed to solve_fun
:return: list of solving times
"""
m, n = dim
res = list(np.zeros(n_runs))
i = 0
while i < n_runs:
A, B, C = build_matrices(m, n)
X = C.copy()
t = time()
solve_fun(A, B, X, **kwargs)
t = time() - t
res[i] = t
if check_solution:
if check_sol(A, B, C, X):
i += 1
else:
print('WARNING: incorrect solution, retrying...')
else:
i += 1
return res
def multiple_runs_bertel_stewart(solve_fun: Callable,
n_runs: int,
dim: Tuple[int, int],
check_solution: bool = True, **kwargs):
"""Return computation times for multiple runs, calling solve_bartels_stewart.
:param solve_fun: function taking matrices (np.ndarray) A mxm, B nxn, C mxn and optional keyword arguments
:param n_runs: number of replications
:param dim: dimension of matrices (m, n), m size of A, n size of B
:param check_solution: whether to check solution after calling solve_fun
:param kwargs: passed to solve_fun
:return: list tuples (time_schur, time_solve, time_back)
"""
m, n = dim
res = []
i = 0
while i < n_runs:
A, B, C = build_matrices(m, n)
r = solve_bartels_stewart(A, B, C, solve_fun, **kwargs)
res.append(r[1:])
if check_solution:
X = r[0]
if check_sol(A, B, C, X):
i += 1
else:
print('WARNING: incorrect solution, retrying...')
else:
i += 1
return res