-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathmake_schoop.py
129 lines (109 loc) · 4.67 KB
/
make_schoop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
""" make_schoop.py
For generating schoopy plots
Collaboratively developed
by Avi Schwarzschild, Eitan Borgnia,
Arpit Bansal, and Zeyad Emam.
Developed for DeepThinking project
October 2021
"""
import argparse
from datetime import datetime
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from make_table import get_table
def get_schoopy_plot(table, error_bars=True):
fig, ax = plt.subplots(figsize=(20, 9))
models = set(table.model)
test_datas = set(table.test_data)
alphas = set(table.alpha)
sns.lineplot(data=table,
x="test_iter",
y="test_acc_mean",
hue="model",
size="alpha",
sizes=(2, 8),
style="test_data" if len(test_datas) > 1 else None,
palette="dark",
dashes=True,
units=None,
legend="auto",
ax=ax)
if error_bars and "test_acc_sem" in table.keys():
for model in models:
for test_data in test_datas:
for alpha in alphas:
data = table[(table.model == model) &
(table.test_data == test_data) &
(table.alpha == alpha)]
plt.fill_between(data.test_iter,
data.test_acc_mean - data.test_acc_sem,
data.test_acc_mean + data.test_acc_sem,
alpha=0.1, color="k")
tr = table.max_iters.max() # training regime number
ax.fill_between([0, tr], [105, 105], alpha=0.3, label="Training Regime")
return ax
def main():
parser = argparse.ArgumentParser(description="Analysis parser")
parser.add_argument("--alpha_list", type=float, nargs="+", default=None,
help="only plot models with alphas in given list")
parser.add_argument("filepath", type=str)
parser.add_argument("--filter", type=float, default=None,
help="cutoff for filtering by training acc?")
parser.add_argument("--plot_name", type=str, default=None, help="where to save image?")
parser.add_argument("--max_iters_list", type=int, nargs="+", default=None,
help="only plot models with max iters in given list")
parser.add_argument("--model_list", type=str, nargs="+", default=None,
help="only plot models with model name in given list")
parser.add_argument("--width_list", type=str, nargs="+", default=None,
help="only plot models with widths in given list")
parser.add_argument("--max", action="store_true", help="add max values to table?")
parser.add_argument("--min", action="store_true", help="add min values too table?")
parser.add_argument("--xlim", type=float, nargs="+", default=None, help="x limits for plotting")
parser.add_argument("--ylim", type=float, nargs="+", default=None, help="y limits for plotting")
args = parser.parse_args()
if args.plot_name is None:
now = datetime.now().strftime("%m%d-%H.%M")
args.plot_name = f"schoop{now}.png"
plot_title = "Schoopy Plot"
else:
plot_title = args.plot_name[:-4]
# get table of results
table = get_table(args.filepath,
args.max,
args.min,
filter_at=args.filter,
max_iters_list=args.max_iters_list,
alpha_list=args.alpha_list,
width_list=args.width_list,
model_list=args.model_list)
# reformat and reindex table for plotting purposes
table.columns = table.columns.map("_".join)
table.columns.name = None
table = table.reset_index()
print(table.round(2).to_markdown())
ax = get_schoopy_plot(table)
ax.legend(fontsize=26, loc="upper left", bbox_to_anchor=(1.0, 0.8))
x_max = table.test_iter.max()
x = np.arange(20, x_max + 1, 10 if (x_max <= 100) else 100)
ax.tick_params(axis="y", labelsize=34)
ax.set_xticks(x)
ax.set_xticklabels(x, fontsize=34, rotation=37)
if args.xlim is None:
ax.set_xlim([x.min() - 0.5, x.max() + 0.5])
else:
ax.set_xlim(args.xlim)
if args.ylim is None:
ax.set_ylim([0, 103])
else:
ax.set_ylim(args.ylim)
ax.set_xlabel("Test-Time Iterations", fontsize=34)
ax.set_ylabel("Accuracy (%)", fontsize=34)
ax.set_title(plot_title, fontsize=34)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
plt.tight_layout()
plt.savefig(args.plot_name)
# plt.show()
if __name__ == "__main__":
main()