From b805725cb7246f2630d6e67bef4d1d9fad4662d0 Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Thu, 20 Jul 2023 16:58:03 -0700 Subject: [PATCH] temp hold --- nvflare/client/cache.py | 76 ------------------- nvflare/{ => client}/lightning/__init__.py | 3 +- .../module.py => client/lightning/api.py} | 36 +++------ 3 files changed, 14 insertions(+), 101 deletions(-) delete mode 100644 nvflare/client/cache.py rename nvflare/{ => client}/lightning/__init__.py (90%) rename nvflare/{lightning/module.py => client/lightning/api.py} (84%) diff --git a/nvflare/client/cache.py b/nvflare/client/cache.py deleted file mode 100644 index 3162b5c298..0000000000 --- a/nvflare/client/cache.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional - -from nvflare.app_common.abstract.fl_model import FLModel, ParamsType -from nvflare.app_common.model_exchange.model_exchanger import ModelExchanger - -from .config import ClientConfig -from .utils import copy_fl_model_attributes, get_meta_from_fl_model, numerical_params_diff, set_fl_model_with_meta - -IN_ATTRS = ("optimizer_params", "current_round") -SYS_ATTRS = ("job_id", "site_name", "total_rounds") - -DIFF_MAP = {"numerical_params_diff": numerical_params_diff} - - -class Cache: - """This class is used to remember attributes that need to share for a user code. - - For example, after "global_evaluate" we should remember the "metrics" value. - And set that into the model that we want to submit after "train". - - For each user file: - - we only need 1 model exchanger. - - we only need to pull global model once - - """ - - def __init__(self, model_exchanger: ModelExchanger, config: ClientConfig): - self.model_exchanger = model_exchanger - self.input_model: Optional[FLModel] = None - self.meta = None - self.sys_meta = None - - self.config = config - self.initial_metrics = None # get from evaluate on "global model" - self._get_model() - - def _get_model(self): - self.input_model: FLModel = self.model_exchanger.receive_model() - self.meta = get_meta_from_fl_model(self.input_model, IN_ATTRS) - self.sys_meta = get_meta_from_fl_model(self.input_model, SYS_ATTRS) - - def construct_fl_model(self, params): - fl_model = FLModel(params_type=ParamsType.FULL, params=params) - if self.initial_metrics is not None: - fl_model.metrics = self.initial_metrics - - # model difference - params_diff_func_name = self.config.get_params_diff_func() - if params_diff_func_name is not None: - if params_diff_func_name not in DIFF_MAP: - raise RuntimeError(f"params_diff_func {params_diff_func_name} is not pre-defined.") - params_diff_func = DIFF_MAP[params_diff_func_name] - fl_model.params = params_diff_func(self.input_model.params, fl_model.params) - fl_model.params_type = ParamsType.DIFF - - set_fl_model_with_meta(fl_model, self.meta, IN_ATTRS) - copy_fl_model_attributes(self.input_model, fl_model) - fl_model.meta = self.meta - return fl_model - - def __str__(self): - return f"Cache(model_exchanger: {self.model_exchanger}, initial_metrics: {self.initial_metrics})" diff --git a/nvflare/lightning/__init__.py b/nvflare/client/lightning/__init__.py similarity index 90% rename from nvflare/lightning/__init__.py rename to nvflare/client/lightning/__init__.py index 8d081690d5..72cc90332a 100644 --- a/nvflare/lightning/__init__.py +++ b/nvflare/client/lightning/__init__.py @@ -12,4 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .module import LightningModule as LightningModule +from .api import to_fl as to_fl +from .api import init as init diff --git a/nvflare/lightning/module.py b/nvflare/client/lightning/api.py similarity index 84% rename from nvflare/lightning/module.py rename to nvflare/client/lightning/api.py index a8ab63a9f0..18196d8644 100644 --- a/nvflare/lightning/module.py +++ b/nvflare/client/lightning/api.py @@ -22,6 +22,18 @@ import nvflare.client as flare +def init(): + config_file = "nvf_lightning.json" + config = { + "exchange_path": "./", + "exchange_format": "pytorch", + "params_type": "FULL" + } + with open(config_file, "w") as f: + json.dump(config, f) + flare.init(config=config_file) + + def unflatten(global_weights): "Unflattens the params from NVFlare." result = {} @@ -69,17 +81,6 @@ def on_train_end(self): super().on_train_end() print("\n *****nvflare****** on_train_end ********** \n") self._fl_train_end() - - def _fl_init(self): - config_file = "nvf_lightning.json" - config = { - "exchange_path": "./", - "exchange_format": "pytorch", - "params_type": "FULL" - } - with open(config_file, "w") as f: - json.dump(config, f) - flare.init(config=config_file) def _fl_train_start(self): print("ZZZZZ calling _fl_train_start ZZZZZ") @@ -99,19 +100,6 @@ def _fl_train_end(self): flare.submit_model(weights) print("ZZZZZ ending _fl_train_end ZZZZZ") - @staticmethod - def fit_start(_func): - """ Decorator factory. """ - - def decorator(func): - @functools.wraps(func) - def wrapper(self, *args, **kwargs): - self._fl_init() - return func(self, *args, **kwargs) - - return wrapper - return decorator(_func) - @staticmethod def train_start(_func): """ Decorator factory. """