From 52e3f579637a4277ecc1f78343577c6d1204917f Mon Sep 17 00:00:00 2001
From: Tim Blakely <blakely@google.com>
Date: Wed, 21 Aug 2024 11:45:59 -0700
Subject: [PATCH] Create DecoratorSpecs definitions that are compatible with
 TensorStore.

PiperOrigin-RevId: 665968766
---
 connectomics/common/import_util.py        |  11 ++-
 connectomics/volume/base.py               |  29 ++++---
 connectomics/volume/base_test.py          |  29 +++----
 connectomics/volume/decorators.py         |  41 +++++++++
 connectomics/volume/decorators_test.py    | 101 ++++++++++++++++++++++
 connectomics/volume/descriptor.py         |   2 +-
 connectomics/volume/metadata.py           |  24 +++--
 connectomics/volume/metadata_test.py      |   2 +
 connectomics/volume/tensorstore.py        |   2 +-
 connectomics/volume/tsv_decorator.py      |  18 ++--
 connectomics/volume/tsv_decorator_test.py |  63 ++++++--------
 11 files changed, 239 insertions(+), 83 deletions(-)

diff --git a/connectomics/common/import_util.py b/connectomics/common/import_util.py
index 5bea0e3..6d76369 100644
--- a/connectomics/common/import_util.py
+++ b/connectomics/common/import_util.py
@@ -29,14 +29,21 @@ def import_symbol(
   Args:
     specifier: full path specifier in format
       [<packages>.]<module_name>.<model_class>, if packages is missing
-      ``default_packages`` is used.
+      ``default_packages`` is used. Alternatively, the specifier can be just a
+      class name within a module specified by default_packages.
     default_packages: chain of packages before module in format
       <top_pack>.<sub_pack>.<subsub_pack> etc.
 
   Returns:
     symbol: object from module
   """
-  module_path, symbol_name = specifier.rsplit('.', 1)
+
+  try:
+    module_path, symbol_name = specifier.rsplit('.', 1)
+  except ValueError as _:
+    module_path = default_packages
+    symbol_name = specifier
+
   try:
     logging.info(
         'Importing symbol %s from %s.%s',
diff --git a/connectomics/volume/base.py b/connectomics/volume/base.py
index db6cf5f..fd042bf 100644
--- a/connectomics/volume/base.py
+++ b/connectomics/volume/base.py
@@ -21,6 +21,7 @@
 
 from connectomics.common import array
 from connectomics.common import bounding_box
+from connectomics.volume import metadata
 from connectomics.volume import subvolume
 import numpy as np
 
@@ -34,10 +35,10 @@ def slice_to_bbox(ind: array.CanonicalSlice) -> bounding_box.BoundingBox:
 
 class VolumeIndexer:
   """Interface for indexing supporting point lookups and slices."""
-  _volume: 'BaseVolume'
+  _volume: 'Volume'
   slices: array.CanonicalSlice
 
-  def __init__(self, volume: 'BaseVolume'):
+  def __init__(self, volume: 'Volume'):
     self._volume = volume
 
   def __getitem__(
@@ -64,9 +65,14 @@ def __getitem__(self, ind: array.IndexExpOrPointLookups) -> np.ndarray:
 
 # TODO(timblakely): Make generic-typed so it exposes both VolumeInfo and
 # Tensorstore via .descriptor.
-class BaseVolume:
+class Volume:
   """Common interface to multiple volume backends for Decorators."""
 
+  meta: metadata.VolumeMetadata
+
+  def __init__(self, meta: metadata.VolumeMetadata):
+    self.meta = meta
+
   def __getitem__(
       self, ind: array.IndexExpOrPointLookups) -> Union[np.ndarray, Subvolume]:
     return VolumeIndexer(self)[ind]
@@ -110,32 +116,33 @@ def write(self, subvol: subvolume.Subvolume):
   @property
   def volume_size(self) -> array.Tuple3i:
     """Volume size in voxels, XYZ."""
-    raise NotImplementedError
+    return self.meta.volume_size
 
   @property
-  def voxel_size(self) -> array.Tuple3f:
+  def pixel_size(self) -> array.Tuple3f:
     """Size of an individual voxels in physical dimensions (Nanometers)."""
-    raise NotImplementedError
+    return self.meta.pixel_size
 
   @property
   def shape(self) -> array.Tuple4i:
     """Shape of the volume in voxels, CZYX."""
-    raise NotImplementedError
+    return (self.meta.num_channels,) + self.volume_size[::-1]
 
   @property
   def ndim(self) -> int:
     """Number of dimensions in this volume."""
-    raise NotImplementedError
+    # TODO(timblakely): Support 3D volumes?
+    return 4
 
   @property
   def dtype(self) -> np.dtype:
     """Datatype of the underlying data."""
-    raise NotImplementedError
+    return self.meta.dtype
 
   @property
   def bounding_boxes(self) -> list[bounding_box.BoundingBox]:
     """List of bounding boxes contained in this volume."""
-    raise NotImplementedError
+    return self.meta.bounding_boxes
 
   @property
   def chunk_size(self) -> array.Tuple4i:
@@ -156,7 +163,7 @@ def clip_box_to_volume(
 
 
 def get_bounding_boxes_or_full(
-    volume: BaseVolume,
+    volume: Volume,
     bounding_boxes: Optional[Sequence[bounding_box.BoundingBoxBase]] = None,
     clip: bool = False,
 ) -> list[bounding_box.BoundingBox]:
diff --git a/connectomics/volume/base_test.py b/connectomics/volume/base_test.py
index ef030e5..d27e24c 100644
--- a/connectomics/volume/base_test.py
+++ b/connectomics/volume/base_test.py
@@ -15,37 +15,32 @@
 """Tests for base."""
 
 from absl.testing import absltest
-from connectomics.common import array
 from connectomics.common import bounding_box
 from connectomics.volume import base
+from connectomics.volume import metadata
 import numpy as np
 import numpy.testing as npt
 
 Box = bounding_box.BoundingBox
 
 
-class ShimVolume(base.BaseVolume):
+class ShimVolume(base.Volume):
 
-  def __init__(self, *args, **kwargs):
-    super(*args, **kwargs)
-    self.called = False
+  default_metadata = metadata.VolumeMetadata(
+      volume_size=(10, 11, 12),
+      pixel_size=(1, 2, 3),
+      bounding_boxes=[Box([0, 0, 0], [10, 20, 30])],
+      num_channels=1,
+      dtype=np.float32,
+  )
 
-  @property
-  def shape(self) -> array.Tuple4i:
-    return (1, 12, 11, 10)
+  def __init__(self):
+    super().__init__(self.default_metadata)
+    self.called = False
 
 
 class BaseVolumeTest(absltest.TestCase):
 
-  def test_not_implemented(self):
-    v = base.BaseVolume()
-
-    for field in [
-        'volume_size', 'voxel_size', 'shape', 'ndim', 'dtype', 'bounding_boxes'
-    ]:
-      with self.assertRaises(NotImplementedError):
-        _ = getattr(v, field)
-
   def test_get_points(self):
     tself = self
 
diff --git a/connectomics/volume/decorators.py b/connectomics/volume/decorators.py
index ed17cea..09e51e2 100644
--- a/connectomics/volume/decorators.py
+++ b/connectomics/volume/decorators.py
@@ -20,13 +20,16 @@
 """
 
 import copy
+import dataclasses
 import enum
 import pprint
 from typing import Any, Iterable, Mapping, MutableMapping, Optional, Sequence, Union
 
 from absl import logging
 from connectomics.common import counters
+from connectomics.common import import_util
 from connectomics.common import metadata_utils
+import dataclasses_json
 import gin
 import jax
 import numpy as np
@@ -1477,3 +1480,41 @@ def debug_string(self) -> str:
         'multiscale_spec:\n' +
         pprint.pformat(self.multiscale_spec)
     )
