diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 87fbc0655e..0226e810b0 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -46,13 +46,13 @@ "OneOf", "RandomOrder", "SomeOf", - "compose_iterator", - "ranged_compose_iterator", + "transform_iterator", + "ranged_transform_iterator", "execute_compose", ] -def compose_iterator(compose, step_into_all=False): +def transform_iterator(compose, step_into_all=False): """ ``compose_iterator`` is a function that returns an iterator over the transforms in a ``Compose`` instance. @@ -95,9 +95,37 @@ def compose_iterator(compose, step_into_all=False): for i in range(len(transforms)): tx = transforms[i] if type(tx) is Compose or (step_into_all is True and isinstance(tx, Compose)): - yield from compose_iterator(tx, step_into_all=step_into_all) + yield from transform_iterator(tx, step_into_all=step_into_all) else: - yield parent, tx + yield tx + + +def ranged_transform_iterator(compose, start=None, end=None, step_into_all=False): + """ + ``ranged_compose_iterator`` is a function that returns an iterator of a a sub-range of the + transforms in a ``Compose`` instance. It iterates over transforms until it reaches the + transform at index ``start``, iterating until it reaches index ``end``. + It follows the same rules as ``compose_iterator`` in terms of how it iterates into nested + ``Compose`` instances + Args: + compose: A ``Compose`` instance that contains the transforms to be iterated over + step_into_all: A boolean flag that indicates whether to step into randomised ``Compose`` + sub-classes such as ``OneOf``, ``SomeOf``, and ``RandomOrder`` + start: An optional integer that indicates the index to start returning transforms from. + This value is inclusive. If not set, iteration happens from the start of the list + end: An optional integer that indicates the index to stop returning transforms from. This + value is exclusive. If not set, iteration happens until the end of the list + """ + + i = 0 + + for tx in transform_iterator(compose, step_into_all): + if start is None or i >= start: + if end is None or i < end: + yield tx + else: + break + i += 1 def generate_subcompose(data, start, end): @@ -122,7 +150,7 @@ def __generate_subcompose(data, start, end, i=0): result = list() if result is None else result result.append(r) else: - if i >= start and i < end: + if i >= start and (end is None or i < end): # print(f"including {data} as {i} is in range") result = data # else: @@ -134,34 +162,6 @@ def __generate_subcompose(data, start, end, i=0): return result -def ranged_compose_iterator(compose, start=None, end=None, step_into_all=False): - """ - ``ranged_compose_iterator`` is a function that returns an iterator of a a sub-range of the - transforms in a ``Compose`` instance. It iterates over transforms until it reaches the - transform at index ``start``, iterating until it reaches index ``end``. - It follows the same rules as ``compose_iterator`` in terms of how it iterates into nested - ``Compose`` instances - Args: - compose: A ``Compose`` instance that contains the transforms to be iterated over - step_into_all: A boolean flag that indicates whether to step into randomised ``Compose`` - sub-classes such as ``OneOf``, ``SomeOf``, and ``RandomOrder`` - start: An optional integer that indicates the index to start returning transforms from. - This value is inclusive. If not set, iteration happens from the start of the list - end: An optional integer that indicates the index to stop returning transforms from. This - value is exclusive. If not set, iteration happens until the end of the list - """ - - i = 0 - - for cp, tx in compose_iterator(compose, step_into_all): - if start is None or i >= start: - if end is None or i < end: - yield cp, tx - else: - break - i += 1 - - def execute_compose( data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor], transforms: Sequence[Any], @@ -173,6 +173,7 @@ def execute_compose( overrides: dict | None = None, threading: bool = False, log_stats: bool | str = False, + is_inner: bool = False ) -> NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor]: """ ``execute_compose`` provides the implementation that the ``Compose`` class uses to execute a sequence @@ -204,7 +205,8 @@ def execute_compose( log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution. Setting this to False disables logging. Setting it to True enables logging to the default loggers. Setting a string overrides the logger name to which logging is performed. - + is_inner: this optional parameter should not be set by users but indicates to the Compose instance being called + that it being called by another Compose instance as part of its own execution. Returns: A tensorlike, sequence of tensorlikes or dict of tensorlists containing the result of running `data`` through the sequence of ``transforms``. @@ -223,15 +225,25 @@ def execute_compose( if start == end: return data - for _compose, _transform in ranged_compose_iterator(transforms, start, end): - _lazy = _compose.lazy if _compose is not None and lazy is None else lazy - _overrides = _compose.overrides if _compose is not None and overrides is None else overrides + # trim the set of transforms to be executed accoring to start and end + # parameter values + if start != 0 or end != None: + transforms_ = generate_subcompose(transforms, start, end) + else: + transforms_ = transforms + + for _transform in transforms_: if threading: _transform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform data = apply_transform( - _transform, data, map_items, unpack_items, lazy=_lazy, overrides=_overrides, log_stats=log_stats + _transform, data, map_items, unpack_items, lazy=lazy, overrides=overrides, log_stats=log_stats ) - data = apply_pending_transforms(data, None, overrides, logger_name=log_stats) + + # in the case of nested Compose instances, it is the responsiblity of the outermost + # instance to ensure that all pending transforms have been applied + if is_inner is False: + data = apply_pending_transforms(data, None, overrides, logger_name=log_stats) + return data @@ -425,7 +437,7 @@ def get_index_of_first(self, predicate): True. None if no transform satisfies the ``predicate`` """ - for i, (_, tx) in enumerate(compose_iterator(self.transform)): + for i, tx in enumerate(transform_iterator(self.transform)): if predicate(tx): return i @@ -439,7 +451,7 @@ def flatten(self): will result in the equivalent of `t1 = Compose([x, x, x, x, x, x, x, x])`. """ - return Compose([tx[1] for tx in compose_iterator(self)]) + return Compose(tuple(transform_iterator(self))) def sub_range(self, start, end): """ @@ -449,9 +461,17 @@ def sub_range(self, start, end): def __len__(self): """Return number of transformations.""" - return len(list(compose_iterator(self))) + return len(tuple(transform_iterator(self))) - def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None = None): + def __call__( + self, + input_, + start=0, + end=None, + threading=False, + lazy: bool | None = None, + is_inner: bool = False + ): _lazy = self._lazy if lazy is None else lazy result = execute_compose( input_, @@ -464,6 +484,7 @@ def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None overrides=self.overrides, threading=threading, log_stats=self.log_stats, + is_inner=is_inner ) return result @@ -471,7 +492,7 @@ def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None def inverse(self, data): self._raise_if_not_invertible(data) - invertible_transforms = [t[1] for t in compose_iterator(self) if isinstance(t, InvertibleTransform)] + invertible_transforms = [t for t in transform_iterator(self) if isinstance(t, InvertibleTransform)] if not invertible_transforms: warnings.warn("inverse has been called but no invertible transforms have been supplied") @@ -581,7 +602,15 @@ def flatten(self): weights.append(w) return OneOf(transforms, weights, self.map_items, self.unpack_items) - def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None = None): + def __call__( + self, + data, + start=0, + end=None, + threading=False, + lazy: bool | None = None, + is_inner: bool = False + ): if start != 0: raise ValueError(f"OneOf requires 'start' parameter to be 0 (start set to {start})") if end is not None: @@ -605,6 +634,7 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None = overrides=self.overrides, threading=threading, log_stats=self.log_stats, + is_inner=is_inner ) # if the data is a mapping (dictionary), append the OneOf transform to the end @@ -677,7 +707,15 @@ def __init__( super().__init__(transforms, map_items, unpack_items, log_stats, lazy, overrides) self.log_stats = log_stats - def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None = None): + def __call__( + self, + input_, + start=0, + end=None, + threading=False, + lazy: bool | None = None, + is_inner: bool = False +): if start != 0: raise ValueError(f"RandomOrder requires 'start' parameter to be 0 (start set to {start})") if end is not None: @@ -700,6 +738,7 @@ def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None lazy=_lazy, threading=threading, log_stats=self.log_stats, + is_inner=is_inner ) # if the data is a mapping (dictionary), append the RandomOrder transform to the end @@ -843,7 +882,15 @@ def _normalize_probabilities(self, weights): return ensure_tuple(list(weights)) - def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None = None): + def __call__( + self, + data, + start=0, + end=None, + threading=False, + lazy: bool | None = None, + is_inner: bool = False + ): if start != 0: raise ValueError(f"SomeOf requires 'start' parameter to be 0 (start set to {start})") if end is not None: @@ -867,6 +914,7 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None = overrides=self.overrides, threading=threading, log_stats=self.log_stats, + is_inner=is_inner ) if isinstance(data, monai.data.MetaTensor): self.push_transform(data, extra_info={"applied_order": applied_order}) diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 3d09cea545..52c9c76821 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -88,14 +88,20 @@ def _apply_transform( Returns: ReturnType: The return type of `transform`. """ + from monai.transforms.compose import Compose from monai.transforms.lazy.functional import apply_pending_transforms_in_order data = apply_pending_transforms_in_order(transform, data, lazy, overrides, logger_name) + kwargs = {} + if isinstance(transform, LazyTrait): + kwargs['lazy'] = lazy + if isinstance(transform, Compose): + kwargs['is_inner'] = True if isinstance(data, tuple) and unpack_parameters: - return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data) + return transform(*data, **kwargs) - return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data) + return transform(data, **kwargs) def apply_transform( @@ -105,7 +111,7 @@ def apply_transform( unpack_items: bool = False, log_stats: bool | str = False, lazy: bool | None = None, - overrides: dict | None = None, + overrides: dict | None = None ) -> list[ReturnType] | ReturnType: """ Transform `data` with `transform`. diff --git a/tests/test_compose.py b/tests/test_compose.py index 619576379f..f2bb98ee15 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -23,7 +23,7 @@ import monai.transforms as mt from monai.data import DataLoader, Dataset -from monai.transforms.compose import compose_iterator, execute_compose, ranged_compose_iterator +from monai.transforms.compose import transform_iterator, execute_compose, ranged_transform_iterator from monai.transforms.transform import Randomizable from monai.utils import set_determinism @@ -258,6 +258,53 @@ def test_backwards_compatible_imports(self): from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401 +class TestNestedCompose(unittest.TestCase): + + def _test_nested_is_equivalent_impl(self, pipelines, data, lazy): + + p0_result = pipelines[0](data, lazy=lazy) + for p in pipelines[1:]: + p_result = p(data, lazy=lazy) + + self.assertTrue(torch.allclose(p0_result, p_result)) + + def test_nested_is_equivalent(self): + data = data_from_keys(None, 12, 16) + + print(data.shape) + t1 = mt.Rotate(torch.pi / 8) + t2 = mt.Flip(0) + t3 = mt.Resize((64, 64)) + t4 = mt.Rotate(-torch.pi / 16) + uc = mt.Compose([t1, t2, t3, t4]) + n1c = mt.Compose([mt.Compose([t1, t2]), mt.Compose([t3, t4])]) + n2c = mt.Compose([t1, mt.Compose([t2, t3]), t4]) + + self._test_nested_is_equivalent_impl((uc, n1c, n2c), data, lazy=False) + self._test_nested_is_equivalent_impl((uc, n1c, n2c), data, lazy=True) + + def test_nested_is_equivalent_one_of(self): + t1 = mt.Rotate(torch.pi / 8) + t2 = mt.Flip(0) + t3a = mt.Rotate(torch.pi / 4) + t4a = mt.Rotate(torch.pi / 4) + t3b = mt.Rotate(torch.pi / 4, lazy=False) + t4b = mt.Rotate(torch.pi / 4, lazy=False) + t5 = mt.Zoom(0.8) + uo = mt.OneOf([t3a, t4a]) + uo.set_random_state(seed=123456789) + uc = mt.Compose([t1, t2, uo, t5]) + n1o = mt.OneOf([t3b, t4b]) + n1o.set_random_state(seed=123456789) + n1c = mt.Compose([t1, mt.Compose([t2, n1o]), t5]) + + data = data_from_keys(None, 12, 16) + self._test_nested_is_equivalent_impl((uc, n1c), data, lazy=False) + + data = data_from_keys(None, 12, 16) + self._test_nested_is_equivalent_impl((uc, n1c), data, lazy=True) + + class TestComposeIterator(unittest.TestCase): def test_compose_iterator(self): t1 = mt.Rotate(torch.pi / 8) @@ -269,12 +316,12 @@ def test_compose_iterator(self): c = mt.Compose([c1, c2]) for m in (False, True): - expected = [(c1, t1), (c1, t2), (c2, t3), (c2, t4)] - self.assertListEqual(list(compose_iterator(c, step_into_all=m)), expected) - expected = [(c1, t2), (c2, t3)] - self.assertListEqual(list(ranged_compose_iterator(c, start=1, end=3, step_into_all=m)), expected) - expected = [(c1, t1), (c1, t2), (c2, t3), (c2, t4)] - self.assertListEqual(list(ranged_compose_iterator(c, step_into_all=m)), expected) + expected = [t1, t2, t3, t4] + self.assertListEqual(list(transform_iterator(c, step_into_all=m)), expected) + expected = [t2, t3] + self.assertListEqual(list(ranged_transform_iterator(c, start=1, end=3, step_into_all=m)), expected) + expected = [t1, t2, t3, t4] + self.assertListEqual(list(ranged_transform_iterator(c, step_into_all=m)), expected) def test_compose_iterator_oneof(self): t1 = mt.Rotate(torch.pi / 8) @@ -289,20 +336,20 @@ def test_compose_iterator_oneof(self): for m in (False, True): if m is False: - expected = [(c1, t1), (c1, t2), (c, c2), (c, t5)] + expected = [t1, t2, c2, t5] else: - expected = [(c1, t1), (c1, t2), (c2, t3), (c2, t4), (c, t5)] - actual = list(compose_iterator(c, step_into_all=m)) + expected = [t1, t2, t3, t4, t5] + actual = list(transform_iterator(c, step_into_all=m)) self.assertListEqual(actual, expected) - expected = [(c1, t2), (c, c2)] if m is False else [(c1, t2), (c2, t3)] - self.assertListEqual(list(ranged_compose_iterator(c, start=1, end=3, step_into_all=m)), expected) + expected = [t2, c2] if m is False else [t2, t3] + self.assertListEqual(list(ranged_transform_iterator(c, start=1, end=3, step_into_all=m)), expected) if m is False: - expected = [(c1, t1), (c1, t2), (c, c2), (c, t5)] + expected = [t1, t2, c2, t5] else: - expected = [(c1, t1), (c1, t2), (c2, t3), (c2, t4), (c, t5)] - self.assertListEqual(list(ranged_compose_iterator(c, step_into_all=m)), expected) + expected = [t1, t2, t3, t4, t5] + self.assertListEqual(list(ranged_transform_iterator(c, step_into_all=m)), expected) class TestComposeSubRange(unittest.TestCase):