Skip to content

Commit

Permalink
[fix] RandCropByLabelClasses now uses image size to correct its gener…
Browse files Browse the repository at this point in the history
…ated centers
  • Loading branch information
ltetrel committed Nov 23, 2023
1 parent 8e134b8 commit 599aa20
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
6 changes: 4 additions & 2 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,7 @@ def randomize(
label: torch.Tensor | None = None,
indices: list[NdarrayOrTensor] | None = None,
image: torch.Tensor | None = None,
img: torch.Tensor | None = None,
) -> None:
indices_ = self.indices if indices is None else indices
if indices_ is None:
Expand All @@ -1344,8 +1345,9 @@ def randomize(
_shape = image.peek_pending_shape() if isinstance(image, MetaTensor) else image.shape[1:]
if _shape is None:
raise ValueError("label or image must be provided to infer the output spatial shape.")
img_shape = img.peek_pending_shape() if isinstance(image, MetaTensor) else img.shape[1:]
self.centers = generate_label_classes_crop_centers(
self.spatial_size, self.num_samples, _shape, indices_, self.ratios, self.R, self.allow_smaller, self.warn
self.spatial_size, self.num_samples, _shape, indices_, img_shape, self.ratios, self.R, self.allow_smaller, self.warn
)

@LazyTransform.lazy.setter # type: ignore
Expand Down Expand Up @@ -1381,7 +1383,7 @@ def __call__(
if randomize:
if label is None:
label = self.label
self.randomize(label, indices, image)
self.randomize(label, indices, image, img)
results: list[torch.Tensor] = []
if self.centers is not None:
img_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
Expand Down
9 changes: 8 additions & 1 deletion monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,7 @@ def generate_label_classes_crop_centers(
num_samples: int,
label_spatial_shape: Sequence[int],
indices: Sequence[NdarrayOrTensor],
img_spatial_shape: Sequence[int] | None = None,
ratios: list[float | int] | None = None,
rand_state: np.random.RandomState | None = None,
allow_smaller: bool = False,
Expand All @@ -600,6 +601,7 @@ def generate_label_classes_crop_centers(
spatial_size: spatial size of the ROIs to be sampled.
num_samples: total sample centers to be generated.
label_spatial_shape: spatial shape of the original label data to unravel selected centers.
img_spatial_shape: spatial shape of the original image data to correct crop centers.
indices: sequence of pre-computed foreground indices of every class in 1 dimension.
ratios: ratios of every class in the label to generate crop centers, including background class.
if None, every class will have the same ratio to generate crop centers.
Expand All @@ -613,6 +615,11 @@ def generate_label_classes_crop_centers(
if rand_state is None:
rand_state = np.random.random.__self__ # type: ignore

if img_spatial_shape is None:
img_spatial_shape = label_spatial_shape
if warn:
warnings.warn(f"img_spatial_shape not defined, expect invalid shape for samples.")

if num_samples < 1:
raise ValueError(f"num_samples must be an int number and greater than 0, got {num_samples}.")
ratios_: list[float | int] = list(ensure_tuple([1] * len(indices) if ratios is None else ratios))
Expand All @@ -637,7 +644,7 @@ def generate_label_classes_crop_centers(
random_int = rand_state.randint(len(indices_to_use))
center = unravel_index(indices_to_use[random_int], label_spatial_shape).tolist()
# shift center to range of valid centers
centers.append(correct_crop_centers(center, spatial_size, label_spatial_shape, allow_smaller))
centers.append(correct_crop_centers(center, spatial_size, img_spatial_shape, allow_smaller))

return ensure_tuple(centers)

Expand Down

0 comments on commit 599aa20

Please sign in to comment.