diff --git a/.ipynb_checkpoints/README-checkpoint.md b/.ipynb_checkpoints/README-checkpoint.md new file mode 100644 index 0000000..cbbbb08 --- /dev/null +++ b/.ipynb_checkpoints/README-checkpoint.md @@ -0,0 +1,6 @@ +# 3D-ConvLSTMs-for-Monte-Carlo + +This is the official repository for the article ***High-particle simulation of Monte-Carlo dose distribution with 3D ConvLSTMs*** presented in MICCAI 2021 (Strasbourg). + +![](https://github.com/soniamartinot/3D-ConvLSTMs-for-Monte-Carlo/blob/master/case_3339.gif) +![](https://github.com/soniamartinot/3D-ConvLSTMs-for-Monte-Carlo/blob/master/case_3115.gif) diff --git a/.ipynb_checkpoints/lunet4_bn_leakyrelu_3d-checkpoint.py b/.ipynb_checkpoints/lunet4_bn_leakyrelu_3d-checkpoint.py new file mode 100644 index 0000000..e9d6c53 --- /dev/null +++ b/.ipynb_checkpoints/lunet4_bn_leakyrelu_3d-checkpoint.py @@ -0,0 +1,189 @@ +import torch +import torch.nn as nn +from convlstm3D import * +from copy import deepcopy + +class DownBlock(nn.Module): + def __init__(self, in_channels, out_channels, to_bottleneck=False): + super(DownBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.to_bottleneck = to_bottleneck + self.conv1 = nn.Conv3d(in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1) + self.relu1 = nn.LeakyReLU() + self.bn1 = nn.BatchNorm3d(self.out_channels) + self.conv2 = nn.Conv3d(in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1) + self.relu2 = nn.LeakyReLU() + self.bn2 = nn.BatchNorm3d(self.out_channels) + self.maxpool = nn.MaxPool3d(kernel_size=2, + stride=2) + self.clstm = ConvLSTM3DCell(input_dim=self.out_channels, + hidden_dim=self.out_channels, + kernel_size=(3, 3, 3), + bias=True) + + + def forward(self, input, cur_state): + a1 = self.bn1(self.relu1(self.conv1(input))) + a2 = self.bn2(self.relu2(self.conv2(a1))) + h, c = self.clstm(a2, cur_state) + if not self.to_bottleneck: + going_down = self.maxpool(a2) + else: going_down = a2 + # h will be concatenated with a skip connection + return going_down, h, c + + + + +class Bottleneck(nn.Module): + def __init__(self, in_channels, out_channels): + super(Bottleneck, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.conv1 = nn.Conv3d(in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1) + self.relu1 = nn.LeakyReLU() + self.bn1 = nn.BatchNorm3d(self.out_channels) + self.conv2 = nn.Conv3d(in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1) + self.relu2 = nn.LeakyReLU() + self.bn2 = nn.BatchNorm3d(self.out_channels) + + def forward(self, input): + a1 = self.bn1(self.relu1(self.conv1(input))) + a2 = self.bn2(self.relu2(self.conv2(a1))) + return a2 + + + + +class UpBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(UpBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.in_channels = in_channels + self.conv1 = nn.Conv3d(in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1) + self.bn1 = nn.BatchNorm3d(self.out_channels) + self.deconv2 = nn.ConvTranspose3d(in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=2, + stride=2, + padding=0) + self.bn2 = nn.BatchNorm3d(self.out_channels) + + def forward(self, input_sequence): + a1 = self.bn1(self.conv1(input_sequence)) + a2 = self.bn2(self.deconv2(a1)) + return a2 + + + +def crop_weird_sizes(img): + h, w = img.shape[-2], img.shape[-1] + if h % 2 != 0: img = img[..., :-1, :] + if w % 2 != 0: img = img[..., :-1] + return img + + + +def adjust_crop(img_a, img_b): + H, W = img_a.shape[-2], img_a.shape[-1] + h_out, w_out = img_b.shape[-2], img_a.shape[-1] + + if h_out < H: + diff = H - h_out + img_a = img_a[..., :-diff, :] + if w_out < W: + diff = W - w_out + img_a = img_a[..., :-diff] + + try: assert img_b.shape[-2] == img_a.shape[-2] + except: print("Dimension mismatch (height): input {} vs output {} vs output after crop {}".format(H, h_out, img_a.shape[-2])) + try: assert img_b.shape[-1] == img_a.shape[-1] + except: print("Dimension mismatch (width): input {} vs output {} vs output after crop {}".format(W, w_out, img_a.shape[-1])) + return img_a + + + +class LUNet4BNLeaky3D(nn.Module): + def __init__(self, return_last=True): + super(LUNet4BNLeaky3D, self).__init__() + self.model_name = "lunet4" + self.return_last = return_last + + self.down1 = DownBlock(1, 64) + self.down2 = DownBlock(64, 128) + self.down3 = DownBlock(128, 256) + self.down4 = DownBlock(256, 512, to_bottleneck=True) + self.up1 = UpBlock(512, 256) + self.up2 = UpBlock(512, 128) + self.up3 = UpBlock(256, 1) + + self.Encoder = [self.down1, self.down2, self.down3, self.down4] + self.Bottleneck = Bottleneck(512, 512) + self.Decoder = [self.up1, self.up2, self.up3] + + def forward(self, input_sequence): + + # Initialize the lstm cells + b, _, _, H, W, D = input_sequence.size() + hidden_states = [] + for i, block in enumerate(self.Encoder): + height, width, depth = H // (2**i), W // (2**i), D // (2**i) +# height, width, depth = H, W // (2**i), D // (2**i) + h_t, c_t = block.clstm.init_hidden(b, (height, width, depth)) + hidden_states += [(h_t, c_t)] + + # Forward + time_outputs = [] + seq_len = input_sequence.shape[1] + for t in range(seq_len): + skip_inputs = [] + frame = input_sequence[:, t, ...] + + # Forward through encoder + for i, block in enumerate(self.Encoder): + h_t, c_t = hidden_states[i] + frame, h_t, c_t = block(frame, [h_t, c_t]) + hidden_states[i] = (h_t, c_t) + skip_inputs += [h_t] + + # We are at the bottleneck. + bottleneck = self.Bottleneck(h_t) + + # Forward through decoder + skip_inputs.reverse() + + for i, block in enumerate(self.Decoder): + # Concat with skipconnections + if i == 0: + decoded = block(bottleneck) + else: + skipped = skip_inputs[i] + concat = torch.cat([decoded, skipped], 1) + decoded = block(concat) + + if self.return_last and t == seq_len - 1: time_outputs = decoded + elif not self.return_last: time_outputs += [decoded] + + return time_outputs \ No newline at end of file diff --git a/.ipynb_checkpoints/mc_dataset_infinite_patch3D-checkpoint.py b/.ipynb_checkpoints/mc_dataset_infinite_patch3D-checkpoint.py new file mode 100644 index 0000000..ad30626 --- /dev/null +++ b/.ipynb_checkpoints/mc_dataset_infinite_patch3D-checkpoint.py @@ -0,0 +1,422 @@ +import torch +from torch.utils.data import Dataset +from torchvision import transforms +import numpy as np +import SimpleITK as sitk +import time, os, random +from tqdm import tqdm +from glob import glob +import multiprocessing +import matplotlib.pyplot as plt +from copy import copy + + + +class MC3DInfinitePatchDataset(Dataset): + def __init__(self, + train_list, + n_frames, + ct_path, + patch_size=80, + all_channels=False, + normalized_by_gt=True, + standardize=False, + uncertainty_thresh=0.02, + dose_thresh=0.8, + seed=1, + transform=True, + unet=False, + verbose=False, + add_ct=False, + ct_norm=False, + high_dose_only=False, + p1=0.5, + p2=0.2, + single_frame=False, + depth=8, + mode="infinite", + n_samples=15000, + raw=False): + + # Set attributes + self.train_list = train_list # list - List comprising the paths to the cases used in the dataset + self.n_frames = n_frames # int - Number of noisy frames to input the model + self.ct_path = ct_path # string - Path to CT images corresponding to cases in the train_list + self.patch_size = patch_size # int - Size of the patch + self.all_channels = all_channels # bool - Whether to create patches with respect to each dimension + self.normalized_by_gt = normalized_by_gt # bool - Whether to normalize the data beforehand + self.standardize = standardize # bool - Whether to standardize the data + self.uncertainty_thresh = uncertainty_thresh # float - Set the uncertainty threshold below which we select the training samples + self.dose_thresh = dose_thresh # float - Set the dose threshold above which we select the training samples + self.seed = seed # int - Set the random seed + self.transform = transform # bool - Whether to add basic data augmentation + self.unet = unet # bool - Whether the model is unet like + self.add_ct = add_ct # bool - Whether to add the corresponding CT slice to the model's input + self.ct_norm = ct_norm # bool - Whether to normalize the CT volume + self.high_dose_only = high_dose_only # bool - Whether to train on high dose regions only + self.p1 = p1 # float - Probability below which patches are drawn from low dose regions + self.p2 = p2 # float - Probability above which patches are drawn from high dose regions + self.single_frame = single_frame # bool - Whether to train on a single frame instead of a sequence + self.depth = depth # int - Depth of a patch + self.mode = mode # bool - Whether to train in "finite" (looping over the dataset) or "infinite" mode + self.n_samples = n_samples # float - Number of samples to train on + self.raw = raw # bool - Whether to train on raw data with no normalization whatsoever + + + a = time.time() + if verbose: print("\nLoading dataset...") + + # If we want a variable size dataset + random.seed(self.seed) + np.random.seed(self.seed) + torch.manual_seed(self.seed) + + + # Particles to path dictionnary + self.dict_particles = {case_path: self.get_particles_to_path(case_path) for case_path in tqdm(self.train_list)} + self.dict_case_path = {os.path.basename(case_path): case_path for case_path in self.train_list} + + self.dict_ct = {case_path: + np.load(self.ct_path + "cropped_{}.npy".format(os.path.basename(case_path)), allow_pickle=True) + for case_path in tqdm(self.train_list)} + + + + if self.mode == "finite": + self.path_idx_dict = {idx: random.choice(self.train_list) for idx in range(50)} + + if self.ct_norm: + self.ct_max = 3071 + self.ct_min = -1000 + for case_path, ct in self.dict_ct.items(): + self.dict_ct[case_path] = (ct - self.ct_min) / (self.ct_max - self.ct_min) + + if self.mode != "infinite": + print("Initialization of finite mode") + # Here code hard mining when cases are too hard + self.path_mapping = {idx: random.choice(self.train_list) for idx in range(self.n_samples)} + self.path_to_idx = {} + for idx, case_path in self.path_mapping.items(): + self.path_to_idx[case_path] = self.path_to_idx.get(case_path, []) + [idx] + + # Dictionnary mapping indexes to slice number + if self.all_channels: self.channel_mapping = {idx: np.random.randint(3) for idx in self.path_mapping} + else: self.channel_mapping = {idx: 0 for idx in self.path_mapping} + # init number of slices + self.init_slice_numbers() + + if verbose: print("Loading dataset with {} samples took: {:.2f} minutes.\n".format(self.n_samples, (time.time() - a)/60)) + + + def __len__(self): + if self.mode == 'infinite': return int(1e6) + else: return self.n_samples + + def init_slice_numbers(self): + self.slice_mapping = {} + for case_path, idx_list in tqdm(self.path_to_idx.items()): + particles_to_path = self.dict_particles[case_path] + particles = sorted(list(particles_to_path.keys())) + + # Choose the slices where the uncertainty is the lowest + case = os.path.basename(case_path) + if os.path.isfile(case_path + "/{}_uncertainty_{}_0.npy".format(case, particles[-1])): + relunc = np.load(case_path + "/{}_uncertainty_{}_0.npy".format(case, particles[-1]), allow_pickle=True) + else: + relunc = np.load(case_path + "/{}_uncertainty_{}.npy".format(case, particles[-1]), allow_pickle=True) + + # Choose where the dose is the highest + dose = particles_to_path[particles[-1]][0] + + # Probability + p = np.random.rand() + if self.high_dose_only: + if p > self.p1: thresh = 0.6 + elif self.p1 >= p > self.p2: thresh = 0.2 + else: thresh = 0 + else: + thresh = self.dose_thresh + x_gt, y_gt, z_gt = np.where(dose > thresh * np.max(dose)) + x_unc, y_unc, z_unc = np.where(relunc < self.uncertainty_thresh) + + x_thresh = self.common_member(x_gt, x_unc) + y_thresh = self.common_member(y_gt, y_unc) + z_thresh = self.common_member(z_gt, z_unc) + + x_shape, y_shape, z_shape = dose.shape + half_patch_size = int(self.patch_size / 2) + half_depth = int(self.depth / 2) + for idx in idx_list: + channel = self.channel_mapping[idx] + if channel == 0: + a = np.arange(half_depth, x_shape-half_depth) + b = np.arange(half_patch_size, y_shape-half_patch_size) + c = np.arange(half_patch_size, z_shape-half_patch_size) + elif channel == 1: + a = np.arange(half_patch_size, x_shape-half_patch_size) + b = np.arange(half_depth, y_shape-half_depth) + c = np.arange(half_patch_size, z_shape-half_patch_size) + elif channel == 2: + a = np.arange(half_patch_size, x_shape-half_patch_size) + b = np.arange(half_patch_size, y_shape-half_patch_size) + c = np.arange(half_depth, z_shape-half_depth) + + a = self.common_member(x_thresh, a) + b = self.common_member(y_thresh, b) + c = self.common_member(z_thresh, c) + self.slice_mapping[idx] = (np.random.randint(np.min(a), np.max(a)), + np.random.randint(np.min(b), np.max(b)), + np.random.randint(np.min(c), np.max(c))) + + + def get_particles_to_path(self, case_path): + particles_to_path = {} + for p in glob(case_path + "/*"): + if not 'uncertainty' in p and not 'squared' in p: + n = int(os.path.basename(p).split("/")[-1].split('_')[1].split('.')[0]) + particles_to_path[n] = particles_to_path.get(n, []) + [p] + particles = sorted(list(particles_to_path)) + particles_to_path[particles[-1]] = [np.load(particles_to_path[particles[-1]][0], allow_pickle=True)] + return particles_to_path + + + + def create_pair(self, path, channel=0, idx=None, patch=True): + particles_to_path = self.dict_particles[path] + particles = sorted(list(particles_to_path)) + gt = particles_to_path[particles[-1]][0] + + # Get patch + half_patch_size = int(self.patch_size / 2) + half_depth = int(self.depth / 2) + + if self.mode == "infinite": + # Probability + p = np.random.rand() + if self.high_dose_only: + if p > self.p1: thresh = 0.6 + elif self.p1 >= p > self.p2: thresh = 0.2 + else: thresh = 0 + else: + if p > 0.5: thresh = 0.3 + else: thresh = 0. + + x_gt, y_gt, z_gt = np.where(gt >= np.max(gt) * thresh) + + x_shape, y_shape, z_shape = gt.shape + if channel == 0: + a = np.arange(half_depth, x_shape-half_depth) + b = np.arange(half_patch_size, y_shape-half_patch_size) + c = np.arange(half_patch_size, z_shape-half_patch_size) + elif channel == 1: + a = np.arange(half_patch_size, x_shape-half_patch_size) + b = np.arange(half_depth, y_shape-half_depth) + c = np.arange(half_patch_size, z_shape-half_patch_size) + elif channel == 2: + a = np.arange(half_patch_size, x_shape-half_patch_size) + b = np.arange(half_patch_size, y_shape-half_patch_size) + c = np.arange(half_depth, z_shape-half_depth) + + + a = self.common_member(x_gt, a) + b = self.common_member(y_gt, b) + c = self.common_member(z_gt, c) + + + # Get slice numbers + x = random.randint(np.min(a), np.max(a)) + y = random.randint(np.min(b), np.max(b)) + z = random.randint(np.min(c), np.max(c)) + + elif idx is not None: + # Get slice_number + x, y, z = self.slice_mapping[idx] + + if patch: + # Get ground-truth + if channel == 0: ground_truth = copy(gt[x-half_depth:x+half_depth, y-half_patch_size:y+half_patch_size, z-half_patch_size:z+half_patch_size]) + elif channel == 1: ground_truth = copy(gt[x-half_patch_size:x+half_patch_size, y-half_depth:y+half_depth, z-half_patch_size:z+half_patch_size]) + elif channel == 2: ground_truth = copy(gt[x-half_patch_size:x+half_patch_size, y-half_patch_size:y+half_patch_size, z-half_depth:z+half_depth]) + h, w, d = ground_truth.shape + + # If only a single input frame for example in the case of UNet + if self.single_frame: + # Create sequence with added CT in first place + if self.add_ct: + sequence = np.empty((2, h, w, d)) + n_particles = particles[int(self.n_frames -1)] + ind = np.random.randint(len(particles_to_path[n_particles])) + path = particles_to_path[n_particles][ind] + if channel == 0: + sequence[1] = np.load(path, allow_pickle=True)[x-half_depth:x+half_depth, y-half_patch_size:y+half_patch_size, z-half_patch_size:z+half_patch_size] + elif channel == 1: + sequence[1] = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_depth:y+half_depth, z-half_patch_size:z+half_patch_size] + elif channel == 2: + sequence[1] = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_patch_size:y+half_patch_size, z-half_depth:z+half_depth] + case = os.path.basename(path).split("_")[0] + sequence[0] = self.dict_ct[self.dict_case_path[case]][x-half_depth:x+half_depth, y-half_patch_size:y+half_patch_size, z-half_patch_size:z+half_patch_size] + # Create sequence without CT + else: + sequence = np.empty((1, h, w, d)) + n_particles = particles[int(self.n_frames -1)] + ind = np.random.randint(len(particles_to_path[n_particles])) + path = particles_to_path[n_particles][ind] + if channel == 0: + sequence[0] = np.load(path, allow_pickle=True)[x-half_depth:x+half_depth, y-half_patch_size:y+half_patch_size, z-half_patch_size:z+half_patch_size] + elif channel == 1: + sequence[0] = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_depth:y+half_depth, z-half_patch_size:z+half_patch_size] + elif channel == 2: + sequence[0] = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_patch_size:y+half_patch_size, z-half_depth:z+half_depth] + + # If several input frames + else: + # Create sequence with added CT in first place + if self.add_ct: + sequence = np.empty((self.n_frames+1, h, w, d)) + for i, n in enumerate(particles[:self.n_frames]): + ind = np.random.randint(len(particles_to_path[n])) + path = particles_to_path[n][ind] + if channel == 0: + frame = np.load(path, allow_pickle=True)[x-half_depth:x+half_depth, y-half_patch_size:y+half_patch_size, z-half_patch_size:z+half_patch_size] + elif channel == 1: + frame = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_depth:y+half_depth, z-half_patch_size:z+half_patch_size] + elif channel == 2: + frame = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_patch_size:y+half_patch_size, z-half_depth:z+half_depth] + sequence[i+1] = frame + case = os.path.basename(path).split("_")[0] + sequence[0] = self.dict_ct[self.dict_case_path[case]][x-half_depth:x+half_depth, y-half_patch_size:y+half_patch_size, z-half_patch_size:z+half_patch_size] + # Create sequence without CT + else: + sequence = np.empty((self.n_frames, h, w, d)) + for i, n in enumerate(particles[:self.n_frames]): + ind = np.random.randint(len(particles_to_path[n])) + path = particles_to_path[n][ind] + if channel == 0: + frame = np.load(path, allow_pickle=True)[x-half_depth:x+half_depth, y-half_patch_size:y+half_patch_size, z-half_patch_size:z+half_patch_size] + elif channel == 1: + frame = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_depth:y+half_depth, z-half_patch_size:z+half_patch_size] + elif channel == 2: + frame = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_patch_size:y+half_patch_size, z-half_depth:z+half_depth] + sequence[i] = frame + else: + ground_truth = gt + h, w, d = ground_truth.shape + sequence = np.empty((self.n_frames, h, w, d)) + for i, n in enumerate(particles[:self.n_frames]): + ind = np.random.randint(len(particles_to_path[n])) + path = particles_to_path[n][ind] + sequence[i] = np.load(path, allow_pickle=True) + + # Reshape + a, b, c, d = sequence.shape + sequence = sequence.reshape((a, 1, b, c, d)) + ground_truth = ground_truth.reshape((1, 1, b, c, d)) + # Normalize by the max dose of the complete sequence (including ground truth) + m = np.max(ground_truth) + if self.normalized_by_gt: + sequence /= m + ground_truth /= m + # Else, scale between -1 and 1 + elif self.standardize: + sequence = (sequence - np.mean(sequence)) / np.std(sequence) + ground_truth = (ground_truth - np.mean(ground_truth)) / np.std(ground_truth) + # Raw data + elif self.raw: + sequence = sequence + ground_truth = ground_truth + # Else put every frame between 0 and 1 + else: + sequence /= np.ndarray.max(sequence, axis=(2, 3, 4))[:, np.newaxis, np.newaxis, np.newaxis] + ground_truth /= m + return sequence, ground_truth + + + + def common_member(self, a, b): + a_set = set(a) + b_set = set(b) + + if (a_set & b_set): + return list(a_set & b_set) + else: + print("No common elements") + return b + + + + def crop_and_adapt(self, img): + r_h, r_w, r_d = None, None, None + H, W, D = img.shape[-3], img.shape[-2], img.shape[-1] + if H % 2**3 != 0: r_h = - (H % 2**3) + if W % 2**3 != 0: r_w = - (W % 2**3) + if D % 2**3 != 0: r_d = - (D % 2**3) + return img[..., :r_h, :r_w, :r_d] + + + def __getitem__(self, idx): + + if self.mode == "infinite": + # Get path to random case + path = random.choice(self.train_list) + else: + # Get path of precise case +# path = self.path_idx_dict[idx] +# path = random.choice(self.train_list) + path = self.path_mapping[idx] + + # Get sequence et the frame to predict + sequence, next_frame = self.create_pair(path, patch=True, idx=idx, + channel=0) + + # Turn into tensors + sequence = torch.from_numpy(sequence) + next_frame = torch.from_numpy(next_frame) + + # Apply transformations + p = np.random.rand() + if self.transform and p > 0.5: + torch.manual_seed(idx) + composed = transforms.Compose([transforms.RandomHorizontalFlip(p=1), + transforms.RandomVerticalFlip(p=1)]) + + # Concat to transform + all_seq = torch.cat([sequence, next_frame], axis=0) + all_seq = composed(all_seq) + sequence = all_seq[:-1] + next_frame = all_seq[-1:] + + if self.unet: + # Crop to be processed by UNet + sequence = self.crop_and_adapt(sequence) + next_frame = self.crop_and_adapt(next_frame) + t, c, h, w, d = sequence.shape + sequence = torch.reshape(sequence, (t, h, w, d)) + next_frame = torch.reshape(next_frame, (1, h, w, d)) + return sequence.float(), next_frame.float() + + + def get_volumes(self, case_path): + particles_to_path = self.dict_particles[case_path] + particles = sorted(list(particles_to_path.keys())) + sequence, gt = self.create_pair(case_path, + channel=0, + patch=False) + sequence = torch.from_numpy(sequence) + sequence = self.crop_and_adapt(sequence) + gt = self.crop_and_adapt(gt) + H, W, D = sequence.shape[-3], sequence.shape[-2], sequence.shape[-1] + + if self.unet and not self.single_frame: + sequence = self.crop_and_adapt(sequence) + gt = self.crop_and_adapt(gt) + t, _, h, w, d = sequence.shape + sequence = torch.reshape(sequence, (t, h, w, d)) + elif self.unet and self.single_frame: +# sequence = self.crop_and_adapt(sequence)[-1] +# gt = self.crop_and_adapt(gt) + sequence = sequence[-1] + _, h, w, d = sequence.shape + sequence = torch.reshape(sequence, (1, h, w, d)) + + gt = np.reshape(gt, (H, W, D)) + return sequence, gt \ No newline at end of file diff --git a/.ipynb_checkpoints/train-checkpoint.py b/.ipynb_checkpoints/train-checkpoint.py new file mode 100644 index 0000000..99e32bf --- /dev/null +++ b/.ipynb_checkpoints/train-checkpoint.py @@ -0,0 +1,247 @@ +from losses import * +from utils import * +from training_functions import * +from mc_dataset_infinite_patch3D import * +from convlstm3D import * +from stack_convlstm3D import * + + +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from torch.utils.tensorboard import SummaryWriter + +import numpy as np +from glob import glob +from tqdm import tqdm +import os, sys, time, random, argparse +import SimpleITK as sitk +import datetime +from configparser import ConfigParser + + + +# Parser parameters +args = parse_args() +# Get the cases paths +cases = list_cases(args.simu_path, exclude=[]) +# Instantiate the selected model and indicate whether it's UNet based or not +model, unet = instantiate_model(args) +# Create train and validation dataloaders +train_loader, val_loader = create_dataloaders(args, cases, unet) +# Get the optimizer +optimizer = instantiate_optimizer(args, model) +# Get the learning rate scheduler +my_lr_scheduler = create_lr_scheduler(args, optimizer) +# Get the loss function +loss = create_loss(args) +# Write training configuration file +save_path = create_saving_framework(args) +# Instantitate tensorboard writer to write results for training monitoring +writer = SummaryWriter(save_path) + + +torch.cuda.set_device(args.gpu_number) +model.cuda() +if args.mode == "infinite": + print("Infinite training (mode: {})".format(args.mode)) + + # Limited number of iterations + iter_limit = int(6e5 / args.batch_size) + + + count_no_improvement = 0 + best_val, best_train = np.inf, np.inf + model.train() + val_step = 10 + loss_train, ssim_train, mse_train, l1_train = 0, 0, 0, 0 + a = time.time() + for iteration, data in enumerate(train_loader, 0): + if iteration > iter_limit: + print("Stopped training at 1e5 iterations.") + break + + sequence, target = data + sequence = sequence.float().cuda() + target = target.float().cuda() + + loss_, ssim_, mse_, l1_ = train(sequence, target, model, loss, optimizer, unet) + + loss_train += loss_ / val_step + ssim_train += ssim_ / val_step + mse_train += mse_ / val_step + l1_train += l1_ / val_step + + + # Validation step + if iteration % val_step == 0: + loss_val, mse_val, ssim_val, l1_val, pred, gt = validate(model, loss, val_loader, n_val=n_val, unet=unet) + + # Decrease learning rate when needed + if lr_scheduler == "plateau": + my_lr_scheduler.step(loss_val) + else: + my_lr_scheduler.step() + + # Writing to tensorboard + writer.add_scalars("Loss: {}".format(loss_name), {"train":loss_train, "validation":loss_val}, iteration) + writer.add_scalars("SSIM", {"train":ssim_train, "validation":ssim_val}, iteration) + writer.add_scalars("MSE", {"train":mse_train, "validation":mse_val}, iteration) + writer.add_scalars("L1", {"train":l1_train, "validation":l1_val}, iteration) + writer.add_scalar("Learning rate", get_lr(optimizer), iteration) + + # Create figure of samples to visualize + if iteration % 20 == 0: + idx = int(target.shape[1] / 2) + for k in range(len(pred)): + fig = plt.figure(figsize=(12, 6)) + plt.subplot(121) + plt.title("Prediction") + plt.axis('off') + plt.imshow(pred[k, 0, idx], cmap="magma") + plt.subplot(122) + plt.title("Ground-truth") + plt.axis('off') + plt.imshow(gt[k, idx], cmap="magma") + writer.add_figure("Sample {}".format(k), fig, global_step=iteration, close=True) + writer.flush() + + print("Iteration {} {:.2f} sec:\tLoss train: {:.2e} \tLoss val: {:.2e} \tL1 train: {:.2e} \tL1 val: {:.2e} \tSSIM train: {:.2e} \tSSIM val: {:.2e}".format( + iteration, + time.time() - a, + loss_train, loss_val, + l1_train, l1_val, + ssim_train, ssim_val)) + # Save models when reaching new best on validation + if loss_val < best_val: + count_no_improvement = 0 + best_val = loss_val + torch.save({ + 'epoch': iteration, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss}, + save_path + "/best_val_settings.pt") + torch.save( + model.state_dict(), + save_path + "/best_val_model.pt") + elif count_no_improvement > 5000: + print("\nEarly stopping") + break + elif iteration > 500: + count_no_improvement += 1 + + if loss_train < best_train: + best_train = loss_train + torch.save({ + 'epoch': iteration, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss}, + save_path+ "/best_train_settings.pt") + torch.save( + model.state_dict(), + save_path + "/best_train_model.pt") + + # Reset + loss_train, ssim_train, mse_train, l1_train = 0, 0, 0, 0 + a = time.time() + + +else: + print('Finite training (mode: {})'.format(args.mode)) + n_epochs = 100 + count_no_improvement = 0 + model.train() + val_step = 10 + iter_limit = 1e5 + loss_train, ssim_train, mse_train, l1_train = 0, 0, 0, 0 + best_val, best_train = np.inf, np.inf + a = time.time() + + for epoch in range(n_epochs): + for iteration, data in enumerate(train_loader, 0): + if iteration + len(train_loader) * epoch > iter_limit: break + sequence, target = data + sequence = sequence.float().cuda() + target = target.float().cuda() + + loss_, ssim_, mse_, l1_ = train(sequence, target, model, loss, optimizer, unet) + + loss_train += loss_ / val_step + ssim_train += ssim_ / val_step + mse_train += mse_ / val_step + l1_train += l1_ / val_step + + # Validation + if iteration % val_step == 0: + loss_val, mse_val, ssim_val, l1_val, pred, gt = validate(model, loss, val_loader, n_val=n_val, unet=unet) + + # Decrease learning rate when needed + if lr_scheduler == "plateau": + my_lr_scheduler.step(loss_val) + else: + my_lr_scheduler.step() + + writer.add_scalars("Loss: {}".format(loss_name), {"train":loss_train, "validation":loss_val}, iteration + len(train_loader) * epoch) + writer.add_scalars("SSIM", {"train":ssim_train, "validation":ssim_val}, iteration + len(train_loader) * epoch) + writer.add_scalars("MSE", {"train":mse_train, "validation":mse_val}, iteration + len(train_loader) * epoch) + writer.add_scalars("L1", {"train":l1_train, "validation":l1_val}, iteration + len(train_loader) * epoch) + writer.add_scalar("Learning rate", get_lr(optimizer), iteration + len(train_loader) * epoch) + + if iteration % 20 == 0: + idx = int(args.patch_size / 2) + for k in range(len(pred)): + fig = plt.figure(figsize=(12, 6)) + plt.subplot(121) + plt.title("Prediction") + plt.axis('off') + plt.imshow(pred[k, 0, idx], cmap="magma") + plt.subplot(122) + plt.title("Ground-truth") + plt.axis('off') + plt.imshow(gt[k, idx], cmap="magma") + + writer.add_figure("Sample {}".format(k), fig, global_step=iteration + len(train_loader) * epoch, close=True) + writer.flush() + + print("Iteration {} {:.2f} sec:\tLoss train: {:.2e} \tLoss val: {:.2e} \tL1 train: {:.2e} \tL1 val: {:.2e} \tSSIM train: {:.2e} \tSSIM val: {:.2e}".format( + iteration + len(train_loader) * epoch, + time.time() - a, + loss_train, loss_val, + l1_train, l1_val, + ssim_train, ssim_val)) + # Save models + if loss_val < best_val: + count_no_improvement = 0 + best_val = loss_val + torch.save({ + 'epoch': iteration + len(train_loader) * epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss}, + save_path+ "/best_val_settings.pt") + torch.save( + model.state_dict(), + save_path + "/best_val_model.pt") + elif count_no_improvement > 2000: + print("\nEarly stopping") + break + else: + count_no_improvement += 1 + + if loss_train < best_train: + best_train = loss_train + torch.save({ + 'epoch': iteration + len(train_loader) * epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss}, + save_path+ "/best_train_settings.pt") + torch.save( + model.state_dict(), + save_path + "/best_train_model.pt") + + # Reset + loss_train, ssim_train, mse_train, l1_train = 0, 0, 0, 0 + a = time.time() diff --git a/.ipynb_checkpoints/training_functions-checkpoint.py b/.ipynb_checkpoints/training_functions-checkpoint.py new file mode 100644 index 0000000..7551ae2 --- /dev/null +++ b/.ipynb_checkpoints/training_functions-checkpoint.py @@ -0,0 +1,73 @@ +import torch +import pytorch_ssim + + +def validate(model, criterion, dataloader, n_val, unet=False): + running_val_loss, running_mse_loss, running_ssim_loss, running_l1_loss = 0, 0, 0, 0 + # Losses + mse_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + ssim_loss = pytorch_ssim.SSIM() + # Validation + count, count_batch = 0, 0 + with torch.no_grad(): + for i, data in enumerate(dataloader, 0): + + sequence, target = data + sequence = sequence.float().cuda() + target = target.float().cuda() + + outputs = model(sequence) + if unet: target = target[:, 0] + else: target = target[:, 0, 0] + + # Training loss + loss = criterion(outputs[:, 0], target) + + # Evaluation losses + mse_ = mse_loss(outputs[:, 0], target).item() + ssim_ = ssim_loss(outputs[:, 0], target).item() + l1 = l1_loss(outputs[:, 0], target).item() + + running_val_loss += loss.item() + running_l1_loss += l1_ + running_mse_loss += mse_ + running_ssim_loss += ssim_ + + if count > n_val: break + else: + count_batch += 1 + count += len(target) + + # Get the average loss per batch + running_val_loss /= count_batch + running_mse_loss /= count_batch + running_ssim_loss /= count_batch + running_l1_loss /= count_batch + return running_val_loss, running_mse_loss, running_ssim_loss, running_l1_loss, outputs.detach().cpu().numpy()[:5], target.detach().cpu().numpy()[:5] + +def train(sequence, target, model, loss, optimizer, unet=False): + + # Losses + mse_loss = torch.nn.MSELoss() + ssim_loss = pytorch_ssim.SSIM() + l1_loss = torch.nn.L1Loss() + + # Prediction + outputs = model(sequence) + + if unet: target = target[:, 0] + else: target = target[:, 0, 0] + loss_value = loss(outputs[:, 0], target) + + # Backpropagation + loss_value.backward() + optimizer.step() + optimizer.zero_grad() + + # print statistics + loss_ = loss_value.item() + ssim_ = ssim_loss(outputs[:, 0], target).item() + mse_ = mse_loss(outputs[:, 0], target).item() + l1_ = l1_loss(outputs[:, 0], target).item() + return loss_, ssim_, mse_, l1_ \ No newline at end of file diff --git a/.ipynb_checkpoints/utils-checkpoint.py b/.ipynb_checkpoints/utils-checkpoint.py new file mode 100644 index 0000000..2e2d199 --- /dev/null +++ b/.ipynb_checkpoints/utils-checkpoint.py @@ -0,0 +1,234 @@ +import os +import torch +from torch.utils.data import DataLoader +from glob import glob +import argparse +from configparser import ConfigParser +from datetime import datetime +from mc_dataset_infinite_patch3D import * +from convlstm3D import * +from stack_convlstm3D import * +from losses import * +from lunet4_bn_leakyrelu_3d import * +from bionet3d import * + + +def parse_args(): + parser = argparse.ArgumentParser( + description='3D ConvLSTM training', + add_help=True) + + parser.add_argument("--simu_path", '-simu', default='/workspace/workstation/lstm_denoising/numpy_simulations/', type=str, + help="Path pointing to the folder where the MC simulations are kept in .npy format") + parser.add_argument("--ct_path", '-ctp', default='/workspace/workstation/numpy_cropped_images/', type=str, + help="Path to CT images") + parser.add_argument("--n_samples", '-n', default=0, type=float, + help='Number of training samples') + parser.add_argument("--patch_size", '-ps', default=64, type=int, + help='Size the height and width of a patch') + parser.add_argument("--n_train", '-nt', default=40, type=int, + help='Number of cases to use in training set') + parser.add_argument("--model_name", '-m', default='stack3D_deep', type=str, + help='Name of the model') + parser.add_argument("--loss_name", '-l', default='ssim-smoothl1', type=str, + help='Name of the loss') + parser.add_argument("--learning_rate", '-lr', default=5e-5, type=float, + help='Initial learning rate') + parser.add_argument("--weight_decay", default=1e-6, type=float, + help='Initial weight decay') + parser.add_argument("--optimizer_name", default='adamw', type=str, + help='Name of the optimizer') + parser.add_argument("--save_path", '-save', default='.', type=str, + help="Path to save the model's weigths and results") + parser.add_argument('--gpu_number', '-g', default=4, type=int, + help='GPU identifier (default 0)') + parser.add_argument('--all_channels', '-ac', default=False, action='store_false', + help='Whether to select patches from all views') + parser.add_argument('--normalized_by_gt', '-ngt', default=True, action='store_true', + help='Whether to normalize input data by ground-truth maximum') + parser.add_argument("--standardize", '-st', default=False, action='store_false', + help='Whether to standardize input data') + parser.add_argument("--uncertainty_thresh", '-u', default=0.2, type=float, + help='Uncertainty threshold') + parser.add_argument("--dose_thresh", '-dt', default=0.2, type=float, + help='Dose threshold') + parser.add_argument("--n_frames", '-nf', default=3, type=int, + help='Number of frames in the input sequence') + parser.add_argument("--batch_size", '-bs', default=8, type=int, + help='Batch size') + parser.add_argument("--add_ct", '-ct', default=False, action='store_false', + help='Whether to add CT in input sequence') + parser.add_argument("--ct_norm", '-nct', default=True, action='store_true', + help='Whether to normalize the CT') + parser.add_argument("--high_dose_only", '-hd', default=False, action='store_false', + help='Whether to train only on high dose regions') + parser.add_argument("--p1", '-p1', default=0.1, type=float, + help="Probability below which you draw patches from low dose regions") + parser.add_argument("--p2", '-p2', default=0.6, type=float, + help='Probability above which you draw patches from high dose regions') + parser.add_argument("--single_frame", '-sf', action='store_false', + help='Whether you train on single frame instead of sequence') + parser.add_argument("--mode", '-mode', default='infinite', type=str, + help="Whether to do finite or infinite training") + parser.add_argument("--lr_scheduler", '-lrs', default='plateau', type=str, + help="Name of the learning rate scheduler to use") + parser.add_argument("--depth", '-d', default='64', type=int, + help='Depth of an input patch') + parser.add_argument("--raw", '-r', default=False, action='store_false', + help='Whether to train on raw data with no preprocessing whatsoever') + parser.add_argument("--n_layers", '-nl', default=3, type=int, + help="Number of layers for BiONet3D") + args = parser.parse_args() + return args + + +def instantiate_model(args): + + if args.model_name == "stack3D": + unet = False + model = stack_model_3D() + elif args.model_name == "stack3D_deep": + unet = False + model = stack_model_3D_deep() + elif args.model_name == "lunet4-bn-leaky3D": + model = LUNet4BNLeaky3D() + unet = False + elif args.model_name == "lunet4-bn-leaky3D_big": + model = LUNet4BNLeaky3D_big() + unet = False + elif args.model_name == "bionet3d": + if args.single_frame and args.add_ct: model = BiONet3D(input_channels=2, num_layers=args.num_layers) + elif args.single_frame: model = BiONet3D(input_channels=1, num_layers=args.num_layers) + elif not args.single_frame and args.add_ct: model = BiONet3D(input_channels=args.n_frames + 1, num_layers=args.num_layers) + else: model = BiONet3D(input_channels=args.n_frames, num_layers=args.num_layers) + unet = True + elif args.model_name == "unet3d": + if args.single_frame and args.add_ct: model = UNet3D(n_frames=2, num_layers=args.num_layers) + elif args.single_frame: model = UNet3D(n_frames=1, num_layers=args.num_layers) + elif not args.single_frame and args.add_ct: model = UNet3D(n_frames=args.n_frames + 1, num_layers=args.num_layers) + else: model = UNet3D(n_frames=args.n_frames, num_layers=args.num_layers) + unet = True + else: + print("Unrecognized model") + for i in range(1):break + print("Model: ", args.model_name) + return model, unet + + +def instantiate_optimizer(args, model): + params = [p for p in model.parameters() if p.requires_grad] + if args.optimizer_name == "adam": optimizer = torch.optim.Adam(params, lr=args.learning_rate, weight_decay=wargs.eight_decay) + elif args.optimizer_name == "adamw": optimizer = torch.optim.AdamW(params, lr=args.learning_rate) + elif args.optimizer_name == "sgd": optimizer = torch.optim.SGD(params, lr=args.learning_rate, weight_decay=args.weight_decay) + else: print('Unrecognized optimizer', args.optimizer_name) + return optimizer + + +def create_loss(args): + if args.loss_name == "mse": loss = nn.MSELoss() + elif args.loss_name == "l1": loss = nn.L1Loss() + elif args.loss_name == "l1smooth": loss = nn.SmoothL1Loss() + elif args.loss_name == "ssim": loss = ssim_loss + elif args.loss_name == "ssim-smoothl1": loss = ssim_smoothl1 + elif args.loss_name == "ssim-mse": loss = ssim_mse + print("Loss used: ", args.loss_name) + return loss + + + +def create_lr_scheduler(args, optimizer): + # Learning rate decay + if args.lr_scheduler == "plateau": + decayRate = 0.8 + my_lr_scheduler= torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, + mode='min', + factor=decayRate, + patience=15, + threshold=1e-2, + threshold_mode='rel', + cooldown=5, + min_lr=1e-7, + eps=1e-08, + verbose=True) + elif args.lr_scheduler == "cosine": + # Learning rate update: cosine anneeling + my_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, + T_max=500, + eta_min=1e-8, + verbose=True) + return my_lr_scheduler + +def list_cases(path, exclude=[]): + return [p for p in glob(path + "*") if len(os.path.basename(p)) == 4 and not os.path.basename(p) in exclude] + + +def get_lr(optimizer): + for param_group in optimizer.param_groups: + return param_group['lr'] + + +def create_saving_framework(args): + # Don't forget to save all intermediate results and model + now = str(datetime.now()) + training_name = "train_{}_{}".format(now.split(' ')[0], now.split(' ')[-1]) + os.system("mkdir {save_path}/{training_name}".format(save_path=args.save_path , + training_name=training_name)) + save_path = args.save_path + "/" + training_name + print("Training:", training_name) + + parameters = {arg: getattr(args, arg) for arg in vars(args)} + config_object = ConfigParser() + config_object['Parameters for {}'.format(training_name)] = parameters + + # Write configuration file + with open(save_path + '/config.ini', 'w') as conf: + config_object.write(conf) + return save_path + + +def create_dataloaders(args, cases, unet): + n_train = args.n_train + train = MC3DInfinitePatchDataset(cases[:n_train], + n_frames = args.n_frames, + ct_path = args.ct_path, + patch_size = args.patch_size, + all_channels = args.all_channels, + normalized_by_gt = args.normalized_by_gt, + standardize = args.standardize, + uncertainty_thresh = args.uncertainty_thresh, + dose_thresh = args.dose_thresh, + unet = unet, + add_ct = args.add_ct, + ct_norm = args.ct_norm, + high_dose_only = args.high_dose_only, + p1 = args.p1, + p2 = args.p2, + single_frame = args.single_frame, + mode = args.mode, + n_samples = args.n_samples, + depth = args.depth) + + val = MC3DInfinitePatchDataset(cases[n_train:n_train+5], + n_frames = args.n_frames, + ct_path = args.ct_path, + patch_size = args.patch_size, + all_channels = args.all_channels, + normalized_by_gt = args.normalized_by_gt, + standardize = args.standardize, + uncertainty_thresh = args.uncertainty_thresh, + dose_thresh = args.dose_thresh, + unet = unet, + add_ct = args.add_ct, + ct_norm = args.ct_norm, + high_dose_only = args.high_dose_only, + p1 = args.p1, + p2 = args.p2, + single_frame = args.single_frame, + mode = 'finite', + n_samples = args.n_samples, + depth = args.depth) + + + train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, num_workers=6) + val_loader = DataLoader(val, batch_size=args.batch_size, shuffle=False, num_workers=6) + return train_loader, val_loader diff --git a/__pycache__/bionet3d.cpython-38.pyc b/__pycache__/bionet3d.cpython-38.pyc new file mode 100644 index 0000000..d905a78 Binary files /dev/null and b/__pycache__/bionet3d.cpython-38.pyc differ diff --git a/__pycache__/convlstm3D.cpython-38.pyc b/__pycache__/convlstm3D.cpython-38.pyc new file mode 100644 index 0000000..c989c66 Binary files /dev/null and b/__pycache__/convlstm3D.cpython-38.pyc differ diff --git a/__pycache__/losses.cpython-38.pyc b/__pycache__/losses.cpython-38.pyc new file mode 100644 index 0000000..41e7928 Binary files /dev/null and b/__pycache__/losses.cpython-38.pyc differ diff --git a/__pycache__/lunet4_bn_leakyrelu_3d.cpython-38.pyc b/__pycache__/lunet4_bn_leakyrelu_3d.cpython-38.pyc new file mode 100644 index 0000000..315f24b Binary files /dev/null and b/__pycache__/lunet4_bn_leakyrelu_3d.cpython-38.pyc differ diff --git a/__pycache__/mc_dataset_infinite_patch3D.cpython-38.pyc b/__pycache__/mc_dataset_infinite_patch3D.cpython-38.pyc new file mode 100644 index 0000000..5835679 Binary files /dev/null and b/__pycache__/mc_dataset_infinite_patch3D.cpython-38.pyc differ diff --git a/__pycache__/stack_convlstm3D.cpython-38.pyc b/__pycache__/stack_convlstm3D.cpython-38.pyc new file mode 100644 index 0000000..87c55b3 Binary files /dev/null and b/__pycache__/stack_convlstm3D.cpython-38.pyc differ diff --git a/__pycache__/training_functions.cpython-38.pyc b/__pycache__/training_functions.cpython-38.pyc new file mode 100644 index 0000000..b3e7cc3 Binary files /dev/null and b/__pycache__/training_functions.cpython-38.pyc differ diff --git a/__pycache__/utils.cpython-38.pyc b/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000..63d3934 Binary files /dev/null and b/__pycache__/utils.cpython-38.pyc differ diff --git a/lunet4_bn_leakyrelu_3d.py b/lunet4_bn_leakyrelu_3d.py index d5d1503..e9d6c53 100644 --- a/lunet4_bn_leakyrelu_3d.py +++ b/lunet4_bn_leakyrelu_3d.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from models.ndrplz_convlstm3D import * +from convlstm3D import * from copy import deepcopy class DownBlock(nn.Module): diff --git a/mc_dataset_infinite_patch3D.py b/mc_dataset_infinite_patch3D.py index 8a2cda3..ad30626 100644 --- a/mc_dataset_infinite_patch3D.py +++ b/mc_dataset_infinite_patch3D.py @@ -74,9 +74,10 @@ def __init__(self, # Particles to path dictionnary self.dict_particles = {case_path: self.get_particles_to_path(case_path) for case_path in tqdm(self.train_list)} - self.dict_case_path = {os.path.basename(case_path): case_path for case_path in self.train_list} + self.dict_case_path = {os.path.basename(case_path): case_path for case_path in self.train_list} + self.dict_ct = {case_path: - np.load(self.ct_path + "ct_{}.npy".format(os.path.basename(case_path)), allow_pickle=True) + np.load(self.ct_path + "cropped_{}.npy".format(os.path.basename(case_path)), allow_pickle=True) for case_path in tqdm(self.train_list)} @@ -313,7 +314,6 @@ def create_pair(self, path, channel=0, idx=None, patch=True): # Normalize by the max dose of the complete sequence (including ground truth) m = np.max(ground_truth) if self.normalized_by_gt: - print("Normalized by gt") sequence /= m ground_truth /= m # Else, scale between -1 and 1 diff --git a/pytorch_ssim/.ipynb_checkpoints/__init__-checkpoint.py b/pytorch_ssim/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000..738e803 --- /dev/null +++ b/pytorch_ssim/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,73 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def _ssim(img1, img2, window, window_size, channel, size_average = True): + mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) + mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1*mu2 + + sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq + sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq + sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + +class SSIM(torch.nn.Module): + def __init__(self, window_size = 11, size_average = True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + +def ssim(img1, img2, window_size = 11, size_average = True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/pytorch_ssim/__init__.py b/pytorch_ssim/__init__.py new file mode 100644 index 0000000..738e803 --- /dev/null +++ b/pytorch_ssim/__init__.py @@ -0,0 +1,73 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def _ssim(img1, img2, window, window_size, channel, size_average = True): + mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) + mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1*mu2 + + sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq + sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq + sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + +class SSIM(torch.nn.Module): + def __init__(self, window_size = 11, size_average = True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + +def ssim(img1, img2, window_size = 11, size_average = True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/pytorch_ssim/__pycache__/__init__.cpython-35.pyc b/pytorch_ssim/__pycache__/__init__.cpython-35.pyc new file mode 100644 index 0000000..3501dfe Binary files /dev/null and b/pytorch_ssim/__pycache__/__init__.cpython-35.pyc differ diff --git a/pytorch_ssim/__pycache__/__init__.cpython-37.pyc b/pytorch_ssim/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..6a611b8 Binary files /dev/null and b/pytorch_ssim/__pycache__/__init__.cpython-37.pyc differ diff --git a/pytorch_ssim/__pycache__/__init__.cpython-38.pyc b/pytorch_ssim/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..7cba296 Binary files /dev/null and b/pytorch_ssim/__pycache__/__init__.cpython-38.pyc differ diff --git a/train.py b/train.py index 97c7ad6..99e32bf 100644 --- a/train.py +++ b/train.py @@ -19,12 +19,14 @@ import datetime from configparser import ConfigParser + + # Parser parameters args = parse_args() # Get the cases paths cases = list_cases(args.simu_path, exclude=[]) # Instantiate the selected model and indicate whether it's UNet based or not -model, unet = instantiate_model(args.model_name) +model, unet = instantiate_model(args) # Create train and validation dataloaders train_loader, val_loader = create_dataloaders(args, cases, unet) # Get the optimizer @@ -34,12 +36,13 @@ # Get the loss function loss = create_loss(args) # Write training configuration file -create_saving_framework(args) +save_path = create_saving_framework(args) # Instantitate tensorboard writer to write results for training monitoring writer = SummaryWriter(save_path) - +torch.cuda.set_device(args.gpu_number) +model.cuda() if args.mode == "infinite": print("Infinite training (mode: {})".format(args.mode)) diff --git a/training_functions.py b/training_functions.py index 2607656..7551ae2 100644 --- a/training_functions.py +++ b/training_functions.py @@ -9,7 +9,7 @@ def validate(model, criterion, dataloader, n_val, unet=False): l1_loss = torch.nn.L1Loss() ssim_loss = pytorch_ssim.SSIM() # Validation - count, count_batch =, 0 0 + count, count_batch = 0, 0 with torch.no_grad(): for i, data in enumerate(dataloader, 0): diff --git a/utils.py b/utils.py index c643d39..2e2d199 100644 --- a/utils.py +++ b/utils.py @@ -1,14 +1,15 @@ import os import torch +from torch.utils.data import DataLoader from glob import glob import argparse from configparser import ConfigParser -import datetime +from datetime import datetime from mc_dataset_infinite_patch3D import * from convlstm3D import * from stack_convlstm3D import * from losses import * -from lunet4_bn_leakyrely_3d import * +from lunet4_bn_leakyrelu_3d import * from bionet3d import * @@ -17,8 +18,10 @@ def parse_args(): description='3D ConvLSTM training', add_help=True) - parser.add_argument("--simu_path", '-simu', default='.', type=str, - help="Path pointing to the folder where the MC simulations are kept in .npy format") + parser.add_argument("--simu_path", '-simu', default='/workspace/workstation/lstm_denoising/numpy_simulations/', type=str, + help="Path pointing to the folder where the MC simulations are kept in .npy format") + parser.add_argument("--ct_path", '-ctp', default='/workspace/workstation/numpy_cropped_images/', type=str, + help="Path to CT images") parser.add_argument("--n_samples", '-n', default=0, type=float, help='Number of training samples') parser.add_argument("--patch_size", '-ps', default=64, type=int, @@ -37,9 +40,9 @@ def parse_args(): help='Name of the optimizer') parser.add_argument("--save_path", '-save', default='.', type=str, help="Path to save the model's weigths and results") - parser.add_argument('--gpu_number', '-g', default=0, type=int, + parser.add_argument('--gpu_number', '-g', default=4, type=int, help='GPU identifier (default 0)') - parser.add_argument('--all_channels', '-ac', default=False action='store_false', + parser.add_argument('--all_channels', '-ac', default=False, action='store_false', help='Whether to select patches from all views') parser.add_argument('--normalized_by_gt', '-ngt', default=True, action='store_true', help='Whether to normalize input data by ground-truth maximum') @@ -65,7 +68,7 @@ def parse_args(): help='Probability above which you draw patches from high dose regions') parser.add_argument("--single_frame", '-sf', action='store_false', help='Whether you train on single frame instead of sequence') - parser.add_argument("--mode", '-m', default='infinite', type=str, + parser.add_argument("--mode", '-mode', default='infinite', type=str, help="Whether to do finite or infinite training") parser.add_argument("--lr_scheduler", '-lrs', default='plateau', type=str, help="Name of the learning rate scheduler to use") @@ -107,7 +110,7 @@ def instantiate_model(args): unet = True else: print("Unrecognized model") - break + for i in range(1):break print("Model: ", args.model_name) return model, unet @@ -128,7 +131,7 @@ def create_loss(args): elif args.loss_name == "ssim": loss = ssim_loss elif args.loss_name == "ssim-smoothl1": loss = ssim_smoothl1 elif args.loss_name == "ssim-mse": loss = ssim_mse - print("Loss used: ", loss_name) + print("Loss used: ", args.loss_name) return loss @@ -180,12 +183,14 @@ def create_saving_framework(args): # Write configuration file with open(save_path + '/config.ini', 'w') as conf: config_object.write(conf) + return save_path def create_dataloaders(args, cases, unet): n_train = args.n_train train = MC3DInfinitePatchDataset(cases[:n_train], n_frames = args.n_frames, + ct_path = args.ct_path, patch_size = args.patch_size, all_channels = args.all_channels, normalized_by_gt = args.normalized_by_gt, @@ -205,6 +210,7 @@ def create_dataloaders(args, cases, unet): val = MC3DInfinitePatchDataset(cases[n_train:n_train+5], n_frames = args.n_frames, + ct_path = args.ct_path, patch_size = args.patch_size, all_channels = args.all_channels, normalized_by_gt = args.normalized_by_gt, @@ -224,5 +230,5 @@ def create_dataloaders(args, cases, unet): train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, num_workers=6) - val_loader = DataLoader(val, batch_size=val_batch_size, shuffle=False, num_workers=6) + val_loader = DataLoader(val, batch_size=args.batch_size, shuffle=False, num_workers=6) return train_loader, val_loader