-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathAttMIL_Deployment.py
134 lines (96 loc) · 5.6 KB
/
AttMIL_Deployment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# -*- coding: utf-8 -*-
"""
Created on Thu Mar 11 13:45:05 2021
@author: Narmin Ghaffari Laleh
"""
##############################################################################
import utils.utils as utils
from utils.core_utils import Validate_model_Classic
from utils.data_utils import ConcatCohorts_Classic, DatasetLoader_Classic, GetTiles
from eval.eval import CalculatePatientWiseAUC
import torch.nn as nn
import torchvision
import pandas as pd
import argparse
import torch
import os
import random
from sklearn import preprocessing
##############################################################################
parser = argparse.ArgumentParser(description = 'Main Script to Run Training')
parser.add_argument('--adressExp', type = str, default = r"L:\Experiments\DACHS_RESNET50_TestFull.txt", help = 'Adress to the experiment File')
parser.add_argument('--modelAdr', type = str, default = r"L:\Experiments\DACHS_RESNET50_TRAINFULL_isMSIH_1\RESULTS\bestModel", help = 'Adress to the selected model')
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('\nTORCH Detected: {}\n'.format(device))
##############################################################################
if __name__ == '__main__':
args = utils.ReadExperimentFile(args, deploy = True)
torch.cuda.set_device(args.gpuNo)
random.seed(args.seed)
targetLabels = args.target_labels
for targetLabel in targetLabels:
args.target_label = targetLabel
args.projectFolder = utils.CreateProjectFolder(args.project_name, args.adressExp, targetLabel, args.model_name)
if os.path.exists(args.projectFolder):
continue
else:
os.mkdir(args.projectFolder)
args.result_dir = os.path.join(args.projectFolder, 'RESULTS')
os.makedirs(args.result_dir, exist_ok = True)
args.split_dir = os.path.join(args.projectFolder, 'SPLITS')
os.makedirs(args.split_dir, exist_ok = True)
reportFile = open(os.path.join(args.projectFolder,'Report.txt'), 'a', encoding="utf-8")
reportFile.write('-' * 30 + '\n')
reportFile.write(str(args))
reportFile.write('-' * 30 + '\n')
print('\nLOAD THE DATASET FOR TESTING...\n')
patientsList, labelsList, args.csvFile = ConcatCohorts_Classic(imagesPath = args.datadir_test,
cliniTablePath = args.clini_dir, slideTablePath = args.slide_dir,
label = targetLabel, minNumberOfTiles = args.minNumBlocks,
outputPath = args.projectFolder, reportFile = reportFile, csvName = args.csv_name,
patientNumber = args.numPatientToUse)
labelsList = utils.CheckForTargetType(labelsList)
le = preprocessing.LabelEncoder()
labelsList = le.fit_transform(labelsList)
args.num_classes = len(set(labelsList))
args.target_labelDict = dict(zip(le.classes_, range(len(le.classes_))))
utils.Summarize(args, list(labelsList), reportFile)
print('-' * 30)
print('IT IS A DEPLOYMENT FOR ' + targetLabel + '!')
print('GENERATE NEW TILES...')
test_data = GetTiles(csvFile = args.csvFile, label = targetLabel, target_labelDict = args.target_labelDict, maxBlockNum = args.maxBlockNum, test = True)
test_x = list(test_data['TilePath'])
test_y = list(test_data[targetLabel])
test_data.to_csv(os.path.join(args.split_dir, 'TestSplit.csv'), index = False)
print()
print('-' * 30)
model, input_size = utils.Initialize_model(model_name = args.model_name, num_classes = args.num_classes, feature_extract = False, use_pretrained = True)
params = {'batch_size': args.batch_size,
'shuffle': True,
'num_workers': 0,
'pin_memory' : False}
test_set = DatasetLoader_Classic(test_x, test_y, transform = torchvision.transforms.ToTensor, target_patch_size = input_size)
testGenerator = torch.utils.data.DataLoader(test_set, **params)
try:
model.load_state_dict(torch.load(args.modelAdr))
except:
model = torch.load(args.modelAdr)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
print('START DEPLOYING...')
print('')
probsList = Validate_model_Classic(model = model, dataloaders = testGenerator)
probs = {}
for key in list(args.target_labelDict.keys()):
probs[key] = []
for item in probsList:
probs[key].append(item[utils.get_value_from_key(args.target_labelDict, key)])
probs = pd.DataFrame.from_dict(probs)
testResults = pd.concat([test_data, probs], axis = 1)
testResultsPath = os.path.join(args.result_dir, 'TEST_RESULT_TILE_BASED_FULL.csv')
testResults.to_csv(testResultsPath, index = False)
CalculatePatientWiseAUC(resultCSVPath = testResultsPath, args = args, foldcounter = None , clamMil = False, reportFile = reportFile)
reportFile.write('-' * 100 + '\n')
print('\n')
print('-' * 30)