From d242f8f970159a6aae80968acde21884ffe09610 Mon Sep 17 00:00:00 2001 From: Tim Blakely Date: Fri, 23 Aug 2024 11:36:19 -0700 Subject: [PATCH] Create DecoratorSpecs definitions that are compatible with TensorStore. PiperOrigin-RevId: 666868050 --- connectomics/common/import_util.py | 11 ++- connectomics/volume/decorators.py | 41 ++++++++++ connectomics/volume/decorators_test.py | 101 +++++++++++++++++++++++++ connectomics/volume/metadata.py | 9 ++- 4 files changed, 157 insertions(+), 5 deletions(-) diff --git a/connectomics/common/import_util.py b/connectomics/common/import_util.py index 5bea0e3..6d76369 100644 --- a/connectomics/common/import_util.py +++ b/connectomics/common/import_util.py @@ -29,14 +29,21 @@ def import_symbol( Args: specifier: full path specifier in format [.]., if packages is missing - ``default_packages`` is used. + ``default_packages`` is used. Alternatively, the specifier can be just a + class name within a module specified by default_packages. default_packages: chain of packages before module in format .. etc. Returns: symbol: object from module """ - module_path, symbol_name = specifier.rsplit('.', 1) + + try: + module_path, symbol_name = specifier.rsplit('.', 1) + except ValueError as _: + module_path = default_packages + symbol_name = specifier + try: logging.info( 'Importing symbol %s from %s.%s', diff --git a/connectomics/volume/decorators.py b/connectomics/volume/decorators.py index ed17cea..09e51e2 100644 --- a/connectomics/volume/decorators.py +++ b/connectomics/volume/decorators.py @@ -20,13 +20,16 @@ """ import copy +import dataclasses import enum import pprint from typing import Any, Iterable, Mapping, MutableMapping, Optional, Sequence, Union from absl import logging from connectomics.common import counters +from connectomics.common import import_util from connectomics.common import metadata_utils +import dataclasses_json import gin import jax import numpy as np @@ -1477,3 +1480,41 @@ def debug_string(self) -> str: 'multiscale_spec:\n' + pprint.pformat(self.multiscale_spec) ) + + +@dataclasses_json.dataclass_json(undefined=dataclasses_json.Undefined.INCLUDE) +@dataclasses.dataclass(frozen=True) +class DecoratorArgs: + """Empty dataclass to allow automatic parsing of decorator args. + + This precludes the need to define a dataclass for each decorator. All + undefined fields are included in the resulting python object. + """ + + values: dataclasses_json.CatchAll + + +@dataclasses_json.dataclass_json +@dataclasses.dataclass(frozen=True) +class DecoratorSpec: + """Decorator specification. + + Attributes: + name: Name of the decorator. + args: Arguments for decorator's constructor. + package: Package where the decorator is defined. + """ + + name: str + args: DecoratorArgs | None = None + package: str | None = None + + +def build_decorator(spec: DecoratorSpec) -> Decorator: + """Builds a Decorator from a DecoratorSpec.""" + package = spec.package + if package is None: + package = 'connectomics.volume.decorators' + decorator_cls = import_util.import_symbol(spec.name, package) + args = spec.args.values if spec.args else {} + return decorator_cls(**args) diff --git a/connectomics/volume/decorators_test.py b/connectomics/volume/decorators_test.py index 62358d5..f247ca7 100644 --- a/connectomics/volume/decorators_test.py +++ b/connectomics/volume/decorators_test.py @@ -15,6 +15,7 @@ """Tests for decorators.""" import copy +import json from absl.testing import absltest from connectomics.volume import decorators @@ -723,5 +724,105 @@ def test_clobber_with_input_spec(self): }) +class TestDecorator(decorators.Decorator): + + def __init__(self, foo: int, bar: str): + self.foo = foo + self.bar = bar + + +class DecoratorSpecTest(absltest.TestCase): + + def test_decorator_args_unknown(self): + expected_args = { + 'downsample_factors': [1, 2], + 'method': 'max', + } + args = decorators.DecoratorArgs.from_json(json.dumps(expected_args)) + self.assertEqual(args.values['downsample_factors'], [1, 2]) + self.assertEqual(args.values['method'], 'max') + self.assertEqual(args.to_dict(), expected_args) + + def test_decorator_spec(self): + expected_spec = { + 'name': 'Downsample', + } + spec = decorators.DecoratorSpec.from_json(json.dumps(expected_spec)) + self.assertEqual(spec.name, 'Downsample') + self.assertIsNone(spec.args) + self.assertIsNone(spec.package) + + expected_spec = { + 'name': 'Downsample', + 'args': { + 'downsample_factors': [1, 2], + 'method': 'max', + }, + 'package': 'foo.bar.baz', + } + spec = decorators.DecoratorSpec.from_json(json.dumps(expected_spec)) + args = decorators.DecoratorArgs.from_dict(expected_spec['args']) + self.assertEqual(args.to_dict(), expected_spec['args']) + self.assertEqual(spec.to_dict(), expected_spec) + + def test_build_decorator(self): + spec = decorators.DecoratorSpec.from_json( + json.dumps({ + 'name': 'Downsample', + 'args': { + 'downsample_factors': [2, 4], + 'method': 'max', + }, + }) + ) + decorator = decorators.build_decorator(spec) + self.assertIsInstance(decorator, decorators.Downsample) + self.assertEqual(decorator._downsample_factors, [2, 4]) + self.assertEqual(decorator._method, 'max') + + def test_build_decorator_with_bad_args(self): + spec = decorators.DecoratorSpec.from_json( + json.dumps({ + 'name': 'Downsample', + 'args': { + 'downsample_factors': [2, 4], + 'method': 'max', + 'BAD_ARG': 'very_bad', + }, + }) + ) + with self.assertRaises(TypeError): + decorators.build_decorator(spec) + + spec = decorators.DecoratorSpec.from_json( + json.dumps({ + 'name': 'Downsample', + 'args': { + 'downsample_factors': [2, 4], + # missing method + }, + }) + ) + with self.assertRaises(TypeError): + decorators.build_decorator(spec) + + def test_build_decorator_with_package(self): + spec = decorators.DecoratorSpec.from_json( + json.dumps({ + 'name': 'TestDecorator', + 'args': { + 'foo': 1, + 'bar': 'baz', + }, + 'package': 'connectomics.volume.decorators_test', + }) + ) + decorator = decorators.build_decorator(spec) + # Can't use assertIsInstance because the package is imported. + self.assertEqual(decorator.__class__.__name__, TestDecorator.__name__) + self.assertEqual(decorator.foo, 1) + self.assertEqual(decorator.bar, 'baz') + + if __name__ == '__main__': absltest.main() diff --git a/connectomics/volume/metadata.py b/connectomics/volume/metadata.py index c9aaab3..269ce14 100644 --- a/connectomics/volume/metadata.py +++ b/connectomics/volume/metadata.py @@ -20,6 +20,7 @@ from connectomics.common import bounding_box from connectomics.common import file +from connectomics.volume import decorators import dataclasses_json import numpy as np import numpy.typing as npt @@ -112,9 +113,11 @@ class DecoratedVolume: Attributes: path: The path to the volume. - decorator_specs: A JSON string of decorator specs. + decorator_specs: A JSON string of decorator specs, or one or more + DecoratorSpec objects. """ path: pathlib.Path - # TODO(timblakely): This should be a list of DecoratorSpec dataclasses. - decorator_specs: str + decorator_specs: ( + str | decorators.DecoratorSpec | list[decorators.DecoratorSpec] + )