Skip to content

Commit

Permalink
feat: add filter_cc3d()
Browse files Browse the repository at this point in the history
  • Loading branch information
trivoldus28 authored and supersergiy committed Jun 22, 2024
1 parent 6f43304 commit 7eb6d0f
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 0 deletions.
86 changes: 86 additions & 0 deletions tests/unit/tensor_ops/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,92 @@ def test_filter_cc_big():
assert_array_equal(result, expected)


def test_filter_cc3d_small():
a = torch.Tensor(
[
[
[1, 1, 0, 1],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 1, 1],
],
[
[1, 1, 0, 1],
[1, 1, 0, 0],
[0, 0, 0, 0],
[1, 0, 0, 0],
],
]
).unsqueeze(0)

expected = torch.Tensor(
[
[
[0, 0, 0, 1],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 1, 1],
],
[
[0, 0, 0, 1],
[0, 0, 0, 0],
[0, 0, 0, 0],
[1, 0, 0, 0],
],
]
).unsqueeze(0)

result = mask.filter_cc3d(
a,
mode="keep_small",
thr=2,
)
assert_array_equal(result, expected)


def test_filter_cc3d_large():
a = torch.Tensor(
[
[
[1, 1, 0, 1],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 1, 1],
],
[
[1, 1, 0, 1],
[1, 1, 0, 0],
[0, 0, 0, 0],
[1, 0, 0, 0],
],
]
).unsqueeze(0)

expected = torch.Tensor(
[
[
[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
],
[
[1, 1, 0, 0],
[1, 1, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
],
]
).unsqueeze(0)

result = mask.filter_cc3d(
a,
mode="keep_large",
thr=2,
)
assert_array_equal(result, expected)


def test_kornia_closing():
a = np.expand_dims(
np.array(
Expand Down
44 changes: 44 additions & 0 deletions zetta_utils/tensor_ops/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,50 @@ def filter_cc(
return result


@builder.register("filter_cc3d") # type: ignore # TODO: pyright
@supports_dict
@skip_on_empty_data
@typechecked
def filter_cc3d(
data: TensorTypeVar,
mode: MaskFilteringModes = "keep_small",
thr: int = 100,
connectivity_3d: Literal[6, 18, 26] = 6,
) -> TensorTypeVar:
"""
Remove 3D connected components from the given input tensor_ops.
Clustering is performed based on non-zero values.
:param data: Input tensor (CXYZ).
:param mode: Filtering mode.
:param thr: Pixel size threshold.
:return: Tensor with the filtered clusters removed.
"""
data_np = convert.to_np(data)

data_np = einops.rearrange(data_np, "1 X Y Z -> X Y Z")

result_raw = np.zeros_like(data_np)

if (data_np != 0).sum() > 0:
cc_labels = cc3d.connected_components(data_np != 0, connectivity=connectivity_3d)
segids, counts = np.unique(cc_labels, return_counts=True)
if mode == "keep_large":
segids = [segid for segid, ct in zip(segids, counts) if ct > thr]
else:
segids = [segid for segid, ct in zip(segids, counts) if ct <= thr]

filtered_mask = fastremap.mask_except(cc_labels, segids, in_place=True) != 0

result_raw = copy.copy(data_np)
result_raw[filtered_mask == 0] = 0

result_raw = einops.rearrange(result_raw, "X Y Z -> 1 X Y Z")
result = convert.astype(result_raw, data)
return result


def _normalize_kernel(
kernel: Union[Tensor, str], width: int, device: torch.types.Device
) -> torch.Tensor:
Expand Down

0 comments on commit 7eb6d0f

Please sign in to comment.