From b0ca8eb06d8dfc04def806f195f3a6e053f73da7 Mon Sep 17 00:00:00 2001 From: Chester Chen <512707+chesterxgchen@users.noreply.github.com> Date: Mon, 8 Jan 2024 16:13:22 -0800 Subject: [PATCH] Address few misc. bugs (#2252) * address few misc. bugs 1) Step-by-step example train_with_mlflow.py should use global_step for logging metrics 2) fl_model_utils.update_model() losing metrics after update * update fl_model_utils.py * address comments * import re-order * remove logger for now --- .../cifar10/code/fl/train_with_mlflow.py | 3 ++- nvflare/app_common/utils/fl_model_utils.py | 20 ++++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py b/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py index c161a7c117..898d5aa507 100644 --- a/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py +++ b/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py @@ -139,7 +139,8 @@ def evaluate(input_weights): running_loss += loss.item() if i % 2000 == 1999: # print every 2000 mini-batches print(f"({client_id}) [{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") - mlflow.log_metric("loss", running_loss / 2000, i) + global_step = input_model.current_round * local_epochs * batch_size + epoch * batch_size + i + mlflow.log_metric("loss", running_loss / 2000, global_step) running_loss = 0.0 print(f"({client_id}) Finished Training") diff --git a/nvflare/app_common/utils/fl_model_utils.py b/nvflare/app_common/utils/fl_model_utils.py index 4ce9d01d6c..204928fd13 100644 --- a/nvflare/app_common/utils/fl_model_utils.py +++ b/nvflare/app_common/utils/fl_model_utils.py @@ -55,8 +55,9 @@ def to_shareable(fl_model: FLModel) -> Shareable: raise ValueError("FLModel without params and metrics is NOT supported.") elif fl_model.params is not None: if fl_model.params_type is None: - raise ValueError(f"Invalid ParamsType: ({fl_model.params_type}).") - data_kind = params_type_to_data_kind.get(fl_model.params_type) + fl_model.params_type = ParamsType.FULL + + data_kind = params_type_to_data_kind.get(fl_model.params_type.value) if data_kind is None: raise ValueError(f"Invalid ParamsType: ({fl_model.params_type}).") @@ -103,11 +104,15 @@ def from_shareable(shareable: Shareable, fl_ctx: Optional[FLContext] = None) -> metrics = dxo.data else: params_type = data_kind_to_params_type.get(dxo.data_kind) + params = dxo.data if params_type is None: - raise ValueError(f"Invalid shareable with dxo that has data kind: {dxo.data_kind}") + if params is None: + raise ValueError(f"Invalid shareable with dxo that has data kind: {dxo.data_kind}") + else: + params_type = ParamsType.FULL + params_type = ParamsType(params_type) - params = dxo.data if MetaKey.INITIAL_METRICS in meta: metrics = meta[MetaKey.INITIAL_METRICS] except: @@ -197,14 +202,15 @@ def get_configs(model: FLModel) -> Optional[dict]: @staticmethod def update_model(model: FLModel, model_update: FLModel, replace_meta: bool = True) -> FLModel: if model.params_type != ParamsType.FULL: - raise RuntimeError( - f"params_type {model_update.params_type} of `model` not supported! Expected `ParamsType.FULL`." - ) + raise RuntimeError(f"params_type {model.params_type} of `model` not supported! Expected `ParamsType.FULL`.") if replace_meta: model.meta = model_update.meta else: model.meta.update(model_update.meta) + + model.metrics = model_update.metrics + if model_update.params_type == ParamsType.FULL: model.params = model_update.params elif model_update.params_type == ParamsType.DIFF: