-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
92 lines (68 loc) · 2.5 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
from os.path import join
import argparse
import torch
import numpy as np
import random
from data import build_dataloader
from models.modeling import build_gzsl_pipeline
from models.config import cfg
from models.utils.comm import *
from models.utils import ReDirectSTD
from models.engine.inferencer import eval_zs_gzsl
def test_model(cfg, local_rank, distributed):
model = build_gzsl_pipeline(cfg)
model_dict = model.state_dict()
saved_dict = torch.load('checkpoints/best_model.pth')
saved_dict = {k: v for k, v in saved_dict['model'].items() if k in model_dict}
model_dict.update(saved_dict)
model.load_state_dict(model_dict)
device = torch.device(cfg.MODEL.DEVICE)
model = model.to(device)
tr_dataloader, tu_loader, ts_loader, res = build_dataloader(cfg, is_distributed=distributed)
test_gamma = cfg.TEST.GAMMA
acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl(
tu_loader,
ts_loader,
res,
model,
test_gamma,
device)
print('zsl: %.4f, gzsl: seen=%.4f, unseen=%.4f, h=%.4f' % (acc_zs, acc_seen, acc_novel, H))
return model
def main():
parser = argparse.ArgumentParser(description="PyTorch Zero-Shot Learning Training")
parser.add_argument(
"--config-file",
default="config/ip102.yaml",
metavar="FILE",
help="path to config file",
type=str,
)
parser.add_argument("--local_rank", type=int, default=0)
args = parser.parse_args()
num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
args.distributed = num_gpus > 1
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(
backend="nccl", init_method="env://"
)
synchronize()
cfg.merge_from_file(args.config_file)
cfg.freeze()
output_dir = cfg.OUTPUT_DIR
log_file_name = cfg.LOG_FILE_NAME
log_file_path = join(output_dir, log_file_name)
if is_main_process():
ReDirectSTD(log_file_path, 'stdout', True)
print("Loaded configuration file {}".format(args.config_file))
with open(args.config_file, "r") as cf:
config_str = "\n" + cf.read()
print(config_str)
print("Running with config:\n{}".format(cfg))
torch.backends.cudnn.benchmark = True
model = test_model(cfg=cfg, local_rank=args.local_rank, distributed=args.distributed)
if __name__ == '__main__':
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
main()