Skip to content

Commit

Permalink
Sort encoded Skeleton dictionary for backwards compatibility (#1975)
Browse files Browse the repository at this point in the history
* Add failing test to check that encoded Skeleton is sorted

* Sort Skeleton dictionary before encoding

* Remove unused import
  • Loading branch information
roomrys authored Oct 2, 2024
1 parent 10aae76 commit 1339f0d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
24 changes: 21 additions & 3 deletions sleap/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import attr
import cattr
import h5py
import jsonpickle
import networkx as nx
import numpy as np
from networkx.readwrite import json_graph
Expand Down Expand Up @@ -421,11 +420,30 @@ def encode(cls, data: Dict[str, Any]) -> str:
Returns:
json_str: The JSON string representation of the data.
"""

# This is required for backwards compatibility with SLEAP <=1.3.4
sorted_data = cls._recursively_sort_dict(data)

encoder = cls()
encoded_data = encoder._encode(data)
encoded_data = encoder._encode(sorted_data)
json_str = json.dumps(encoded_data)
return json_str

@staticmethod
def _recursively_sort_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively sorts the dictionary by keys."""
sorted_dict = dict(sorted(dictionary.items()))
for key, value in sorted_dict.items():
if isinstance(value, dict):
sorted_dict[key] = SkeletonEncoder._recursively_sort_dict(value)
elif isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
sorted_dict[key][i] = SkeletonEncoder._recursively_sort_dict(
item
)
return sorted_dict

def _encode(self, obj: Any) -> Any:
"""Recursively encodes the input object.
Expand Down Expand Up @@ -1477,7 +1495,7 @@ def to_json(self, node_to_idx: Optional[Dict[Node, int]] = None) -> str:
Returns:
A string containing the JSON representation of the skeleton.
"""
jsonpickle.set_encoder_options("simplejson", sort_keys=True, indent=4)

if node_to_idx is not None:
# Map Nodes to int
indexed_node_graph = nx.relabel_nodes(G=self._graph, mapping=node_to_idx)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@ def test_decoded_encoded_Skeleton(skeleton_fixture_name, request):
# Encode the graph as a json string to test .encode method
encoded_json_str = SkeletonEncoder.encode(graph)

# Assert that the encoded json has keys in sorted order (backwards compatibility)
encoded_dict = json.loads(encoded_json_str)
sorted_keys = sorted(encoded_dict.keys())
assert list(encoded_dict.keys()) == sorted_keys
for key, value in encoded_dict.items():
if isinstance(value, dict):
assert list(value.keys()) == sorted(value.keys())
elif isinstance(value, list):
for item in value:
if isinstance(item, dict):
assert list(item.keys()) == sorted(item.keys())

# Get the skeleton from the encoded json string
decoded_skeleton = Skeleton.from_json(encoded_json_str)

Expand Down

0 comments on commit 1339f0d

Please sign in to comment.