From 796271ccd4ee48065af3b2afd56bf46960a53112 Mon Sep 17 00:00:00 2001 From: Hsin-Tong Hsieh <75067139+ctongh@users.noreply.github.com> Date: Sun, 13 Oct 2024 05:55:46 +0800 Subject: [PATCH] Add options to skip operations for RestoreLabeld Transform (#8125) Fixes #6380 ### Description Four new bool parameters are added into `RestoreLabeld` to allow users to selectively enable or disable each restoration operation as needed, and a corresponding test case is added to verify that the function runs correctly. This design allows users to selectively enable or disable each restoration operation as needed, providing greater flexibility. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Hsin Tong Signed-off-by: Hsin-Tong Hsieh <75067139+ctongh@users.noreply.github.com> Signed-off-by: kbbbbkb <139567836+Kbinn@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: kbbbbkb <139567836+Kbinn@users.noreply.github.com> --- monai/apps/deepgrow/transforms.py | 69 ++++++++++++++-------- tests/test_deepgrow_transforms.py | 95 ++++++++++++++++++++++++++++++- 2 files changed, 140 insertions(+), 24 deletions(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index c2f97091fd..721c0db489 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -803,6 +803,14 @@ class RestoreLabeld(MapTransform): original_shape_key: key that records original shape for foreground. cropped_shape_key: key that records cropped shape for foreground. allow_missing_keys: don't raise exception if key is missing. + restore_resizing: used to enable or disable resizing restoration, default is True. + If True, the transform will resize the items back to its original shape. + restore_cropping: used to enable or disable cropping restoration, default is True. + If True, the transform will restore the items to its uncropped size. + restore_spacing: used to enable or disable spacing restoration, default is True. + If True, the transform will resample the items back to the spacing it had before being altered. + restore_slicing: used to enable or disable slicing restoration, default is True. + If True, the transform will reassemble the full volume by restoring the slices to their original positions. """ def __init__( @@ -819,6 +827,10 @@ def __init__( original_shape_key: str = "foreground_original_shape", cropped_shape_key: str = "foreground_cropped_shape", allow_missing_keys: bool = False, + restore_resizing: bool = True, + restore_cropping: bool = True, + restore_spacing: bool = True, + restore_slicing: bool = True, ) -> None: super().__init__(keys, allow_missing_keys) self.ref_image = ref_image @@ -833,6 +845,10 @@ def __init__( self.end_coord_key = end_coord_key self.original_shape_key = original_shape_key self.cropped_shape_key = cropped_shape_key + self.restore_resizing = restore_resizing + self.restore_cropping = restore_cropping + self.restore_spacing = restore_spacing + self.restore_slicing = restore_slicing def __call__(self, data: Any) -> dict: d = dict(data) @@ -842,38 +858,45 @@ def __call__(self, data: Any) -> dict: image = d[key] # Undo Resize - current_shape = image.shape - cropped_shape = meta_dict[self.cropped_shape_key] - if np.any(np.not_equal(current_shape, cropped_shape)): - resizer = Resize(spatial_size=cropped_shape[1:], mode=mode) - image = resizer(image, mode=mode, align_corners=align_corners) + if self.restore_resizing: + current_shape = image.shape + cropped_shape = meta_dict[self.cropped_shape_key] + if np.any(np.not_equal(current_shape, cropped_shape)): + resizer = Resize(spatial_size=cropped_shape[1:], mode=mode) + image = resizer(image, mode=mode, align_corners=align_corners) # Undo Crop - original_shape = meta_dict[self.original_shape_key] - result = np.zeros(original_shape, dtype=np.float32) - box_start = meta_dict[self.start_coord_key] - box_end = meta_dict[self.end_coord_key] - - spatial_dims = min(len(box_start), len(image.shape[1:])) - slices = tuple( - [slice(None)] + [slice(s, e) for s, e in zip(box_start[:spatial_dims], box_end[:spatial_dims])] - ) - result[slices] = image + if self.restore_cropping: + original_shape = meta_dict[self.original_shape_key] + result = np.zeros(original_shape, dtype=np.float32) + box_start = meta_dict[self.start_coord_key] + box_end = meta_dict[self.end_coord_key] + + spatial_dims = min(len(box_start), len(image.shape[1:])) + slices = tuple( + [slice(None)] + [slice(s, e) for s, e in zip(box_start[:spatial_dims], box_end[:spatial_dims])] + ) + result[slices] = image + else: + result = image # Undo Spacing - current_size = result.shape[1:] - # change spatial_shape from HWD to DHW - spatial_shape = list(np.roll(meta_dict["spatial_shape"], 1)) - spatial_size = spatial_shape[-len(current_size) :] + if self.restore_spacing: + current_size = result.shape[1:] + # change spatial_shape from HWD to DHW + spatial_shape = list(np.roll(meta_dict["spatial_shape"], 1)) + spatial_size = spatial_shape[-len(current_size) :] - if np.any(np.not_equal(current_size, spatial_size)): - resizer = Resize(spatial_size=spatial_size, mode=mode) - result = resizer(result, mode=mode, align_corners=align_corners) # type: ignore + if np.any(np.not_equal(current_size, spatial_size)): + resizer = Resize(spatial_size=spatial_size, mode=mode) + result = resizer(result, mode=mode, align_corners=align_corners) # type: ignore # Undo Slicing slice_idx = meta_dict.get("slice_idx") final_result: NdarrayOrTensor - if slice_idx is None or self.slice_only: + if not self.restore_slicing: # do nothing if restore slicing isn't requested + final_result = result + elif slice_idx is None or self.slice_only: final_result = result if len(result.shape) <= 3 else result[0] else: slice_idx = meta_dict["slice_idx"][0] diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py index a491a8004b..091d00afcd 100644 --- a/tests/test_deepgrow_transforms.py +++ b/tests/test_deepgrow_transforms.py @@ -141,6 +141,21 @@ DATA_12 = {"image": np.arange(27).reshape(3, 3, 3), PostFix.meta("image"): {}, "guidance": [[0, 0, 0], [0, 1, 1], 1]} +DATA_13 = { + "image": np.arange(64).reshape((1, 4, 4, 4)), + PostFix.meta("image"): { + "spatial_shape": [8, 8, 4], + "foreground_start_coord": np.array([1, 1, 1]), + "foreground_end_coord": np.array([3, 3, 3]), + "foreground_original_shape": (1, 4, 4, 4), + "foreground_cropped_shape": (1, 2, 2, 2), + "original_affine": np.array( + [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]] + ), + }, + "pred": np.array([[[[10, 20], [30, 40]], [[50, 60], [70, 80]]]]), +} + FIND_SLICE_TEST_CASE_1 = [{"label": "label", "sids": "sids"}, DATA_1, [0]] FIND_SLICE_TEST_CASE_2 = [{"label": "label", "sids": "sids"}, DATA_2, [0, 1]] @@ -329,6 +344,74 @@ RESTORE_LABEL_TEST_CASE_2 = [{"keys": ["pred"], "ref_image": "image", "mode": "nearest"}, DATA_11, RESULT] +RESTORE_LABEL_TEST_CASE_3_RESULT = np.zeros((10, 20, 20)) +RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 0:10, 0:10] = 1 +RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 0:10, 10:20] = 2 +RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 10:20, 0:10] = 3 +RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 10:20, 10:20] = 4 +RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 0:10, 0:10] = 5 +RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 0:10, 10:20] = 6 +RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 10:20, 0:10] = 7 +RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 10:20, 10:20] = 8 + +RESTORE_LABEL_TEST_CASE_3 = [ + {"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_cropping": False}, + DATA_11, + RESTORE_LABEL_TEST_CASE_3_RESULT, +] + +RESTORE_LABEL_TEST_CASE_4_RESULT = np.zeros((4, 8, 8)) +RESTORE_LABEL_TEST_CASE_4_RESULT[1, 2:6, 2:6] = np.array( + [[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [30.0, 30.0, 40.0, 40.0]] +) +RESTORE_LABEL_TEST_CASE_4_RESULT[2, 2:6, 2:6] = np.array( + [[50.0, 50.0, 60.0, 60.0], [50.0, 50.0, 60.0, 60.0], [70.0, 70.0, 80.0, 80.0], [70.0, 70.0, 80.0, 80.0]] +) + +RESTORE_LABEL_TEST_CASE_4 = [ + {"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_resizing": False}, + DATA_13, + RESTORE_LABEL_TEST_CASE_4_RESULT, +] + +RESTORE_LABEL_TEST_CASE_5_RESULT = np.zeros((4, 4, 4)) +RESTORE_LABEL_TEST_CASE_5_RESULT[1, 1:3, 1:3] = np.array([[10.0, 20.0], [30.0, 40.0]]) +RESTORE_LABEL_TEST_CASE_5_RESULT[2, 1:3, 1:3] = np.array([[50.0, 60.0], [70.0, 80.0]]) + +RESTORE_LABEL_TEST_CASE_5 = [ + {"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_spacing": False}, + DATA_13, + RESTORE_LABEL_TEST_CASE_5_RESULT, +] + +RESTORE_LABEL_TEST_CASE_6_RESULT = np.zeros((1, 4, 8, 8)) +RESTORE_LABEL_TEST_CASE_6_RESULT[-1, 1, 2:6, 2:6] = np.array( + [[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [30.0, 30.0, 40.0, 40.0]] +) +RESTORE_LABEL_TEST_CASE_6_RESULT[-1, 2, 2:6, 2:6] = np.array( + [[50.0, 50.0, 60.0, 60.0], [50.0, 50.0, 60.0, 60.0], [70.0, 70.0, 80.0, 80.0], [70.0, 70.0, 80.0, 80.0]] +) + +RESTORE_LABEL_TEST_CASE_6 = [ + {"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_slicing": False}, + DATA_13, + RESTORE_LABEL_TEST_CASE_6_RESULT, +] + +RESTORE_LABEL_TEST_CASE_7 = [ + { + "keys": ["pred"], + "ref_image": "image", + "mode": "nearest", + "restore_resizing": False, + "restore_cropping": False, + "restore_spacing": False, + "restore_slicing": False, + }, + DATA_11, + np.array([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]), +] + FETCH_2D_SLICE_TEST_CASE_1 = [ {"keys": ["image"], "guidance": "guidance"}, DATA_12, @@ -445,7 +528,17 @@ def test_correct_results(self, arguments, input_data, expected_result): class TestRestoreLabeld(unittest.TestCase): - @parameterized.expand([RESTORE_LABEL_TEST_CASE_1, RESTORE_LABEL_TEST_CASE_2]) + @parameterized.expand( + [ + RESTORE_LABEL_TEST_CASE_1, + RESTORE_LABEL_TEST_CASE_2, + RESTORE_LABEL_TEST_CASE_3, + RESTORE_LABEL_TEST_CASE_4, + RESTORE_LABEL_TEST_CASE_5, + RESTORE_LABEL_TEST_CASE_6, + RESTORE_LABEL_TEST_CASE_7, + ] + ) def test_correct_results(self, arguments, input_data, expected_result): result = RestoreLabeld(**arguments)(input_data) np.testing.assert_allclose(result["pred"], expected_result)