-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmiscal.py
44 lines (42 loc) · 1.5 KB
/
miscal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import numpy as np
from scipy.interpolate import interp1d
def get_uncertainty_prediction(unc_preds):
"""unc_preds 2D ndarray (99, num_x) where each row
corresonds to tau 0.01, 0.02... and the columns
are for the set of x being predicted over.
"""
taus = np.arange(0.01, 1, 0.01)
y_min, y_max = np.min(unc_preds), np.max(unc_preds)
y_grid = np.linspace(y_min, y_max, 1000)
new_quants = []
avg_cdfs = []
for x_idx in range(unc_preds.shape[-1]):
x_cdf = []
for ens_idx in range(unc_preds.shape[0]):
xs, ys = [], []
targets = unc_preds[ens_idx, :, x_idx]
for idx in np.argsort(targets):
if len(xs) != 0 and targets[idx] <= xs[-1]:
continue
xs.append(targets[idx])
ys.append(taus[idx])
intr = interp1d(
xs, ys, kind="linear", fill_value=([0], [1]), bounds_error=False
)
x_cdf.append(intr(y_grid))
x_cdf = np.asarray(x_cdf)
avg_cdf = np.mean(x_cdf, axis=0)
avg_cdfs.append(avg_cdf)
t_idx = 0
x_quants = []
for idx in range(len(avg_cdf)):
if t_idx >= len(taus):
break
if taus[t_idx] <= avg_cdf[idx]:
x_quants.append(y_grid[idx])
t_idx += 1
while t_idx < len(taus):
x_quants.append(y_grid[-1])
t_idx += 1
new_quants.append(x_quants)
return np.asarray(new_quants).T