Skip to content

Commit

Permalink
feat: keep buit specs in registry
Browse files Browse the repository at this point in the history
  • Loading branch information
supersergiy committed Nov 2, 2023
1 parent 3147f19 commit 2916c70
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 20 deletions.
6 changes: 2 additions & 4 deletions tests/unit/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ def test_register(register_dummy_a):
def test_build_unversioned(spec: dict, expected: Any, register_dummy_a, register_dummy_b):
result = builder.build(spec)
assert result == expected
if hasattr(result, "__dict__"):
assert result.__built_with_spec == spec
assert builder.get_initial_builder_spec(result) == spec


@pytest.mark.parametrize(
Expand All @@ -166,8 +165,7 @@ def test_build_unversioned(spec: dict, expected: Any, register_dummy_a, register
def test_build_versioned(spec: dict, expected: Any, register_dummy_a_v0, register_dummy_a_v2):
result = builder.build(spec)
assert result == expected
if hasattr(result, "__dict__"):
assert result.__built_with_spec == spec
assert builder.get_initial_builder_spec(result) == spec


def test_build_partial(register_dummy_a, register_dummy_c):
Expand Down
2 changes: 1 addition & 1 deletion zetta_utils/builder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Building objects from specs"""
from . import constants
from .registry import REGISTRY, register, get_matching_entry, unregister
from .build import SPECIAL_KEYS, build, BuilderPartial
from .build import SPECIAL_KEYS, build, BuilderPartial, get_initial_builder_spec
from . import built_in_registrations
41 changes: 31 additions & 10 deletions zetta_utils/builder/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from zetta_utils import parsing
from zetta_utils.common import ctx_managers
from zetta_utils.parsing import json
from zetta_utils.typing import JsonDict

from . import constants
from .registry import get_matching_entry
Expand All @@ -19,6 +20,19 @@
"version": "@version",
}

BUILT_OBJECT_ID_REGISTRY: dict[int, JsonDict] = {}


def get_initial_builder_spec(obj: Any) -> JsonDict | None:
"""Returns the builder spec that the object was initially built with.
Note that mutations to the object after it was built will not be
reflected in the spec. Returns `None` if the object was not built with
builder
"""
# breakpoint()
result = BUILT_OBJECT_ID_REGISTRY.get(id(obj), None)
return result


@typechecked
def build(
Expand All @@ -41,12 +55,18 @@ def build(

# error check the spec
_traverse_spec(
final_spec, _check_type_value, name_prefix="spec", version=constants.DEFAULT_VERSION
final_spec,
_check_type_value,
name_prefix="spec",
version=constants.DEFAULT_VERSION,
)

# build the spec
result = _traverse_spec(
final_spec, _build_dict_spec, name_prefix="spec", version=constants.DEFAULT_VERSION
final_spec,
_build_dict_spec,
name_prefix="spec",
version=constants.DEFAULT_VERSION,
)

return result
Expand All @@ -64,7 +84,10 @@ def _traverse_spec(spec: Any, apply_fn: Callable, name_prefix: str, version: str
elif isinstance(spec, list):
result = [
_traverse_spec(
spec=e, apply_fn=apply_fn, name_prefix=f"{name_prefix}[{i}]", version=version
spec=e,
apply_fn=apply_fn,
name_prefix=f"{name_prefix}[{i}]",
version=version,
)
for i, e in enumerate(spec)
]
Expand Down Expand Up @@ -93,7 +116,7 @@ def _traverse_spec(spec: Any, apply_fn: Callable, name_prefix: str, version: str
}
else:
result = spec

BUILT_OBJECT_ID_REGISTRY[id(result)] = spec
return result


Expand All @@ -106,7 +129,10 @@ def _check_type_value(spec: dict[str, Any], name_prefix: str, version: str) -> A
get_matching_entry(this_type, version=version)
for k, v in spec.items():
_traverse_spec(
v, apply_fn=_check_type_value, name_prefix=f"{name_prefix}.{k}", version=version
v,
apply_fn=_check_type_value,
name_prefix=f"{name_prefix}.{k}",
version=version,
)


Expand Down Expand Up @@ -153,11 +179,6 @@ def _build_dict_spec(spec: dict[str, Any], name_prefix: str, version: str) -> An
else:
raise ValueError(f"Unsupported mode: {this_mode}")

# save the spec that was used to create the object if possible
# slotted classes won't allow adding new attributes
if hasattr(result, "__dict__"):
object.__setattr__(result, "__built_with_spec", spec)

return result


Expand Down
17 changes: 12 additions & 5 deletions zetta_utils/training/lightning/trainers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from pytorch_lightning.strategies import ddp

from zetta_utils import builder, log
from zetta_utils.builder.build import get_initial_builder_spec
from zetta_utils.parsing import json
from zetta_utils.typing import JsonSerializableValue

logger = log.get_logger("zetta_utils")
ONNX_OPSET_VERSION = 17
Expand Down Expand Up @@ -133,15 +135,20 @@ def save_checkpoint(

regime = self.lightning_module
for k, v in regime._modules.items(): # pylint: disable=protected-access
if hasattr(v, "__built_with_spec"):
model_spec = getattr(v, "__built_with_spec") # pylint: disable=protected-access
while "@type" in model_spec and model_spec["@type"] == "load_weights_file":
model_spec = model_spec["model"]
model_spec: JsonSerializableValue = get_initial_builder_spec(v)
if model_spec is not None:
unrolled_spec: JsonSerializableValue = model_spec
while (
isinstance(unrolled_spec, dict)
and "@type" in unrolled_spec
and unrolled_spec["@type"] == "load_weights_file"
):
unrolled_spec = unrolled_spec["model"]

spec = {
"@type": "load_weights_file",
"@version": importlib.metadata.version("zetta_utils"),
"model": model_spec,
"model": unrolled_spec,
"ckpt_path": filepath,
"component_names": [k],
"remove_component_prefix": True,
Expand Down
13 changes: 13 additions & 0 deletions zetta_utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,16 @@ def ensure_seq_of_seq(x, length):
result = [x] * length

return result


JsonSerializableValue = Union[
str,
int,
float,
bool,
None,
list["JsonSerializableValue"],
dict[str, "JsonSerializableValue"],
]

JsonDict = dict[str, JsonSerializableValue]

0 comments on commit 2916c70

Please sign in to comment.