Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rotate with nearest filter does not return same dtype #3831

Closed
razorx89 opened this issue Feb 21, 2022 · 3 comments
Closed

Rotate with nearest filter does not return same dtype #3831

razorx89 opened this issue Feb 21, 2022 · 3 comments
Labels
bug Something isn't working

Comments

@razorx89
Copy link
Contributor

Describe the bug
When rotating with Rotate/Rotated/RandRotate/RandRotated with nearest neighbor sampling, the output dtype should be identical to the input dtype and is currently forced to be float32.

To Reproduce

arr = np.ones((1, 32, 32), dtype=np.uint8)
transform = Rotate(np.deg2rad(10), mode="nearest")
res = transform(arr)
print(arr.dtype, res.dtype)
# np.uint8 np.float32

transform = Rotate(np.deg2rad(10), mode="nearest", dtype=None)
res = transform(arr)
print(arr.dtype, res.dtype)
# Traceback (most recent call last):
#   File ".../test.py", line 10, in <module>
#     res = transform(arr)
#   File ".../.venv/lib/python3.9/site-packages/monai/transforms/spatial/array.py", line 555, in __call__
#     output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=output_shape).float().squeeze(0)
#   File ".../.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
#     return forward_call(*input, **kwargs)
#   File ".../.venv/lib/python3.9/site-packages/monai/networks/layers/spatial_transforms.py", line 527, in forward
#     theta = to_norm_affine(
#   File ".../.venv/lib/python3.9/site-packages/monai/networks/utils.py", line 189, in to_norm_affine
#     return src_xform @ affine @ torch.inverse(dst_xform)
# RuntimeError: "inverse_cpu" not implemented for 'Byte'

Expected behavior
For nearest neighbor sampling there is no need for converting the dtype to float32. Especially for label maps converting from uint8 to float32 only increases the required memory. Other spatial transforms like e.g. Zoom/RandZoom correctly return the desired dtype.

from monai.transforms import Zoom
transform = Zoom(1.2, mode="nearest")
res = transform(arr)
print(arr.dtype, res.dtype)
# uint8 uint8

Environment
Monai 0.8.0

@wyli
Copy link
Contributor

wyli commented Feb 21, 2022

as your error log correctly noted, internally it's converted to float32 for the torch backend (currently it doesn't support uint8). so it wouldn't actually reduce the memory footprint by specifying dtype=unit8. you can optionally cast the output type with monai.transforms.CastToType

@razorx89
Copy link
Contributor Author

Yes, that is currently my workaround. Though it feels a bit inconsistent to the rest of the API. I've only spotted this change in dtype since I used the label in the training loop for plotting. If someone does not spot this, the rest of the transform pipeline after a Rotate/RandRotate(/Affine/RandAffine) will be performed on an array with a larger memory footprint and may even be slower to compute.

Too bad that this is a limitation of torch.nn.functional.grid_sampler. For reference I link to an open issue:
pytorch/vision#2289

@vikashg
Copy link

vikashg commented Jan 5, 2024

closing because of inactivity

@vikashg vikashg closed this as completed Jan 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants