Skip to content

Commit

Permalink
Create DecoratorSpecs definitions that are compatible with TensorStore.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 665968766
  • Loading branch information
timblakely authored and copybara-github committed Aug 23, 2024
1 parent 63cd4d8 commit 52e3f57
Show file tree
Hide file tree
Showing 11 changed files with 239 additions and 83 deletions.
11 changes: 9 additions & 2 deletions connectomics/common/import_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,21 @@ def import_symbol(
Args:
specifier: full path specifier in format
[<packages>.]<module_name>.<model_class>, if packages is missing
``default_packages`` is used.
``default_packages`` is used. Alternatively, the specifier can be just a
class name within a module specified by default_packages.
default_packages: chain of packages before module in format
<top_pack>.<sub_pack>.<subsub_pack> etc.
Returns:
symbol: object from module
"""
module_path, symbol_name = specifier.rsplit('.', 1)

try:
module_path, symbol_name = specifier.rsplit('.', 1)
except ValueError as _:
module_path = default_packages
symbol_name = specifier

try:
logging.info(
'Importing symbol %s from %s.%s',
Expand Down
29 changes: 18 additions & 11 deletions connectomics/volume/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from connectomics.common import array
from connectomics.common import bounding_box
from connectomics.volume import metadata
from connectomics.volume import subvolume
import numpy as np

Expand All @@ -34,10 +35,10 @@ def slice_to_bbox(ind: array.CanonicalSlice) -> bounding_box.BoundingBox:

class VolumeIndexer:
"""Interface for indexing supporting point lookups and slices."""
_volume: 'BaseVolume'
_volume: 'Volume'
slices: array.CanonicalSlice

def __init__(self, volume: 'BaseVolume'):
def __init__(self, volume: 'Volume'):
self._volume = volume

def __getitem__(
Expand All @@ -64,9 +65,14 @@ def __getitem__(self, ind: array.IndexExpOrPointLookups) -> np.ndarray:

# TODO(timblakely): Make generic-typed so it exposes both VolumeInfo and
# Tensorstore via .descriptor.
class BaseVolume:
class Volume:
"""Common interface to multiple volume backends for Decorators."""

meta: metadata.VolumeMetadata

def __init__(self, meta: metadata.VolumeMetadata):
self.meta = meta

def __getitem__(
self, ind: array.IndexExpOrPointLookups) -> Union[np.ndarray, Subvolume]:
return VolumeIndexer(self)[ind]
Expand Down Expand Up @@ -110,32 +116,33 @@ def write(self, subvol: subvolume.Subvolume):
@property
def volume_size(self) -> array.Tuple3i:
"""Volume size in voxels, XYZ."""
raise NotImplementedError
return self.meta.volume_size

@property
def voxel_size(self) -> array.Tuple3f:
def pixel_size(self) -> array.Tuple3f:
"""Size of an individual voxels in physical dimensions (Nanometers)."""
raise NotImplementedError
return self.meta.pixel_size

@property
def shape(self) -> array.Tuple4i:
"""Shape of the volume in voxels, CZYX."""
raise NotImplementedError
return (self.meta.num_channels,) + self.volume_size[::-1]

@property
def ndim(self) -> int:
"""Number of dimensions in this volume."""
raise NotImplementedError
# TODO(timblakely): Support 3D volumes?
return 4

@property
def dtype(self) -> np.dtype:
"""Datatype of the underlying data."""
raise NotImplementedError
return self.meta.dtype

@property
def bounding_boxes(self) -> list[bounding_box.BoundingBox]:
"""List of bounding boxes contained in this volume."""
raise NotImplementedError
return self.meta.bounding_boxes

@property
def chunk_size(self) -> array.Tuple4i:
Expand All @@ -156,7 +163,7 @@ def clip_box_to_volume(


def get_bounding_boxes_or_full(
volume: BaseVolume,
volume: Volume,
bounding_boxes: Optional[Sequence[bounding_box.BoundingBoxBase]] = None,
clip: bool = False,
) -> list[bounding_box.BoundingBox]:
Expand Down
29 changes: 12 additions & 17 deletions connectomics/volume/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,37 +15,32 @@
"""Tests for base."""

from absl.testing import absltest
from connectomics.common import array
from connectomics.common import bounding_box
from connectomics.volume import base
from connectomics.volume import metadata
import numpy as np
import numpy.testing as npt

Box = bounding_box.BoundingBox


class ShimVolume(base.BaseVolume):
class ShimVolume(base.Volume):

def __init__(self, *args, **kwargs):
super(*args, **kwargs)
self.called = False
default_metadata = metadata.VolumeMetadata(
volume_size=(10, 11, 12),
pixel_size=(1, 2, 3),
bounding_boxes=[Box([0, 0, 0], [10, 20, 30])],
num_channels=1,
dtype=np.float32,
)

@property
def shape(self) -> array.Tuple4i:
return (1, 12, 11, 10)
def __init__(self):
super().__init__(self.default_metadata)
self.called = False


class BaseVolumeTest(absltest.TestCase):

def test_not_implemented(self):
v = base.BaseVolume()

for field in [
'volume_size', 'voxel_size', 'shape', 'ndim', 'dtype', 'bounding_boxes'
]:
with self.assertRaises(NotImplementedError):
_ = getattr(v, field)

def test_get_points(self):
tself = self

Expand Down
41 changes: 41 additions & 0 deletions connectomics/volume/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
"""

import copy
import dataclasses
import enum
import pprint
from typing import Any, Iterable, Mapping, MutableMapping, Optional, Sequence, Union

from absl import logging
from connectomics.common import counters
from connectomics.common import import_util
from connectomics.common import metadata_utils
import dataclasses_json
import gin
import jax
import numpy as np
Expand Down Expand Up @@ -1477,3 +1480,41 @@ def debug_string(self) -> str:
'multiscale_spec:\n' +
pprint.pformat(self.multiscale_spec)
)


