-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtest_model.py
124 lines (105 loc) · 5.26 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
""" test_model.py
Test models
Collaboratively developed
by Avi Schwarzschild, Eitan Borgnia,
Arpit Bansal, and Zeyad Emam.
Developed for DeepThinking project
October 2021
"""
import logging
import os
import sys
from collections import OrderedDict
import json
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
import deepthinking as dt
# Ignore statements for pylint:
# Too many branches (R0912), Too many statements (R0915), No member (E1101),
# Not callable (E1102), Invalid name (C0103), No exception (W0702),
# Too many local variables (R0914), Missing docstring (C0116, C0115).
# pylint: disable=R0912, R0915, E1101, E1102, C0103, W0702, R0914, C0116, C0115
@hydra.main(config_path="config", config_name="test_model_config")
def main(cfg: DictConfig):
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cudnn.benchmark = True
if cfg.problem.hyp.save_period is None:
cfg.problem.hyp.save_period = cfg.problem.hyp.epochs
log = logging.getLogger()
log.info("\n_________________________________________________\n")
log.info("test_model.py main() running.")
log.info(OmegaConf.to_yaml(cfg))
training_args = OmegaConf.load(os.path.join(cfg.problem.model.model_path, ".hydra/config.yaml"))
cfg_keys_to_load = [("hyp", "alpha"),
("hyp", "epochs"),
("hyp", "lr"),
("hyp", "lr_factor"),
("model", "max_iters"),
("model", "model"),
("hyp", "optimizer"),
("hyp", "train_mode"),
("model", "width")]
for k1, k2 in cfg_keys_to_load:
cfg["problem"][k1][k2] = training_args["problem"][k1][k2]
cfg.problem.train_data = cfg.problem.train_data
log.info(OmegaConf.to_yaml(cfg))
####################################################
# Dataset and Network and Optimizer
loaders = dt.utils.get_dataloaders(cfg.problem)
cfg.problem.model.model_path = os.path.join(cfg.problem.model.model_path, "model_best.pth")
net, start_epoch, optimizer_state_dict = dt.utils.load_model_from_checkpoint(cfg.problem.name,
cfg.problem.model,
device)
pytorch_total_params = sum(p.numel() for p in net.parameters())
log.info(f"This {cfg.problem.model.model} has {pytorch_total_params/1e6:0.3f} million parameters.")
####################################################
####################################################
# Test
log.info("==> Starting testing...")
if "feedforward" in cfg.problem.model.model:
test_iterations = [cfg.problem.model.max_iters]
else:
test_iterations = list(range(cfg.problem.model.test_iterations["low"],
cfg.problem.model.test_iterations["high"] + 1))
if cfg.quick_test:
test_acc = dt.test(net, [loaders["test"]], cfg.problem.hyp.test_mode, test_iterations, cfg.problem.name, device)
test_acc = test_acc[0]
val_acc, train_acc = None, None
else:
test_acc, val_acc, train_acc = dt.test(net,
[loaders["test"], loaders["val"], loaders["train"]],
cfg.problem.hyp.test_mode,
test_iterations,
cfg.problem.name, device)
log.info(f"{dt.utils.now()} Training accuracy: {train_acc}")
log.info(f"{dt.utils.now()} Val accuracy: {val_acc}")
log.info(f"{dt.utils.now()} Testing accuracy (hard data): {test_acc}")
model_name_str = f"{cfg.problem.model.model}_width={cfg.problem.model.width}"
stats = OrderedDict([("epochs", cfg.problem.hyp.epochs),
("lr", cfg.problem.hyp.lr),
("lr_factor", cfg.problem.hyp.lr_factor),
("max_iters", cfg.problem.model.max_iters),
("model", model_name_str),
("model_path", cfg.problem.model.model_path),
("num_params", pytorch_total_params),
("optimizer", cfg.problem.hyp.optimizer),
("val_acc", val_acc),
("run_id", cfg.run_id),
("test_acc", test_acc),
("test_data", cfg.problem.test_data),
("test_iters", test_iterations),
("test_mode", cfg.problem.hyp.test_mode),
("train_data", cfg.problem.train_data),
("train_acc", train_acc),
("train_batch_size", cfg.problem.hyp.train_batch_size),
("train_mode", cfg.problem.hyp.train_mode),
("alpha", cfg.problem.hyp.alpha)])
with open(os.path.join("stats.json"), "w") as fp:
json.dump(stats, fp)
log.info(stats)
####################################################
if __name__ == "__main__":
run_id = dt.utils.generate_run_id()
sys.argv.append(f"+run_id={run_id}")
main()