-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinfer.py
264 lines (221 loc) · 10 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import argparse
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import SimpleITK as sitk
from skimage.transform import resize
from collections import OrderedDict
import warnings
warnings.filterwarnings('ignore')
from argparse import Namespace
import os
from accelerate import Accelerator
from LaMed.src.model.language_model import *
import json
from tqdm import tqdm
import monai.transforms as mtf
from generate_green_score_new import GenerateGreenScore
import pandas as pd
from LaMed.src.dataset.multi_dataset import prompt_templates
import re
from LaMed.src.dataset.utils import read_numpy_or_dicom
from utils.postprocessor import PostProcessor
from LaMed.src.dataset.multi_dataset import AMOSCapDataset
from RaTEScore import RaTEScore
import re
from nltk.translate.bleu_score import sentence_bleu
from rouge import Rouge
from nltk import word_tokenize
from nltk.translate.meteor_score import meteor_score
pattern = r"<FINDINGS>(.*?)<FINDINGS>"
def seed_everything(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed_all(seed)
from torch.utils.data.dataloader import default_collate
def custom_collate(batch):
images, input_ids, answer, image_name = tuple(
[b[key] for b in batch] for key in ('image', 'input_id', 'answer', 'image_name'))
images_ = {i: a for i, a in enumerate(images)}
input_ids = {i: a for i, a in enumerate(input_ids)}
answers_ = {i: a for i, a in enumerate(answer)}
image_name_ = {i: a for i, a in enumerate(image_name)}
return_dict = dict(
image=images_,
input_id=input_ids,
answer=answers_,
image_name=image_name_
)
return return_dict
# Set the seed for reproducibility
def main():
parser = argparse.ArgumentParser(description='Script configuration')
parser.add_argument('--is_val', type=bool, default=False, help='Validation flag')
parser.add_argument('--model_name_or_path', type=str, default='/scratch/ssd004/scratch/mohammed/results/hilt_64_320_1024', help='Model path or name')
parser.add_argument('--json_path', type=str, default="/scratch/ssd004/scratch/mohammed/AMOSMM/AMOSMMVal.json", help='Path to JSON file')
parser.add_argument('--model_max_length', type=int, default=768, help='Maximum model length')
parser.add_argument('--proj_out_num', type=int, default=512, help='Project output number')
parser.add_argument('--image_path', type=str, default="/scratch/ssd004/datasets/med-img-data/amosmm/ori_nii/imagesVa", help='Path to the image directory')
parser.add_argument('--prompt', type=str, default="")
parser.add_argument('--organs', metavar='N', type=str, nargs='+', default=["abdomen", "pelvis", "chest"])
parser.add_argument('--zoom', type=bool, default=False)
parser.add_argument('--post_process', metavar='N', type=str, nargs='+', default=[])
args = parser.parse_args()
seed_everything(42)
device = torch.device('cuda')
dtype = torch.bfloat16 # or bfloat16, float16, float32
with_template = True
green = True
for key, value in vars(args).items():
globals()[key] = value
if "llama" in model_name_or_path:
model = LamedLlamaForCausalLM.from_pretrained(
model_name_or_path,
cache_dir='/scratch/ssd004/datasets/med-img-data/amosmm/trained/cache/',
torch_dtype=dtype,
device_map='auto',
trust_remote_code=True)
elif "gemma" in model_name_or_path:
model = LamedGemmaForCausalLM.from_pretrained(
model_name_or_path,
cache_dir='/scratch/ssd004/datasets/med-img-data/amosmm/trained/cache/',
trust_remote_code=True,
torch_dtype=dtype,
device_map='auto')
elif "qwen" in model_name_or_path:
model = LamedQwen2ForCausalLM.from_pretrained(
model_name_or_path,
cache_dir='/scratch/ssd004/datasets/med-img-data/amosmm/trained/cache/',
trust_remote_code=True,
torch_dtype=dtype,
device_map='auto')
else:
model = LamedPhi3ForCausalLM.from_pretrained(
model_name_or_path,
cache_dir='/scratch/ssd004/datasets/med-img-data/amosmm/trained/cache/',
torch_dtype=dtype,
device_map='auto',
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
cache_dir='/scratch/ssd004/datasets/med-img-data/amosmm/trained/cache/',
model_max_length=model_max_length,
padding_side="right",
use_fast=False,
trust_remote_code=True
)
model = model.eval()
if model.config.any_res_image_size:
resize_size = model.config.any_res_image_size
else:
resize_size = model.config.image_size
proj_out_num = args.proj_out_num * model.config.multipler
args_dict = vars(args)
print("Arguments received:")
for key, value in args_dict.items():
print(f"{key}: {value}")
tag = json_path.split(os.sep)[-1].split(".")[0]
path = model_name_or_path + os.sep + f'{tag}.csv'
if os.path.exists(path):
results = pd.read_csv(path)
results = results.to_dict(orient='list')
else:
results = OrderedDict()
results['names'] = []
for organ in organs:
results[f'generated-{organ}'] = []
results[f'gt-{organ}'] = []
data_args = Namespace()
data_args.proj_out_num = proj_out_num
data_args.json_path = json_path
data_args.data_root = ""
data_args.max_length = model_max_length
data_args.prompt = prompt
data_args.zoom_in = zoom
data_args.organs = organs
data_args.with_seg_mask = False
data_args.with_template= with_template
data_args.data_img_size = resize_size
dataset = AMOSCapDataset(data_args, tokenizer, mode='validation')
for item in tqdm(dataset):
image_name = item["image_name"]
if image_name in results['names']:
print(f"Skipping {image_name}--already done.")
continue
organs_ = ["abdomen", "pelvis", "chest"]
if green:
organs_ = item["answer"].keys()
for organ in organs_:
image = item["image"][organ].unsqueeze(0).to(device, dtype=dtype)
input_id = item["input_id"][organ].to(device)
generation = model.generate(image, input_id, segs=None, max_new_tokens=512, do_sample=False, top_p=0.9, temperature=1)
generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)[0]
findings_match = re.search(pattern, generated_texts, re.DOTALL)
generated_texts = findings_match.group(1).strip() if findings_match else generated_texts.strip()
results['generated-' + str(organ)].append(generated_texts)
if green:
gt_text = item["answer"][organ]
results['gt-' + str(organ)].append(gt_text)
else:
results['gt-' + str(organ)].append("")
missing_organs = [o for o in organs if o not in organs_]
for m_organ in missing_organs:
results['gt-' + str(m_organ)].append("")
results['generated-' + str(m_organ)].append("")
results['names'].append(image_name)
results_df = pd.DataFrame(results)
results_df.to_csv(path, index=False)
if len(post_process) > 0:
print("Using post processing methods:", post_process)
pp = PostProcessor(results, post_process, dataset)
results = pp.run()
results_df = pd.DataFrame(results)
results_df.to_csv(path, index=False)
if green:
print("Generating Green")
g = GenerateGreenScore(path, cache_dir="/checkpoint/datasets.damaged/med-img-data/amosmm/green", organs=organs)
results = g.run()
bleu_scores = {'abdomen': [], 'chest': [], 'pelvis': []}
rouge_scores = {'abdomen': [], 'chest': [], 'pelvis': []}
meteor_scores = {'abdomen': [], 'chest': [], 'pelvis': []}
rate_scores = {'abdomen': [], 'chest': [], 'pelvis': []}
rouge = Rouge()
ratescore = RaTEScore()
print("Calculating other metrics")
# categories = organs_
categories = ["abdomen", "pelvis", "chest"]
# categories = ["abdomen"]
for i in tqdm(range(len(results['names']))):
for category in categories:
gen_key = f"generated-{category}"
gt_key = f"gt-{category}"
if isinstance(results[gen_key][i], str) and len(results[gen_key][i]) > 1 \
and isinstance(results[gt_key][i], str) and len(results[gt_key][i]) > 1:
ref_tokens = word_tokenize(results[gt_key][i].lower())
cand_tokens = word_tokenize(results[gen_key][i].lower())
try:
bleu_scores[category].append(sentence_bleu([ref_tokens], cand_tokens, weights=(0.5, 0.5, 0, 0)))
rouge_result = rouge.get_scores(results[gen_key][i], results[gt_key][i])[0]
rouge_scores[category].append(rouge_result["rouge-l"]['f'])
meteor_scores[category].append(meteor_score([ref_tokens], cand_tokens))
rate_scores[category].append(ratescore.compute_score([results[gen_key][i]], [results[gt_key][i]])[0])
except:
continue
def filtered_average(scores):
averaged_region = {k: sum(v) / len(v) if v else 0 for k, v in scores.items()}
filtered_scores = [score for score in averaged_region.values() if score > 0]
return sum(filtered_scores) / len(filtered_scores) if filtered_scores else 0
final = {
"bleu": filtered_average(bleu_scores),
"rouge": filtered_average(rouge_scores),
"meteor": filtered_average(meteor_scores),
"rate": filtered_average(rate_scores)
}
other_path = f"{os.sep}".join(path.split(os.sep)[:-1])
with open(other_path + os.sep + "other_metrics.json", 'w') as json_file:
json.dump(final, json_file, indent=4)
if __name__ == '__main__':
main()