Skip to content

Commit

Permalink
Create DecoratorSpecs definitions that are compatible with TensorStore.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 665968766
  • Loading branch information
timblakely authored and copybara-github committed Aug 23, 2024
1 parent 214ffac commit 9f16d2e
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 5 deletions.
11 changes: 9 additions & 2 deletions connectomics/common/import_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,21 @@ def import_symbol(
Args:
specifier: full path specifier in format
[<packages>.]<module_name>.<model_class>, if packages is missing
``default_packages`` is used.
``default_packages`` is used. Alternatively, the specifier can be just a
class name within a module specified by default_packages.
default_packages: chain of packages before module in format
<top_pack>.<sub_pack>.<subsub_pack> etc.
Returns:
symbol: object from module
"""
module_path, symbol_name = specifier.rsplit('.', 1)

try:
module_path, symbol_name = specifier.rsplit('.', 1)
except ValueError as _:
module_path = default_packages
symbol_name = specifier

try:
logging.info(
'Importing symbol %s from %s.%s',
Expand Down
41 changes: 41 additions & 0 deletions connectomics/volume/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
"""

import copy
import dataclasses
import enum
import pprint
from typing import Any, Iterable, Mapping, MutableMapping, Optional, Sequence, Union

from absl import logging
from connectomics.common import counters
from connectomics.common import import_util
from connectomics.common import metadata_utils
import dataclasses_json
import gin
import jax
import numpy as np
Expand Down Expand Up @@ -1477,3 +1480,41 @@ def debug_string(self) -> str:
'multiscale_spec:\n' +
pprint.pformat(self.multiscale_spec)
)


@dataclasses_json.dataclass_json(undefined=dataclasses_json.Undefined.INCLUDE)
@dataclasses.dataclass(frozen=True)
class DecoratorArgs:
"""Empty dataclass to allow automatic parsing of decorator args.
This precludes the need to define a dataclass for each decorator. All
undefined fields are included in the resulting python object.
"""

values: dataclasses_json.CatchAll


@dataclasses_json.dataclass_json
@dataclasses.dataclass(frozen=True)
class DecoratorSpec:
"""Decorator specification.
Attributes:
name: Name of the decorator.
args: Arguments for decorator's constructor.
package: Package where the decorator is defined.
"""

name: str
args: DecoratorArgs | None = None
package: str | None = None


def build_decorator(spec: DecoratorSpec) -> Decorator:
"""Builds a Decorator from a DecoratorSpec."""
package = spec.package
if package is None:
package = 'connectomics.volume.decorators'
decorator_cls = import_util.import_symbol(spec.name, package)
args = spec.args.values if spec.args else {}
return decorator_cls(**args)
101 changes: 101 additions & 0 deletions connectomics/volume/decorators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tests for decorators."""

import copy
import json

from absl.testing import absltest
from connectomics.volume import decorators
Expand Down Expand Up @@ -723,5 +724,105 @@ def test_clobber_with_input_spec(self):
})


class TestDecorator(decorators.Decorator):

def __init__(self, foo: int, bar: str):
self.foo = foo
self.bar = bar


class DecoratorSpecTest(absltest.TestCase):

def test_decorator_args_unknown(self):
expected_args = {
'downsample_factors': [1, 2],
'method': 'max',
}
args = decorators.DecoratorArgs.from_json(json.dumps(expected_args))
self.assertEqual(args.values['downsample_factors'], [1, 2])
self.assertEqual(args.values['method'], 'max')
self.assertEqual(args.to_dict(), expected_args)

def test_decorator_spec(self):
expected_spec = {
'name': 'Downsample',
}
spec = decorators.DecoratorSpec.from_json(json.dumps(expected_spec))
self.assertEqual(spec.name, 'Downsample')
self.assertIsNone(spec.args)
self.assertIsNone(spec.package)

expected_spec = {
'name': 'Downsample',
'args': {
'downsample_factors': [1, 2],
'method': 'max',
},
'package': 'foo.bar.baz',
}
spec = decorators.DecoratorSpec.from_json(json.dumps(expected_spec))
args = decorators.DecoratorArgs.from_dict(expected_spec['args'])
self.assertEqual(args.to_dict(), expected_spec['args'])
self.assertEqual(spec.to_dict(), expected_spec)

def test_build_decorator(self):
spec = decorators.DecoratorSpec.from_json(
json.dumps({
'name': 'Downsample',
'args': {
'downsample_factors': [2, 4],
'method': 'max',
},
})
)
decorator = decorators.build_decorator(spec)
self.assertIsInstance(decorator, decorators.Downsample)
self.assertEqual(decorator._downsample_factors, [2, 4])
self.assertEqual(decorator._method, 'max')

def test_build_decorator_with_bad_args(self):
spec = decorators.DecoratorSpec.from_json(
json.dumps({
'name': 'Downsample',
'args': {
'downsample_factors': [2, 4],
'method': 'max',
'BAD_ARG': 'very_bad',
},
})
)
with self.assertRaises(TypeError):
decorators.build_decorator(spec)

spec = decorators.DecoratorSpec.from_json(
json.dumps({
'name': 'Downsample',
'args': {
'downsample_factors': [2, 4],
# missing method
},
})
)
with self.assertRaises(TypeError):
decorators.build_decorator(spec)

def test_build_decorator_with_package(self):
spec = decorators.DecoratorSpec.from_json(
json.dumps({
'name': 'TestDecorator',
'args': {
'foo': 1,
'bar': 'baz',
},
'package': 'connectomics.volume.decorators_test',
})
)
decorator = decorators.build_decorator(spec)
# Can't use assertIsInstance because the package is imported.
self.assertEqual(decorator.__class__.__name__, TestDecorator.__name__)
self.assertEqual(decorator.foo, 1)
self.assertEqual(decorator.bar, 'baz')


if __name__ == '__main__':
absltest.main()
9 changes: 6 additions & 3 deletions connectomics/volume/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from connectomics.common import bounding_box
from connectomics.common import file
from connectomics.volume import decorators
import dataclasses_json
import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -112,9 +113,11 @@ class DecoratedVolume:
Attributes:
path: The path to the volume.
decorator_specs: A JSON string of decorator specs.
decorator_specs: A JSON string of decorator specs, or one or more
DecoratorSpec objects.
"""

path: pathlib.Path
# TODO(timblakely): This should be a list of DecoratorSpec dataclasses.
decorator_specs: str
decorator_specs: (
str | decorators.DecoratorSpec | list[decorators.DecoratorSpec]
)

0 comments on commit 9f16d2e

Please sign in to comment.