-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathrun_experiment.py
330 lines (289 loc) · 17.1 KB
/
run_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
#run_experiment.py
#Copyright (c) 2020 Rachel Lea Ballantyne Draelos
#MIT License
#Permission is hereby granted, free of charge, to any person obtaining a copy
#of this software and associated documentation files (the "Software"), to deal
#in the Software without restriction, including without limitation the rights
#to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#copies of the Software, and to permit persons to whom the Software is
#furnished to do so, subject to the following conditions:
#The above copyright notice and this permission notice shall be included in all
#copies or substantial portions of the Software.
#THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#SOFTWARE
import os
import timeit
import datetime
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch, torch.nn as nn, torch.nn.functional as F
import torchvision
from torchvision import transforms, models, utils
import evaluate
from load_dataset import custom_datasets
#Set seeds
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
class DukeCTModel(object):
def __init__(self, descriptor, custom_net, custom_net_args,
loss, loss_args, num_epochs, patience, batch_size, device, data_parallel,
use_test_set, task, old_params_dir, dataset_class, dataset_args):
"""Variables:
<descriptor>: string describing the experiment
<custom_net>: class defining a model
<custom_net_args>: dictionary where keys correspond to custom net
input arguments, and values are the desired values
<loss>: 'bce' for binary cross entropy
<loss_args>: arguments to pass to the loss function if any
<num_epochs>: int for the maximum number of epochs to train
<patience>: number of epochs for which loss must fail to improve to
cause early stopping
<batch_size>: int for number of examples per batch
<device>: int specifying which device to use, or 'all' for all devices
<data_parallel>: if True then parallelize across available GPUs.
<use_test_set>: if True, then run model on the test set. If False, use
only the training and validation sets.
<task>:
'train_eval': train and evaluate a new model. 'evaluate' will
always imply use of the validation set. if <use_test_set> is
True, then 'evaluate' also includes calculation of test set
performance for the best validation epoch.
'predict_on_test': load a trained model and make predictions on
the test set using that model.
<old_params_dir>: this is only needed if <task>=='predict_on_test'. This
is the path to the parameters that will be loaded in to the model.
<dataset_class>: CT Dataset class for preprocessing the data
<dataset_args>: arguments for the dataset class specifying how
the data should be prepared."""
self.descriptor = descriptor
self.set_up_results_dirs()
self.custom_net = custom_net
self.custom_net_args = custom_net_args
self.loss = loss
self.loss_args = loss_args
self.num_epochs = num_epochs
self.batch_size = batch_size
print('self.batch_size=',self.batch_size)
#num_workers is number of threads to use for data loading
self.num_workers = int(batch_size*4) #batch_size 1 = num_workers 4. batch_size 2 = num workers 8. batch_size 4 = num_workers 16.
print('self.num_workers=',self.num_workers)
if self.num_workers == 1:
print('Warning: Using only one worker will slow down data loading')
#Set Device and Data Parallelism
if device in [0,1,2,3]: #i.e. if a GPU number was specified:
self.device = torch.device('cuda:'+str(device))
print('using device:',str(self.device),'\ndescriptor: ',self.descriptor)
elif device == 'all':
self.device = torch.device('cuda')
self.data_parallel = data_parallel
if self.data_parallel:
assert device == 'all' #use all devices when running data parallel
#Set Task
self.use_test_set = use_test_set
self.task = task
assert self.task in ['train_eval','predict_on_test']
if self.task == 'predict_on_test':
#overwrite the params dir that was created in the call to
#set_up_results_dirs() with the dir you want to load from
self.params_dir = old_params_dir
#Data and Labels
self.CTDatasetClass = dataset_class
self.dataset_args = dataset_args
#Get label meanings, a list of descriptive strings (list elements must
#be strings found in the column headers of the labels file)
self.set_up_label_meanings(self.dataset_args['label_meanings'])
if self.task == 'train_eval':
self.dataset_train = self.CTDatasetClass(setname = 'train', **self.dataset_args)
self.dataset_valid = self.CTDatasetClass(setname = 'valid', **self.dataset_args)
if self.use_test_set:
self.dataset_test = self.CTDatasetClass(setname = 'test', **self.dataset_args)
#Tracking losses and evaluation results
self.train_loss = np.zeros((self.num_epochs))
self.valid_loss = np.zeros((self.num_epochs))
self.eval_results_valid, self.eval_results_test = evaluate.initialize_evaluation_dfs(self.label_meanings, self.num_epochs)
#For early stopping
self.initial_patience = patience
self.patience_remaining = patience
self.best_valid_epoch = 0
self.min_val_loss = np.inf
#Run everything
self.run_model()
### Methods ###
def set_up_label_meanings(self,label_meanings):
if label_meanings == 'all': #get full list of all available labels
temp = custom_datasets.read_in_labels(self.dataset_args['label_type_ld'], 'valid')
self.label_meanings = temp.columns.values.tolist()
else: #use the label meanings that were passed in
self.label_meanings = label_meanings
print('label meanings ('+str(len(self.label_meanings))+' labels total):',self.label_meanings)
def set_up_results_dirs(self):
if not os.path.isdir('results'):
os.mkdir('results')
self.results_dir = os.path.join('results',datetime.datetime.today().strftime('%Y-%m-%d')+'_'+self.descriptor)
if not os.path.isdir(self.results_dir):
os.mkdir(self.results_dir)
self.params_dir = os.path.join(self.results_dir,'params')
if not os.path.isdir(self.params_dir):
os.mkdir(self.params_dir)
self.backup_dir = os.path.join(self.results_dir,'backup')
if not os.path.isdir(self.backup_dir):
os.mkdir(self.backup_dir)
def run_model(self):
if self.data_parallel:
self.model = nn.DataParallel(self.custom_net(**self.custom_net_args)).to(self.device)
else:
self.model = self.custom_net(**self.custom_net_args).to(self.device)
self.sigmoid = torch.nn.Sigmoid()
self.set_up_loss_function()
momentum = 0.99
print('Running with optimizer lr=1e-3, momentum='+str(round(momentum,2))+' and weight_decay=1e-7')
self.optimizer = torch.optim.SGD(self.model.parameters(), lr = 1e-3, momentum=momentum, weight_decay=1e-7)
train_dataloader = DataLoader(self.dataset_train, batch_size=self.batch_size, shuffle=True, num_workers = self.num_workers)
valid_dataloader = DataLoader(self.dataset_valid, batch_size=self.batch_size, shuffle=False, num_workers = self.num_workers)
if self.task == 'train_eval':
for epoch in range(self.num_epochs):
t0 = timeit.default_timer()
self.train(train_dataloader, epoch)
self.valid(valid_dataloader, epoch)
self.save_evals(epoch)
if self.patience_remaining <= 0:
print('No more patience (',self.initial_patience,') left at epoch',epoch)
print('--> Implementing early stopping. Best epoch was:',self.best_valid_epoch)
break
t1 = timeit.default_timer()
self.back_up_model_every_ten(epoch)
print('Epoch',epoch,'time:',round((t1 - t0)/60.0,2),'minutes')
if self.use_test_set: self.test(DataLoader(self.dataset_test, batch_size=self.batch_size, shuffle=False, num_workers = self.num_workers))
self.save_final_summary()
def set_up_loss_function(self):
if self.loss == 'bce':
self.loss_func = nn.BCEWithLogitsLoss() #includes application of sigmoid for numerical stability
def train(self, dataloader, epoch):
model = self.model.train()
epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=True)
self.train_loss[epoch] = epoch_loss
self.plot_roc_and_pr_curves('train', epoch, pred_epoch, gr_truth_epoch)
print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Train Loss', epoch_loss))
def valid(self, dataloader, epoch):
model = self.model.eval()
with torch.no_grad():
epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=False)
self.valid_loss[epoch] = epoch_loss
self.eval_results_valid = evaluate.evaluate_all(self.eval_results_valid, epoch,
self.label_meanings, gr_truth_epoch, pred_epoch)
self.early_stopping_check(epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch)
print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Valid Loss', epoch_loss))
def early_stopping_check(self, epoch, val_pred_epoch, val_gr_truth_epoch, val_volume_accs_epoch):
"""Check whether criteria for early stopping are met and update
counters accordingly"""
val_loss = self.valid_loss[epoch]
if (val_loss < self.min_val_loss) or epoch==0: #then save parameters
self.min_val_loss = val_loss
check_point = {'params': self.model.state_dict(),
'optimizer': self.optimizer.state_dict()}
torch.save(check_point, os.path.join(self.params_dir, self.descriptor))
self.best_valid_epoch = epoch
self.patience_remaining = self.initial_patience
print('model saved, val loss',val_loss)
self.plot_roc_and_pr_curves('valid', epoch, val_pred_epoch, val_gr_truth_epoch)
self.save_all_pred_probs('valid', epoch, val_pred_epoch, val_gr_truth_epoch, val_volume_accs_epoch)
else:
self.patience_remaining -= 1
def back_up_model_every_ten(self, epoch):
"""Back up the model parameters every 10 epochs"""
if epoch % 10 == 0:
check_point = {'params': self.model.state_dict(),
'optimizer': self.optimizer.state_dict()}
torch.save(check_point, os.path.join(self.backup_dir, self.descriptor+'_ep_'+str(epoch)))
def test(self, dataloader):
epoch = self.best_valid_epoch
if self.data_parallel:
model = nn.DataParallel(self.custom_net(**self.custom_net_args)).to(self.device).eval()
else:
model = self.custom_net(**self.custom_net_args).to(self.device).eval()
params_path = os.path.join(self.params_dir,self.descriptor)
print('For test set predictions, loading model params from params_path=',params_path)
check_point = torch.load(params_path)
model.load_state_dict(check_point['params'])
with torch.no_grad():
epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch = self.iterate_through_batches(model, dataloader, epoch, training=False)
self.eval_results_test = evaluate.evaluate_all(self.eval_results_test, epoch,
self.label_meanings, gr_truth_epoch, pred_epoch)
self.plot_roc_and_pr_curves('test', epoch, pred_epoch, gr_truth_epoch)
self.save_all_pred_probs('test', epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch)
print("{:5s} {:<3d} {:11s} {:.3f}".format('Epoch', epoch, 'Test Loss', epoch_loss))
def iterate_through_batches(self, model, dataloader, epoch, training):
epoch_loss = 0
#Initialize numpy arrays for storing results. examples x labels
#Do NOT use concatenation, or else you will have memory fragmentation.
num_examples = len(dataloader.dataset)
num_labels = len(self.label_meanings)
pred_epoch = np.zeros([num_examples,num_labels])
gr_truth_epoch = np.zeros([num_examples,num_labels])
volume_accs_epoch = np.empty(num_examples,dtype='U32') #need to use U32 to allow string of length 32
for batch_idx, batch in enumerate(dataloader):
data, gr_truth = self.move_data_to_device(batch)
self.optimizer.zero_grad()
if training:
out = model(data)
else:
with torch.set_grad_enabled(False):
out = model(data)
loss = self.loss_func(out, gr_truth)
if training:
loss.backward()
self.optimizer.step()
epoch_loss += loss.item()
torch.cuda.empty_cache()
#Save predictions and ground truth across batches
pred = self.sigmoid(out.data).detach().cpu().numpy()
gr_truth = gr_truth.detach().cpu().numpy()
start_row = batch_idx*self.batch_size
stop_row = min(start_row + self.batch_size, num_examples)
pred_epoch[start_row:stop_row,:] = pred #pred_epoch is e.g. [25355,80] and pred is e.g. [1,80] for a batch size of 1
gr_truth_epoch[start_row:stop_row,:] = gr_truth #gr_truth_epoch has same shape as pred_epoch
volume_accs_epoch[start_row:stop_row] = batch['volume_acc'] #volume_accs_epoch stores the volume accessions in the order they were used
#the following line to empty the cache is necessary in order to
#reduce memory usage and avoid OOM error:
torch.cuda.empty_cache()
return epoch_loss, pred_epoch, gr_truth_epoch, volume_accs_epoch
def move_data_to_device(self, batch):
"""Move data and ground truth to device."""
assert self.dataset_args['crop_type'] == 'single'
if self.dataset_args['crop_type'] == 'single':
data = batch['data'].to(self.device)
#Ground truth to device
gr_truth = batch['gr_truth'].to(self.device)
return data, gr_truth
def plot_roc_and_pr_curves(self, setname, epoch, pred_epoch, gr_truth_epoch):
outdir = os.path.join(self.results_dir,'curves')
if not os.path.isdir(outdir):
os.mkdir(outdir)
evaluate.plot_roc_curve_multi_class(label_meanings=self.label_meanings,
y_test=gr_truth_epoch, y_score=pred_epoch,
outdir = outdir, setname = setname, epoch = epoch)
evaluate.plot_pr_curve_multi_class(label_meanings=self.label_meanings,
y_test=gr_truth_epoch, y_score=pred_epoch,
outdir = outdir, setname = setname, epoch = epoch)
def save_all_pred_probs(self, setname, epoch, pred_epoch, gr_truth_epoch, volume_accs_epoch):
outdir = os.path.join(self.results_dir,'pred_probs')
if not os.path.isdir(outdir):
os.mkdir(outdir)
(pd.DataFrame(pred_epoch,columns=self.label_meanings,index=volume_accs_epoch.tolist())).to_csv(os.path.join(outdir, setname+'_predprob_ep'+str(epoch)+'.csv'))
(pd.DataFrame(gr_truth_epoch,columns=self.label_meanings,index=volume_accs_epoch.tolist())).to_csv(os.path.join(outdir, setname+'_grtruth_ep'+str(epoch)+'.csv'))
def save_evals(self, epoch):
evaluate.save(self.eval_results_valid, self.results_dir, self.descriptor+'_valid')
if self.use_test_set: evaluate.save(self.eval_results_test, self.results_dir, self.descriptor+'_test')
evaluate.plot_learning_curves(self.train_loss, self.valid_loss, self.results_dir, self.descriptor)
def save_final_summary(self):
evaluate.save_final_summary(self.eval_results_valid, self.best_valid_epoch, 'valid', self.results_dir)
if self.use_test_set: evaluate.save_final_summary(self.eval_results_test, self.best_valid_epoch, 'test', self.results_dir)
evaluate.clean_up_output_files(self.best_valid_epoch, self.results_dir)