Skip to content

Commit

Permalink
Implement (minimally tested) replace_jsonpickle_decode
Browse files Browse the repository at this point in the history
  • Loading branch information
roomrys committed Sep 12, 2024
1 parent ab4b9d9 commit fc563c1
Showing 1 changed file with 271 additions and 1 deletion.
272 changes: 271 additions & 1 deletion sleap/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,276 @@ def matches(self, other: "Node") -> bool:
return other.name == self.name and other.weight == self.weight


def replace_jsonpickle_decode(json_str: str) -> Any:
"""Replace jsonpickle.decode with own decoder.
This function will decode the following from their encoded format:
`Node` objects from
{
"py/object": "sleap.skeleton.Node",
"py/state": { "py/tuple": ["thorax1", 1.0] }
}
to `Node(name="thorax1", weight=1.0)`
`EdgeType` objects from
{
"py/reduce": [
{ "py/type": "sleap.skeleton.EdgeType" },
{ "py/tuple": [1] }
]
}
to `EdgeType(1)`
`bytes` from
{
"py/b64": "aVZC..."
}
to `b"iVBO..."`
and any repeated objects from
{
"py/id": 1
}
to the object with the same reconstruction id (from top to bottom).
"""

def decode_id(
id: int, objects: List[Union[Node, EdgeType]]
) -> Union[Node, EdgeType]:
"""Decode the object with the given `py/id` value of `id`.
Args:
id: The `py/id` value to decode (1-indexed).
objects: The dictionary of objects that have already been decoded.
Returns:
The object with the given `py/id` value.
"""
return objects[id - 1]

def decode_state(state: dict) -> Node:
"""Reconstruct the `Node` object from 'py/state' key in the serialized nx_graph.
We support states in either dictionary or tuple format:
{
"py/state": { "py/tuple": ["thorax1", 1.0] }
}
or
{
"py/state": {"name": "thorax1", "weight": 1.0}
}
Args:
state: The state to decode, i.e. state = dict["py/state"]
Returns:
The `Node` object reconstructed from the state.
"""

if "py/tuple" in state:
return Node(*state["py/tuple"])

return Node(**state)

def decode_object_dict(object_dict) -> Node:
"""Decode dict containing `py/object` key in the serialized nx_graph.
Args:
object_dict: The dict to decode, i.e.
object_dict = {"py/object": ..., "py/state":...}
Raises:
ValueError: If object_dict does not have 'py/object' and 'py/state' keys.
ValueError: If object_dict['py/object'] is not 'sleap.skeleton.Node'.
Returns:
The decoded `Node` object.
"""

if object_dict["py/object"] != "sleap.skeleton.Node":
raise ValueError("Only 'sleap.skeleton.Node' objects are supported.")

node: Node = decode_state(object_dict["py/state"])
return node

def decode_node(
encoded_node: dict, decoded_objects: List[Union[Node, EdgeType]]
) -> Tuple[Node, List[Union[Node, EdgeType]]]:
"""Decode an item believed to be an encoded `Node` object.
Args:
encoded_node: The encoded node to decode.
decoded_objects: The list of decoded objects so far.
Returns:
The decoded node and the updated list of decoded objects.
"""

if "py/object" in encoded_node:
decoded_node: Node = decode_object_dict(encoded_node)
decoded_objects.append(decoded_node)
elif "py/id" in encoded_node:
decoded_node: Node = decode_id(encoded_node["py/id"], decoded_objects)

return decoded_node, decoded_objects

def decode_nodes(
encoded_nodes: List[dict], decoded_objects: List[Union[Node, EdgeType]]
) -> Tuple[List[Dict[str, Node]], List[Union[Node, EdgeType]]]:
"""Decode the 'nodes' key in the serialized nx_graph.
The encoded_nodes is a list of dictionary of two types:
- A dictionary with 'py/object' and 'py/state' keys.
- A dictionary with 'py/id' key.
Args:
encoded_nodes: The list of encoded nodes to decode.
decoded_objects: The list of decoded objects so far.
Returns:
The decoded nodes and the updated list of decoded objects.
"""

decoded_nodes: List[Dict[str, Node]] = []
for e_node_dict in encoded_nodes:
e_node = e_node_dict["id"]
d_node, decoded_objects = decode_node(e_node, decoded_objects)
decoded_nodes.append({"id": d_node})

return decoded_nodes, decoded_objects

