You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
When rotating with Rotate/Rotated/RandRotate/RandRotated with nearest neighbor sampling, the output dtype should be identical to the input dtype and is currently forced to be float32.
To Reproduce
arr=np.ones((1, 32, 32), dtype=np.uint8)
transform=Rotate(np.deg2rad(10), mode="nearest")
res=transform(arr)
print(arr.dtype, res.dtype)
# np.uint8 np.float32transform=Rotate(np.deg2rad(10), mode="nearest", dtype=None)
res=transform(arr)
print(arr.dtype, res.dtype)
# Traceback (most recent call last):# File ".../test.py", line 10, in <module># res = transform(arr)# File ".../.venv/lib/python3.9/site-packages/monai/transforms/spatial/array.py", line 555, in __call__# output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=output_shape).float().squeeze(0)# File ".../.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl# return forward_call(*input, **kwargs)# File ".../.venv/lib/python3.9/site-packages/monai/networks/layers/spatial_transforms.py", line 527, in forward# theta = to_norm_affine(# File ".../.venv/lib/python3.9/site-packages/monai/networks/utils.py", line 189, in to_norm_affine# return src_xform @ affine @ torch.inverse(dst_xform)# RuntimeError: "inverse_cpu" not implemented for 'Byte'
Expected behavior
For nearest neighbor sampling there is no need for converting the dtype to float32. Especially for label maps converting from uint8 to float32 only increases the required memory. Other spatial transforms like e.g. Zoom/RandZoom correctly return the desired dtype.
as your error log correctly noted, internally it's converted to float32 for the torch backend (currently it doesn't support uint8). so it wouldn't actually reduce the memory footprint by specifying dtype=unit8. you can optionally cast the output type with monai.transforms.CastToType
Yes, that is currently my workaround. Though it feels a bit inconsistent to the rest of the API. I've only spotted this change in dtype since I used the label in the training loop for plotting. If someone does not spot this, the rest of the transform pipeline after a Rotate/RandRotate(/Affine/RandAffine) will be performed on an array with a larger memory footprint and may even be slower to compute.
Too bad that this is a limitation of torch.nn.functional.grid_sampler. For reference I link to an open issue: pytorch/vision#2289
Describe the bug
When rotating with Rotate/Rotated/RandRotate/RandRotated with nearest neighbor sampling, the output dtype should be identical to the input dtype and is currently forced to be float32.
To Reproduce
Expected behavior
For nearest neighbor sampling there is no need for converting the dtype to float32. Especially for label maps converting from uint8 to float32 only increases the required memory. Other spatial transforms like e.g. Zoom/RandZoom correctly return the desired dtype.
Environment
Monai 0.8.0
The text was updated successfully, but these errors were encountered: