Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smooth interpolation transform for multi-label images #6960

Closed
dyollb opened this issue Sep 7, 2023 · 3 comments
Closed

Smooth interpolation transform for multi-label images #6960

dyollb opened this issue Sep 7, 2023 · 3 comments

Comments

@dyollb
Copy link
Contributor

dyollb commented Sep 7, 2023

Is your feature request related to a problem? Please describe.
Upsampling a segmentation using nearest neighbor interpolation yields non-smooth labels.
This issue was already reported e.g. here #3178

Describe the solution you'd like
I would a transform that can resample a segmentation to a different resolution (pixdim), similar to the Spacing transform, but with a controllable parameter sigma to smooth the result.

Describe alternatives you've considered
I use ITK/SimpleITK with the LabelGaussian interpolation option, but this gets quite slow. I would like the solution to run on GPUs/have a better performance.

Using existing MONAI transforms, I could devise a workaround (see test_Existing_Workaround), but this will likely use too much GPU memory.

Additional context
I would propose to create a PR adding a transform similar to the draft below:

from pathlib import Path
from typing import Sequence

import numpy as np
import SimpleITK as sitk
import torch

from monai.data.meta_obj import get_track_meta
from monai.utils import (
    convert_to_tensor,
    convert_data_type,
    convert_to_dst_type,
)
from monai.utils import (
    GridSampleMode,
    GridSamplePadMode,
)
from monai.config.type_definitions import NdarrayTensor
from monai.networks import one_hot
from monai.networks.layers import GaussianFilter
from monai.transforms import (
    AsDiscrete,
    EnsureType,
    LoadImage,
    GaussianSmooth,
    Spacing,
    SaveImage,
    Transform,
)
from monai.utils.enums import TransformBackends
from torch.testing import assert_close


class LabelGaussianResample(Transform):
    """
    Args:
        pixdim: output voxel spacing. if providing a single number, will use it for the first dimension.
            items of the pixdim sequence map to the spatial dimensions of input image, if length
            of pixdim sequence is longer than image spatial dimensions, will ignore the longer part,
            if shorter, will pad with the last value. For example, for 3D image if pixdim is [1.0, 2.0] it
            will be padded to [1.0, 2.0, 2.0]
            if the components of the `pixdim` are non-positive values, the transform will use the
            corresponding components of the original pixdim, which is computed from the `affine`
            matrix of input image.
        to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
            Defaults to ``None``.
        dim: the dimension to be converted to `num_classes` channels from `1` channel, should be non-negative number.
        keepdim: whether the output tensor has dim retained or not. Ignored if dim=None
        kwargs: additional parameters to `torch.argmax`, `monai.networks.one_hot`.
            currently ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored.
            These default to ``0``, ``True``, ``torch.float`` respectively.
        sigma: if a list of values, must match the count of spatial dimensions of input data,
            and apply every value in the list to 1 spatial dimension. if only 1 value provided,
            use it for all spatial dimensions.
        approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace".
            see also :py:meth:`monai.networks.layers.GaussianFilter`.
    """

    backend = [TransformBackends.TORCH]

    def __init__(
        self,
        pixdim: Sequence[float] | float | np.ndarray,
        channel_wise: bool = True,
        to_onehot: int | None = None,
        dim: int = 0,
        keepdim: bool = True,
        sigma: Sequence[float] | float = 1.0,
        approx: str = "erf",
    ) -> None:
        self.channel_wise = channel_wise
        self.to_onehot = to_onehot
        self.dim = dim
        self.keepdim = keepdim
        self.sigma = sigma
        self.approx = approx
        self.resample = Spacing(pixdim=pixdim, mode=GridSampleMode.BILINEAR)

    def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
        img = convert_to_tensor(img, track_meta=get_track_meta())

        # convert to one-hot if necessary
        if self.to_onehot is not None:
            img_t, *_ = convert_data_type(img, torch.Tensor)
            if not isinstance(self.to_onehot, int):
                raise ValueError(
                    f"the number of classes for One-Hot must be an integer, got {type(self.to_onehot)}."
                )
            img_t = one_hot(
                img_t,
                num_classes=self.to_onehot,
                dim=self.dim,
                dtype=torch.float,
            )
        else:
            img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)

        # perform gaussian smoothing
        sigma: Sequence[torch.Tensor] | torch.Tensor
        if isinstance(self.sigma, Sequence):
            sigma = [torch.as_tensor(s, device=img_t.device) for s in self.sigma]
        else:
            sigma = torch.as_tensor(self.sigma, device=img_t.device)
        gaussian_filter = GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx)
        
        if self.channel_wise:
            dims = img_t.shape 
            for i in range(dims[0]):
                img_t[i, ...] = gaussian_filter(img_t.select(0, i).unsqueeze(0)).squeeze(0)
        else:
            img_t = gaussian_filter(img_t.unsqueeze(0)).squeeze(0)

        img_t, *_ = convert_to_dst_type(img_t, dst=img, dtype=img_t.dtype)

        # resample
        img_t = self.resample(img_t)

        # argmax
        img_t = torch.argmax(
            img_t,
            dim=self.dim,
            keepdim=self.keepdim,
        )

        # convert to output type
        out, *_ = convert_to_dst_type(img_t, dst=img, dtype=img_t.dtype)

        return out


