From 80c5ab90dbfad6744affa47f17f88048cff75037 Mon Sep 17 00:00:00 2001 From: Chris Yeung Date: Fri, 12 May 2023 17:32:47 -0400 Subject: [PATCH] Re #46: add support for multi-class segmentation --- UltrasoundSegmentation/train.py | 89 +++++++++++++++--------- UltrasoundSegmentation/train_config.yaml | 11 +-- 2 files changed, 60 insertions(+), 40 deletions(-) diff --git a/UltrasoundSegmentation/train.py b/UltrasoundSegmentation/train.py index 8d4ae6e..337b099 100644 --- a/UltrasoundSegmentation/train.py +++ b/UltrasoundSegmentation/train.py @@ -23,7 +23,7 @@ from datetime import datetime from torch.nn import BCEWithLogitsLoss from torch.optim import Adam -from torch.optim.lr_scheduler import StepLR +from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR from monai.data import DataLoader from monai.data.utils import decollate_batch @@ -127,32 +127,54 @@ def main(args): if config["transforms"]["general"]: for tfm in config["transforms"]["general"]: try: - train_transform_list.append(getattr(transforms, tfm["name"])(**tfm["params"])) - val_transform_list.append(getattr(transforms, tfm["name"])(**tfm["params"])) + train_transform_list.append( + getattr(transforms, tfm["name"])(**tfm["params"]) + ) + val_transform_list.append( + getattr(transforms, tfm["name"])(**tfm["params"]) + ) except KeyError: # Apply transform to both image and label by default - train_transform_list.append(getattr(transforms, tfm["name"])(keys=["image", "label"])) - val_transform_list.append(getattr(transforms, tfm["name"])(keys=["image", "label"])) + train_transform_list.append( + getattr(transforms, tfm["name"])(keys=["image", "label"]) + ) + val_transform_list.append( + getattr(transforms, tfm["name"])(keys=["image", "label"]) + ) if config["transforms"]["train"]: for tfm in config["transforms"]["train"]: try: - train_transform_list.append(getattr(transforms, tfm["name"])(**tfm["params"])) + train_transform_list.append( + getattr(transforms, tfm["name"])(**tfm["params"]) + ) except KeyError: - train_transform_list.append(getattr(transforms, tfm["name"])(keys=["image", "label"])) + train_transform_list.append( + getattr(transforms, tfm["name"])(keys=["image", "label"]) + ) train_transform = Compose(train_transform_list) val_transform = Compose(val_transform_list) # Create dataloaders using UltrasoundDataset train_dataset = UltrasoundDataset(args.train_data_folder, transform=train_transform) - train_dataloader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=False, generator=g) + train_dataloader = DataLoader( + train_dataset, + batch_size=config["batch_size"], + shuffle=True, + generator=g + ) val_dataset = UltrasoundDataset(args.val_data_folder, transform=val_transform) - val_dataloader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, generator=g) + val_dataloader = DataLoader( + val_dataset, + batch_size=config["batch_size"], + shuffle=False, + generator=g + ) # Construct model if config["model_name"] == "monai_unet": model = monai.networks.nets.UNet( spatial_dims=2, - in_channels=1, - out_channels=1, + in_channels=config["in_channels"], + out_channels=config["out_channels"], channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, @@ -162,7 +184,7 @@ def main(args): # Construct loss function if config["loss_function"] == "monai_dice": - loss_function = monai.losses.DiceLoss(sigmoid=True) + loss_function = monai.losses.DiceLoss(to_onehot_y=True, softmax=True) else: loss_function = BCEWithLogitsLoss() @@ -171,7 +193,7 @@ def main(args): # from torchinfo import summary # summary(model, input_size=(1, config["in_channels"], 128, 128)) - optimizer = Adam(model.parameters(), config["learning_rate"]) + optimizer = Adam(model.parameters(), config["learning_rate"], weight_decay=config["weight_decay"]) # Set up learning rate decay try: @@ -184,19 +206,20 @@ def main(args): learning_rate_decay_factor = 1.0 # No decay logging.info(f"Learning rate decay frequency: {learning_rate_decay_frequency}") logging.info(f"Learning rate decay factor: {learning_rate_decay_factor}") + scheduler = StepLR(optimizer, step_size=learning_rate_decay_frequency, gamma=learning_rate_decay_factor) # Metrics - dice_metric = DiceMetric(include_background=True, reduction="mean") - iou_metric = MeanIoU(include_background=True, reduction="mean") + dice_metric = DiceMetric(include_background=False, reduction="mean") + iou_metric = MeanIoU(include_background=False, reduction="mean") confusion_matrix_metric = ConfusionMatrixMetric( - include_background=True, + include_background=False, metric_name=["accuracy", "precision", "sensitivity", "specificity", "f1_score"], reduction="mean" ) - post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) + post_pred = Compose([AsDiscrete(argmax=True, to_onehot=config["out_channels"])]) + post_label = Compose([AsDiscrete(to_onehot=config["out_channels"])]) # Train model - scheduler = StepLR(optimizer, step_size=learning_rate_decay_frequency, gamma=learning_rate_decay_factor) for epoch in range(config["num_epochs"]): logging.info(f"Epoch {epoch+1}/{config['num_epochs']}") model.train() @@ -230,10 +253,11 @@ def main(args): val_loss += loss.item() # Compute metrics for current iteration - val_post_preds = [post_pred(i) for i in decollate_batch(val_outputs)] - dice_metric(y_pred=val_post_preds, y=val_labels) - iou_metric(y_pred=val_post_preds, y=val_labels) - confusion_matrix_metric(y_pred=val_post_preds, y=val_labels) + val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)] + val_labels = [post_label(i) for i in decollate_batch(val_labels)] + dice_metric(y_pred=val_outputs, y=val_labels) + iou_metric(y_pred=val_outputs, y=val_labels) + confusion_matrix_metric(y_pred=val_outputs, y=val_labels) val_loss /= val_step dice = dice_metric.aggregate().item() @@ -261,14 +285,14 @@ def main(args): inputs = torch.stack([val_dataset[i]["image"] for i in sample]) labels = torch.stack([val_dataset[i]["label"] for i in sample]) with torch.no_grad(): - logits = model(inputs.to(device=device)) - outputs = torch.sigmoid(logits) + outputs = model(inputs.to(device=device)) fig, axes = plt.subplots(3, 3, figsize=(9, 9)) for i in range(3): axes[i, 0].imshow(inputs[i, 0, :, :], cmap="gray") axes[i, 1].imshow(labels[i].squeeze(), cmap="gray") - im = axes[i, 2].imshow(outputs[i].squeeze().cpu().detach().numpy(), vmin=0, vmax=1, cmap="viridis") + # im = axes[i, 2].imshow(torch.argmax(outputs[i], dim=0).detach().cpu(), vmin=0, vmax=1, cmap="viridis") + im = axes[i, 2].imshow(torch.argmax(outputs[i], dim=0).detach().cpu(), cmap="gray") # Create an additional axis for the colorbar cax = fig.add_axes([axes[i, 2].get_position().x1 + 0.01, @@ -276,6 +300,11 @@ def main(args): 0.02, axes[i, 2].get_position().height]) fig.colorbar(im, cax=cax) + + # Log current learning rate + for param_group in optimizer.param_groups: + current_lr = param_group["lr"] + logging.info(f"Current learning rate: {current_lr}") # Log metrics and examples together to maintain global step == epoch run.log({ @@ -288,23 +317,19 @@ def main(args): "sensitivity": sen, "specificity": spe, "f1_score": f1, + "lr": current_lr, "examples": wandb.Image(fig)}) plt.close(fig) - # Log current learning rate - for param_group in optimizer.param_groups: - current_lr = param_group["lr"] - logging.info(f"Current learning rate: {current_lr}") - # Save model after every Nth epoch as specified in the config file if config["save_frequency"]>0 and (epoch + 1) % config["save_frequency"] == 0: torch.save(model.state_dict(), os.path.join(args.output_dir, f"model_{epoch+1}.pt")) # Log final metrics metric_table = wandb.Table( - columns=["acc", "pre", "sen", "spe", "f1", "dice", "iou"], - data=[[acc, pre, sen, spe, f1, dice, iou]] + columns=["run", "acc", "pre", "sen", "spe", "f1", "dice", "iou"], + data=[[experiment_name, acc, pre, sen, spe, f1, dice, iou]] ) run.log({"metrics": metric_table}) diff --git a/UltrasoundSegmentation/train_config.yaml b/UltrasoundSegmentation/train_config.yaml index 20f35be..21f2437 100644 --- a/UltrasoundSegmentation/train_config.yaml +++ b/UltrasoundSegmentation/train_config.yaml @@ -3,13 +3,14 @@ model_name: "monai_unet" loss_function: "monai_dice" in_channels: !!int 1 -out_channels: !!int 1 +out_channels: !!int 16 save_frequency: !!int 0 -num_epochs: !!int 40 +num_epochs: !!int 100 batch_size: !!int 32 learning_rate: !!float 0.004 learning_rate_decay_factor: !!float 0.5 learning_rate_decay_frequency: !!int 10 +weight_decay: 0.00005 seed: !!int 42 transforms: general: @@ -30,9 +31,3 @@ transforms: keys: ["image", "label"] spatial_size: [128, 128] train: - - name: "RandGaussianNoised" - params: - keys: ["image"] - prob: 0.15 - mean: 0.0 - std: 0.1