Skip to content

Commit

Permalink
Move default config registration and instantiation into connectomics.…
Browse files Browse the repository at this point in the history
…subvolume_processor.

PiperOrigin-RevId: 696215416
  • Loading branch information
timblakely authored and copybara-github committed Nov 16, 2024
1 parent 9137510 commit 6a4308b
Showing 1 changed file with 49 additions and 1 deletion.
50 changes: 49 additions & 1 deletion connectomics/volume/subvolume_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import enum
import importlib
import inspect
from typing import Any, List, Optional, Tuple, Type, Union
import logging
from typing import Any, Callable, List, Optional, Tuple, Type, TypeVar, Union

from connectomics.common import array
from connectomics.common import bounding_box
Expand Down Expand Up @@ -325,3 +326,50 @@ def get_processor(config: SubvolumeProcessorConfig) -> SubvolumeProcessor:
processor = getattr(package, name)
args = {} if not config.args else config.args
return processor(**args)


class DefaultConfigType(enum.Enum):
EM_2D = 'em_2d'
# TODO(timblakely): Support additional configuration.
# EM_3D = 'em_3d'
# LICONN = 'liconn'


_KNOWN_DEFAULT_PROCESSOR_CONFIGS = {}


T = TypeVar('T', bound=utils.IsDataclass)


def register_default_config(
config_type: DefaultConfigType,
config_class: Type[utils.IsDataclass],
config_fn: Callable[[dict[str, Any] | None], utils.IsDataclass],
):
"""Registers a default configuration for a given config type and class."""
config_type_map = _KNOWN_DEFAULT_PROCESSOR_CONFIGS.setdefault(config_type, {})
if config_class in config_type_map:
logging.warning(
'Default config for %s already registered for %s overwriting.',
config_class,
config_type,
)
config_type_map[config_class] = config_fn


def default_config(
config_class: Type[T],
config_type: DefaultConfigType | None = None,
overrides: dict[str, Any] | None = None,
fallback_to_em_2d: bool = True,
) -> T:
"""Returns a default configuration for a given config type and class."""
if config_type is None and fallback_to_em_2d:
logging.warning('No default config type specified, falling back to EM_2D.')
config_type = DefaultConfigType.EM_2D
if config_type not in _KNOWN_DEFAULT_PROCESSOR_CONFIGS:
raise ValueError(f'No default configurations available for {config_type}')
default_map = _KNOWN_DEFAULT_PROCESSOR_CONFIGS[config_type]
if config_class not in default_map:
raise ValueError(f'No default config for {config_class} for {config_type}')
return default_map[config_class](overrides)

0 comments on commit 6a4308b

Please sign in to comment.