def log_tensor(op, data):
    print(f"{op}: {type(data)} {data.shape}, {data.dtype}")


def test_Existing_Workaround():
    """ Implement the same thing with existing transforms"""
    labels = sitk.Image(100, 110, 80, sitk.sitkUInt16)
    labels[:] = 0
    labels[3:40, 80:90, 25:60] = 1
    labels[15:35, 30:95, 50:70] = 2
    labels[30:35, 30:35, 30:35] = 3

    num_classes = 4
    file_path = Path.cwd() / "labels.nii.gz"
    sitk.WriteImage(labels, file_path)

    reader = LoadImage(reader="ITKReader", ensure_channel_first=True, image_only=False)
    to_tensor = EnsureType(dtype=torch.half)
    to_one_hot = AsDiscrete(to_onehot=num_classes, dtype=torch.half)

    labels_tuple = reader(filename=[file_path])

    labels_tensor = to_tensor(labels_tuple[0])
    log_tensor("input", labels_tensor)

    onehot_tensor = to_one_hot(labels_tensor)
    log_tensor("onehot", onehot_tensor)

    smooth = GaussianSmooth(sigma=1.5)
    smooth_onehot_tensor = smooth(onehot_tensor)
    log_tensor("smooth", smooth_onehot_tensor)

    resample = Spacing(pixdim=0.5)
    hires_onehot_tensor = resample(smooth_onehot_tensor)
    log_tensor("hires", hires_onehot_tensor)

    to_argmax = AsDiscrete(argmax=True)
    hires_argmax_tensor = to_argmax(hires_onehot_tensor)
    log_tensor("argmax", hires_argmax_tensor)

    saver = SaveImage(
        resample=False,
        output_postfix="hr",
        output_dir=Path.cwd(),
        separate_folder=False,
    )
    saver(hires_argmax_tensor)


def test_LabelGaussianResample():
    labels = sitk.Image(100, 110, 80, sitk.sitkUInt16)
    labels[:] = 0
    labels[3:40, 80:90, 25:60] = 1
    labels[15:35, 30:95, 50:70] = 2
    labels[30:35, 30:35, 30:35] = 3

    labels.SetSpacing([0.9, 0.95, 1.0])
    labels.SetOrigin([5.0, 6.0, -7.0])
    labels.SetDirection([0, 0, -1.0, 0.99237, -0.12324, 0, 0.12324, 0.99237, 0])

    num_classes = 4
    file_path = Path.cwd() / "labels.nii.gz"
    sitk.WriteImage(labels, file_path)

    reader = LoadImage(reader="ITKReader", image_only=False, ensure_channel_first=True)
    to_tensor = EnsureType(dtype=torch.half, track_meta=True)

    resample = LabelGaussianResample(pixdim=0.5, to_onehot=num_classes)

    labels_tuple = reader(filename=[file_path])

    labels_tensor = to_tensor(labels_tuple[0])
    print(labels_tensor.affine)

    output = resample(labels_tensor)
    print(output.affine)

    saver = SaveImage(
        writer="ITKWriter",
        resample=False,
        output_postfix="hr",
        output_dir=Path.cwd(),
        separate_folder=False,
    )
    saver(output)
@wyli
Copy link
Contributor

wyli commented Sep 8, 2023

thanks, it makes sense, labelling this as contribution wanted... I think the logic could be optimized to run on downsampling only and with an auto sigma if it's not specified

anti_aliasing_sigma = torch.maximum(torch.zeros(factors.shape), (factors - 1) / 2).tolist()

@dyollb
Copy link
Contributor Author

dyollb commented Sep 8, 2023

ok, thanks @wyli. I will work on it. I also need to change the logic to process channels one-by-one (if channel_wise=True), else the memory will quickly exceed what is available on a GPU.

@vikashg
Copy link

vikashg commented Jan 4, 2024

closing because of inactivity and very specific use case. Please reopen if needed

@vikashg vikashg closed this as completed Jan 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants