forked from ksoh97/FIESTA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
123 lines (113 loc) · 4.19 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import sys
sys.path.append(os.getcwd())
import argparse
from torch.utils.data import DataLoader
import glob
from omegaconf import OmegaConf
from main import instantiate_from_config
import GPUtil
import torch
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"-g",
"--gpu_id",
type=str,
default="5",
help="GPU id",
)
parser.add_argument(
"-r",
"--resume",
type=str,
nargs="?",
help="load from logdir or checkpoint in logdir",
)
parser.add_argument(
"-b",
"--base",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
# default=["configs/efficientUnet_prostate_F-rest.yaml"], # LGE->bSSFP (cardiac)
# default=["configs/efficientUnet_SABSCT_to_CHAOS.yaml"], # CT->MRI (abdominal)
# default=["configs/efficientUnet_CHAOS_to_SABSCT.yaml"], # MRI->CT (abdominal)
# default=["configs/efficientUnet_bSSFP_to_LEG.yaml"], # bSSFP->LGE (cardiac)
default=["configs/efficientUnet_LEG_to_bSSFP.yaml"], # LGE->bSSFP (cardiac)
)
parser.add_argument(
"-c",
"--config",
nargs="?",
metavar="single_config.yaml",
help="path to single config. If specified, base configs will be ignored "
"(except for the last one if left unspecified).",
const=True,
default="",
)
parser.add_argument(
"--ignore_base_data",
action="store_true",
help="Ignore data specification from base configs. Useful if you want "
"to specify a custom datasets on the command line.",
)
return parser
if __name__ == "__main__":
sys.path.append(os.getcwd())
parser = get_parser()
opt, unknown = parser.parse_known_args()
ckpt = None
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
try:
idx = len(paths)-paths[::-1].index("logs")+1
except ValueError:
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
for f in os.listdir(os.path.join(logdir, "checkpoints")):
if 'latest' in f:
ckpt = os.path.join(logdir, "checkpoints", f)
print(f"logdir:{logdir}")
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*-project.yaml")))
opt.base = base_configs+opt.base
if opt.config:
if type(opt.config) == str:
opt.base = [opt.config]
else:
opt.base = [opt.base[-1]]
if opt.gpu_id:
if not opt.gpu_id == '-1':
devices = opt.gpu_id
else:
devices = "%s" % GPUtil.getFirstAvailable(order="memory")[0]
os.environ["CUDA_VISIBLE_DEVICES"] = devices
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
model_config = config.pop("model", OmegaConf.create())
print(model_config)
gpu = True
eval_mode = True
show_config = False
model = instantiate_from_config(model_config)
pl_sd=torch.load(ckpt, map_location="cpu")
model.load_state_dict(pl_sd['model'], strict=False)
model.cuda().eval()
data = instantiate_from_config(config.data)
data.prepare_data()
data.setup()
# val_loader = DataLoader(data.datasets["validation"], batch_size=1, num_workers=1)
test_loader = DataLoader(data.datasets["test"], batch_size=1, num_workers=1)
from engine import prediction_wrapper
label_name=data.datasets["train"].all_label_names
out_prediction_list, dsc_table, error_dict, domain_names = prediction_wrapper(model, test_loader, 0, label_name, path=opt.resume, save_prediction=True)
print(f'Selected Model: {ckpt}')