Skip to content

Commit

Permalink
add check_for_errors_in_dkt_values function
Browse files Browse the repository at this point in the history
  • Loading branch information
Czaki committed Apr 8, 2022
1 parent 80ce3f0 commit fb63691
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/nme/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
rename_key,
update_argument,
)
from ._serialize_hooks import NMEEncoder, nme_object_encoder, nme_object_hook
from ._serialize_hooks import NMEEncoder, check_for_errors_in_dkt_values, nme_object_encoder, nme_object_hook
from .version import version as __version__


Expand Down Expand Up @@ -47,6 +47,7 @@ def nme_cbor_decoder(decoder, value):

__all__ = (
"class_to_str",
"check_for_errors_in_dkt_values",
"register_class",
"nme_object_hook",
"rename_key",
Expand Down
14 changes: 14 additions & 0 deletions src/nme/_serialize_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ def default(self, o):
return val


def check_for_errors_in_dkt_values(dkt: dict) -> typing.List[str]:
"""
Function checking if any of values in dict contains ``"__error__"`` key.
:param dkt: dictionary to check.
:return: list of keys that value is dict containing ``"__error__"`` key.
"""
return [key for key, value in dkt.items() if isinstance(value, dict) and "__error__" in value]


def nme_object_hook(dkt: dict) -> typing.Any:
"""
Function restoring supported types from :py:func:`nme_object_encoder` function output.
Expand All @@ -133,6 +143,10 @@ def nme_object_hook(dkt: dict) -> typing.Any:
cls_str = dkt.pop("__class__")
version_dkt = dkt.pop("__class_version_dkt__") if "__class_version_dkt__" in dkt else {cls_str: "0.0.0"}
dkt = {"__values__": dkt, "__class__": cls_str, "__class_version_dkt__": version_dkt}
problematic_fields = check_for_errors_in_dkt_values(dkt["__values__"])
if problematic_fields:
dkt["__error__"] = f"Error in fields: {', '.join(problematic_fields)}"
return dkt
try:
dkt_migrated = REGISTER.migrate_data(dkt["__class__"], dkt["__class_version_dkt__"], dkt["__values__"])
cls = REGISTER.get_class(dkt["__class__"])
Expand Down
28 changes: 27 additions & 1 deletion src/tests/test_json_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest
from napari.utils import Colormap
from napari.utils.notifications import NotificationSeverity
from pydantic import BaseModel, dataclasses
from pydantic import BaseModel, Extra, dataclasses

from nme import NMEEncoder, nme_object_hook, register_class, rename_key
from nme._class_register import class_to_str
Expand Down Expand Up @@ -230,6 +230,32 @@ class MainClass(BaseClass):
ob = json.loads(data_str, object_hook=nme_object_hook)
assert isinstance(ob, MainClass)

def test_error_in_object_restore(self, clean_register):
@register_class
class SubClass(BaseModel, extra=Extra.forbid):
field: int = 1

@register_class
class MainClass(BaseModel):
field: int = 1
sub11: SubClass = SubClass()

data_str = """
{"__class__": "test_json_hooks.TestNMEObjectHook.test_error_in_object_restore.<locals>.MainClass",
"__class_version_dkt__":
{"test_json_hooks.TestNMEObjectHook.test_error_in_object_restore.<locals>.MainClass": "0.0.0"},
"__values__": {"field": 1,
"sub11": {"__class__": "test_json_hooks.TestNMEObjectHook.test_error_in_object_restore.<locals>.SubClass",
"__class_version_dkt__":
{"test_json_hooks.TestNMEObjectHook.test_error_in_object_restore.<locals>.SubClass": "0.0.0"},
"__values__": {"field": 1, "eee": 1}}}}
"""

ob = json.loads(data_str, object_hook=nme_object_hook)
assert isinstance(ob, dict)
assert "__error__" in ob
assert ob["__error__"] == "Error in fields: sub11"


class DummyClassForTest:
class DummySubClassForTest:
Expand Down

0 comments on commit fb63691

Please sign in to comment.