Skip to content

Commit

Permalink
Closing in on a solution. Mechanism for getting Compose with the appr…
Browse files Browse the repository at this point in the history
…opriate subset implemented. Tests for said mechanism. Renaming of methods for clarity

Signed-off-by: Ben Murray <[email protected]>
  • Loading branch information
atbenmurray committed Nov 3, 2023
1 parent 8c3e573 commit 9c99bb1
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 66 deletions.
144 changes: 96 additions & 48 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@
"OneOf",
"RandomOrder",
"SomeOf",
"compose_iterator",
"ranged_compose_iterator",
"transform_iterator",
"ranged_transform_iterator",
"execute_compose",
]


def compose_iterator(compose, step_into_all=False):
def transform_iterator(compose, step_into_all=False):
"""
``compose_iterator`` is a function that returns an iterator over the transforms in a ``Compose``
instance.
Expand Down Expand Up @@ -95,9 +95,37 @@ def compose_iterator(compose, step_into_all=False):
for i in range(len(transforms)):
tx = transforms[i]
if type(tx) is Compose or (step_into_all is True and isinstance(tx, Compose)):
yield from compose_iterator(tx, step_into_all=step_into_all)
yield from transform_iterator(tx, step_into_all=step_into_all)
else:
yield parent, tx
yield tx


def ranged_transform_iterator(compose, start=None, end=None, step_into_all=False):
"""
``ranged_compose_iterator`` is a function that returns an iterator of a a sub-range of the
transforms in a ``Compose`` instance. It iterates over transforms until it reaches the
transform at index ``start``, iterating until it reaches index ``end``.
It follows the same rules as ``compose_iterator`` in terms of how it iterates into nested
``Compose`` instances
Args:
compose: A ``Compose`` instance that contains the transforms to be iterated over
step_into_all: A boolean flag that indicates whether to step into randomised ``Compose``
sub-classes such as ``OneOf``, ``SomeOf``, and ``RandomOrder``
start: An optional integer that indicates the index to start returning transforms from.
This value is inclusive. If not set, iteration happens from the start of the list
end: An optional integer that indicates the index to stop returning transforms from. This
value is exclusive. If not set, iteration happens until the end of the list
"""

i = 0

for tx in transform_iterator(compose, step_into_all):
if start is None or i >= start:
if end is None or i < end:
yield tx
else:
break
i += 1


def generate_subcompose(data, start, end):
Expand All @@ -122,7 +150,7 @@ def __generate_subcompose(data, start, end, i=0):
result = list() if result is None else result
result.append(r)
else:
if i >= start and i < end:
if i >= start and (end is None or i < end):
# print(f"including {data} as {i} is in range")
result = data
# else:
Expand All @@ -134,34 +162,6 @@ def __generate_subcompose(data, start, end, i=0):
return result


def ranged_compose_iterator(compose, start=None, end=None, step_into_all=False):
"""
``ranged_compose_iterator`` is a function that returns an iterator of a a sub-range of the
transforms in a ``Compose`` instance. It iterates over transforms until it reaches the
transform at index ``start``, iterating until it reaches index ``end``.
It follows the same rules as ``compose_iterator`` in terms of how it iterates into nested
``Compose`` instances
Args:
compose: A ``Compose`` instance that contains the transforms to be iterated over
step_into_all: A boolean flag that indicates whether to step into randomised ``Compose``
sub-classes such as ``OneOf``, ``SomeOf``, and ``RandomOrder``
start: An optional integer that indicates the index to start returning transforms from.
This value is inclusive. If not set, iteration happens from the start of the list
end: An optional integer that indicates the index to stop returning transforms from. This
value is exclusive. If not set, iteration happens until the end of the list
"""

i = 0

for cp, tx in compose_iterator(compose, step_into_all):
if start is None or i >= start:
if end is None or i < end:
yield cp, tx
else:
break
i += 1


def execute_compose(
data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
transforms: Sequence[Any],
Expand All @@ -173,6 +173,7 @@ def execute_compose(
overrides: dict | None = None,
threading: bool = False,
log_stats: bool | str = False,
is_inner: bool = False
) -> NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor]:
"""
``execute_compose`` provides the implementation that the ``Compose`` class uses to execute a sequence
Expand Down Expand Up @@ -204,7 +205,8 @@ def execute_compose(
log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.
Setting this to False disables logging. Setting it to True enables logging to the default loggers.
Setting a string overrides the logger name to which logging is performed.
is_inner: this optional parameter should not be set by users but indicates to the Compose instance being called
that it being called by another Compose instance as part of its own execution.
Returns:
A tensorlike, sequence of tensorlikes or dict of tensorlists containing the result of running
`data`` through the sequence of ``transforms``.
Expand All @@ -223,15 +225,25 @@ def execute_compose(
if start == end:
return data

for _compose, _transform in ranged_compose_iterator(transforms, start, end):
_lazy = _compose.lazy if _compose is not None and lazy is None else lazy
_overrides = _compose.overrides if _compose is not None and overrides is None else overrides
# trim the set of transforms to be executed accoring to start and end
# parameter values
if start != 0 or end != None:
transforms_ = generate_subcompose(transforms, start, end)
else:
transforms_ = transforms

for _transform in transforms_:
if threading:
_transform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
data = apply_transform(
_transform, data, map_items, unpack_items, lazy=_lazy, overrides=_overrides, log_stats=log_stats
_transform, data, map_items, unpack_items, lazy=lazy, overrides=overrides, log_stats=log_stats
)
data = apply_pending_transforms(data, None, overrides, logger_name=log_stats)

# in the case of nested Compose instances, it is the responsiblity of the outermost
# instance to ensure that all pending transforms have been applied
if is_inner is False:
data = apply_pending_transforms(data, None, overrides, logger_name=log_stats)

return data


Expand Down Expand Up @@ -425,7 +437,7 @@ def get_index_of_first(self, predicate):
True. None if no transform satisfies the ``predicate``
"""
for i, (_, tx) in enumerate(compose_iterator(self.transform)):
for i, tx in enumerate(transform_iterator(self.transform)):
if predicate(tx):
return i

Expand All @@ -439,7 +451,7 @@ def flatten(self):
will result in the equivalent of `t1 = Compose([x, x, x, x, x, x, x, x])`.
"""
return Compose([tx[1] for tx in compose_iterator(self)])
return Compose(tuple(transform_iterator(self)))

def sub_range(self, start, end):
"""
Expand All @@ -449,9 +461,17 @@ def sub_range(self, start, end):

def __len__(self):
"""Return number of transformations."""
return len(list(compose_iterator(self)))
return len(tuple(transform_iterator(self)))

def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None = None):
def __call__(
self,
input_,
start=0,
end=None,
threading=False,
lazy: bool | None = None,
is_inner: bool = False
):
_lazy = self._lazy if lazy is None else lazy
result = execute_compose(
input_,
Expand All @@ -464,14 +484,15 @@ def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None
overrides=self.overrides,
threading=threading,
log_stats=self.log_stats,
is_inner=is_inner
)

return result

def inverse(self, data):
self._raise_if_not_invertible(data)

invertible_transforms = [t[1] for t in compose_iterator(self) if isinstance(t, InvertibleTransform)]
invertible_transforms = [t for t in transform_iterator(self) if isinstance(t, InvertibleTransform)]
if not invertible_transforms:
warnings.warn("inverse has been called but no invertible transforms have been supplied")

Expand Down Expand Up @@ -581,7 +602,15 @@ def flatten(self):
weights.append(w)
return OneOf(transforms, weights, self.map_items, self.unpack_items)

def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None = None):
def __call__(
self,
data,
start=0,
end=None,
threading=False,
lazy: bool | None = None,
is_inner: bool = False
):
if start != 0:
raise ValueError(f"OneOf requires 'start' parameter to be 0 (start set to {start})")
if end is not None:
Expand All @@ -605,6 +634,7 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None =
overrides=self.overrides,
threading=threading,
log_stats=self.log_stats,
is_inner=is_inner
)

# if the data is a mapping (dictionary), append the OneOf transform to the end
Expand Down Expand Up @@ -677,7 +707,15 @@ def __init__(
super().__init__(transforms, map_items, unpack_items, log_stats, lazy, overrides)
self.log_stats = log_stats

def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None = None):
def __call__(
self,
input_,
start=0,
end=None,
threading=False,
lazy: bool | None = None,
is_inner: bool = False
):
if start != 0:
raise ValueError(f"RandomOrder requires 'start' parameter to be 0 (start set to {start})")
if end is not None:
Expand All @@ -700,6 +738,7 @@ def __call__(self, input_, start=0, end=None, threading=False, lazy: bool | None
lazy=_lazy,
threading=threading,
log_stats=self.log_stats,
is_inner=is_inner
)

# if the data is a mapping (dictionary), append the RandomOrder transform to the end
Expand Down Expand Up @@ -843,7 +882,15 @@ def _normalize_probabilities(self, weights):

return ensure_tuple(list(weights))

def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None = None):
def __call__(
self,
data,
start=0,
end=None,
threading=False,
lazy: bool | None = None,
is_inner: bool = False
):
if start != 0:
raise ValueError(f"SomeOf requires 'start' parameter to be 0 (start set to {start})")
if end is not None:
Expand All @@ -867,6 +914,7 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None =
overrides=self.overrides,
threading=threading,
log_stats=self.log_stats,
is_inner=is_inner
)
if isinstance(data, monai.data.MetaTensor):
self.push_transform(data, extra_info={"applied_order": applied_order})
Expand Down
12 changes: 9 additions & 3 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,20 @@ def _apply_transform(
Returns:
ReturnType: The return type of `transform`.
"""
from monai.transforms.compose import Compose
from monai.transforms.lazy.functional import apply_pending_transforms_in_order

data = apply_pending_transforms_in_order(transform, data, lazy, overrides, logger_name)

kwargs = {}
if isinstance(transform, LazyTrait):
kwargs['lazy'] = lazy
if isinstance(transform, Compose):
kwargs['is_inner'] = True
if isinstance(data, tuple) and unpack_parameters:
return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data)
return transform(*data, **kwargs)

return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)
return transform(data, **kwargs)


def apply_transform(
Expand All @@ -105,7 +111,7 @@ def apply_transform(
unpack_items: bool = False,
log_stats: bool | str = False,
lazy: bool | None = None,
overrides: dict | None = None,
overrides: dict | None = None
) -> list[ReturnType] | ReturnType:
"""
Transform `data` with `transform`.
Expand Down
Loading

0 comments on commit 9c99bb1

Please sign in to comment.