You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is your feature request related to a problem? Please describe.
Upsampling a segmentation using nearest neighbor interpolation yields non-smooth labels.
This issue was already reported e.g. here #3178
Describe the solution you'd like
I would a transform that can resample a segmentation to a different resolution (pixdim), similar to the Spacing transform, but with a controllable parameter sigma to smooth the result.
Describe alternatives you've considered
I use ITK/SimpleITK with the LabelGaussian interpolation option, but this gets quite slow. I would like the solution to run on GPUs/have a better performance.
Using existing MONAI transforms, I could devise a workaround (see test_Existing_Workaround), but this will likely use too much GPU memory.
Additional context
I would propose to create a PR adding a transform similar to the draft below:
frompathlibimportPathfromtypingimportSequenceimportnumpyasnpimportSimpleITKassitkimporttorchfrommonai.data.meta_objimportget_track_metafrommonai.utilsimport (
convert_to_tensor,
convert_data_type,
convert_to_dst_type,
)
frommonai.utilsimport (
GridSampleMode,
GridSamplePadMode,
)
frommonai.config.type_definitionsimportNdarrayTensorfrommonai.networksimportone_hotfrommonai.networks.layersimportGaussianFilterfrommonai.transformsimport (
AsDiscrete,
EnsureType,
LoadImage,
GaussianSmooth,
Spacing,
SaveImage,
Transform,
)
frommonai.utils.enumsimportTransformBackendsfromtorch.testingimportassert_closeclassLabelGaussianResample(Transform):
""" Args: pixdim: output voxel spacing. if providing a single number, will use it for the first dimension. items of the pixdim sequence map to the spatial dimensions of input image, if length of pixdim sequence is longer than image spatial dimensions, will ignore the longer part, if shorter, will pad with the last value. For example, for 3D image if pixdim is [1.0, 2.0] it will be padded to [1.0, 2.0, 2.0] if the components of the `pixdim` are non-positive values, the transform will use the corresponding components of the original pixdim, which is computed from the `affine` matrix of input image. to_onehot: if not None, convert input data into the one-hot format with specified number of classes. Defaults to ``None``. dim: the dimension to be converted to `num_classes` channels from `1` channel, should be non-negative number. keepdim: whether the output tensor has dim retained or not. Ignored if dim=None kwargs: additional parameters to `torch.argmax`, `monai.networks.one_hot`. currently ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored. These default to ``0``, ``True``, ``torch.float`` respectively. sigma: if a list of values, must match the count of spatial dimensions of input data, and apply every value in the list to 1 spatial dimension. if only 1 value provided, use it for all spatial dimensions. approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace". see also :py:meth:`monai.networks.layers.GaussianFilter`. """backend= [TransformBackends.TORCH]
def__init__(
self,
pixdim: Sequence[float] |float|np.ndarray,
channel_wise: bool=True,
to_onehot: int|None=None,
dim: int=0,
keepdim: bool=True,
sigma: Sequence[float] |float=1.0,
approx: str="erf",
) ->None:
self.channel_wise=channel_wiseself.to_onehot=to_onehotself.dim=dimself.keepdim=keepdimself.sigma=sigmaself.approx=approxself.resample=Spacing(pixdim=pixdim, mode=GridSampleMode.BILINEAR)
def__call__(self, img: NdarrayTensor) ->NdarrayTensor:
img=convert_to_tensor(img, track_meta=get_track_meta())
# convert to one-hot if necessaryifself.to_onehotisnotNone:
img_t, *_=convert_data_type(img, torch.Tensor)
ifnotisinstance(self.to_onehot, int):
raiseValueError(
f"the number of classes for One-Hot must be an integer, got {type(self.to_onehot)}."
)
img_t=one_hot(
img_t,
num_classes=self.to_onehot,
dim=self.dim,
dtype=torch.float,
)
else:
img_t, *_=convert_data_type(img, torch.Tensor, dtype=torch.float)
# perform gaussian smoothingsigma: Sequence[torch.Tensor] |torch.Tensorifisinstance(self.sigma, Sequence):
sigma= [torch.as_tensor(s, device=img_t.device) forsinself.sigma]
else:
sigma=torch.as_tensor(self.sigma, device=img_t.device)
gaussian_filter=GaussianFilter(img_t.ndim-1, sigma, approx=self.approx)
ifself.channel_wise:
dims=img_t.shapeforiinrange(dims[0]):
img_t[i, ...] =gaussian_filter(img_t.select(0, i).unsqueeze(0)).squeeze(0)
else:
img_t=gaussian_filter(img_t.unsqueeze(0)).squeeze(0)
img_t, *_=convert_to_dst_type(img_t, dst=img, dtype=img_t.dtype)
# resampleimg_t=self.resample(img_t)
# argmaximg_t=torch.argmax(
img_t,
dim=self.dim,
keepdim=self.keepdim,
)
# convert to output typeout, *_=convert_to_dst_type(img_t, dst=img, dtype=img_t.dtype)
returnoutdeflog_tensor(op, data):
print(f"{op}: {type(data)}{data.shape}, {data.dtype}")
deftest_Existing_Workaround():
""" Implement the same thing with existing transforms"""labels=sitk.Image(100, 110, 80, sitk.sitkUInt16)
labels[:] =0labels[3:40, 80:90, 25:60] =1labels[15:35, 30:95, 50:70] =2labels[30:35, 30:35, 30:35] =3num_classes=4file_path=Path.cwd() /"labels.nii.gz"sitk.WriteImage(labels, file_path)
reader=LoadImage(reader="ITKReader", ensure_channel_first=True, image_only=False)
to_tensor=EnsureType(dtype=torch.half)
to_one_hot=AsDiscrete(to_onehot=num_classes, dtype=torch.half)
labels_tuple=reader(filename=[file_path])
labels_tensor=to_tensor(labels_tuple[0])
log_tensor("input", labels_tensor)
onehot_tensor=to_one_hot(labels_tensor)
log_tensor("onehot", onehot_tensor)
smooth=GaussianSmooth(sigma=1.5)
smooth_onehot_tensor=smooth(onehot_tensor)
log_tensor("smooth", smooth_onehot_tensor)
resample=Spacing(pixdim=0.5)
hires_onehot_tensor=resample(smooth_onehot_tensor)
log_tensor("hires", hires_onehot_tensor)
to_argmax=AsDiscrete(argmax=True)
hires_argmax_tensor=to_argmax(hires_onehot_tensor)
log_tensor("argmax", hires_argmax_tensor)
saver=SaveImage(
resample=False,
output_postfix="hr",
output_dir=Path.cwd(),
separate_folder=False,
)
saver(hires_argmax_tensor)
deftest_LabelGaussianResample():
labels=sitk.Image(100, 110, 80, sitk.sitkUInt16)
labels[:] =0labels[3:40, 80:90, 25:60] =1labels[15:35, 30:95, 50:70] =2labels[30:35, 30:35, 30:35] =3labels.SetSpacing([0.9, 0.95, 1.0])
labels.SetOrigin([5.0, 6.0, -7.0])
labels.SetDirection([0, 0, -1.0, 0.99237, -0.12324, 0, 0.12324, 0.99237, 0])
num_classes=4file_path=Path.cwd() /"labels.nii.gz"sitk.WriteImage(labels, file_path)
reader=LoadImage(reader="ITKReader", image_only=False, ensure_channel_first=True)
to_tensor=EnsureType(dtype=torch.half, track_meta=True)
resample=LabelGaussianResample(pixdim=0.5, to_onehot=num_classes)
labels_tuple=reader(filename=[file_path])
labels_tensor=to_tensor(labels_tuple[0])
print(labels_tensor.affine)
output=resample(labels_tensor)
print(output.affine)
saver=SaveImage(
writer="ITKWriter",
resample=False,
output_postfix="hr",
output_dir=Path.cwd(),
separate_folder=False,
)
saver(output)
The text was updated successfully, but these errors were encountered:
thanks, it makes sense, labelling this as contribution wanted... I think the logic could be optimized to run on downsampling only and with an auto sigma if it's not specified
ok, thanks @wyli. I will work on it. I also need to change the logic to process channels one-by-one (if channel_wise=True), else the memory will quickly exceed what is available on a GPU.
Is your feature request related to a problem? Please describe.
Upsampling a segmentation using nearest neighbor interpolation yields non-smooth labels.
This issue was already reported e.g. here #3178
Describe the solution you'd like
I would a transform that can resample a segmentation to a different resolution (
pixdim
), similar to theSpacing
transform, but with a controllable parametersigma
to smooth the result.Describe alternatives you've considered
I use ITK/SimpleITK with the LabelGaussian interpolation option, but this gets quite slow. I would like the solution to run on GPUs/have a better performance.
Using existing MONAI transforms, I could devise a workaround (see
test_Existing_Workaround
), but this will likely use too much GPU memory.Additional context
I would propose to create a PR adding a transform similar to the draft below:
The text was updated successfully, but these errors were encountered: