Skip to content

Commit

Permalink
Improve json serialization to accomodate numpy float32
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh committed Oct 10, 2024
1 parent 57d843a commit 178b518
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions nvflare/app_common/widgets/validation_json_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

import json
import os.path
from functools import singledispatch

import numpy as np

from nvflare.apis.dxo import DataKind, from_shareable, get_leaf_dxos
from nvflare.apis.event_type import EventType
Expand All @@ -23,6 +26,17 @@
from nvflare.widgets.widget import Widget


@singledispatch
def to_serializable(val):
"""Default json serializable method."""
return str(val)


@to_serializable.register(np.float32)
def ts_float32(val):
return np.float64(val)


class ValidationJsonGenerator(Widget):
def __init__(self, results_dir=AppConstants.CROSS_VAL_DIR, json_file_name="cross_val_results.json"):
"""Catches VALIDATION_RESULT_RECEIVED event and generates a results.json containing accuracy of each
Expand Down Expand Up @@ -58,7 +72,6 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
if val_results:
try:
dxo = from_shareable(val_results)
dxo.validate()

if dxo.data_kind == DataKind.METRICS:
if data_client not in self._val_results:
Expand All @@ -71,7 +84,6 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
for err in errors:
self.log_error(fl_ctx, f"Bad result from {data_client}: {err}")
for _sub_data_client, _dxo in leaf_dxos.items():
_dxo.validate()
if _sub_data_client not in self._val_results:
self._val_results[_sub_data_client] = {}
self._val_results[_sub_data_client][model_owner] = _dxo.data
Expand All @@ -93,4 +105,4 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):

res_file_path = os.path.join(cross_val_res_dir, self._json_file_name)
with open(res_file_path, "w") as f:
json.dump(self._val_results, f)
json.dump(self._val_results, f, default=to_serializable)

0 comments on commit 178b518

Please sign in to comment.