Skip to content

Commit

Permalink
4636 4637 backward compatible types (#4638)
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Jul 7, 2022
1 parent 4ddd2bc commit e270741
Show file tree
Hide file tree
Showing 23 changed files with 152 additions and 99 deletions.
43 changes: 36 additions & 7 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from copy import deepcopy
from typing import Any, Sequence

import numpy as np
import torch

from monai.config.type_definitions import NdarrayTensor
from monai.data.meta_obj import MetaObj, get_track_meta
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
from monai.utils import look_up_option
from monai.utils.enums import PostFix
from monai.utils.type_conversion import convert_to_tensor
from monai.utils.type_conversion import convert_data_type, convert_to_tensor

__all__ = ["MetaTensor"]

Expand Down Expand Up @@ -307,6 +309,33 @@ def as_dict(self, key: str) -> dict:
PostFix.transforms(key): deepcopy(self.applied_operations),
}

def astype(self, dtype, device=None, *unused_args, **unused_kwargs):
"""
Cast to ``dtype``, sharing data whenever possible.
Args:
dtype: dtypes such as np.float32, torch.float, "np.float32", float.
device: the device if `dtype` is a torch data type.
unused_args: additional args (currently unused).
unused_kwargs: additional kwargs (currently unused).
Returns:
data array instance
"""
if isinstance(dtype, str):
mod_str, *dtype = dtype.split(".", 1)
dtype = mod_str if not dtype else dtype[0]
else:
mod_str = getattr(dtype, "__module__", "torch")
mod_str = look_up_option(mod_str, {"torch", "numpy", "np"}, default="numpy")
if mod_str == "torch":
out_type = torch.Tensor
elif mod_str in ("numpy", "np"):
out_type = np.ndarray
else:
out_type = None
return convert_data_type(self, output_type=out_type, device=device, dtype=dtype, wrap_sequence=True)[0]

@property
def affine(self) -> torch.Tensor:
"""Get the affine."""
Expand Down Expand Up @@ -334,7 +363,7 @@ def new_empty(self, size, dtype=None, device=None, requires_grad=False):
)

