From 81caf29556e3ecc7377a1d127a148dad70d62327 Mon Sep 17 00:00:00 2001 From: Chester Chen <512707+chesterxgchen@users.noreply.github.com> Date: Mon, 26 Feb 2024 09:28:23 -0800 Subject: [PATCH] In process Client API Executor Part 1 (#2248) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 1) fix issue with logging 2) fix example code formatting add queue.task_done() 1) add message bus 2) hide task func wrapper class 3) rename executor package 4) clean up some code update meta info remove used code optimize import fix message_bus import order change rename the executor from ClientAPIMemExecutor to InProcessClientAPIExecutor 1) remove thread_pool 2) further loose couple executor and client_api implementation formating add unit tests avoid duplicated constant TASK_NAME definition split PR into two parts (besides message bus) this is part 1: only remove the example and job template changes 1. Replace MemPipe (Queues) with callback via EventManager 2. Simplified overall logics 3. notice the param convert doesn't quite work ( need to fix later) 4. removed some tests that now invalid. Will need to add more unit tests later fix task_name is None bug add few unit tests code format update to comform with new databus changes * rebase * conform with recemt changes * clean up, support main func * fix format * update unit tests * databus updates, enhance module parsing * address comments * add docstrings, address comments --------- Co-authored-by: Sean Yang Co-authored-by: Yuan-Ting Hsieh (謝沅廷) --- .../cifar10/code/fl/train_with_mlflow.py | 5 +- .../sag_pt_in_proc/config_fed_client.conf | 69 +++++ .../sag_pt_in_proc/config_fed_server.conf | 127 ++++++++ job_templates/sag_pt_in_proc/info.conf | 5 + job_templates/sag_pt_in_proc/info.md | 11 + job_templates/sag_pt_in_proc/meta.conf | 10 + nvflare/apis/fl_constant.py | 1 + .../executors/client_api_launcher_executor.py | 5 +- .../executors/exec_task_fn_wrapper.py | 79 +++++ .../in_process_client_api_executor.py | 212 +++++++++++++ .../launchers/subprocess_launcher.py | 2 + .../widgets/external_configurator.py | 9 +- nvflare/app_opt/lightning/api.py | 17 +- .../pt/in_process_client_api_executor.py | 64 ++++ nvflare/client/__init__.py | 1 - nvflare/client/api.py | 293 ++++-------------- nvflare/client/api_spec.py | 279 +++++++++++++++++ nvflare/client/config.py | 5 +- nvflare/client/constants.py | 4 +- nvflare/client/decorator.py | 32 +- nvflare/client/ex_process/__init__.py | 13 + nvflare/client/ex_process/api.py | 204 ++++++++++++ nvflare/client/in_process/__init__.py | 13 + nvflare/client/in_process/api.py | 214 +++++++++++++ nvflare/fuel/message/__init__.py | 13 + nvflare/fuel/utils/function_utils.py | 56 ++++ .../executors/exec_task_fn_wrapper_test.py | 89 ++++++ tests/unit_test/client/__init__.py | 13 + tests/unit_test/client/in_process/__init__.py | 13 + tests/unit_test/client/in_process/api_test.py | 91 ++++++ .../fuel/utils/function_utils_test.py | 39 +++ 31 files changed, 1709 insertions(+), 279 deletions(-) create mode 100644 job_templates/sag_pt_in_proc/config_fed_client.conf create mode 100644 job_templates/sag_pt_in_proc/config_fed_server.conf create mode 100644 job_templates/sag_pt_in_proc/info.conf create mode 100644 job_templates/sag_pt_in_proc/info.md create mode 100644 job_templates/sag_pt_in_proc/meta.conf create mode 100644 nvflare/app_common/executors/exec_task_fn_wrapper.py create mode 100644 nvflare/app_common/executors/in_process_client_api_executor.py create mode 100644 nvflare/app_opt/pt/in_process_client_api_executor.py create mode 100644 nvflare/client/api_spec.py create mode 100644 nvflare/client/ex_process/__init__.py create mode 100644 nvflare/client/ex_process/api.py create mode 100644 nvflare/client/in_process/__init__.py create mode 100644 nvflare/client/in_process/api.py create mode 100644 nvflare/fuel/message/__init__.py create mode 100644 nvflare/fuel/utils/function_utils.py create mode 100644 tests/unit_test/app_common/executors/exec_task_fn_wrapper_test.py create mode 100644 tests/unit_test/client/__init__.py create mode 100644 tests/unit_test/client/in_process/__init__.py create mode 100644 tests/unit_test/client/in_process/api_test.py create mode 100644 tests/unit_test/fuel/utils/function_utils_test.py diff --git a/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py b/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py index 898d5aa507..1d43b88d4d 100644 --- a/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py +++ b/examples/hello-world/step-by-step/cifar10/code/fl/train_with_mlflow.py @@ -95,7 +95,7 @@ def evaluate(input_weights): # (4) receive FLModel from NVFlare input_model = flare.receive() - client_id = flare.system_info().get("site_name", None) + client_id = flare.get_site_name() # Based on different "task" we will do different things # for "train" task (flare.is_train()) we use the received model to do training and/or evaluation @@ -106,7 +106,7 @@ def evaluate(input_weights): # for "submit_model" task (flare.is_submit_model()) we just need to send back the local model # (5) performing train task on received model if flare.is_train(): - print(f"({client_id}) round={input_model.current_round}/{input_model.total_rounds-1}") + print(f"({client_id}) current_round={input_model.current_round}, total_rounds={input_model.total_rounds}") # (5.1) loads model from NVFlare net.load_state_dict(input_model.params) @@ -167,7 +167,6 @@ def evaluate(input_weights): # (5.5) send model back to NVFlare flare.send(output_model) - # (6) performing evaluate task on received model elif flare.is_evaluate(): accuracy = evaluate(input_model.params) diff --git a/job_templates/sag_pt_in_proc/config_fed_client.conf b/job_templates/sag_pt_in_proc/config_fed_client.conf new file mode 100644 index 0000000000..bda33c0826 --- /dev/null +++ b/job_templates/sag_pt_in_proc/config_fed_client.conf @@ -0,0 +1,69 @@ +{ + # version of the configuration + format_version = 2 + + fn_path = "train.main" + fn_args = { + batch_size = 6 + dataset_path = "/tmp/nvflare/data/cifar10" + num_workers = 2 + } + + # Client Computing Executors. + executors = [ + { + # tasks the executors are defined to handle + tasks = ["train"] + + # This particular executor + executor { + + path = "nvflare.app_opt.pt.in_process_client_api_executor.PTInProcessClientAPIExecutor" + args { + # if the task_fn_path is main, task_fn_args are passed as sys.argv + # if the task_fn_path is a function, task_fn_args are passed as the function args + # (Note: task_fn_path must be of the form {module}.{func_name}) + task_fn_path = "{fn_path}" + task_fn_args = "{fn_args}" + + # if the transfer_type is FULL, then it will be sent directly + # if the transfer_type is DIFF, then we will calculate the + # difference VS received parameters and send the difference + params_transfer_type = "DIFF" + + # if train_with_evaluation is true, the executor will expect + # the custom code need to send back both the trained parameters and the evaluation metric + # otherwise only trained parameters are expected + train_with_evaluation = true + + # time interval in seconds. Time interval to wait before check if the local task has submitted the result + # if the local task takes long time, you can increase this interval to larger number + # uncomment to overwrite the default, default is 0.5 seconds + result_pull_interval = 0.5 + + # time interval in seconds. Time interval to wait before check if the trainig code has log metric (such as + # Tensorboard log, MLFlow log or Weights & Biases logs. The result will be streanmed to the server side + # then to the corresponding tracking system + # if the log is not needed, you can set this to a larger number + # uncomment to overwrite the default, default is None, which disable the log streaming feature. + log_pull_interval = 0.1 + + } + } + } + ], + + # this defined an array of task data filters. If provided, it will control the data from server controller to client executor + task_data_filters = [] + + # this defined an array of task result filters. If provided, it will control the result from client executor to server controller + task_result_filters = [] + + components = [ + { + "id": "event_to_fed", + "name": "ConvertToFedEvent", + "args": {"events_to_convert": ["analytix_log_stats"], "fed_event_prefix": "fed."} + } + ] +} diff --git a/job_templates/sag_pt_in_proc/config_fed_server.conf b/job_templates/sag_pt_in_proc/config_fed_server.conf new file mode 100644 index 0000000000..ab5691c4b7 --- /dev/null +++ b/job_templates/sag_pt_in_proc/config_fed_server.conf @@ -0,0 +1,127 @@ +{ + # version of the configuration + format_version = 2 + + # task data filter: if filters are provided, the filter will filter the data flow out of server to client. + task_data_filters =[] + + # task result filter: if filters are provided, the filter will filter the result flow out of client to server. + task_result_filters = [] + + # This assumes that there will be a "net.py" file with class name "Net". + # If your model code is not in "net.py" and class name is not "Net", please modify here + model_class_path = "net.Net" + + # workflows: Array of workflows the control the Federated Learning workflow lifecycle. + # One can specify multiple workflows. The NVFLARE will run them in the order specified. + workflows = [ + { + # 1st workflow" + id = "scatter_and_gather" + + # name = ScatterAndGather, path is the class path of the ScatterAndGather controller. + path = "nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather" + args { + # argument of the ScatterAndGather class. + # min number of clients required for ScatterAndGather controller to move to the next round + # during the workflow cycle. The controller will wait until the min_clients returned from clients + # before move to the next step. + min_clients = 2 + + # number of global round of the training. + num_rounds = 5 + + # starting round is 0-based + start_round = 0 + + # after received min number of clients' result, + # how much time should we wait further before move to the next step + wait_time_after_min_received = 0 + + # For ScatterAndGather, the server will aggregate the weights based on the client's result. + # the aggregator component id is named here. One can use the this ID to find the corresponding + # aggregator component listed below + aggregator_id = "aggregator" + + # The Scatter and Gather controller use an persistor to load the model and save the model. + # The persistent component can be identified by component ID specified here. + persistor_id = "persistor" + + # Shareable to a communication message, i.e. shared between clients and server. + # Shareable generator is a component that responsible to take the model convert to/from this communication message: Shareable. + # The component can be identified via "shareable_generator_id" + shareable_generator_id = "shareable_generator" + + # train task name: client side needs to have an executor that handles this task + train_task_name = "train" + + # train timeout in second. If zero, meaning no timeout. + train_timeout = 0 + } + } + ] + + # List of components used in the server side workflow. + components = [ + { + # This is the persistence component used in above workflow. + # PTFileModelPersistor is a Pytorch persistor which save/read the model to/from file. + + id = "persistor" + path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" + + # the persitor class take model class as argument + # This imply that the model is initialized from the server-side. + # The initialized model will be broadcast to all the clients to start the training. + args.model.path = "{model_class_path}" + }, + { + # This is the generator that convert the model to shareable communication message structure used in workflow + id = "shareable_generator" + path = "nvflare.app_common.shareablegenerators.full_model_shareable_generator.FullModelShareableGenerator" + args = {} + }, + { + # This is the aggregator that perform the weighted average aggregation. + # the aggregation is "in-time", so it doesn't wait for client results, but aggregates as soon as it received the data. + id = "aggregator" + path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator" + args.expected_data_kind = "WEIGHT_DIFF" + }, + { + # This component is not directly used in Workflow. + # it select the best model based on the incoming global validation metrics. + id = "model_selector" + path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" + # need to make sure this "key_metric" match what server side received + args.key_metric = "accuracy" + }, + { + id = "receiver" + path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver" + args.events = ["fed.analytix_log_stats"] + }, + + { + id = "mlflow_receiver" + path = "nvflare.app_opt.tracking.mlflow.mlflow_receiver.MLflowReceiver" + args { + # tracking_uri = "http://0.0.0.0:5000" + tracking_uri = "" + kwargs { + experiment_name = "nvflare-sag-pt-experiment" + run_name = "nvflare-sag-pt-with-mlflow" + experiment_tags { + "mlflow.note.content": "## **NVFlare SAG PyTorch experiment with MLflow**" + } + run_tags { + "mlflow.note.content" = "## Federated Experiment tracking with MLflow \n### Example of using **[NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html)** to train an image classifier using federated averaging ([FedAvg]([FedAvg](https://arxiv.org/abs/1602.05629))) and [PyTorch](https://pytorch.org/) as the deep learning training framework. This example also highlights the NVFlare streaming capability from the clients to the server.\n\n> **_NOTE:_** \n This example uses the *[CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html)* dataset and will load its data within the trainer code.\n" + } + } + artifact_location = "artifacts" + events = ["fed.analytix_log_stats"] + } + } + ] + +} diff --git a/job_templates/sag_pt_in_proc/info.conf b/job_templates/sag_pt_in_proc/info.conf new file mode 100644 index 0000000000..56be30170f --- /dev/null +++ b/job_templates/sag_pt_in_proc/info.conf @@ -0,0 +1,5 @@ +{ + description = "scatter & gather workflow using pytorch with in_process executor" + client_category = "client_api" + controller_type = "server" +} \ No newline at end of file diff --git a/job_templates/sag_pt_in_proc/info.md b/job_templates/sag_pt_in_proc/info.md new file mode 100644 index 0000000000..bda88cacdf --- /dev/null +++ b/job_templates/sag_pt_in_proc/info.md @@ -0,0 +1,11 @@ +# Job Template Information Card + +## sag_pt_in_proc + name = "sag_pt_in_proc" + description = "Scatter and Gather Workflow using pytorch with in_process executor" + class_name = "ScatterAndGather" + controller_type = "server" + executor_type = "in_process_client_api_executor" + contributor = "NVIDIA" + init_publish_date = "2024-02-8" + last_updated_date = "2024-02-8" # yyyy-mm-dd diff --git a/job_templates/sag_pt_in_proc/meta.conf b/job_templates/sag_pt_in_proc/meta.conf new file mode 100644 index 0000000000..d543facc8f --- /dev/null +++ b/job_templates/sag_pt_in_proc/meta.conf @@ -0,0 +1,10 @@ +{ + name = "sag_pt_in_proc" + resource_spec = {} + deploy_map { + # change deploy map as needed. + app = ["@ALL"] + } + min_clients = 2 + mandatory_clients = [] +} diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 4f1f32b0ce..af15ac6504 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -414,6 +414,7 @@ class FLMetaKey: FILTER_HISTORY = "filter_history" CONFIGS = "configs" VALIDATE_TYPE = "validate_type" + START_ROUND = "start_round" CURRENT_ROUND = "current_round" TOTAL_ROUNDS = "total_rounds" JOB_ID = "job_id" diff --git a/nvflare/app_common/executors/client_api_launcher_executor.py b/nvflare/app_common/executors/client_api_launcher_executor.py index 776c5a0206..f15dad07b8 100644 --- a/nvflare/app_common/executors/client_api_launcher_executor.py +++ b/nvflare/app_common/executors/client_api_launcher_executor.py @@ -15,6 +15,7 @@ import os from typing import Optional +from nvflare.apis.fl_constant import FLMetaKey from nvflare.apis.fl_context import FLContext from nvflare.app_common.executors.launcher_executor import LauncherExecutor from nvflare.client.config import ConfigKey, ExchangeFormat, TransferType, write_config_to_file @@ -123,8 +124,8 @@ def prepare_config_for_launch(self, fl_ctx: FLContext): config_data = { ConfigKey.TASK_EXCHANGE: task_exchange_attributes, - ConfigKey.SITE_NAME: fl_ctx.get_identity_name(), - ConfigKey.JOB_ID: fl_ctx.get_job_id(), + FLMetaKey.SITE_NAME: fl_ctx.get_identity_name(), + FLMetaKey.JOB_ID: fl_ctx.get_job_id(), } config_file_path = self._get_external_config_file_path(fl_ctx) diff --git a/nvflare/app_common/executors/exec_task_fn_wrapper.py b/nvflare/app_common/executors/exec_task_fn_wrapper.py new file mode 100644 index 0000000000..b85c862b3b --- /dev/null +++ b/nvflare/app_common/executors/exec_task_fn_wrapper.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import sys +import traceback +from typing import Dict + +from nvflare.fuel.utils.function_utils import find_task_fn, require_arguments + + +class ExecTaskFuncWrapper: + def __init__(self, task_fn_path: str, task_fn_args: Dict = None): + """Wrapper for function given function path and args + + Args: + task_fn_path (str): function path (ex: train.main, custom/train.main, custom.train.main). + task_fn_args (Dict, optional): function arguments to pass in. + """ + self.task_fn_path = task_fn_path + self.task_fn_args = task_fn_args + self.client_api = None + self.logger = logging.getLogger(self.__class__.__name__) + + self.task_fn = find_task_fn(task_fn_path) + require_args, args_size, args_default_size = require_arguments(self.task_fn) + self.check_fn_inputs(task_fn_path, require_args, args_size, args_default_size) + self.task_fn_require_args = require_args + + def run(self): + """Call the task_fn with any required arguments.""" + msg = f"\n start task run() with {self.task_fn_path}" + msg = msg if not self.task_fn_require_args else msg + f", {self.task_fn_args}" + self.logger.info(msg) + try: + if self.task_fn.__name__ == "main": + args_list = [] + for k, v in self.task_fn_args.items(): + args_list.extend(["--" + str(k), str(v)]) + + curr_argv = sys.argv + sys.argv = [self.task_fn_path.rsplit(".", 1)[0].replace(".", "/") + ".py"] + args_list + self.task_fn() + sys.argv = curr_argv + elif self.task_fn_require_args: + self.task_fn(**self.task_fn_args) + else: + self.task_fn() + except Exception as e: + msg = traceback.format_exc() + self.logger.error(msg) + if self.client_api: + self.client_api.exec_queue.ask_abort(msg) + raise e + + def check_fn_inputs(self, task_fn_path, require_args: bool, required_args_size: int, args_default_size: int): + """Check if the provided task_fn_args are compatible with the task_fn.""" + if require_args: + if not self.task_fn_args: + raise ValueError(f"function '{task_fn_path}' requires arguments, but none provided") + elif len(self.task_fn_args) < required_args_size - args_default_size: + raise ValueError( + f"function '{task_fn_path}' requires {required_args_size} " + f"arguments, but {len(self.task_fn_args)} provided" + ) + else: + if self.task_fn_args and self.task_fn.__name__ != "main": + msg = f"function '{task_fn_path}' does not require arguments, {self.task_fn_args} will be ignored" + self.logger.warning(msg) diff --git a/nvflare/app_common/executors/in_process_client_api_executor.py b/nvflare/app_common/executors/in_process_client_api_executor.py new file mode 100644 index 0000000000..2740f605ad --- /dev/null +++ b/nvflare/app_common/executors/in_process_client_api_executor.py @@ -0,0 +1,212 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +import time +from typing import Dict, Optional + +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import FLMetaKey, ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.apis.utils.analytix_utils import create_analytic_dxo +from nvflare.app_common.abstract.params_converter import ParamsConverter +from nvflare.app_common.executors.exec_task_fn_wrapper import ExecTaskFuncWrapper +from nvflare.app_common.tracking.tracker_types import ANALYTIC_EVENT_TYPE +from nvflare.app_common.widgets.streaming import send_analytic_dxo +from nvflare.client.api_spec import CLIENT_API_KEY +from nvflare.client.config import ConfigKey, ExchangeFormat, TransferType +from nvflare.client.in_process.api import ( + TOPIC_ABORT, + TOPIC_GLOBAL_RESULT, + TOPIC_LOCAL_RESULT, + TOPIC_LOG_DATA, + TOPIC_STOP, + InProcessClientAPI, +) +from nvflare.fuel.data_event.data_bus import DataBus +from nvflare.fuel.data_event.event_manager import EventManager +from nvflare.fuel.utils.validation_utils import check_object_type +from nvflare.security.logging import secure_format_traceback + + +class InProcessClientAPIExecutor(Executor): + def __init__( + self, + task_fn_path: str, + task_fn_args: Dict = None, + task_wait_time: Optional[float] = None, + result_pull_interval: float = 0.5, + log_pull_interval: Optional[float] = None, + params_exchange_format: str = ExchangeFormat.NUMPY, + params_transfer_type: TransferType = TransferType.FULL, + from_nvflare_converter_id: Optional[str] = None, + to_nvflare_converter_id: Optional[str] = None, + train_with_evaluation: bool = True, + train_task_name: str = "train", + evaluate_task_name: str = "evaluate", + submit_model_task_name: str = "submit_model", + ): + super(InProcessClientAPIExecutor, self).__init__() + self._result_pull_interval = result_pull_interval + self._log_pull_interval = log_pull_interval + self._params_exchange_format = params_exchange_format + self._params_transfer_type = params_transfer_type + self._task_fn_path = task_fn_path + self._task_fn_args = task_fn_args + self._task_wait_time = task_wait_time + + # flags to indicate whether the launcher side will send back trained model and/or metrics + self._train_with_evaluation = train_with_evaluation + self._train_task_name = train_task_name + self._evaluate_task_name = evaluate_task_name + self._submit_model_task_name = submit_model_task_name + + self._from_nvflare_converter_id = from_nvflare_converter_id + self._from_nvflare_converter: Optional[ParamsConverter] = None + self._to_nvflare_converter_id = to_nvflare_converter_id + self._to_nvflare_converter: Optional[ParamsConverter] = None + + self._task_fn_wrapper = ExecTaskFuncWrapper( + task_fn_path=self._task_fn_path, task_fn_args=self._task_fn_args, read_interval=self._result_pull_interval + ) + self._engine = None + self._task_fn_thread = None + self._log_thread = None + self._data_bus = DataBus() + self._event_manager = EventManager(self._data_bus) + self._data_bus.subscribe([TOPIC_LOCAL_RESULT], self.local_result_callback) + self._data_bus.subscribe([TOPIC_LOG_DATA], self.log_result_callback) + self.local_result = None + self._fl_ctx = None + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + super().handle_event(event_type, fl_ctx) + self._engine = fl_ctx.get_engine() + self._fl_ctx = fl_ctx + self._init_converter(fl_ctx) + + self._task_fn_thread = threading.Thread(target=self._task_fn_wrapper.run) + self._task_fn_thread.start() + + elif event_type == EventType.END_RUN: + self._event_manager.fire_event(TOPIC_STOP, "END_RUN received") + if self._task_fn_thread: + self._task_fn_thread.join() + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + self.log_info(fl_ctx, f"execute for task ({task_name})") + try: + fl_ctx.set_prop("abort_signal", abort_signal) + + meta = self._prepare_task_meta(fl_ctx, task_name) + client_api = InProcessClientAPI(task_metadata=meta, result_check_interval=0.5) + client_api.init() + self._data_bus.put_data(CLIENT_API_KEY, client_api) + + shareable.set_header(FLMetaKey.JOB_ID, fl_ctx.get_job_id()) + shareable.set_header(FLMetaKey.SITE_NAME, fl_ctx.get_identity_name()) + if self._from_nvflare_converter is not None: + shareable = self._from_nvflare_converter.process(task_name, shareable, fl_ctx) + + self.log_info(fl_ctx, "send data to peer") + + self.send_data_to_peer(shareable, fl_ctx) + + # wait for result + self.log_info(fl_ctx, "Waiting for result from peer") + while True: + if abort_signal.triggered: + # notify peer that the task is aborted + self._event_manager.fire_event(TOPIC_ABORT, f"{task_name}' is aborted, abort_signal_triggered") + return make_reply(ReturnCode.TASK_ABORTED) + + if self.local_result: + result = self.local_result + self.local_result = None + if self._to_nvflare_converter is not None: + result = self._to_nvflare_converter.process(task_name, result, fl_ctx) + return result + else: + self.log_debug(fl_ctx, f"waiting for result, sleep for {self._result_pull_interval} secs") + time.sleep(self._result_pull_interval) + + except Exception as e: + self.log_error(fl_ctx, secure_format_traceback()) + self._event_manager.fire_event(TOPIC_ABORT, f"{task_name}' failed: {secure_format_traceback()}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + def _prepare_task_meta(self, fl_ctx, task_name): + job_id = fl_ctx.get_job_id() + site_name = fl_ctx.get_identity_name() + meta = { + FLMetaKey.SITE_NAME: site_name, + FLMetaKey.JOB_ID: job_id, + ConfigKey.TASK_NAME: task_name, + ConfigKey.TASK_EXCHANGE: { + ConfigKey.TRAIN_WITH_EVAL: self._train_with_evaluation, + ConfigKey.EXCHANGE_FORMAT: self._params_exchange_format, + ConfigKey.TRANSFER_TYPE: self._params_transfer_type, + ConfigKey.TRAIN_TASK_NAME: self._train_task_name, + ConfigKey.EVAL_TASK_NAME: self._evaluate_task_name, + ConfigKey.SUBMIT_MODEL_TASK_NAME: self._submit_model_task_name, + }, + } + return meta + + def send_data_to_peer(self, shareable, fl_ctx: FLContext): + self.log_info(fl_ctx, "sending payload to peer") + self._event_manager.fire_event(TOPIC_GLOBAL_RESULT, shareable) + + def _init_converter(self, fl_ctx: FLContext): + engine = fl_ctx.get_engine() + from_nvflare_converter: ParamsConverter = engine.get_component(self._from_nvflare_converter_id) + if from_nvflare_converter is not None: + check_object_type(self._from_nvflare_converter_id, from_nvflare_converter, ParamsConverter) + self._from_nvflare_converter = from_nvflare_converter + + to_nvflare_converter: ParamsConverter = engine.get_component(self._to_nvflare_converter_id) + if to_nvflare_converter is not None: + check_object_type(self._to_nvflare_converter_id, to_nvflare_converter, ParamsConverter) + self._to_nvflare_converter = to_nvflare_converter + + def check_output_shareable(self, task_name: str, shareable, fl_ctx: FLContext): + """Checks output shareable after execute.""" + if not isinstance(shareable, Shareable): + msg = f"bad task result from peer: expect Shareable but got {type(shareable)}" + self.log_error(fl_ctx, msg) + raise ValueError(msg) + + def local_result_callback(self, topic, data, databus): + if not isinstance(data, Shareable): + msg = f"bad task result from peer: expect Shareable but got {type(data)}" + self.logger(msg) + raise ValueError(msg) + + self.local_result = data + + def log_result_callback(self, topic, data, databus): + result = data + if result and not isinstance(result, dict): + raise ValueError(f"invalid result format, expecting Dict, but get {type(result)}") + + if "key" in result: + result["tag"] = result.pop("key") + dxo = create_analytic_dxo(**result) + + # fire_fed_event = True w/o fed_event_converter somehow did not work + with self._engine.new_context() as fl_ctx: + send_analytic_dxo(self, dxo=dxo, fl_ctx=fl_ctx, event_type=ANALYTIC_EVENT_TYPE, fire_fed_event=False) diff --git a/nvflare/app_common/launchers/subprocess_launcher.py b/nvflare/app_common/launchers/subprocess_launcher.py index 8555747004..6884b6d6a0 100644 --- a/nvflare/app_common/launchers/subprocess_launcher.py +++ b/nvflare/app_common/launchers/subprocess_launcher.py @@ -61,6 +61,8 @@ def _start_external_process(self): if self._process is None: command = self._script env = os.environ.copy() + env["CLIENT_API_TYPE"] = "EX_PROCESS_API" + command_seq = shlex.split(command) self._process = subprocess.Popen( diff --git a/nvflare/app_common/widgets/external_configurator.py b/nvflare/app_common/widgets/external_configurator.py index f9d4b53ee6..b9b8c5b14d 100644 --- a/nvflare/app_common/widgets/external_configurator.py +++ b/nvflare/app_common/widgets/external_configurator.py @@ -16,8 +16,9 @@ from typing import List from nvflare.apis.event_type import EventType +from nvflare.apis.fl_constant import FLMetaKey from nvflare.apis.fl_context import FLContext -from nvflare.client.config import ConfigKey, write_config_to_file +from nvflare.client.config import write_config_to_file from nvflare.client.constants import CLIENT_API_CONFIG from nvflare.fuel.utils.attributes_exportable import ExportMode, export_components from nvflare.fuel.utils.validation_utils import check_object_type @@ -46,8 +47,8 @@ def __init__( def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.ABOUT_TO_START_RUN: components_data = self._export_all_components(fl_ctx) - components_data[ConfigKey.SITE_NAME] = fl_ctx.get_identity_name() - components_data[ConfigKey.JOB_ID] = fl_ctx.get_job_id() + components_data[FLMetaKey.SITE_NAME] = fl_ctx.get_identity_name() + components_data[FLMetaKey.JOB_ID] = fl_ctx.get_job_id() config_file_path = self._get_external_config_file_path(fl_ctx) write_config_to_file(config_data=components_data, config_file_path=config_file_path) @@ -64,5 +65,5 @@ def _export_all_components(self, fl_ctx: FLContext) -> dict: engine = fl_ctx.get_engine() all_components = engine.get_all_components() components = {i: all_components.get(i) for i in self._component_ids} - reserved_keys = [ConfigKey.SITE_NAME, ConfigKey.JOB_ID] + reserved_keys = [FLMetaKey.SITE_NAME, FLMetaKey.JOB_ID] return export_components(components=components, reserved_keys=reserved_keys, export_mode=ExportMode.PEER) diff --git a/nvflare/app_opt/lightning/api.py b/nvflare/app_opt/lightning/api.py index 9035b2fe7c..567fe1b0a1 100644 --- a/nvflare/app_opt/lightning/api.py +++ b/nvflare/app_opt/lightning/api.py @@ -19,17 +19,7 @@ from torch import Tensor from nvflare.app_common.abstract.fl_model import FLModel, MetaKey -from nvflare.client.api import ( - clear, - get_config, - get_model_registry, - init, - is_evaluate, - is_submit_model, - is_train, - receive, - send, -) +from nvflare.client.api import clear, get_config, init, is_evaluate, is_submit_model, is_train, receive, send from nvflare.client.config import ConfigKey from .callbacks import RestoreState @@ -196,7 +186,6 @@ def _receive_and_update_model(self, trainer, pl_module): def _receive_model(self, trainer) -> FLModel: """Receives model from NVFlare.""" - registry = get_model_registry() model = None _is_training = False _is_evaluation = False @@ -208,8 +197,6 @@ def _receive_model(self, trainer) -> FLModel: _is_submit_model = is_submit_model() model = trainer.strategy.broadcast(model, src=0) - task_name = trainer.strategy.broadcast(registry.task_name, src=0) - registry.set_task_name(task_name) self._is_training = trainer.strategy.broadcast(_is_training, src=0) self._is_evaluation = trainer.strategy.broadcast(_is_evaluation, src=0) self._is_submit_model = trainer.strategy.broadcast(_is_submit_model, src=0) @@ -217,7 +204,7 @@ def _receive_model(self, trainer) -> FLModel: def _send_model(self, output_model: FLModel): try: - send(output_model, clear_registry=False) + send(output_model, clear_cache=False) except Exception as e: raise RuntimeError(f"failed to send FL model: {e}") diff --git a/nvflare/app_opt/pt/in_process_client_api_executor.py b/nvflare/app_opt/pt/in_process_client_api_executor.py new file mode 100644 index 0000000000..165fcee77e --- /dev/null +++ b/nvflare/app_opt/pt/in_process_client_api_executor.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, Optional + +from nvflare.app_common.app_constant import AppConstants +from nvflare.app_common.executors.in_process_client_api_executor import InProcessClientAPIExecutor +from nvflare.app_opt.pt.decomposers import TensorDecomposer +from nvflare.app_opt.pt.params_converter import NumpyToPTParamsConverter, PTToNumpyParamsConverter +from nvflare.client.config import ExchangeFormat, TransferType +from nvflare.fuel.utils import fobs + + +class PTInProcessClientAPIExecutor(InProcessClientAPIExecutor): + def __init__( + self, + task_fn_path: str, + task_fn_args: Dict = None, + task_wait_time: Optional[float] = None, + result_pull_interval: float = 0.5, + log_pull_interval: Optional[float] = None, + params_transfer_type: TransferType = TransferType.FULL, + from_nvflare_converter_id: Optional[str] = None, + to_nvflare_converter_id: Optional[str] = None, + train_with_evaluation: bool = True, + train_task_name: str = "train", + evaluate_task_name: str = "evaluate", + submit_model_task_name: str = "submit_model", + ): + super(PTInProcessClientAPIExecutor, self).__init__( + task_fn_path=task_fn_path, + task_fn_args=task_fn_args, + task_wait_time=task_wait_time, + result_pull_interval=result_pull_interval, + train_with_evaluation=train_with_evaluation, + train_task_name=train_task_name, + evaluate_task_name=evaluate_task_name, + submit_model_task_name=submit_model_task_name, + from_nvflare_converter_id=from_nvflare_converter_id, + to_nvflare_converter_id=to_nvflare_converter_id, + params_exchange_format=ExchangeFormat.PYTORCH, + params_transfer_type=params_transfer_type, + log_pull_interval=log_pull_interval, + ) + fobs.register(TensorDecomposer) + + if self._from_nvflare_converter is None: + self._from_nvflare_converter = NumpyToPTParamsConverter( + [AppConstants.TASK_TRAIN, AppConstants.TASK_VALIDATION] + ) + if self._to_nvflare_converter is None: + self._to_nvflare_converter = PTToNumpyParamsConverter( + [AppConstants.TASK_TRAIN, AppConstants.TASK_SUBMIT_MODEL] + ) diff --git a/nvflare/client/__init__.py b/nvflare/client/__init__.py index f910c2c8cf..0e0583816a 100644 --- a/nvflare/client/__init__.py +++ b/nvflare/client/__init__.py @@ -19,7 +19,6 @@ from nvflare.app_common.abstract.fl_model import FLModel as FLModel from nvflare.app_common.abstract.fl_model import ParamsType as ParamsType -from .api import clear as clear from .api import get_config as get_config from .api import get_job_id as get_job_id from .api import get_site_name as get_site_name diff --git a/nvflare/client/api.py b/nvflare/client/api.py index f9aaf886bb..2cb43fdbcb 100644 --- a/nvflare/client/api.py +++ b/nvflare/client/api.py @@ -11,58 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import importlib import os -from typing import Any, Dict, Optional, Tuple +from enum import Enum +from typing import Any, Dict, Optional from nvflare.apis.analytix import AnalyticsDataType -from nvflare.apis.utils.analytix_utils import create_analytic_dxo from nvflare.app_common.abstract.fl_model import FLModel -from nvflare.fuel.utils import fobs -from nvflare.fuel.utils.import_utils import optional_import -from nvflare.fuel.utils.pipe.pipe import Pipe - -from .config import ClientConfig, ConfigKey, ExchangeFormat, from_file -from .constants import CLIENT_API_CONFIG -from .flare_agent import FlareAgentException -from .flare_agent_with_fl_model import FlareAgentWithFLModel -from .model_registry import ModelRegistry - -PROCESS_MODEL_REGISTRY = None - - -def _create_client_config(config: str) -> ClientConfig: - if isinstance(config, str): - client_config = from_file(config_file=config) - else: - raise ValueError("config should be a string.") - return client_config +from nvflare.fuel.data_event.data_bus import DataBus +from .api_spec import CLIENT_API_KEY, CLIENT_API_TYPE_KEY, APISpec +from .ex_process.api import ExProcessClientAPI -def _create_pipe_using_config(client_config: ClientConfig, section: str) -> Tuple[Pipe, str]: - pipe_class_name = client_config.get_pipe_class(section) - module_name, _, class_name = pipe_class_name.rpartition(".") - module = importlib.import_module(module_name) - pipe_class = getattr(module, class_name) - pipe_args = client_config.get_pipe_args(section) - pipe = pipe_class(**pipe_args) - pipe_channel_name = client_config.get_pipe_channel_name(section) - return pipe, pipe_channel_name +class ClientAPIType(Enum): + IN_PROCESS_API = "IN_PROCESS_API" + EX_PROCESS_API = "EX_PROCESS_API" -def _register_tensor_decomposer(): - tensor_decomposer, ok = optional_import(module="nvflare.app_opt.pt.decomposers", name="TensorDecomposer") - if ok: - fobs.register(tensor_decomposer) - else: - raise RuntimeError(f"Can't import TensorDecomposer for format: {ExchangeFormat.PYTORCH}") +client_api: Optional[APISpec] = None +data_bus = DataBus() -def init( - rank: Optional[str] = None, -) -> None: +def init(rank: Optional[str] = None): """Initializes NVFlare Client API environment. Args: @@ -71,60 +41,16 @@ def init( Returns: None - - Example: - - .. code-block:: python - - nvflare.client.init() - - """ - global PROCESS_MODEL_REGISTRY # Declare PROCESS_MODEL_REGISTRY as global - - if rank is None: - rank = os.environ.get("RANK", "0") - - if PROCESS_MODEL_REGISTRY: - print("Warning: called init() more than once. The subsequence calls are ignored") - return - - client_config = _create_client_config(config=f"config/{CLIENT_API_CONFIG}") - - flare_agent = None - try: - if rank == "0": - if client_config.get_exchange_format() == ExchangeFormat.PYTORCH: - _register_tensor_decomposer() - - pipe, task_channel_name = _create_pipe_using_config( - client_config=client_config, section=ConfigKey.TASK_EXCHANGE - ) - metric_pipe, metric_channel_name = None, "" - if ConfigKey.METRICS_EXCHANGE in client_config.config: - metric_pipe, metric_channel_name = _create_pipe_using_config( - client_config=client_config, section=ConfigKey.METRICS_EXCHANGE - ) - - flare_agent = FlareAgentWithFLModel( - pipe=pipe, - task_channel_name=task_channel_name, - metric_pipe=metric_pipe, - metric_channel_name=metric_channel_name, - ) - flare_agent.start() - - PROCESS_MODEL_REGISTRY = ModelRegistry(client_config, rank, flare_agent) - except Exception as e: - print(f"flare.init failed: {e}") - raise e - + api_type_name = os.environ.get(CLIENT_API_TYPE_KEY, ClientAPIType.IN_PROCESS_API.value) + api_type = ClientAPIType(api_type_name) + global client_api + if api_type == ClientAPIType.IN_PROCESS_API: + client_api = data_bus.get_data(CLIENT_API_KEY) + else: + client_api = ExProcessClientAPI() -def get_model_registry() -> ModelRegistry: - """Gets the ModelRegistry.""" - if PROCESS_MODEL_REGISTRY is None: - raise RuntimeError("needs to call init method first") - return PROCESS_MODEL_REGISTRY + client_api.init(rank) def receive(timeout: Optional[float] = None) -> Optional[FLModel]: @@ -132,50 +58,20 @@ def receive(timeout: Optional[float] = None) -> Optional[FLModel]: Returns: An FLModel received. - - Example: - - .. code-block:: python - - nvflare.client.receive() - """ - model_registry = get_model_registry() - return model_registry.get_model(timeout) + global client_api + return client_api.receive(timeout) -def send(fl_model: FLModel, clear_registry: bool = True) -> None: +def send(model: FLModel, clear_cache: bool = True) -> None: """Sends the model to NVFlare side. Args: - fl_model (FLModel): Sends a FLModel object. - clear_registry (bool): To clear the registry or not. - - Example: - - .. code-block:: python - - nvflare.client.send(fl_model=FLModel(...)) - - """ - model_registry = get_model_registry() - model_registry.submit_model(model=fl_model) - if clear_registry: - clear() - - -def clear(): - """Clears the model registry. - - Example: - - .. code-block:: python - - nvflare.client.clear() - + model (FLModel): Sends a FLModel object. + clear_cache: clear cache after send """ - model_registry = get_model_registry() - model_registry.clear() + global client_api + return client_api.send(model, clear_cache) def system_info() -> Dict: @@ -190,15 +86,9 @@ def system_info() -> Dict: Returns: A dict of system information. - Example: - - .. code-block:: python - - sys_info = nvflare.client.system_info() - """ - model_registry = get_model_registry() - return model_registry.get_sys_info() + global client_api + return client_api.system_info() def get_config() -> Dict: @@ -206,16 +96,9 @@ def get_config() -> Dict: Returns: A dict of the configuration used in Client API. - - Example: - - .. code-block:: python - - config = nvflare.client.get_config() - """ - model_registry = get_model_registry() - return model_registry.config.config + global client_api + return client_api.get_config() def get_job_id() -> str: @@ -223,16 +106,9 @@ def get_job_id() -> str: Returns: The current job id. - - Example: - - .. code-block:: python - - job_id = nvflare.client.get_job_id() - """ - sys_info = system_info() - return sys_info.get(ConfigKey.JOB_ID, "") + global client_api + return client_api.get_job_id() def get_site_name() -> str: @@ -240,16 +116,19 @@ def get_site_name() -> str: Returns: The site name of this client. + """ + global client_api + return client_api.get_site_name() - Example: - - .. code-block:: python - site_name = nvflare.client.get_site_name() +def get_task_name() -> str: + """Gets task name. + Returns: + The task name. """ - sys_info = system_info() - return sys_info.get(ConfigKey.SITE_NAME, "") + global client_api + return client_api.get_task_name() def is_running() -> bool: @@ -257,21 +136,9 @@ def is_running() -> bool: Returns: True, if the system is up and running. False, otherwise. - - Example: - - .. code-block:: python - - while nvflare.client.is_running(): - # receive model, perform task, send model, etc. - ... - """ - try: - receive() - return True - except FlareAgentException: - return False + global client_api + return client_api.is_running() def is_train() -> bool: @@ -279,20 +146,9 @@ def is_train() -> bool: Returns: True, if the current task is a training task. False, otherwise. - - Example: - - .. code-block:: python - - if nvflare.client.is_train(): - # perform train task on received model - ... - """ - model_registry = get_model_registry() - if model_registry.rank != "0": - raise RuntimeError("only rank 0 can call is_train!") - return model_registry.task_name == model_registry.config.get_train_task() + global client_api + return client_api.is_train() def is_evaluate() -> bool: @@ -300,20 +156,9 @@ def is_evaluate() -> bool: Returns: True, if the current task is an evaluate task. False, otherwise. - - Example: - - .. code-block:: python - - if nvflare.client.is_evaluate(): - # perform evaluate task on received model - ... - """ - model_registry = get_model_registry() - if model_registry.rank != "0": - raise RuntimeError("only rank 0 can call is_evaluate!") - return model_registry.task_name == model_registry.config.get_eval_task() + global client_api + return client_api.is_evaluate() def is_submit_model() -> bool: @@ -321,23 +166,12 @@ def is_submit_model() -> bool: Returns: True, if the current task is a submit_model. False, otherwise. - - Example: - - .. code-block:: python - - if nvflare.client.is_submit_model(): - # perform submit_model task to obtain the best local model - ... - """ - model_registry = get_model_registry() - if model_registry.rank != "0": - raise RuntimeError("only rank 0 can call is_submit_model!") - return model_registry.task_name == model_registry.config.get_submit_model_task() + global client_api + return client_api.is_submit_model() -def log(key: str, value: Any, data_type: AnalyticsDataType, **kwargs) -> bool: +def log(key: str, value: Any, data_type: AnalyticsDataType, **kwargs): """Logs a key value pair. We suggest users use the high-level APIs in nvflare/client/tracking.py @@ -350,25 +184,12 @@ def log(key: str, value: Any, data_type: AnalyticsDataType, **kwargs) -> bool: Returns: whether the key value pair is logged successfully - - Example: - - .. code-block:: python - - log( - key=tag, - value=scalar, - data_type=AnalyticsDataType.SCALAR, - global_step=global_step, - writer=LogWriterName.TORCH_TB, - **kwargs, - ) - """ - model_registry = get_model_registry() - if model_registry.rank != "0": - raise RuntimeError("only rank 0 can call log!") + global client_api + return client_api.log(key, value, data_type, **kwargs) - flare_agent = model_registry.flare_agent - dxo = create_analytic_dxo(tag=key, value=value, data_type=data_type, **kwargs) - return flare_agent.log(dxo) + +def clear(): + """Clears the cache.""" + global client_api + return client_api.clear() diff --git a/nvflare/client/api_spec.py b/nvflare/client/api_spec.py new file mode 100644 index 0000000000..fd13c333ae --- /dev/null +++ b/nvflare/client/api_spec.py @@ -0,0 +1,279 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +from nvflare.apis.analytix import AnalyticsDataType +from nvflare.app_common.abstract.fl_model import FLModel + +CLIENT_API_KEY = "CLIENT_API" +CLIENT_API_TYPE_KEY = "CLIENT_API_TYPE" + + +class APISpec(ABC): + @abstractmethod + def init(self, rank: Optional[str] = None): + """Initializes NVFlare Client API environment. + + Args: + rank (str): local rank of the process. + It is only useful when the training script has multiple worker processes. (for example multi GPU) + + Returns: + None + + Example: + + .. code-block:: python + + nvflare.client.init() + + """ + pass + + @abstractmethod + def receive(self, timeout: Optional[float] = None) -> Optional[FLModel]: + """Receives model from NVFlare side. + + Returns: + An FLModel received. + + Example: + + .. code-block:: python + + nvflare.client.receive() + + """ + pass + + @abstractmethod + def send(self, model: FLModel, clear_cache: bool = True) -> None: + """Sends the model to NVFlare side. + + Args: + fl_model (FLModel): Sends a FLModel object. + clear_cache (bool): clear cache after send. + + Example: + + .. code-block:: python + + nvflare.client.send(fl_model=FLModel(...)) + + """ + pass + + @abstractmethod + def system_info(self) -> Dict: + """Gets NVFlare system information. + + System information will be available after a valid FLModel is received. + It does not retrieve information actively. + + Note: + system information includes job id and site name. + + Returns: + A dict of system information. + + Example: + + .. code-block:: python + + sys_info = nvflare.client.system_info() + + """ + pass + + @abstractmethod + def get_config(self) -> Dict: + """Gets the ClientConfig dictionary. + + Returns: + A dict of the configuration used in Client API. + + Example: + + .. code-block:: python + + config = nvflare.client.get_config() + + """ + pass + + @abstractmethod + def get_job_id(self) -> str: + """Gets job id. + + Returns: + The current job id. + + Example: + + .. code-block:: python + + job_id = nvflare.client.get_job_id() + + """ + pass + + @abstractmethod + def get_site_name(self) -> str: + """Gets site name. + + Returns: + The site name of this client. + + Example: + + .. code-block:: python + + site_name = nvflare.client.get_site_name() + + """ + pass + + @abstractmethod + def get_task_name(self) -> str: + """Gets task name. + + Returns: + The task name. + + Example: + + .. code-block:: python + + task_name = nvflare.client.get_task_name() + + """ + pass + + @abstractmethod + def is_running(self) -> bool: + """Returns whether the NVFlare system is up and running. + + Returns: + True, if the system is up and running. False, otherwise. + + Example: + + .. code-block:: python + + while nvflare.client.is_running(): + # receive model, perform task, send model, etc. + ... + + """ + pass + + @abstractmethod + def is_train(self) -> bool: + """Returns whether the current task is a training task. + + Returns: + True, if the current task is a training task. False, otherwise. + + Example: + + .. code-block:: python + + if nvflare.client.is_train(): + # perform train task on received model + ... + + """ + pass + + @abstractmethod + def is_evaluate(self) -> bool: + """Returns whether the current task is an evaluate task. + + Returns: + True, if the current task is an evaluate task. False, otherwise. + + Example: + + .. code-block:: python + + if nvflare.client.is_evaluate(): + # perform evaluate task on received model + ... + + """ + pass + + @abstractmethod + def is_submit_model(self) -> bool: + """Returns whether the current task is a submit_model task. + + Returns: + True, if the current task is a submit_model. False, otherwise. + + Example: + + .. code-block:: python + + if nvflare.client.is_submit_model(): + # perform submit_model task to obtain the best local model + ... + + """ + pass + + @abstractmethod + def log(self, key: str, value: Any, data_type: AnalyticsDataType, **kwargs): + """Logs a key value pair. + + We suggest users use the high-level APIs in nvflare/client/tracking.py + + Args: + key (str): key string. + value (Any): value to log. + data_type (AnalyticsDataType): the data type of the "value". + kwargs: additional arguments to be included. + + Returns: + whether the key value pair is logged successfully + + Example: + + .. code-block:: python + + log( + key=tag, + value=scalar, + data_type=AnalyticsDataType.SCALAR, + global_step=global_step, + writer=LogWriterName.TORCH_TB, + **kwargs, + ) + + """ + pass + + @abstractmethod + def clear(self): + """Clears the cache. + + Example: + + .. code-block:: python + + nvflare.client.clear() + + """ + pass diff --git a/nvflare/client/config.py b/nvflare/client/config.py index e85d3ab837..c47326b562 100644 --- a/nvflare/client/config.py +++ b/nvflare/client/config.py @@ -41,8 +41,7 @@ class ConfigKey: PIPE = "pipe" CLASS_NAME = "CLASS_NAME" ARG = "ARG" - SITE_NAME = "SITE_NAME" - JOB_ID = "JOB_ID" + TASK_NAME = "TASK_NAME" TASK_EXCHANGE = "TASK_EXCHANGE" METRICS_EXCHANGE = "METRICS_EXCHANGE" @@ -121,7 +120,7 @@ def __init__(self, config: Optional[Dict] = None): config = {} self.config = config - def get_config(self): + def get_config(self) -> Dict: return self.config def get_pipe_channel_name(self, section: str) -> str: diff --git a/nvflare/client/constants.py b/nvflare/client/constants.py index 9d74cdf648..4458d07ee6 100644 --- a/nvflare/client/constants.py +++ b/nvflare/client/constants.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config import ConfigKey +from nvflare.apis.fl_constant import FLMetaKey -SYS_ATTRS = (ConfigKey.JOB_ID, ConfigKey.SITE_NAME) +SYS_ATTRS = (FLMetaKey.JOB_ID, FLMetaKey.SITE_NAME) CLIENT_API_CONFIG = "client_api_config.json" diff --git a/nvflare/client/decorator.py b/nvflare/client/decorator.py index a70e0bba18..e00e4423f5 100644 --- a/nvflare/client/decorator.py +++ b/nvflare/client/decorator.py @@ -17,7 +17,7 @@ from nvflare.app_common.abstract.fl_model import FLModel -from .api import get_model_registry, is_train +from .api import is_train, receive, send def _replace_func_args(func, kwargs, model: FLModel): @@ -26,6 +26,13 @@ def _replace_func_args(func, kwargs, model: FLModel): kwargs[first_params.name] = model +class ObjectHolder: + pass + + +object_holder = ObjectHolder() + + def train( _func=None, **root_kwargs, @@ -50,9 +57,7 @@ def my_train(input_model=None, device="cuda:0"): def decorator(train_fn): @functools.wraps(train_fn) def wrapper(*args, **kwargs): - model_registry = get_model_registry() - input_model = model_registry.get_model() - + input_model = receive() # Replace func arguments _replace_func_args(train_fn, kwargs, input_model) return_value = train_fn(**kwargs) @@ -62,11 +67,13 @@ def wrapper(*args, **kwargs): elif not isinstance(return_value, FLModel): raise RuntimeError("return value needs to be an FLModel.") - if model_registry.metrics is not None: - return_value.metrics = model_registry.metrics + global object_holder + + if object_holder.metrics is not None: + return_value.metrics = object_holder.metrics + object_holder = None - model_registry.submit_model(model=return_value) - model_registry.clear() + send(model=return_value) return return_value @@ -104,20 +111,19 @@ def my_eval(input_model, device="cuda:0"): def decorator(eval_fn): @functools.wraps(eval_fn) def wrapper(*args, **kwargs): - model_registry = get_model_registry() - input_model = model_registry.get_model() + input_model = receive() _replace_func_args(eval_fn, kwargs, input_model) return_value = eval_fn(**kwargs) if return_value is None: raise RuntimeError("return value is None!") + global object_holder if is_train(): - model_registry.metrics = return_value + object_holder.metrics = return_value else: - model_registry.submit_model(model=FLModel(metrics=return_value)) - model_registry.clear() + send(model=FLModel(metrics=return_value)) return return_value diff --git a/nvflare/client/ex_process/__init__.py b/nvflare/client/ex_process/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/nvflare/client/ex_process/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/client/ex_process/api.py b/nvflare/client/ex_process/api.py new file mode 100644 index 0000000000..912aaea03a --- /dev/null +++ b/nvflare/client/ex_process/api.py @@ -0,0 +1,204 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import os +from typing import Any, Dict, Optional, Tuple + +from nvflare.apis.analytix import AnalyticsDataType +from nvflare.apis.fl_constant import FLMetaKey +from nvflare.apis.utils.analytix_utils import create_analytic_dxo +from nvflare.app_common.abstract.fl_model import FLModel +from nvflare.client.api_spec import APISpec +from nvflare.client.config import ClientConfig, ConfigKey, ExchangeFormat, from_file +from nvflare.client.constants import CLIENT_API_CONFIG +from nvflare.client.flare_agent import FlareAgentException +from nvflare.client.flare_agent_with_fl_model import FlareAgentWithFLModel +from nvflare.client.model_registry import ModelRegistry +from nvflare.fuel.utils import fobs +from nvflare.fuel.utils.import_utils import optional_import +from nvflare.fuel.utils.pipe.pipe import Pipe + + +def _create_client_config(config: str) -> ClientConfig: + if isinstance(config, str): + client_config = from_file(config_file=config) + else: + raise ValueError("config should be a string but got: {type(config)}") + return client_config + + +def _create_pipe_using_config(client_config: ClientConfig, section: str) -> Tuple[Pipe, str]: + pipe_class_name = client_config.get_pipe_class(section) + module_name, _, class_name = pipe_class_name.rpartition(".") + module = importlib.import_module(module_name) + pipe_class = getattr(module, class_name) + + pipe_args = client_config.get_pipe_args(section) + pipe = pipe_class(**pipe_args) + pipe_channel_name = client_config.get_pipe_channel_name(section) + return pipe, pipe_channel_name + + +def _register_tensor_decomposer(): + tensor_decomposer, ok = optional_import(module="nvflare.app_opt.pt.decomposers", name="TensorDecomposer") + if ok: + fobs.register(tensor_decomposer) + else: + raise RuntimeError(f"Can't import TensorDecomposer for format: {ExchangeFormat.PYTORCH}") + + +class ExProcessClientAPI(APISpec): + def __init__(self): + self.process_model_registry = None + + def get_model_registry(self) -> ModelRegistry: + """Gets the ModelRegistry.""" + if self.process_model_registry is None: + raise RuntimeError("needs to call init method first") + return self.process_model_registry + + def init(self, rank: Optional[str] = None): + """Initializes NVFlare Client API environment. + + Args: + rank (str): local rank of the process. + It is only useful when the training script has multiple worker processes. (for example multi GPU) + """ + + if rank is None: + rank = os.environ.get("RANK", "0") + + if self.process_model_registry: + print("Warning: called init() more than once. The subsequence calls are ignored") + return + + client_config = _create_client_config(config=f"config/{CLIENT_API_CONFIG}") + + flare_agent = None + try: + if rank == "0": + if client_config.get_exchange_format() == ExchangeFormat.PYTORCH: + _register_tensor_decomposer() + + pipe, task_channel_name = _create_pipe_using_config( + client_config=client_config, section=ConfigKey.TASK_EXCHANGE + ) + metric_pipe, metric_channel_name = None, "" + if ConfigKey.METRICS_EXCHANGE in client_config.config: + metric_pipe, metric_channel_name = _create_pipe_using_config( + client_config=client_config, section=ConfigKey.METRICS_EXCHANGE + ) + + flare_agent = FlareAgentWithFLModel( + pipe=pipe, + task_channel_name=task_channel_name, + metric_pipe=metric_pipe, + metric_channel_name=metric_channel_name, + ) + flare_agent.start() + + self.process_model_registry = ModelRegistry(client_config, rank, flare_agent) + except Exception as e: + print(f"flare.init failed: {e}") + raise e + + def receive(self, timeout: Optional[float] = None) -> Optional[FLModel]: + """Receives model from NVFlare side. + + Returns: + An FLModel received. + """ + model_registry = self.get_model_registry() + return model_registry.get_model(timeout) + + def send(self, model: FLModel, clear_cache: bool = True) -> None: + """Sends the model to Controller side. + Args: + model (FLModel): Sends a FLModel object. + clear_cache (bool): To clear the cache or not. + """ + model_registry = self.get_model_registry() + model_registry.submit_model(model=model) + if clear_cache: + self.clear() + + def system_info(self) -> Dict: + """Gets NVFlare system information. + + System information will be available after a valid FLModel is received. + It does not retrieve information actively. + + Returns: + A dict of system information. + """ + model_registry = self.get_model_registry() + return model_registry.get_sys_info() + + def get_config(self) -> Dict: + model_registry = self.get_model_registry() + return model_registry.config.config + + def get_job_id(self) -> str: + sys_info = self.system_info() + return sys_info.get(FLMetaKey.JOB_ID, "") + + def get_site_name(self) -> str: + sys_info = self.system_info() + return sys_info.get(FLMetaKey.SITE_NAME, "") + + def get_task_name(self) -> str: + model_registry = self.get_model_registry() + if model_registry.rank != "0": + raise RuntimeError("only rank 0 can call get_task_name!") + return model_registry.get_task().task_name + + def is_running(self) -> bool: + try: + self.receive() + return True + except FlareAgentException: + return False + + def is_train(self) -> bool: + model_registry = self.get_model_registry() + if model_registry.rank != "0": + raise RuntimeError("only rank 0 can call is_train!") + return model_registry.task_name == model_registry.config.get_train_task() + + def is_evaluate(self) -> bool: + model_registry = self.get_model_registry() + if model_registry.rank != "0": + raise RuntimeError("only rank 0 can call is_evaluate!") + return model_registry.task_name == model_registry.config.get_eval_task() + + def is_submit_model(self) -> bool: + model_registry = self.get_model_registry() + if model_registry.rank != "0": + raise RuntimeError("only rank 0 can call is_submit_model!") + return model_registry.task_name == model_registry.config.get_submit_model_task() + + def log(self, key: str, value: Any, data_type: AnalyticsDataType, **kwargs): + model_registry = self.get_model_registry() + if model_registry.rank != "0": + raise RuntimeError("only rank 0 can call log!") + + flare_agent = model_registry.flare_agent + dxo = create_analytic_dxo(tag=key, value=value, data_type=data_type, **kwargs) + flare_agent.log(dxo) + + def clear(self): + """Clears the model registry.""" + model_registry = self.get_model_registry() + model_registry.clear() diff --git a/nvflare/client/in_process/__init__.py b/nvflare/client/in_process/__init__.py new file mode 100644 index 0000000000..4fc50543f1 --- /dev/null +++ b/nvflare/client/in_process/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/client/in_process/api.py b/nvflare/client/in_process/api.py new file mode 100644 index 0000000000..0a09ec80d8 --- /dev/null +++ b/nvflare/client/in_process/api.py @@ -0,0 +1,214 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import time +from typing import Any, Dict, Optional + +from nvflare.apis.analytix import AnalyticsDataType +from nvflare.apis.fl_constant import FLMetaKey +from nvflare.apis.shareable import Shareable +from nvflare.app_common.abstract.fl_model import FLModel, ParamsType +from nvflare.app_common.utils.fl_model_utils import FLModelUtils +from nvflare.client.api_spec import APISpec +from nvflare.client.config import ClientConfig, ConfigKey +from nvflare.client.constants import SYS_ATTRS +from nvflare.client.utils import DIFF_FUNCS +from nvflare.fuel.data_event.data_bus import DataBus +from nvflare.fuel.data_event.event_manager import EventManager + +TOPIC_LOG_DATA = "LOG_DATA" +TOPIC_STOP = "STOP" +TOPIC_ABORT = "ABORT" +TOPIC_LOCAL_RESULT = "LOCAL_RESULT" +TOPIC_GLOBAL_RESULT = "GLOBAL_RESULT" + + +class InProcessClientAPI(APISpec): + def __init__(self, task_metadata: dict, result_check_interval: float = 2.0): + """Initializes the InProcessClientAPI. + + Args: + task_metadata (dict): task metadata, added to client_config. + result_check_interval (float): how often to check if result is availabe. + """ + self.data_bus = DataBus() + self.data_bus.subscribe([TOPIC_GLOBAL_RESULT], self.__receive_callback) + self.data_bus.subscribe([TOPIC_ABORT, TOPIC_STOP], self.__ask_to_abort) + + self.meta = task_metadata + self.result_check_interval = result_check_interval + + self.start_round = None + self.fl_model = None + self.sys_info = {} + self.client_config: Optional[ClientConfig] = None + self.current_round = None + self.total_rounds = None + self.logger = logging.getLogger(self.__class__.__name__) + self.event_manager = EventManager(self.data_bus) + self.abort_reason = "" + self.stop_reason = "" + self.abort = False + self.stop = False + + def init(self, config: Optional[Dict] = None, rank: Optional[str] = None): + """Initializes NVFlare Client API environment. + + Args: + config (Union[str, Dict]): config dictionary. + rank (str): local rank of the process. + It is only useful when the training script has multiple worker processes. (for example multi GPU) + """ + config = {} if config is None else config + self.prepare_client_config(config) + + for k, v in self.client_config.config.items(): + if k in SYS_ATTRS: + self.sys_info[k] = v + + def prepare_client_config(self, config): + if isinstance(config, dict): + client_config = ClientConfig(config=config) + else: + raise ValueError("config should be a dictionary.") + + if client_config.config: + client_config.config.update(self.meta) + else: + client_config.config = self.meta + self.client_config = client_config + + def receive(self, timeout: Optional[float] = None) -> Optional[FLModel]: + if self.fl_model: + return self.fl_model + + while True: + if not self.__continue_job(): + break + + if self.fl_model is None: + self.logger.debug(f"no result global message available, sleep {self.result_check_interval} sec") + time.sleep(self.result_check_interval) + else: + break + + return self.fl_model + + def send(self, model: FLModel, clear_cache: bool = True) -> None: + if self.__continue_job(): + self.logger.info("send local model back to peer ") + + if self.client_config.get_transfer_type() == "DIFF": + model = self._prepare_param_diff(model) + + shareable = FLModelUtils.to_shareable(model) + self.event_manager.fire_event(TOPIC_LOCAL_RESULT, shareable) + + if clear_cache: + self.fl_model = None + + def system_info(self) -> Dict: + return self.sys_info + + def get_config(self) -> Dict: + return self.client_config.get_config() + + def get_job_id(self) -> str: + return self.meta[FLMetaKey.JOB_ID] + + def get_site_name(self) -> str: + return self.meta[FLMetaKey.SITE_NAME] + + def get_task_name(self) -> str: + return self.meta[ConfigKey.TASK_NAME] + + def is_running(self) -> bool: + if not self.__continue_job(): + return False + else: + self.receive() + + if self.fl_model: + self.current_round = self.fl_model.current_round + self.total_rounds = self.fl_model.total_rounds + self.start_round = self.fl_model.meta.get(FLMetaKey.START_ROUND, 0) + else: + return False + + return self.current_round < self.start_round + self.total_rounds + + def is_train(self) -> bool: + return self.meta.get(ConfigKey.TASK_NAME) == self.client_config.get_train_task() + + def is_evaluate(self) -> bool: + return self.meta.get(ConfigKey.TASK_NAME) == self.client_config.get_eval_task() + + def is_submit_model(self) -> bool: + return self.meta.get(ConfigKey.TASK_NAME) == self.client_config.get_submit_model_task() + + def log(self, key: str, value: Any, data_type: AnalyticsDataType, **kwargs): + msg = dict(key=key, value=value, data_type=data_type, **kwargs) + self.event_manager.fire_event(TOPIC_LOG_DATA, msg) + + def clear(self): + self.fl_model = None + + def _prepare_param_diff(self, model: FLModel) -> FLModel: + exchange_format = self.client_config.get_exchange_format() + diff_func = DIFF_FUNCS.get(exchange_format, None) + + if diff_func is None: + raise RuntimeError(f"no default params diff function for {exchange_format}") + elif self.fl_model is None: + raise RuntimeError("no received model") + elif self.fl_model.params is not None: + if model.params_type == ParamsType.FULL: + try: + model.params = diff_func(original=self.fl_model.params, new=model.params) + model.params_type = ParamsType.DIFF + except Exception as e: + raise RuntimeError(f"params diff function failed: {e}") + + if model.params is None and model.metrics is None: + raise RuntimeError("the model to send does not have either params or metrics") + + return model + + def __receive_callback(self, topic, data, databus): + + if topic == TOPIC_GLOBAL_RESULT and not isinstance(data, Shareable): + raise ValueError(f"expecting a Shareable, but got '{type(data)}'") + + fl_model = FLModelUtils.from_shareable(data) + self.fl_model = fl_model + + def __ask_to_abort(self, topic, msg, databus): + if topic == TOPIC_ABORT: + self.abort = True + self.abort_reason = msg + self.logger.error(f"ask to abort job: reason: {msg}") + elif topic == TOPIC_STOP: + self.stop = True + self.stop_reason = msg + self.logger.warning(f"ask to stop job: reason: {msg}") + + def __continue_job(self) -> bool: + if self.abort: + raise RuntimeError(f"request to abort the job for reason {self.abort_reason}") + if self.stop: + self.logger.warning(f"request to stop the job for reason {self.stop_reason}") + self.fl_model = None + return False + + return True diff --git a/nvflare/fuel/message/__init__.py b/nvflare/fuel/message/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/nvflare/fuel/message/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/fuel/utils/function_utils.py b/nvflare/fuel/utils/function_utils.py new file mode 100644 index 0000000000..6555b7ce3b --- /dev/null +++ b/nvflare/fuel/utils/function_utils.py @@ -0,0 +1,56 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import inspect +from typing import Callable + + +def find_task_fn(task_fn_path) -> Callable: + """Return function given a function path. + + Args: + task_fn_path (str): function path + + Returns: + function + + ex: train.main -> main + custom/train.main -> main + custom.train.main -> main + """ + # Split the text by the last dot + tokens = task_fn_path.rsplit(".", 1) + module_name = tokens[0].replace("/", ".") + fn_name = tokens[1] if len(tokens) > 1 else "" + module = importlib.import_module(module_name) + fn = getattr(module, fn_name) + return fn + + +def require_arguments(func): + """Inspect function to get required arguments. + + Args: + func: function + + Returns: + require_args (bool), args_size (int), args_default_size (int) + """ + signature = inspect.signature(func) + parameters = signature.parameters + req = any(p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD for p in parameters.values()) + size = len(parameters) + args_with_defaults = [param for param in parameters.values() if param.default != inspect.Parameter.empty] + default_args_size = len(args_with_defaults) + return req, size, default_args_size diff --git a/tests/unit_test/app_common/executors/exec_task_fn_wrapper_test.py b/tests/unit_test/app_common/executors/exec_task_fn_wrapper_test.py new file mode 100644 index 0000000000..10e7b5340a --- /dev/null +++ b/tests/unit_test/app_common/executors/exec_task_fn_wrapper_test.py @@ -0,0 +1,89 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +from nvflare.app_common.executors.exec_task_fn_wrapper import ExecTaskFuncWrapper +from nvflare.fuel.data_event.data_bus import DataBus + + +class TestExecTaskFuncWrapper(unittest.TestCase): + def test_init_with_required_args(self): + # Test initialization with a function that requires arguments + task_fn_path = "nvflare.fuel.utils.class_utils.instantiate_class" + task_fn_args = {"class_path": "foo", "init_params": {}} + wrapper = ExecTaskFuncWrapper(task_fn_path, task_fn_args) + + self.assertEqual(wrapper.task_fn_path, task_fn_path) + self.assertEqual(wrapper.task_fn_args, task_fn_args) + self.assertTrue(wrapper.task_fn_require_args) + + def test_init_with_optional_args(self): + # Test initialization with a function that does not require arguments + task_fn_path = "nvflare.utils.cli_utils.get_home_dir" + task_fn_args = {"class_path": "foo", "init_params": {}} + wrapper = ExecTaskFuncWrapper(task_fn_path, task_fn_args) + + self.assertEqual(wrapper.task_fn_path, task_fn_path) + self.assertEqual(wrapper.task_fn_args, task_fn_args) + self.assertFalse(wrapper.task_fn_require_args) + + def test_init_with_missing_required_args(self): + # Test initialization with a function that requires arguments but none are provided + task_fn_path = "nvflare.fuel.utils.class_utils.instantiate_class" + # task_fn_args = {"class_path": "foo", "init_params": {}} + + with self.assertRaises(ValueError) as context: + wrapper = ExecTaskFuncWrapper(task_fn_path) + + expected_msg = f"function '{task_fn_path}' requires arguments, but none provided" + self.assertEqual(str(context.exception), expected_msg) + + def test_init_with_partial_missing_required_args(self): + # Test initialization with a function that requires arguments but only partially are provided + task_fn_path = "nvflare.fuel.utils.class_utils.instantiate_class" + task_fn_args = {"init_params": {}} + + with self.assertRaises(ValueError) as context: + wrapper = ExecTaskFuncWrapper(task_fn_path, task_fn_args) + + expected_msg = f"function '{task_fn_path}' requires 2 arguments, but 1 provided" + self.assertEqual(str(context.exception), expected_msg) + + def test_init_with_partial_missing_required_args_with_default(self): + # Test initialization with a function that requires arguments but only partially are provided + # the missing arg has default value + # def augment(to_dict: dict, from_dict: dict, from_override_to=False, append_list="components") + task_fn_path = "nvflare.fuel.utils.dict_utils.augment" + task_fn_args = {"to_dict": {}, "from_dict": {}} + wrapper = ExecTaskFuncWrapper(task_fn_path, task_fn_args) + + self.assertEqual(wrapper.task_fn_path, task_fn_path) + self.assertEqual(wrapper.task_fn_args, task_fn_args) + self.assertTrue(wrapper.task_fn_require_args) + + def test_run(self): + message_bus = DataBus() + message_bus.put_data("task_metadata", {}) + message_bus.put_data("mem_pipe", {}) + + # Test the run method + task_fn_path = "nvflare.fuel.utils.dict_utils.augment" + task_fn_args = {"to_dict": {}, "from_dict": {}} + wrapper = ExecTaskFuncWrapper(task_fn_path, task_fn_args) + + with patch.object(wrapper, "run") as mock_task_fn: + wrapper.run() + mock_task_fn.assert_called_once_with() diff --git a/tests/unit_test/client/__init__.py b/tests/unit_test/client/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/tests/unit_test/client/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/client/in_process/__init__.py b/tests/unit_test/client/in_process/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/tests/unit_test/client/in_process/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/client/in_process/api_test.py b/tests/unit_test/client/in_process/api_test.py new file mode 100644 index 0000000000..af2a62bc44 --- /dev/null +++ b/tests/unit_test/client/in_process/api_test.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from nvflare.apis.fl_constant import FLMetaKey +from nvflare.client.config import ConfigKey +from nvflare.client.in_process.api import ( + TOPIC_ABORT, + TOPIC_GLOBAL_RESULT, + TOPIC_LOCAL_RESULT, + TOPIC_LOG_DATA, + TOPIC_STOP, + InProcessClientAPI, +) +from nvflare.fuel.data_event.data_bus import DataBus + + +class TestInProcessClientAPI(unittest.TestCase): + def setUp(self): + # Create a mock task_metadata for testing + self.task_metadata = { + FLMetaKey.JOB_ID: "123", + FLMetaKey.SITE_NAME: "site-1", + "TASK_NAME": "train", + ConfigKey.TASK_EXCHANGE: { + ConfigKey.TRAIN_WITH_EVAL: "train_with_eval", + ConfigKey.EXCHANGE_FORMAT: "pytorch", + ConfigKey.TRANSFER_TYPE: "DIFF", + ConfigKey.TRAIN_TASK_NAME: "train", + ConfigKey.EVAL_TASK_NAME: "evaluate", + ConfigKey.SUBMIT_MODEL_TASK_NAME: "submit_model", + }, + } + + def test_init(self): + # Test the initialization of InProcessClientAPI + client_api = InProcessClientAPI(self.task_metadata) + client_api.init() + assert client_api.get_site_name() == "site-1" + assert client_api.get_task_name() == "train" + assert client_api.get_job_id() == "123" + assert client_api.is_train() is True + assert client_api.is_evaluate() is False + assert client_api.is_submit_model() is False + + assert client_api.sys_info == {FLMetaKey.JOB_ID: "123", FLMetaKey.SITE_NAME: "site-1"} + + def test_init_with_custom_interval(self): + # Test initialization with a custom result_check_interval + client_api = InProcessClientAPI(self.task_metadata, result_check_interval=5.0) + self.assertEqual(client_api.result_check_interval, 5.0) + + def test_init_subscriptions(self): + client_api = InProcessClientAPI(self.task_metadata) + assert list(client_api.data_bus.subscribers.keys()) == [TOPIC_GLOBAL_RESULT, TOPIC_ABORT, TOPIC_STOP] + + def local_result_callback(self, data, topic): + pass + + def log_result_callback(self, data, topic): + pass + + def test_init_subscriptions2(self): + data_bus = DataBus() + data_bus.subscribers.clear() + + data_bus.subscribe([TOPIC_LOCAL_RESULT], self.local_result_callback) + data_bus.subscribe([TOPIC_LOG_DATA], self.log_result_callback) + assert list(data_bus.subscribers.keys()) == [TOPIC_LOCAL_RESULT, TOPIC_LOG_DATA] + client_api = InProcessClientAPI(self.task_metadata) + assert list(client_api.data_bus.subscribers.keys()) == [ + TOPIC_LOCAL_RESULT, + TOPIC_LOG_DATA, + TOPIC_GLOBAL_RESULT, + TOPIC_ABORT, + TOPIC_STOP, + ] + + # Add more test methods for other functionalities in the class diff --git a/tests/unit_test/fuel/utils/function_utils_test.py b/tests/unit_test/fuel/utils/function_utils_test.py new file mode 100644 index 0000000000..400664f7e2 --- /dev/null +++ b/tests/unit_test/fuel/utils/function_utils_test.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest +from unittest.mock import MagicMock, patch + +from nvflare.fuel.utils.function_utils import find_task_fn + + +class TestFindTaskFn(unittest.TestCase): + @patch("importlib.import_module") + def test_find_task_fn_with_module(self, mock_import_module): + # Test find_task_fn when a module is specified in task_fn_path + task_fn_path = "nvflare.utils.cli_utils.get_home_dir" + mock_module = MagicMock() + mock_import_module.return_value = mock_module + + result = find_task_fn(task_fn_path) + + mock_import_module.assert_called_once_with("nvflare.utils.cli_utils") + self.assertTrue(callable(result)) + + def test_find_task_fn_without_module(self): + # Test find_task_fn when no module is specified in task_fn_path + task_fn_path = "get_home_dir" + with self.assertRaises(ModuleNotFoundError) as context: + result = find_task_fn(task_fn_path)