Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Probabilstic UNet logger #358

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 55 additions & 49 deletions torch_em/model/probabilistic_unet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# This code is based on the original TensorFlow implementation: https://github.com/SimonKohl/probabilistic_unet
# The below implementation is from: https://github.com/stefanknegt/Probabilistic-Unet-Pytorch

from typing import Union

import numpy as np

import torch
Expand All @@ -21,15 +23,15 @@ def truncated_normal_(tensor, mean=0, std=1):


def init_weights(m):
if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
if isinstance(m, Union[nn.Conv2d, nn.ConvTranspose2d]):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
# nn.init.normal_(m.weight, std=0.001)
# nn.init.normal_(m.bias, std=0.001)
truncated_normal_(m.bias, mean=0, std=0.001)


def init_weights_orthogonal_normal(m):
if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
if isinstance(m, Union[nn.Conv2d, nn.ConvTranspose2d]):
nn.init.orthogonal_(m.weight)
truncated_normal_(m.bias, mean=0, std=0.001)
# nn.init.normal_(m.bias, std=0.001)
Expand Down Expand Up @@ -126,13 +128,13 @@ def __init__(
self.name = 'Prior'

self.encoder = Encoder(
self.input_channels,
self.num_filters,
self.no_convs_per_block,
initializers,
posterior=self.posterior,
num_classes=num_classes
)
self.input_channels,
self.num_filters,
self.no_convs_per_block,
initializers,
posterior=self.posterior,
num_classes=num_classes
)

self.conv_layer = nn.Conv2d(num_filters[-1], 2 * self.latent_dim, (1, 1), stride=1)
self.show_img = 0
Expand Down Expand Up @@ -198,7 +200,6 @@ def __init__(
num_filters,
latent_dim,
num_output_channels,
num_classes,
no_convs_fcomb,
initializers,
use_tile=True,
Expand All @@ -207,8 +208,7 @@ def __init__(

super().__init__()

self.num_channels = num_output_channels
self.num_classes = num_classes
self.num_output_channels = num_output_channels
self.channel_axis = 1
self.spatial_axes = [2, 3]
self.num_filters = num_filters
Expand All @@ -235,7 +235,7 @@ def __init__(

self.layers = nn.Sequential(*layers)

self.last_layer = nn.Conv2d(self.num_filters[0], self.num_classes, kernel_size=1)
self.last_layer = nn.Conv2d(self.num_filters[0], self.num_output_channels, kernel_size=1)

if initializers['w'] == 'orthogonal':
self.layers.apply(init_weights_orthogonal_normal)
Expand All @@ -254,8 +254,8 @@ def tile(self, a, dim, n_tile):
repeat_idx[dim] = n_tile
a = a.repeat(*(repeat_idx))
order_index = torch.LongTensor(
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
).to(self.device)
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
).to(self.device)
return torch.index_select(a, dim, order_index)

def forward(self, feature_map, z):
Expand All @@ -282,16 +282,18 @@ class ProbabilisticUNet(nn.Module):

The following elements are initialized to get our desired network:
input_channels: the number of channels in the image (1 for grayscale and 3 for RGB)
num_classes: the number of classes to predict
output_channels: the number of channels to predict.
num_classes: the number of classes (raters) for the posterior
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_raters

num_filters: is a list consisting of the amount of filters layer
latent_dim: dimension of the latent space
no_cons_per_block: no convs per block in the (convolutional) encoder of prior and posterior
no_convs_per_block: no convs per block in the (convolutional) encoder of prior and posterior
beta: KL and reconstruction loss are weighted using a KL weighting factor (β)
consensus_masking: activates consensus masking in the reconstruction loss
rl_swap: switches the reconstruction loss to dice loss from the default (binary cross-entroy loss)

Parameters:
input_channels [int] - (default: 1)
output_channels [int] - (default: 1)
num_classes [int] - (default: 1)
num_filters [list] - (default: [32, 64, 128, 192])
latent_dim [int] - (default: 6)
Expand All @@ -305,6 +307,7 @@ class ProbabilisticUNet(nn.Module):
def __init__(
self,
input_channels=1,
output_channels=1,
num_classes=1,
num_filters=[32, 64, 128, 192],
latent_dim=6,
Expand All @@ -318,6 +321,7 @@ def __init__(
super().__init__()

self.input_channels = input_channels
self.output_channels = output_channels
self.num_classes = num_classes
self.num_filters = num_filters
self.latent_dim = latent_dim
Expand All @@ -335,40 +339,39 @@ def __init__(
self.device = device

self.unet = UNet2d(
in_channels=self.input_channels,
out_channels=None,
depth=len(self.num_filters),
initial_features=num_filters[0]
).to(self.device)
in_channels=self.input_channels,
out_channels=None,
depth=len(self.num_filters),
initial_features=num_filters[0]
).to(self.device)

self.prior = AxisAlignedConvGaussian(
self.input_channels,
self.num_filters,
self.no_convs_per_block,
self.latent_dim,
self.initializers
).to(self.device)
self.input_channels,
self.num_filters,
self.no_convs_per_block,
self.latent_dim,
self.initializers
).to(self.device)

self.posterior = AxisAlignedConvGaussian(
self.input_channels,
self.num_filters,
self.no_convs_per_block,
self.latent_dim,
self.initializers,
posterior=True,
num_classes=num_classes
).to(self.device)
self.input_channels,
self.num_filters,
self.no_convs_per_block,
self.latent_dim,
self.initializers,
posterior=True,
num_classes=num_classes
).to(self.device)

self.fcomb = Fcomb(
self.num_filters,
self.latent_dim,
self.input_channels,
self.num_classes,
self.no_convs_fcomb,
{'w': 'orthogonal', 'b': 'normal'},
use_tile=True,
device=self.device
).to(self.device)
self.num_filters,
self.latent_dim,
self.output_channels,
self.no_convs_fcomb,
{'w': 'orthogonal', 'b': 'normal'},
use_tile=True,
device=self.device
).to(self.device)

def _check_shape(self, patch):
spatial_shape = tuple(patch.shape)[2:]
Expand Down Expand Up @@ -449,12 +452,15 @@ def elbo(self, segm, consm=None, analytic_kl=True, reconstruct_posterior_mean=Fa
z_posterior = self.posterior_latent_space.rsample()

self.kl = torch.mean(
self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)
)
self.kl_divergence(analytic=analytic_kl, calculate_posterior=False, z_posterior=z_posterior)
)