@staticmethod
def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict):
def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict, simple_keys: bool = False):
"""
Convert the image to `torch.Tensor`. If `affine` is in the `meta` dictionary,
convert that to `torch.Tensor`, too. Remove any superfluous metadata.
Expand All @@ -353,12 +382,12 @@ def ensure_torch_and_prune_meta(im: NdarrayTensor, meta: dict):
if not get_track_meta() or meta is None:
return img

# ensure affine is of type `torch.Tensor`
if "affine" in meta:
meta["affine"] = convert_to_tensor(meta["affine"])

# remove any superfluous metadata.
remove_extra_metadata(meta)
if simple_keys:
# ensure affine is of type `torch.Tensor`
if "affine" in meta:
meta["affine"] = convert_to_tensor(meta["affine"]) # bc-breaking
remove_extra_metadata(meta) # bc-breaking

# return the `MetaTensor`
return MetaTensor(img, meta=meta)
Expand Down
7 changes: 5 additions & 2 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@ class LoadImage(Transform):
def __init__(
self,
reader=None,
image_only: bool = True,
image_only: bool = False,
dtype: DtypeLike = np.float32,
ensure_channel_first: bool = False,
simple_keys: bool = False,
*args,
**kwargs,
) -> None:
Expand All @@ -127,6 +128,7 @@ def __init__(
dtype: if not None convert the loaded image to this data type.
ensure_channel_first: if `True` and loaded both image array and metadata, automatically convert
the image array shape to `channel first`. default to `False`.
simple_keys: whether to remove redundant metadata keys, default to False for backward compatibility.
args: additional parameters for reader if providing a reader name.
kwargs: additional parameters for reader if providing a reader name.
Expand All @@ -145,6 +147,7 @@ def __init__(
self.image_only = image_only
self.dtype = dtype
self.ensure_channel_first = ensure_channel_first
self.simple_keys = simple_keys

self.readers: List[ImageReader] = []
for r in SUPPORTED_READERS: # set predefined readers as default
Expand Down Expand Up @@ -255,7 +258,7 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option
meta_data = switch_endianness(meta_data, "<")

meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader
img = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data)
img = MetaTensor.ensure_torch_and_prune_meta(img_array, meta_data, self.simple_keys)
if self.ensure_channel_first:
img = EnsureChannelFirst()(img)
if self.image_only:
Expand Down
6 changes: 4 additions & 2 deletions monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,9 @@ def __init__(
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = DEFAULT_POST_FIX,
overwriting: bool = False,
image_only: bool = True,
image_only: bool = False,
ensure_channel_first: bool = False,
simple_keys: bool = False,
allow_missing_keys: bool = False,
*args,
**kwargs,
Expand Down Expand Up @@ -103,12 +104,13 @@ def __init__(
dictionary containing image data array and header dict per input key.
ensure_channel_first: if `True` and loaded both image array and metadata, automatically convert
the image array shape to `channel first`. default to `False`.
simple_keys: whether to remove redundant metadata keys, default to False for backward compatibility.
allow_missing_keys: don't raise exception if key is missing.
args: additional parameters for reader if providing a reader name.
kwargs: additional parameters for reader if providing a reader name.
"""
super().__init__(keys, allow_missing_keys)
self._loader = LoadImage(reader, image_only, dtype, ensure_channel_first, *args, **kwargs)
self._loader = LoadImage(reader, image_only, dtype, ensure_channel_first, simple_keys, *args, **kwargs)
if not isinstance(meta_key_postfix, str):
raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.")
self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
Expand Down
15 changes: 13 additions & 2 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,8 @@ class EnsureType(Transform):
device: for Tensor data type, specify the target device.
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
track_meta: whether to convert to `MetaTensor` when `data_type` is "tensor".
If False, the output data type will be `torch.Tensor`. Default to the return value of ``get_track_meta``.
"""

Expand All @@ -446,11 +448,13 @@ def __init__(
dtype: Optional[Union[DtypeLike, torch.dtype]] = None,
device: Optional[torch.device] = None,
wrap_sequence: bool = True,
track_meta: Optional[bool] = None,
) -> None:
self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"})
self.dtype = dtype
self.device = device
self.wrap_sequence = wrap_sequence
self.track_meta = get_track_meta() if track_meta is None else bool(track_meta)

def __call__(self, data: NdarrayOrTensor):
"""
Expand All @@ -461,10 +465,17 @@ def __call__(self, data: NdarrayOrTensor):
if applicable and `wrap_sequence=False`.
"""
output_type = torch.Tensor if self.data_type == "tensor" else np.ndarray
if self.data_type == "tensor":
output_type = MetaTensor if self.track_meta else torch.Tensor
else:
output_type = np.ndarray # type: ignore
out: NdarrayOrTensor
out, *_ = convert_data_type(
data=data, output_type=output_type, dtype=self.dtype, device=self.device, wrap_sequence=self.wrap_sequence
data=data,
output_type=output_type, # type: ignore
dtype=self.dtype,
device=self.device,
wrap_sequence=self.wrap_sequence,
)
return out

Expand Down
22 changes: 8 additions & 14 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
)
from monai.transforms.utils import extreme_points_to_image, get_extreme_points
from monai.transforms.utils_pytorch_numpy_unification import concatenate
from monai.utils import convert_to_numpy, deprecated, deprecated_arg, ensure_tuple, ensure_tuple_rep
from monai.utils import deprecated, deprecated_arg, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import PostFix, TraceKeys, TransformBackends
from monai.utils.type_conversion import convert_to_dst_type

Expand Down Expand Up @@ -519,7 +519,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
return d


class EnsureTyped(MapTransform, InvertibleTransform):
class EnsureTyped(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.EnsureType`.
Expand All @@ -541,6 +541,7 @@ def __init__(
dtype: Union[DtypeLike, torch.dtype] = None,
device: Optional[torch.device] = None,
wrap_sequence: bool = True,
track_meta: Optional[bool] = None,
allow_missing_keys: bool = False,
) -> None:
"""
Expand All @@ -552,28 +553,21 @@ def __init__(
device: for Tensor data type, specify the target device.
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
track_meta: whether to convert to `MetaTensor` when `data_type` is "tensor".
If False, the output data type will be `torch.Tensor`. Default to the return value of `get_track_meta`.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.converter = EnsureType(data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence)
self.converter = EnsureType(
data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta
)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
self.push_transform(d, key)
d[key] = self.converter(d[key])
return d

def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
d = deepcopy(dict(data))
for key in self.key_iterator(d):
# FIXME: currently, only convert tensor data to numpy array or scalar number,
# need to also invert numpy array but it's not easy to determine the previous data type
d[key] = convert_to_numpy(d[key])
# Remove the applied transform
self.pop_transform(d, key)
return d


class ToNumpyd(MapTransform):
"""
Expand Down
9 changes: 5 additions & 4 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False)
E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.
"""
if isinstance(data, torch.Tensor):
data = data.detach().to(dtype=get_equivalent_dtype(dtype, torch.Tensor), device="cpu").numpy()
data = np.asarray(data.detach().to(device="cpu").numpy(), dtype=get_equivalent_dtype(dtype, np.ndarray))
elif has_cp and isinstance(data, cp_ndarray):
data = cp.asnumpy(data).astype(dtype, copy=False)
elif isinstance(data, (np.ndarray, float, int, bool)):
Expand Down Expand Up @@ -235,12 +235,13 @@ def convert_data_type(
wrap_sequence: bool = False,
) -> Tuple[NdarrayTensor, type, Optional[torch.device]]:
"""
Convert to `torch.Tensor`/`np.ndarray` from `torch.Tensor`/`np.ndarray`/`float`/`int` etc.
Convert to `MetaTensor`, `torch.Tensor` or `np.ndarray` from `MetaTensor`, `torch.Tensor`,
`np.ndarray`, `float`, `int`, etc.
Args:
data: data to be converted
output_type: `torch.Tensor` or `np.ndarray` (if `None`, unchanged)
device: if output is `torch.Tensor`, select device (if `None`, unchanged)
output_type: `monai.data.MetaTensor`, `torch.Tensor`, or `np.ndarray` (if `None`, unchanged)
device: if output is `MetaTensor` or `torch.Tensor`, select device (if `None`, unchanged)
dtype: dtype of output data. Converted to correct library type (e.g.,
`np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).
If left blank, it remains unchanged.
Expand Down
14 changes: 7 additions & 7 deletions tests/test_arraydataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
from monai.transforms import AddChannel, Compose, LoadImage, RandAdjustContrast, RandGaussianNoise, Spacing

TEST_CASE_1 = [
Compose([LoadImage(), AddChannel(), RandGaussianNoise(prob=1.0)]),
Compose([LoadImage(), AddChannel(), RandGaussianNoise(prob=1.0)]),
Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]),
Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]),
(0, 1),
(1, 128, 128, 128),
]

TEST_CASE_2 = [
Compose([LoadImage(), AddChannel(), RandAdjustContrast(prob=1.0)]),
Compose([LoadImage(), AddChannel(), RandAdjustContrast(prob=1.0)]),
Compose([LoadImage(image_only=True), AddChannel(), RandAdjustContrast(prob=1.0)]),
Compose([LoadImage(image_only=True), AddChannel(), RandAdjustContrast(prob=1.0)]),
(0, 1),
(1, 128, 128, 128),
]
Expand All @@ -48,13 +48,13 @@ def __call__(self, input_):


TEST_CASE_3 = [
TestCompose([LoadImage(), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]),
TestCompose([LoadImage(), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]),
TestCompose([LoadImage(image_only=True), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]),
TestCompose([LoadImage(image_only=True), AddChannel(), Spacing(pixdim=(2, 2, 4)), RandAdjustContrast(prob=1.0)]),
(0, 2),
(1, 64, 64, 33),
]

TEST_CASE_4 = [Compose([LoadImage(), AddChannel(), RandGaussianNoise(prob=1.0)]), (1, 128, 128, 128)]
TEST_CASE_4 = [Compose([LoadImage(image_only=True), AddChannel(), RandGaussianNoise(prob=1.0)]), (1, 128, 128, 128)]


class TestArrayDataset(unittest.TestCase):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_decollate.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_decollation_dict(self, *transforms):
t_compose = Compose([AddChanneld(KEYS), Compose(transforms), ToTensord(KEYS)])
# If nibabel present, read from disk
if has_nib:
t_compose = Compose([LoadImaged("image"), t_compose])
t_compose = Compose([LoadImaged("image", image_only=True), t_compose])

dataset = CacheDataset(self.data_dict, t_compose, progress=False)
self.check_decollate(dataset=dataset)
Expand All @@ -141,7 +141,7 @@ def test_decollation_tensor(self, *transforms):
t_compose = Compose([AddChannel(), Compose(transforms), ToTensor()])
# If nibabel present, read from disk
if has_nib:
t_compose = Compose([LoadImage(), t_compose])
t_compose = Compose([LoadImage(image_only=True), t_compose])

dataset = Dataset(self.data_list, t_compose)
self.check_decollate(dataset=dataset)
Expand All @@ -151,7 +151,7 @@ def test_decollation_list(self, *transforms):
t_compose = Compose([AddChannel(), Compose(transforms), ToTensor()])
# If nibabel present, read from disk
if has_nib:
t_compose = Compose([LoadImage(), t_compose])
t_compose = Compose([LoadImage(image_only=True), t_compose])

dataset = Dataset(self.data_list, t_compose)
self.check_decollate(dataset=dataset)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_ensure_channel_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def test_load_nifti(self, input_param, filenames, original_channel_dim):
filenames[i] = os.path.join(tempdir, name)
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filenames[i])

result = LoadImage(**input_param)(filenames)
result = LoadImage(image_only=True, **input_param)(filenames)
result = EnsureChannelFirst()(result)
self.assertEqual(result.shape[0], len(filenames))

@parameterized.expand([TEST_CASE_7])
def test_itk_dicom_series_reader(self, input_param, filenames, _):
result = LoadImage(**input_param)(filenames)
result = LoadImage(image_only=True, **input_param)(filenames)
result = EnsureChannelFirst()(result)
self.assertEqual(result.shape[0], 1)

Expand All @@ -68,7 +68,7 @@ def test_load_png(self):
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, "test_image.png")
Image.fromarray(test_image.astype("uint8")).save(filename)
result = LoadImage()(filename)
result = LoadImage(image_only=True)(filename)
result = EnsureChannelFirst()(result)
self.assertEqual(result.shape[0], 3)

Expand Down
7 changes: 4 additions & 3 deletions tests/test_ensure_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
import torch

from monai.data import MetaTensor
from monai.transforms import EnsureType
from tests.utils import assert_allclose

Expand Down Expand Up @@ -59,9 +60,9 @@ def test_string(self):

def test_list_tuple(self):
for dtype in ("tensor", "numpy"):
result = EnsureType(data_type=dtype, wrap_sequence=False)([[1, 2], [3, 4]])
result = EnsureType(data_type=dtype, wrap_sequence=False, track_meta=True)([[1, 2], [3, 4]])
self.assertTrue(isinstance(result, list))
self.assertTrue(isinstance(result[0][1], torch.Tensor if dtype == "tensor" else np.ndarray))
self.assertTrue(isinstance(result[0][1], MetaTensor if dtype == "tensor" else np.ndarray))
torch.testing.assert_allclose(result[1][0], torch.as_tensor(3))
# tuple of numpy arrays
result = EnsureType(data_type=dtype, wrap_sequence=False)((np.array([1, 2]), np.array([3, 4])))
Expand All @@ -77,7 +78,7 @@ def test_dict(self):
"extra": None,
}
for dtype in ("tensor", "numpy"):
result = EnsureType(data_type=dtype)(test_data)
result = EnsureType(data_type=dtype, track_meta=False)(test_data)
self.assertTrue(isinstance(result, dict))
self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray))
torch.testing.assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]))
Expand Down
Loading

0 comments on commit e270741

Please sign in to comment.