Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend/correct behaviour of RandAffined to allow for different input shapes #4491

Closed
josegcpa opened this issue Jun 13, 2022 · 3 comments
Closed

Comments

@josegcpa
Copy link

Hello!

Is your feature request related to a problem? Please describe.
When using different shapes (i.e. predicting a downsampled segmentation map or an upsampled, high resolution image) it is often the case that affine rotations can be quite handy. However, the current implementation of MONAI's RandAffined does not enable users to apply these transforms to different sized inputs, often leading to mismatches as MONAI assumes that the shape of all inputs corresponds to the shape of the input of the first key.

Describe the solution you'd like
What I would like is, ideally, to have a random affine augmentation that is able to handle such cases, i.e. differently shaped inputs to be rotated, scaled, translated and sheared in the same fashion without altering the output shape of the input tensors.

Describe alternatives you've considered
The easiest alternative would be reimplement this class, which is what I am presently doing. The main issue here is, obviously, the translation which requires some additional thought (one cannot apply the same 10 voxel translation to an image and its downsampled version). Other than that, and assuming that rotations assume the center of the image as the center of rotation, I believe that this works relatively well (I may be wrong). The best part is that you can easily build on monai.transforms.Affine. Additionally, one can, in any case, provide the necessary spatial sizes in order. The example below does this.

class RandomAffined(monai.transforms.RandomizableTransform):
    def __init__(
        self,
        keys:List[str],
        spatial_sizes:List[Union[Tuple[int,int,int],Tuple[int,int]]],
        mode:List[str],
        prob:float=0.1,
        rotate_range:Union[Tuple[int,int,int],Tuple[int,int]]=[0,0,0],
        shear_range:Union[Tuple[int,int,int],Tuple[int,int]]=[0,0,0],
        translate_range:Union[Tuple[int,int,int],Tuple[int,int]]=[0,0,0],
        scale_range:Union[Tuple[int,int,int],Tuple[int,int]]=[0,0,0],
        device:"str"="cpu"):

        self.keys = keys
        self.spatial_sizes = [np.array(s,dtype=np.int32) for s in spatial_sizes]
        self.mode = mode
        self.prob = prob
        self.rotate_range = np.array(rotate_range)
        self.shear_range = np.array(shear_range)
        self.translate_range = np.array(translate_range)
        self.scale_range = np.array(scale_range)
        self.device = device

        self.affine_trans = {
            k:monai.transforms.Affine(
                spatial_size=s,
                mode=m,
                device=self.device)
            for k,s,m in zip(self.keys,self.spatial_sizes,self.mode)}
        
        self.get_translation_adjustment()

    def get_random_parameters(self):
        angle = self.R.uniform(
            -self.rotate_range,self.rotate_range)
        shear = self.R.uniform(
            -self.shear_range,self.shear_range)
        trans = self.R.uniform(
            -self.translate_range,self.translate_range)
        scale = self.R.uniform(
            1-self.scale_range,1+self.scale_range)

        return angle,shear,trans,scale
    
    def get_translation_adjustment(self):
        # we have to adjust the translation to ensure that all inputs
        # do not become misaligned. to do this I assume that the first image
        # is the reference
        ref_size = self.spatial_sizes[0]
        self.trans_adj = {
            k:s/ref_size
            for k,s in zip(self.keys,self.spatial_sizes)}
    
    def randomize(self):
        angle,shear,trans,scale = self.get_random_parameters()
        for k in self.affine_trans:
            # we only need to update the affine grid
            self.affine_trans[k].affine_grid = monai.transforms.AffineGrid(
                rotate_params=list(angle),
                shear_params=list(shear),
                translate_params=list(trans),
                scale_params=list(np.float32(scale*self.trans_adj[k])),
                device=self.device)

    def __call__(self,data):
        self.randomize()
        for k in self.keys:
            if self.R.uniform() < self.prob:
                transform = self.affine_trans[k]
                data[k],data[k+"_affine"] = transform(data[k])
        return data

Additional context
This is mostly it, just a random affine augmentation that can handle inputs of different sizes. Below I provide a reproducible example and the (unexpected IMO) output of the MONAI function and the output of the function noted above.


import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))
import numpy as np
import time
import torch
import monai

from typing import List,Union,Tuple