+
+
+@dataclasses_json.dataclass_json(undefined=dataclasses_json.Undefined.INCLUDE)
+@dataclasses.dataclass(frozen=True)
+class DecoratorArgs:
+  """Empty dataclass to allow automatic parsing of decorator args.
+
+  This precludes the need to define a dataclass for each decorator. All
+  undefined fields are included in the resulting python object.
+  """
+
+  values: dataclasses_json.CatchAll
+
+
+@dataclasses_json.dataclass_json
+@dataclasses.dataclass(frozen=True)
+class DecoratorSpec:
+  """Decorator specification.
+
+  Attributes:
+    name: Name of the decorator.
+    args: Arguments for decorator's constructor.
+    package: Package where the decorator is defined.
+  """
+
+  name: str
+  args: DecoratorArgs | None = None
+  package: str | None = None
+
+
+def build_decorator(spec: DecoratorSpec) -> Decorator:
+  """Builds a Decorator from a DecoratorSpec."""
+  package = spec.package
+  if package is None:
+    package = 'connectomics.volume.decorators'
+  decorator_cls = import_util.import_symbol(spec.name, package)
+  args = spec.args.values if spec.args else {}
+  return decorator_cls(**args)
diff --git a/connectomics/volume/decorators_test.py b/connectomics/volume/decorators_test.py
index 62358d5..f247ca7 100644
--- a/connectomics/volume/decorators_test.py
+++ b/connectomics/volume/decorators_test.py
@@ -15,6 +15,7 @@
 """Tests for decorators."""
 
 import copy
