Skip to content

Commit

Permalink
Merge pull request #32 from brainglobe/add-cropping-util
Browse files Browse the repository at this point in the history
implement a cropping+padding utility function
  • Loading branch information
alessandrofelder authored Aug 20, 2024
2 parents 9cc9611 + c3f6a66 commit afe54b1
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 0 deletions.
48 changes: 48 additions & 0 deletions brainglobe_template_builder/preproc/cropping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np


def crop_to_mask(
stack: np.ndarray, mask: np.ndarray, padding: np.uint8 = 0
) -> tuple[np.ndarray, np.ndarray]:
"""
Crop stack and mask to the mask extent, and pad with zeros.
Args:
Stack (np.ndarray): Stack
Mask (np.ndarray): Mask
padding (np.uint8):
number of pixels to pad with on all sides. Default is 0.
Returns:
tuple[np.ndarray, np.ndarray]: the cropped, padded stack and mask.
"""
assert (
stack.shape == mask.shape
), "Stack and mask must have the same shape."
assert not np.all(
mask == 0
), "The mask is invalid because it does not contain foreground."
# Find the bounding box of the mask
mask_indices = np.nonzero(mask)
min_z = np.min(mask_indices[0])
max_z = np.max(mask_indices[0])
min_y = np.min(mask_indices[1])
max_y = np.max(mask_indices[1])
min_x = np.min(mask_indices[2])
max_x = np.max(mask_indices[2])

# Crop the stack and mask to the bounding box
stack = stack[min_z : max_z + 1, min_y : max_y + 1, min_x : max_x + 1]
mask = mask[min_z : max_z + 1, min_y : max_y + 1, min_x : max_x + 1]
if padding:
stack = np.pad(
stack,
((padding, padding), (padding, padding), (padding, padding)),
mode="constant",
)
mask = np.pad(
mask,
((padding, padding), (padding, padding), (padding, padding)),
mode="constant",
)
return stack, mask
85 changes: 85 additions & 0 deletions tests/test_unit/test_cropping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import numpy as np
import pytest

from brainglobe_template_builder.preproc.cropping import crop_to_mask


def test_crop_to_mask_invalid_stack_and_mask():
stack = np.zeros((10, 10, 10))
mask = np.zeros((20, 20, 20))
with pytest.raises(AssertionError) as e:
_ = crop_to_mask(stack, mask)
assert str(e.value) == "Stack and mask must have the same shape."


def test_crop_to_mask_invalid_mask():
stack = np.ones((10, 10, 10))
mask = np.zeros((10, 10, 10))
with pytest.raises(AssertionError) as e:
_ = crop_to_mask(stack, mask)
assert (
str(e.value)
== "The mask is invalid because it does not contain foreground."
)


def test_simple_crop_to_mask():
stack = np.ones((10, 10, 10))
mask = np.zeros((10, 10, 10))
mask[3:7, 3:7, 3:7] = 1
cropped_stack, cropped_mask = crop_to_mask(stack, mask)
assert cropped_stack.shape == (4, 4, 4)
assert cropped_mask.shape == (4, 4, 4)
assert np.all(cropped_stack == stack[3:7, 3:7, 3:7])
assert np.all(cropped_mask == mask[3:7, 3:7, 3:7])


@pytest.mark.parametrize("padding", [1, 5, 10])
def test_padding(padding):
stack = np.ones((10, 10, 10))
mask = np.ones((10, 10, 10))
cropped_stack, cropped_mask = crop_to_mask(stack, mask, padding=padding)
assert cropped_stack.shape == tuple([s + 2 * padding for s in stack.shape])
assert cropped_mask.shape == tuple([s + 2 * padding for s in stack.shape])
assert np.all(
cropped_stack[padding:-padding, padding:-padding, padding:-padding]
== stack
)
assert np.all(
cropped_mask[padding:-padding, padding:-padding, padding:-padding]
== mask
)
assert np.all(cropped_mask[0:padding, :, :] == 0)
assert np.all(cropped_mask[-padding:, :, :] == 0)
assert np.all(cropped_mask[:, 0:padding, :] == 0)
assert np.all(cropped_mask[:, -padding:, :] == 0)
assert np.all(cropped_mask[:, :, 0:padding] == 0)
assert np.all(cropped_mask[:, :, -padding:] == 0)


def test_crop_to_full_mask_does_nothing():
stack = np.ones((10, 10, 10))
mask = np.ones((10, 10, 10))
cropped_stack, cropped_mask = crop_to_mask(stack, mask)
assert cropped_stack.shape == (10, 10, 10)
assert cropped_mask.shape == (10, 10, 10)
assert np.all(cropped_stack == stack)
assert np.all(cropped_mask == mask)


def test_crop_to_mask_with_padding():
stack = np.ones((10, 10, 10))
mask = np.zeros((10, 10, 10))
mask[3:7, 3:7, 3:7] = 1
padding = 2
cropped_stack, cropped_mask = crop_to_mask(stack, mask, padding=padding)
assert cropped_stack.shape == (8, 8, 8)
assert cropped_mask.shape == (8, 8, 8)
assert np.all(
cropped_stack[padding:-padding, padding:-padding, padding:-padding]
== stack[3:7, 3:7, 3:7]
)
assert np.all(
cropped_mask[padding:-padding, padding:-padding, padding:-padding]
== mask[3:7, 3:7, 3:7]
)

0 comments on commit afe54b1

Please sign in to comment.