# Here we use the posterior sample sampled above
self.reconstruction = self.reconstruct(use_posterior_mean=reconstruct_posterior_mean,
calculate_posterior=False, z_posterior=z_posterior)
self.reconstruction = self.reconstruct(
use_posterior_mean=reconstruct_posterior_mean,
calculate_posterior=False,
z_posterior=z_posterior
)

if self.consensus_masking is True and consm is not None:
reconstruction_loss = criterion(self.reconstruction * consm, segm * consm)
Expand Down
2 changes: 1 addition & 1 deletion torch_em/self_training/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .logger import SelfTrainingTensorboardLogger
from .logger import SelfTrainingTensorboardLogger, ProbabilisticUNetTrainerLogger
from .loss import DefaultSelfTrainingLoss, DefaultSelfTrainingLossAndMetric, ProbabilisticUNetLoss, \
ProbabilisticUNetLossAndMetric
from .mean_teacher import MeanTeacherTrainer
Expand Down
29 changes: 29 additions & 0 deletions torch_em/self_training/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,32 @@ def log_validation(self, step, metric, loss, xt, xt1, xt2, y, z, gt, samples, gt
self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
if gt_metric is not None:
self.tb.add_scalar(tag="validation/gt_metric", scalar_value=gt_metric, global_step=step)


class ProbabilisticUNetTrainerLogger(torch_em.trainer.logger_base.TorchEmLogger):
def __init__(self, trainer, save_root, **unused_kwargs):
super().__init__(trainer, save_root)
self.log_dir = f"./logs/{trainer.name}" if save_root is None else\
os.path.join(save_root, "logs", trainer.name)
os.makedirs(self.log_dir, exist_ok=True)

self.tb = torch.utils.tensorboard.SummaryWriter(self.log_dir)
self.log_image_interval = trainer.log_image_interval

def add_image(self, x, y, samples, name, step):
# NOTE: we only show the first tensor per batch for all images
self.tb.add_image(tag=f"{name}/input", img_tensor=x[0], global_step=step)
self.tb.add_image(tag=f"{name}/target", img_tensor=y[0], global_step=step)
sample_grid = make_grid([sample[0] for sample in samples], nrow=4, padding=4)
self.tb.add_image(tag=f"{name}/samples", img_tensor=sample_grid, global_step=step)

def log_train(self, step, loss, lr, x, y, samples):
self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step)
self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step)
if step % self.log_image_interval == 0:
self.add_image(x, y, samples, "train", step)

