diff --git a/connectomics/common/utils.py b/connectomics/common/utils.py index cf8b7c9..6e391ae 100644 --- a/connectomics/common/utils.py +++ b/connectomics/common/utils.py @@ -15,9 +15,11 @@ """Various utility functions.""" import contextlib +import dataclasses import itertools import re import time +from typing import Any, Protocol, TypeVar from absl import logging import dataclasses_json @@ -108,3 +110,48 @@ def report_time(name): def pascal_to_kebab(name: str) -> str: """Converts a PascalCase name to kebab-case.""" return _PASCAL_TO_KEBAB_RE.sub('-', name).lower() + + +# This is a kludge, since dataclasses do not export their typeshed for dataclass +# externally. +class IsDataclass(Protocol): + __dataclass_fields__: dict[str, Any] + + +D = TypeVar('D', bound=IsDataclass) + + +# TODO(timblakely): Move to a more appropriate location. +# TODO(timblakely): Support Sequences of Dataclasses. +def update_dataclass( + source: D, + overrides: dict[str, Any], + apply_recursive: bool = True, +) -> D: + """Recursively updates a dataclass with overrides. + + Contrary to dataclasses.replace, this function will only update attributes + that are present in the source dataclass, and will apply recursively to + sub-dataclasses. + + Args: + source: The dataclass to update. + overrides: A mapping of attribute name to value to override. + apply_recursive: Whether to apply the overrides recursively to + sub-dataclasses or not. + + Returns: + A new dataclass with the overrides applied. + """ + params = {} + for k, v in overrides.items(): + if not hasattr(source, k): + raise ValueError(f'Attribute {k} not found in {source}') + attr = getattr(source, k) + if dataclasses.is_dataclass(attr) and apply_recursive: + params[k] = update_dataclass(attr, v) + else: + params[k] = v + + source = dataclasses.replace(source, **params) + return source