Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TorchIO transforms wrapper analogous to TorchVision transfo… #7579

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2b37b94
Implement TorchIO transforms wrapper analogous to TorchVision transfo…
SomeUserName1 Mar 25, 2024
96955c6
Add torchio to dependencies
SomeUserName1 Mar 25, 2024
e478ef2
Fixup import order in test
SomeUserName1 Mar 25, 2024
a3cfde1
Add skipUnless annotation to torchio transform wrapper test
SomeUserName1 Mar 26, 2024
61d6987
Merge branch 'dev' into 7499-torchio-transforms-wrapper
SomeUserName1 Mar 26, 2024
e19dc3c
fixup imports
SomeUserName1 Mar 26, 2024
ed88571
Merge branch '7499-torchio-transforms-wrapper' of github.com:SomeUser…
SomeUserName1 Mar 26, 2024
c100076
Merge branch 'dev' into 7499-torchio-transforms-wrapper
SomeUserName1 Mar 27, 2024
de491af
add TorchIOd wrapper, add Transform and RandomizableTrait as base cla…
SomeUserName1 Mar 28, 2024
a4e8161
Merge remote-tracking branch 'main/dev' into 7499-torchio-transforms-…
SomeUserName1 Mar 28, 2024
60dccfb
Merge branch 'Project-MONAI:dev' into 7499-torchio-transforms-wrapper
SomeUserName1 Mar 28, 2024
813ebaf
Merge remote-tracking branch 'refs/remotes/origin/7499-torchio-transf…
SomeUserName1 Mar 28, 2024
50cd7ec
add TorchIO and TorchIOd to docs
SomeUserName1 Mar 28, 2024
701c83e
add flag to apply the same random transform to all elements in the di…
SomeUserName1 Mar 28, 2024
74b6b41
Remove trailing quotes docs/source/transforms.rst
SomeUserName1 Jun 10, 2024
b896604
Remove trailing quotes docs/source/transforms.rst
SomeUserName1 Jun 10, 2024
472c747
Merge branch 'Project-MONAI:dev' into 7499-torchio-transforms-wrapper
SomeUserName1 Nov 13, 2024
09d1099
TorchIO, RandTorchIO, TorchIOd and RandTorchIOd; add RandTorchVision …
SomeUserName1 Nov 18, 2024
1c9334c
Fixup alias for RandTorchIOd
SomeUserName1 Nov 18, 2024
a64d2a1
Merge branch 'dev' into 7499-torchio-transforms-wrapper
SomeUserName1 Nov 18, 2024
63d7579
TorchIO, RandTorchIO, TorchIOd and RandTorchIOd; add RandTorchVision …
SomeUserName1 Nov 18, 2024
688dd06
rebase and merge
SomeUserName1 Nov 18, 2024
d27bb78
remove duplicate export
SomeUserName1 Nov 18, 2024
99bc993
fix formatting
SomeUserName1 Nov 18, 2024
98e8275
remove apply same flag from test and remove redundant test, fix type …
SomeUserName1 Nov 18, 2024
f05aab5
fixup
SomeUserName1 Nov 18, 2024
27bd7fe
Finally...
SomeUserName1 Nov 19, 2024
0697f62
Merge branch 'dev' into 7499-torchio-transforms-wrapper
ericspod Nov 21, 2024
133d391
add docs
SomeUserName1 Nov 22, 2024
34fbc5f
Merge branch '7499-torchio-transforms-wrapper' of github.com:SomeUser…
SomeUserName1 Nov 22, 2024
a1046d9
correct indentation of docs
SomeUserName1 Nov 25, 2024
2a7842d
apply autofix and validate that docs still build
SomeUserName1 Nov 26, 2024
8b286ce
Merge branch 'dev' into 7499-torchio-transforms-wrapper
ericspod Nov 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ tests/testing_data/nrrd_example.nrrd
# clang format tool
.clang-format-bin/

# ctags
tags

# VSCode
.vscode/
*.zip
Expand Down
24 changes: 24 additions & 0 deletions docs/source/transforms.rst
SomeUserName1 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,18 @@ Utility
:members:
:special-members: __call__

`TorchIO`
"""""""""
.. autoclass:: TorchIO
:members:
:special-members: __call__

`RandTorchIO`
"""""""""""""
.. autoclass:: RandTorchIO
:members:
:special-members: __call__

`MapLabelValue`
"""""""""""""""
.. autoclass:: MapLabelValue
Expand Down Expand Up @@ -2253,6 +2265,18 @@ Utility (Dict)
:members:
:special-members: __call__

`TorchIOd`
""""""""""
.. autoclass:: TorchIOd
:members:
:special-members: __call__

`RandTorchIOd`
""""""""""""""
.. autoclass:: RandTorchIOd
:members:
:special-members: __call__

`MapLabelValued`
""""""""""""""""
.. autoclass:: MapLabelValued
Expand Down
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ channels:
dependencies:
- numpy>=1.24,<2.0
- pytorch>=1.9
- torchio
- torchvision
- pytorch-cuda>=11.6
- pip
Expand Down
9 changes: 9 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,8 @@
RandIdentity,
RandImageFilter,
RandLambda,
RandTorchIO,
RandTorchVision,
RemoveRepeatedChannel,
RepeatChannel,
SimulateDelay,
Expand All @@ -540,6 +542,7 @@
ToDevice,
ToNumpy,
ToPIL,
TorchIO,
TorchVision,
ToTensor,
Transpose,
Expand Down Expand Up @@ -620,6 +623,9 @@
RandLambdad,
RandLambdaD,
RandLambdaDict,
RandTorchIOd,
RandTorchIOD,
RandTorchIODict,
RandTorchVisiond,
RandTorchVisionD,
RandTorchVisionDict,
Expand Down Expand Up @@ -653,6 +659,9 @@
ToPILd,
ToPILD,
ToPILDict,
TorchIOd,
TorchIOD,
TorchIODict,
TorchVisiond,
TorchVisionD,
TorchVisionDict,
Expand Down
109 changes: 103 additions & 6 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import sys
import time
import warnings
from collections.abc import Mapping, Sequence
from collections.abc import Hashable, Mapping, Sequence
from copy import deepcopy
from functools import partial
from typing import Any, Callable
from typing import Any, Callable, Union

import numpy as np
import torch
Expand Down Expand Up @@ -99,11 +99,14 @@
"ConvertToMultiChannelBasedOnBratsClasses",
"AddExtremePointsChannel",
"TorchVision",
"TorchIO",
"MapLabelValue",
"IntensityStats",
"ToDevice",
"CuCIM",
"RandCuCIM",
"RandTorchIO",
"RandTorchVision",
"ToCupy",
"ImageFilter",
"RandImageFilter",
Expand Down Expand Up @@ -1136,12 +1139,44 @@ def __call__(
return concatenate((img, points_image), axis=0)


class TorchVision:
class TorchVision(Transform):
"""
This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args.
As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input
data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor.
This is a wrapper transform for PyTorch TorchVision non-randomized transform based on the specified transform name and args.
Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.
"""

backend = [TransformBackends.TORCH]

def __init__(self, name: str, *args, **kwargs) -> None:
"""
Args:
name: The transform name in TorchVision package.
args: parameters for the TorchVision transform.
kwargs: parameters for the TorchVision transform.

"""
super().__init__()
self.name = name
transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name)
self.trans = transform(*args, **kwargs)

def __call__(self, img: NdarrayOrTensor):
"""
Args:
img: PyTorch Tensor data for the TorchVision transform.

"""
img_t, *_ = convert_data_type(img, torch.Tensor)

out = self.trans(img_t)
out, *_ = convert_to_dst_type(src=out, dst=img)
return out


class RandTorchVision(Transform, RandomizableTrait):
"""
This is a wrapper transform for PyTorch TorchVision randomized transform based on the specified transform name and args.
Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.
"""

backend = [TransformBackends.TORCH]
Expand Down Expand Up @@ -1172,6 +1207,68 @@ def __call__(self, img: NdarrayOrTensor):
return out


class TorchIO(Transform):
"""
This is a wrapper for TorchIO non-randomized transforms based on the specified transform name and args.
See https://torchio.readthedocs.io/transforms/transforms.html for more details.
"""

backend = [TransformBackends.TORCH]

def __init__(self, name: str, *args, **kwargs) -> None:
"""
Args:
name: The transform name in TorchIO package.
args: parameters for the TorchIO transform.
SomeUserName1 marked this conversation as resolved.
Show resolved Hide resolved
kwargs: parameters for the TorchIO transform.
"""
super().__init__()
self.name = name
transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
self.trans = transform(*args, **kwargs)

