forked from Project-MONAI/MONAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Expose ITK Image to MONAI MetaTensor conversion (Project-MONAI#5897)
Fixes Project-MONAI#5708 Fixes Project-MONAI#4117 ### Description This is a migration of the PR (by @ntatsisk https://github.com/ntatsisk) InsightSoftwareConsortium/itk-torch-transform-bridge#6 into MONAI. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Felix Schnabel <[email protected]>
- Loading branch information
1 parent
f5708ea
commit 2a8c8cd
Showing
7 changed files
with
861 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,338 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, cast | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from monai.config.type_definitions import DtypeLike | ||
from monai.data import ITKReader, ITKWriter | ||
from monai.data.meta_tensor import MetaTensor | ||
from monai.transforms import EnsureChannelFirst | ||
from monai.utils import convert_to_dst_type, optional_import | ||
|
||
if TYPE_CHECKING: | ||
import itk | ||
|
||
has_itk = True | ||
else: | ||
itk, has_itk = optional_import("itk") | ||
|
||
__all__ = [ | ||
"itk_image_to_metatensor", | ||
"metatensor_to_itk_image", | ||
"itk_to_monai_affine", | ||
"monai_to_itk_affine", | ||
"get_itk_image_center", | ||
"monai_to_itk_ddf", | ||
] | ||
|
||
|
||
def itk_image_to_metatensor( | ||
image, channel_dim: str | int | None = None, dtype: DtypeLike | torch.dtype = float | ||
) -> MetaTensor: | ||
""" | ||
Converts an ITK image to a MetaTensor object. | ||
Args: | ||
image: The ITK image to be converted. | ||
channel_dim: the channel dimension of the input image, default is None. | ||
This is used to set original_channel_dim in the metadata, EnsureChannelFirst reads this field. | ||
If None, the channel_dim is inferred automatically. | ||
If the input array doesn't have a channel dim, this value should be ``'no_channel'``. | ||
dtype: output dtype, defaults to the Python built-in `float`. | ||
Returns: | ||
A MetaTensor object containing the array data and metadata in ChannelFirst format. | ||
""" | ||
reader = ITKReader(affine_lps_to_ras=False, channel_dim=channel_dim) | ||
image_array, meta_data = reader.get_data(image) | ||
image_array = convert_to_dst_type(image_array, dst=image_array, dtype=dtype)[0] | ||
metatensor = MetaTensor.ensure_torch_and_prune_meta(image_array, meta_data) | ||
metatensor = EnsureChannelFirst(channel_dim=channel_dim)(metatensor) | ||
|
||
return cast(MetaTensor, metatensor) | ||
|
||
|
||
def metatensor_to_itk_image( | ||
meta_tensor: MetaTensor, channel_dim: int | None = 0, dtype: DtypeLike = np.float32, **kwargs | ||
): | ||
""" | ||
Converts a MetaTensor object to an ITK image. Expects the MetaTensor to be in ChannelFirst format. | ||
Args: | ||
meta_tensor: The MetaTensor to be converted. | ||
channel_dim: channel dimension of the data array, defaults to ``0`` (Channel-first). | ||
``None`` indicates no channel dimension. This is used to create a Vector Image if it is not ``None``. | ||
dtype: output data type, defaults to `np.float32`. | ||
kwargs: additional keyword arguments. Currently `itk.GetImageFromArray` will get ``ttype`` from this dictionary. | ||
Returns: | ||
The ITK image. | ||
See also: :py:func:`ITKWriter.create_backend_obj` | ||
""" | ||
writer = ITKWriter(output_dtype=dtype, affine_lps_to_ras=False) | ||
writer.set_data_array(data_array=meta_tensor.data, channel_dim=channel_dim, squeeze_end_dims=True) | ||
return writer.create_backend_obj( | ||
writer.data_obj, | ||
channel_dim=writer.channel_dim, | ||
affine=meta_tensor.affine, | ||
affine_lps_to_ras=False, # False if the affine is in itk convention | ||
dtype=writer.output_dtype, | ||
kwargs=kwargs, | ||
) | ||
|
||
|
||
def itk_to_monai_affine(image, matrix, translation, center_of_rotation=None, reference_image=None) -> torch.Tensor: | ||
""" | ||
Converts an ITK affine matrix (2x2 for 2D or 3x3 for 3D matrix and translation vector) to a MONAI affine matrix. | ||
Args: | ||
image: The ITK image object. This is used to extract the spacing and direction information. | ||
matrix: The 2x2 or 3x3 ITK affine matrix. | ||
translation: The 2-element or 3-element ITK affine translation vector. | ||
center_of_rotation: The center of rotation. If provided, the affine | ||
matrix will be adjusted to account for the difference | ||
between the center of the image and the center of rotation. | ||
reference_image: The coordinate space that matrix and translation were defined | ||
in respect to. If not supplied, the coordinate space of image | ||
is used. | ||
Returns: | ||
A 4x4 MONAI affine matrix. | ||
""" | ||
|
||
_assert_itk_regions_match_array(image) | ||
ndim = image.ndim | ||
# If there is a reference image, compute an affine matrix that maps the image space to the | ||
# reference image space. | ||
if reference_image: | ||
reference_affine_matrix = _compute_reference_space_affine_matrix(image, reference_image) | ||
else: | ||
reference_affine_matrix = torch.eye(ndim + 1, dtype=torch.float64) | ||
|
||
# Create affine matrix that includes translation | ||
affine_matrix = torch.eye(ndim + 1, dtype=torch.float64) | ||
affine_matrix[:ndim, :ndim] = torch.tensor(matrix, dtype=torch.float64) | ||
affine_matrix[:ndim, ndim] = torch.tensor(translation, dtype=torch.float64) | ||
|
||
# Adjust offset when center of rotation is different from center of the image | ||
if center_of_rotation: | ||
offset_matrix, inverse_offset_matrix = _compute_offset_matrix(image, center_of_rotation) | ||
affine_matrix = inverse_offset_matrix @ affine_matrix @ offset_matrix | ||
|
||
# Adjust direction | ||
direction_matrix, inverse_direction_matrix = _compute_direction_matrix(image) | ||
affine_matrix = inverse_direction_matrix @ affine_matrix @ direction_matrix | ||
|
||
# Adjust based on spacing. It is required because MONAI does not update the | ||
# pixel array according to the spacing after a transformation. For example, | ||
# a rotation of 90deg for an image with different spacing along the two axis | ||
# will just rotate the image array by 90deg without also scaling accordingly. | ||
spacing_matrix, inverse_spacing_matrix = _compute_spacing_matrix(image) | ||
affine_matrix = inverse_spacing_matrix @ affine_matrix @ spacing_matrix | ||
|
||
return affine_matrix @ reference_affine_matrix | ||
|
||
|
||
def monai_to_itk_affine(image, affine_matrix, center_of_rotation=None): | ||
""" | ||
Converts a MONAI affine matrix to an ITK affine matrix (2x2 for 2D or 3x3 for | ||
3D matrix and translation vector). See also 'itk_to_monai_affine'. | ||
Args: | ||
image: The ITK image object. This is used to extract the spacing and direction information. | ||
affine_matrix: The 3x3 for 2D or 4x4 for 3D MONAI affine matrix. | ||
center_of_rotation: The center of rotation. If provided, the affine | ||
matrix will be adjusted to account for the difference | ||
between the center of the image and the center of rotation. | ||
Returns: | ||
The ITK matrix and the translation vector. | ||
""" | ||
_assert_itk_regions_match_array(image) | ||
|
||
# Adjust spacing | ||
spacing_matrix, inverse_spacing_matrix = _compute_spacing_matrix(image) | ||
affine_matrix = spacing_matrix @ affine_matrix @ inverse_spacing_matrix | ||
|
||
# Adjust direction | ||
direction_matrix, inverse_direction_matrix = _compute_direction_matrix(image) | ||
affine_matrix = direction_matrix @ affine_matrix @ inverse_direction_matrix | ||
|
||
# Adjust offset when center of rotation is different from center of the image | ||
if center_of_rotation: | ||
offset_matrix, inverse_offset_matrix = _compute_offset_matrix(image, center_of_rotation) | ||
affine_matrix = offset_matrix @ affine_matrix @ inverse_offset_matrix | ||
|
||
ndim = image.ndim | ||
matrix = affine_matrix[:ndim, :ndim].numpy() | ||
translation = affine_matrix[:ndim, ndim].tolist() | ||
|
||
return matrix, translation | ||
|
||
|
||
def get_itk_image_center(image): | ||
""" | ||
Calculates the center of the ITK image based on its origin, size, and spacing. | ||
This center is equivalent to the implicit image center that MONAI uses. | ||
Args: | ||
image: The ITK image. | ||
Returns: | ||
The center of the image as a list of coordinates. | ||
""" | ||
image_size = np.asarray(image.GetLargestPossibleRegion().GetSize(), np.float32) | ||
spacing = np.asarray(image.GetSpacing()) | ||
origin = np.asarray(image.GetOrigin()) | ||
center = image.GetDirection() @ ((image_size / 2 - 0.5) * spacing) + origin | ||
|
||
return center.tolist() | ||
|
||
|
||
def _assert_itk_regions_match_array(image): | ||
# Note: Make it more compact? Also, are there redundant checks? | ||
largest_region = image.GetLargestPossibleRegion() | ||
buffered_region = image.GetBufferedRegion() | ||
requested_region = image.GetRequestedRegion() | ||
|
||
largest_region_size = np.array(largest_region.GetSize()) | ||
buffered_region_size = np.array(buffered_region.GetSize()) | ||
requested_region_size = np.array(requested_region.GetSize()) | ||
array_size = np.array(image.shape)[::-1] | ||
|
||
largest_region_index = np.array(largest_region.GetIndex()) | ||
buffered_region_index = np.array(buffered_region.GetIndex()) | ||
requested_region_index = np.array(requested_region.GetIndex()) | ||
|
||
indices_are_zeros = ( | ||
np.all(largest_region_index == 0) and np.all(buffered_region_index == 0) and np.all(requested_region_index == 0) | ||
) | ||
|
||
sizes_match = ( | ||
np.array_equal(array_size, largest_region_size) | ||
and np.array_equal(largest_region_size, buffered_region_size) | ||
and np.array_equal(buffered_region_size, requested_region_size) | ||
) | ||
|
||
if not indices_are_zeros: | ||
raise AssertionError("ITK-MONAI bridge: non-zero ITK region indices encountered") | ||
if not sizes_match: | ||
raise AssertionError("ITK-MONAI bridge: ITK regions should be of the same shape") | ||
|
||
|
||
def _compute_offset_matrix(image, center_of_rotation) -> tuple[torch.Tensor, torch.Tensor]: | ||
ndim = image.ndim | ||
offset = np.asarray(get_itk_image_center(image)) - np.asarray(center_of_rotation) | ||
offset_matrix = torch.eye(ndim + 1, dtype=torch.float64) | ||
offset_matrix[:ndim, ndim] = torch.tensor(offset, dtype=torch.float64) | ||
inverse_offset_matrix = torch.eye(ndim + 1, dtype=torch.float64) | ||
inverse_offset_matrix[:ndim, ndim] = -torch.tensor(offset, dtype=torch.float64) | ||
|
||
return offset_matrix, inverse_offset_matrix | ||
|
||
|
||
def _compute_spacing_matrix(image) -> tuple[torch.Tensor, torch.Tensor]: | ||
ndim = image.ndim | ||
spacing = np.asarray(image.GetSpacing(), dtype=np.float64) | ||
spacing_matrix = torch.eye(ndim + 1, dtype=torch.float64) | ||
inverse_spacing_matrix = torch.eye(ndim + 1, dtype=torch.float64) | ||
for i, e in enumerate(spacing): | ||
spacing_matrix[i, i] = e | ||
inverse_spacing_matrix[i, i] = 1 / e | ||
|
||
return spacing_matrix, inverse_spacing_matrix | ||
|
||
|
||
def _compute_direction_matrix(image) -> tuple[torch.Tensor, torch.Tensor]: | ||
ndim = image.ndim | ||
direction = itk.array_from_matrix(image.GetDirection()) | ||
direction_matrix = torch.eye(ndim + 1, dtype=torch.float64) | ||
direction_matrix[:ndim, :ndim] = torch.tensor(direction, dtype=torch.float64) | ||
inverse_direction = itk.array_from_matrix(image.GetInverseDirection()) | ||
inverse_direction_matrix = torch.eye(ndim + 1, dtype=torch.float64) | ||
inverse_direction_matrix[:ndim, :ndim] = torch.tensor(inverse_direction, dtype=torch.float64) | ||
|
||
return direction_matrix, inverse_direction_matrix | ||
|
||
|
||
def _compute_reference_space_affine_matrix(image, ref_image) -> torch.Tensor: | ||
ndim = ref_image.ndim | ||
|
||
# Spacing and direction as matrices | ||
spacing_matrix, inv_spacing_matrix = (m[:ndim, :ndim].numpy() for m in _compute_spacing_matrix(image)) | ||
ref_spacing_matrix, ref_inv_spacing_matrix = (m[:ndim, :ndim].numpy() for m in _compute_spacing_matrix(ref_image)) | ||
|
||
direction_matrix, inv_direction_matrix = (m[:ndim, :ndim].numpy() for m in _compute_direction_matrix(image)) | ||
ref_direction_matrix, ref_inv_direction_matrix = ( | ||
m[:ndim, :ndim].numpy() for m in _compute_direction_matrix(ref_image) | ||
) | ||
|
||
# Matrix calculation | ||
matrix = ref_direction_matrix @ ref_spacing_matrix @ inv_spacing_matrix @ inv_direction_matrix | ||
|
||
# Offset calculation | ||
pixel_offset = -1 | ||
image_size = np.asarray(ref_image.GetLargestPossibleRegion().GetSize(), np.float32) | ||
translation = ( | ||
(ref_direction_matrix @ ref_spacing_matrix - direction_matrix @ spacing_matrix) | ||
@ (image_size + pixel_offset) | ||
/ 2 | ||
) | ||
translation += np.asarray(ref_image.GetOrigin()) - np.asarray(image.GetOrigin()) | ||
|
||
# Convert matrix ITK matrix and translation to MONAI affine matrix | ||
ref_affine_matrix = itk_to_monai_affine(image, matrix=matrix, translation=translation) | ||
|
||
return ref_affine_matrix | ||
|
||
|
||
def monai_to_itk_ddf(image, ddf): | ||
""" | ||
converting the dense displacement field from the MONAI space to the ITK | ||
Args: | ||
image: itk image of array shape 2D: (H, W) or 3D: (D, H, W) | ||
ddf: numpy array of shape 2D: (2, H, W) or 3D: (3, D, H, W) | ||
Returns: | ||
displacement_field: itk image of the corresponding displacement field | ||
""" | ||
# 3, D, H, W -> D, H, W, 3 | ||
ndim = image.ndim | ||
ddf = ddf.transpose(tuple(list(range(1, ndim + 1)) + [0])) | ||
# x, y, z -> z, x, y | ||
ddf = ddf[..., ::-1] | ||
|
||
# Correct for spacing | ||
spacing = np.asarray(image.GetSpacing(), dtype=np.float64) | ||
ddf *= np.array(spacing, ndmin=ndim + 1) | ||
|
||
# Correct for direction | ||
direction = np.asarray(image.GetDirection(), dtype=np.float64) | ||
ddf = np.einsum("ij,...j->...i", direction, ddf, dtype=np.float64).astype(np.float32) | ||
|
||
# initialise displacement field - | ||
vector_component_type = itk.F | ||
vector_pixel_type = itk.Vector[vector_component_type, ndim] | ||
displacement_field_type = itk.Image[vector_pixel_type, ndim] | ||
displacement_field = itk.GetImageFromArray(ddf, ttype=displacement_field_type) | ||
|
||
# Set image metadata | ||
displacement_field.SetSpacing(image.GetSpacing()) | ||
displacement_field.SetOrigin(image.GetOrigin()) | ||
displacement_field.SetDirection(image.GetDirection()) | ||
|
||
return displacement_field |
Oops, something went wrong.