diff --git a/mlp_pipeline/model_mlp.py b/mlp_pipeline/model_mlp.py new file mode 100644 index 0000000..874292a --- /dev/null +++ b/mlp_pipeline/model_mlp.py @@ -0,0 +1,25 @@ +from torch import nn + +class MLP(nn.Module): + def __init__(self, input_num_features, neurons): + super().__init__() + self.layers = self._create_layers(input_num_features, neurons) + + def _create_layers(self, input_features, neurons): + task_layers = [] + if isinstance(neurons, list): + for index, arg in enumerate(neurons): + if index == len(neurons) - 1: + task_layers += [nn.Linear(input_features, arg)] + else: + task_layers += [nn.Linear(input_features, arg), nn.ReLU(), nn.BatchNorm1d(arg), nn.Dropout(p=0.2)] + input_features = arg + return nn.Sequential(*task_layers) + else: + print("Input neurons must be list") + + def forward(self, x): + x = self.layers(x) + return x + + diff --git a/mlp_pipeline/train_mlp.py b/mlp_pipeline/train_mlp.py new file mode 100644 index 0000000..13d4496 --- /dev/null +++ b/mlp_pipeline/train_mlp.py @@ -0,0 +1,325 @@ +import pandas as pd +import torch.utils.data as data_utils +import numpy as np +from sklearn.preprocessing import StandardScaler +import matplotlib.pyplot as plt +import torch.nn.functional as F +from model_mlp import MLP +from sklearn.model_selection import train_test_split +from sklearn.metrics import confusion_matrix, precision_score, accuracy_score, recall_score, \ + roc_curve, roc_auc_score, auc +import torch.optim as optim +from keras.utils import np_utils +from itertools import cycle +from scipy import interp +from sklearn.metrics import confusion_matrix as cm +from mlxtend.plotting import plot_confusion_matrix +import torch + +class TrainModel: + def __init__(self, input_filter, hidden_neurons, num_epochs, batch_size, weights_path): + self.input_neuron = input_filter + self.hidden_layers = hidden_neurons + self.epochs = num_epochs + self.batch = batch_size + self.thresh = 0.5 + self.path_to_weights = weights_path + self.device = self._get_device() + + def _get_device(self): + train_on_gpu = torch.cuda.is_available() + if not train_on_gpu: + device = torch.device("cpu") + else: + device = torch.device("cuda:0") + return device + + def start_training(self, path_to_csv): + df = pd.read_csv(path_to_csv) + shuffled_df = df.sample(frac=1) + # Get numpy data from csv + data_y = shuffled_df.iloc[:, -1].to_numpy() + data_x = shuffled_df.drop(['slice_id', 'label'], axis=1).to_numpy() + # Split the dataset + x_train, x_valid, y_train, y_valid = train_test_split(data_x, data_y, test_size=0.2, random_state=42) + # Standarize the training data + sc = StandardScaler() + x_train, x_valid = sc.fit_transform(x_train), sc.fit_transform(x_valid) + # Covert to tensors + x_train, y_train = torch.from_numpy(x_train), torch.from_numpy(y_train) + x_valid, y_valid = torch.from_numpy(x_valid), torch.from_numpy(y_valid) + # Create datasets + train = data_utils.TensorDataset(x_train.float(), y_train) + validation = data_utils.TensorDataset(x_valid.float(), y_valid) + trainloader = data_utils.DataLoader(train, batch_size=self.batch, shuffle=True, drop_last=True) + validloader = data_utils.DataLoader(validation, batch_size=self.batch, shuffle=True, drop_last=True) + + # Instantiate model and other parameters + model = MLP(self.input_neuron, self.hidden_layers).to(self.device) + optimizer = optim.SGD(model.parameters(), lr=0.001) + criterion = torch.nn.CrossEntropyLoss() + scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True) + # Varibles to track + train_losses, val_losses, aucs = [], [], [] + valid_loss_min = np.Inf + # Training loop + metrics = {'accuracy': {0: [], 1: [], 2: []}, + 'sensitivity': {0: [], 1: [], 2: []}, + 'specificity': {0: [], 1: [], 2: []} + } + + for epoch in range(self.epochs): + running_train_loss, running_val_loss = 0.0, 0.0 + epoch_loss = [] + model.train() + for images, labels in trainloader: + images, labels = images.to(self.device), labels.to(self.device, dtype=torch.long) + optimizer.zero_grad() + output = model(images) + loss = criterion(output, labels) + loss.backward() + optimizer.step() + #running_train_loss += loss.detach().item() + running_train_loss += float(loss.item()) * images.size(0) + epoch_loss.append(float(loss.item() * images.size(0))) + + scheduler.step(np.mean(epoch_loss)) + # Validation loop + with torch.no_grad(): + model.eval() + y_truth, y_prediction, scores = [], [], [] + for images, labels in validloader: + images, labels = images.to(self.device), labels.to(self.device, dtype=torch.long) + output = model(images) + loss = criterion(output, labels) + running_val_loss += float(loss.item()) * images.size(0) + output_pb = F.softmax(output.cpu(), dim=1) + top_ps, top_class = output_pb.topk(1, dim=1) + y_prediction.extend(list(top_class.flatten().numpy())) + y_truth.extend(list(labels.cpu().flatten().numpy())) + scores.extend(output_pb.numpy().tolist()) + + avg_train_loss = running_train_loss / len(trainloader) + avg_val_loss = running_val_loss / len(validloader) + cnf_matrix = cm(y_truth, y_prediction, labels=[0, 1, 2]) + + # Compute evaluations + FP = cnf_matrix.sum(axis=0) - np.diag(cnf_matrix) + FN = cnf_matrix.sum(axis=1) - np.diag(cnf_matrix) + TP = np.diag(cnf_matrix) + TN = cnf_matrix.sum() - (FP + FN + TP) + # Convert to float + f_p = FP.astype(float) + f_n = FN.astype(float) + t_p = TP.astype(float) + t_n = TN.astype(float) + + # Compute metrics + accuracy = (t_p + t_n) / (f_p + f_n + t_p + t_n) + recall_sensitivity = t_p / (t_p + f_n) + specificity = t_n / (t_n + f_p) + precision = t_p / (t_p + f_p) + one_hot_true = np_utils.to_categorical(y_truth, num_classes=3) + model_auc = roc_auc_score(one_hot_true, scores, average='weighted') + + # Append losses and track metrics + train_losses.append(avg_train_loss) + val_losses.append(avg_val_loss) + for index in range(3): + metrics['accuracy'][index].append(accuracy[index]) + metrics['sensitivity'][index].append(recall_sensitivity[index]) + metrics['specificity'][index].append(specificity[index]) + aucs.append(model_auc) + print("Epoch:{}/{} - Training Loss:{:.6f} | Validation Loss: {:.6f}".format( + epoch + 1, self.epochs, avg_train_loss, avg_val_loss)) + print("Accuracy:{}\nPrecision:{}\nSensitivity:{}\nSpecificity:{}\nAUC:{}".format( + accuracy, precision, recall_sensitivity, specificity, model_auc)) + + # Save model + if avg_val_loss <= valid_loss_min: + print("Validation loss decreased ({:.6f} --> {:.6f}). Saving model ...".format(valid_loss_min, + avg_val_loss)) + print("-" * 40) + torch.save(model.state_dict(), "..\\checkpoints\\MLP_Covid_Viral_Normal.pth") + # Update minimum loss + valid_loss_min = avg_val_loss + + # Save plots + plt.plot(train_losses, label='Training loss') + plt.plot(val_losses, label='Validation loss') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.legend(frameon=False) + plt.savefig('losses.png') + plt.clf() + + plt.plot(metrics["accuracy"][0], label='Normal') + plt.plot(metrics["accuracy"][1], label='Viral') + plt.plot(metrics["accuracy"][2], label='Covid') + plt.xlabel('Epoch') + plt.ylabel('Score') + plt.legend(frameon=False) + plt.savefig('accuracy.png') + plt.clf() + + plt.plot(metrics["sensitivity"][0], label='Normal') + plt.plot(metrics["sensitivity"][1], label='Viral') + plt.plot(metrics["sensitivity"][2], label='Covid') + plt.xlabel('Epoch') + plt.ylabel('Score') + plt.legend(frameon=False) + plt.savefig('sensitivity.png') + plt.clf() + + plt.plot(metrics["specificity"][0], label='Normal') + plt.plot(metrics["specificity"][1], label='Viral') + plt.plot(metrics["specificity"][2], label='Covid') + plt.xlabel('Epoch') + plt.ylabel('Score') + plt.legend(frameon=False) + plt.savefig('specificity.png') + plt.clf() + + plt.plot(aucs, label='AUCs') + plt.xlabel('Epoch') + plt.ylabel('AUC') + plt.legend(frameon=False) + plt.savefig('aucs.png') + plt.clf() + + + def run_testset(self, path_to_csv): + df = pd.read_csv(path_to_csv) + shuffled_df = df.sample(frac=1) + # Get numpy data from csv + data_y = shuffled_df.iloc[:, -1].to_numpy() + data_x = shuffled_df.drop(['slice_id', 'label'], axis=1).to_numpy() + # Standarize the test data + sc = StandardScaler() + data_x = sc.fit_transform(data_x) + + # Covert to tensors + x_test, y_test = torch.from_numpy(data_x), torch.from_numpy(data_y) + # Create datasets + test = data_utils.TensorDataset(x_test.float(), y_test) + testloader = data_utils.DataLoader(test, batch_size=self.batch, shuffle=True, drop_last=True) + + # Instantiate model and other parameters + model = MLP(self.input_neuron, self.hidden_layers) + if self.path_to_weights: + print("=" * 40) + print("Model Weights Loaded") + print("=" * 40) + weights = torch.load(self.path_to_weights) + model.load_state_dict(weights) + model.to(self.device) + + criterion = torch.nn.CrossEntropyLoss() + with torch.no_grad(): + model.eval() + y_truth, y_prediction, scores = [], [], [] + running_test_loss = 0.0 + for images, labels in testloader: + images, labels = images.to(self.device), labels.to(self.device, dtype=torch.long) + output = model(images) + loss = criterion(output, labels) + running_test_loss += float(loss.item()) * images.size(0) + output_pb = F.softmax(output.cpu(), dim=1) + top_ps, top_class = output_pb.topk(1, dim=1) + y_prediction.extend(list(top_class.flatten().numpy())) + y_truth.extend(list(labels.cpu().flatten().numpy())) + scores.extend(output_pb.numpy().tolist()) + + avg_test_loss = running_test_loss / len(testloader) + cnf_matrix = cm(y_truth, y_prediction, labels=[0, 1, 2]) + fig, ax = plot_confusion_matrix(conf_mat=cnf_matrix) + plt.show() + # Compute evaluations + FP = cnf_matrix.sum(axis=0) - np.diag(cnf_matrix) + FN = cnf_matrix.sum(axis=1) - np.diag(cnf_matrix) + TP = np.diag(cnf_matrix) + TN = cnf_matrix.sum() - (FP + FN + TP) + # Convert to float + f_p = FP.astype(float) + f_n = FN.astype(float) + t_p = TP.astype(float) + t_n = TN.astype(float) + + # Compute metrics + accuracy = (t_p + t_n) / (f_p + f_n + t_p + t_n) + recall_sensitivity = t_p / (t_p + f_n) + specificity = t_n / (t_n + f_p) + precision = t_p / (t_p + f_p) + one_hot_true = np_utils.to_categorical(y_truth, num_classes=3) + model_auc = roc_auc_score(one_hot_true, scores, average='weighted') + + print("Test loss:{:.6f}".format(avg_test_loss)) + print("Accuracy:{}\nPrecision:{}\nSensitivity:{}\nSpecificity:{}\nAUC:{}".format(accuracy, + precision, + recall_sensitivity, + specificity, model_auc)) + + # Draw ROC Curve + fpr = dict() + tpr = dict() + roc_auc = dict() + # Computer FPR, TPR for two classes + scores = np.array(scores) + for i in range(3): + fpr[i], tpr[i], _ = roc_curve(one_hot_true[:, i], scores[:, i]) + roc_auc[i] = auc(fpr[i], tpr[i]) + """ + # Computer Micro FPR, TPR + fpr["micro"], tpr["micro"], _ = roc_curve(one_hot_true_lbls.ravel(), y_scores.ravel()) + roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) + """ + # Computer Macro FPR, TPR + all_fpr = np.unique(np.concatenate([fpr[i] for i in range(3)])) + # Then interpolate all ROC curves at this points + mean_tpr = np.zeros_like(all_fpr) + for i in range(3): + mean_tpr += interp(all_fpr, fpr[i], tpr[i]) + # Average and compute AUC + mean_tpr /= 3 + fpr["macro"] = all_fpr + tpr["macro"] = mean_tpr + roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) + + # Plot the ROC + plt.figure() + # Plot the micro-average + plt.plot(fpr["macro"], tpr["macro"], label='Average AUC Curve (area = {0:0.2f})'.format(roc_auc["macro"]), + color='deeppink', linestyle=':', linewidth=4) + colors = cycle(['aqua', 'darkorange', 'teal']) + # Plot the different classes + for i, color in zip(range(3), colors): + plt.plot(fpr[i], tpr[i], color=color, lw=2, + label='ROC curve of class {0} (area = {1:0.2f})'.format(i, roc_auc[i])) + + # Plot the figure + plt.plot([0, 1], [0, 1], 'k--', lw=2) # plot the middle line + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + plt.title('ROC - Malignancy') + plt.legend(loc="lower right") + plt.savefig('multiclass_auc.png') + plt.clf() + + +if __name__ == "__main__": + # Hyper-param + in_channel = 21 + hidden_outputs = [128, 64, 32, 3] + num_epcohs = 500 + batches = 50 + path_to_weights = None # Required during predictions + train_csv = "path_to_train.csv" + test_csv = "path_to_test.csv" + train_obj = TrainModel(in_channel, hidden_outputs, num_epcohs, batches, path_to_weights) + """ + Note: validation set is created out of training set + """ + train_obj.start_training(train_csv) + #train_obj.run_testset(test_csv) diff --git a/run_maskrcnn.py b/run_maskrcnn.py index 9d3164e..8e63ee0 100644 --- a/run_maskrcnn.py +++ b/run_maskrcnn.py @@ -1,6 +1,6 @@ from model_maskrcnn import MaskRCNN -path_to_weights = "path_to_weights" # Required when doing predictions +path_to_weights = None # Required when doing predictions path_to_images = "path_to_train_images\\" path_to_masks = "path_to_train_masks\\" path_to_test_images = "path_to_test_images\\" diff --git a/unet_pipeline/__init__.py b/unet_pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/unet_pipeline/__pycache__/build_unet_model.cpython-37.pyc b/unet_pipeline/__pycache__/build_unet_model.cpython-37.pyc new file mode 100644 index 0000000..5924ce4 Binary files /dev/null and b/unet_pipeline/__pycache__/build_unet_model.cpython-37.pyc differ diff --git a/unet_pipeline/__pycache__/custom_loss.cpython-37.pyc b/unet_pipeline/__pycache__/custom_loss.cpython-37.pyc new file mode 100644 index 0000000..7ef7eaa Binary files /dev/null and b/unet_pipeline/__pycache__/custom_loss.cpython-37.pyc differ diff --git a/unet_pipeline/__pycache__/dataloader_unet.cpython-37.pyc b/unet_pipeline/__pycache__/dataloader_unet.cpython-37.pyc new file mode 100644 index 0000000..5b78596 Binary files /dev/null and b/unet_pipeline/__pycache__/dataloader_unet.cpython-37.pyc differ diff --git a/unet_pipeline/__pycache__/inferences.cpython-37.pyc b/unet_pipeline/__pycache__/inferences.cpython-37.pyc new file mode 100644 index 0000000..40a3ece Binary files /dev/null and b/unet_pipeline/__pycache__/inferences.cpython-37.pyc differ diff --git a/unet_pipeline/__pycache__/unet_parts.cpython-37.pyc b/unet_pipeline/__pycache__/unet_parts.cpython-37.pyc new file mode 100644 index 0000000..f434f6c Binary files /dev/null and b/unet_pipeline/__pycache__/unet_parts.cpython-37.pyc differ diff --git a/unet_pipeline/build_unet_model.py b/unet_pipeline/build_unet_model.py new file mode 100644 index 0000000..37cf87c --- /dev/null +++ b/unet_pipeline/build_unet_model.py @@ -0,0 +1,169 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from collections import OrderedDict + +class SegNet(nn.Module): + def __init__(self,input_nbr,label_nbr): + super(SegNet, self).__init__() + self.conv11 = nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1) + self.bn11 = nn.BatchNorm2d(64) + self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1) + self.bn12 = nn.BatchNorm2d(64) + + self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1) + self.bn21 = nn.BatchNorm2d(128) + self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1) + self.bn22 = nn.BatchNorm2d(128) + + self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1) + self.bn31 = nn.BatchNorm2d(256) + self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.bn32 = nn.BatchNorm2d(256) + self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.bn33 = nn.BatchNorm2d(256) + + self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1) + self.bn41 = nn.BatchNorm2d(512) + self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.bn42 = nn.BatchNorm2d(512) + self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.bn43 = nn.BatchNorm2d(512) + + self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.bn51 = nn.BatchNorm2d(512) + self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.bn52 = nn.BatchNorm2d(512) + self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.bn53 = nn.BatchNorm2d(512) + + self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.bn53d = nn.BatchNorm2d(512) + self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.bn52d = nn.BatchNorm2d(512) + self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.bn51d = nn.BatchNorm2d(512) + + self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.bn43d = nn.BatchNorm2d(512) + self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1) + self.bn42d = nn.BatchNorm2d(512) + self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1) + self.bn41d = nn.BatchNorm2d(256) + + self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.bn33d = nn.BatchNorm2d(256) + self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.bn32d = nn.BatchNorm2d(256) + self.conv31d = nn.Conv2d(256, 128, kernel_size=3, padding=1) + self.bn31d = nn.BatchNorm2d(128) + + self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1) + self.bn22d = nn.BatchNorm2d(128) + self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1) + self.bn21d = nn.BatchNorm2d(64) + + self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1) + self.bn12d = nn.BatchNorm2d(64) + self.conv11d = nn.Conv2d(64, label_nbr, kernel_size=3, padding=1) + self.Dropout = nn.Dropout(0.5) + + + def forward(self, x): + # Stage 1 + x11 = F.relu(self.bn11(self.conv11(x))) + x11 = self.Dropout(x11) + x12 = F.relu(self.bn12(self.conv12(x11))) + x1p, id1 = F.max_pool2d(x12,kernel_size=2, stride=2,return_indices=True) + + # Stage 2 + x21 = F.relu(self.bn21(self.conv21(x1p))) + x22 = F.relu(self.bn22(self.conv22(x21))) + x2p, id2 = F.max_pool2d(x22,kernel_size=2, stride=2,return_indices=True) + + # Stage 3 + x31 = F.relu(self.bn31(self.conv31(x2p))) + x31 = self.Dropout(x31) + x32 = F.relu(self.bn32(self.conv32(x31))) + x33 = F.relu(self.bn33(self.conv33(x32))) + x3p, id3 = F.max_pool2d(x33,kernel_size=2, stride=2,return_indices=True) + + # Stage 4 + x41 = F.relu(self.bn41(self.conv41(x3p))) + x42 = F.relu(self.bn42(self.conv42(x41))) + x43 = F.relu(self.bn43(self.conv43(x42))) + x4p, id4 = F.max_pool2d(x43,kernel_size=2, stride=2,return_indices=True) + + # Stage 5 + x51 = F.relu(self.bn51(self.conv51(x4p))) + x51 = self.Dropout(x51) + x52 = F.relu(self.bn52(self.conv52(x51))) + x53 = F.relu(self.bn53(self.conv53(x52))) + x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=2,return_indices=True) + + # Stage 5d + x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2) + x53d = F.relu(self.bn53d(self.conv53d(x5d))) + x52d = F.relu(self.bn52d(self.conv52d(x53d))) + x51d = F.relu(self.bn51d(self.conv51d(x52d))) + + # Stage 4d + x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2) + x43d = F.relu(self.bn43d(self.conv43d(x4d))) + x42d = F.relu(self.bn42d(self.conv42d(x43d))) + x41d = F.relu(self.bn41d(self.conv41d(x42d))) + + # Stage 3d + x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2) + x33d = F.relu(self.bn33d(self.conv33d(x3d))) + x32d = F.relu(self.bn32d(self.conv32d(x33d))) + x31d = F.relu(self.bn31d(self.conv31d(x32d))) + + # Stage 2d + x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2) + x22d = F.relu(self.bn22d(self.conv22d(x2d))) + x21d = F.relu(self.bn21d(self.conv21d(x22d))) + + # Stage 1d + x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2) + x12d = F.relu(self.bn12d(self.conv12d(x1d))) + x11d = self.conv11d(x12d) + return x11d + + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from unet_parts import * + +class UNet(nn.Module): + def __init__(self, n_channels, n_classes): + super(UNet, self).__init__() + self.inc = inconv(n_channels, 64) + self.down1 = down(64, 128) + self.down2 = down(128, 256) + self.down3 = down(256, 512) + self.down4 = down(512, 512) + self.up1 = up(1024, 256) + self.up2 = up(512, 128) + self.up3 = up(256, 64) + self.up4 = up(128, 64) + self.outc = outconv(64, n_classes) + self.Dropout = nn.Dropout(0.5) + + def forward(self, x): + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x3 = self.Dropout(x3) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + x = self.outc(x) + return x + + diff --git a/unet_pipeline/custom_loss.py b/unet_pipeline/custom_loss.py new file mode 100644 index 0000000..486df23 --- /dev/null +++ b/unet_pipeline/custom_loss.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn + +class SoftDiceLoss(nn.Module): + """ + Soft Dice Loss + """ + def __init__(self): + super(SoftDiceLoss, self).__init__() + + def forward(self, logits, targets): + smooth = 1. + logits = torch.sigmoid(logits) + iflat = logits.view(-1) + tflat = targets.view(-1) + intersection = (iflat * tflat).sum() + return 1 - ((2. * intersection + smooth) /(iflat.sum() + tflat.sum() + smooth)) + +class InvSoftDiceLoss(nn.Module): + """ + Inverted Soft Dice Loss + """ + def __init__(self): + super(InvSoftDiceLoss, self).__init__() + + def forward(self, logits, targets): + smooth = 1. + logits = torch.sigmoid(logits) + iflat = 1-logits.view(-1) + tflat = 1-targets.view(-1) + intersection = (iflat * tflat).sum() + return 1 - ((2. * intersection + smooth) / (iflat.sum() + tflat.sum() + smooth)) + diff --git a/unet_pipeline/dataloader_unet.py b/unet_pipeline/dataloader_unet.py new file mode 100644 index 0000000..f28b44d --- /dev/null +++ b/unet_pipeline/dataloader_unet.py @@ -0,0 +1,28 @@ +from torch.utils.data.dataset import Dataset +from PIL import Image +from os import listdir + +class DataProcessor(Dataset): + def __init__(self, imgs_dir, masks_dir, transformation=None): + self.imgs_dir = imgs_dir + self.masks_dir = masks_dir + self.transformations = transformation + self.imgs_ids = [file for file in listdir(imgs_dir)] + self.mask_ids = [file for file in listdir(masks_dir)] + + def __getitem__(self, i): + img_idx = self.imgs_ids[i] + mask_idx = self.mask_ids[i] + image_path = self.imgs_dir + img_idx + mask_path = self.masks_dir + mask_idx + image = Image.open(image_path).convert("RGB") + mask = Image.open(mask_path).convert('L') + if self.transformations is not None: + image = self.transformations(image) + mask = self.transformations(mask) + return image, mask + + def __len__(self): + return len(self.imgs_ids) + + diff --git a/unet_pipeline/default_tranforms.py b/unet_pipeline/default_tranforms.py new file mode 100644 index 0000000..9b3784e --- /dev/null +++ b/unet_pipeline/default_tranforms.py @@ -0,0 +1,75 @@ +import numpy as np +from skimage.transform import rotate +from skimage.util import random_noise, img_as_ubyte +import random +import cv2 +import torch +from skimage.measure import label +from skimage import measure + +def make_binary(mask): + coords = np.where(mask != 0) + mask[coords] = 1 + return mask + +def assign_labels(mask): + coords = np.where(mask != 0) + mask[coords] = 1 + labeled = label(mask) + regions = measure.regionprops(labeled) + if len(regions) > 2: + areas = [r.area for r in measure.regionprops(labeled)] + areas.sort() + for region in measure.regionprops(labeled): + if region.area < areas[-2]: + for coordinates in region.coords: + labeled[coordinates[0], coordinates[1]] = 0 + coords = np.where(labeled != 0) + labeled[coords] = 1 + new_labeled = label(labeled) + return new_labeled + else: + return labeled + +class RandomRotate: + def __call__(self, data): + rotation_angle = random.randint(-180, 180) + img, mask = data['img'], data['mask'] + img = rotate(img, rotation_angle, mode='reflect').astype(float) + mask = assign_labels(rotate(img_as_ubyte(mask), rotation_angle, mode='reflect')).astype(float) + return {"img": img, "mask": mask} + +class HorizontalFlip: + def __call__(self, data): + img, mask = data['img'], data['mask'] + h_img = np.fliplr(img).astype(float) + h_mask = np.fliplr(mask).astype(float) + return {"img": h_img, "mask": h_mask} + +class VerticalFlip: + def __call__(self, data): + img, mask = data['img'], data['mask'] + v_img = np.flipud(img).astype(float) + v_mask = np.flipud(mask).astype(float) + return {"img": v_img, "mask": v_mask} + +class RandomNoise: + def __call__(self, data): + img, mask = data['img'], data['mask'] + noised_img = random_noise(img).astype(float) + mask = mask.astype(float) + return {"img": noised_img, "mask": mask} + +class RandomBlur(object): + def __call__(self, data): + img, mask = data['img'], data['mask'] + blur_factor = random.randrange(1, 10, 2) + blurred_img = cv2.GaussianBlur(img, (blur_factor, blur_factor), 0) + return {"img": blurred_img, "mask": mask} + +class ToTensor(object): + def __call__(self, data): + img, mask = data['img'], data['mask'] + tensored = torch.from_numpy(img) + return {"img": tensored, "mask": mask} + diff --git a/unet_pipeline/inferences.py b/unet_pipeline/inferences.py new file mode 100644 index 0000000..728cad8 --- /dev/null +++ b/unet_pipeline/inferences.py @@ -0,0 +1,25 @@ +import numpy as np + +def _get_iou_vector(target, prediction): + avg_iou = 0.0 + # print("traget shape:", target.shape) + # print("prediction shape:", prediction.shape) + for i in range(target.shape[0]): + target_i = target[i] + prediction_i = prediction[i] + intersection = np.logical_and(target_i, prediction_i) + union = np.logical_or(target_i, prediction_i) + avg_iou += np.sum(intersection) / np.sum(union) + #delete variable to save memory + del intersection, union + return avg_iou / target.shape[0] + +def IoU(y_true, y_pred): + """Returns Intersection over Union score for ground truth and predicted masks.""" + assert y_true.dtype == bool and y_pred.dtype == bool + y_true_f = y_true.flatten() + y_pred_f = y_pred.flatten() + intersection = np.logical_and(y_true_f, y_pred_f).sum() + union = np.logical_or(y_true_f, y_pred_f).sum() + return (intersection + 1) * 1. / (union + 1) + diff --git a/unet_pipeline/train_unet.py b/unet_pipeline/train_unet.py new file mode 100644 index 0000000..5b1cda0 --- /dev/null +++ b/unet_pipeline/train_unet.py @@ -0,0 +1,193 @@ +from build_unet_model import * +from custom_loss import * +from inferences import _get_iou_vector +from dataloader_unet import DataProcessor +import matplotlib.pyplot as plt +import torchvision.transforms as transforms +from tqdm import tqdm +from torch.utils.data import DataLoader, random_split +from torch.autograd import Variable +import torch +import torch.nn as nn +import numpy as np + +class TrainModel: + def __init__(self, image_channel, num_out_classes, num_epochs, batch_size, weights_path): + self.image_channel = image_channel + self.num_out_classes = num_out_classes + self.epochs = num_epochs + self.batch = batch_size + self.path_to_weights = weights_path + self.device = self._get_device() + + def _get_device(self): + train_on_gpu = torch.cuda.is_available() + if not train_on_gpu: + device = torch.device("cpu") + else: + device = torch.device("cuda:0") + return device + + def _get_default_transforms(self): + my_transforms = transforms.Compose([transforms.ToPILImage(), transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.5), + transforms.RandomCrop((224, 224)), transforms.RandomRotation(degrees=45), transforms.RandomVerticalFlip(p=0.2), + transforms.ToTensor()]) + return my_transforms + + def start_training(self, path_to_images, path_to_masks, transformation=None, val_percent=0.2, lr_rate=1e-4): + if transformation is None: + transformations_train = transforms.Compose([transforms.ToTensor()]) + dataset = DataProcessor(imgs_dir=path_to_images, masks_dir=path_to_masks, transformation=transformations_train) + else: + dataset = DataProcessor(imgs_dir=path_to_images, masks_dir=path_to_masks, transformation=transformation) + n_val = int(len(dataset) * val_percent) + n_train = len(dataset) - n_val + train, val = random_split(dataset, [n_train, n_val]) + print("=" * 40) + print("Images for Training:", n_train) + print("Images for Validation:", n_val) + print("="*40) + trainloader = DataLoader(train, batch_size=self.batch, shuffle=True, drop_last=True) + validloader = DataLoader(val, batch_size=self.batch, shuffle=True, drop_last=True) + + # Instantiate model and other parameters + model = UNet(self.image_channel, self.num_out_classes).to(self.device) + # Load Weights if available + if path_to_weights: + weights = torch.load(path_to_weights, map_location=self.device) + model.load_state_dict(weights) + # Define three losses + criterion1 = nn.BCEWithLogitsLoss().to(self.device) + criterion2 = SoftDiceLoss().to(self.device) + criterion3 = InvSoftDiceLoss().to(self.device) + optimizer = torch.optim.Adam(model.parameters(), lr=lr_rate) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True) + # Varibles to track + running_bce, running_dice, running_invtdice, running_loss_comb = [], [], [], [] + val_bce, val_dice, val_invtdice, val_loss_comb, ious = [], [], [], [], [] + global_avg_iou = 0.0 + + for epoch in tqdm(range(self.epochs)): + running_bce_loss, running_dice_loss, running_invtdice_loss, running_train_loss = 0.0, 0.0, 0.0, 0.0 + val_bce_loss, val_dice_loss, val_invtdice_loss, val_loss = 0.0, 0.0, 0.0, 0.0 + epoch_loss, avg_iou = [], 0.0 + model.train() + for i, (images, masks) in tqdm(enumerate(trainloader)): + images = Variable(images).to(self.device) + masks = Variable(masks).to(self.device) + optimizer.zero_grad() + outputs = model(images) + bce_loss, dice_loss, invt_dice_loss = criterion1(outputs, masks), criterion2(outputs, masks), criterion3(outputs, masks) + loss = bce_loss + dice_loss + invt_dice_loss + loss.backward() + optimizer.step() + # Track train loss + running_bce_loss += float(bce_loss.item()) * images.size(0) + running_dice_loss += float(dice_loss.item()) * images.size(0) + running_invtdice_loss += float(invt_dice_loss.item()) * images.size(0) + running_train_loss += float(loss.item()) * images.size(0) + epoch_loss.append(float(loss.item() * images.size(0))) # For scheluder + + scheduler.step(np.mean(epoch_loss)) + with torch.no_grad(): + model.eval() + for images, masks in tqdm(validloader): + images = Variable(images).to(self.device) + masks = Variable(masks).to(self.device) + outputs = model(images) + output_prob = torch.sigmoid(outputs).detach().cpu().numpy() + output_gt = masks.detach().cpu().numpy() + output_prob_thresh = (output_prob > 0.5) * 1 + avg_iou += _get_iou_vector(output_gt, output_prob_thresh) + # Calculate Losses + bce_loss, dice_loss, invt_dice_loss = criterion1(outputs, masks), criterion2(outputs, masks), criterion3(outputs, masks) + loss = bce_loss + dice_loss + invt_dice_loss + # Track val loss + val_bce_loss += float(bce_loss.item()) * images.size(0) + val_dice_loss += float(dice_loss.item()) * images.size(0) + val_invtdice_loss += float(invt_dice_loss.item()) * images.size(0) + val_loss += float(loss.item()) * images.size(0) + + # Average the metrics + avg_iou_batch = avg_iou / len(validloader) # IoU + # Average train metrics + avg_train_bce_loss = running_bce_loss / len(trainloader) + avg_train_dice_loss = running_dice_loss / len(trainloader) + avg_train_invtdice_loss = running_invtdice_loss / len(trainloader) + avg_train_loss = running_train_loss / len(trainloader) + + # Average Val metrics + avg_val_bce_loss = val_bce_loss / len(validloader) + avg_val_dice_loss = val_dice_loss / len(validloader) + avg_val_invtdice_loss = val_invtdice_loss / len(validloader) + avg_val_loss_combined = val_loss / len(validloader) + + # Append metrics for tracking + running_bce.append(avg_train_bce_loss) + running_dice.append(avg_train_dice_loss) + running_invtdice.append(avg_train_invtdice_loss) + running_loss_comb.append(avg_train_loss) + val_bce.append(avg_val_bce_loss) + val_dice.append(avg_val_dice_loss) + val_invtdice.append(avg_val_invtdice_loss) + val_loss_comb.append(avg_val_loss_combined) + ious.append(avg_iou_batch) + + print("Epoch {}, Training Loss(BCE+Dice+InvtDice): {}, Validation Loss(BCE): {}, Validation Loss(Dice): {}, Validation Loss(InvDice): {} Average IoU: {}".format(epoch + 1, avg_train_loss, + avg_val_bce_loss, avg_val_dice_loss, + avg_val_invtdice_loss, avg_iou_batch)) + + if avg_iou_batch > global_avg_iou: + print("Average mask IoU increased: ({:.6f} --> {:.6f}). Saving model ...".format(global_avg_iou, + avg_iou_batch)) + print("-" * 40) + # Save model + torch.save(model.state_dict(), 'checkpoints/LungMask_UNet.pth') + global_avg_iou = avg_iou_batch + + # Save Plots + # Train loss + plt.plot(running_bce, label='BCE loss') + plt.plot(running_dice, label='Dice loss') + plt.plot(running_invtdice, label='Invt Dice loss') + plt.plot(running_loss_comb, label='Combined loss') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.legend(frameon=False) + plt.savefig('train_losses.png') + plt.clf() + + # Val loss + plt.plot(val_bce, label='BCE loss') + plt.plot(val_dice, label='Dice loss') + plt.plot(val_invtdice, label='Invt Dice loss') + plt.plot(val_loss_comb, label='Combined loss') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.legend(frameon=False) + plt.savefig('val_losses.png') + plt.clf() + + # IoU + plt.plot(ious, label='Mask IoU') + plt.xlabel('Epoch') + plt.ylabel('IoU') + plt.legend(frameon=False) + plt.savefig('val_iou.png') + plt.clf() + + +if __name__ == "__main__": + # Hyper-param + img_channel = 3 + out_num_class = 1 + num_epcohs = 100 + batches = 5 + path_to_weights = None + train_images = "path_to_train_images\\" + train_masks = "path_to_train_masks\\" + train_obj = TrainModel(img_channel, out_num_class, num_epcohs, batches, path_to_weights) + """ + Note: validation set is create out of training images + """ + train_obj.start_training(train_images, train_masks) diff --git a/unet_pipeline/unet.png b/unet_pipeline/unet.png new file mode 100644 index 0000000..312c59f Binary files /dev/null and b/unet_pipeline/unet.png differ diff --git a/unet_pipeline/unet_parts.py b/unet_pipeline/unet_parts.py new file mode 100644 index 0000000..be58b3a --- /dev/null +++ b/unet_pipeline/unet_parts.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class double_conv(nn.Module): + '''(conv => BN => ReLU) * 2''' + def __init__(self, in_ch, out_ch): + super(double_conv, self).__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + nn.Conv2d(out_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + x = self.conv(x) + return x + + +class inconv(nn.Module): + def __init__(self, in_ch, out_ch): + super(inconv, self).__init__() + self.conv = double_conv(in_ch, out_ch) + + def forward(self, x): + x = self.conv(x) + return x + + +class down(nn.Module): + def __init__(self, in_ch, out_ch): + super(down, self).__init__() + self.mpconv = nn.Sequential( + nn.MaxPool2d(2), + double_conv(in_ch, out_ch) + ) + + def forward(self, x): + x = self.mpconv(x) + return x + + +class up(nn.Module): + def __init__(self, in_ch, out_ch, bilinear=True): + super(up, self).__init__() + if bilinear: + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + else: + self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) + + self.conv = double_conv(in_ch, out_ch) + + def forward(self, x1, x2): + x1 = self.up(x1) + diffX = x1.size()[2] - x2.size()[2] + diffY = x1.size()[3] - x2.size()[3] + x2 = F.pad(x2, (diffX // 2, int(diffX / 2), + diffY // 2, int(diffY / 2))) + x = torch.cat([x2, x1], dim=1) + x = self.conv(x) + return x + +class outconv(nn.Module): + def __init__(self, in_ch, out_ch): + super(outconv, self).__init__() + self.conv = nn.Conv2d(in_ch, out_ch, 1) + + def forward(self, x): + x = self.conv(x) + return x