Skip to content

Commit

Permalink
Extend build_mask to support arbitrary data loading
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668024765
  • Loading branch information
timblakely authored and copybara-github committed Aug 27, 2024
1 parent d242f8f commit 9a1aa4b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 41 deletions.
81 changes: 44 additions & 37 deletions connectomics/volume/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

import dataclasses
import threading
from typing import Optional, Sequence, Union
import typing
from typing import Any, Callable, Sequence

from absl import logging
from connectomics.common import bounding_box
from connectomics.volume import descriptor
from connectomics.volume import metadata
import dataclasses_json
import numpy as np
Expand Down Expand Up @@ -53,16 +54,16 @@ class MaskChannelConfig(dataclasses_json.DataClassJsonMixin):
# 'x', 'y', and 'z' indicating spatial coordinates like in
# CoordinateExpressionOptions. If specified, will be used instead of
# min_value/max_value/values.
expression: Optional[str] = None
expression: str | None = None

# Value to substitute nan's with.
nan_value: Optional[float] = None
nan_value: float | None = None

# A voxel will be considered masked if any voxels within the FOV centered
# around it are masked. The FOV start coordinates are relative to current
# position. If unset, defaults to a FOV of (1, 1, 1), i.e. masked voxels
# corresponding to the mask exactly.
fov: Optional[bounding_box.BoundingBox] = None
fov: bounding_box.BoundingBox | None = None

invert: bool = False

Expand All @@ -80,7 +81,7 @@ class ImageMaskOptions(dataclasses_json.DataClassJsonMixin):
# SECURITY WARNING: This gets passed to Python's eval, which will allow
# execution of arbitrary code. This option is for internal use only. The
# unnormalized 3d image ndarray is accessible under 'image'.
expression: Optional[str] = None
expression: str | None = None
channels: list[MaskChannelConfig] = dataclasses.field(default_factory=list)


Expand All @@ -98,9 +99,9 @@ class CoordinateExpressionOptions(dataclasses_json.DataClassJsonMixin):
@dataclasses.dataclass
class MaskConfig(dataclasses_json.DataClassJsonMixin):
"""Top-level configuration for creating masks."""
volume: Optional[VolumeMaskOptions] = None
image: Optional[ImageMaskOptions] = None
coordinate_expression: Optional[CoordinateExpressionOptions] = None
volume: VolumeMaskOptions | None = None
image: ImageMaskOptions | None = None
coordinate_expression: CoordinateExpressionOptions | None = None

invert: bool = False

Expand All @@ -120,21 +121,26 @@ class MaskConfigs(dataclasses_json.DataClassJsonMixin):

# TODO(timblakely): Return a Subvolume.
def build_mask(
masks: Union[Sequence[MaskConfig], MaskConfigs],
masks: Sequence[MaskConfig] | MaskConfigs,
box: bounding_box.BoundingBoxBase,
mask_volume_map=None,
# TODO(timblakely): Is this used anymore?
# volume_decorator_fn=lambda x:
image=None):
decorated_volume_loader: Callable[[metadata.DecoratedVolume], np.ndarray],
mask_volume_map: dict[str, Any] | None = None,
image: np.ndarray | None = None,
volume_decorator_fn: Callable[[np.ndarray], np.ndarray] = lambda x: x,
) -> np.ndarray:
"""Builds a boolean mask.
Args:
masks: iterable of MaskConfig proto or MaskConfigs proto
box: bounding box defining the area for which to build the mask
decorated_volume_loader: Function to load a DecoratedVolume object, return
an ndarray-like object.
mask_volume_map: optional dict mapping volume proto hashes to open volumes;
use this as a cache to avoid opening volumes multiple times.
image: 3d image ndarray; only needed if the mask config uses the image as
input
volume_decorator_fn: callable taking a volume object and returning another
object supporting the same interface.
Returns:
boolean mask built according to the specified config
Expand All @@ -147,49 +153,43 @@ def build_mask(
if isinstance(masks, MaskConfigs):
invert = masks.invert
masks = masks.masks
masks = typing.cast(Sequence[MaskConfig], masks)

final_mask = None
if mask_volume_map is None:
mask_volume_map = {}

src_box = box

z, y, x = np.mgrid[src_box.to_slice3d()] # pylint:disable=unused-variable
final_mask: Optional[np.ndarray] = None
z, y, x = np.mgrid[box.to_slice3d()] # pylint:disable=unused-variable
mask = None
for config in masks:
curr_mask = np.zeros(box.size[::-1], dtype=bool)

channels: list[MaskChannelConfig] = []
mask: Optional[np.ndarray] = None

if config.coordinate_expression is not None:
bool_mask = eval(config.coordinate_expression.expression) # pylint: disable=eval-used
curr_mask = np.array(bool_mask)
curr_mask |= bool_mask
else:
if config.image:
if config.image is not None:
assert image is not None

if config.image.expression:
if config.image.expression is not None:
channels = []
assert not config.image.channels
curr_mask |= eval(config.image.expression) # pylint: disable=eval-used
else:
channels = config.image.channels
mask = image[np.newaxis, ...]
elif config.volume:
elif config.volume is not None:
channels = config.volume.channels

volume_key = config.volume.mask.to_json()
if volume_key not in mask_volume_map:
mask_volume_map[volume_key] = descriptor.open_descriptor(
config.volume.mask)
mask_volume_map[volume_key] = volume_decorator_fn(
decorated_volume_loader(config.volume.mask)
)
volume = mask_volume_map[volume_key]
# TODO(timblakely): Have these handle subvolumes.
mask = np.asarray(volume[box.to_slice4d()].data)
mask = volume[box.to_slice()]
else:
raise ValueError('Unsupported mask configuration')

assert mask is not None
logging.fatal('Unsupported mask source: %s', config.to_json())

for chan_config in channels:
channel_mask = mask[chan_config.channel, ...]
Expand All @@ -200,22 +200,28 @@ def build_mask(
if chan_config.expression:
bool_mask = eval(chan_config.expression) # pylint: disable=eval-used
elif chan_config.values:
bool_mask = np.in1d(channel_mask,
chan_config.values).reshape(channel_mask.shape)
bool_mask = np.in1d(channel_mask, chan_config.values).reshape(
channel_mask.shape
)
else:
bool_mask = ((channel_mask >= chan_config.min_value) &
(channel_mask <= chan_config.max_value))
assert chan_config.max_value >= chan_config.min_value
bool_mask = (channel_mask >= chan_config.min_value) & (
channel_mask <= chan_config.max_value
)
if chan_config.invert:
bool_mask = np.logical_not(bool_mask)

if chan_config.fov is not None:
fov = chan_config.fov
# TODO(timblakely): Type checker appears to be confused without this
# check...?
assert fov is not None
with _MASK_SEM:
bool_mask = ndimage.maximum_filter(
bool_mask,
size=fov.size[::-1],
origin=fov.size[::-1] // 2 + fov.start[::-1])
origin=fov.size[::-1] // 2 + fov.start[::-1],
)
curr_mask |= bool_mask

if config.invert:
Expand All @@ -227,6 +233,7 @@ def build_mask(
final_mask |= curr_mask

assert final_mask is not None

if invert:
return np.logical_not(final_mask)
else:
Expand Down
20 changes: 16 additions & 4 deletions connectomics/volume/mask_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,18 +39,30 @@ def test_build_mask(self):

box = bounding_box.BoundingBox(start=(0, 0, 0), size=subvol_size[::-1])

mask = m.build_mask([mask_config], box, image=image)
mask = m.build_mask(
[mask_config],
box,
decorated_volume_loader=lambda x: x,
image=image,
)

np.testing.assert_array_equal(mask, image >= 0.5)

# Test with an int image, and masking specific values only.
image = np.random.randint(0, 10, subvol_size, dtype=np.uint8)
chan_config.values = [1, 5, 8]

mask = m.build_mask([mask_config], box, image=image)
self.called = False
mask = m.build_mask(
[mask_config],
box,
decorated_volume_loader=lambda x: x,
image=image,
)

np.testing.assert_array_equal(mask,
(image == 1) | (image == 5) | (image == 8))
np.testing.assert_array_equal(
mask, (image == 1) | (image == 5) | (image == 8)
)


if __name__ == '__main__':
Expand Down

0 comments on commit 9a1aa4b

Please sign in to comment.