+import json
 
 from absl.testing import absltest
 from connectomics.volume import decorators
@@ -723,5 +724,105 @@ def test_clobber_with_input_spec(self):
     })
 
 
+class TestDecorator(decorators.Decorator):
+
+  def __init__(self, foo: int, bar: str):
+    self.foo = foo
+    self.bar = bar
+
+
+class DecoratorSpecTest(absltest.TestCase):
+
+  def test_decorator_args_unknown(self):
+    expected_args = {
+        'downsample_factors': [1, 2],
+        'method': 'max',
+    }
+    args = decorators.DecoratorArgs.from_json(json.dumps(expected_args))
+    self.assertEqual(args.values['downsample_factors'], [1, 2])
+    self.assertEqual(args.values['method'], 'max')
+    self.assertEqual(args.to_dict(), expected_args)
+
+  def test_decorator_spec(self):
+    expected_spec = {
+        'name': 'Downsample',
+    }
+    spec = decorators.DecoratorSpec.from_json(json.dumps(expected_spec))
+    self.assertEqual(spec.name, 'Downsample')
+    self.assertIsNone(spec.args)
+    self.assertIsNone(spec.package)
+
+    expected_spec = {
+        'name': 'Downsample',
+        'args': {
+            'downsample_factors': [1, 2],
+            'method': 'max',
+        },
+        'package': 'foo.bar.baz',
+    }
+    spec = decorators.DecoratorSpec.from_json(json.dumps(expected_spec))
+    args = decorators.DecoratorArgs.from_dict(expected_spec['args'])
+    self.assertEqual(args.to_dict(), expected_spec['args'])
+    self.assertEqual(spec.to_dict(), expected_spec)
+
+  def test_build_decorator(self):
+    spec = decorators.DecoratorSpec.from_json(
+        json.dumps({
+            'name': 'Downsample',
+            'args': {
+                'downsample_factors': [2, 4],
+                'method': 'max',
+            },
+        })
+    )
+    decorator = decorators.build_decorator(spec)
+    self.assertIsInstance(decorator, decorators.Downsample)
+    self.assertEqual(decorator._downsample_factors, [2, 4])
+    self.assertEqual(decorator._method, 'max')
+
+  def test_build_decorator_with_bad_args(self):
+    spec = decorators.DecoratorSpec.from_json(
+        json.dumps({
+            'name': 'Downsample',
+            'args': {
+                'downsample_factors': [2, 4],
+                'method': 'max',
+                'BAD_ARG': 'very_bad',
+            },
+        })
+    )
+    with self.assertRaises(TypeError):
+      decorators.build_decorator(spec)
+
+    spec = decorators.DecoratorSpec.from_json(
+        json.dumps({
+            'name': 'Downsample',
+            'args': {
+                'downsample_factors': [2, 4],
+                # missing method
+            },
+        })
+    )
+    with self.assertRaises(TypeError):
+      decorators.build_decorator(spec)
+
+  def test_build_decorator_with_package(self):
+    spec = decorators.DecoratorSpec.from_json(
+        json.dumps({
+            'name': 'TestDecorator',
+            'args': {
+                'foo': 1,
+                'bar': 'baz',
+            },
+            'package': 'connectomics.volume.decorators_test',
+        })
+    )
+    decorator = decorators.build_decorator(spec)
+    # Can't use assertIsInstance because the package is imported.
+    self.assertEqual(decorator.__class__.__name__, TestDecorator.__name__)
+    self.assertEqual(decorator.foo, 1)
+    self.assertEqual(decorator.bar, 'baz')
+
+
 if __name__ == '__main__':
   absltest.main()