class RandomAffined(monai.transforms.RandomizableTransform):
    def __init__(
        self,
        keys:List[str],
        spatial_sizes:List[Union[Tuple[int,int,int],Tuple[int,int]]],
        mode:List[str],
        prob:float=0.1,
        rotate_range:Union[Tuple[int,int,int],Tuple[int,int]]=[0,0,0],
        shear_range:Union[Tuple[int,int,int],Tuple[int,int]]=[0,0,0],
        translate_range:Union[Tuple[int,int,int],Tuple[int,int]]=[0,0,0],
        scale_range:Union[Tuple[int,int,int],Tuple[int,int]]=[0,0,0],
        device:"str"="cpu"):

        self.keys = keys
        self.spatial_sizes = [np.array(s,dtype=np.int32) for s in spatial_sizes]
        self.mode = mode
        self.prob = prob
        self.rotate_range = np.array(rotate_range)
        self.shear_range = np.array(shear_range)
        self.translate_range = np.array(translate_range)
        self.scale_range = np.array(scale_range)
        self.device = device

        self.affine_trans = {
            k:monai.transforms.Affine(
                spatial_size=s,
                mode=m,
                device=self.device)
            for k,s,m in zip(self.keys,self.spatial_sizes,self.mode)}
        
        self.get_translation_adjustment()

    def get_random_parameters(self):
        angle = self.R.uniform(
            -self.rotate_range,self.rotate_range)
        shear = self.R.uniform(
            -self.shear_range,self.shear_range)
        trans = self.R.uniform(
            -self.translate_range,self.translate_range)
        scale = self.R.uniform(
            1-self.scale_range,1+self.scale_range)

        return angle,shear,trans,scale
    
    def get_translation_adjustment(self):
        # we have to adjust the translation to ensure that all inputs
        # do not become misaligned. to do this I assume that the first image
        # is the reference
        ref_size = self.spatial_sizes[0]
        self.trans_adj = {
            k:s/ref_size
            for k,s in zip(self.keys,self.spatial_sizes)}
    
    def randomize(self):
        angle,shear,trans,scale = self.get_random_parameters()
        for k in self.affine_trans:
            # we only need to update the affine grid
            self.affine_trans[k].affine_grid = monai.transforms.AffineGrid(
                rotate_params=list(angle),
                shear_params=list(shear),
                translate_params=list(trans),
                scale_params=list(np.float32(scale*self.trans_adj[k])),
                device=self.device)

    def __call__(self,data):
        self.randomize()
        for k in self.keys:
            if self.R.uniform() < self.prob:
                transform = self.affine_trans[k]
                data[k],data[k+"_affine"] = transform(data[k])
        return data

input_shape = [1,128,128,16]
input_shape_ds = [1,64,64,8]

data = {
    "a":torch.rand(input_shape),
    "b":torch.rand(input_shape_ds)}

t = monai.transforms.RandAffined(
    keys=["a","b"],
    mode=["bilinear","nearest"],
    prob=1.0,
    rotate_range=[np.pi/6,np.pi/6,np.pi/6],
    shear_range=[0,0,0],
    translate_range=[10,10,3],
    scale_range=[0.1,0.1,0.1])

t_1 = time.time()
o = t(data)
t_2 = time.time()

print("Result from MONAI:",o['a'].shape,o['b'].shape)
print("\tTime elapsed:",t_2-t_1)

t = RandomAffined(
    keys=["a","b"],
    spatial_sizes=[input_shape[1:],input_shape_ds[1:]],
    mode=["bilinear","nearest"],
    prob=1.0,
    rotate_range=[np.pi/6,np.pi/6,np.pi/6],
    shear_range=[0,0,0],
    translate_range=[10,10,3],
    scale_range=[0.1,0.1,0.1])

t_1 = time.time()
o = t(data)
t_2 = time.time()

print("Result from own implementaion",o['a'].shape,o['b'].shape)
print("\tTime elapsed:",t_2-t_1)

Output:

Result from MONAI: torch.Size([1, 128, 128, 16]) torch.Size([1, 128, 128, 16])
        Time elapsed: 0.0467381477355957
Result from own implementaion torch.Size([1, 128, 128, 16]) torch.Size([1, 64, 64, 8])
        Time elapsed: 0.02890753746032715

I understand that MONAI has specific coding conventions that I did not follow here (this being the reason why I did not submit a request for changes) but I hope this is clear enough.

@wyli
Copy link
Contributor

wyli commented Jun 15, 2022

thanks for the feature request, the root cause is that the current rotate_range translate_range are defined in terms of the image coordinates, but in this use case, the assumption is that images at different scales share the same world coordinate system. RandAffine should have an option to interpret rotate_range translate_range with respect to the world coordinate, and according to the voxel-to-world transform (provided as image meta info) we can transform the pixel values in a consistent manner. we are in the process of releasing MetaTensor API which tracks the voxel-to-world transform and will include this feature.

@josegcpa
Copy link
Author

@wyli yes, that is correct, there is a strong underlying assumption behind this application - that images share identical world coordinates (I must stress here that this is also the case if images are of equal sizes). This can be useful for cases were the user is absolutely certain that both images are co-registered or, for my use case, that the smaller image simply corresponds to a downsampled segmentation map of the input version. This can also be useful in cases where an algorithm is being trained for super-resolution applications.

I agree that the more generic use case would involve using metadata to interpret world coordinates and ensure that transforms are cohesive between images but I would not consider this to be a necessity - in some cases there is prior knowledge that the world coordinates are identical (even if under/upsampled). I am looking forward for the MetaTensor API, it does seem to greatly simplify how transforms are currently handling this.

@vikashg
Copy link

vikashg commented Jan 5, 2024

closing because of inactivity.

@vikashg vikashg closed this as completed Jan 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants