Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TorchIO transforms wrapper analogous to TorchVision transfo… #7579

Merged
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2b37b94
Implement TorchIO transforms wrapper analogous to TorchVision transfo…
SomeUserName1 Mar 25, 2024
96955c6
Add torchio to dependencies
SomeUserName1 Mar 25, 2024
e478ef2
Fixup import order in test
SomeUserName1 Mar 25, 2024
a3cfde1
Add skipUnless annotation to torchio transform wrapper test
SomeUserName1 Mar 26, 2024
61d6987
Merge branch 'dev' into 7499-torchio-transforms-wrapper
SomeUserName1 Mar 26, 2024
e19dc3c
fixup imports
SomeUserName1 Mar 26, 2024
ed88571
Merge branch '7499-torchio-transforms-wrapper' of github.com:SomeUser…
SomeUserName1 Mar 26, 2024
c100076
Merge branch 'dev' into 7499-torchio-transforms-wrapper
SomeUserName1 Mar 27, 2024
de491af
add TorchIOd wrapper, add Transform and RandomizableTrait as base cla…
SomeUserName1 Mar 28, 2024
a4e8161
Merge remote-tracking branch 'main/dev' into 7499-torchio-transforms-…
SomeUserName1 Mar 28, 2024
60dccfb
Merge branch 'Project-MONAI:dev' into 7499-torchio-transforms-wrapper
SomeUserName1 Mar 28, 2024
813ebaf
Merge remote-tracking branch 'refs/remotes/origin/7499-torchio-transf…
SomeUserName1 Mar 28, 2024
50cd7ec
add TorchIO and TorchIOd to docs
SomeUserName1 Mar 28, 2024
701c83e
add flag to apply the same random transform to all elements in the di…
SomeUserName1 Mar 28, 2024
74b6b41
Remove trailing quotes docs/source/transforms.rst
SomeUserName1 Jun 10, 2024
b896604
Remove trailing quotes docs/source/transforms.rst
SomeUserName1 Jun 10, 2024
472c747
Merge branch 'Project-MONAI:dev' into 7499-torchio-transforms-wrapper
SomeUserName1 Nov 13, 2024
09d1099
TorchIO, RandTorchIO, TorchIOd and RandTorchIOd; add RandTorchVision …
SomeUserName1 Nov 18, 2024
1c9334c
Fixup alias for RandTorchIOd
SomeUserName1 Nov 18, 2024
a64d2a1
Merge branch 'dev' into 7499-torchio-transforms-wrapper
SomeUserName1 Nov 18, 2024
63d7579
TorchIO, RandTorchIO, TorchIOd and RandTorchIOd; add RandTorchVision …
SomeUserName1 Nov 18, 2024
688dd06
rebase and merge
SomeUserName1 Nov 18, 2024
d27bb78
remove duplicate export
SomeUserName1 Nov 18, 2024
99bc993
fix formatting
SomeUserName1 Nov 18, 2024
98e8275
remove apply same flag from test and remove redundant test, fix type …
SomeUserName1 Nov 18, 2024
f05aab5
fixup
SomeUserName1 Nov 18, 2024
27bd7fe
Finally...
SomeUserName1 Nov 19, 2024
0697f62
Merge branch 'dev' into 7499-torchio-transforms-wrapper
ericspod Nov 21, 2024
133d391
add docs
SomeUserName1 Nov 22, 2024
34fbc5f
Merge branch '7499-torchio-transforms-wrapper' of github.com:SomeUser…
SomeUserName1 Nov 22, 2024
a1046d9
correct indentation of docs
SomeUserName1 Nov 25, 2024
2a7842d
apply autofix and validate that docs still build
SomeUserName1 Nov 26, 2024
8b286ce
Merge branch 'dev' into 7499-torchio-transforms-wrapper
ericspod Nov 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ channels:
dependencies:
- numpy>=1.20
- pytorch>=1.9
- torchio
- torchvision
- pytorch-cuda=11.6
- pip
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@
ToDevice,
ToNumpy,
ToPIL,
TorchIO,
TorchVision,
ToTensor,
Transpose,
Expand Down Expand Up @@ -627,6 +628,9 @@
ToPILd,
ToPILD,
ToPILDict,
TorchIOd,
TorchIOD,
TorchIODict,
TorchVisiond,
TorchVisionD,
TorchVisionDict,
Expand Down
37 changes: 33 additions & 4 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
"ConvertToMultiChannelBasedOnBratsClasses",
"AddExtremePointsChannel",
"TorchVision",
"TorchIO",
"MapLabelValue",
"IntensityStats",
"ToDevice",
Expand Down Expand Up @@ -1127,12 +1128,10 @@ def __call__(
return concatenate((img, points_image), axis=0)


class TorchVision:
class TorchVision(Transform):
"""
This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args.
As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input
data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor.

Data is converted to a torch.tensor before applying the transform and then converted back to the original data type.
"""

backend = [TransformBackends.TORCH]
Expand Down Expand Up @@ -1163,6 +1162,36 @@ def __call__(self, img: NdarrayOrTensor):
return out


class TorchIO(Transform, RandomizableTrait):
SomeUserName1 marked this conversation as resolved.
Show resolved Hide resolved
"""
This is a wrapper transform for TorchIO transforms based on the specified transform name and args.
"""

backend = [TransformBackends.TORCH]

def __init__(self, name: str, *args, **kwargs) -> None:
"""
Args:
name: The transform name in TorchIO package.
args: parameters for the TorchIO transform.
SomeUserName1 marked this conversation as resolved.
Show resolved Hide resolved
kwargs: parameters for the TorchIO transform.

"""
super().__init__()
self.name = name
transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
self.trans = transform(*args, **kwargs)

def __call__(self, img: NdarrayOrTensor):
"""
Args:
img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,
or dict containing 4D tensors as values

"""
return self.trans(img)
SomeUserName1 marked this conversation as resolved.
Show resolved Hide resolved


class MapLabelValue:
"""
Utility to map label values to another set of values.
Expand Down
35 changes: 35 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
ToDevice,
ToNumpy,
ToPIL,
TorchIO,
TorchVision,
ToTensor,
Transpose,
Expand Down Expand Up @@ -171,6 +172,9 @@
"ToTensorD",
"ToTensorDict",
"ToTensord",
"TorchIOD",
"TorchIODict",
"TorchIOd",
"TorchVisionD",
"TorchVisionDict",
"TorchVisiond",
Expand Down Expand Up @@ -1419,6 +1423,36 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class TorchIOd(MapTransform, RandomizableTrait):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for transforms.
All transforms in TorchIO can be applied randomly with probability p by specifying the `p=` argument.
"""

backend = TorchIO.backend

def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
name: The transform name in TorchIO package.
allow_missing_keys: don't raise exception if key is missing.
args: parameters for the TorchIO transform.
kwargs: parameters for the TorchIO transform.

"""
super().__init__(keys, allow_missing_keys)
self.name = name
self.trans = TorchIO(name, *args, **kwargs)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.trans(d[key])
SomeUserName1 marked this conversation as resolved.
Show resolved Hide resolved
return d


class MapLabelValued(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`.
Expand Down Expand Up @@ -1771,6 +1805,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
)
AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
TorchVisionD = TorchVisionDict = TorchVisiond
TorchIOD = TorchIODict = TorchIOd
RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond
RandLambdaD = RandLambdaDict = RandLambdad
MapLabelValueD = MapLabelValueDict = MapLabelValued
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pytype>=2020.6.1; platform_system != "Windows"
types-pkg_resources
mypy>=1.5.0
ninja
torchio
torchvision
psutil
cucim-cu12; platform_system == "Linux" and python_version >= "3.9" and python_version <= "3.10"
Expand Down
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ all =
tensorboard
gdown>=4.7.3
pytorch-ignite==0.4.11
torchio
torchvision
itk>=5.2
tqdm>=4.47.0
Expand Down Expand Up @@ -100,6 +101,8 @@ gdown =
gdown>=4.7.3
ignite =
pytorch-ignite==0.4.11
torchio =
torchio
torchvision =
torchvision
itk =
Expand Down
56 changes: 56 additions & 0 deletions tests/test_torchio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest
from unittest import skipUnless

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import TorchIO
from monai.utils import optional_import, set_determinism

_, has_torchio = optional_import("torchio")

TEST_DIMS = [3, 128, 160, 160]
TESTS = [
[{"name": "RescaleIntensity"}, torch.rand(TEST_DIMS)],
[{"name": "ZNormalization"}, torch.rand(TEST_DIMS)],
[{"name": "RandomAffine"}, torch.rand(TEST_DIMS)],
[{"name": "RandomElasticDeformation"}, torch.rand(TEST_DIMS)],
[{"name": "RandomAnisotropy"}, torch.rand(TEST_DIMS)],
[{"name": "RandomMotion"}, torch.rand(TEST_DIMS)],
[{"name": "RandomGhosting"}, torch.rand(TEST_DIMS)],
[{"name": "RandomSpike"}, torch.rand(TEST_DIMS)],
[{"name": "RandomBiasField"}, torch.rand(TEST_DIMS)],
[{"name": "RandomBlur"}, torch.rand(TEST_DIMS)],
[{"name": "RandomNoise"}, torch.rand(TEST_DIMS)],
[{"name": "RandomSwap"}, torch.rand(TEST_DIMS)],
[{"name": "RandomGamma"}, torch.rand(TEST_DIMS)],
]


@skipUnless(has_torchio, "Requires torchio")
class TestTorchIO(unittest.TestCase):

@parameterized.expand(TESTS)
def test_value(self, input_param, input_data):
set_determinism(seed=0)
result = TorchIO(**input_param)(input_data)
self.assertIsNotNone(result)
self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), f"{input_param} failed")


if __name__ == "__main__":
unittest.main()
48 changes: 48 additions & 0 deletions tests/test_torchiod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest
from unittest import skipUnless

import torch
from parameterized import parameterized

from monai.transforms import TorchIOd
from monai.utils import optional_import, set_determinism
from tests.utils import assert_allclose

_, has_torchio = optional_import("torchio")

TEST_DIMS = [3, 128, 160, 160]
TEST_TENSOR = torch.rand(TEST_DIMS)
TESTS = [
[
{"keys": "img", "name": "RescaleIntensity", "out_min_max": (0, 42)},
{"img": TEST_TENSOR},
((TEST_TENSOR - TEST_TENSOR.min()) / (TEST_TENSOR.max() - TEST_TENSOR.min())) * 42,
]
]


@skipUnless(has_torchio, "Requires torchio")
class TestTorchVisiond(unittest.TestCase):

@parameterized.expand(TESTS)
def test_value(self, input_param, input_data, expected_value):
set_determinism(seed=0)
result = TorchIOd(**input_param)(input_data)
assert_allclose(result["img"], expected_value, atol=1e-4, rtol=1e-4, type_test=False)


if __name__ == "__main__":
unittest.main()
Loading