Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using pure transforms for image and geometry data #7403

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@
from .lazy.array import ApplyPending
from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict
from .lazy.functional import apply_pending
from .lazy.utils import combine_transforms, resample
from .lazy.utils import combine_transforms, resample_image
from .meta_utility.dictionary import (
FromMetaTensord,
FromMetaTensorD,
Expand Down
72 changes: 68 additions & 4 deletions monai/transforms/lazy/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,24 @@

from __future__ import annotations

from typing import Any, Mapping, Sequence
from typing import Any, Mapping, Sequence, Tuple

import copy

import torch

from monai.apps.utils import get_logger
from monai.config import NdarrayOrTensor
from monai.data.meta_tensor import MetaTensor
from monai.data.meta_obj import get_track_meta
from monai.data.utils import to_affine_nd
from monai.transforms.lazy.utils import (
affine_from_pending,
combine_transforms,
is_compatible_apply_kwargs,
kwargs_from_pending,
resample,
resample_image,
resample_points,
)
from monai.transforms.traits import LazyTrait
from monai.transforms.transform import MapTransform
Expand Down Expand Up @@ -80,6 +84,63 @@ def _log_applied_info(data: Any, key=None, logger_name: bool | str = False):
logger.info(f"Pending transforms applied: {key_str}applied_operations: {len(data.applied_operations)}")


def lazily_apply_op(
tensor,
op,
lazy_evaluation,
track_meta=True
) -> MetaTensor | Tuple[torch.Tensor, dict | None]:
"""
This function is intended for use only by developers of spatial functional transforms that
can be lazily executed.

This function will immediately apply the op to the given tensor if `lazy_evaluation` is set to
False. Its precise behaviour depends on whether it is passed a Tensor or MetaTensor:

If passed a Tensor, `lazily_apply_op` returns a tuple of Tensor and operation description:
- if `lazy_evaluation` is False, the transformed tensor and op is returned
- if `lazy_evaluation` is True, the tensor and op is returned

If passed a MetaTensor, only the tensor itself is returned:
- if `lazy_evaluation` is False, the transformed tensor is returned, with the op added to
the applied operations
- if `lazy_evaluation` is True, the untransformed tensor is returned, with the op added to
the pending operations

Args:
tensor: the tensor to have the operation lazily applied to
op: the operation description containing the transform and metadata
lazy_evaluation: a boolean flag indicating whether to apply the operation lazily
"""
if isinstance(tensor, MetaTensor):
tensor.push_pending_operation(op)
if lazy_evaluation is False:
response = apply_pending(tensor, track_meta=track_meta)
result, pending = response if isinstance(response, tuple) else (response, None)
# result, pending = apply_transforms(tensor, track_meta=track_meta)
return result
else:
return tensor
else:
if lazy_evaluation is False:
response = apply_pending(tensor, [op], track_meta=track_meta)
result, pending = response if isinstance(response, tuple) else (response, None)
# result, pending = apply_transforms(tensor, [op], track_meta=track_meta)
return (result, op) if get_track_meta() is True else result
else:
return (tensor, op) if get_track_meta() is True else tensor


def invert(
data: torch.tensor | MetaTensor,
lazy_evaluation=True
):
metadata = data.applied_operations.pop()
inv_metadata = copy.deepcopy(metadata)
inv_metadata.invert()
return lazily_apply_op(data, inv_metadata, lazy_evaluation, False)


def apply_pending_transforms(
data: NdarrayOrTensor | Sequence[Any | NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
keys: tuple | None,
Expand Down Expand Up @@ -279,7 +340,7 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None,
# carry out an intermediate resample here due to incompatibility between arguments
_cur_kwargs = cur_kwargs.copy()
_cur_kwargs.update(override_kwargs)
data = resample(data.to(device), cumulative_xform, _cur_kwargs)
data = resample_image(data.to(device), cumulative_xform, _cur_kwargs)

next_matrix = affine_from_pending(p)
if next_matrix.shape[0] == 3:
Expand All @@ -288,7 +349,10 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None,
cumulative_xform = combine_transforms(cumulative_xform, next_matrix)
cur_kwargs.update(new_kwargs)
cur_kwargs.update(override_kwargs)
data = resample(data.to(device), cumulative_xform, cur_kwargs)
if data.kind() == 'pixel':
data = resample_image(data.to(device), cumulative_xform, cur_kwargs)
elif data.kind() == 'point':
data = resample_points(data.to(device), cumulative_xform, cur_kwargs)
if isinstance(data, MetaTensor):
for p in pending:
data.push_applied_operation(p)
Expand Down
65 changes: 8 additions & 57 deletions monai/transforms/lazy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,65 +19,11 @@
import monai
from monai.config import NdarrayOrTensor
from monai.data.utils import AFFINE_TOL
from monai.transforms.utils import Affine
from monai.transforms.utils_pytorch_numpy_unification import allclose
from monai.utils import LazyAttr, convert_to_numpy, convert_to_tensor, look_up_option

__all__ = ["resample", "combine_transforms"]


class Affine:
"""A class to represent an affine transform matrix."""

__slots__ = ("data",)

def __init__(self, data):
self.data = data

@staticmethod
def is_affine_shaped(data):
"""Check if the data is an affine matrix."""
if isinstance(data, Affine):
return True
if isinstance(data, DisplacementField):
return False
if not hasattr(data, "shape") or len(data.shape) < 2:
return False
return data.shape[-1] in (3, 4) and data.shape[-1] == data.shape[-2]


class DisplacementField:
"""A class to represent a dense displacement field."""

__slots__ = ("data",)

def __init__(self, data):
self.data = data

@staticmethod
def is_ddf_shaped(data):
"""Check if the data is a DDF."""
if isinstance(data, DisplacementField):
return True
if isinstance(data, Affine):
return False
if not hasattr(data, "shape") or len(data.shape) < 3:
return False
return not Affine.is_affine_shaped(data)


def combine_transforms(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor:
"""Given transforms A and B to be applied to x, return the combined transform (AB), so that A(B(x)) becomes AB(x)"""
if Affine.is_affine_shaped(left) and Affine.is_affine_shaped(right): # linear transforms
left = convert_to_tensor(left.data if isinstance(left, Affine) else left, wrap_sequence=True)
right = convert_to_tensor(right.data if isinstance(right, Affine) else right, wrap_sequence=True)
return torch.matmul(left, right)
if DisplacementField.is_ddf_shaped(left) and DisplacementField.is_ddf_shaped(
right
): # adds DDFs, do we need metadata if metatensor input?
left = convert_to_tensor(left.data if isinstance(left, DisplacementField) else left, wrap_sequence=True)
right = convert_to_tensor(right.data if isinstance(right, DisplacementField) else right, wrap_sequence=True)
return left + right
raise NotImplementedError
__all__ = ["resample_image", "combine_transforms"]


def affine_from_pending(pending_item):
Expand Down Expand Up @@ -145,7 +91,7 @@ def requires_interp(matrix, atol=AFFINE_TOL):
__override_lazy_keywords = {*list(LazyAttr), "atol"}


def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = None):
def resample_image(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = None):
"""
Resample `data` using the affine transformation defined by ``matrix``.

Expand Down Expand Up @@ -227,3 +173,8 @@ def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None =
resampler.lazy = False # resampler is a lazytransform
with resampler.trace_transform(False): # don't track this transform in `img`
return resampler(img=img, **call_kwargs)


def resample_points(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = None):
# Handle all point resampling here
raise NotImplementedError()
Loading