From 97678fa75a851c411c713658059bac3b2ce5cf70 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Mon, 25 Mar 2024 22:13:56 -0400 Subject: [PATCH 1/2] Remove nested error propagation on `ConfigComponent` instantiate (#7569) Fixes #7451 ### Description Reduces the length of error messages and error messages being propagated twice. This helps debug better when long `ConfigComponent`s are being instantiated. Refer to issue #7451 for more details ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Suraj Pai Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> --- monai/bundle/config_item.py | 5 +---- monai/utils/module.py | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py index 844d5b30bf5..e5122bf3de3 100644 --- a/monai/bundle/config_item.py +++ b/monai/bundle/config_item.py @@ -289,10 +289,7 @@ def instantiate(self, **kwargs: Any) -> object: mode = self.get_config().get("_mode_", CompInitMode.DEFAULT) args = self.resolve_args() args.update(kwargs) - try: - return instantiate(modname, mode, **args) - except Exception as e: - raise RuntimeError(f"Failed to instantiate {self}") from e + return instantiate(modname, mode, **args) class ConfigExpression(ConfigItem): diff --git a/monai/utils/module.py b/monai/utils/module.py index 5e058c105bf..6f301d8067d 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -272,7 +272,7 @@ def instantiate(__path: str, __mode: str, **kwargs: Any) -> Any: return pdb.runcall(component, **kwargs) except Exception as e: raise RuntimeError( - f"Failed to instantiate component '{__path}' with kwargs: {kwargs}" + f"Failed to instantiate component '{__path}' with keywords: {','.join(kwargs.keys())}" f"\n set '_mode_={CompInitMode.DEBUG}' to enter the debugging mode." ) from e From e5bebfc7cd2e48499159743365230115d1a7f460 Mon Sep 17 00:00:00 2001 From: Juampa <1523654+juampatronics@users.noreply.github.com> Date: Tue, 26 Mar 2024 03:57:36 +0100 Subject: [PATCH 2/2] 2872 implementation of mixup, cutmix and cutout (#7198) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #2872 ### Description Implementation of mixup, cutmix and cutout as described in the original papers. Current implementation support both, the dictionary-based batches and tuples of tensors. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Juan Pablo de la Cruz Gutiérrez Signed-off-by: monai-bot Signed-off-by: elitap Signed-off-by: Felix Schnabel Signed-off-by: YanxuanLiu Signed-off-by: ytl0623 Signed-off-by: Dženan Zukić Signed-off-by: KumoLiu Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Ishan Dutta Signed-off-by: dependabot[bot] Signed-off-by: kaibo Signed-off-by: heyufan1995 Signed-off-by: binliu Signed-off-by: axel.vlaminck Signed-off-by: Ibrahim Hadzic Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> Signed-off-by: Timothy Baker Signed-off-by: Mathijs de Boer Signed-off-by: Fabian Klopfer Signed-off-by: Lucas Robinet Signed-off-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Signed-off-by: chaoliu Signed-off-by: cxlcl Signed-off-by: chaoliu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: monai-bot <64792179+monai-bot@users.noreply.github.com> Co-authored-by: elitap Co-authored-by: Felix Schnabel Co-authored-by: YanxuanLiu <104543031+YanxuanLiu@users.noreply.github.com> Co-authored-by: ytl0623 Co-authored-by: Dženan Zukić Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Ishan Dutta Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Kaibo Tang Co-authored-by: Yufan He <59374597+heyufan1995@users.noreply.github.com> Co-authored-by: binliunls <107988372+binliunls@users.noreply.github.com> Co-authored-by: Ben Murray Co-authored-by: axel.vlaminck Co-authored-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com> Co-authored-by: Ibrahim Hadzic Co-authored-by: Dr. Behrooz Hashemian <3968947+drbeh@users.noreply.github.com> Co-authored-by: Timothy J. Baker <62781117+tim-the-baker@users.noreply.github.com> Co-authored-by: Mathijs de Boer <8137653+MathijsdeBoer@users.noreply.github.com> Co-authored-by: Mathijs de Boer Co-authored-by: Fabian Klopfer Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: Lucas Robinet <67736918+Lucas-rbnt@users.noreply.github.com> Co-authored-by: Lucas Robinet Co-authored-by: cxlcl --- docs/source/transforms.rst | 42 +++++ docs/source/transforms_idx.rst | 10 + monai/transforms/__init__.py | 12 ++ monai/transforms/regularization/__init__.py | 10 + monai/transforms/regularization/array.py | 173 ++++++++++++++++++ monai/transforms/regularization/dictionary.py | 97 ++++++++++ tests/test_regularization.py | 90 +++++++++ 7 files changed, 434 insertions(+) create mode 100644 monai/transforms/regularization/__init__.py create mode 100644 monai/transforms/regularization/array.py create mode 100644 monai/transforms/regularization/dictionary.py create mode 100644 tests/test_regularization.py diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 8990e7991dc..bd3feb3497b 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -661,6 +661,27 @@ Post-processing :members: :special-members: __call__ +Regularization +^^^^^^^^^^^^^^ + +`CutMix` +"""""""" +.. autoclass:: CutMix + :members: + :special-members: __call__ + +`CutOut` +"""""""" +.. autoclass:: CutOut + :members: + :special-members: __call__ + +`MixUp` +""""""" +.. autoclass:: MixUp + :members: + :special-members: __call__ + Signal ^^^^^^^ @@ -1707,6 +1728,27 @@ Post-processing (Dict) :members: :special-members: __call__ +Regularization (Dict) +^^^^^^^^^^^^^^^^^^^^^ + +`CutMixd` +""""""""" +.. autoclass:: CutMixd + :members: + :special-members: __call__ + +`CutOutd` +""""""""" +.. autoclass:: CutOutd + :members: + :special-members: __call__ + +`MixUpd` +"""""""" +.. autoclass:: MixUpd + :members: + :special-members: __call__ + Signal (Dict) ^^^^^^^^^^^^^ diff --git a/docs/source/transforms_idx.rst b/docs/source/transforms_idx.rst index f4d02a483f7..650d45db716 100644 --- a/docs/source/transforms_idx.rst +++ b/docs/source/transforms_idx.rst @@ -74,6 +74,16 @@ Post-processing post.array post.dictionary +Regularization +^^^^^^^^^^^^^^ + +.. autosummary:: + :toctree: _gen + :nosignatures: + + regularization.array + regularization.dictionary + Signal ^^^^^^ diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2aa8fbf8a18..349533fb3e8 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -336,6 +336,18 @@ VoteEnsembled, VoteEnsembleDict, ) +from .regularization.array import CutMix, CutOut, MixUp +from .regularization.dictionary import ( + CutMixd, + CutMixD, + CutMixDict, + CutOutd, + CutOutD, + CutOutDict, + MixUpd, + MixUpD, + MixUpDict, +) from .signal.array import ( SignalContinuousWavelet, SignalFillEmpty, diff --git a/monai/transforms/regularization/__init__.py b/monai/transforms/regularization/__init__.py new file mode 100644 index 00000000000..1e97f894078 --- /dev/null +++ b/monai/transforms/regularization/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py new file mode 100644 index 00000000000..6c9022d6473 --- /dev/null +++ b/monai/transforms/regularization/array.py @@ -0,0 +1,173 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import abstractmethod +from math import ceil, sqrt + +import torch + +from ..transform import RandomizableTransform + +__all__ = ["MixUp", "CutMix", "CutOut", "Mixer"] + + +class Mixer(RandomizableTransform): + def __init__(self, batch_size: int, alpha: float = 1.0) -> None: + """ + Mixer is a base class providing the basic logic for the mixup-class of + augmentations. In all cases, we need to sample the mixing weights for each + sample (lambda in the notation used in the papers). Also, pairs of samples + being mixed are picked by randomly shuffling the batch samples. + + Args: + batch_size (int): number of samples per batch. That is, samples are expected tp + be of size batchsize x channels [x depth] x height x width. + alpha (float, optional): mixing weights are sampled from the Beta(alpha, alpha) + distribution. Defaults to 1.0, the uniform distribution. + """ + super().__init__() + if alpha <= 0: + raise ValueError(f"Expected positive number, but got {alpha = }") + self.alpha = alpha + self.batch_size = batch_size + + @abstractmethod + def apply(self, data: torch.Tensor): + raise NotImplementedError() + + def randomize(self, data=None) -> None: + """ + Sometimes you need may to apply the same transform to different tensors. + The idea is to get a sample and then apply it with apply() as often + as needed. You need to call this method everytime you apply the transform to a new + batch. + """ + self._params = ( + torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32), + self.R.permutation(self.batch_size), + ) + + +class MixUp(Mixer): + """MixUp as described in: + Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz. + mixup: Beyond Empirical Risk Minimization, ICLR 2018 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. + """ + + def apply(self, data: torch.Tensor): + weight, perm = self._params + nsamples, *dims = data.shape + if len(weight) != nsamples: + raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}") + + if len(dims) not in [3, 4]: + raise ValueError("Unexpected number of dimensions") + + mixweight = weight[(Ellipsis,) + (None,) * len(dims)] + return mixweight * data + (1 - mixweight) * data[perm, ...] + + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): + self.randomize() + if labels is None: + return self.apply(data) + return self.apply(data), self.apply(labels) + + +class CutMix(Mixer): + """CutMix augmentation as described in: + Sangdoo Yun, Dongyoon Han, Seong Joon Oh, Sanghyuk Chun, Junsuk Choe, Youngjoon Yoo. + CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features, + ICCV 2019 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. Here, alpha not only determines + the mixing weight but also the size of the random rectangles used during for mixing. + Please refer to the paper for details. + + The most common use case is something close to: + + .. code-block:: python + + cm = CutMix(batch_size=8, alpha=0.5) + for batch in loader: + images, labels = batch + augimg, auglabels = cm(images, labels) + output = model(augimg) + loss = loss_function(output, auglabels) + ... + + """ + + def apply(self, data: torch.Tensor): + weights, perm = self._params + nsamples, _, *dims = data.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mask = torch.ones_like(data) + for s, weight in enumerate(weights): + coords = [torch.randint(0, d, size=(1,)) for d in dims] + lengths = [d * sqrt(1 - weight) for d in dims] + idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] + mask[s][idx] = 0 + + return mask * data + (1 - mask) * data[perm, ...] + + def apply_on_labels(self, labels: torch.Tensor): + weights, perm = self._params + nsamples, *dims = labels.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mixweight = weights[(Ellipsis,) + (None,) * len(dims)] + return mixweight * labels + (1 - mixweight) * labels[perm, ...] + + def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None): + self.randomize() + augmented = self.apply(data) + return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented + + +class CutOut(Mixer): + """Cutout as described in the paper: + Terrance DeVries, Graham W. Taylor. + Improved Regularization of Convolutional Neural Networks with Cutout, + arXiv:1708.04552 + + Class derived from :py:class:`monai.transforms.Mixer`. See corresponding + documentation for details on the constructor parameters. Here, alpha not only determines + the mixing weight but also the size of the random rectangles being cut put. + Please refer to the paper for details. + """ + + def apply(self, data: torch.Tensor): + weights, _ = self._params + nsamples, _, *dims = data.shape + if len(weights) != nsamples: + raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}") + + mask = torch.ones_like(data) + for s, weight in enumerate(weights): + coords = [torch.randint(0, d, size=(1,)) for d in dims] + lengths = [d * sqrt(1 - weight) for d in dims] + idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)] + mask[s][idx] = 0 + + return mask * data + + def __call__(self, data: torch.Tensor): + self.randomize() + return self.apply(data) diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py new file mode 100644 index 00000000000..373913da991 --- /dev/null +++ b/monai/transforms/regularization/dictionary.py @@ -0,0 +1,97 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from monai.config import KeysCollection +from monai.utils.misc import ensure_tuple + +from ..transform import MapTransform +from .array import CutMix, CutOut, MixUp + +__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"] + + +class MixUpd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.MixUp`. + + Notice that the mixup transformation will be the same for all entries + for consistency, i.e. images and labels must be applied the same augmenation. + """ + + def __init__( + self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mixup = MixUp(batch_size, alpha) + + def __call__(self, data): + self.mixup.randomize() + result = dict(data) + for k in self.keys: + result[k] = self.mixup.apply(data[k]) + return result + + +class CutMixd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.CutMix`. + + Notice that the mixture weights will be the same for all entries + for consistency, i.e. images and labels must be aggregated with the same weights, + but the random crops are not. + """ + + def __init__( + self, + keys: KeysCollection, + batch_size: int, + label_keys: KeysCollection | None = None, + alpha: float = 1.0, + allow_missing_keys: bool = False, + ) -> None: + super().__init__(keys, allow_missing_keys) + self.mixer = CutMix(batch_size, alpha) + self.label_keys = ensure_tuple(label_keys) if label_keys is not None else [] + + def __call__(self, data): + self.mixer.randomize() + result = dict(data) + for k in self.keys: + result[k] = self.mixer.apply(data[k]) + for k in self.label_keys: + result[k] = self.mixer.apply_on_labels(data[k]) + return result + + +class CutOutd(MapTransform): + """ + Dictionary-based version :py:class:`monai.transforms.CutOut`. + + Notice that the cutout is different for every entry in the dictionary. + """ + + def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bool = False) -> None: + super().__init__(keys, allow_missing_keys) + self.cutout = CutOut(batch_size) + + def __call__(self, data): + result = dict(data) + self.cutout.randomize() + for k in self.keys: + result[k] = self.cutout(data[k]) + return result + + +MixUpD = MixUpDict = MixUpd +CutMixD = CutMixDict = CutMixd +CutOutD = CutOutDict = CutOutd diff --git a/tests/test_regularization.py b/tests/test_regularization.py new file mode 100644 index 00000000000..d381ea72ca3 --- /dev/null +++ b/tests/test_regularization.py @@ -0,0 +1,90 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd + + +class TestMixup(unittest.TestCase): + def test_mixup(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + mixup = MixUp(6, 1.0) + output = mixup(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, mixup(sample)) for _ in range(10))) + + with self.assertRaises(ValueError): + MixUp(6, -0.5) + + mixup = MixUp(6, 0.5) + for dims in [2, 3]: + with self.assertRaises(ValueError): + shape = (5, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + mixup(sample) + + def test_mixupd(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + t = torch.rand(*shape, dtype=torch.float32) + sample = {"a": t, "b": t} + mixup = MixUpd(["a", "b"], 6) + output = mixup(sample) + self.assertTrue(torch.allclose(output["a"], output["b"])) + + with self.assertRaises(ValueError): + MixUpd(["k1", "k2"], 6, -0.5) + + +class TestCutMix(unittest.TestCase): + def test_cutmix(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + cutmix = CutMix(6, 1.0) + output = cutmix(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10))) + + def test_cutmixd(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + t = torch.rand(*shape, dtype=torch.float32) + label = torch.randint(0, 1, shape) + sample = {"a": t, "b": t, "lbl1": label, "lbl2": label} + cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2")) + output = cutmix(sample) + # croppings are different on each application + self.assertTrue(not torch.allclose(output["a"], output["b"])) + # but mixing of labels is not affected by it + self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"])) + + +class TestCutOut(unittest.TestCase): + def test_cutout(self): + for dims in [2, 3]: + shape = (6, 3) + (32,) * dims + sample = torch.rand(*shape, dtype=torch.float32) + cutout = CutOut(6, 1.0) + output = cutout(sample) + self.assertEqual(output.shape, sample.shape) + self.assertTrue(any(not torch.allclose(sample, cutout(sample)) for _ in range(10))) + + +if __name__ == "__main__": + unittest.main()