Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Geometric transform implementations #7955

Draft
wants to merge 31 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
00436d0
Add `kind` property in `MetaTensor` (#7488)
KumoLiu Apr 12, 2024
b2ab07e
Adding apply_to_geometry function
atbenmurray Jul 11, 2024
7c7ec30
Adding load_geometry function
atbenmurray Jul 11, 2024
8077759
Adding missing import for apply_to_geometry
atbenmurray Jul 11, 2024
7edb44e
Adding KindKeys to __all__ in monai.utils.enums.py
atbenmurray Jul 11, 2024
61cb5c3
Adding tests for apply_to_geometry
atbenmurray Jul 11, 2024
1e83ec8
Adding apply_to_geometry to __all__ in lazy/functional.py
atbenmurray Jul 11, 2024
e503d6a
Adding flip functionality for geometric_tensors
atbenmurray Jul 11, 2024
f345fc9
Adding KindKeys to utils/__init__
atbenmurray Jul 11, 2024
edc4bb7
Adding | None to meta_info for apply_to_geometry
atbenmurray Jul 11, 2024
b8baae9
Adding resize functionality for geometric tensors
atbenmurray Jul 11, 2024
e59b4a9
Adding rotate functionality for geometric tensors
atbenmurray Jul 11, 2024
a0b768e
Fixing line endings for monai/transforms/io/functional.py
atbenmurray Jul 11, 2024
4e6ae70
Fixed rotate to that output_shape is returned. Made resize consistent…
atbenmurray Jul 19, 2024
f763c8f
Work towards geometry tests for rotate
atbenmurray Jul 19, 2024
fc8e0b8
Fixed KindKey types for flip functionality
atbenmurray Jul 19, 2024
0857ac7
Fix to handle 2d data being multiplied by a 3d transform from the met…
atbenmurray Jul 26, 2024
aa83926
Bug fixes to make all related unit tests pass
atbenmurray Jul 26, 2024
55edf77
Removed ndims (not ndim) from MetaTensor
atbenmurray Jul 26, 2024
7928f32
Adding tests for rotate / flip
atbenmurray Jul 26, 2024
1a216e9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2024
7004598
load_geometry functionality and tests
atbenmurray Aug 2, 2024
51d8d28
Resolving conflict
atbenmurray Aug 2, 2024
62f506b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2024
4740533
Adding spatial_dims_from_tensorlike function to handle point tensors
atbenmurray Aug 2, 2024
2ff1d73
Zoom modified to use spatial_dims_from_tensorlike
atbenmurray Aug 2, 2024
c2d351f
Merge branch 'geometric_all' of github.com:atbenmurray/monai into geo…
atbenmurray Aug 2, 2024
645e0c6
Added and tested point zoom functionality
atbenmurray Aug 2, 2024
f2e2426
Adding traced_no_op function for use in transforms that no-op geometr…
atbenmurray Aug 2, 2024
4898b88
Adding geometry support to Spacing transform
atbenmurray Aug 2, 2024
2cda79a
Adding save_geometry function plus test
atbenmurray Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from monai.data.meta_obj import MetaObj, get_track_meta
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
from monai.utils import look_up_option
from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys
from monai.utils.enums import KindKeys, LazyAttr, MetaKeys, PostFix, SpaceKeys
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor

__all__ = ["MetaTensor"]
Expand Down Expand Up @@ -345,6 +345,10 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
def get_default_affine(dtype=torch.float64) -> torch.Tensor:
return torch.eye(4, device=torch.device("cpu"), dtype=dtype)

@staticmethod
def get_default_kind() -> str:
return KindKeys.PIXEL

def as_tensor(self) -> torch.Tensor:
"""
Return the `MetaTensor` as a `torch.Tensor`.
Expand Down Expand Up @@ -469,13 +473,30 @@ def affine(self, d: NdarrayTensor) -> None:
"""Set the affine."""
self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64)

@property
def kind(self) -> str:
"""Get the data kind. Defaults to ``KindKeys.PIXEL``"""
return self.meta.get(MetaKeys.KIND, self.get_default_kind()) # type: ignore

@kind.setter
def kind(self, d: str) -> None:
"""Set the data kind."""
self.meta[MetaKeys.KIND] = d

@property
def pixdim(self):
"""Get the spacing"""
if self.is_batch:
return [affine_to_spacing(a) for a in self.affine]
return affine_to_spacing(self.affine)

# @property
# def ndims(self):
# # TODO: this will be wrong when there are batches; review
# if self.kind == KindKeys.POINT:
# return self.shape[2] - 1
# return len(self.shape)

def peek_pending_shape(self):
"""
Get the currently expected spatial shape as if all the pending operations are executed.
Expand Down
2 changes: 2 additions & 0 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from monai.utils import GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import
from monai.utils.enums import KindKeys, MetaKeys

nib, _ = optional_import("nibabel")
Image, _ = optional_import("PIL.Image")
Expand Down Expand Up @@ -280,6 +281,7 @@ def __call__(self, filename: Sequence[PathLike] | PathLike, reader: ImageReader

img_array: NdarrayOrTensor
img_array, meta_data = reader.get_data(img)
meta_data[MetaKeys.KIND] = KindKeys.PIXEL
img_array = convert_to_dst_type(img_array, dst=img_array, dtype=self.dtype)[0]
if not isinstance(meta_data, dict):
raise ValueError(f"`meta_data` must be a dict, got type {type(meta_data)}.")
Expand Down
84 changes: 84 additions & 0 deletions monai/transforms/io/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import json

import numpy as np

import torch

from monai.data.meta_tensor import MetaTensor

from monai.utils.enums import KindKeys


def load_geometry(file, image, origin):
"""
Load geometry from a file and optionally map it to another coordinate space.
"""
geometry = json.load(file)
geometry_schema = geometry.get("schema", None)
if geometry_schema is None:
raise ValueError("Geometry import issue: missing 'schema' entry")
elif "geometry" not in geometry_schema:
raise ValueError(f"Geometry import issue: 'schema' entry must contain 'geometry' key, got: {geometry_schema}")

if "points" not in geometry:
raise ValueError("Geometry import issue: missing 'points' entry")

points = geometry["points"]
if not isinstance(points, list):
raise ValueError(f"Geometry import issue: 'points' entry must be a list, got: {type(points)}")

if len(points) > 0:
first_len = None
for p in points:
if first_len is None:
first_len = len(p)
if len(p) != first_len:
raise ValueError("Geometry import issue: 'points' entry contains inconsistent point lengths")

points = np.asarray(points)
points = np.concatenate((points, np.ones((points.shape[0], 1))), axis=1)
points = torch.as_tensor(points, dtype=torch.float32)
points = MetaTensor(points)
points.kind = KindKeys.POINT

return points

"""
{
"schema": {
"geometry": "point"
},
"points": [
[0, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 1, 1],
[1, 0, 0],
[1, 1, 0],
[1, 0, 1],
[1, 1, 1],
]
}
"""

def save_geometry(data, file, image, origin):
"""
Load geometry from a file and optionally map it to another coordinate space.
"""
if not isinstance(data, MetaTensor):
raise ValueError(f"Geometry export issue: data must be a MetaTensor, got: {type(data)}")
if data.kind != KindKeys.POINT:
raise ValueError(f"Geometry export issue: geometry must be a point {KindKeys.POINT}")
geometry = data.detach().cpu().numpy()
geometry = geometry[:, :-1].tolist()

schema = {
"schema": {
"geometry": "point"
},
"points":
geometry
}

geometry = json.dump(schema, file)
return None
58 changes: 56 additions & 2 deletions monai/transforms/lazy/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from monai.apps.utils import get_logger
from monai.config import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta, MetaObj
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import to_affine_nd
from monai.transforms.lazy.utils import (
Expand All @@ -28,9 +29,9 @@
)
from monai.transforms.traits import LazyTrait
from monai.transforms.transform import MapTransform
from monai.utils import LazyAttr, look_up_option
from monai.utils import LazyAttr, MetaKeys, convert_to_tensor, look_up_option

__all__ = ["apply_pending_transforms", "apply_pending_transforms_in_order", "apply_pending"]
__all__ = ["apply_pending_transforms", "apply_pending_transforms_in_order", "apply_pending", "apply_to_geometry"]

__override_keywords = {"mode", "padding_mode", "dtype", "align_corners", "resample_mode", "device"}

Expand Down Expand Up @@ -293,3 +294,56 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None,
for p in pending:
data.push_applied_operation(p)
return data, pending


def apply_to_geometry(
data: torch.Tensor,
meta_info: dict | MetaObj | None = None,
transform: torch.Tensor | None = None,
):
"""
Apply an affine geometric transform or deformation field to geometry.
At present this is limited to the transformation of points.

The points must be provided as a tensor and must be compatible with a homogeneous
transform. This means that:
- 2D points are of the form (x, y, 1)
- 3D points are of the form (x, y, z, 1)

The affine transform or deformation field is applied to the the points and a tensor of
the same shape as the input tensor is returned.

Args:
data: the tensor of points to be transformed.
meta_info: the metadata containing the affine transformation
"""

if meta_info is None and transform is None:
raise ValueError("either meta_info or transform must be provided")
if meta_info is not None and transform is not None:
raise ValueError("only one of meta_info or transform can be provided")

if not isinstance(data, (torch.Tensor, MetaTensor)):
raise TypeError(f"data {type(data)} must be a torch.Tensor or MetaTensor")

data = convert_to_tensor(data, track_meta=get_track_meta())

if meta_info is not None:
transform_ = meta_info.meta[MetaKeys.AFFINE]
else:
transform_ = transform

if transform_.dtype != data.dtype:
transform_ = transform_.to(data.dtype)
if data.shape[-1] == 3 and transform_.shape[0] == 4:
transform_[2, 0:2] = transform_[3, 0:2]
transform_[2, 2] = transform_[3, 3]
transform_[0:2, 2] = transform_[0:2, 3]
transform_ = transform_[:-1, :-1]

if data.shape[-1] != transform_.shape[0]:
raise ValueError(f"final element of data.shape {data.shape} must match transform shape {transform_.shape}")

result = torch.matmul(data, transform_.T)

return result
14 changes: 11 additions & 3 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
)
from monai.transforms.traits import MultiSampleTrait
from monai.transforms.transform import LazyTransform, Randomizable, RandomizableTransform, Transform
from monai.transforms.utility.functional import traced_no_op
from monai.transforms.utils import (
create_control_grid,
create_grid,
Expand All @@ -54,6 +55,7 @@
map_spatial_axes,
resolves_modes,
scale_affine,
spatial_dims_from_tensorlike,
)
from monai.transforms.utils_pytorch_numpy_unification import argsort, argwhere, linalg_inv, moveaxis
from monai.utils import (
Expand All @@ -72,7 +74,7 @@
issequenceiterable,
optional_import,
)
from monai.utils.enums import GridPatchSort, PatchKeys, TraceKeys, TransformBackends
from monai.utils.enums import GridPatchSort, KindKeys, PatchKeys, TraceKeys, TransformBackends
from monai.utils.misc import ImageMetaKey as Key
from monai.utils.module import look_up_option
from monai.utils.type_conversion import convert_data_type, get_equivalent_dtype, get_torch_dtype_from_string
Expand Down Expand Up @@ -482,6 +484,12 @@ def __call__(
data tensor or MetaTensor (resampled into `self.pixdim`).

"""
lazy_ = self.lazy if lazy is None else lazy
if isinstance(data_array, MetaTensor) and data_array.kind == KindKeys.POINT:
warnings.warn("Spacing transform is not applied to point data.")
data_array = traced_no_op(data_array, lazy_, self.get_transform_info())
return data_array

original_spatial_shape = (
data_array.peek_pending_shape() if isinstance(data_array, MetaTensor) else data_array.shape[1:]
)
Expand Down Expand Up @@ -521,7 +529,6 @@ def __call__(
new_affine[:sr, -1] = offset[:sr]

actual_shape = list(output_shape) if output_spatial_shape is None else output_spatial_shape
lazy_ = self.lazy if lazy is None else lazy
data_array = self.sp_resample(
data_array,
dst_affine=torch.as_tensor(new_affine),
Expand Down Expand Up @@ -1089,7 +1096,8 @@ def __call__(
during initialization for this call. Defaults to None.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
_zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim
spatial_dims = spatial_dims_from_tensorlike(img)
_zoom = ensure_tuple_rep(self.zoom, spatial_dims) # match the spatial image dim
_mode = self.mode if mode is None else mode
_padding_mode = padding_mode or self.padding_mode
_align_corners = self.align_corners if align_corners is None else align_corners
Expand Down
Loading
Loading