From c2c75895be5ad1d6fe786c9464596852f7a27541 Mon Sep 17 00:00:00 2001 From: Sonia Date: Fri, 11 Jun 2021 17:12:35 +0000 Subject: [PATCH] new modif --- .ipynb_checkpoints/README-checkpoint.md | 6 + .../lunet4_bn_leakyrelu_3d-checkpoint.py | 189 ++++++++ .../mc_dataset_infinite_patch3D-checkpoint.py | 422 ++++++++++++++++++ .ipynb_checkpoints/train-checkpoint.py | 247 ++++++++++ .../training_functions-checkpoint.py | 73 +++ .ipynb_checkpoints/utils-checkpoint.py | 234 ++++++++++ __pycache__/bionet3d.cpython-38.pyc | Bin 0 -> 4269 bytes __pycache__/convlstm3D.cpython-38.pyc | Bin 0 -> 6198 bytes __pycache__/losses.cpython-38.pyc | Bin 0 -> 830 bytes .../lunet4_bn_leakyrelu_3d.cpython-38.pyc | Bin 0 -> 5165 bytes ...mc_dataset_infinite_patch3D.cpython-38.pyc | Bin 0 -> 10882 bytes __pycache__/stack_convlstm3D.cpython-38.pyc | Bin 0 -> 823 bytes __pycache__/training_functions.cpython-38.pyc | Bin 0 -> 1740 bytes __pycache__/utils.cpython-38.pyc | Bin 0 -> 8067 bytes lunet4_bn_leakyrelu_3d.py | 2 +- mc_dataset_infinite_patch3D.py | 6 +- .../.ipynb_checkpoints/__init__-checkpoint.py | 73 +++ pytorch_ssim/__init__.py | 73 +++ .../__pycache__/__init__.cpython-35.pyc | Bin 0 -> 2919 bytes .../__pycache__/__init__.cpython-37.pyc | Bin 0 -> 2612 bytes .../__pycache__/__init__.cpython-38.pyc | Bin 0 -> 2707 bytes train.py | 9 +- training_functions.py | 2 +- utils.py | 26 +- 24 files changed, 1344 insertions(+), 18 deletions(-) create mode 100644 .ipynb_checkpoints/README-checkpoint.md create mode 100644 .ipynb_checkpoints/lunet4_bn_leakyrelu_3d-checkpoint.py create mode 100644 .ipynb_checkpoints/mc_dataset_infinite_patch3D-checkpoint.py create mode 100644 .ipynb_checkpoints/train-checkpoint.py create mode 100644 .ipynb_checkpoints/training_functions-checkpoint.py create mode 100644 .ipynb_checkpoints/utils-checkpoint.py create mode 100644 __pycache__/bionet3d.cpython-38.pyc create mode 100644 __pycache__/convlstm3D.cpython-38.pyc create mode 100644 __pycache__/losses.cpython-38.pyc create mode 100644 __pycache__/lunet4_bn_leakyrelu_3d.cpython-38.pyc create mode 100644 __pycache__/mc_dataset_infinite_patch3D.cpython-38.pyc create mode 100644 __pycache__/stack_convlstm3D.cpython-38.pyc create mode 100644 __pycache__/training_functions.cpython-38.pyc create mode 100644 __pycache__/utils.cpython-38.pyc create mode 100644 pytorch_ssim/.ipynb_checkpoints/__init__-checkpoint.py create mode 100644 pytorch_ssim/__init__.py create mode 100644 pytorch_ssim/__pycache__/__init__.cpython-35.pyc create mode 100644 pytorch_ssim/__pycache__/__init__.cpython-37.pyc create mode 100644 pytorch_ssim/__pycache__/__init__.cpython-38.pyc diff --git a/.ipynb_checkpoints/README-checkpoint.md b/.ipynb_checkpoints/README-checkpoint.md new file mode 100644 index 0000000..cbbbb08 --- /dev/null +++ b/.ipynb_checkpoints/README-checkpoint.md @@ -0,0 +1,6 @@ +# 3D-ConvLSTMs-for-Monte-Carlo + +This is the official repository for the article ***High-particle simulation of Monte-Carlo dose distribution with 3D ConvLSTMs*** presented in MICCAI 2021 (Strasbourg). + +![](https://github.com/soniamartinot/3D-ConvLSTMs-for-Monte-Carlo/blob/master/case_3339.gif) +![](https://github.com/soniamartinot/3D-ConvLSTMs-for-Monte-Carlo/blob/master/case_3115.gif) diff --git a/.ipynb_checkpoints/lunet4_bn_leakyrelu_3d-checkpoint.py b/.ipynb_checkpoints/lunet4_bn_leakyrelu_3d-checkpoint.py new file mode 100644 index 0000000..e9d6c53 --- /dev/null +++ b/.ipynb_checkpoints/lunet4_bn_leakyrelu_3d-checkpoint.py @@ -0,0 +1,189 @@ +import torch +import torch.nn as nn +from convlstm3D import * +from copy import deepcopy + +class DownBlock(nn.Module): + def __init__(self, in_channels, out_channels, to_bottleneck=False): + super(DownBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.to_bottleneck = to_bottleneck + self.conv1 = nn.Conv3d(in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1) + self.relu1 = nn.LeakyReLU() + self.bn1 = nn.BatchNorm3d(self.out_channels) + self.conv2 = nn.Conv3d(in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1) + self.relu2 = nn.LeakyReLU() + self.bn2 = nn.BatchNorm3d(self.out_channels) + self.maxpool = nn.MaxPool3d(kernel_size=2, + stride=2) + self.clstm = ConvLSTM3DCell(input_dim=self.out_channels, + hidden_dim=self.out_channels, + kernel_size=(3, 3, 3), + bias=True) + + + def forward(self, input, cur_state): + a1 = self.bn1(self.relu1(self.conv1(input))) + a2 = self.bn2(self.relu2(self.conv2(a1))) + h, c = self.clstm(a2, cur_state) + if not self.to_bottleneck: + going_down = self.maxpool(a2) + else: going_down = a2 + # h will be concatenated with a skip connection + return going_down, h, c + + + + +class Bottleneck(nn.Module): + def __init__(self, in_channels, out_channels): + super(Bottleneck, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.conv1 = nn.Conv3d(in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1) + self.relu1 = nn.LeakyReLU() + self.bn1 = nn.BatchNorm3d(self.out_channels) + self.conv2 = nn.Conv3d(in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1) + self.relu2 = nn.LeakyReLU() + self.bn2 = nn.BatchNorm3d(self.out_channels) + + def forward(self, input): + a1 = self.bn1(self.relu1(self.conv1(input))) + a2 = self.bn2(self.relu2(self.conv2(a1))) + return a2 + + + + +class UpBlock(nn.Module): + def __init__(self, in_channels, out_channels): + super(UpBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.in_channels = in_channels + self.conv1 = nn.Conv3d(in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1) + self.bn1 = nn.BatchNorm3d(self.out_channels) + self.deconv2 = nn.ConvTranspose3d(in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=2, + stride=2, + padding=0) + self.bn2 = nn.BatchNorm3d(self.out_channels) + + def forward(self, input_sequence): + a1 = self.bn1(self.conv1(input_sequence)) + a2 = self.bn2(self.deconv2(a1)) + return a2 + + + +def crop_weird_sizes(img): + h, w = img.shape[-2], img.shape[-1] + if h % 2 != 0: img = img[..., :-1, :] + if w % 2 != 0: img = img[..., :-1] + return img + + + +def adjust_crop(img_a, img_b): + H, W = img_a.shape[-2], img_a.shape[-1] + h_out, w_out = img_b.shape[-2], img_a.shape[-1] + + if h_out < H: + diff = H - h_out + img_a = img_a[..., :-diff, :] + if w_out < W: + diff = W - w_out + img_a = img_a[..., :-diff] + + try: assert img_b.shape[-2] == img_a.shape[-2] + except: print("Dimension mismatch (height): input {} vs output {} vs output after crop {}".format(H, h_out, img_a.shape[-2])) + try: assert img_b.shape[-1] == img_a.shape[-1] + except: print("Dimension mismatch (width): input {} vs output {} vs output after crop {}".format(W, w_out, img_a.shape[-1])) + return img_a + + + +class LUNet4BNLeaky3D(nn.Module): + def __init__(self, return_last=True): + super(LUNet4BNLeaky3D, self).__init__() + self.model_name = "lunet4" + self.return_last = return_last + + self.down1 = DownBlock(1, 64) + self.down2 = DownBlock(64, 128) + self.down3 = DownBlock(128, 256) + self.down4 = DownBlock(256, 512, to_bottleneck=True) + self.up1 = UpBlock(512, 256) + self.up2 = UpBlock(512, 128) + self.up3 = UpBlock(256, 1) + + self.Encoder = [self.down1, self.down2, self.down3, self.down4] + self.Bottleneck = Bottleneck(512, 512) + self.Decoder = [self.up1, self.up2, self.up3] + + def forward(self, input_sequence): + + # Initialize the lstm cells + b, _, _, H, W, D = input_sequence.size() + hidden_states = [] + for i, block in enumerate(self.Encoder): + height, width, depth = H // (2**i), W // (2**i), D // (2**i) +# height, width, depth = H, W // (2**i), D // (2**i) + h_t, c_t = block.clstm.init_hidden(b, (height, width, depth)) + hidden_states += [(h_t, c_t)] + + # Forward + time_outputs = [] + seq_len = input_sequence.shape[1] + for t in range(seq_len): + skip_inputs = [] + frame = input_sequence[:, t, ...] + + # Forward through encoder + for i, block in enumerate(self.Encoder): + h_t, c_t = hidden_states[i] + frame, h_t, c_t = block(frame, [h_t, c_t]) + hidden_states[i] = (h_t, c_t) + skip_inputs += [h_t] + + # We are at the bottleneck. + bottleneck = self.Bottleneck(h_t) + + # Forward through decoder + skip_inputs.reverse() + + for i, block in enumerate(self.Decoder): + # Concat with skipconnections + if i == 0: + decoded = block(bottleneck) + else: + skipped = skip_inputs[i] + concat = torch.cat([decoded, skipped], 1) + decoded = block(concat) + + if self.return_last and t == seq_len - 1: time_outputs = decoded + elif not self.return_last: time_outputs += [decoded] + + return time_outputs \ No newline at end of file diff --git a/.ipynb_checkpoints/mc_dataset_infinite_patch3D-checkpoint.py b/.ipynb_checkpoints/mc_dataset_infinite_patch3D-checkpoint.py new file mode 100644 index 0000000..ad30626 --- /dev/null +++ b/.ipynb_checkpoints/mc_dataset_infinite_patch3D-checkpoint.py @@ -0,0 +1,422 @@ +import torch +from torch.utils.data import Dataset +from torchvision import transforms +import numpy as np +import SimpleITK as sitk +import time, os, random +from tqdm import tqdm +from glob import glob +import multiprocessing +import matplotlib.pyplot as plt +from copy import copy + + + +class MC3DInfinitePatchDataset(Dataset): + def __init__(self, + train_list, + n_frames, + ct_path, + patch_size=80, + all_channels=False, + normalized_by_gt=True, + standardize=False, + uncertainty_thresh=0.02, + dose_thresh=0.8, + seed=1, + transform=True, + unet=False, + verbose=False, + add_ct=False, + ct_norm=False, + high_dose_only=False, + p1=0.5, + p2=0.2, + single_frame=False, + depth=8, + mode="infinite", + n_samples=15000, + raw=False): + + # Set attributes + self.train_list = train_list # list - List comprising the paths to the cases used in the dataset + self.n_frames = n_frames # int - Number of noisy frames to input the model + self.ct_path = ct_path # string - Path to CT images corresponding to cases in the train_list + self.patch_size = patch_size # int - Size of the patch + self.all_channels = all_channels # bool - Whether to create patches with respect to each dimension + self.normalized_by_gt = normalized_by_gt # bool - Whether to normalize the data beforehand + self.standardize = standardize # bool - Whether to standardize the data + self.uncertainty_thresh = uncertainty_thresh # float - Set the uncertainty threshold below which we select the training samples + self.dose_thresh = dose_thresh # float - Set the dose threshold above which we select the training samples + self.seed = seed # int - Set the random seed + self.transform = transform # bool - Whether to add basic data augmentation + self.unet = unet # bool - Whether the model is unet like + self.add_ct = add_ct # bool - Whether to add the corresponding CT slice to the model's input + self.ct_norm = ct_norm # bool - Whether to normalize the CT volume + self.high_dose_only = high_dose_only # bool - Whether to train on high dose regions only + self.p1 = p1 # float - Probability below which patches are drawn from low dose regions + self.p2 = p2 # float - Probability above which patches are drawn from high dose regions + self.single_frame = single_frame # bool - Whether to train on a single frame instead of a sequence + self.depth = depth # int - Depth of a patch + self.mode = mode # bool - Whether to train in "finite" (looping over the dataset) or "infinite" mode + self.n_samples = n_samples # float - Number of samples to train on + self.raw = raw # bool - Whether to train on raw data with no normalization whatsoever + + + a = time.time() + if verbose: print("\nLoading dataset...") + + # If we want a variable size dataset + random.seed(self.seed) + np.random.seed(self.seed) + torch.manual_seed(self.seed) + + + # Particles to path dictionnary + self.dict_particles = {case_path: self.get_particles_to_path(case_path) for case_path in tqdm(self.train_list)} + self.dict_case_path = {os.path.basename(case_path): case_path for case_path in self.train_list} + + self.dict_ct = {case_path: + np.load(self.ct_path + "cropped_{}.npy".format(os.path.basename(case_path)), allow_pickle=True) + for case_path in tqdm(self.train_list)} + + + + if self.mode == "finite": + self.path_idx_dict = {idx: random.choice(self.train_list) for idx in range(50)} + + if self.ct_norm: + self.ct_max = 3071 + self.ct_min = -1000 + for case_path, ct in self.dict_ct.items(): + self.dict_ct[case_path] = (ct - self.ct_min) / (self.ct_max - self.ct_min) + + if self.mode != "infinite": + print("Initialization of finite mode") + # Here code hard mining when cases are too hard + self.path_mapping = {idx: random.choice(self.train_list) for idx in range(self.n_samples)} + self.path_to_idx = {} + for idx, case_path in self.path_mapping.items(): + self.path_to_idx[case_path] = self.path_to_idx.get(case_path, []) + [idx] + + # Dictionnary mapping indexes to slice number + if self.all_channels: self.channel_mapping = {idx: np.random.randint(3) for idx in self.path_mapping} + else: self.channel_mapping = {idx: 0 for idx in self.path_mapping} + # init number of slices + self.init_slice_numbers() + + if verbose: print("Loading dataset with {} samples took: {:.2f} minutes.\n".format(self.n_samples, (time.time() - a)/60)) + + + def __len__(self): + if self.mode == 'infinite': return int(1e6) + else: return self.n_samples + + def init_slice_numbers(self): + self.slice_mapping = {} + for case_path, idx_list in tqdm(self.path_to_idx.items()): + particles_to_path = self.dict_particles[case_path] + particles = sorted(list(particles_to_path.keys())) + + # Choose the slices where the uncertainty is the lowest + case = os.path.basename(case_path) + if os.path.isfile(case_path + "/{}_uncertainty_{}_0.npy".format(case, particles[-1])): + relunc = np.load(case_path + "/{}_uncertainty_{}_0.npy".format(case, particles[-1]), allow_pickle=True) + else: + relunc = np.load(case_path + "/{}_uncertainty_{}.npy".format(case, particles[-1]), allow_pickle=True) + + # Choose where the dose is the highest + dose = particles_to_path[particles[-1]][0] + + # Probability + p = np.random.rand() + if self.high_dose_only: + if p > self.p1: thresh = 0.6 + elif self.p1 >= p > self.p2: thresh = 0.2 + else: thresh = 0 + else: + thresh = self.dose_thresh + x_gt, y_gt, z_gt = np.where(dose > thresh * np.max(dose)) + x_unc, y_unc, z_unc = np.where(relunc < self.uncertainty_thresh) + + x_thresh = self.common_member(x_gt, x_unc) + y_thresh = self.common_member(y_gt, y_unc) + z_thresh = self.common_member(z_gt, z_unc) + + x_shape, y_shape, z_shape = dose.shape + half_patch_size = int(self.patch_size / 2) + half_depth = int(self.depth / 2) + for idx in idx_list: + channel = self.channel_mapping[idx] + if channel == 0: + a = np.arange(half_depth, x_shape-half_depth) + b = np.arange(half_patch_size, y_shape-half_patch_size) + c = np.arange(half_patch_size, z_shape-half_patch_size) + elif channel == 1: + a = np.arange(half_patch_size, x_shape-half_patch_size) + b = np.arange(half_depth, y_shape-half_depth) + c = np.arange(half_patch_size, z_shape-half_patch_size) + elif channel == 2: + a = np.arange(half_patch_size, x_shape-half_patch_size) + b = np.arange(half_patch_size, y_shape-half_patch_size) + c = np.arange(half_depth, z_shape-half_depth) + + a = self.common_member(x_thresh, a) + b = self.common_member(y_thresh, b) + c = self.common_member(z_thresh, c) + self.slice_mapping[idx] = (np.random.randint(np.min(a), np.max(a)), + np.random.randint(np.min(b), np.max(b)), + np.random.randint(np.min(c), np.max(c))) + + + def get_particles_to_path(self, case_path): + particles_to_path = {} + for p in glob(case_path + "/*"): + if not 'uncertainty' in p and not 'squared' in p: + n = int(os.path.basename(p).split("/")[-1].split('_')[1].split('.')[0]) + particles_to_path[n] = particles_to_path.get(n, []) + [p] + particles = sorted(list(particles_to_path)) + particles_to_path[particles[-1]] = [np.load(particles_to_path[particles[-1]][0], allow_pickle=True)] + return particles_to_path + + + + def create_pair(self, path, channel=0, idx=None, patch=True): + particles_to_path = self.dict_particles[path] + particles = sorted(list(particles_to_path)) + gt = particles_to_path[particles[-1]][0] + + # Get patch + half_patch_size = int(self.patch_size / 2) + half_depth = int(self.depth / 2) + + if self.mode == "infinite": + # Probability + p = np.random.rand() + if self.high_dose_only: + if p > self.p1: thresh = 0.6 + elif self.p1 >= p > self.p2: thresh = 0.2 + else: thresh = 0 + else: + if p > 0.5: thresh = 0.3 + else: thresh = 0. + + x_gt, y_gt, z_gt = np.where(gt >= np.max(gt) * thresh) + + x_shape, y_shape, z_shape = gt.shape + if channel == 0: + a = np.arange(half_depth, x_shape-half_depth) + b = np.arange(half_patch_size, y_shape-half_patch_size) + c = np.arange(half_patch_size, z_shape-half_patch_size) + elif channel == 1: + a = np.arange(half_patch_size, x_shape-half_patch_size) + b = np.arange(half_depth, y_shape-half_depth) + c = np.arange(half_patch_size, z_shape-half_patch_size) + elif channel == 2: + a = np.arange(half_patch_size, x_shape-half_patch_size) + b = np.arange(half_patch_size, y_shape-half_patch_size) + c = np.arange(half_depth, z_shape-half_depth) + + + a = self.common_member(x_gt, a) + b = self.common_member(y_gt, b) + c = self.common_member(z_gt, c) + + + # Get slice numbers + x = random.randint(np.min(a), np.max(a)) + y = random.randint(np.min(b), np.max(b)) + z = random.randint(np.min(c), np.max(c)) + + elif idx is not None: + # Get slice_number + x, y, z = self.slice_mapping[idx] + + if patch: + # Get ground-truth + if channel == 0: ground_truth = copy(gt[x-half_depth:x+half_depth, y-half_patch_size:y+half_patch_size, z-half_patch_size:z+half_patch_size]) + elif channel == 1: ground_truth = copy(gt[x-half_patch_size:x+half_patch_size, y-half_depth:y+half_depth, z-half_patch_size:z+half_patch_size]) + elif channel == 2: ground_truth = copy(gt[x-half_patch_size:x+half_patch_size, y-half_patch_size:y+half_patch_size, z-half_depth:z+half_depth]) + h, w, d = ground_truth.shape + + # If only a single input frame for example in the case of UNet + if self.single_frame: + # Create sequence with added CT in first place + if self.add_ct: + sequence = np.empty((2, h, w, d)) + n_particles = particles[int(self.n_frames -1)] + ind = np.random.randint(len(particles_to_path[n_particles])) + path = particles_to_path[n_particles][ind] + if channel == 0: + sequence[1] = np.load(path, allow_pickle=True)[x-half_depth:x+half_depth, y-half_patch_size:y+half_patch_size, z-half_patch_size:z+half_patch_size] + elif channel == 1: + sequence[1] = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_depth:y+half_depth, z-half_patch_size:z+half_patch_size] + elif channel == 2: + sequence[1] = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_patch_size:y+half_patch_size, z-half_depth:z+half_depth] + case = os.path.basename(path).split("_")[0] + sequence[0] = self.dict_ct[self.dict_case_path[case]][x-half_depth:x+half_depth, y-half_patch_size:y+half_patch_size, z-half_patch_size:z+half_patch_size] + # Create sequence without CT + else: + sequence = np.empty((1, h, w, d)) + n_particles = particles[int(self.n_frames -1)] + ind = np.random.randint(len(particles_to_path[n_particles])) + path = particles_to_path[n_particles][ind] + if channel == 0: + sequence[0] = np.load(path, allow_pickle=True)[x-half_depth:x+half_depth, y-half_patch_size:y+half_patch_size, z-half_patch_size:z+half_patch_size] + elif channel == 1: + sequence[0] = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_depth:y+half_depth, z-half_patch_size:z+half_patch_size] + elif channel == 2: + sequence[0] = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_patch_size:y+half_patch_size, z-half_depth:z+half_depth] + + # If several input frames + else: + # Create sequence with added CT in first place + if self.add_ct: + sequence = np.empty((self.n_frames+1, h, w, d)) + for i, n in enumerate(particles[:self.n_frames]): + ind = np.random.randint(len(particles_to_path[n])) + path = particles_to_path[n][ind] + if channel == 0: + frame = np.load(path, allow_pickle=True)[x-half_depth:x+half_depth, y-half_patch_size:y+half_patch_size, z-half_patch_size:z+half_patch_size] + elif channel == 1: + frame = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_depth:y+half_depth, z-half_patch_size:z+half_patch_size] + elif channel == 2: + frame = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_patch_size:y+half_patch_size, z-half_depth:z+half_depth] + sequence[i+1] = frame + case = os.path.basename(path).split("_")[0] + sequence[0] = self.dict_ct[self.dict_case_path[case]][x-half_depth:x+half_depth, y-half_patch_size:y+half_patch_size, z-half_patch_size:z+half_patch_size] + # Create sequence without CT + else: + sequence = np.empty((self.n_frames, h, w, d)) + for i, n in enumerate(particles[:self.n_frames]): + ind = np.random.randint(len(particles_to_path[n])) + path = particles_to_path[n][ind] + if channel == 0: + frame = np.load(path, allow_pickle=True)[x-half_depth:x+half_depth, y-half_patch_size:y+half_patch_size, z-half_patch_size:z+half_patch_size] + elif channel == 1: + frame = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_depth:y+half_depth, z-half_patch_size:z+half_patch_size] + elif channel == 2: + frame = np.load(path, allow_pickle=True)[x-half_patch_size:x+half_patch_size, y-half_patch_size:y+half_patch_size, z-half_depth:z+half_depth] + sequence[i] = frame + else: + ground_truth = gt + h, w, d = ground_truth.shape + sequence = np.empty((self.n_frames, h, w, d)) + for i, n in enumerate(particles[:self.n_frames]): + ind = np.random.randint(len(particles_to_path[n])) + path = particles_to_path[n][ind] + sequence[i] = np.load(path, allow_pickle=True) + + # Reshape + a, b, c, d = sequence.shape + sequence = sequence.reshape((a, 1, b, c, d)) + ground_truth = ground_truth.reshape((1, 1, b, c, d)) + # Normalize by the max dose of the complete sequence (including ground truth) + m = np.max(ground_truth) + if self.normalized_by_gt: + sequence /= m + ground_truth /= m + # Else, scale between -1 and 1 + elif self.standardize: + sequence = (sequence - np.mean(sequence)) / np.std(sequence) + ground_truth = (ground_truth - np.mean(ground_truth)) / np.std(ground_truth) + # Raw data + elif self.raw: + sequence = sequence + ground_truth = ground_truth + # Else put every frame between 0 and 1 + else: + sequence /= np.ndarray.max(sequence, axis=(2, 3, 4))[:, np.newaxis, np.newaxis, np.newaxis] + ground_truth /= m + return sequence, ground_truth + + + + def common_member(self, a, b): + a_set = set(a) + b_set = set(b) + + if (a_set & b_set): + return list(a_set & b_set) + else: + print("No common elements") + return b + + + + def crop_and_adapt(self, img): + r_h, r_w, r_d = None, None, None + H, W, D = img.shape[-3], img.shape[-2], img.shape[-1] + if H % 2**3 != 0: r_h = - (H % 2**3) + if W % 2**3 != 0: r_w = - (W % 2**3) + if D % 2**3 != 0: r_d = - (D % 2**3) + return img[..., :r_h, :r_w, :r_d] + + + def __getitem__(self, idx): + + if self.mode == "infinite": + # Get path to random case + path = random.choice(self.train_list) + else: + # Get path of precise case +# path = self.path_idx_dict[idx] +# path = random.choice(self.train_list) + path = self.path_mapping[idx] + + # Get sequence et the frame to predict + sequence, next_frame = self.create_pair(path, patch=True, idx=idx, + channel=0) + + # Turn into tensors + sequence = torch.from_numpy(sequence) + next_frame = torch.from_numpy(next_frame) + + # Apply transformations + p = np.random.rand() + if self.transform and p > 0.5: + torch.manual_seed(idx) + composed = transforms.Compose([transforms.RandomHorizontalFlip(p=1), + transforms.RandomVerticalFlip(p=1)]) + + # Concat to transform + all_seq = torch.cat([sequence, next_frame], axis=0) + all_seq = composed(all_seq) + sequence = all_seq[:-1] + next_frame = all_seq[-1:] + + if self.unet: + # Crop to be processed by UNet + sequence = self.crop_and_adapt(sequence) + next_frame = self.crop_and_adapt(next_frame) + t, c, h, w, d = sequence.shape + sequence = torch.reshape(sequence, (t, h, w, d)) + next_frame = torch.reshape(next_frame, (1, h, w, d)) + return sequence.float(), next_frame.float() + + + def get_volumes(self, case_path): + particles_to_path = self.dict_particles[case_path] + particles = sorted(list(particles_to_path.keys())) + sequence, gt = self.create_pair(case_path, + channel=0, + patch=False) + sequence = torch.from_numpy(sequence) + sequence = self.crop_and_adapt(sequence) + gt = self.crop_and_adapt(gt) + H, W, D = sequence.shape[-3], sequence.shape[-2], sequence.shape[-1] + + if self.unet and not self.single_frame: + sequence = self.crop_and_adapt(sequence) + gt = self.crop_and_adapt(gt) + t, _, h, w, d = sequence.shape + sequence = torch.reshape(sequence, (t, h, w, d)) + elif self.unet and self.single_frame: +# sequence = self.crop_and_adapt(sequence)[-1] +# gt = self.crop_and_adapt(gt) + sequence = sequence[-1] + _, h, w, d = sequence.shape + sequence = torch.reshape(sequence, (1, h, w, d)) + + gt = np.reshape(gt, (H, W, D)) + return sequence, gt \ No newline at end of file diff --git a/.ipynb_checkpoints/train-checkpoint.py b/.ipynb_checkpoints/train-checkpoint.py new file mode 100644 index 0000000..99e32bf --- /dev/null +++ b/.ipynb_checkpoints/train-checkpoint.py @@ -0,0 +1,247 @@ +from losses import * +from utils import * +from training_functions import * +from mc_dataset_infinite_patch3D import * +from convlstm3D import * +from stack_convlstm3D import * + + +import torch +from torch.utils.data import Dataset, DataLoader +from torchvision import transforms +from torch.utils.tensorboard import SummaryWriter + +import numpy as np +from glob import glob +from tqdm import tqdm +import os, sys, time, random, argparse +import SimpleITK as sitk +import datetime +from configparser import ConfigParser + + + +# Parser parameters +args = parse_args() +# Get the cases paths +cases = list_cases(args.simu_path, exclude=[]) +# Instantiate the selected model and indicate whether it's UNet based or not +model, unet = instantiate_model(args) +# Create train and validation dataloaders +train_loader, val_loader = create_dataloaders(args, cases, unet) +# Get the optimizer +optimizer = instantiate_optimizer(args, model) +# Get the learning rate scheduler +my_lr_scheduler = create_lr_scheduler(args, optimizer) +# Get the loss function +loss = create_loss(args) +# Write training configuration file +save_path = create_saving_framework(args) +# Instantitate tensorboard writer to write results for training monitoring +writer = SummaryWriter(save_path) + + +torch.cuda.set_device(args.gpu_number) +model.cuda() +if args.mode == "infinite": + print("Infinite training (mode: {})".format(args.mode)) + + # Limited number of iterations + iter_limit = int(6e5 / args.batch_size) + + + count_no_improvement = 0 + best_val, best_train = np.inf, np.inf + model.train() + val_step = 10 + loss_train, ssim_train, mse_train, l1_train = 0, 0, 0, 0 + a = time.time() + for iteration, data in enumerate(train_loader, 0): + if iteration > iter_limit: + print("Stopped training at 1e5 iterations.") + break + + sequence, target = data + sequence = sequence.float().cuda() + target = target.float().cuda() + + loss_, ssim_, mse_, l1_ = train(sequence, target, model, loss, optimizer, unet) + + loss_train += loss_ / val_step + ssim_train += ssim_ / val_step + mse_train += mse_ / val_step + l1_train += l1_ / val_step + + + # Validation step + if iteration % val_step == 0: + loss_val, mse_val, ssim_val, l1_val, pred, gt = validate(model, loss, val_loader, n_val=n_val, unet=unet) + + # Decrease learning rate when needed + if lr_scheduler == "plateau": + my_lr_scheduler.step(loss_val) + else: + my_lr_scheduler.step() + + # Writing to tensorboard + writer.add_scalars("Loss: {}".format(loss_name), {"train":loss_train, "validation":loss_val}, iteration) + writer.add_scalars("SSIM", {"train":ssim_train, "validation":ssim_val}, iteration) + writer.add_scalars("MSE", {"train":mse_train, "validation":mse_val}, iteration) + writer.add_scalars("L1", {"train":l1_train, "validation":l1_val}, iteration) + writer.add_scalar("Learning rate", get_lr(optimizer), iteration) + + # Create figure of samples to visualize + if iteration % 20 == 0: + idx = int(target.shape[1] / 2) + for k in range(len(pred)): + fig = plt.figure(figsize=(12, 6)) + plt.subplot(121) + plt.title("Prediction") + plt.axis('off') + plt.imshow(pred[k, 0, idx], cmap="magma") + plt.subplot(122) + plt.title("Ground-truth") + plt.axis('off') + plt.imshow(gt[k, idx], cmap="magma") + writer.add_figure("Sample {}".format(k), fig, global_step=iteration, close=True) + writer.flush() + + print("Iteration {} {:.2f} sec:\tLoss train: {:.2e} \tLoss val: {:.2e} \tL1 train: {:.2e} \tL1 val: {:.2e} \tSSIM train: {:.2e} \tSSIM val: {:.2e}".format( + iteration, + time.time() - a, + loss_train, loss_val, + l1_train, l1_val, + ssim_train, ssim_val)) + # Save models when reaching new best on validation + if loss_val < best_val: + count_no_improvement = 0 + best_val = loss_val + torch.save({ + 'epoch': iteration, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss}, + save_path + "/best_val_settings.pt") + torch.save( + model.state_dict(), + save_path + "/best_val_model.pt") + elif count_no_improvement > 5000: + print("\nEarly stopping") + break + elif iteration > 500: + count_no_improvement += 1 + + if loss_train < best_train: + best_train = loss_train + torch.save({ + 'epoch': iteration, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss}, + save_path+ "/best_train_settings.pt") + torch.save( + model.state_dict(), + save_path + "/best_train_model.pt") + + # Reset + loss_train, ssim_train, mse_train, l1_train = 0, 0, 0, 0 + a = time.time() + + +else: + print('Finite training (mode: {})'.format(args.mode)) + n_epochs = 100 + count_no_improvement = 0 + model.train() + val_step = 10 + iter_limit = 1e5 + loss_train, ssim_train, mse_train, l1_train = 0, 0, 0, 0 + best_val, best_train = np.inf, np.inf + a = time.time() + + for epoch in range(n_epochs): + for iteration, data in enumerate(train_loader, 0): + if iteration + len(train_loader) * epoch > iter_limit: break + sequence, target = data + sequence = sequence.float().cuda() + target = target.float().cuda() + + loss_, ssim_, mse_, l1_ = train(sequence, target, model, loss, optimizer, unet) + + loss_train += loss_ / val_step + ssim_train += ssim_ / val_step + mse_train += mse_ / val_step + l1_train += l1_ / val_step + + # Validation + if iteration % val_step == 0: + loss_val, mse_val, ssim_val, l1_val, pred, gt = validate(model, loss, val_loader, n_val=n_val, unet=unet) + + # Decrease learning rate when needed + if lr_scheduler == "plateau": + my_lr_scheduler.step(loss_val) + else: + my_lr_scheduler.step() + + writer.add_scalars("Loss: {}".format(loss_name), {"train":loss_train, "validation":loss_val}, iteration + len(train_loader) * epoch) + writer.add_scalars("SSIM", {"train":ssim_train, "validation":ssim_val}, iteration + len(train_loader) * epoch) + writer.add_scalars("MSE", {"train":mse_train, "validation":mse_val}, iteration + len(train_loader) * epoch) + writer.add_scalars("L1", {"train":l1_train, "validation":l1_val}, iteration + len(train_loader) * epoch) + writer.add_scalar("Learning rate", get_lr(optimizer), iteration + len(train_loader) * epoch) + + if iteration % 20 == 0: + idx = int(args.patch_size / 2) + for k in range(len(pred)): + fig = plt.figure(figsize=(12, 6)) + plt.subplot(121) + plt.title("Prediction") + plt.axis('off') + plt.imshow(pred[k, 0, idx], cmap="magma") + plt.subplot(122) + plt.title("Ground-truth") + plt.axis('off') + plt.imshow(gt[k, idx], cmap="magma") + + writer.add_figure("Sample {}".format(k), fig, global_step=iteration + len(train_loader) * epoch, close=True) + writer.flush() + + print("Iteration {} {:.2f} sec:\tLoss train: {:.2e} \tLoss val: {:.2e} \tL1 train: {:.2e} \tL1 val: {:.2e} \tSSIM train: {:.2e} \tSSIM val: {:.2e}".format( + iteration + len(train_loader) * epoch, + time.time() - a, + loss_train, loss_val, + l1_train, l1_val, + ssim_train, ssim_val)) + # Save models + if loss_val < best_val: + count_no_improvement = 0 + best_val = loss_val + torch.save({ + 'epoch': iteration + len(train_loader) * epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss}, + save_path+ "/best_val_settings.pt") + torch.save( + model.state_dict(), + save_path + "/best_val_model.pt") + elif count_no_improvement > 2000: + print("\nEarly stopping") + break + else: + count_no_improvement += 1 + + if loss_train < best_train: + best_train = loss_train + torch.save({ + 'epoch': iteration + len(train_loader) * epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': loss}, + save_path+ "/best_train_settings.pt") + torch.save( + model.state_dict(), + save_path + "/best_train_model.pt") + + # Reset + loss_train, ssim_train, mse_train, l1_train = 0, 0, 0, 0 + a = time.time() diff --git a/.ipynb_checkpoints/training_functions-checkpoint.py b/.ipynb_checkpoints/training_functions-checkpoint.py new file mode 100644 index 0000000..7551ae2 --- /dev/null +++ b/.ipynb_checkpoints/training_functions-checkpoint.py @@ -0,0 +1,73 @@ +import torch +import pytorch_ssim + + +def validate(model, criterion, dataloader, n_val, unet=False): + running_val_loss, running_mse_loss, running_ssim_loss, running_l1_loss = 0, 0, 0, 0 + # Losses + mse_loss = torch.nn.MSELoss() + l1_loss = torch.nn.L1Loss() + ssim_loss = pytorch_ssim.SSIM() + # Validation + count, count_batch = 0, 0 + with torch.no_grad(): + for i, data in enumerate(dataloader, 0): + + sequence, target = data + sequence = sequence.float().cuda() + target = target.float().cuda() + + outputs = model(sequence) + if unet: target = target[:, 0] + else: target = target[:, 0, 0] + + # Training loss + loss = criterion(outputs[:, 0], target) + + # Evaluation losses + mse_ = mse_loss(outputs[:, 0], target).item() + ssim_ = ssim_loss(outputs[:, 0], target).item() + l1 = l1_loss(outputs[:, 0], target).item() + + running_val_loss += loss.item() + running_l1_loss += l1_ + running_mse_loss += mse_ + running_ssim_loss += ssim_ + + if count > n_val: break + else: + count_batch += 1 + count += len(target) + + # Get the average loss per batch + running_val_loss /= count_batch + running_mse_loss /= count_batch + running_ssim_loss /= count_batch + running_l1_loss /= count_batch + return running_val_loss, running_mse_loss, running_ssim_loss, running_l1_loss, outputs.detach().cpu().numpy()[:5], target.detach().cpu().numpy()[:5] + +def train(sequence, target, model, loss, optimizer, unet=False): + + # Losses + mse_loss = torch.nn.MSELoss() + ssim_loss = pytorch_ssim.SSIM() + l1_loss = torch.nn.L1Loss() + + # Prediction + outputs = model(sequence) + + if unet: target = target[:, 0] + else: target = target[:, 0, 0] + loss_value = loss(outputs[:, 0], target) + + # Backpropagation + loss_value.backward() + optimizer.step() + optimizer.zero_grad() + + # print statistics + loss_ = loss_value.item() + ssim_ = ssim_loss(outputs[:, 0], target).item() + mse_ = mse_loss(outputs[:, 0], target).item() + l1_ = l1_loss(outputs[:, 0], target).item() + return loss_, ssim_, mse_, l1_ \ No newline at end of file diff --git a/.ipynb_checkpoints/utils-checkpoint.py b/.ipynb_checkpoints/utils-checkpoint.py new file mode 100644 index 0000000..2e2d199 --- /dev/null +++ b/.ipynb_checkpoints/utils-checkpoint.py @@ -0,0 +1,234 @@ +import os +import torch +from torch.utils.data import DataLoader +from glob import glob +import argparse +from configparser import ConfigParser +from datetime import datetime +from mc_dataset_infinite_patch3D import * +from convlstm3D import * +from stack_convlstm3D import * +from losses import * +from lunet4_bn_leakyrelu_3d import * +from bionet3d import * + + +def parse_args(): + parser = argparse.ArgumentParser( + description='3D ConvLSTM training', + add_help=True) + + parser.add_argument("--simu_path", '-simu', default='/workspace/workstation/lstm_denoising/numpy_simulations/', type=str, + help="Path pointing to the folder where the MC simulations are kept in .npy format") + parser.add_argument("--ct_path", '-ctp', default='/workspace/workstation/numpy_cropped_images/', type=str, + help="Path to CT images") + parser.add_argument("--n_samples", '-n', default=0, type=float, + help='Number of training samples') + parser.add_argument("--patch_size", '-ps', default=64, type=int, + help='Size the height and width of a patch') + parser.add_argument("--n_train", '-nt', default=40, type=int, + help='Number of cases to use in training set') + parser.add_argument("--model_name", '-m', default='stack3D_deep', type=str, + help='Name of the model') + parser.add_argument("--loss_name", '-l', default='ssim-smoothl1', type=str, + help='Name of the loss') + parser.add_argument("--learning_rate", '-lr', default=5e-5, type=float, + help='Initial learning rate') + parser.add_argument("--weight_decay", default=1e-6, type=float, + help='Initial weight decay') + parser.add_argument("--optimizer_name", default='adamw', type=str, + help='Name of the optimizer') + parser.add_argument("--save_path", '-save', default='.', type=str, + help="Path to save the model's weigths and results") + parser.add_argument('--gpu_number', '-g', default=4, type=int, + help='GPU identifier (default 0)') + parser.add_argument('--all_channels', '-ac', default=False, action='store_false', + help='Whether to select patches from all views') + parser.add_argument('--normalized_by_gt', '-ngt', default=True, action='store_true', + help='Whether to normalize input data by ground-truth maximum') + parser.add_argument("--standardize", '-st', default=False, action='store_false', + help='Whether to standardize input data') + parser.add_argument("--uncertainty_thresh", '-u', default=0.2, type=float, + help='Uncertainty threshold') + parser.add_argument("--dose_thresh", '-dt', default=0.2, type=float, + help='Dose threshold') + parser.add_argument("--n_frames", '-nf', default=3, type=int, + help='Number of frames in the input sequence') + parser.add_argument("--batch_size", '-bs', default=8, type=int, + help='Batch size') + parser.add_argument("--add_ct", '-ct', default=False, action='store_false', + help='Whether to add CT in input sequence') + parser.add_argument("--ct_norm", '-nct', default=True, action='store_true', + help='Whether to normalize the CT') + parser.add_argument("--high_dose_only", '-hd', default=False, action='store_false', + help='Whether to train only on high dose regions') + parser.add_argument("--p1", '-p1', default=0.1, type=float, + help="Probability below which you draw patches from low dose regions") + parser.add_argument("--p2", '-p2', default=0.6, type=float, + help='Probability above which you draw patches from high dose regions') + parser.add_argument("--single_frame", '-sf', action='store_false', + help='Whether you train on single frame instead of sequence') + parser.add_argument("--mode", '-mode', default='infinite', type=str, + help="Whether to do finite or infinite training") + parser.add_argument("--lr_scheduler", '-lrs', default='plateau', type=str, + help="Name of the learning rate scheduler to use") + parser.add_argument("--depth", '-d', default='64', type=int, + help='Depth of an input patch') + parser.add_argument("--raw", '-r', default=False, action='store_false', + help='Whether to train on raw data with no preprocessing whatsoever') + parser.add_argument("--n_layers", '-nl', default=3, type=int, + help="Number of layers for BiONet3D") + args = parser.parse_args() + return args + + +def instantiate_model(args): + + if args.model_name == "stack3D": + unet = False + model = stack_model_3D() + elif args.model_name == "stack3D_deep": + unet = False + model = stack_model_3D_deep() + elif args.model_name == "lunet4-bn-leaky3D": + model = LUNet4BNLeaky3D() + unet = False + elif args.model_name == "lunet4-bn-leaky3D_big": + model = LUNet4BNLeaky3D_big() + unet = False + elif args.model_name == "bionet3d": + if args.single_frame and args.add_ct: model = BiONet3D(input_channels=2, num_layers=args.num_layers) + elif args.single_frame: model = BiONet3D(input_channels=1, num_layers=args.num_layers) + elif not args.single_frame and args.add_ct: model = BiONet3D(input_channels=args.n_frames + 1, num_layers=args.num_layers) + else: model = BiONet3D(input_channels=args.n_frames, num_layers=args.num_layers) + unet = True + elif args.model_name == "unet3d": + if args.single_frame and args.add_ct: model = UNet3D(n_frames=2, num_layers=args.num_layers) + elif args.single_frame: model = UNet3D(n_frames=1, num_layers=args.num_layers) + elif not args.single_frame and args.add_ct: model = UNet3D(n_frames=args.n_frames + 1, num_layers=args.num_layers) + else: model = UNet3D(n_frames=args.n_frames, num_layers=args.num_layers) + unet = True + else: + print("Unrecognized model") + for i in range(1):break + print("Model: ", args.model_name) + return model, unet + + +def instantiate_optimizer(args, model): + params = [p for p in model.parameters() if p.requires_grad] + if args.optimizer_name == "adam": optimizer = torch.optim.Adam(params, lr=args.learning_rate, weight_decay=wargs.eight_decay) + elif args.optimizer_name == "adamw": optimizer = torch.optim.AdamW(params, lr=args.learning_rate) + elif args.optimizer_name == "sgd": optimizer = torch.optim.SGD(params, lr=args.learning_rate, weight_decay=args.weight_decay) + else: print('Unrecognized optimizer', args.optimizer_name) + return optimizer + + +def create_loss(args): + if args.loss_name == "mse": loss = nn.MSELoss() + elif args.loss_name == "l1": loss = nn.L1Loss() + elif args.loss_name == "l1smooth": loss = nn.SmoothL1Loss() + elif args.loss_name == "ssim": loss = ssim_loss + elif args.loss_name == "ssim-smoothl1": loss = ssim_smoothl1 + elif args.loss_name == "ssim-mse": loss = ssim_mse + print("Loss used: ", args.loss_name) + return loss + + + +def create_lr_scheduler(args, optimizer): + # Learning rate decay + if args.lr_scheduler == "plateau": + decayRate = 0.8 + my_lr_scheduler= torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, + mode='min', + factor=decayRate, + patience=15, + threshold=1e-2, + threshold_mode='rel', + cooldown=5, + min_lr=1e-7, + eps=1e-08, + verbose=True) + elif args.lr_scheduler == "cosine": + # Learning rate update: cosine anneeling + my_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, + T_max=500, + eta_min=1e-8, + verbose=True) + return my_lr_scheduler + +def list_cases(path, exclude=[]): + return [p for p in glob(path + "*") if len(os.path.basename(p)) == 4 and not os.path.basename(p) in exclude] + + +def get_lr(optimizer): + for param_group in optimizer.param_groups: + return param_group['lr'] + + +def create_saving_framework(args): + # Don't forget to save all intermediate results and model + now = str(datetime.now()) + training_name = "train_{}_{}".format(now.split(' ')[0], now.split(' ')[-1]) + os.system("mkdir {save_path}/{training_name}".format(save_path=args.save_path , + training_name=training_name)) + save_path = args.save_path + "/" + training_name + print("Training:", training_name) + + parameters = {arg: getattr(args, arg) for arg in vars(args)} + config_object = ConfigParser() + config_object['Parameters for {}'.format(training_name)] = parameters + + # Write configuration file + with open(save_path + '/config.ini', 'w') as conf: + config_object.write(conf) + return save_path + + +def create_dataloaders(args, cases, unet): + n_train = args.n_train + train = MC3DInfinitePatchDataset(cases[:n_train], + n_frames = args.n_frames, + ct_path = args.ct_path, + patch_size = args.patch_size, + all_channels = args.all_channels, + normalized_by_gt = args.normalized_by_gt, + standardize = args.standardize, + uncertainty_thresh = args.uncertainty_thresh, + dose_thresh = args.dose_thresh, + unet = unet, + add_ct = args.add_ct, + ct_norm = args.ct_norm, + high_dose_only = args.high_dose_only, + p1 = args.p1, + p2 = args.p2, + single_frame = args.single_frame, + mode = args.mode, + n_samples = args.n_samples, + depth = args.depth) + + val = MC3DInfinitePatchDataset(cases[n_train:n_train+5], + n_frames = args.n_frames, + ct_path = args.ct_path, + patch_size = args.patch_size, + all_channels = args.all_channels, + normalized_by_gt = args.normalized_by_gt, + standardize = args.standardize, + uncertainty_thresh = args.uncertainty_thresh, + dose_thresh = args.dose_thresh, + unet = unet, + add_ct = args.add_ct, + ct_norm = args.ct_norm, + high_dose_only = args.high_dose_only, + p1 = args.p1, + p2 = args.p2, + single_frame = args.single_frame, + mode = 'finite', + n_samples = args.n_samples, + depth = args.depth) + + + train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, num_workers=6) + val_loader = DataLoader(val, batch_size=args.batch_size, shuffle=False, num_workers=6) + return train_loader, val_loader diff --git a/__pycache__/bionet3d.cpython-38.pyc b/__pycache__/bionet3d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d905a78f0ee14012548575cd60b846d0d2897ec8 GIT binary patch literal 4269 zcmcInOOM>f5$|KmTYf=q`2jxk%?OZWZ@I(=A&FAeoR_4v0Ru4!7TML+ zRn=d2R~0{MHa!B@pMLk}U+sL4kiTMM_A;UIK0NttAe?X-lZ5&-VLnR?-$+c~Of26@ zY~N05el2l)Cvkn3qAZKOr0&;~hTov%L&6Pi?h|ebIyC$yx9*UReFR!b$5r;NjQ3-q zygOpIFVZ3kV`X2<(vPojRsSF?y7xCSnLuL>nRjKF=DjQzz%YLyZhoZPTj8VISr!AM zb|>0SvWP49b|}L{6hd~GvI~*snN&s>l%dht%NaHDtMDO+PdV`!ZwN!*CQysppV^<1 zEy`=$*>C!`sEHbPKV#hE^?lQKJ|!mcU6@i+*&jqd-w?$S1!YGJ$lojV|9tNV(Y(>I zRVzw+{UYez57Sh{xvHoABlO8)0bJ1Vz~j~m;Whwba_7J*nnO~w#(a1?+&y4gT1$?5 zzJ(EUxCp$#n+L|w125;d3D5;i7d2hdw5?}k;%uI>d;XwS&2ZGdF>uhw(t%mk+hl;V z*R28Q#4ki*g$(LtjaxhHDV2XH8l_u$(-{pK!0}$;G&QF_BZC+%9G)l}WwUH;8P6ys zMICwCGGJOES)x^q;iS)evu|h*=4M9MClaW!T7g+~%)$HhuluHWP5LHyRT9VGH8chb8jpnu9#f1KC%lCb zMs1DJLWR*&y88}1%cE1`@udnWd{1}4%t@*GXHUvE8L^D3@P>^MZpT|0!mbWh^p2@L zHRbhJ&ZpVd>-y@sPS3m|^R1tHh$a|O+6QKFdPs(FCWG&A0)h3!#`fd(A3ywJZ~wh6 zhSu?g5!Qxh37-5ckZn?ud#4Aq(D|yQOGM9$wjqZ7mhGCLiyf_kpu|uKxxs*7Do?=o zbEeqZMMb0MRGGPmx1N(1<%Nft+{t@kS5#kx4$T+hyhs8r(k#lObo;`U>p>6GRFHJ1 z++b7WMUZ41V*G_9>UP5@xP0;QC6GRkdFvlhMu3Z^ zrUV%0knsZ~Zv)voH_2>|FAl0y_W05u5NQ`6l0g?B$@k8Fy%5hHrX91=j7qh7)ynz> zX4H|_UuB7lSxc+OTN(@Xdq*$f>agiFI%kxX_j^LhCYV;v7=4L}SGJMhK9VPq;DY%} zn|fEIP`$uLsPI!3)Gyn@Z)`;|r20I-6_a{_aD63abZH^M!bC0iqzDQud0?kNnE)sS zo3PqDeqD-wE-KuVGlDVqn-v@90e~uJG@)GCVXr4rt~^*<+?H5z*{C!uBpxAEOCru= zkxLg_D^Ow;$6D?z%D2!zJ#9RA$ITszlyM=WEM2$wUX+fD9;gc=K5fAIjT_2Bzb{p< z%hl^j^~#lr>#tb92S?Nms+uVXrlpWiYx-O}a?W%+SV{ zc5^HQkaCIhd&v>S|>FIOphX@d_AVc7Ry)0lePrslu)=3YDe5X0yrWlK@?W!lAnIKHsB z$4zt%o_rI|Cx_y0AIid#JfKxsH-KXA4IY&06CNAMnI$%5J{Dd zp3@Bpzg|S1MxwtoS;gk}fehi9M*g#-#fNkIFqHhb!2iwc%lFaH*Ltu&>%sqB`)_8o zSNY=|bV=tWstg__9Z&ukSvs)lIQuR(&m(Cdxr}5D35H#H1qh7MK?;BBlmXEZG>NY7 zg>f8I>B2We5cviVBKZ>>^;00e`3O?#Ys;W&BeZ`F5;@u;aFX%%ATpS=1u?^+Ci@xC zMd}%z*`hzJz9B9+5IH=$*Lro~I`&Ko*7+1Shg<-1qvK(2N#XA`2$UC$Vl{j>f?&5F z##N780a;~j#aUP=a}$1lb@Y>U)N>vCnxhEW6)@*uW>V`YvWf%~8RnWCv2qOfOB~-X zAo@)27xRy`>_(@hoFbFm`)g?`G1kmvE}gV6lTLe}Aebu?jr&UUzxT<_?mypAC( zQM=DN=I6{{9Sbt+MeAx3e|+u6&Fhyou7G6SWSg}WxK2PCH{YS@hJOU{_%j<2`lD2x1Fv>&>YJ-Pw5xK@pL*Oz`j z%tDz~U8~hqbw%-Dn0X@Vue(v4Royhp*08@FN;lbYgRURPVK1e&dcn=YI8EeqvTh_q z82>+V$~W9J^RsX&=XNMDk(Xj$>+U(MXoZ%gNc-60ZqUxLoU9wV0S-Co4YMeT{hnKj zn2Owv{B+&jPLkd+0UvZjyo>EgH%>A)OI%+FC5~R0zD`cmxYTJd`^@5=+%s~C| zsNg#VPYS+USb-OZpJavARsRpkc-v;th+=P|!lIc#-_=uNX2V5ofF(rnZeh?8w%QAY znGOe`lr(!`dtQV?_q?KB9a`Z`6)%@CN}eEclE$}ghv5LPc>S=S$a{q~{@3ETu*haE ziNYY&wXMQP!`_a3n%;eeh!RCoy*e@Ebaf|@ ztJjk_3swu8*77mav5u?;Z`w$Ew(g^Te!c^Egdbc9sm8yuD7_&4f_3O z|I}oXG9s)^lp0VJm|Ubp6Zciz=@&TmT>DTP=_4a!dmM2=-_-=d&jW1)|2))k^N5cu z@bsMg)=)hCLt|7M*`xZ%$=IQp*D$7`S~glu)vBYlfR>dzXtnZYP7$Dyx0cx{Hd@R} z-r^BoW}_vF2*2m@clfu3ZnD#eLWlJucIe=49I;*gGD6v9iXgTy6yJCv-&Fj@6Zw|n zH=oF_DgMF}`L^P>p2)8&{^Aq)j^ZypmcNuQ?r4aJjV(zI2J6%jza$-)q?j(rHl$62 z!n>ptCL2VWAnj&hW{C{Cg&z2s{0_R{W`kaoDbp`%X|&rv?B2ZLc@ zi7@l=t{D&egL~~0@+`5Mr9+lMv8W?Dcq*<<*HqX+(%+8aP;?e1Esw0aV+4U0DY~O* zqUi3oZ$cs&vy!;9fEpa;t}n$!Jf!D97%cuIcla6H>ZdL7mS%Hz?zjKUoIi{O9dn0C zqEzm%3E^-Do!nt$jva#SX-JUr!l0$I{OqIsGjsZw8QkWZ?HYyQ7VH`XuKYQNHdcI3TTp|?V$$(SRwXL;&XrESn0UnQC6&|{V?g|faprA#C zi3nTOfrXYtP1yH&#}swp02NtbO*BOFK0DwY8|?+rQti4zOAg%WMBz;L?q~WKY)oOq zdX+z`{ImWzkPyWsNT`NSbJXnkF#wpP!fNzPMqGC%X};`)+}UJ51vc3hr&%_H-rPb2_^q(&eEQCBAI6E?i?~eYysFoeseFlQCcDdGQTl$n95K% z<8{!I<8D(ED<-Xu;Yl@EhS^ZYo)3eVtouFHrR4XbG*h@M=>wgq|>Z?3pkSKPN&+z(dV4O$pF&^_%_g5DWs$0XeZszSQQA{1i+ zv{aW`?-U=s^Io+n%s(9TN?3L`acLI{GWNoFH|x67LvqhyA>-bQZnWcmLmi~)O~03h z?deQ}NiVpor2@QiWu8*i_*f=to1ge5t-m~VwVi6gB=Mq>2uz|P_T%M|CKP*WarEls zQli6N2s=^7MwoQ6E`O|(bKMm;XsbgU8@>|g`mu0JVCJZ0ZQcwhX-zK5Sd_oujV%2q z$Xt!W4)t6=;;2^GAyp~H5parXYpShfs9Fvy5kd5jEa$$S*+)@8DvL%cMz)+ms`D1qS@TnZQ2(v5;hOlIA z?Vo--4idOviUgaLnk|iIO@VWT0@~&i;cNS8`4luKsYX?yxp)o$JPdAoGogl(aXCY9 zuXx4-iVtJq!RNgGu$M*3=nEaw7S8*AZy3HIWg_RMES#x{7mb;iQUdZ5->`CpE(%ln z@ordH{$PMPRoIuOvDU)5uF}gjY#gA1k_TB;yQz}k+}!dt+TH6!6mGvtEduZIO(Fz| zfII+9s@N@6_EA8>r6Gi>cX6i|K$zCj9KFe#{5ftDzMlPQ24CW5fhOojsM@jipPJ^5 z3R7<#LlFe335C(h(hCHt*&a*J2&57uJzAf;mzN{-LEX22jSx8o~GguXPxt zmDhH3Q3LqaV1_oCB8?Xu5ZcKJWTOP%GY>e-@{3UuM%_Hk0N>5g0>-w$brkO!05%oc z;N}fs3;Q=(zF7W$lkQj;W#$V2{`Ch`3_x1mIz*wb%=XjFhC#Qm4+ra7%9mjBHn>a5 zHS08v-c@6_F{=Y0Z^M=9goAP$nHjUgQP5Mh>(s`l1>857u;B_#{}j|#MXzaQuDYuF zezoeFg{JE$b460ID1N=Py5Zhb#k?CvWkHPO(e($Rdk%H?Hlo&wi`aE{#T9MJALq~o z#!PP|ugqZ67&%@|#i#T!6J}SAMRJloH{jQ+7jMz+o}0@>(>yf&D4?Q184oB}J|rR+ zs71J{t;rXu;}Xag9)3g*D@2|LdH4udyP@0%Rd^`-!z@HmO!%@{R^4ULmK*enay8{v zZ_wilASfXSU`R6VaaEw^=rqURM@9RBZx`Gvc(>qp3obe*R9Kl!Wu3)29dv4Gc*jHB z=`51>mM^F?E$ZV3@kzmh4t-akPdO33)X0Koo#w3Y8PMs{nhBi#7I*p*2y-lLSzAW& zY(KJeIJD;AZfQ->rnbhdM-Kj5s(*>xUIm8zlQS;=xUzxtPH~5HqJj$CPtVBhk#pdj zXGh3bsGiBOjnkCBc&;6lPSVQt8>sBI5Z0$ESK3y3Cs+*@6$8_A?UcyAOw?ulbi7QG z*D?Mtw3-s~8@z=AY4(%A+U#KmaJK5AHkp*#44D#bB;yBKx~5v%fhK>D;RB)2?t`-d zoPEH-GX&D_XZUIf|24LbiJV~m8<3XLXO~^Pc&jzy5XedNAGABhjk3&ZYbbQ^RRx=i zgAf%g0WEZDYZnTGpvT++7tUqM6(}_aSN=}&scx{C#`a&CIkj(3gKL+2NdOFX1&xWU z{nd|VlIBNle+X!_9ZrBBTo@x2T%-l2IHO-;V)-kO!XRFu1H;N!@!aO+XQsI`dIhY% z;!X+rFj%7fNaIV7G;P2AKd(F>)lM#IS0_i6e06fmIX;@?^l=4>hoV0K6R|<9_?BuH z=77GUmL5O1&?^}658SF`qlG$r|M}S^J{Ci?jm?{zn{8NjL;e_V%byTg2PtZvCz1dK zB4y&Td?bS{nq<*YfN#0oL`c$)q`Bo!iA)!z(z#201a#iToziY?%^=&N>rv}@=aOw{ z+%Yw-M&Z++tv-gV#j(O1WiYc!;88*m1)HYRsFn+9g?+g!f3J{lC@nM*8czQFzpn@Y Ap8x;= literal 0 HcmV?d00001 diff --git a/__pycache__/losses.cpython-38.pyc b/__pycache__/losses.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41e79280dbb70656e4e1f71aab26839eb8a95377 GIT binary patch literal 830 zcmbtS&2AGh5VpOWWK$AAs1itAxhzt;q`e}9lp=wYJ0w-p?L@fuz0b8CKkL~&SKEEE1cNmUupTEEP!WsKPo0~)6<~c5S zi9#~TYu50L^UfBTK%c0CBif5P{H=Sa0~vl|Szq>KAN4>EWQ01BJZ154gVESB<}H6* zQi|;$JH$1^1&>fj?${Oo$}ZU@|0s4b$$jR;Rj~CuU@2P%%grv{-CzSO4qKshHcA^~ z-06HeHvsKub>WOHPji5(X~Q&qwOETRFz2onCvMH`EqwX&L-O9(GguX+>MU2d%IKs9 z*W^-ZQ$eLql7qv1RX7aRl-J(;tpb-fMyeW;rYg&#%J-k_Px1C-X7syxdbEJav9Xhd z(N0Zfg{@6eV{r=hRu}DyMWBaJG!6JP@xLC?#BN9Ur zghD4qXtj~AsR2b_WLwT{wEtetJv{xBG9oAzOWVVfSo|a7&}Sr6sOxRYJM?szMV`Vv zt$hc&$18@+oJIsl#` Vq35-FMjs34?ubX-6j2ZbzW|-twS@ox literal 0 HcmV?d00001 diff --git a/__pycache__/lunet4_bn_leakyrelu_3d.cpython-38.pyc b/__pycache__/lunet4_bn_leakyrelu_3d.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..315f24bcdc4f4f3d5883927dd34a20f9ee3dff5e GIT binary patch literal 5165 zcmb7ITXP&o74Dwf&fc`El`PAS0g(%MLo7v-fe>6FSjq*k1%+i2whY5+W>y+$cXrh? zYsF%=Dnw2tMU|^uJn$?%ng^cv2mAvb`;{kE@e3%3`A*MXeKhrx39lGXKxpa zIfn1cKmYxAxl4@wlQQF5IMv;jJ&MB~QxiQ(=uyTRu-Ol2NId&FhdaolD-iX7(=`}ZH04%%Wb>eO97O~v&% zXopLr!ok|xL4=)KT3vTKG(D%~j!K;^KZ>1J+rv35wF1|z2hPfyE6ZqoWiY4cN~0~V z+-!%je`T#Mn(d`#H}vDH&Q|C&X$xra&Z@WAxhE=^AHg(w3xe?}ZmFNm4X)AeKem}e z>Dj!+wa2z*rLF#ZSQ%b|#lgwAgr31ir$nQyAw+i|5_U`hFh152exiOoK>wHmKzM8> zdcTx%R$?UDiFQZpnLP_|%|sh0wG&)~UM4gWqnAlCiJ4dp{T8-+qnZ^LF^ZTcF-PKS zBtTe9yhP%gBT3tI6c~-e#F6U2vTJ zZoQe-s526aG)l1#BNvzh_(yYi1+{`EqV9)0Pm(fTOl&c*bK+~FF_({G$Ecf^G*{C$3Z*8D+h{}}J36Ciz6&AtrLnIn05=R70 zk5C_faw@e6B7eW@hpzv8DE^iPPoY=_=^mB_Q~tj=r0xD0^-lS+P!Uw((uoW{!r*X) zA0vZ9g=KpojkOC*V?Ex(JLc(guw?gi@m2y`)`O+{xdzET*Bbf-)?*1wx_*nzQHusN zppJ|#Yvc5*$B!SMB0L@;V-aSn8ST_NepQ!x(ApMnU=oraoVk>;?22~BIq(DFDQhh9 z(Jm@L&B8@O3wNUhn8h$-Co}a%?283D7?dCWvTDdR?=uSjjk$YtLfd&GKm!ut1|Om=)#*2>iLzKbvz%)!DFNI7B@+}LgFgA@l z=?@4I8AkmFL@U6|Er5Adtbu6w;ryB7d;0T1b{6rE;-1g2`h*{|6L!byfpPEy!}P3e z*2|FhysI7Yb@rQI!4a@~*#v%aV$ZOPEHQ^EYmm~3mUF#)uaIPSbvPATl1uVQp`j~B zqS%{)%TP?F8rnAhKJ@Ruj8RHpARI3Uxg1+bS&jCu-ZW||pm#;($+V{{I-ePv(L4Pa zWfJ=WLz=W_lNr*#4Y1tNdo!v`jfZjVYb>cGMKy9||8ICAF&8*9vq^=_VDKERHmk-T z=(h2+Xb-t~1Da=#c<)@e038^wHy7vdzreXbR$@WWprb_WPI;{#{>yq<)SL(u~`}-I2Omj}^MA^C{APK$RQ>Hiaw?3`JH&Drtdg$O?M#*F=(6k4U`e?o}pwP5-z+~Eqp ztj+VXZtxuL0P=bT@hsg7J*V4RMa%IV7!)@}n`3acSwTBIfA*#fUBu028_+}upCSw# zy=g%uup)0GlD69B#V{1)+{hdmh4lZz>bf-O-!DP5DXp8z<5bL%dPd6nBm+Saxl%c{ g>fzF*eJyoX-yuUOB1*8aMg*-=tz4>Hv~7Fle_#G@4*&oF literal 0 HcmV?d00001 diff --git a/__pycache__/mc_dataset_infinite_patch3D.cpython-38.pyc b/__pycache__/mc_dataset_infinite_patch3D.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5835679c43d8c2c3ee19ed26ae93a295a45a6b69 GIT binary patch literal 10882 zcmcIqTWlQHd7j(O&fdu7a!D>Pl2)=Tdu5xl<;F=|)lzJgzQksnN=e+U)yZN%uVJKxFI^V}!ZE_4d@ zLehVvKEmxrJH`4KvGA?O9Evh4ZE+9M5>etd9%*f>z6HV`bW6O^?#&8(|~x zLS(a{Z1}-)S9SuFtZ+kgC|zlHmndy^dz($_w)a;?;T)de6#!4G>q4s=qAK#D@TOij z#fTVvQ>$B|D8`VcL`jSzwZ%~}DW=dOEh=If7)Q*AS)>^;Ck`Raio@av(wvyUa*suo zm!5h2{PX^D%Ws9=%Z;$Pnyh>KWlX$q@q^dOfBo#g_TN7H!S}yMzwNWzIx0TEgWtdG zp#S0Qwu5ryw34B=-+fZkniXypgE=HN9urUSsG(^}|EA_@$^g`r31}z_&{QcvOWA-a zl?Jqx1DIABKu2W(Gb#s|Re8Xi8Uf6!QNV&K0*W=y0`lV`w{G4rq=TL2X{S76A)XQS z<=r{&Z=tVy0=WB=S&n0t*}*h0Cx#EG=@o6uP_rxg6wPrWnd3*uTUhBYz@0G#?oTN2 zpkETB2XWujhP5qC*Hi&(dKhae4o_l)a(pJLQqjcNEast=3Ju&^TI`Q!U+gwS%U`LA z_;g>dyfBQoY6jWSK#-xz^g8hGtxIfem$`R!9zyK+6~HJVz`@|VaFAo z-)#k0;Hk&YQ{U#Q+iBj!x=UUVx}C1@+QF$#tJ!R{+|!SqehjT2d8X@MdGW%M>xbSW&opGadx}Pgx9jfQuXyao&(HTZqwE>cYKF~jr}xy0v~9X&wr1e3|EW7& z?|BzPwL;gOKhy3u8|~n!`J1EMLeUgPgh1AU833keX`qMNnm#m$qAH;&)2C)eMmLBo zqGaSOfrfsg<4HOYu8#p7jhqH}3D53|v)V0FlSbgfd=wd;`awEOnjFkx2yJMF)>>+4 z(?YhukOJ58x)C38&FqgfWw+P!g!|fH-tTQ*d{2+^jdr_x)$O&K>us-=jtsvSS#9uD zlx`9`VBI!J>qaPPzjBrUiLZDQlawP)kh`1s0!r~F4C6yXM|YN|w`_(^hBvd9j~`0mtX5xpa{l!4pxSBq zmqRa@&u*W=K(&YD14tsrcb6q@QXr2~F&hh)pw;)7q21O2u;c8!+Ey6k&lWH*Fxn<3v3 znjsolZLO@j+_USqHzT9>7@pHnp2T0e--b8?C{LFXms3D8l=Qv^tS$;SxLRg{kt_zeP|BJc!(PXk2hE1q1!X~@qI z>azr3Non#o2@pTX-y-l7fwKh85%?THq&MUQrA9N9HNrA$^8F)86*)UHzlc&mOsdV; z`X~M~9o^C#Gp~>6d22>58Ck>m(6RG+#c+6(l3q3Px?{|sRMHQdj#fr1$FQ~j(T~+b z`@{`hG1BDE;0Z|bDp*rOUpHhagm`}grZzCfh1f1I7g?Us&DLIjbN+J>2t*mkMsU?R zWe76_W2)f~NFy~|R$!@iXsLT^%0ywbnJnn`51TBy_dhUOOA zg*S{#KZO0H4-8-}rB7=EGfaiHGKXnu@rFLILT4>A%#sotq&T00(9wr^l~R^4reOz7 zLGa1f!@}Cgtu?|F)^%fG`-hYI)c&43^~#J#o7RMlRix2+bXbH1>|kYM^1l?NrFBdG zRHO%)jbBkoITnoe^_2WKl*Uo_A3oeG-xdzF5cVMB+o&lIGGSpjpD%@_IPa92NXDAu zofp*%t&#I%Y8Lrf&X?32@^gG{M=){~?R2D*lx`V&@UsaX=|22if=9Y#ESsQ|rYh_s zyuMLpJ3F`7ubg^q;O~>A3V;#pP!WlG~K4lEMa|U zwWCoX7MzJNWx+1VZj~<(SOAFZpesWtA&X3%$XfR{g9D~mWVeFlR@-B_8|(inb!Q{M zA|v3S*d$>arLL}e(u+)J0T#hg0cLZj>$@F~#CVhnRvSG}UZkN|Y_ptQq^@>@b&7nL z3MO=P?NBW3N#@oaQZl^Sv0e4F<4;7HT@<{17{qm3dTme_{yQXnr1v6wNB3SMeQmMU zw))7|Q`hJm>Z#2*>2u;-OO$op+0o&3r@y1X)7M_5DCZOk401d*vqG$CY6vG@8z+tm!W`WCUwmyf) zHf;T{J_*HuKfn?EW%;kD7ojH5b5B}~+hIifeb|5KjAUF5YD+5y)%3M#&qA=q$)I(#<99qiU(uXMl(IryT8ge#rY0f#wWu{5= zSjvFUVX-4@k|V4TzBr?86vK4rD5xm3NrjoUEbu8B2O<|eMONgl8!P$__6+*LwaqW~ zjZ-J4+ish^T$&PTe-S18iw`HIE2u$Ggw(U@iTyDQ-4Ub_`aa%T(Y4zj^9p5B~^HK-Pn%TYPTwIH$aRT<_1{@uY6+ zV)F(To3!$sA9)OEJbYw7$;3LJb!CbUMx~gfHuRw_pg}=OXhmo+l3<9qus(3Ogh&gQ zGN`eI#nF}j4AX1QFau4O1vV?+fzHv^a?obJNDXoue?=vgO*GxtQOcw4MxyC#Y9Wk4 z&M&E~NTcR2q5X2cB{G9tm>w2X8gPVkp2|V%rC7h^Ny9;RjY4;g4vNrehA?(&Y4Nru-UO&O+ejME<&QZ?sZ#B6~*j-$|%t%AXz>>mI0|h_d;;13lM_fu>%dO$fm^Xt ztZ4THQu5^kW8K>c&{#Vs;7IIpA7|_z>*G!!rquh)vuSyTD26;Z6e@5J=Z421r!}g1 z3+F>B7nXS*UQl!VkJARa!thli2h$HTUK`oK!>fr5i${+}fLYs+sP7-#>EY1BL3 zdW^G>@UAU?cGq$4O1Z%?bxh5vD$xnNmfHB)|05s2zjG=ddm_#~ok&{#(*X+ndMDz@ zxdY?;dMDz@rMr&viB1Hkk=|JHLzRZ#xJNpFG8$ew^T{Z*pV)s+yz9vIL=rvO2tCBG5tAi1^f9U>dzfQ}h@pB7KQe<*#PPl~#dKw3GfMm3 zRo?RP*r4YZ7Lxi=7!dJQdJh#__+;5f#!4uARJ}ofjDFc7uts2=K$}2^fKPy8Wb6=1 z^7~{65b4(%F77!>iEAJ=^}ChZ{xCuSaUt+?+Mu=C?_PHZ=|>6 zCZ#D3Ye#zLfX5N#n$iP3(77f57A=Ec03b>SkE3Y9k2LbeeoxbZziAnV;Y05CDtEk0 z@=@_;k#{+R@MY1kjjZ+=_^R|gq32Bpq0AEg%4nC>pH7~0`lP;K-ikk;axl_d()ua1 zrsoViBWhp1P0fg5VdSzt!+j>Pi?UwEE*$CxPo9zoTEZhd51>2$THC zr=l|PO+Y~pa(iL@nn>Z6!f{0m@SG-(BSy(Nil&A}{O+5aQH}u_i~Z7Kw;D%Bs$Sdc zczzh*-JXeW8e%_^Ca9%ij{73=%p3Jo1Mzx?pxds|p0VRRp8knD z?{jZ}W}lPqU<|r>ywh(XBPgn93{31%N4ywz2dA+9IfZO85mp|g2KFFLPWdFFpzEf5 zI?m@1B7n`ZgMB5=18b`EF4m9pz+w)300Nr61v#iph8BDc%mR7mEY{4$#l`LaM!0PI zLx5QPjbN251Ltcx%aoeK1gx8_&Pv^s?ka(+1Vp4i7wKP#^z#R(81XAj7a=aUAsW5# zRU#;P=X5oW>3IzoBzgJ|+?lL+EqZf|eRL&xyz7Y{!|2ciarC_s(i>lG4FQsY;gbg) zN3#i+f|Z1~!5nF}pc8Qlp*>7fFa^9vIr3;I&aJZ_om=&&$9 zEQI;B5n+q;bsf3UIJaV;za83ZMdI5P__V@7%sZ`Yl?Gp@aWV+8khRGth7Wx>Z(*J> ze3d}{Ybjn0=Xf>BTpNEiH9#agiBa1__??J?XH`z+;e%lxSXmiSb9)aCcW5}Vr74Rb z>$hzD4&ovoE_BjcbkO_~hUGBvdmR7NV|n`;At{cTEVdvdUs@^#s-bHUYZu_4G6N1ONfWdii(29No@yUD@j6GlE8Yq221$OELD# zc=^}D8-Kz@W)m;D>AjQu9s$;JuOkPcj2Gfztq@;t-Uw@z>!_?}eeYTrzu)8C)E)Zx z05gkv8t>L|^EV@dR2qPLF zWi12?k)xm?o{wy+1hGM(L6j|4yR-f&j?;h=FvvEZ8wE5A?0n>@uV z#xE{krwU@>+DOb>6h6Ocn@D0X`E>&C5cn>EcM1GCfhz<)hNURrc%|EhH5GiDSn4GD zkZ3F-DsAY-hr~|wcZ}@>!cUUs2YZ!`(%o2T&%m{^B<-_saY1HiNm&9pfQt+CV{4Y} z+6BoXf0XJeaJm55fJJGG7i)QXAxq!JxNekn>7$^_cx&!PdDlhYw;k8W3)J&V1Ssq- zO9WUhG$_Xurueq}3j#N97w71TNn&!EJ~#v5r0~JjC;8jXx{l?1*2!o8P@8u0PR2RO zBVpq#kc#|@e*SXUY6tW5agwC)dz54J6?_5J^;sxIshdJ8%3NsC_YTir{5-y=YK80a zY4nNhVitd7zH>A?8tq(ehpnFMHobs8Cy2)IHBPVH4co0HeDXxO$WBtT*A8O~oVI#9 c8br>Sn7U69>4QlCmLY6(Q2-lntqEHH1Mn>KW&i*H literal 0 HcmV?d00001 diff --git a/__pycache__/stack_convlstm3D.cpython-38.pyc b/__pycache__/stack_convlstm3D.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87c55b39aea2f88bda76792343e4637a59077144 GIT binary patch literal 823 zcmY+C&2AGh5XbG!$8I)FX!#Tu#0jxTX`s0xR4GwTw3miUFJ`lLTPyn|+W`fo5{F#5 z@e+9eo`6T-<}0Ut1bSj-!)k+9{%Jh6$20Ta551mGVEp+0^W!gvkl(o29}QTX!qm?I zoNyYG8Q~sp@aBeIQr_b24Y_olA_zKuI3d~A*6^a&EvuWVjj!g5VVO%EGzmc4a zw6wbmp;8euEx!osr^|%J;YuiNyYtXQH|$C(ZESNcLv4?gFiVxPFpjHJ{e=)YlPOC? zlBv~}T7R6SpWa{0&PJ2KQAhB+^&)V^+9X?|9)aU-GyI&Xh0eoB>?{)+nWaNK!yvOv zYMEXSM-!HpNhDG2WpkkoOENBEJxpX2g_51TK6wMD2i4PhaFwaSS(X|x7>6p(hT4SD zf>n8pCdc{8_IJ)Q6dO&jarz;E&?Y^gPW}C9d+G`JBiyV8cE1C&6z(M$d%`t9J`~Tb z#RKWwTe2plbgC4o36~V=quK};P#{1qSK&XQx`+cr579^5LmVRR0|Kfr<+80iJcgxx z@PBF0J(1&Wjd!rV!@9b=wyU0k_dKA}pke*GNoCRH-d=nZP?e8(Z=rR|SiHklXspI_ G{O(^0Xuj_N literal 0 HcmV?d00001 diff --git a/__pycache__/training_functions.cpython-38.pyc b/__pycache__/training_functions.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3e7cc333b149a2aa7bc9aceb3856f6db34f9843 GIT binary patch literal 1740 zcmZ`(OK%)S5bmDod3x6Sh~p%dfB@pL62>?PAw?)caSr8efth>C8GNO<*`1xKOg019Ty?9$QV+v?Q6i$K*NvZG{r) zDJ!#&totzk&ISEq(^hSuou;cgGyj;(0$AHo-E+RHi0T~Crl*+rRdwMqkYX{a$5BoQ?GbuAZ9tNoZbCH*Y4zoli;3wk( z%;G1>ShKp|a@Y}dQRKxQisCVx??N)E_*~LGv%!YoCH^C|IAu@XQJTTTsvu5 z7jn2@NW7a!I5N#O!&7O$;s-S^Qm$Q@h&?W~SJrY`%Stojo1t&*7(!02c(OJ4X1iJ)m11?fEF@puNW{2YlOoAYv_^H2;$GE<(?jh;>~gpY$E;WP z00@n&kcQN^d>YXugApJjJ=(LE?G*|?E22LBtdOAHP20>sBHjZJsf&J2zcx??!*q~; zfI06PuC!+iYS1$h>>9qZ6oG0Flq2pcdRoWSR6B)gKbK_GR6A9+3NBkI3e_H}5PJ?K zip{)J3cUs9q3JIRP(?4TGMWe~>}pL5>h5aIB6N^QwbbH>L9JOkrl$I#z&i_!s=#On z9|1e5@Bp;Go22_s6On0FNj}xVPh8CFM=XNA-xjZ0k ztRUP%xD7B2@!erb)KGth7jL8TEd)%4d4O|x7p73?_%hakt!@Jl8*)!W+qX={>GC|c wfF|=)tb$FSZsMQKZ)xs^i8s1S{_&{`vx)y1yw`fn?_rCnFe30`?C^vC0cxMJi~s-t literal 0 HcmV?d00001 diff --git a/__pycache__/utils.cpython-38.pyc b/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63d3934825b6be23ec25453e95d7941146fac601 GIT binary patch literal 8067 zcmbtZUu+!5dEedJ+uQp?@ZL4g8A`(B{Hyfx_yQ8W)lfdEEY z_xH`-9hZ__80bkeJM+yqf4~19Z;XwVH2fa?!Qa1Enb)+RQYZV%q4O$A{Ew)Z#`HjI zq+h+Eqi+O8H`mDN1j_|x*KC-otq1u=9xyYox`jqT!SaFKwHvl-TS0M^=2$^#@GTF% zmBDvx@Eymu!0fwPV}cb~3EwI!vkJb4*ccnfcalx8D!x-~ZCz)F*yLTkG0mpfG+>9> zVRi)HBkU+UhVN0PU)O5K_i!6+rKZc$k}KTh$YmaXs7tdQL~DRn7Nc<8Z(ntJ>~ZvM z<_b^vT@P*jREs7Le+C)_ltq*{Ku>5JI@6iKa`%i~ZAaTMn91_Y5;?lYy}X{na1~Yt ztdfD@I&4CiteS!0I&4bh*>nbmE3u=19m~LuvnSY-SnsI}>`Uxv_GQ4F4D1AZhCPct zp3A^~i9HX<3xI%l>cl5m4f9TAAYWlGvN~WhnR&D9G@HXY&SYTo>@0hUz06)=Ume}` zYwR4}**hQRVF|*zhyQE;3GqBg$8|+Q&`^y>Fudroy1+bM2 z>@9Xx?eWdb9@p4)b(Yl(>;`*V;oZ!@-eJGW8aT_ZW!C!^dzbw>V9gBdJ?7%JYpgYT zJH|Y=j+bg@usY0V8-Q(QU;*p05GRf@upWC~orGsj5;LK0+aKcHw;pJ_`VKMjt(9bY ze#rs5-deuCdc_gk^+P{wuhtA%W?tOleoy#O*sxv3njJ6bCFOcO_PhON&lQ~{uiD91 zXSX808TZ_lm$rmUvu1->beqfzBR|HXvths6+iueAK#j+<$?{b!m;{v?t^@QVx&P>?b#suDVMJ*=_%v*Q+fsm%y?s?KwM=sXu zEs-ABM5~v)_&MiI57^>SujjF*-*wv_&Ui?j5l35Gby65TLfDMmZZGiS#Hfc{hxjH> ztn|BUICr!@dPrx8RH@hTVyzCYpLj{G-i!BMr8&=CM_=8sVm@ zlM>t@Y%1blYAE(j($weow`;kv7t`zYV~-wXNJl63M1rlm5%Yp(=ypAvwks>Ka9f-6 zOQ4O{OAf67t`3c-SD19aAd2JRjNp^881G+?yHO-M!CX?!P9e->0y8|96Z14Vv>nbF z@OFLisek{;%cTp+;WtBH_-^0~ksO6IUaxPdYvUv>ce{P>KmO|e&;Q{s7Lw^9emd-^ z;iOuxN8raUo{X!r$-K+l?pAU*d!7-B;#K#SH{w;ar9P9qI3#EcfO>m%>lb2&b`zag z5rKO#7%8UbYxnw1@CIm$=WFkoirBATz2W#^ZsD){Ao|Gx=Q*cq39jV^L9^9y!_W(2 zyoKAk+$hH);$Cyz4Pq}jd9&kz)j1I@_5!aZQc{DA>pbc@nD5;3y{)*G<3;QzEw@E7 z#mlt$B=~_?KEO+`=Gu0%EfTXHwuLOEn+V?bl2=iz{@9vlXw5V2dG-70h2%nZO?NGV zr2JoNK7LRL0crgJ98W~U^6`4|((tBqvf)+n`h*9@Cq?9kvG80*3Ta5an!%?)td#ey zRGq>0kkm@shpPqU;w81=+7KxgIVseMI943AhiZ`oG~ms+)$v$AfK;34#+z`NLHEKoZgpG{N8T+QSBXO4ZhKsbLXdoE-xQ=^vWU(_|E(2I%rDgn z(sp^fNA9Q`e~q{MU8t?pJCzk;d5kKU%*BGs{CW_%LgpaFvZQ7aLJXh4QavtB>Oa)@ zGoZ>R+n-}iAF=*0H>m~ew)TC3=9(9arc%H0Y}aqKT)%nd^qINY`K9{MD~ap4O#Mm} z3a`HC@*tY+3qOcwdfWWV*uRPr{~an#H+4fd{>L&ay78%HeDS)Z@YDK>&-&t+Uu0!l zSGQ@ACo%kyp`oIJgKuL5&NX7FO@kj}#;z{(JKC=CLp^+3!E=D0R&ZkgUr=xp@Ha$` zU^EK+=M_8;_FHl?j7os%SwW4-ip)uf01D@`i z!e36|@j04>@pGbt*8bipRyBoU#AXVj_I_HF+1MRj*F=RC@90a~yOmwUqCeE%4^iI3 zsCtn*2ErO+M_)JKGEb~D3Iq4)G9L$Vl0!j1^u*cvT1YO?=Jxzja`-V&bIor{dkrca z63O-`?1MawD=Rf}6l4irMl+S*Jz|8_hDB@7GvL}o>&IW9T_S(-%^EE754x8nBob*VZyJ6d_nfwF>rAfDvd38pq`$~_b)w48wh-`oh&kE`| zJ@yTNPSG1~bp;WXZI$C-Kt*UOLR8xpy15yH>n+C@}M0ZZ{V$D-05= zYYM|oF_;Ag&V@5-w182eMvLhvpryh53K-5NEX8m+Fk5lufIgrDlh<<{B|d_x4f%R^ zoPxOxjhj3A?-=i!*Hg8ORKpmD_v6FIYPPw{9vae^IW6_xqpw{$ANa9oMcv-Q7A;f9YnJ@|P#KERUz-1-t)N@kMHJyJQAJ~^ROO701-O;kO zin2EF0R#+NV|+9)aP{Fmw>mPfz|wpT+cn09##zN2jr((N{o81a-u?^>q9ODRV?*1>?LfMB z2H7^swNbi|mG6}xyN;Ss!HZrSU_;oDJ|v9sd**I#hop64g&)J}1HXLlg3NXOaF1TJ zeS7-PUjNtREnhSiFov_1d7U{H`D$*(pM z=E;I5+$QjO6KARAWhLtm^9um-S5e7B*Wep4w{*dYgp2M6mG<1B$S4I5kjO9;qQY< z9_V+Fxdtnk`;d>c&BDI^g|Z{HP%~vN@Iq-su{1%+4(vbjG$9lme;MNsbx`4hR?tVT z>lyUCbej3S=wLZ z0s1k$lcw`7>d)6z^nS$saE_FrVft|#N{0<%C^ch6DRx+!ic+I6HBrBZSb-BW$Z`hw zDD|zjC%_(m47`}g1)kE_@>5DZ8rz?2t!TEQLzdAjldSgj|MEVYl;#) z8f_6pEt4tDgRuC($#Ys%7j=C@Ku z;V*$Tauf_L`N}|&aad@CG~%b6aE4h&b19)pkv>^@6{p_bP0F(^^(T`Vgz{3~s*Q0{ z^PDarb78avm8ep@G9UMl<=`C}u;Oi`3c5TdR5I(*yoDH-lX{hA)PwBhw>VM*wL*H8 z+k_~nQd1VtYK)~DH>0%;B;gt+b<;b4K!{3kj(*PV{sb6G!Dt9J9oQb@Q`;EVOU4m+ zpX0fzKCSa_1E)pIJN%JcA4XY3iT@21Nf+1<#9?5X?-*cX{M82~hhn1SQ7jY{gyEZd zjuhsUQ7R~7C}i~~P^u`0P$rpvZwl;-b$4K&7?}P4w>GR}cICX7f>3B1)9K%*U0yaB>yEcq+TJx<|YS z`&rsgBg#VD>94N`o-9+`M*I$zzFa#j3xjOEJa%Pqe(B9Y{^=^+m;PXgth{;)eus9y zOVxL%`jD#cQuPs4zeyEcgZ~y)WEuJQsJchh=;ZwS)cpZfB>4PyQPu3!rOGRlkv_=p z11|GQhsfM5H;@(8hG#{8zLX_^2JJ^R_2hv6DLa*7$A62(;&)JKD$p_!j%EKD2YPDc zXV!G}G3Y@&)%%%scoGgZ{}I-%9aEC7{1I7I0l-X}{Gp?y*Co@zkE4>6RFq+I;8-dx zkSdTTDdUtrdvllzOmk6)PfM6|i&9l+SJS{a1F^{M^y0>mG@x&;!NH>V7+zezInVfW zv^iy=R79!Ve&v5Er$kv+Wzm!tP&djv?K3n@=@t<8 literal 0 HcmV?d00001 diff --git a/lunet4_bn_leakyrelu_3d.py b/lunet4_bn_leakyrelu_3d.py index d5d1503..e9d6c53 100644 --- a/lunet4_bn_leakyrelu_3d.py +++ b/lunet4_bn_leakyrelu_3d.py @@ -1,6 +1,6 @@ import torch import torch.nn as nn -from models.ndrplz_convlstm3D import * +from convlstm3D import * from copy import deepcopy class DownBlock(nn.Module): diff --git a/mc_dataset_infinite_patch3D.py b/mc_dataset_infinite_patch3D.py index 8a2cda3..ad30626 100644 --- a/mc_dataset_infinite_patch3D.py +++ b/mc_dataset_infinite_patch3D.py @@ -74,9 +74,10 @@ def __init__(self, # Particles to path dictionnary self.dict_particles = {case_path: self.get_particles_to_path(case_path) for case_path in tqdm(self.train_list)} - self.dict_case_path = {os.path.basename(case_path): case_path for case_path in self.train_list} + self.dict_case_path = {os.path.basename(case_path): case_path for case_path in self.train_list} + self.dict_ct = {case_path: - np.load(self.ct_path + "ct_{}.npy".format(os.path.basename(case_path)), allow_pickle=True) + np.load(self.ct_path + "cropped_{}.npy".format(os.path.basename(case_path)), allow_pickle=True) for case_path in tqdm(self.train_list)} @@ -313,7 +314,6 @@ def create_pair(self, path, channel=0, idx=None, patch=True): # Normalize by the max dose of the complete sequence (including ground truth) m = np.max(ground_truth) if self.normalized_by_gt: - print("Normalized by gt") sequence /= m ground_truth /= m # Else, scale between -1 and 1 diff --git a/pytorch_ssim/.ipynb_checkpoints/__init__-checkpoint.py b/pytorch_ssim/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000..738e803 --- /dev/null +++ b/pytorch_ssim/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,73 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def _ssim(img1, img2, window, window_size, channel, size_average = True): + mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) + mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1*mu2 + + sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq + sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq + sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + +class SSIM(torch.nn.Module): + def __init__(self, window_size = 11, size_average = True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + +def ssim(img1, img2, window_size = 11, size_average = True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/pytorch_ssim/__init__.py b/pytorch_ssim/__init__.py new file mode 100644 index 0000000..738e803 --- /dev/null +++ b/pytorch_ssim/__init__.py @@ -0,0 +1,73 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) + return gauss/gauss.sum() + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + +def _ssim(img1, img2, window, window_size, channel, size_average = True): + mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) + mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1*mu2 + + sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq + sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq + sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1).mean(1).mean(1) + +class SSIM(torch.nn.Module): + def __init__(self, window_size = 11, size_average = True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + +def ssim(img1, img2, window_size = 11, size_average = True): + (_, channel, _, _) = img1.size() + window = create_window(window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/pytorch_ssim/__pycache__/__init__.cpython-35.pyc b/pytorch_ssim/__pycache__/__init__.cpython-35.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3501dfe6a2528ef24390f7b49d1dee7d23cc18bc GIT binary patch literal 2919 zcmZ`*OLH4V5bjy+Lt0C*6B`no6o)_+ED#%uNAXC-P!5C?#8s4B90b(Xvb$E+Xm^#F zRZNUj9Bc?zxNzgbi3?|LT)6QM<_4Tn`~eCsd_8Mffg&SKZ_jjZ&rE;aJ^OH{<7a<= z6Ta;e{Z98a8~khN>N}7Ge?dc{{s}K=&!W8s^=%`yXy{PxQr{(UtnE?1X=Du=wkY?h z??VQ4hmK%rffAPz&+u(ZpxUCurz1iA4qZziZnF$DC9~EkI>d5|lyr>tDM}XU$fEud z@2wXs-NQy{I}p10UL>>VqhSiZk$y6YEta4)(Cy==dKpXw%YyC_-4(QJv7N-CFEDS> zmx2xjT}8)#YSF~}ao|Iu^Sz+49Sa!X#@3tY%joJ=kO9$z=)(<~3SMC%*chC$dJbfH zp)!LnjVG5^VEsJ$R~FrJaSCyd%>|a^y=rH7SVmP~X>09SEj|I`sBDl&dSO2+l5#&( z*+Kd}{dIkPzmy-VQ52`^)y-5!>#8iWXnm-vJWSG}%v4qk*1K1?%HsAry`3BDqdQe8 zy?xHh`;mw$PfsVXk>(G@@04l(b=!YaqnZ>b?oy|P2l7FIQc4kgM%Jg*3_M8hzDtSG;ujI@gBRV zNfYnBVT+}hS(*AXOUA7$|G(a(N#~{uGfk{1=s}t-XtrmXu=!w3ADXBZ9!*g#nlwf6 z0J=UbpQ5wOcEB#u*%holY0+7po-EL$jR#MkpDxkFA^DP+EFA|@ujWv@IELV~e=|L# zcrR#eYi|@K36f579BX$V%kfAJK3)0Y+OG#kn}Z)8{rr!2fBbe+a@81I+Pd(^Z!fNH z4!-{7v$uYl-q=J&i8r;2V7c8*bYoQRYdcS)A~>z>EFWwb+Lfz#xlW)*c$GGngD|>{ z0x(GXjeNYpq}zAFh3XbRyZrQc%z~DwCPr4XwYZ+`YHMo)z1wecl?d}_^xqJ4O#R@{ zQWro-Tof0rMX@M+>nVYs=QAAafo0y z*&gCns%IRkOcj81&*2wX&Et*r0&Cl#6*N(W$D>rrCt+MSXVtR?;&IGLM#JQvK`3Zw zTcyL@xdJYxRv5;^NU1QKr?h+;@3B$E1(L+FXqn$ZhmW~7Tj=DZoZBAy7P|Tw1jpLL zF^N8~31tmci>XOzK9Hb~Es{HYruy=(OCrC8>SxO1{R&fuig!(Igg(5&b%8L+`X0$6 zPK`Yz4u-8+orDyVK2LH@Fhh1B9ElPN23H~y49><3q%+IRwCT`Xf!jd`4nru{e&it6 zb`n*QwyQhvI~TU&(wV6+9w!hO>JQQi*WK+bPPNCXVWdplNk%1bxLwE}{=YCS93))}iM9ukY7t;Rvy0CTJwI2;T`l&4{+{V>eSWISZP z9fr5Y(Xf^nkmZXoIA1al@)A#QAe*Rq0*tu<+v1#OW@Lb=*aqRosGlIu*C6}v1bM*j z81QWq_{P%=9WaC%Lf{Qkxe*r_8rT^tNKwuWCiRdr!Ne37@)|DOoW~G(at+JYSZpxY zW%3*o&S&HHz?(aJ9{lnJmYWnY^^0SDo&!K0#pZR)sWu4TX6$SBxO@fU;I!=E(O)(n z9BW0fwmUB3Dl3akT8rvjKO$%xEgwX zRz0(%Ma&`x5|A!A#Ft!mk9mqbK;EE_2@pWY3k1%^U-j&ghTKF~S66lSY=2)>HGex8 z1Pm?v|KFlNLi;yOR?UX;5qfE3`RTMD7%B=o+2pUg*PS<-H;Z z1K4coT(EFWy3)hgmm9Jtg9{!Ge$V8Z>|Ze1Pua0JUO%_OYx25Wm)GFkko7P_+>>vy3vY|dY@m;6XbF2#l4h*$iaKVj#5!Dphy z-`rZTy9~!$`w7PmYIK~BeiOYOL5x|;zIb=W8`knyT(T{82Y0>6)>wntE2D4X(7Wh= z70+!P$6wPaL!sz5Y3C@f;$|qc*nXh-*HD}~8yB%&JIzX2okn$bmj0c6`|-|crM{{s zagy#d&r%id)K!_qJ9*s{kxa`ftFv;vGkUUDl_#Gb96sHdyl5(wJd5f&D|VtNE3+nw zwkI#X4IbuM-6T~pc|>=q>F2oY?A|z@Vr^V*Kg_En&g(}|Rttx=cFb|wJxt5GQrb~* zIZm}zPm7_YdN@PTY1%QzD2f=|kxZR)jhz1K8y;pg#SY{5xw?Vzk}9H!+n*ij>$ct|n+3~r*ybB&wf-G8TU(yTf8g!G z+h6Yep*MhcAg$TDv}e~OZv*UKzYP}K0kmtk>1%%5yKSuPTHA+atasULqjl7O+6_v- z14g5C!0z|t0kR5Qy8DU(2pQu|Vkr^*+8wKEI;qF+?cMwDzdqc0JpSgNfBwTiW=|h0 zN?61&|DATh;K@j8Yf_zRyGY}5cwO6BG2S(Dq_(iM3)vq*(|t<7C_X{TjnmL7rn@9Y zp$jdlpVK&^(W5yzdM1C3uTwjBXB~xNZ&!=GQRq`%Mnyb%GcFvHvBYN~1nmj2BQG`Mvy_f6==^^clUI-wiixO+UWHt*O z0O@P{;Nat@uPFo%6RJwoXk2wu>Kc@^mLUzr0`b>Ek1PZli)_%!BBw#pwd~kAC!1#| z68jzSK8o6Ok}CBM4%Pmu7N@UMjW6-IGvQGIE7 zH?FKRr~D9Ah8?@DeJ*AW{GR%}b>==!>n#}a%)u;ZDdV644ngNeyjcT8f-JNbvVZQ< zi?BZ&;KGQr4hM=rZOgccwcWe`Zd3}E#U`txWGdk>G8m@~s{curq}rpUQCyp#P=xBx zRX?YlKOyl`2+gCbFT5Fd-chAaVk-gyjJ>X+D{2TLxKG|yLjniqqsU%;;@6iu*Dz?gT! z=j)!am@g(a?9_I3qdur>vSZg&;8U>n*08`$bT1l4@P zmrz!}#(c;RFR{I@2C#%T&GWlmmfJ_uGHJ4^jB_o@QU_Oab)QzQcPry*Q;k(Dy9dlE zr^Vz2HcU|T;^vw8@58)%-R>u{#u@x2%^7a(4=|FGlt5j~zYz1=>)!Ta3j Nec{`_8{FF5`VNJSLbCt> literal 0 HcmV?d00001 diff --git a/pytorch_ssim/__pycache__/__init__.cpython-38.pyc b/pytorch_ssim/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cba296fb9e0cf1d042578d278a2fb8e047c9e3d GIT binary patch literal 2707 zcmZWrOK%*<5$^7J?Cfxtq9j5RC_xU8*zwvSDXtSG2SG|eA#d_RCI|*L#^dSXYUp{a zdS*$Bm_-gGAmxxla`1Wgs9)iOkNF3EOn?A_|A3KGzUtW(Ww{q!U0qe(JzZZ__5P{f z_ZixUUw-}RAHe;GCT}(ylV74&Cm@nZp0S`Ec`gLwEVlxilSE`r?gp;mdDaO$$gHfJ z`+*OcEu9M%tVmaOF!tn{>`MQF2mRkLxgvWPO!g9XERI*tt>B8hDp%zdXxAh^V1sL~ za5mc?aP9p$Qfc&AmVmdCFUGMT&7xTt9dsJ24?t=*2cI(jt(8L@_a^1+nXtl=-XWE<=j?tGoCuo|-$hF{0V zx6%J0p4!-wzoK0RLeckV=O`KOLZjYbUrNT~CmFC-Fm=J2Xs@rA?WP=5jI5o_R%W}3q@q$nUB!jNl`tW zB*|H#MV@P3EBX{gU~3mK85L6dF&vzZCgr3WIOfW_6F-THBFO?j+<6?fpo1SSdF{5c z)*r?x#vlDyB~hI$R*<|RXEDA8x7_0D4*0{^CDEc?#Yl7G8DuCtbIfr|(O8miu*R0+ z@AlL+QvQ~-)3`Ha3;gFbodwrz+@`b6rv73dw*J;$jdzPpR~Gy6?`gZx_7-cuul1qr zOKZ9+?dcWC8z1XeZvx4t5AMoMI?ZdkH;uGeX?oy{^fsHWHIDjkvqo8bz-W{X5PpLm zATfcbyRRsCk$s(UBqb6^yCYRj#?@$Z_s%!}zQ6Hk^v{3&^|xP5_a3Qrva~h-wRVBx z$xv!*T%Kw>PoiRQRoiJk+A(yfZo{+<=nZhXM>!luCy9ziNnquZ9U{ZP1s7INX&lnn zp*bnKrUV#W%eLxPIt;|_4*D?gC^f@88owPIj!9a=yP_~){06@veC~5k+(mTU6*jlU zJy1`m_n?o-=*>o6p;nPLVF7AcAY|t`J3^XB`=T4bnbJkYwx2Vb1rD6)Y5U;dcYCiW z6c1x6Zp%jevYV<|qdGPWDJkZN!v=h4A>deO16~$7H4?C4$Idy)ItD7S-vZ#FK2FAo zQa{F~+FRD8*4ytwrpV)UC$GHTdU3{Q{#nwy54Tyd9U)sZ-oX z)nUhOW1owu1ARyRrEz8+cI(a=a?ZibZz%h~!6kzB4SBr+Hwlu^TF~CPd%#}u{Xrj> zM$EOoplH;#jOs|+^)vX6Vq#Hz(khH65(-28QBvdCI7#C~cgQr1Dia`zJUejJhqUsC zM1BOKd3brk+cDU|dvY2nxeddLatD=3`~VTx6E2+VgSysP&p3BsiZ_t}I2^bHY!T&2 z7-~PHmur#{-wVU1lPGI71~Bz=EYST0o^g_8OH56r{0X$?`7q%^A!A++kFR!&#GDr| zqaA7k=W9sm`DoYzQhVwER2z{Cmjfu5AXQr^j23JFEZ}i5102L5hAFaL1CT|)s4Z+a z=qU0?ivj(EP#+QbDUqKMp^SskYbRd@uDVZp6W*r4m>U{m9t4}J!LMLa(Ujf+@-cZ@ zJ-~dx4`2KDs#=94xNhF#t)kdEniMhporp3mibDIBbM+2cuC|uZq%KD)lI^4B6q6i( z2ShPO5sd05<`0c|3)|b#O|-jbz-zrt-|(?pAD=&3@&Et; literal 0 HcmV?d00001 diff --git a/train.py b/train.py index 97c7ad6..99e32bf 100644 --- a/train.py +++ b/train.py @@ -19,12 +19,14 @@ import datetime from configparser import ConfigParser + + # Parser parameters args = parse_args() # Get the cases paths cases = list_cases(args.simu_path, exclude=[]) # Instantiate the selected model and indicate whether it's UNet based or not -model, unet = instantiate_model(args.model_name) +model, unet = instantiate_model(args) # Create train and validation dataloaders train_loader, val_loader = create_dataloaders(args, cases, unet) # Get the optimizer @@ -34,12 +36,13 @@ # Get the loss function loss = create_loss(args) # Write training configuration file -create_saving_framework(args) +save_path = create_saving_framework(args) # Instantitate tensorboard writer to write results for training monitoring writer = SummaryWriter(save_path) - +torch.cuda.set_device(args.gpu_number) +model.cuda() if args.mode == "infinite": print("Infinite training (mode: {})".format(args.mode)) diff --git a/training_functions.py b/training_functions.py index 2607656..7551ae2 100644 --- a/training_functions.py +++ b/training_functions.py @@ -9,7 +9,7 @@ def validate(model, criterion, dataloader, n_val, unet=False): l1_loss = torch.nn.L1Loss() ssim_loss = pytorch_ssim.SSIM() # Validation - count, count_batch =, 0 0 + count, count_batch = 0, 0 with torch.no_grad(): for i, data in enumerate(dataloader, 0): diff --git a/utils.py b/utils.py index c643d39..2e2d199 100644 --- a/utils.py +++ b/utils.py @@ -1,14 +1,15 @@ import os import torch +from torch.utils.data import DataLoader from glob import glob import argparse from configparser import ConfigParser -import datetime +from datetime import datetime from mc_dataset_infinite_patch3D import * from convlstm3D import * from stack_convlstm3D import * from losses import * -from lunet4_bn_leakyrely_3d import * +from lunet4_bn_leakyrelu_3d import * from bionet3d import * @@ -17,8 +18,10 @@ def parse_args(): description='3D ConvLSTM training', add_help=True) - parser.add_argument("--simu_path", '-simu', default='.', type=str, - help="Path pointing to the folder where the MC simulations are kept in .npy format") + parser.add_argument("--simu_path", '-simu', default='/workspace/workstation/lstm_denoising/numpy_simulations/', type=str, + help="Path pointing to the folder where the MC simulations are kept in .npy format") + parser.add_argument("--ct_path", '-ctp', default='/workspace/workstation/numpy_cropped_images/', type=str, + help="Path to CT images") parser.add_argument("--n_samples", '-n', default=0, type=float, help='Number of training samples') parser.add_argument("--patch_size", '-ps', default=64, type=int, @@ -37,9 +40,9 @@ def parse_args(): help='Name of the optimizer') parser.add_argument("--save_path", '-save', default='.', type=str, help="Path to save the model's weigths and results") - parser.add_argument('--gpu_number', '-g', default=0, type=int, + parser.add_argument('--gpu_number', '-g', default=4, type=int, help='GPU identifier (default 0)') - parser.add_argument('--all_channels', '-ac', default=False action='store_false', + parser.add_argument('--all_channels', '-ac', default=False, action='store_false', help='Whether to select patches from all views') parser.add_argument('--normalized_by_gt', '-ngt', default=True, action='store_true', help='Whether to normalize input data by ground-truth maximum') @@ -65,7 +68,7 @@ def parse_args(): help='Probability above which you draw patches from high dose regions') parser.add_argument("--single_frame", '-sf', action='store_false', help='Whether you train on single frame instead of sequence') - parser.add_argument("--mode", '-m', default='infinite', type=str, + parser.add_argument("--mode", '-mode', default='infinite', type=str, help="Whether to do finite or infinite training") parser.add_argument("--lr_scheduler", '-lrs', default='plateau', type=str, help="Name of the learning rate scheduler to use") @@ -107,7 +110,7 @@ def instantiate_model(args): unet = True else: print("Unrecognized model") - break + for i in range(1):break print("Model: ", args.model_name) return model, unet @@ -128,7 +131,7 @@ def create_loss(args): elif args.loss_name == "ssim": loss = ssim_loss elif args.loss_name == "ssim-smoothl1": loss = ssim_smoothl1 elif args.loss_name == "ssim-mse": loss = ssim_mse - print("Loss used: ", loss_name) + print("Loss used: ", args.loss_name) return loss @@ -180,12 +183,14 @@ def create_saving_framework(args): # Write configuration file with open(save_path + '/config.ini', 'w') as conf: config_object.write(conf) + return save_path def create_dataloaders(args, cases, unet): n_train = args.n_train train = MC3DInfinitePatchDataset(cases[:n_train], n_frames = args.n_frames, + ct_path = args.ct_path, patch_size = args.patch_size, all_channels = args.all_channels, normalized_by_gt = args.normalized_by_gt, @@ -205,6 +210,7 @@ def create_dataloaders(args, cases, unet): val = MC3DInfinitePatchDataset(cases[n_train:n_train+5], n_frames = args.n_frames, + ct_path = args.ct_path, patch_size = args.patch_size, all_channels = args.all_channels, normalized_by_gt = args.normalized_by_gt, @@ -224,5 +230,5 @@ def create_dataloaders(args, cases, unet): train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, num_workers=6) - val_loader = DataLoader(val, batch_size=val_batch_size, shuffle=False, num_workers=6) + val_loader = DataLoader(val, batch_size=args.batch_size, shuffle=False, num_workers=6) return train_loader, val_loader