Skip to content

Commit

Permalink
- Rename processor configs from ProcessorNameConfig to Config.
Browse files Browse the repository at this point in the history
- Add explicit defaults for 2D EM data to flow_config.
- Remove defaults from the respective Configs' data

PiperOrigin-RevId: 693837439
  • Loading branch information
timblakely authored and copybara-github committed Nov 6, 2024
1 parent f989bcc commit e40ceca
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions connectomics/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
"""Various utility functions."""

import contextlib
import dataclasses
import itertools
import re
import time
from typing import Any, Protocol, TypeVar

from absl import logging
import dataclasses_json
Expand Down Expand Up @@ -108,3 +110,48 @@ def report_time(name):
def pascal_to_kebab(name: str) -> str:
"""Converts a PascalCase name to kebab-case."""
return _PASCAL_TO_KEBAB_RE.sub('-', name).lower()


# This is a kludge, since dataclasses do not export their typeshed for dataclass
# externally.
class IsDataclass(Protocol):
__dataclass_fields__: dict[str, Any]


D = TypeVar('D', bound=IsDataclass)


# TODO(timblakely): Move to a more appropriate location.
# TODO(timblakely): Support Sequences of Dataclasses.
def update_dataclass(
source: D,
overrides: dict[str, Any],
apply_recursive: bool = True,
) -> D:
"""Recursively updates a dataclass with overrides.
Contrary to dataclasses.replace, this function will only update attributes
that are present in the source dataclass, and will apply recursively to
sub-dataclasses.
Args:
source: The dataclass to update.
overrides: A mapping of attribute name to value to override.
apply_recursive: Whether to apply the overrides recursively to
sub-dataclasses or not.
Returns:
A new dataclass with the overrides applied.
"""
params = {}
for k, v in overrides.items():
if not hasattr(source, k):
raise ValueError(f'Attribute {k} not found in {source}')
attr = getattr(source, k)
if dataclasses.is_dataclass(attr) and apply_recursive:
params[k] = update_dataclass(attr, v)
else:
params[k] = v

source = dataclasses.replace(source, **params)
return source

0 comments on commit e40ceca

Please sign in to comment.