def log_validation(self, step, metric, loss, x, y, samples):
self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step)
self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step)
self.add_image(x, y, samples, "validation", step)
23 changes: 21 additions & 2 deletions torch_em/self_training/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,19 @@ class ProbabilisticUNetLoss(nn.Module):
# TODO : Implement a generic utility function for all Probabilistic UNet schemes (ELBO, GECO, etc.)
loss [nn.Module] - the loss function to be used. (default: None)
"""
def __init__(self, loss=None):
def __init__(self, loss=None, output_channels=None):
super().__init__()
self.loss = loss
self.output_channels = output_channels

def __call__(self, model, input_, labels, label_filter=None):
model.forward(input_, labels)

# NOTE: 'output_channels' ensures to compute loss over only one label set (in case of multi-rater annotation).
# In the current experiment, we consider the first label out of the bunch.
if self.output_channels is not None:
labels = labels[:, :self.output_channels, ...]

if self.loss is None:
elbo = model.elbo(labels, label_filter)
reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \
Expand All @@ -105,16 +111,29 @@ class ProbabilisticUNetLossAndMetric(nn.Module):
activation [nn.Module, callable] - the activation function to be applied to the prediction
before evaluating the average predictions. (default: None)
"""
def __init__(self, loss=None, metric=DiceLoss(), activation=torch.nn.Sigmoid(), prior_samples=16):
def __init__(
self,
loss=None,
metric=DiceLoss(),
activation=torch.nn.Sigmoid(),
prior_samples=16,
output_channels=None,
):
super().__init__()
self.activation = activation
self.metric = metric
self.loss = loss
self.prior_samples = prior_samples
self.output_channels = output_channels

def __call__(self, model, input_, labels, label_filter=None):
model.forward(input_, labels)

# NOTE: 'output_channels' ensures to compute loss over only one label set (in case of multi-rater annotation).
# In the current experiment, we consider the first label out of the bunch.
if self.output_channels is not None:
labels = labels[:, :self.output_channels, ...]

if self.loss is None:
elbo = model.elbo(labels, label_filter)
reg_loss = l2_regularisation(model.posterior) + l2_regularisation(model.prior) + \
Expand Down
15 changes: 9 additions & 6 deletions torch_em/self_training/probabilistic_unet_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ class ProbabilisticUNetTrainer(torch_em.trainer.DefaultTrainer):
"""

def __init__(
self,
clipping_value=None,
prior_samples=16,
loss=None,
loss_and_metric=None,
**kwargs
self,
clipping_value=None,
prior_samples=16,
loss=None,
loss_and_metric=None,
**kwargs
):
super().__init__(loss=loss, metric=DummyLoss(), **kwargs)
assert loss, loss_and_metric is not None
Expand Down Expand Up @@ -76,7 +76,9 @@ def _train_epoch_impl(self, progress, forward_context, backprop):

if self.logger is not None:
lr = [pm["lr"] for pm in self.optimizer.param_groups][0]
# We only sample if we log images in this iteration.
samples = self._sample() if self._iteration % self.log_image_interval == 0 else None
y = y[:, :self.model.output_channels, ...]
self.logger.log_train(self._iteration, loss, lr, x, y, samples)

self._iteration += 1
Expand Down Expand Up @@ -109,6 +111,7 @@ def _validate_impl(self, forward_context):

if self.logger is not None:
samples = self._sample()
y = y[:, :self.model.output_channels, ...]
self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, samples)

return metric_val