@dataclasses_json.dataclass_json(undefined=dataclasses_json.Undefined.INCLUDE)
@dataclasses.dataclass(frozen=True)
class DecoratorArgs:
"""Empty dataclass to allow automatic parsing of decorator args.
This precludes the need to define a dataclass for each decorator. All
undefined fields are included in the resulting python object.
"""

values: dataclasses_json.CatchAll


@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class DecoratorSpec:
"""Decorator specification.
Attributes:
name: Name of the decorator.
args: Arguments for decorator's constructor.
package: Package where the decorator is defined.
"""

name: str
args: DecoratorArgs | None = None
package: str | None = None


def build_decorator(spec: DecoratorSpec) -> Decorator:
"""Builds a Decorator from a DecoratorSpec."""
package = spec.package
if package is None:
package = 'connectomics.volume.decorators'
decorator_cls = import_util.import_symbol(spec.name, package)
args = spec.args.values if spec.args else {}
return decorator_cls(**args)
101 changes: 101 additions & 0 deletions connectomics/volume/decorators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tests for decorators."""

import copy
import json

from absl.testing import absltest
from connectomics.volume import decorators
Expand Down Expand Up @@ -723,5 +724,105 @@ def test_clobber_with_input_spec(self):
})


class TestDecorator(decorators.Decorator):

def __init__(self, foo: int, bar: str):
self.foo = foo
self.bar = bar


class DecoratorSpecTest(absltest.TestCase):

def test_decorator_args_unknown(self):
expected_args = {
'downsample_factors': [1, 2],
'method': 'max',
}
args = decorators.DecoratorArgs.from_json(json.dumps(expected_args))
self.assertEqual(args.values['downsample_factors'], [1, 2])
self.assertEqual(args.values['method'], 'max')
self.assertEqual(args.to_dict(), expected_args)

def test_decorator_spec(self):
expected_spec = {
'name': 'Downsample',
}
spec = decorators.DecoratorSpec.from_json(json.dumps(expected_spec))
self.assertEqual(spec.name, 'Downsample')
self.assertIsNone(spec.args)
self.assertIsNone(spec.package)

expected_spec = {
'name': 'Downsample',
'args': {
'downsample_factors': [1, 2],
'method': 'max',
},
'package': 'foo.bar.baz',
}
spec = decorators.DecoratorSpec.from_json(json.dumps(expected_spec))
args = decorators.DecoratorArgs.from_dict(expected_spec['args'])
self.assertEqual(args.to_dict(), expected_spec['args'])
self.assertEqual(spec.to_dict(), expected_spec)

def test_build_decorator(self):
spec = decorators.DecoratorSpec.from_json(
json.dumps({
'name': 'Downsample',
'args': {
'downsample_factors': [2, 4],
'method': 'max',
},
})
)
decorator = decorators.build_decorator(spec)
self.assertIsInstance(decorator, decorators.Downsample)
self.assertEqual(decorator._downsample_factors, [2, 4])
self.assertEqual(decorator._method, 'max')

def test_build_decorator_with_bad_args(self):
spec = decorators.DecoratorSpec.from_json(
json.dumps({
'name': 'Downsample',
'args': {
'downsample_factors': [2, 4],
'method': 'max',
'BAD_ARG': 'very_bad',
},
})
)
with self.assertRaises(TypeError):
decorators.build_decorator(spec)

spec = decorators.DecoratorSpec.from_json(
json.dumps({
'name': 'Downsample',
'args': {
'downsample_factors': [2, 4],
# missing method
},
})
)
with self.assertRaises(TypeError):
decorators.build_decorator(spec)

def test_build_decorator_with_package(self):
spec = decorators.DecoratorSpec.from_json(
json.dumps({
'name': 'TestDecorator',
'args': {
'foo': 1,
'bar': 'baz',
},
'package': 'connectomics.volume.decorators_test',
})
)
decorator = decorators.build_decorator(spec)
# Can't use assertIsInstance because the package is imported.
self.assertEqual(decorator.__class__.__name__, TestDecorator.__name__)
self.assertEqual(decorator.foo, 1)
self.assertEqual(decorator.bar, 'baz')


if __name__ == '__main__':
absltest.main()
2 changes: 1 addition & 1 deletion connectomics/volume/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def load_descriptor(spec: Union[str, VolumeDescriptor]) -> VolumeDescriptor:

def open_descriptor(
spec: Union[str, VolumeDescriptor],
context: Optional[dict[str, Any]] = None) -> base.BaseVolume:
context: Optional[dict[str, Any]] = None) -> base.Volume:
"""Open a volume from a volume descriptor.
Args:
Expand Down
24 changes: 18 additions & 6 deletions connectomics/volume/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

