Skip to content

Commit

Permalink
Add defect and resin detector
Browse files Browse the repository at this point in the history
  • Loading branch information
trivoldus28 committed Nov 21, 2023
1 parent 95a31de commit ae3fd32
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 0 deletions.
2 changes: 2 additions & 0 deletions zetta_utils/alignment/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@
from .base_coarsener import BaseCoarsener
from .encoding_coarsener import EncodingCoarsener
from .misalignment_detector import MisalignmentDetector
from .defect_detector import DefectDetector
from .resin_detector import ResinDetector
80 changes: 80 additions & 0 deletions zetta_utils/alignment/defect_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import attrs
import einops
import torch
from typeguard import typechecked

from zetta_utils import builder, convnet


@builder.register("DefectDetector")
@typechecked
@attrs.mutable
class DefectDetector:
# Input uint8 [ 0 .. 255]
# Output uint8 Prediction [0 .. 255]

# Don't create the model during initialization for efficient serialization
model_path: str
tile_pad_in: int = 32
tile_size: int = 448

def __call__(self, src: torch.Tensor) -> torch.Tensor:
if (src != 0).sum() == 0:
result = torch.zeros_like(src).float()
else:
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"

# load model during the call _with caching_
model = convnet.utils.load_model(self.model_path, device=device, use_cache=True)

if src.dtype == torch.uint8:
data_in = src.float() / 255.0 # [0.0 .. 1.0]
else:
raise ValueError(f"Unsupported src dtype: {src.dtype}")

data_in = einops.rearrange(data_in, "C X Y Z -> Z C X Y")
data_in = data_in.to(device=device)
with torch.no_grad():
result = torch.zeros_like(
data_in[
...,
: data_in.shape[-2],
: data_in.shape[-1],
]
).float()

tile_pad_out = self.tile_pad_in

for x in range(
self.tile_pad_in, data_in.shape[-2] - self.tile_pad_in, self.tile_size
):
x_start = x - self.tile_pad_in
x_end = x + self.tile_size + self.tile_pad_in
for y in range(
self.tile_pad_in, data_in.shape[-1] - self.tile_pad_in, self.tile_size
):
y_start = y - self.tile_pad_in
y_end = y + self.tile_size + self.tile_pad_in
tile = data_in[:, :, x_start:x_end, y_start:y_end]
if (tile != 0).sum() > 0.0:
tile_result = model(tile)
if tile_pad_out > 0:
tile_result = tile_result[
:, :, tile_pad_out:-tile_pad_out, tile_pad_out:-tile_pad_out
]

result[
:,
:,
x : x + tile_result.shape[-2],
y : y + tile_result.shape[-1],
] = tile_result

result = einops.rearrange(result, "Z C X Y -> C X Y Z")
result = 255.0 * torch.sigmoid(result)
result[src == 0.0] = 0.0

return result.round().clamp(0, 255).type(torch.uint8)
107 changes: 107 additions & 0 deletions zetta_utils/alignment/resin_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import attrs
import einops
import torch
from typeguard import typechecked

from zetta_utils import builder, convnet
import numpy as np
import cv2
import fastremap
import cc3d

@builder.register("ResinDetector")
@typechecked
@attrs.mutable
class ResinDetector:
# Input uint8 [ 0 .. 255]
# Output uint8 Prediction [0 .. 255]

# Don't create the model during initialization for efficient serialization
model_path: str
tile_pad_in: int = 32
tile_size: int = 448
tissue_filter_threshold: int = 1000
resin_filter_threshold: int = 1000

def __call__(self, src: torch.Tensor) -> torch.Tensor:
if (src != 0).sum() == 0:
return torch.full_like(src, 255).type(torch.uint8)
else:
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"

# load model during the call _with caching_
model = convnet.utils.load_model(self.model_path, device=device, use_cache=True)

if src.dtype == torch.uint8:
data_in = src.float() / 255.0 # [0.0 .. 1.0]
else:
raise ValueError(f"Unsupported src dtype: {src.dtype}")

data_in = einops.rearrange(data_in, "C X Y Z -> Z C X Y")
data_in = data_in.to(device=device)
with torch.no_grad():
result = torch.zeros_like(
data_in[
...,
: data_in.shape[-2],
: data_in.shape[-1],
]
).float()

tile_pad_out = self.tile_pad_in

for x in range(
self.tile_pad_in, data_in.shape[-2] - self.tile_pad_in, self.tile_size
):
x_start = x - self.tile_pad_in
x_end = x + self.tile_size + self.tile_pad_in
for y in range(
self.tile_pad_in, data_in.shape[-1] - self.tile_pad_in, self.tile_size
):
y_start = y - self.tile_pad_in
y_end = y + self.tile_size + self.tile_pad_in
tile = data_in[:, :, x_start:x_end, y_start:y_end]
if (tile != 0).sum() > 0.0:
tile_result = model(tile)
if tile_pad_out > 0:
tile_result = tile_result[
:, :, tile_pad_out:-tile_pad_out, tile_pad_out:-tile_pad_out
]

result[
:,
:,
x : x + tile_result.shape[-2],
y : y + tile_result.shape[-1],
] = tile_result

result = einops.rearrange(result, "Z C X Y -> C X Y Z")
result = torch.sigmoid(result)
pred = (((result > 250. / 255.) * 255).to(dtype=torch.uint8, device='cpu'))

# Background is resin
pred[src == 0.0] = 255

# Filter small islands of tissue
tissue = (255 - pred).squeeze().numpy()
tissue = cv2.morphologyEx(tissue, cv2.MORPH_CLOSE, np.ones((3,3), np.uint8))
tissue = cv2.morphologyEx(tissue, cv2.MORPH_OPEN, np.ones((3,3), np.uint8))
if self.tissue_filter_threshold > 0:
cc = cc3d.connected_components(tissue)
uniq, counts = fastremap.unique(cc, return_counts=True)
cc = fastremap.mask(cc, [lbl for lbl, cnt in zip(uniq, counts) if cnt < self.tissue_filter_threshold])
tissue[cc==0] = 0

# Filter small islands of resin
resin = 255 - tissue
if self.resin_filter_threshold > 0:
cc = cc3d.connected_components(resin)
uniq, counts = fastremap.unique(cc, return_counts=True)
cc = fastremap.mask(cc, [lbl for lbl, cnt in zip(uniq, counts) if cnt < self.resin_filter_threshold])
resin[cc==0] = 0


return torch.from_numpy(resin).reshape(pred.shape)

0 comments on commit ae3fd32

Please sign in to comment.