def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]):
"""
Args:
img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,
or dict containing 4D tensors as values

"""
return self.trans(img)


class RandTorchIO(Transform, RandomizableTrait):
"""
This is a wrapper for TorchIO randomized transforms based on the specified transform name and args.
See https://torchio.readthedocs.io/transforms/transforms.html for more details.
Use this wrapper for all TorchIO transform inheriting from RandomTransform:
https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform
"""

backend = [TransformBackends.TORCH]

def __init__(self, name: str, *args, **kwargs) -> None:
"""
Args:
name: The transform name in TorchIO package.
args: parameters for the TorchIO transform.
kwargs: parameters for the TorchIO transform.
"""
super().__init__()
self.name = name
transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
self.trans = transform(*args, **kwargs)

def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]):
"""
Args:
img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,
or dict containing 4D tensors as values

"""
return self.trans(img)
SomeUserName1 marked this conversation as resolved.
Show resolved Hide resolved


class MapLabelValue:
"""
Utility to map label values to another set of values.
Expand Down
67 changes: 67 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
ToDevice,
ToNumpy,
ToPIL,
TorchIO,
TorchVision,
ToTensor,
Transpose,
Expand Down Expand Up @@ -136,6 +137,9 @@
"RandLambdaD",
"RandLambdaDict",
"RandLambdad",
"RandTorchIOd",
"RandTorchIOD",
"RandTorchIODict",
"RandTorchVisionD",
"RandTorchVisionDict",
"RandTorchVisiond",
Expand Down Expand Up @@ -172,6 +176,9 @@
"ToTensorD",
"ToTensorDict",
"ToTensord",
"TorchIOD",
"TorchIODict",
"TorchIOd",
"TorchVisionD",
"TorchVisionDict",
"TorchVisiond",
Expand Down Expand Up @@ -1445,6 +1452,64 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class TorchIOd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for non-randomized transforms.
For randomized transforms of TorchIO use :py:class:`monai.transforms.RandTorchIOd`.
"""

backend = TorchIO.backend

def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
name: The transform name in TorchIO package.
allow_missing_keys: don't raise exception if key is missing.
args: parameters for the TorchIO transform.
kwargs: parameters for the TorchIO transform.

"""
super().__init__(keys, allow_missing_keys)
self.name = name
kwargs["include"] = self.keys

self.trans = TorchIO(name, *args, **kwargs)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
return dict(self.trans(data))


class RandTorchIOd(MapTransform, RandomizableTrait):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for randomized transforms.
For non-randomized transforms of TorchIO use :py:class:`monai.transforms.TorchIOd`.
"""

backend = TorchIO.backend

def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
name: The transform name in TorchIO package.
allow_missing_keys: don't raise exception if key is missing.
args: parameters for the TorchIO transform.
kwargs: parameters for the TorchIO transform.

"""
super().__init__(keys, allow_missing_keys)
self.name = name
kwargs["include"] = self.keys

self.trans = TorchIO(name, *args, **kwargs)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
return dict(self.trans(data))


class MapLabelValued(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`.
Expand Down Expand Up @@ -1871,8 +1936,10 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
ConvertToMultiChannelBasedOnBratsClassesd
)
AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
TorchIOD = TorchIODict = TorchIOd
TorchVisionD = TorchVisionDict = TorchVisiond
RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond
RandTorchIOD = RandTorchIODict = RandTorchIOd
RandLambdaD = RandLambdaDict = RandLambdad
MapLabelValueD = MapLabelValueDict = MapLabelValued
IntensityStatsD = IntensityStatsDict = IntensityStatsd
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pytype>=2020.6.1; platform_system != "Windows"
types-setuptools
mypy>=1.5.0, <1.12.0
ninja
torchio
torchvision
psutil
cucim-cu12; platform_system == "Linux" and python_version >= "3.9" and python_version <= "3.10"
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ all =
tensorboard
gdown>=4.7.3
pytorch-ignite==0.4.11
torchio
torchvision
itk>=5.2
tqdm>=4.47.0
Expand Down Expand Up @@ -102,6 +103,8 @@ gdown =
gdown>=4.7.3
ignite =
pytorch-ignite==0.4.11
torchio =
torchio
torchvision =
torchvision
itk =
Expand Down
Loading
Loading