diff --git a/connectomics/volume/descriptor.py b/connectomics/volume/descriptor.py
index 6678982..a4c6588 100644
--- a/connectomics/volume/descriptor.py
+++ b/connectomics/volume/descriptor.py
@@ -60,7 +60,7 @@ def load_descriptor(spec: Union[str, VolumeDescriptor]) -> VolumeDescriptor:
 
 def open_descriptor(
     spec: Union[str, VolumeDescriptor],
-    context: Optional[dict[str, Any]] = None) -> base.BaseVolume:
+    context: Optional[dict[str, Any]] = None) -> base.Volume:
   """Open a volume from a volume descriptor.
 
   Args:
diff --git a/connectomics/volume/metadata.py b/connectomics/volume/metadata.py
index 9fed778..269ce14 100644
--- a/connectomics/volume/metadata.py
+++ b/connectomics/volume/metadata.py
@@ -20,7 +20,10 @@
 
 from connectomics.common import bounding_box
 from connectomics.common import file
+from connectomics.volume import decorators
 import dataclasses_json
+import numpy as np
+import numpy.typing as npt
 
 
 @dataclasses_json.dataclass_json
@@ -32,13 +35,20 @@ class VolumeMetadata:
     volume_size: Volume size in voxels. XYZ order.
     pixel_size: Pixel size in nm. XYZ order.
     bounding_boxes: Bounding boxes associated with the volume.
+    num_channels: Number of channels in the volume.
+    dtype: Datatype of the volume. Must be numpy compatible.
   """
   volume_size: tuple[int, int, int]
   pixel_size: tuple[float, float, float]
   bounding_boxes: list[bounding_box.BoundingBox]