def decode_reduce_dict(
reduce_dict: Dict[str, List[dict]], decoded_objects: List[Union[Node, EdgeType]]
) -> EdgeType:
"""Decode the 'reduce' key in the serialized nx_graph.
The reduce_dict is a dictionary in the following format:
{
"py/reduce": [
{ "py/type": "sleap.skeleton.EdgeType" },
{ "py/tuple": [1] }
]
}
Args:
reduce_dict: The dictionary to decode i.e. reduce_dict = {"py/reduce": ...}
decoded_objects: The list of decoded objects so far.
Returns:
The decoded `EdgeType` object.
"""

reduce_list = reduce_dict["py/reduce"]
has_py_type = has_py_tuple = False
for reduce_item in reduce_list:
if (
"py/type" in reduce_item
and reduce_item["py/type"] == "sleap.skeleton.EdgeType"
):
has_py_type = True
elif "py/tuple" in reduce_item:
edge_type: int = reduce_item["py/tuple"][0]
has_py_tuple = True

if not has_py_type or not has_py_tuple:
raise ValueError(
"Only 'sleap.skeleton.EdgeType' objects are supported. "
"The 'py/reduce' list must have dictionaries with 'py/type' and "
"'py/tuple' keys."
f"\n\tHas py/type: {has_py_type}\n\tHas py/tuple: {has_py_tuple}"
)

edge = EdgeType(edge_type)
decoded_objects.append(edge)

return edge, decoded_objects

def decode_edge_type(
encoded_edge_type: dict, decoded_objects: List[Union[Node, EdgeType]]
) -> Tuple[EdgeType, List[Union[Node, EdgeType]]]:
"""Decode the 'type' key in the serialized nx_graph.
Args:
encoded_edge_type: a dictionary with either 'py/id' or 'py/reduce' key.
decoded_objects: The list of decoded objects so far.
Returns:
A tuple including the decoded `EdgeType` object and the updated list of
decoded objects.
"""

if "py/reduce" in encoded_edge_type:
edge_type, decoded_objects = decode_reduce_dict(
encoded_edge_type, decoded_objects=decoded_objects
)
else:
# Expect a "py/id" instead of "py/reduce"
edge_type = decode_id(encoded_edge_type["py/id"], decoded_objects)
return edge_type, decoded_objects

def decode_links(
links: List[dict], decoded_objects: List[Union[Node, EdgeType]]
) -> Tuple[
List[Dict[str, Union[int, Node, EdgeType]]], List[Union[Node, EdgeType]]
]:
"""Decode the 'links' key in the serialized nx_graph.
The links are the edges in the graph and will have the following keys:
- source: The source node of the edge.
- target: The destination node of the edge.
- type: The type of the edge (e.g. BODY, SYMMETRY).
and more.
Args:
encoded_links: The list of encoded links to decode.
decoded_objects: The list of decoded objects so far.
"""

for link in links:
for key, value in link.items():
if key == "source":
link[key], decoded_objects = decode_node(value, decoded_objects)
elif key == "target":
link[key], decoded_objects = decode_node(value, decoded_objects)
elif key == "type":
link[key], decoded_objects = decode_edge_type(
value, decoded_objects
)

return links, decoded_objects

dicts = json.loads(json_str)

# Enforce same format across template and non-template skeletons
if "nx_graph" not in dicts:
# Non-template skeletons use the dicts as the "nx_graph"
dicts = {"nx_graph": dicts}

# Decode the graph
nx_graph = dicts["nx_graph"]

decoded_objects = []
for key, value in nx_graph.items():
if key == "nodes":
nx_graph[key], decoded_objects = decode_nodes(
value, decoded_objects=decoded_objects
)
elif key == "links":
nx_graph[key], decoded_objects = decode_links(
value, decoded_objects=decoded_objects
)

# Decode the preview image (if it exists)
preview_image = dicts.get("preview_image", None)
if preview_image is not None:
dicts["preview_image"] = decode_preview_image(
preview_image["py/b64"], return_bytes=True
)

return dicts


class Skeleton:
"""The main object for representing animal skeletons.
Expand Down Expand Up @@ -1071,7 +1341,7 @@ def from_json(
Returns:
An instance of the `Skeleton` object decoded from the JSON.
"""
dicts = jsonpickle.decode(json_str)
dicts: dict = replace_jsonpickle_decode(json_str)
nx_graph = dicts.get("nx_graph", dicts)
graph = json_graph.node_link_graph(nx_graph)

Expand Down

0 comments on commit fc563c1

Please sign in to comment.