-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathevaluate.py
39 lines (33 loc) · 1.69 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import argparse
import torch
import os
from trainer import Trainer
from models import DANNModel, OneDomainModel
from dataloader import create_data_generators
from metrics import AccuracyScoreFromLogits
import configs.dann_config as dann_config
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--checkpoint', type=str, required=True, help='path to model checkpoint')
args = parser.parse_args()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_gen_t, _, _ = create_data_generators(dann_config.DATASET,
dann_config.TARGET_DOMAIN,
batch_size=dann_config.BATCH_SIZE,
infinite_train=False,
image_size=dann_config.IMAGE_SIZE,
split_ratios=[1, 0, 0],
num_workers=dann_config.NUM_WORKERS,
device=device)
model = DANNModel().to(device)
# model = OneDomainModel().to(device)
model.load_state_dict(torch.load(args.checkpoint))
model.eval()
acc = AccuracyScoreFromLogits()
tr = Trainer(model, None)
scores = tr.score(test_gen_t, [acc])
scores_string = ' '.join(['{}: {:.5f}\t'.format(k, float(v)) for k, v in scores.items()])
print(f"scores on dataset \"{dann_config.DATASET}\", domain \"{dann_config.TARGET_DOMAIN}\":")
print(scores_string)