from connectomics.common import bounding_box
from connectomics.common import file
from connectomics.volume import decorators
import dataclasses_json
import numpy as np
import numpy.typing as npt


@dataclasses_json.dataclass_json
Expand All @@ -32,13 +35,20 @@ class VolumeMetadata:
volume_size: Volume size in voxels. XYZ order.
pixel_size: Pixel size in nm. XYZ order.
bounding_boxes: Bounding boxes associated with the volume.
num_channels: Number of channels in the volume.
dtype: Datatype of the volume. Must be numpy compatible.
"""
volume_size: tuple[int, int, int]
pixel_size: tuple[float, float, float]
bounding_boxes: list[bounding_box.BoundingBox]
# TODO(timblakely): In the event we want to enforce the assumption that volumes
# are XYZC (i.e. processing happens differently for spatial and channel axes),
# add num_channels to this class to record any changes in channel counts.
num_channels: int = 1
dtype: npt.DTypeLike = dataclasses.field(
metadata=dataclasses_json.config(
decoder=np.dtype,
encoder=lambda x: np.dtype(x).name,
),
default=np.uint8,
)

def scale(
self, scale_factors: float | Sequence[float]
Expand Down Expand Up @@ -103,9 +113,11 @@ class DecoratedVolume:
Attributes:
path: The path to the volume.
decorator_specs: A JSON string of decorator specs.
decorator_specs: A JSON string of decorator specs, or one or more
DecoratorSpec objects.
"""

path: pathlib.Path
# TODO(timblakely): This should be a list of DecoratorSpec dataclasses.
decorator_specs: str
decorator_specs: (
str | decorators.DecoratorSpec | list[decorators.DecoratorSpec]
)
Loading

0 comments on commit 52e3f57

Please sign in to comment.