-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsettings.py
executable file
·129 lines (111 loc) · 4.46 KB
/
settings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# -*- coding: utf-8 -*-
import os
import math
import json
import numpy as np
import torch.nn as nn
# batch sizes used for training the network
precalculated_batch = {'0': 1} # [192, 192, 96]
def plan_experiment(task, batch, patience, fold, rank, model_name, root, ver):
root = os.path.join(root, ver)
with open(os.path.join(root, f'fold{fold}', 'dataset_stats.json'), 'r') as f:
dataset = json.load(f)
mean_size = dataset['mean_size']
small_size = dataset['small_size']
volume = dataset['volume_small']
modalities = len(dataset['modality'])
classes = len(dataset['labels'])
# Asuming the structures are cubes, if the volume of a structure is
# smaller than 32^3 we should use only 4 downsamples
threshold = 32 * 32 * 32
num_downsamples = 4 if volume < threshold else 5
# Analysis per axis
tr_size, val_size, p_size, strides, size_last = calculate_sizes(
num_downsamples, mean_size, small_size)
# MEMORY CONSTRAINT 1
# If the feature maps in the final stage are too big we make sure to use
# five downsamples
if size_last >= (12 * 12 * 6):
tr_size, val_size, p_size, strides, size_last = calculate_sizes(
5, mean_size, small_size)
num_downsamples = 5
# MEMORY CONSTRAINT 2
# If the feature maps in the final stage are still too big
# we reduce the input size
if size_last >= (12 * 12 * 4):
tr_size, val_size, p_size, strides, size_last = calculate_sizes(
num_downsamples, mean_size, small_size, 6)
feature_size = 48
strides = list(map(list, zip(*strides)))
if len(strides) == 4:
strides.insert(0, [1, 1, 1])
if rank == 0:
print('Current task is {}, with {} modalities and {} classes'.format(
task, modalities, classes))
print('The mean size of the images in this dataset is:', mean_size)
print('--- Training input size set to {} ---'.format(tr_size))
print('--- Validation input size set to {} ---'.format(val_size))
hyperparams = {
'task': task,
'classes': classes,
'p_size': p_size, # size of the patch before the transformations
'in_size': tr_size, # size of the patch that enters the network
'val_size': val_size, # size of the patch that enters the network
'test_size': val_size, # size of the patch that enters the network
'batch': batch,
'test_batch': int(batch * 1),
'patience': patience, # Make it dependent on the data?
'seed': 12345,
'output_folder': '',
'root': os.path.join(root, f'fold{fold}'),
'data_file': os.path.join(root, f'fold{fold}', 'dataset.json'),
}
model = {
'classes': classes,
'modalities': modalities,
'strides': strides[:3],
'img_size': tr_size,
'in_channels': modalities,
'out_channels': classes,
'feature_size': feature_size,
'use_checkpoint': True,
}
return hyperparams, model
def calculate_sizes(num_downsamples, mean_size, small_size, max_pow=7):
tr_size = []
val_size = []
strides = []
p_size = []
size_last = 1
# Analysis per axis
for i, j in zip(mean_size, small_size):
# If the image is too big we'll treat it as if it was smaller to avoid
# ending up with huge input patches (memory constraint)
i_big = i
if i > 128:
i_big = i
i *= 0.7
# Calculate the maximum possible input patch size
power_t = min(int(math.log(i, 2)), max_pow) # Max 64 or 128
power_v = min(int(math.log(i, 2)), 7) # Max 128
sz_t = pow(2, power_t)
sz_v = pow(2, power_v)
# Calculate the number of strides
stride = min(min(int(math.log(j, 2)), num_downsamples), power_t - 2)
temp = np.ones(num_downsamples, dtype=int) * 2
temp[:-stride] = 1
strides.append(list(temp))
# Calculate the input patch size
constraint = sz_t / (2 ** stride) >= 8 and num_downsamples == 5
if sz_t * 1.5 < i and not constraint:
sz_t *= 1.5 # Max 196
if sz_v * 1.5 < i:
sz_v *= 1.5 # Max 196
tr_size.append(int(sz_t))
val_size.append(np.maximum(int(i_big // 2), int(sz_v)))
size_last *= (sz_t / (2 ** stride))
if sz_t + 20 < i:
p_size.append(int(sz_t + 20))
else:
p_size.append(int(sz_t))
return tr_size, val_size, p_size, strides, size_last