diff --git a/connectomics/volume/subvolume_processor.py b/connectomics/volume/subvolume_processor.py index 9abad4c..40af9f3 100644 --- a/connectomics/volume/subvolume_processor.py +++ b/connectomics/volume/subvolume_processor.py @@ -19,7 +19,8 @@ import enum import importlib import inspect -from typing import Any, List, Optional, Tuple, Type, Union +import logging +from typing import Any, Callable, List, Optional, Tuple, Type, TypeVar, Union from connectomics.common import array from connectomics.common import bounding_box @@ -325,3 +326,50 @@ def get_processor(config: SubvolumeProcessorConfig) -> SubvolumeProcessor: processor = getattr(package, name) args = {} if not config.args else config.args return processor(**args) + + +class DefaultConfigType(enum.Enum): + EM_2D = 'em_2d' + # TODO(timblakely): Support additional configuration. + # EM_3D = 'em_3d' + # LICONN = 'liconn' + + +_KNOWN_DEFAULT_PROCESSOR_CONFIGS = {} + + +T = TypeVar('T', bound=utils.IsDataclass) + + +def register_default_config( + config_type: DefaultConfigType, + config_class: Type[utils.IsDataclass], + config_fn: Callable[[dict[str, Any] | None], utils.IsDataclass], +): + """Registers a default configuration for a given config type and class.""" + config_type_map = _KNOWN_DEFAULT_PROCESSOR_CONFIGS.setdefault(config_type, {}) + if config_class in config_type_map: + logging.warning( + 'Default config for %s already registered for %s overwriting.', + config_class, + config_type, + ) + config_type_map[config_class] = config_fn + + +def default_config( + config_class: Type[T], + config_type: DefaultConfigType | None = None, + overrides: dict[str, Any] | None = None, + fallback_to_em_2d: bool = True, +) -> T: + """Returns a default configuration for a given config type and class.""" + if config_type is None and fallback_to_em_2d: + logging.warning('No default config type specified, falling back to EM_2D.') + config_type = DefaultConfigType.EM_2D + if config_type not in _KNOWN_DEFAULT_PROCESSOR_CONFIGS: + raise ValueError(f'No default configurations available for {config_type}') + default_map = _KNOWN_DEFAULT_PROCESSOR_CONFIGS[config_type] + if config_class not in default_map: + raise ValueError(f'No default config for {config_class} for {config_type}') + return default_map[config_class](overrides)