diff --git a/connectomics/volume/mask.py b/connectomics/volume/mask.py index 6f12d79..b7ce35b 100644 --- a/connectomics/volume/mask.py +++ b/connectomics/volume/mask.py @@ -16,10 +16,11 @@ import dataclasses import threading -from typing import Optional, Sequence, Union +import typing +from typing import Any, Callable, Sequence +from absl import logging from connectomics.common import bounding_box -from connectomics.volume import descriptor from connectomics.volume import metadata import dataclasses_json import numpy as np @@ -53,16 +54,16 @@ class MaskChannelConfig(dataclasses_json.DataClassJsonMixin): # 'x', 'y', and 'z' indicating spatial coordinates like in # CoordinateExpressionOptions. If specified, will be used instead of # min_value/max_value/values. - expression: Optional[str] = None + expression: str | None = None # Value to substitute nan's with. - nan_value: Optional[float] = None + nan_value: float | None = None # A voxel will be considered masked if any voxels within the FOV centered # around it are masked. The FOV start coordinates are relative to current # position. If unset, defaults to a FOV of (1, 1, 1), i.e. masked voxels # corresponding to the mask exactly. - fov: Optional[bounding_box.BoundingBox] = None + fov: bounding_box.BoundingBox | None = None invert: bool = False @@ -80,7 +81,7 @@ class ImageMaskOptions(dataclasses_json.DataClassJsonMixin): # SECURITY WARNING: This gets passed to Python's eval, which will allow # execution of arbitrary code. This option is for internal use only. The # unnormalized 3d image ndarray is accessible under 'image'. - expression: Optional[str] = None + expression: str | None = None channels: list[MaskChannelConfig] = dataclasses.field(default_factory=list) @@ -98,9 +99,9 @@ class CoordinateExpressionOptions(dataclasses_json.DataClassJsonMixin): @dataclasses.dataclass class MaskConfig(dataclasses_json.DataClassJsonMixin): """Top-level configuration for creating masks.""" - volume: Optional[VolumeMaskOptions] = None - image: Optional[ImageMaskOptions] = None - coordinate_expression: Optional[CoordinateExpressionOptions] = None + volume: VolumeMaskOptions | None = None + image: ImageMaskOptions | None = None + coordinate_expression: CoordinateExpressionOptions | None = None invert: bool = False @@ -120,21 +121,26 @@ class MaskConfigs(dataclasses_json.DataClassJsonMixin): # TODO(timblakely): Return a Subvolume. def build_mask( - masks: Union[Sequence[MaskConfig], MaskConfigs], + masks: Sequence[MaskConfig] | MaskConfigs, box: bounding_box.BoundingBoxBase, - mask_volume_map=None, - # TODO(timblakely): Is this used anymore? - # volume_decorator_fn=lambda x: - image=None): + decorated_volume_loader: Callable[[metadata.DecoratedVolume], np.ndarray], + mask_volume_map: dict[str, Any] | None = None, + image: np.ndarray | None = None, + volume_decorator_fn: Callable[[np.ndarray], np.ndarray] = lambda x: x, +) -> np.ndarray: """Builds a boolean mask. Args: masks: iterable of MaskConfig proto or MaskConfigs proto box: bounding box defining the area for which to build the mask + decorated_volume_loader: Function to load a DecoratedVolume object, return + an ndarray-like object. mask_volume_map: optional dict mapping volume proto hashes to open volumes; use this as a cache to avoid opening volumes multiple times. image: 3d image ndarray; only needed if the mask config uses the image as input + volume_decorator_fn: callable taking a volume object and returning another + object supporting the same interface. Returns: boolean mask built according to the specified config @@ -147,49 +153,43 @@ def build_mask( if isinstance(masks, MaskConfigs): invert = masks.invert masks = masks.masks + masks = typing.cast(Sequence[MaskConfig], masks) final_mask = None if mask_volume_map is None: mask_volume_map = {} - src_box = box - - z, y, x = np.mgrid[src_box.to_slice3d()] # pylint:disable=unused-variable - final_mask: Optional[np.ndarray] = None + z, y, x = np.mgrid[box.to_slice3d()] # pylint:disable=unused-variable + mask = None for config in masks: curr_mask = np.zeros(box.size[::-1], dtype=bool) - channels: list[MaskChannelConfig] = [] - mask: Optional[np.ndarray] = None - if config.coordinate_expression is not None: bool_mask = eval(config.coordinate_expression.expression) # pylint: disable=eval-used - curr_mask = np.array(bool_mask) + curr_mask |= bool_mask else: - if config.image: + if config.image is not None: assert image is not None - if config.image.expression: + if config.image.expression is not None: channels = [] assert not config.image.channels curr_mask |= eval(config.image.expression) # pylint: disable=eval-used else: channels = config.image.channels mask = image[np.newaxis, ...] - elif config.volume: + elif config.volume is not None: channels = config.volume.channels volume_key = config.volume.mask.to_json() if volume_key not in mask_volume_map: - mask_volume_map[volume_key] = descriptor.open_descriptor( - config.volume.mask) + mask_volume_map[volume_key] = volume_decorator_fn( + decorated_volume_loader(config.volume.mask) + ) volume = mask_volume_map[volume_key] - # TODO(timblakely): Have these handle subvolumes. - mask = np.asarray(volume[box.to_slice4d()].data) + mask = volume[box.to_slice()] else: - raise ValueError('Unsupported mask configuration') - - assert mask is not None + logging.fatal('Unsupported mask source: %s', config.to_json()) for chan_config in channels: channel_mask = mask[chan_config.channel, ...] @@ -200,22 +200,28 @@ def build_mask( if chan_config.expression: bool_mask = eval(chan_config.expression) # pylint: disable=eval-used elif chan_config.values: - bool_mask = np.in1d(channel_mask, - chan_config.values).reshape(channel_mask.shape) + bool_mask = np.in1d(channel_mask, chan_config.values).reshape( + channel_mask.shape + ) else: - bool_mask = ((channel_mask >= chan_config.min_value) & - (channel_mask <= chan_config.max_value)) + assert chan_config.max_value >= chan_config.min_value + bool_mask = (channel_mask >= chan_config.min_value) & ( + channel_mask <= chan_config.max_value + ) if chan_config.invert: bool_mask = np.logical_not(bool_mask) if chan_config.fov is not None: fov = chan_config.fov + # TODO(timblakely): Type checker appears to be confused without this + # check...? assert fov is not None with _MASK_SEM: bool_mask = ndimage.maximum_filter( bool_mask, size=fov.size[::-1], - origin=fov.size[::-1] // 2 + fov.start[::-1]) + origin=fov.size[::-1] // 2 + fov.start[::-1], + ) curr_mask |= bool_mask if config.invert: @@ -227,6 +233,7 @@ def build_mask( final_mask |= curr_mask assert final_mask is not None + if invert: return np.logical_not(final_mask) else: diff --git a/connectomics/volume/mask_test.py b/connectomics/volume/mask_test.py index b24dc0e..9a2e417 100644 --- a/connectomics/volume/mask_test.py +++ b/connectomics/volume/mask_test.py @@ -39,7 +39,12 @@ def test_build_mask(self): box = bounding_box.BoundingBox(start=(0, 0, 0), size=subvol_size[::-1]) - mask = m.build_mask([mask_config], box, image=image) + mask = m.build_mask( + [mask_config], + box, + decorated_volume_loader=lambda x: x, + image=image, + ) np.testing.assert_array_equal(mask, image >= 0.5) @@ -47,10 +52,17 @@ def test_build_mask(self): image = np.random.randint(0, 10, subvol_size, dtype=np.uint8) chan_config.values = [1, 5, 8] - mask = m.build_mask([mask_config], box, image=image) + self.called = False + mask = m.build_mask( + [mask_config], + box, + decorated_volume_loader=lambda x: x, + image=image, + ) - np.testing.assert_array_equal(mask, - (image == 1) | (image == 5) | (image == 8)) + np.testing.assert_array_equal( + mask, (image == 1) | (image == 5) | (image == 8) + ) if __name__ == '__main__':