-  # TODO(timblakely): In the event we want to enforce the assumption that volumes
-  # are XYZC (i.e. processing happens differently for spatial and channel axes),
-  # add num_channels to this class to record any changes in channel counts.
+  num_channels: int = 1
+  dtype: npt.DTypeLike = dataclasses.field(
+      metadata=dataclasses_json.config(
+          decoder=np.dtype,
+          encoder=lambda x: np.dtype(x).name,
+      ),
+      default=np.uint8,
+  )
 
   def scale(
       self, scale_factors: float | Sequence[float]
@@ -103,9 +113,11 @@ class DecoratedVolume:
 
   Attributes:
     path: The path to the volume.
-    decorator_specs: A JSON string of decorator specs.
+    decorator_specs: A JSON string of decorator specs, or one or more
+      DecoratorSpec objects.
   """
 
   path: pathlib.Path
-  # TODO(timblakely): This should be a list of DecoratorSpec dataclasses.
-  decorator_specs: str
+  decorator_specs: (
+      str | decorators.DecoratorSpec | list[decorators.DecoratorSpec]
+  )
diff --git a/connectomics/volume/metadata_test.py b/connectomics/volume/metadata_test.py
index 99dd53b..bd25d84 100644
--- a/connectomics/volume/metadata_test.py
+++ b/connectomics/volume/metadata_test.py
@@ -20,6 +20,7 @@
 from absl.testing import absltest
 from connectomics.common import bounding_box
 from connectomics.volume import metadata
+import numpy as np
 
 
 FLAGS = flags.FLAGS
@@ -83,6 +84,7 @@ def test_volume_save_metadata(self):
         volume_size=(100, 100, 100),
         pixel_size=(8, 8, 30),
         bounding_boxes=[BBOX([10, 10, 10], [100, 100, 100])],
+        dtype=np.uint64,
     )
     temp_path = pathlib.Path(self.create_tempdir().full_path)
     vol = metadata.Volume(path=temp_path / 'foo.volinfo', meta=meta)
diff --git a/connectomics/volume/tensorstore.py b/connectomics/volume/tensorstore.py
index e49edf6..d22799c 100644
--- a/connectomics/volume/tensorstore.py
+++ b/connectomics/volume/tensorstore.py
@@ -59,7 +59,7 @@ class TensorstoreConfig(utils.NPDataClassJsonMixin):
           decoder=file.dataclass_loader(TensorstoreMetadata)))
 
 
-class TensorstoreVolume(base.BaseVolume):
+class TensorstoreVolume(base.Volume):
   """Tensorstore-backed Volume."""
 
   _store: ts.TensorStore
diff --git a/connectomics/volume/tsv_decorator.py b/connectomics/volume/tsv_decorator.py
index 2716804..fc821dd 100644
--- a/connectomics/volume/tsv_decorator.py
+++ b/connectomics/volume/tsv_decorator.py
@@ -34,7 +34,7 @@
 class DecoratorFactory:
   """Constructs a VolumeDecorator based on a name and arguments."""
 
-  def make_decorator(self, wrapped_volume: base.BaseVolume, name: str,
+  def make_decorator(self, wrapped_volume: base.Volume, name: str,
                      *args: list[Any],
                      **kwargs: dict[str, Any]) -> 'VolumeDecorator':
     raise NotImplementedError()
@@ -43,14 +43,14 @@ def make_decorator(self, wrapped_volume: base.BaseVolume, name: str,
 class GlobalsDecoratorFactory:
   """Loads VolumeDecorators from globals()."""
 
-  def make_decorator(self, wrapped_volume: base.BaseVolume, name: str,
+  def make_decorator(self, wrapped_volume: base.Volume, name: str,
                      *args: list[Any],
                      **kwargs: dict[str, Any]) -> 'VolumeDecorator':
     decorator_ctor = globals()[name]
     return decorator_ctor(wrapped_volume, *args, **kwargs)
 
 
-def from_specs(volume: base.BaseVolume,
+def from_specs(volume: base.Volume,
                specs: list[DecoratorSpec],
                decorator_factory: Optional[DecoratorFactory] = None):
   """Decorates the given volume from the given specs.
@@ -85,12 +85,12 @@ def from_specs(volume: base.BaseVolume,
   return volume
 
 
-class VolumeDecorator(base.BaseVolume):
+class VolumeDecorator(base.Volume):
   """Delegates to wrapped volumes, optionally applying transforms."""
 
-  wrapped: base.BaseVolume
+  wrapped: base.Volume
 
-  def __init__(self, wrapped: base.BaseVolume):
+  def __init__(self, wrapped: base.Volume):
     self._wrapped = wrapped
 
   def get_points(self, points: array.PointLookups) -> np.ndarray:
@@ -105,7 +105,7 @@ def volume_size(self) -> array.Tuple3i:
 
   @property
   def voxel_size(self) -> array.Tuple3f:
-    return self._wrapped.voxel_size
+    return self._wrapped.pixel_size
 
   @property
   def shape(self) -> array.Tuple4i:
@@ -133,7 +133,7 @@ class Upsample(VolumeDecorator):
 
   scale_zyx: np.ndarray
 
-  def __init__(self, wrapped: base.BaseVolume, scale: array.ArrayLike3d):
+  def __init__(self, wrapped: base.Volume, scale: array.ArrayLike3d):
     """Initializes the wrapper.
 
     Args:
@@ -154,7 +154,7 @@ def volume_size(self) -> array.Tuple3i:
 
   @property
   def voxel_size(self) -> array.Tuple3i:
-    return tuple(self._wrapped.voxel_size / self.scale_zyx[::-1])
+    return tuple(self._wrapped.pixel_size / self.scale_zyx[::-1])
 
   @property
   def shape(self) -> array.Tuple4i:
diff --git a/connectomics/volume/tsv_decorator_test.py b/connectomics/volume/tsv_decorator_test.py
index 3e84b48..7025466 100644
--- a/connectomics/volume/tsv_decorator_test.py
+++ b/connectomics/volume/tsv_decorator_test.py
@@ -15,15 +15,17 @@
 """Tests for tsv_decorator."""
 
 import typing
-from typing import Any, Sequence, Tuple
+from typing import Any
 
 from absl.testing import absltest
 from connectomics.common import array
 from connectomics.common import bounding_box
 from connectomics.volume import base as base_volume
 from connectomics.volume import descriptor as vd
+from connectomics.volume import metadata
 from connectomics.volume import tsv_decorator
 import numpy as np
+import numpy.typing as nptyping
 import numpy.testing as npt
 
 BBox = bounding_box.BoundingBox
@@ -33,13 +35,24 @@
 
 # TODO(timblakely): Create an common in-memory volume implementation. Would be
 # useful in both tests and in temporary volume situations.
-class DummyVolume(base_volume.BaseVolume):
-
-  def __init__(self, volume_size: Sequence[int], voxel_size: Sequence[int],
-               bounding_boxes: list[BBox], data: np.ndarray):
-    self._volume_size = tuple(volume_size)
-    self._voxel_size = tuple(voxel_size)
-    self._bounding_boxes = bounding_boxes
+class DummyVolume(base_volume.Volume):
+
+  def __init__(
+      self,
+      volume_size: tuple[int, int, int],
+      voxel_size: tuple[int, int, int],
+      bounding_boxes: list[BBox],
+      data: np.ndarray,
+      dtype: nptyping.DTypeLike,
+  ):
+    super().__init__(
+        metadata.VolumeMetadata(
+            volume_size=volume_size,
+            pixel_size=voxel_size,
+            bounding_boxes=bounding_boxes,
+            dtype=dtype,
+        )
+    )
     self._data = data
 
   def __getitem__(self, ind):
@@ -56,39 +69,17 @@ def get_points(self, points: array.PointLookups) -> np.ndarray:
   def get_slices(self, slices: array.CanonicalSlice) -> np.ndarray:
     return self._data[slices]
 
-  @property
-  def volume_size(self) -> array.Tuple3i:
-    return self._volume_size
-
-  @property
-  def voxel_size(self) -> array.Tuple3i:
-    return self._voxel_size
-
-  @property
-  def shape(self) -> array.Tuple4i:
-    return (1,) + tuple(self._volume_size[::-1])
-
-  @property
-  def ndim(self) -> int:
-    return len(self._data.shape)
-
-  @property
-  def dtype(self) -> np.dtype:
-    return self._data.dtype
-
-  @property
-  def bounding_boxes(self) -> list[BBox]:
-    return self._bounding_boxes
-
 
-def _make_dummy_vol() -> Tuple[DummyVolume, BBox, np.ndarray]:
+def _make_dummy_vol() -> tuple[DummyVolume, BBox, np.ndarray]:
   bbox = BBox([100, 200, 300], [20, 50, 100])
   data = np.zeros(bbox.size)
   data[0] = 1
   data = np.cumsum(
       np.cumsum(np.cumsum(data, axis=0), axis=1), axis=2, dtype=np.uint64)
   data = data[np.newaxis]
-  vol = DummyVolume([3000, 2000, 1000], (8, 8, 33), [bbox], data)
+  vol = DummyVolume(
+      (3000, 2000, 1000), (8, 8, 33), [bbox], data, dtype=np.uint64
+  )
   return vol, bbox, data
 
 
@@ -99,7 +90,7 @@ def test_dummy_volume(self):
     self.assertEqual((3000, 2000, 1000), vol.volume_size)
     self.assertLen(vol.bounding_boxes, 1)
     self.assertEqual([bbox], vol.bounding_boxes)
-    self.assertEqual((8, 8, 33), vol.voxel_size)
+    self.assertEqual((8, 8, 33), vol.pixel_size)
     self.assertEqual(np.uint64, vol.dtype)
     self.assertEqual(4, vol.ndim)
     self.assertEqual((1, 1000, 2000, 3000), vol.shape)
@@ -137,7 +128,7 @@ class CustomDecoratorFactory(tsv_decorator.DecoratorFactory):
   def __init__(self, *args, **kwargs):
     self.called = False
 
-  def make_decorator(self, wrapped_volume: base_volume.BaseVolume, name: str,
+  def make_decorator(self, wrapped_volume: base_volume.Volume, name: str,
                      *args: list[Any],
                      **kwargs: dict[str, Any]) -> tsv_decorator.VolumeDecorator:
     if name == 'CustomDecorator':