diff --git a/.github/workflows/markdown-links-check.yml b/.github/workflows/markdown-links-check.yml index 0f0c3f4505..186c14d178 100644 --- a/.github/workflows/markdown-links-check.yml +++ b/.github/workflows/markdown-links-check.yml @@ -17,10 +17,7 @@ name: Check Markdown links on: push: - branches: [ "main", "dev" ] pull_request: - # The branches below must be a subset of the branches above - branches: [ "main", "dev" ] jobs: markdown-link-check: diff --git a/.github/workflows/premerge.yml b/.github/workflows/premerge.yml index bc9048752a..932275df7e 100644 --- a/.github/workflows/premerge.yml +++ b/.github/workflows/premerge.yml @@ -17,8 +17,6 @@ name: pre-merge on: # quick tests for pull requests and the releasing branches push: - branches: - - dev pull_request: workflow_dispatch: diff --git a/docs/_static/css/additions.css b/docs/_static/css/additions.css index 999ff74614..a8490da9b4 100644 --- a/docs/_static/css/additions.css +++ b/docs/_static/css/additions.css @@ -1,3 +1,6 @@ .wy-menu-vertical li.toctree-l4.current li.toctree-l5>a{display:block;background:#b1b1b1;padding:.4045em 7.3em} .wy-menu-vertical li.toctree-l5.current li.toctree-l6>a{display:block;background:#a9a9a9;padding:.4045em 8.8em} -.wy-menu-vertical li.toctree-l5{font-size: .9em;} \ No newline at end of file +.wy-menu-vertical li.toctree-l5{font-size: .9em;} +.wy-menu > .caption > span.caption-text { + color: #76b900; + } \ No newline at end of file diff --git a/docs/conf.py b/docs/conf.py index fa388e92eb..57a8f9e1c7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -44,7 +44,7 @@ def resolve_xref(self, env, fromdocname, builder, typ, target, node, contnode): # -- Project information ----------------------------------------------------- project = "NVIDIA FLARE" -copyright = "2023, NVIDIA" +copyright = "2024, NVIDIA" author = "NVIDIA" # The full version, including alpha/beta/rc tags @@ -114,6 +114,7 @@ def resolve_xref(self, env, fromdocname, builder, typ, target, node, contnode): html_scaled_image_link = False html_show_sourcelink = True html_favicon = "favicon.ico" +html_logo = "resources/nvidia_logo.png" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, diff --git a/docs/example_applications_algorithms.rst b/docs/example_applications_algorithms.rst index 7f4b5c17e4..3b07038e97 100644 --- a/docs/example_applications_algorithms.rst +++ b/docs/example_applications_algorithms.rst @@ -6,13 +6,7 @@ Example Applications NVIDIA FLARE has several tutorials and examples to help you get started with federated learning and to explore certain features in the :github_nvflare_link:`examples directory `. -1. Step-By-Step Example Series -============================== - - * :github_nvflare_link:`Step-by-Step CIFAR-10 Examples (GitHub) ` - Step-by-step examples series with CIFAR-10 (image data) to showcase to showcase different FLARE features, workflows, and APIs. - * :github_nvflare_link:`Step-by-Step HIGGS Examples (GitHub) ` - Step-by-step examples series with HIGGS (tabular data) to showcase to showcase different FLARE features, workflows, and APIs. - -2. Hello World Examples +1. Hello World Examples ======================= Can be run from the :github_nvflare_link:`hello_world notebook `. @@ -22,27 +16,58 @@ Can be run from the :github_nvflare_link:`hello_world notebook ` - Example for converting Deep Learning (DL) to Federated Learning (FL) using the Client API. -2.2. Workflows +1.2. Workflows -------------- * :ref:`Hello Scatter and Gather ` - Example using the Scatter And Gather (SAG) workflow with a Numpy trainer - * :ref:`Hello Cross-Site Validation ` - Example using the Cross Site Model Eval workflow with a Numpy trainer + * :ref:`Hello Cross-Site Validation ` - Example using the Cross Site Model Eval workflow with a Numpy trainer, also demonstrates running cross site validation using the previous training results. * :github_nvflare_link:`Hello Cyclic Weight Transfer (GitHub) ` - Example using the CyclicController workflow to implement `Cyclic Weight Transfer `_ with TensorFlow as the deep learning training framework * :github_nvflare_link:`Swarm Learning ` - Example using Swarm Learning and Client-Controlled Cross-site Evaluation workflows. * :github_nvflare_link:`Client-Controlled Cyclic Weight Transfer ` - Example using Client-Controlled Cyclic workflow using Client API. -2.3. Deep Learning +1.3. Deep Learning ------------------ * :ref:`Hello PyTorch ` - Example image classifier using FedAvg and PyTorch as the deep learning training framework * :ref:`Hello TensorFlow ` - Example image classifier using FedAvg and TensorFlow as the deep learning training frameworks + +2. Step-By-Step Example Series +============================== + +:github_nvflare_link:`Step-by-Step Examples (GitHub) ` - Step-by-step examples series with CIFAR-10 (image data) and HIGGS (tabular data) to showcase different FLARE features, workflows, and APIs. + +2.1 CIFAR-10 Image Data Examples +-------------------------------- + + * :github_nvflare_link:`image_stats ` - federated statistics (histograms) of CIFAR10. + * :github_nvflare_link:`sag ` - scatter and gather (SAG) workflow with PyTorch with Client API. + * :github_nvflare_link:`sag_deploy_map ` - scatter and gather workflow with deploy_map configuration for deployment of apps to different sites using the Client API. + * :github_nvflare_link:`sag_model_learner ` - scatter and gather workflow illustrating how to write client code using the ModelLearner. + * :github_nvflare_link:`sag_executor ` - scatter and gather workflow demonstrating show to write client-side executors. + * :github_nvflare_link:`sag_mlflow ` - MLflow experiment tracking logs with the Client API in scatter & gather workflows. + * :github_nvflare_link:`sag_he ` - homomorphic encyption using Client API and POC -he mode. + * :github_nvflare_link:`cse ` - cross-site evaluation using the Client API. + * :github_nvflare_link:`cyclic ` - cyclic weight transfer workflow with server-side controller. + * :github_nvflare_link:`cyclic_ccwf ` - client-controlled cyclic weight transfer workflow with client-side controller. + * :github_nvflare_link:`swarm ` - swarm learning and client-side cross-site evaluation with Client API. + +2.2 HIGGS Tabular Data Examples +------------------------------- + + * :github_nvflare_link:`tabular_stats `- federated stats tabular histogram calculation. + * :github_nvflare_link:`sklearn_linear `- federated linear model (logistic regression on binary classification) learning on tabular data. + * :github_nvflare_link:`sklearn_svm `- federated SVM model learning on tabular data. + * :github_nvflare_link:`sklearn_kmeans `- federated k-Means clustering on tabular data. + * :github_nvflare_link:`xgboost `- federated horizontal xgboost learning on tabular data with bagging collaboration. + + 3. Tutorial Notebooks ===================== diff --git a/docs/examples/fl_experiment_tracking_mlflow.rst b/docs/examples/fl_experiment_tracking_mlflow.rst index 14b6e36860..d9a9b63891 100644 --- a/docs/examples/fl_experiment_tracking_mlflow.rst +++ b/docs/examples/fl_experiment_tracking_mlflow.rst @@ -53,10 +53,10 @@ Adding MLflow Logging to Configurations Inside the config folder there are two files, ``config_fed_client.json`` and ``config_fed_server.json``. -.. literalinclude:: ../../examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_client.json - :language: json +.. literalinclude:: ../../examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_client.conf + :language: :linenos: - :caption: config_fed_client.json + :caption: config_fed_client.conf Take a look at the components section of the client config at line 24. The first component is the ``pt_learner`` which contains the initialization, training, and validation logic. @@ -69,10 +69,10 @@ within NVFlare with the information to track. Finally, :class:`ConvertToFedEvent` converts local events to federated events. This changes the event ``analytix_log_stats`` into a fed event ``fed.analytix_log_stats``, which will then be streamed from the clients to the server. -.. literalinclude:: ../../examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_server.json - :language: json +.. literalinclude:: ../../examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_server.conf + :language: :linenos: - :caption: config_fed_server.json + :caption: config_fed_server.conf Under the component section in the server config, we have the :class:`MLflowReceiver`. This component receives diff --git a/docs/fl_introduction.rst b/docs/fl_introduction.rst new file mode 100644 index 0000000000..04cb9a9cd5 --- /dev/null +++ b/docs/fl_introduction.rst @@ -0,0 +1,64 @@ +.. _fl_introduction: + +########################### +What is Federated Learning? +########################### + +Federated Learning is a distributed learning paradigm where training occurs across multiple clients, each with their own local datasets. +This enables the creation of common robust models without sharing sensitive local data, helping solve issues of data privacy and security. + +How does Federated Learning Work? +================================= +The federated learning (FL) server orchestrates the collaboration of multiple clients by first sending an initial model to the FL clients. +The clients perform training on their local datasets, then send the model updates back to the FL server for aggregation to form a global model. +This process forms a single round of federated learning and after a number of rounds, a robust global model can be developed. + +.. image:: resources/fl_diagram.png + :height: 500px + :align: center + +FL Terms and Definitions +======================== + +- FL server: manages job lifecycle, orchestrates workflow, assigns tasks to clients, performs aggregation +- FL client: executes tasks, performs local computation/learning with local dataset, submits result back to FL server +- FL algorithms: FedAvg, FedOpt, FedProx etc. implemented as workflows + +.. note:: + + Here we describe the centralized version of FL, where the FL server has the role of the aggregrator node. However in a decentralized version such as + swarm learning, FL clients can serve as the aggregator node instead. + +- Types of FL + + - horizontal FL: clients hold different data samples over the same features + - vertical FL: clients hold different features over an overlapping set of data samples + - swarm learning: a decentralized subset of FL where orchestration and aggregation is performed by the clients + +Main Benefits +============= + +Enhanced Data Privacy and Security +---------------------------------- +Federated learning facilitates data privacy and data locality by ensuring that the data remains at each site. +Additionally, privacy preserving techniques such as homomorphic encryption and differential privacy filters can also be leveraged to further protect the transferred data. + +Improved Accuracy and Diversity +------------------------------- +By training with a variety of data sources across different clients, a robust and generalizable global model can be developed to better represent heterogeneous datasets. + +Scalability and Network Efficiency +---------------------------------- +With the ability to perform training at the edge, federated learning can be highly scalable across the globe. +Additionally only needing to transfer the model weights rather than entire datasets enables efficient use of network resources. + +Applications +============ +An important application of federated learning is in the healthcare sector, where data privacy regulations and patient record confidentiality make training models challenging. +Federated learning can help break down these healthcare data silos to allow hospitals and medical institutions to collaborate and pool their medical knowledge without the need to share their data. +Some common use cases involve classification and detection tasks, drug discovery with federated protein LLMs, and federated analytics on medical devices. + +Furthermore there are many other areas and industries such as financial fraud detection, autonomous vehicles, HPC, mobile applications, etc. +where the ability to use distributed data silos while maintaining data privacy is essential for the development of better models. + +Read on to learn how FLARE is built as a flexible federated computing framework to enable federated learning from research to production. \ No newline at end of file diff --git a/docs/flare_overview.rst b/docs/flare_overview.rst index c2f6ecfb91..15eaafa8d3 100644 --- a/docs/flare_overview.rst +++ b/docs/flare_overview.rst @@ -26,7 +26,7 @@ Built for productivity FLARE is designed for maximum productivity, providing a range of tools to enhance user experience and research efficiency at different stages of the development process: - **FLARE Client API:** Enables users to transition seamlessly from ML/DL to FL with just a few lines of code changes. -- **Simulator CLI:** Allows users to simulate federated learning or computing jobs in multi-thread settings within a single computer, offering quick response and debugging. The same job can be deployed directly to production. +- **Simulator CLI:** Allows users to simulate federated learning or computing jobs in multi-process settings within a single computer, offering quick response and debugging. The same job can be deployed directly to production. - **POC CLI:** Facilitates the simulation of federated learning or computing jobs in multi-process settings within one computer. Different processes represent server, clients, and an admin console, providing users with a realistic sense of the federated network. It also allows users to simulate project deployment on a single host. - **Job CLI:** Permits users to create and submit jobs directly in POC or production environments. - **FLARE API:** Enables users to run jobs directly from Python code or notebooks. diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 1abb87c58d..cee26e8f9e 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -22,14 +22,14 @@ Clone NVFLARE repo to get examples, switch main branch (latest stable branch) $ git clone https://github.com/NVIDIA/NVFlare.git $ cd NVFlare - $ git switch main + $ git switch 2.4 Note on branches: * The `main `_ branch is the default (unstable) development branch -* The 2.0, 2.1, 2.2, and 2.3 etc. branches are the branches for each major release and minor patches +* The 2.1, 2.2, 2.3, and 2.4 etc. branches are the branches for each major release and minor patches Quick Start with Simulator @@ -63,6 +63,14 @@ establishing a secure, distributed FL workflow. Installation ============= +.. note:: + The server and client versions of nvflare must match, we do not support cross-version compatibility. + +Supported Operating Systems +--------------------------- +- Linux +- OSX (Note: some optional dependencies are not compatible, such as tenseal and openmined.psi) + Python Version -------------- @@ -117,7 +125,6 @@ You may find that the pip and setuptools versions in the venv need updating: (nvflare-env) $ python3 -m pip install -U pip (nvflare-env) $ python3 -m pip install -U setuptools - Install Stable Release ---------------------- @@ -127,6 +134,11 @@ Stable releases are available on `NVIDIA FLARE PyPI ` for modules and components with optional dependencies. .. _containerized_deployment: @@ -210,7 +222,7 @@ Production mode is secure with TLS certificates - depending the choice the deplo - HA or non-HA - Local or remote - - On-premise or on cloud + - On-premise or on cloud (See :ref:`cloud_deployment`) Using non-HA, secure, local mode (all clients and server running on the same host), production mode is very similar to POC mode except it is secure. diff --git a/docs/index.rst b/docs/index.rst index 78ea857358..1c8527e62c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -5,15 +5,29 @@ NVIDIA FLARE .. toctree:: :maxdepth: -1 :hidden: + :caption: Introduction + fl_introduction flare_overview whats_new getting_started + +.. toctree:: + :maxdepth: -1 + :hidden: + :caption: Guides + example_applications_algorithms real_world_fl user_guide programming_guide best_practices + +.. toctree:: + :maxdepth: -1 + :hidden: + :caption: Miscellaneous + faq publications_and_talks contributing @@ -21,7 +35,7 @@ NVIDIA FLARE glossary NVIDIA FLARE (NVIDIA Federated Learning Application Runtime Environment) is a domain-agnostic, open-source, extensible SDK that allows -researchers and data scientists to adaptexisting ML/DL workflows (PyTorch, RAPIDS, Nemo, TensorFlow) to a federated paradigm; and enables +researchers and data scientists to adapt existing ML/DL workflows (PyTorch, RAPIDS, Nemo, TensorFlow) to a federated paradigm; and enables platform developers to build a secure, privacy preserving offering for a distributed multi-party collaboration. NVIDIA FLARE is built on a componentized architecture that gives you the flexibility to take federated learning workloads from research @@ -34,18 +48,21 @@ and simulation to real-world production deployment. Some of the key components - **Management tools** for secure provisioning and deployment, orchestration, and management - **Specification-based API** for extensibility -Learn more in the :ref:`FLARE Overview `, :ref:`Key Features `, :ref:`What's New `, and the -:ref:`User Guide ` and :ref:`Programming Guide `. +Learn more about FLARE features in the :ref:`FLARE Overview ` and :ref:`What's New `. Getting Started =============== -For first-time users and FL researchers, FLARE provides the :ref:`fl_simulator` that allows you to build, test, and deploy applications locally. -The :ref:`Getting Started guide ` covers installation and walks through an example application using the FL Simulator. +For first-time users and FL researchers, FLARE provides the :ref:`FL Simulator ` that allows you to build, test, and deploy applications locally. +The :ref:`Getting Started ` guide covers installation and walks through an example application using the FL Simulator. +Additional examples can be found at the :ref:`Examples Applications `, which showcase different federated learning workflows and algorithms on various machine learning and deep learning tasks. +FLARE for Users +=============== +If you want to learn how to interact with the FLARE system, please refer to the :ref:`User Guide `. When you are ready to for a secure, distributed deployment, the :ref:`Real World Federated Learning ` section covers the tools and process required to deploy and operate a secure, real-world FLARE project. FLARE for Developers ==================== -When you're ready to build your own application, the :ref:`Programming Best Practices `, :ref:`FAQ`, and -:ref:`Programming Guide ` give an in depth look at the FLARE platform and APIs. +When you're ready to build your own application, the :ref:`Programming Guide `, :ref:`Programming Best Practices `, :ref:`FAQ`, and :ref:`API Reference ` +give an in depth look at the FLARE platform and APIs. diff --git a/docs/programming_guide.rst b/docs/programming_guide.rst index b181cbe7c9..28e8b7992b 100644 --- a/docs/programming_guide.rst +++ b/docs/programming_guide.rst @@ -36,7 +36,8 @@ Please refer to :ref:`application` for more details. :maxdepth: 1 programming_guide/workflows_and_controllers - programming_guide/fl_clients + programming_guide/execution_api_type + programming_guide/fl_model programming_guide/shareable programming_guide/data_exchange_object programming_guide/fl_context diff --git a/docs/programming_guide/controllers/controllers.rst b/docs/programming_guide/controllers/controllers.rst index dd611b3631..cf4a8f4368 100644 --- a/docs/programming_guide/controllers/controllers.rst +++ b/docs/programming_guide/controllers/controllers.rst @@ -73,7 +73,9 @@ The Controller's Task Manager manages the task's lifecycle: .. note:: - In NVIDIA FLARE 2.0, the underlying communication is by gRPC: the client always initiates communication by sending - a request to the server and a receiving response. When we say "server sends task to the client", it is only - conceptual. With gRPC, the client sends the "ask for next task" request to the server, and the server responds with - the task data. + In NVIDIA FLARE, the underlying communication is facilitated through gRPC: + the client always initiates communication by sending a request to the server and receiving a response. + When referring to the scenario where the "server sends a task to the client," + it is important to note that this is a conceptual representation. + In reality, with gRPC, the client initiates the interaction by sending a "request for the next task" to the server, + and the server responds by providing the task data. diff --git a/docs/programming_guide/controllers/cross_site_model_evaluation.rst b/docs/programming_guide/controllers/cross_site_model_evaluation.rst index 456e8fc138..75936806d5 100644 --- a/docs/programming_guide/controllers/cross_site_model_evaluation.rst +++ b/docs/programming_guide/controllers/cross_site_model_evaluation.rst @@ -23,7 +23,7 @@ example that implements the :class:`cross site model evaluation workflow` to write the results to a JSON file on the server. -Example with Cross Site Model Evaluation / Federated Evaluation Workflow -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -See the :github_nvflare_link:`Hello Numpy Cross-Site Validation ` for an example application with -the cross site model evaluation / federated evaluation workflow. +Examples with Cross Site Model Evaluation / Federated Evaluation Workflow +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +See :github_nvflare_link:`Hello Numpy Cross-Site Validation ` and +:github_nvflare_link:`Step-by-step Cross-site Evaluation ` for examples using server-controlled cross-site evaluation workflows. diff --git a/docs/programming_guide/controllers/initialize_global_weights.rst b/docs/programming_guide/controllers/initialize_global_weights.rst index aed6b035d0..6634d3133e 100644 --- a/docs/programming_guide/controllers/initialize_global_weights.rst +++ b/docs/programming_guide/controllers/initialize_global_weights.rst @@ -26,7 +26,7 @@ Two changes are needed: The updated file should look like the following: -.. literalinclude:: ../resources/init_weights_1_config_fed_server.json +.. literalinclude:: ../../resources/init_weights_1_config_fed_server.json :language: json diff --git a/docs/programming_guide/controllers/scatter_and_gather_workflow.rst b/docs/programming_guide/controllers/scatter_and_gather_workflow.rst index ad5b1d9507..44c9d232a8 100644 --- a/docs/programming_guide/controllers/scatter_and_gather_workflow.rst +++ b/docs/programming_guide/controllers/scatter_and_gather_workflow.rst @@ -7,7 +7,7 @@ of NVIDIA FLARE with a Server aggregating results from Clients that have produce At the core, the control_flow of :class:`nvflare.app_common.workflows.scatter_and_gather.ScatterAndGather` is a for loop: -.. image:: ../resources/fed_sag_round.png +.. image:: ../../resources/fed_sag_round.png :height: 400px Trainer diff --git a/docs/programming_guide/execution_api_type.rst b/docs/programming_guide/execution_api_type.rst new file mode 100644 index 0000000000..77baf7806a --- /dev/null +++ b/docs/programming_guide/execution_api_type.rst @@ -0,0 +1,98 @@ +.. _execution_api_type: + +####################### +From Local to Federated +####################### + +In the FLARE system, a federated learning algorithm is defined in a Job format +(for details, please refer to :ref:`job`). + +A Job consists of multiple "workflows" and "executors." + +The simplified job execution flow is as follows: + +- The workflow schedules a task for the FL clients. +- Each FL client performs the received task and sends the result back. +- The workflow receives the results and determines if it is done. +- If it is not done, it schedules a new task +- If it is done, it proceeds to the next workflow in the Job. + +Users need to adapt their local training or computing logic into FLARE's task +execution abstractions to make their training or computing federated. + +We offer various levels of abstraction for writing task execution code, +catering to use cases that span from complete customizability to easy user adaptation. + +Execution API Type +================== + +Below is a general overview of the key ideas and use cases for each type: + +Client API +---------- + +The :ref:`client_api` provides the most straightforward way to write FL code, +and can easily be used to convert centralized code with minimal code changes. +The Client API uses the :class:`FLModel` +object for data transfer and supports common tasks such as train, validate, and submit_model. +Additionally, options for using decorators or PyTorch Lightning are also available. + +We recommend users start with the Client API, and to consider the other types +for more specific cases as required. + +ModelLearner +------------ + +The :ref:`model_learner` is designed to simplify writing learning logic by +minimizing FLARE-specific concepts. +The :class:`ModelLearner` +defines familiar learning functions for training and validation, +and uses the :class:`FLModel` +object for transferring learning information. +The ModelLearner also contains several convenient capabilities, +such as lifecycle and logging information. + +The ModelLearner is best used when working with standard machine learning code +that can fit well into the train and validate methods and can be easily adapted +to the ModelLearner subclass and method structure. + +Executor +-------- + +:ref:`executor` is the most flexible for defining custom logic and tasks, +as with a custom executor and controller, any form of computation can be performed. +However, Executors must deal directly with FLARE-specific communication concepts +such as :class:`Shareable`, :class:`DXO`, +and :class:`FLContext`. +As a result, many higher-level APIs are built on top of Executors in order to +abstract these concepts away for easier user adaptation. + +Overall, writing an Executor is most useful when implementing tasks and logic +that do not fit within the structure of higher-level APIs or other predefined Executors. + +3rd-Party System Integration +---------------------------- + +There are cases where users have a pre-existing ML/DL training system +infrastructure that cannot be easily adapted to the FLARE client. + +The :ref:`3rd_party_integration` pattern allows for a seamless integration +between the FLARE system and a third-party external training system. + +With the use of the :mod:`FlareAgent ` and +:mod:`TaskExchanger `, +users can easily enable any 3rd-party system to receive tasks and submit results back to the server. + +Please use the following chart to decide which abstraction to use: + +.. image:: ../resources/task_execution_decision_chart.png + +For more details about each type, refer to each page below. + +.. toctree:: + :maxdepth: 1 + + execution_api_type/3rd_party_integration + execution_api_type/client_api + execution_api_type/model_learner + execution_api_type/executor diff --git a/docs/programming_guide/execution_api_type/3rd_party_integration.rst b/docs/programming_guide/execution_api_type/3rd_party_integration.rst new file mode 100644 index 0000000000..093490897f --- /dev/null +++ b/docs/programming_guide/execution_api_type/3rd_party_integration.rst @@ -0,0 +1,350 @@ +.. _3rd_party_integration: + +############################ +3rd-Party System Integration +############################ + +NVFLARE supports a seamless integration between the FLARE system and a +third-party external training system. +This is especially useful with pre-existing ML/DL training system +infrastructure that cannot be easily adapted to the FLARE client. + +The FL Client uses the :class:`TaskExchanger` +executor to receive tasks, and submit results to the FLARE server. +The 3rd-party system uses the :class:`FlareAgent` to +interact with the TaskExchanger to get tasks and submit results. + +This integration pattern is illustrated in the diagram below: + +.. image:: ../../resources/3rd_party_integration_diagram.png + :height: 400px + +Requirements +============ + +- The key to enabling this integration is the "agent_id" that must be made known to both systems. + The FL client gets this information from the job's config_fed_client, and the + 3rd-party trainer gets this from its own launch process. +- It is assumed that the customer already has a way to dynamically generate the + "agent_id" for each job, and start its trainer process with this information. +- Each FL client must be able to open an address (host:port) to allow the trainer to connect to. + Depending on where the trainer is running, the connection may or may not need to be in secure mode (TLS). +- We will need to modify the "project.yml" for NVFlare provision system + and generate new package folders for each participating sites +- The trainer must be a Python program that can integrate with the NVFLARE library. +- The trainer must be able to connect to the server, as well as the address that + is dynamically opened by the FL client. + +Prepare the Trainer +=================== + +Let's prepare the trainer code first, we will modify the "project.yml" in the +next section for project setup. + +You need to modify your trainer code to integrate with the :class:`FlareAgent` API. +This API provides simple ``get_task()`` and ``submit_result()`` methods to interact with the FL client. + +We will go through the steps one by one: + +1. Create Agent +--------------- + +The :class:`FlareAgent` is responsible +for interacting with the FL client to exchange task data. + +If using FLModel, :class:`FlareAgentWithFLModel` +subclasses FlareAgent and provides conversion from shareables to task using the FLModel data structure. + +If using CellPipe, a convenient class :class:`FlareAgentWithCellPipe` +can be used. + +Please refer to their API page for detailed explanations of each argument: + + - :class:`FlareAgent` + - :class:`FlareAgentWithFLModel` + - :class:`FlareAgentWithCellPipe` + +You can create the FlareAgentWithCellPipe as the following code: + +.. code-block:: python + + from nvflare.client.flare_agent import FlareAgentWithCellPipe + + agent = FlareAgentWithCellPipe( + root_url="grpc://server:8002", + site_name=args.site_name, + agent_id=args.agent_id, + workspace_dir=args.workspace, + secure_mode=True, + submit_result_timeout=2.0, + heartbeat_timeout=120.0, + ) + +2. Start Agent +-------------- + +After we create the agent, we need to start it. +We can call ``agent.start()`` to start the agent. +This call must be made before trying to get tasks. + +For example: + +.. code-block:: python + + agent.start() + +3. Process Tasks +---------------- + +The training is a continuous process of getting a task, executing the task, +and submitting the task result. + +Call ``agent.get_task()`` to get a Task object from the FL client. +This is a blocking call and returns only when a task is available. +If there are no more tasks available (i.e. end of the job), ``AgentClosed`` +exception will be raised, and signaling to end the training. + +The :class:`Task` object contains 3 pieces of +information: task_name, task_id, and data. +The task_name tells you what the task is (e.g. train). +The task_id is a UUID of the task instance. +The data contains model data to be trained on. + +Once the task is completed, the result can be submitted to the FL client by calling ``agent.submit_result()``. +A return code (``rc``) must be provided to indicate whether the task was executed successfully. +If the ``rc`` is not RC.OK, then the job will be aborted. + +For example: + +.. code-block:: python + + while True: + print("getting task ...") + try: + task = agent.get_task() + except AgentClosed: + print("agent closed - exit") + break + + print(f"got task: {task}") + rc, meta, result = train(task.data) # perform train task + submitted = agent.submit_result(TaskResult(data=result, meta=meta, return_code=rc)) + print(f"result submitted: {submitted}") + +4. Stop Agent +------------- + +At the end of the training, ``agent.stop()`` must be called to end the program gracefully. +If this call is missed, the program may not exit properly. + +.. code-block:: python + + agent.stop() + + +5. Putting Together +------------------- + +Now we learn all the necessary steps, we can put together into the following +example code of this usage pattern: + +.. literalinclude:: ../../resources/3rd_party_trainer.py + :language: python + + +Notes: + +- This pattern of (``start``, ``get_task``, ``submit_result``, and ``stop``) is strictly enforced. + If the pattern is not followed (e.g. ``get_task``, then ``get_task`` again without ``submit_result``), + you will get a ``CallStateError`` exception. +- The only way to know that the job is ended is the ``AgentClosed`` exception from the ``get_task`` call. + This exception is raised when the FL client tells the agent that the job is done; + or when the FL client is considered dead (missing heartbeats for the configured period of time). +- If your training algorithm runs into an unrecoverable error and wants to end the job, + you should use a proper return code (e.g. ``RC.EXECUTION_EXCEPTION``). + +Project Setup +============= + +After we prepare the trainer code we can follow the steps below to properly +set up the project and jobs. + +Step One - Provision +-------------------- + +The FL client site will behave like both client and server for connecting from the perspective of the trainer. +This requires the client site to have two sets of TLS credentials. +Make sure to specify the "listening_host" for the client in the project.yml when provisioning the project. + +.. note:: + We assume you understand NVFlare provision, if not please read :ref:`provisioning`. + +An example looks like: + +.. code-block:: yaml + + participants: + # change example.com to the FQDN of the server + - name: server + type: server + org: nvidia + fed_learn_port: 8002 + admin_port: 8003 + - name: site_1 + type: client + org: nvidia + listening_host: localhost + - name: site_2 + type: client + org: nvidia + listening_host: localhost + +Once the project is provisioned, check the "startup" kit generated for the clients. +You should see the following files, among others: + +client.crt, client.key, server.crt, server.key, rootCA.pem + +Note that the specified listening_host of a site must be a hostname that +the external trainer can reach via network. + +Step Two - Prepare Job Configuration +------------------------------------ + +For each job, configure the config_fed_client.json to use +:class:`TaskExchanger` as the executor. + +.. code-block:: + + { + "format_version": 2, + "executors": [ + { + "tasks": [ + "train" + ], + "executor": { + "path": "nvflare.app_common.executors.task_exchanger.TaskExchanger", + "args": { + "pipe_id": "pipe" + "peer_read_timeout": 30, + "heartbeat_timeout": 60 + } + } + } + ], + "task_result_filters": [], + "task_data_filters": [], + components = [ + { + id = "pipe" + path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" + args { + mode = "PASSIVE" + site_name = "{SITE_NAME}" + token = "{SITE_NAME}" + root_url = "{ROOT_URL}" + secure_mode = "{SECURE_MODE}" + workspace_dir = "{WORKSPACE}" + } + } + ] + } + +Make sure the parameters of the :class:`TaskExchanger` +are configured properly, and change the default values as needed. + +Please refer to the API page for a detailed explanation of each argument: +:class:`TaskExchanger` + +Step Three - Trainer Setup +-------------------------- + +For each client site, you will have an FL client and a trainer process. + +To make our integration work, please follow the following steps to +setup the trainer process on each client site: + + - Make sure the trainer process has access to a local file system. + - Create a "workspace" folder that is going to be used by this trainer process + This workspace will be used for all jobs. + - Copy the "startup" folder of the client site to this "workspace" folder + If needed, any additional config files required by the trainer can also + be placed in this "workspace" folder. + - Create the trainer script following the steps in the above section. + Please set the FlareAgentWithCellPipe's "workspace_dir" to the path of + this "workspace" folder that you just created. + Please make sure the "agent_id" value of FlareAgentWithCellPipe is the same + as the "token" value in the above + +Verification +============ + +The FL client (TaskExchanger) and your trainer process (FlareAgentWithCellPipe) +do not have to be started at exactly the same time. + +Whichever is started first will wait for the other for ``heartbeat_timeout`` seconds. +Once they both are started and connected, you can verify they are directly +connected using the Admin console's ``cells`` commands. + +The following example shows two clients (site-1, site-2) connected to their +external trainers via the agent_id/token "ext_trainer": + +.. code-block:: shell + + > cells + server + server.10d1d3b7-fb50-4c83-9575-e510f32c5d21 + site-1 + site-1.10d1d3b7-fb50-4c83-9575-e510f32c5d21 + site-2 + site-2.10d1d3b7-fb50-4c83-9575-e510f32c5d21 + site-1_ext_trainer_active + site-2_ext_trainer_active + site-2_ext_trainer_passive + site-1_ext_trainer_passive + Total Cells: 10 + + +The ``cells`` command lists all cells. + +Notice that the job ``10d1d3b7-fb50-4c83-9575-e510f32c5d21`` is running on both +"site-1" and "site-2" clients. + +Also notice that there are two pairs of corresponding cells +(site-1_ext_trainer_active, site-1_ext_trainer_passive) +and ((site-2_ext_trainer_active, site-2_ext_trainer_passive)). + + +Optional - Setup for Adhoc Direct Connection between FL Client and Trainer +========================================================================== + +FL client and the trainer can always talk to each other via the server, +but it could be slow, especially if the server is located far away. +The enable adhoc direct connections between the FL client and the trainer, +configure the comm_config.json on the client site as follows: + +.. code-block:: json + + { + "allow_adhoc_conns": true, + "use_aio_grpc": true, + "adhoc": { + "scheme": "tcp", + "resources": { + "host": "localhost", + "secure": true + } + } + } + +This file must be placed into the site's "local" folder within its workspace. + +Pay attention to the following: + +- For most cases, the "scheme" should be set to "tcp" to get the best performance. + If "tcp" cannot be used, you can use "grpc". +- In "resources": + + - If FL client and the trainer are within the same trusted network, + you can set "secure" to false; otherwise set it to true. + - The value of the "host" must match the "listening_host" value of the site used in provision. diff --git a/docs/programming_guide/execution_api_type/client_api.rst b/docs/programming_guide/execution_api_type/client_api.rst new file mode 100644 index 0000000000..aff5da50eb --- /dev/null +++ b/docs/programming_guide/execution_api_type/client_api.rst @@ -0,0 +1,195 @@ +.. _client_api: + +########## +Client API +########## + +The FLARE Client API provides an easy way for users to convert their centralized, +local training code into federated learning code with the following benefits: + +* Only requires a few lines of code changes, without the need to restructure the code or implement a new class +* Reduces the number of new FLARE specific concepts exposed to users +* Easy adaptation from existing local training code using different frameworks + (PyTorch, PyTorch Lightning, HuggingFace) + +Core concept +============ + +The general structure of the popular federated learning (FL) workflow, "FedAvg" is as follows: + +#. FL server initializes an initial model +#. For each round (global iteration): + + #. FL server sends the global model to clients + #. Each FL client starts with this global model and trains on their own data + #. Each FL client sends back their trained model + #. FL server aggregates all the models and produces a new global model + +On the client side, the training workflow is as follows: + +#. Receive the model from the FL server +#. Perform local training on the received global model and/or evaluate the + received global model for model selection +#. Send the new model back to the FL server + +To convert a centralized training code to federated learning, we need to +adapt the code to do the following steps: + +#. Obtain the required information from received :ref:`fl_model` +#. Run local training +#. Put the results in a new :ref:`fl_model` to be sent back + +For a general use case, there are three essential methods for the Client API: + +* ``init()``: Initializes NVFlare Client API environment. +* ``receive()``: Receives model from NVFlare side. +* ``send()``: Sends the model to NVFlare side. + +Users can use the Client API to change their centralized training code to +federated learning, for example: + +.. code-block:: python + + import nvflare.client as flare + + flare.init() # 1. Initializes NVFlare Client API environment. + input_model = flare.receive() # 2. Receives model from NVFlare side. + params = input_model.params # 3. Obtain the required information from received FLModel + + # original local training code begins + new_params = local_train(params) + # original local training code ends + + output_model = flare.FLModel(params=new_params) # 4. Put the results in a new FLModel + flare.send(output_model) # 5. Sends the model to NVFlare side. + +With 5 lines of code changes, we convert the centralized training code to +federated learning setting. + +After this, we can utilize the job templates and the :ref:`job_cli` +to generate a job so it can be run using :ref:`fl_simulator` +or submit to a deployed NVFlare system. + +Below is a table overview of key Client APIs. + +.. list-table:: Client API + :widths: 25 25 50 + :header-rows: 1 + + * - API + - Description + - API Doc Link + * - init + - Initializes NVFlare Client API environment. + - :func:`init` + * - receive + - Receives model from NVFlare side. + - :func:`receive` + * - send + - Sends the model to NVFlare side. + - :func:`send` + * - system_info + - Gets NVFlare system information. + - :func:`system_info` + * - get_job_id + - Gets job id. + - :func:`get_job_id` + * - get_site_name + - Gets site name. + - :func:`get_site_name` + * - is_running + - Returns whether the NVFlare system is up and running. + - :func:`is_running` + * - is_train + - Returns whether the current task is a training task. + - :func:`is_train` + * - is_evaluate + - Returns whether the current task is an evaluate task. + - :func:`is_evaluate` + * - is_submit_model + - Returns whether the current task is a submit_model task. + - :func:`is_submit_model` + +.. list-table:: Decorator APIs + :widths: 25 25 50 + :header-rows: 1 + + * - API + - Description + - API Doc Link + * - train + - A decorator to wraps the training logic. + - :func:`train` + * - evaluate + - A decorator to wraps the evaluate logic. + - :func:`evaluate` + +.. list-table:: Lightning APIs + :widths: 25 25 50 + :header-rows: 1 + + * - API + - Description + - API Doc Link + * - patch + - Patches the PyTorch Lightning Trainer for usage with FLARE. + - :func:`patch` + +.. list-table:: Metrics Logger + :widths: 25 25 50 + :header-rows: 1 + + * - API + - Description + - API Doc Link + * - SummaryWriter + - SummaryWriter mimics the usage of Tensorboard's SummaryWriter. + - :class:`SummaryWriter` + * - WandBWriter + - WandBWriter mimics the usage of weights and biases. + - :class:`WandBWriter` + * - MLflowWriter + - MLflowWriter mimics the usage of MLflow. + - :class:`MLflowWriter` + +Please check Client API Module :mod:`nvflare.client.api` for more in-depth +information about all of the Client API functionalities. + +If you are using PyTorch Lightning in your training code, you can check the +Lightning API Module :mod:`nvflare.app_opt.lightning.api`. + + +Configuration +============= + +In the config_fed_client in the FLARE app, in order to launch the training +script we use the +:class:`SubprocessLauncher` component. +The defined ``script`` is invoked, and ``launch_once`` can be set to either +launch once for the whole job, or launch a process for each task received from the server. + +A corresponding :class:`LauncherExecutor` +is used as the executor to handle the tasks and perform the data exchange using the pipe. +For the Pipe component we provide implementations of :class:`FilePipe` +and :class:`CellPipe`. + +.. literalinclude:: ../../../job_templates/sag_pt/config_fed_client.conf + +For example configurations, take a look at the :github_nvflare_link:`job_templates ` +directory for templates using the launcher and Client API. + +.. note:: + In that case that the user does not need to launch the process and instead + has their own existing external training system, this would involve using + the :ref:`3rd_party_integration`, which is based on the same underlying mechanisms. + +Examples +======== + +For examples of using Client API with different frameworks, +please refer to :github_nvflare_link:`examples/hello-world/ml-to-fl `. + +For additional examples, also take a look at the +:github_nvflare_link:`step-by-step series ` +that use Client API to write the +:github_nvflare_link:`train script `. diff --git a/docs/programming_guide/fl_clients/executor.rst b/docs/programming_guide/execution_api_type/executor.rst similarity index 92% rename from docs/programming_guide/fl_clients/executor.rst rename to docs/programming_guide/execution_api_type/executor.rst index dac7bfba81..45d4bea8d3 100644 --- a/docs/programming_guide/fl_clients/executor.rst +++ b/docs/programming_guide/execution_api_type/executor.rst @@ -6,9 +6,9 @@ Executor .. image:: ../../resources/Executor.png :height: 300px -An :class:`Executor` in NVIDIA FLARE is a type of FLComponent for FL clients that has an -``execute`` method that produces a Shareable from an input Shareable. The ``execute`` method also takes a str for -task_name, FLContext, and abort_signal. +An :class:`Executor` is an FLComponent for FL clients used for executing tasks, +wherein the ``execute`` method receives and returns a Shareable object given a task name, +``FLContext``, and ``abort_signal``. .. literalinclude:: ../../../nvflare/apis/executor.py :language: python diff --git a/docs/programming_guide/fl_clients/model_learner.rst b/docs/programming_guide/execution_api_type/model_learner.rst similarity index 98% rename from docs/programming_guide/fl_clients/model_learner.rst rename to docs/programming_guide/execution_api_type/model_learner.rst index 6a80fec437..292d0e78c3 100644 --- a/docs/programming_guide/fl_clients/model_learner.rst +++ b/docs/programming_guide/execution_api_type/model_learner.rst @@ -197,5 +197,6 @@ More Resources ============== In addition to the :github_nvflare_link:`ModelLearner ` and :github_nvflare_link:`FLModel ` APIs, also take a look at some examples using the ModelLearner: + - :github_nvflare_link:`Step-by-step ModelLearner ` -- :github_nvflare_link:`CIFAR10 ModelLearner ` +- :github_nvflare_link:`CIFAR10 ModelLearner ` diff --git a/docs/programming_guide/experiment_tracking.rst b/docs/programming_guide/experiment_tracking.rst index 06a274ee81..c0ccbbe939 100644 --- a/docs/programming_guide/experiment_tracking.rst +++ b/docs/programming_guide/experiment_tracking.rst @@ -37,6 +37,13 @@ provided examples, the Receiver is on the FL server, but it could also be on the - Server-side experiment tracking also can organize different clients' results into different experiment runs so they can be easily compared side-by-side. +.. note:: + + This page covers experiment tracking using :class:`LogWriters `, + which are configured and used with :ref:`executor` or :ref:`model_learner` on the FLARE-side code. + However if using the Client API, please refer to :ref:`client_api` and :ref:`nvflare.client.tracking` for adding experiment tracking to your custom training code. + + ************************************** Tools, Sender, LogWriter and Receivers ************************************** @@ -60,9 +67,9 @@ where the actual experiment logs are recorded. The components that receive these logs are called Receivers based on :class:`AnalyticsReceiver `. The receiver component leverages the experiment tracking tool and records the logs during the experiment run. -In a normal setting, we would have pairs of sender and receivers, such as: +In a normal setting, we would have pairs of sender and receivers, with some provided implementations in :mod:`nvflare.app_opt.tracking`: - - TBWriter <-> TBReceiver + - TBWriter <-> TBAnalyticsReceiver - MLflowWriter <-> MLflowReceiver - WandBWriter <-> WandBReceiver @@ -94,13 +101,11 @@ There are three things to consider for developing a custom experiment tracking t Data Type ========= -Currently, the supported data types are metrics, params, and text. If you require other data types, may sure you add -the type to :class:`AnalyticsDataType `. +Currently, the supported data types are listed in :class:`AnalyticsDataType `, and other data types can be added as needed. Writer ====== - -Implement LogWriter interface with the API syntax. For each tool, we mimic the API syntax of the underlying tool, +Implement :class:`LogWriter ` interface with the API syntax. For each tool, we mimic the API syntax of the underlying tool, so users can use what they are familiar with without learning a new API. For example, for Tensorboard, TBWriter uses add_scalar() and add_scalars(); for MLflow, the syntax is log_metric(), log_metrics(), log_parameter(), and log_parameters(); for W&B, the writer just has log(). @@ -109,7 +114,7 @@ The data collected with these calls will all send to the AnalyticsSender to deli Receiver ======== -Implement AnalyticsReceiver interface and determine how to represent different sites' logs. In all three implementations +Implement :class:`AnalyticsReceiver ` interface and determine how to represent different sites' logs. In all three implementations (Tensorboard, MLflow, WandB), each site's log is represented as one run. Depending on the individual tool, the implementation can be different. For example, for both Tensorboard and MLflow, we create different runs for each client and map to the site name. In the WandB implementation, we have to leverage multiprocess and let each run in a different process. @@ -121,13 +126,19 @@ Examples Overview The :github_nvflare_link:`experiment tracking examples ` illustrate how to leverage different writers and receivers. All examples are based upon the hello-pt example. +TensorBoard +=========== The example in the "tensorboard" directory shows how to use the Tensorboard Tracking Tool (for both the sender and receiver). See :ref:`tensorboard_streaming` for details. +MLflow +====== Under the "mlflow" directory, the "hello-pt-mlflow" job shows how to use MLflow for tracking with both the MLflow sender and receiver. The "hello-pt-tb-mlflow" job shows how to use the Tensorboard Sender, while the receiver is MLflow. See :ref:`experiment_tracking_mlflow` for details. +Weights & Biases +================ Under the :github_nvflare_link:`wandb ` directory, the "hello-pt-wandb" job shows how to use Weights and Biases for experiment tracking with the WandBWriter and WandBReceiver to log metrics. diff --git a/docs/programming_guide/fl_clients.rst b/docs/programming_guide/fl_clients.rst deleted file mode 100644 index 754d787588..0000000000 --- a/docs/programming_guide/fl_clients.rst +++ /dev/null @@ -1,52 +0,0 @@ -.. _fl_clients: - -########## -FL Clients -########## - -FLARE Clients are workers in the FL system that perform tasks. -We provide different levels of abstraction for writing FL Client code to support use cases ranging from complete customizability to easy user adaption. -Here is a general overview of the key ideas and use cases of each FL Client type ordered from most FLARE-specific to least FLARE-specific: - -**Executor** - -An :ref:`executor` is an FLComponent for clients used for executing tasks, wherein the execute method receives and returns a Shareable object given a task name. -Executors are the most flexible for defining custom logic and tasks, as with a custom executor and controller, any form of computation can be performed. -However, Executors must deal directly with FLARE-specific communication concepts such as :class:`Shareable`, :class:`DXO`, and :class:`FLContext`. -As a result, many higher level APIs are built on top of Executors in order to abstract these concepts away for easier user adaption. - -Overall, writing an Executor is most useful when implementing tasks and logic that do not fit within the structure of higher level APIs or other predefined Executors. - -**Model Learner** - -The :ref:`model_learner` is designed to simplify writing learning logic by minimizing FLARE specific concepts. -The :class:`ModelLearner` defines familiar learning functions for training and validation, and uses the :class:`FLModel` object for transferring learning information. -The ModelLearner also contains serveral convenience capabilities, such as lifecycle and logging information. - -The Model Learner is best used when working with standard machine learning code that can fit well into the train and validate methods and can be easily adapated to the ModelLearner subclass and method structure. - -**Client API** - -The :ref:`client_api` provides the most straightforward way to write FL code, and can easily be used to convert centralized code with minimal code changes. -The client API uses the :class:`FLModel` object for data transfer, and supports common tasks such as train, validate, and submit_model. -Additionally, options for using decorators or PyTorch Lightning are also available. - -As of version 2.4.0, we recommend users start with the Client API, and to consider the other Client types for more specific cases as required. - -**3rd-Party System Integration** - -The :ref:`3rd_party_integration` pattern allows for a seamless integration between the FLARE system and a third-party external training system. -This is especially useful with pre-existing ML/DL training system infrastructure that cannot be easily adapted to the FLARE client. - -With the use of the :mod:`FlareAgent ` and :mod:`TaskExchanger `, users can easily enable any 3rd-party system to receive tasks and submit results back to the server. - - -For more details about each client type, refer to each page below. - -.. toctree:: - :maxdepth: 3 - - fl_clients/executor - fl_clients/model_learner - fl_clients/client_api - fl_clients/3rd_party_integration \ No newline at end of file diff --git a/docs/programming_guide/fl_clients/3rd_party_integration.rst b/docs/programming_guide/fl_clients/3rd_party_integration.rst deleted file mode 100644 index 4c83823f3b..0000000000 --- a/docs/programming_guide/fl_clients/3rd_party_integration.rst +++ /dev/null @@ -1,329 +0,0 @@ -.. _3rd_party_integration: - -############################ -3rd-Party System Integration -############################ - -NVFLARE 2.4.0 supports 3rd-party external systems to integrate with FL clients. - -The FL Client installs the :mod:`TaskExchanger` executor and -the 3rd-party system uses the :mod:`FlareAgent` to interact with the TaskExchanger to receive tasks, and submit results to the FLARE server. - -This integration pattern is illustrated in the diagram below: - -.. image:: ../resources/3rd_party_integration_diagram.png - :height: 400px - -Requirements -============ - -- The key to enabling this integration is the "agent_id" that must be made known to both systems. - The FL client gets this information from the job's client_config.json, and the 3rd-party trainer gets this from its own launch process or via the :class:`Piper`. -- It is assumed that the customer already has a way to dynamically generate the "agent_id" for each job, and start its trainer process with this information. -- Each FL client must be able to open an address (host:port) to allow the trainer to connect to. Depending on where the trainer is running, the connection may or may not need to be in secure mode (TLS). -- The trainer must be a Python program that can integrate with the NVFLARE library. -- The trainer must be able to connect to the server, as well as the address that is dynamically opened by the FL client. - -Prepare the Trainer -=================== - -You need to modify your trainer code to integrate with the FlareAgent API. -This API provides simple `get_task()` and `submit_result()` methods to interact with the FL client (FL client). -The following is an example of this usage pattern. - -.. code-block:: python - - import argparse - import logging - - from nvflare.client.defs import RC, AgentClosed, MetaKey - from nvflare.client.flare_agent import FlareAgentWithCellPipe - - NUMPY_KEY = "numpy_key" - - - def main(): - - logging.basicConfig() - logging.getLogger().setLevel(logging.INFO) - - parser = argparse.ArgumentParser() - parser.add_argument("--workspace", "-w", type=str, help="workspace folder", required=False, default=".") - parser.add_argument("--site_name", "-s", type=str, help="flare site name", required=True) - parser.add_argument("--agent_id", "-a", type=str, help="agent id", required=True) - - args = parser.parse_args() - - agent = FlareAgentWithCellPipe( - root_url="grpc://server:8002", - flare_site_name=args.site_name, - agent_id=args.agent_id, - workspace_dir=args.workspace, - secure_mode=True, - submit_result_timeout=2.0, - heartbeat_timeout=120.0, - ) - - agent.start() - - while True: - print("getting task ...") - try: - task = agent.get_task() - except AgentClosed: - print("agent closed - exit") - break - - print(f"got task: {task}") - rc, meta, result = train(task.data) # peform train task - submitted = agent.submit_result(TaskResult(data=result, meta=meta, return_code=rc)) - print(f"result submitted: {submitted}") - - agent.stop() - - - def train(model): - ... - - if __name__ == "__main__": - main() - -Create the Agent ----------------- - -The :class:`FlareAgent` is responsible for interacting with the FL client to exchange task data takes the following parameters: - -- ``pipe`` - component id of pipe for communication -- ``read_interval`` - how often to read from pipe -- ``heartbeat_interval`` - how often to send heartbeat to peer -- ``heartbeat_timeout`` - max amount of time to allow missing heartbeats before treating peer as dead -- ``resend_interval`` - how often to resend a message when failing to send -- ``max_resends`` - max number of resends. None means no limit -- ``submit_result_timeout`` - when submitting task result, how long to wait for response from the FL client -- ``metric_pipe`` - component id of pipe for metrics -- ``task_channel_name`` - the channel name for tasks (defaults to PipeChannelName.TASK) -- ``metric_channel_name`` - the channel name for metrics (defaults to PipeChannelName.METRIC) -- ``close_pipe`` - whether pipe needs to be closed (FilePipe: False, CellPipe: True) - -If using FLModel, :class:`FlareAgentWithFLModel` subclasses FlareAgent and provides conversion from shareables to task using the FLModel data structure. - -If using CellPipe, then :class:`FlareAgentWithCellPipe` subclasses FlareAgent and takes the parameters: - -- ``agent_id`` - this is the ID of the agent dynamically generated by your launch system -- ``site_name`` - this is the name of the FL client provisioned for the project -- ``root_url`` - this is the URL of the server. -- ``secure_mode`` - whether the trainer/FL client communication will be in secure mode (SSL) -- ``workspace_dir`` - this is the local folder that contains the "startup" kit of the FL client site. The trainer system and the FL client must share the same "startup" content. - -Start the Agent ---------------- - -Call ``agent.start()`` to start the agent. This call must be made before trying to get tasks. - -Process Tasks -------------- - -The training is a continuous process of getting a task, executing the task, and submitting the task result. - -Call ``agent.get_task()`` to get a Task object from the FL client. This is a blocking call and returns only when a task is available. -If there are no more tasks available (i.e. end of the job), ``AgentClosed`` exception will be raised, and signaling to end the training. - -The :class:`Task` object contains 3 pieces of information: task_name, task_id, and data. -he task_name tells you what the task is (e.g. train). The task_id is a UUID of the task instance. -The data contains model data to be trained on. - -Once the task is completed, the result can be submitted to the FL client by calling ``agent.submit_result()``. -A return code (``rc``) must be provided to indicate whether the task was executed successfully. -If the ``rc`` is not RC.OK, then the job will be aborted. - -Stop Agent ----------- - -At the end of the training, ``agent.stop()`` must be called to end the program gracefully. -If this call is missed, the program may not exit properly. - -Notes: - -- This pattern of (``start``, ``get_task``, ``submit_result``, and ``stop``) is strictly enforced. - If the pattern is not followed (e.g. ``get_task``, then ``get_task`` again without ``submit_result``), you will get a ``CallStateError`` exception. -- The only way to know that the job is ended is the ``AgentClosed`` exception from the ``get_task`` call. - This exception is raised when the FL client tells the agent that the job is done; or when the FL client is considered dead (missing heartbeats for the configured period of time). -- If your training algorithm runs into an unrecoverable error and wants to end the job, you should use a proper return code (e.g. ``RC.EXECUTION_EXCEPTION``). - -Project Setup -============= - -The following steps show you how to properly set up your project and jobs. - -Step One - Provision --------------------- - -The FL client will behave like both client and server for connecting from the perspective of the trainer. -This requires the client site to have two sets of TLS credentials. -Make sure to specify the "listening_host" for the client in the project.yml when provisioning the project: - -.. code-block:: yaml - - participants: - # change example.com to the FQDN of the server - - name: server - type: server - org: nvidia - fed_learn_port: 8002 - admin_port: 8003 - - name: site_1 - type: client - org: nvidia - listening_host: site_1.maglev.nvidia.com - - name: site_2 - type: client - org: nvidia - listening_host: site_2.maglev.nvidia.com - -Once the project is provisioned, check the "startup" kit generated for the clients. You should see the following files, among others: - -client.crt, client.key, server.crt, server.key, rootCA.pem - -Note that the specified listening_port of a site must be accessible to the trainer of the site. - -Step Two - Setup for adhoc direct connection between FL Client and Trainer --------------------------------------------------------------------------- - -FL client and the trainer can always talk to each other via the server, but it could be slow, especially if the server is located far away. -The enable adhoc direct connections between the FL client and the Trainer, configure the comm_config.json on the client site as follows: - -.. code-block:: json - - { - "allow_adhoc_conns": true, - "use_aio_grpc": true, - "adhoc": { - "scheme": "tcp", - "resources": { - "host": "nvclient", - "secure": true - } - } - } - -This file must be placed into the site's "local" folder within its workspace. - -Pay attention to the following: - -- For most cases, the "scheme" should be set to "tcp" to get the best performance. If "tcp" cannot be used, you can use "grpc". -- In "resources": - - - If FL client and the Trainer are within the same trusted network, you can set "secure" to false; otherwise set it to true; - - The value of the "host" must match the "listening_host" value of the site used in provision. - -Step Three - Prepare job configuration --------------------------------------- - -For each job, configure the config_fed_client.json to use :mod:`TaskExchanger` as the executor. - -.. code-block:: json - - { - "format_version": 2, - "executors": [ - { - "tasks": [ - "train" - ], - "executor": { - "path": "nvflare.app_common.executors.task_exchanger.TaskExchanger", - "args": { - "pipe_id": "pipe" - "peer_read_timeout": 30, - "heartbeat_timeout": 60 - } - } - } - ], - "task_result_filters": [], - "task_data_filters": [], - "components": [ - ... - ] - } - -Make sure the parameters of the TaskExchanger are configured properly, and change the default values as needed: - -- ``pipe_id`` - component id of pipe -- ``read_interval`` - how often to read from pipe -- ``heartbeat_interval`` - how often to send heartbeat to peer -- ``heartbeat_timeout`` - max amount of time to allow missing heartbeats before treating peer as dead -- ``resend_interval`` - how often to resend a message when failing to send -- ``max_resends`` - max number of resends. None means no limit -- ``peer_read_timeout`` - time to wait for peer to accept sent message -- ``task_wait_time`` - how long to wait for a task to complete. None means waiting forever -- ``result_poll_interval`` - how often to poll task result -- ``pipe_channel_name`` - the channel name for sending task requests - -Step Four - Trainer Setup -------------------------- - -The trainer program must have access to a local file system, and you must create a "workspace" folder. This workspace should be used for all jobs. - -Copy the "startup" folder of the provisioned site, and put it in the designated workspace folder. -If needed, any additional config files required by the trainer can also be placed in the workspace folder. - -Ensure to set the FlareAgent's "workspace_dir" to the workspace folder and that the correct "agent_id" value is passed to both the FL client and the training process. - -Verification -============ - -The FL client (TaskExchanger) and your trainer process (FlareAgent) do not have to be started at exactly the same time. Whichever is started first will wait for the other for ``heartbeat_timeout`` seconds. -Once they both are started and connected, you can verify they are directly connected using the Admin's cell commands. - -The following example shows two clients (red, blue) connected to their external trainers via the agent_id "ext_trainer_1": - -.. code-block:: shell - - > cells - server - server.44c08365-e829-4bc1-a034-cda5a252fe73 - red - red.44c08365-e829-4bc1-a034-cda5a252fe73 - blue - blue.44c08365-e829-4bc1-a034-cda5a252fe73 - red--ndas_1 - blue--ndas_1 - Total Cells: 8 - Done [21695 usecs] 2023-10-16 19:28:37.523651 - -The ``cells`` command lists all cells. Notice that the job 44c08365-e829-4bc1-a034-cda5a252fe73 is running on both "blue" and "red" clients. -Also notice that there are two corresponding ext_trainer cells (red-ext_trainer_1, and blue-ext_trainer1). - -.. code-block:: shell - - > peers blue--ext_trainer_1 - server - blue.44c08365-e829-4bc1-a034-cda5a252fe73 - Total Agents: 2 - Done [14526 usecs] 2023-10-16 19:28:44.407505 - -The ``peers`` command shows the cells directly connected to the specified cell. -Here you see that the blue-ext_trainer_1 is directly connected to two cells: the server and the FL client (blue.44c08365-e829-4bc1-a034-cda5a252fe73). - -.. code-block:: shell - - > conns blue--ext_trainer_1 - { - "bb_ext_connector": { - "url": "grpc://server:8002", - "handle": "CH00001", - "type": "connector" - }, - "adhoc_connectors": { - "blue.44c08365-e829-4bc1-a034-cda5a252fe73": { - "url": "stcp://nvclient:11947", - "handle": "CH00002", - "type": "connector" - } - } - } - -The ``conns`` command shows the connectors on the specified cell. Here you see that blue--ext_trainer_1 has two connectors: -one connects the server on ``grpc://server:8002``, and another connects to ``blue.44c08365-e829-4bc1-a034-cda5a252fe73 on stcp://nvclient:11947``. -Note that this port is opened by the FL client dynamically. diff --git a/docs/programming_guide/fl_clients/client_api.rst b/docs/programming_guide/fl_clients/client_api.rst deleted file mode 100644 index ebfe051c4e..0000000000 --- a/docs/programming_guide/fl_clients/client_api.rst +++ /dev/null @@ -1,349 +0,0 @@ -.. _client_api: - -########## -Client API -########## - -The FLARE Client API provides an easy way for users to convert their centralized, local -training code into federated learning code with the following benefits: - -* Only requires a few lines of code changes, without the need to restructure the code or implement a new class -* Reduces the number of new FLARE specific concepts exposed to users -* Easy adaptation from existing local training code using different frameworks (PyTorch, PyTorch Lightning, HuggingFace) - -Core concept -============ - -Federated learning's concept is for each participating site to get a good model (better than -locally trained model) without sharing the data. - -It is done by sharing model parameters or parameter differences (certain filters can be used to -ensure privacy-preserving and protects against gradient inversion attacks) to each other. - -The aggregators will take in all these model parameters submitted by each site and produce a -new global model. - -We hope that this new global model will be better than locally trained model since it -conceptually trained on more data. - -One of the popular federated learning workflow, "FedAvg" is like this: - -The general structure of Federated Learning algorithms involve the following steps: - -#. controller site initializes an initial model -#. For each round (global iteration): - - #. controller sends the global model to clients - #. each client starts with this global model and trains on their own data - #. each client sends back their trained model - #. controller aggregates all the models and produces a new global model - -On the client side, the training workflow is: - -#. receive model from controller -#. perform local training on received model, evaluate global model for model selection -#. send new model back to controller - -To be able to support different training frameworks, we define a standard data structure called "FLModel" -for the local training code to exchange information with the FLARE system. - -We explain its attributes below: - -.. literalinclude:: ../../../nvflare/app_common/abstract/fl_model.py - :language: python - :lines: 41-67 - :linenos: - :caption: fl_model.py - -Users only need to obtain the required information from this received FLModel, -run local training, and put the results in a new FLModel to send back to the controller. - -For a general use case, there are three essential methods for the Client API: - -* `init()`: Initializes NVFlare Client API environment. -* `receive()`: Receives model from NVFlare side. -* `send()`: Sends the model to NVFlare side. - -Users can use these APIs to change their centralized training code to federated learning, for example: - -.. code-block:: python - - import nvflare.client as flare - - flare.init() - input_model = flare.receive() - new_params = local_train(input_model.params) - output_model = flare.FLModel(params=new_params) - flare.send(output_model) - -See below for more in-depth information about all of the Client API functionalities. - -Client API Module -================= - -nvflare.client.init -------------------- - -- Description: initialize required environment variables for NVFlare ML2FL client API -- Arguments: - - - config (str or dict): the path to the config file or the config dictionary - - rank (str): local rank of the process. It is only useful when the training script has multiple worker processes. (for example multi GPU) - -- Returns: None - -Usage: - -``nvflare.client.init(config="./config.json")`` - -Config example: - -.. code-block:: json - - { - "exchange_path": "./", - "exchange_format": "pytorch" - "transfer_type" : "FULL" - } - -Exchange_path is the file path where the model will be exchanged. -Exchange_format is the format we expect of the model, pre-defined ones are "raw", "numpy", "pytorch" -Transfer_type is how to transfer the model, FULL means send it as it is, DIFF means calculate the difference between new model VS initial received model - -nvflare.client.receive ----------------------- -- Description: receive FLModel from NVFlare side -- Arguments: - - - Timeout (Optional[float]): timeout to receive an FLModel - -- Returns: FLModel - -Usage: - -``model = nvflare.client.receive()`` - -nvflare.client.send -------------------- - -- Description: send back the FLModel to NVFlare side -- Arguments: - - - fl_model (FLModel): FLModel to be sent - - clear_registry (bool): whether to clear the model registry after send - -- Returns: None - -Usage: - -``nvflare.client.send(model=FLModel(xxx))`` - - -nvflare.client.system_info --------------------------- - -- Description: gets system's metadata -- Arguments: None -- Returns: A dictionary of system's metadata - -Usage: - -``sys_info = nvflare.client.system_info()`` - -System's metadata includes: - -- identity -- Job_id - -nvflare.client.get_job_id -------------------------- - -- Description: gets the NVFlare job id -- Arguments: None -- Returns: JOB_ID (str) - -Usage: - -``job_id = nvflare.client.get_job_id()`` - -nvflare.client.get_identity ---------------------------- -- Description: gets the NVFlare site name that this process is running on -- Arguments: None -- Returns: identity (str) - -Usage: - -``identity = nvflare.client.get_identity()`` - -nvflare.client.clear --------------------- - -- Description: clears the model registry -- Arguments: None -- Returns: None - -Usage: - -``nvflare.client.clear()`` - -nvflare.client.get_config -------------------------- - -- Description: gets the model registry config -- Arguments: None -- Returns: identity (dict) - -Usage: - -``config = nvflare.client.get_config()`` - -nvflare.client.is_running -------------------------- - -- Description: check if FLARE job is still running in the case of launching once -- Arguments: None -- Returns: bool - -Usage: - -.. code-block:: python - - while nvflare.client.is_running(): - # receive model, perform task, send model, etc. - -nvflare.client.is_train ------------------------ - -- Description: check if current task is train -- Arguments: None -- Returns: bool - -Usage: - -.. code-block:: python - - if nvflare.client.is_train(): - # perform train task on received model - -nvflare.client.is_evaluate() ----------------------------- - -- Description: check if current task is evaluate -- Arguments: None -- Returns: bool - -Usage: - -.. code-block:: python - - if nvflare.client.is_evaluate(): - # perform evaluate task on received model - -nvflare.client.is_submit_model() --------------------------------- - -- Description: check if current task is submit_model -- Arguments: None -- Returns: bool - -Usage: - -.. code-block:: python - - if nvflare.client.is_submit_model(): - # perform submit_model task to obtain best local model - -Client Decorator Module -======================= -nvflare.client.train --------------------- - -Use cases: - -.. code-block:: python - - @nvflare.client.train - def my_train(input_model=None, device="cuda:0"): - ... - return new_model - -NVFlare will pass the FLModel received from the NVFlare server side to the first argument of the "decorated" method. -The return value needs to be an FLModel object, we will send it directly to the NVFlare server side. - - -nvflare.client.evaluate ------------------------ - -Use cases: - -.. code-block:: python - - @nvflare.client.evaluate - def my_eval(input_model, device="cuda:0"): - ... - return metrics - -NVFlare will pass the model received from the NVFlare server side to the first argument of the "decorated" method. -The return value needs to be a "float" metric. -The decorated "my_eval" method needs to be run BEFORE the training method, so the metrics will be sent along with the trained output model. - -Lightning Integration -===================== -nvflare.client.lightning.patch ------------------------------- - -- Description: patch the PyTorch Lightning Trainer object -- Arguments: trainer -- Returns: None - -Usage: - -.. code-block:: python - - trainer = Trainer(max_epochs=1) - flare.patch(trainer) - -Advanced Usage: - -Note that if users want to pass additional information to NVFlare server side VIA the lightning API, they will need to set the information inside the attributes called ``__fl_meta__`` in their LightningModule. For example: - -.. code-block:: python - - class LitNet(LightningModule): - def __init__(self): - super().__init__() - self.save_hyperparameters() - self.model = Net() - self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES) - self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES) - self.__fl_meta__ = {"CUSTOM_VAR": "VALUE_OF_THE_VAR"} - -Configuration and Installation -============================== - -In the client_config.json, in order to launch the training script we use the :class:`SubprocessLauncher` component. -The defined ``script`` is invoked, and ``launch_once`` can be set to either launch once for the whole job, or launch a process for each task received from the server. - -A corresponding :class:`LauncherExecutor` is used as the executor to handle the tasks and peform the data exchange using the pipe. -For the Pipe component we provide implementations of :class:`FilePipe` and :class:`CellPipe`. - -.. literalinclude:: ../../../job_templates/sag_pt/config_fed_client.conf - :language: json - -For example configurations, take a look at the :github_nvflare_link:`job_templates ` directory for templates using the launcher and Client API. - -.. note:: - In that case that the user does not need to launch the process via the SubprocessLauncher and instead has their own external training system, this would involve using - the :ref:`3rd_party_integration`, which is based on the same underlying mechanisms. - Rather than a LauncherExecutor, the parent class :class:`TaskExchanger` would be used to handle the tasks and enable pipe data exchange. - Additionally, the :class:`FlareAgent` would be used to communicate with the Flare Client Job Cell to get the tasks and submit the result. - -Examples -======== - -For examples of using Client API with different frameworks, -please refer to :github_nvflare_link:`examples/hello-world/ml-to-fl `. - -For additional examples, also take a look at the :github_nvflare_link:`step-by-step series ` -that use a :github_nvflare_link:`Client API trainer `. diff --git a/docs/programming_guide/fl_model.rst b/docs/programming_guide/fl_model.rst new file mode 100644 index 0000000000..6b2a9bad07 --- /dev/null +++ b/docs/programming_guide/fl_model.rst @@ -0,0 +1,17 @@ +.. _fl_model: + +FLModel +======= + +We define a standard data structure :mod:`FLModel` +that captures the common attributes needed for exchanging learning results. + +This is particularly useful when NVFlare system needs to exchange learning +information with external training scripts/systems. + +The external training script/system only need to extract the required +information from received FLModel, run local training, and put the results +in a new FLModel to be sent back. + +For a detailed explanation of each attributes, please refer to the API doc: +:mod:`FLModel` diff --git a/docs/programming_guide/resources/3rd_party_integration_diagram.png b/docs/programming_guide/resources/3rd_party_integration_diagram.png deleted file mode 100644 index 43de0dbaa7..0000000000 Binary files a/docs/programming_guide/resources/3rd_party_integration_diagram.png and /dev/null differ diff --git a/docs/programming_guide/resources/te.py b/docs/programming_guide/resources/te.py deleted file mode 100644 index 3f511a6c34..0000000000 --- a/docs/programming_guide/resources/te.py +++ /dev/null @@ -1,9 +0,0 @@ -def _get_model_weights(self) -> Shareable: - # Get state dict and send as weights - new_weights = self.model.state_dict() - new_weights = {k: v.cpu().numpy() for k, v in new_weights.items()} - - outgoing_dxo = DXO( - data_kind=DataKind.WEIGHTS, data=new_weights, meta={MetaKey.NUM_STEPS_CURRENT_ROUND: self._n_iterations} - ) - return outgoing_dxo.to_shareable() diff --git a/docs/programming_guide/workflows_and_controllers.rst b/docs/programming_guide/workflows_and_controllers.rst index 052327365e..9a75c9901d 100644 --- a/docs/programming_guide/workflows_and_controllers.rst +++ b/docs/programming_guide/workflows_and_controllers.rst @@ -7,12 +7,13 @@ A workflow has one or more controllers, each implementing a specific coordinatio CrossSiteValidation controller implements a strategy to let every client site evaluate every other site's model. You can put together a workflow that uses any number of controllers. -Before version 2.4, all federating learning workflows (fed-average, cyclic controller, cross-site evaluation) were server controlled, -implemented with the server-side :ref:`controllers `. In these workflows, -FL clients get tasks assigned by the controller, execute the tasks, -and submit results back to the server. The first section covers the server-side -controller API for server-controlled workflows. The second section covers :ref:`client_controlled_workflows` for -workflows that are controlled by the clients. +We have implemented several server controlled federated learning workflows (fed-average, cyclic controller, cross-site evaluation) with the server-side :ref:`controllers `. +In these workflows, FL clients get tasks assigned by the controller, execute the tasks, and submit results back to the server. + +In certain cases, if the server cannot be trusted, it should not be involved in communication with sensitive information. +To address this concern, NVFlare introduces Client Controlled Workflows (CCWF) to facilitate peer-to-peer communication among clients. + +Please refer to the following sections for more details. .. toctree:: :maxdepth: 3 diff --git a/docs/release_notes/flare_230.rst b/docs/release_notes/flare_230.rst index a042e43212..6721de4b86 100644 --- a/docs/release_notes/flare_230.rst +++ b/docs/release_notes/flare_230.rst @@ -41,7 +41,7 @@ Prior to FLARE 2.3.0, model initialization was performed on the server-side. The model was either initialized from a model file or custom model initiation code. Pre-defining a model file required extra steps of pre-generating and saving the model file and then sending it over to the server. Running custom model initialization code on server could be a security risk. -FLARE 2.3.0 introuduces another way to initialize the model on the client side. The FL Server can select +FLARE 2.3.0 introduces another way to initialize the model on the client side. The FL Server can select the initial model based on a user-chosen strategy. Here is an example using client-side model initialization: https://github.com/NVIDIA/NVFlare/tree/main/examples/hello-world/hello-pt. You can read more about this feature in :ref:`initialize_global_weights_workflow`. @@ -67,7 +67,7 @@ Federated Private Set Intersection (PSI) In order to support vertical learning use cases such as secure user-id matching and feature over-lapping discovery, we have developed a multi-party private set intersection (PSI) operator that allows for the secure discovery of data intersections. Our approach leverages OpenMined's two-party -`Private Set Intersection Cardinality protocol `_, which is basedon ECDH and Bloom Filters, and we have +`Private Set Intersection Cardinality protocol `_, which is based on ECDH and Bloom Filters, and we have made this protocol available for multi-party use. More information on our approach and how to use the PSI operator can be found in the :github_nvflare_link:`PSI Example `. diff --git a/docs/release_notes/flare_240.rst b/docs/release_notes/flare_240.rst index 80eb65d067..4741b386f0 100644 --- a/docs/release_notes/flare_240.rst +++ b/docs/release_notes/flare_240.rst @@ -23,7 +23,7 @@ Here is a brief example of a common pattern when using the Client API for a clie # initialize NVFlare client API flare.init() - # run continously when launching once + # run continuously when launching once while flare.is_running(): # receive FLModel from NVFlare @@ -63,12 +63,12 @@ Furthermore, the Job CLI also offers users a convenient method for submitting jo ``nvflare job list_templates|create|submit|show_variables`` -Also explore the continously growing :github_nvflare_link:`Job Template directory ` we have created for commonly used configurations. +Also explore the continuously growing :github_nvflare_link:`Job Template directory ` we have created for commonly used configurations. For more in-depth information on Job Templates and the Job CLI, refer to the :ref:`job_cli` documentation and :github_nvflare_link:`tutorials `. ModelLearner ------------ -The ModelLearner is introduced for a simplifed user experience in cases requiring a Learner-pattern. +The ModelLearner is introduced for a simplified user experience in cases requiring a Learner-pattern. Users exclusively interact with the FLModel object, which includes weights, optimizer, metrics, and metadata, while FLARE-specific concepts remain hidden to users. The ModelLearner defines standard learning functions, such as ``train()``, ``validate()``, and ``submit_model()`` that can be subclassed for easy adaptation. @@ -83,20 +83,21 @@ Each example will build upon previous ones to showcase different features, workf **CIFAR10 Examples:** -- stats: federated statistics (histograms) of CIFAR10. +- image_stats: federated statistics (histograms) of CIFAR10. - sag: scatter and gather (SAG) workflow with PyTorch with Client API. -- sag_with_deploy_map: scatter and gather workflow with deploy_map configuration, for deployment of apps to different sites using the Client API. -- cse: cross-site evaluation using the Client API. +- sag_deploy_map: scatter and gather workflow with deploy_map configuration, for deployment of apps to different sites using the Client API. - sag_model_learner: scatter and gather workflow illustrating how to write client code using the ModelLearner. - sag_executor: scatter and gather workflow demonstrating show to write client-side executors. +- sag_mlflow: MLflow experiment tracking logs with the Client API in scatter & gather workflows. +- sag_he: homomorphic encryption using Client API and POC -he mode. +- cse: cross-site evaluation using the Client API. - cyclic: cyclic weight transfer workflow with server-side controller. - cyclic_ccwf: client-controlled cyclic weight transfer workflow with client-side controller. - swarm: swarm learning and client-side cross-site evaluation with Client API. -- sag_with_mlflow: MLflow experiment tracking logs with the Client API in scatter & gather workflows. **HIGGS Examples:** -- tabular_stats: federated stats tabular histogram calculation. +- tabular_stats: federated statistics tabular histogram calculation. - scikit_learn: federated linear model (logistic regression on binary classification) learning on tabular data. - sklearn_svm: federated SVM model learning on tabular data. - sklearn_kmeans: federated k-Means clustering on tabular data. @@ -132,7 +133,7 @@ Client-side controlled workflow Three commonly used types of client-side controlled workflows are provided: - :ref:`ccwf_cyclic_learning`: the model is passed from client to client. -- :ref:`ccwf_swarm_learning`: randomly select clients as client-side controller and aggregrators, where then Scatter and Gather with FedAvg is performed. +- :ref:`ccwf_swarm_learning`: randomly select clients as client-side controller and aggregators, where then Scatter and Gather with FedAvg is performed. - :ref:`ccwf_cross_site_evaluation`: allow clients to evaluate other sites' models. See :github_nvflare_link:`swarm learning ` and :github_nvflare_link:`client-controlled cyclic ` for examples using these client-controlled workflows. @@ -167,7 +168,7 @@ Improved Job Configuration File Processing - OS Environment Variables - OS environment variables can be referenced via the dollar sign - Parameterized Variable Definition - for creating configuration templates that can be reused and resolved into different concrete configurations -See more details in the :ref:`configuration_files` documentation. +See more details in the :ref:`configurations` documentation. POC Command Upgrade =================== @@ -201,13 +202,13 @@ FL HUB: Hierarchical Unification Bridge ======================================= The FL HUB is a new experimental feature designed to support multiple FLARE systems working together in a hierarchical manner. In Federated Computing, the number of edge devices is usually large with often just a single server, which can cause performance issues. -A solution to this problem is to use a hierachical FLARE system, where tiered FLARE systems connect together to form a tree-like structure. +A solution to this problem is to use a hierarchical FLARE system, where tiered FLARE systems connect together to form a tree-like structure. Each leaf of clients (edge devices) only connect to its server, where this server also serves as the client for the parent tier FLARE system. One potential use case is with global studies, where the client machine may be located across different regions. Rather than requiring every region's client machines connect to only a single FL server in that region, the FL HUB could enable a more performant tiered multi-server setup. -Learn more about the FL Hub in the :ref:`hierarchy_unification_bridge` documenation and the :github_nvflare_link:`code `. +Learn more about the FL Hub in the :ref:`hierarchy_unification_bridge` documentation and the :github_nvflare_link:`code `. Misc. Features ============== @@ -235,7 +236,7 @@ Misc. Features - We added the application layer ping between Client Job process and Server parent process to replace the gRPC timeout. Previously, we noticed if the gRPC timeout is set too long, the cloud provider (eg. Azure Cloud) will kill the connection after 4 minutes. - If the timeout setup is too short (such as 2 mins), the underlying gRPC will report too many pings. + If the timeout setup is too short (such as 2 minutes), the underlying gRPC will report too many pings. The application level ping will avoid both issues to make sure the server/client is aware of the status of the processes. - FLARE provides two drivers for gRPC based communication- asyncio (AIO) and regular (non-AIO) versions of gRPC library. One notable benefit of the AIO gRPC is its ability to handle many more concurrent connections on the server side. @@ -277,7 +278,7 @@ For this financial application, we use the `Elliptic++ `_. -Finanical Application Examples +Financial Application Examples ------------------------------ To demonstrate how to perform Fraud Detection in financial applications, we introduced an :github_nvflare_link:`example ` illustrating how to use XGBoost in various ways to train a model in a federated manner with a `finance dataset `_. @@ -329,7 +330,7 @@ Here is the default meta.json which can be edited accordingly: FLARE API Parity ================ -In FLARE 2.3.0, an intial version of the FLARE API was implemented as a redesigend FLAdminAPI, however we only included a subset of the functions. +In FLARE 2.3.0, an initial version of the FLARE API was implemented as a redesigned FLAdminAPI, however we only included a subset of the functions. In FLARE 2.4.0, the FLARE API has been enhanced to include the remaining functions of the FLAdminAPI, so that the FLAdminAPI can sunset. See the :ref:`migrating_to_flare_api` for more details on the added functions. diff --git a/docs/resources/3rd_party_integration_diagram.png b/docs/resources/3rd_party_integration_diagram.png new file mode 100644 index 0000000000..5f99832968 Binary files /dev/null and b/docs/resources/3rd_party_integration_diagram.png differ diff --git a/docs/resources/3rd_party_trainer.py b/docs/resources/3rd_party_trainer.py new file mode 100644 index 0000000000..1ffdd085bb --- /dev/null +++ b/docs/resources/3rd_party_trainer.py @@ -0,0 +1,59 @@ +import argparse +import logging + +from nvflare.client.flare_agent import AgentClosed, FlareAgentWithCellPipe + +NUMPY_KEY = "numpy_key" + + +def main(): + + logging.basicConfig() + logging.getLogger().setLevel(logging.INFO) + + parser = argparse.ArgumentParser() + parser.add_argument("--workspace", "-w", type=str, help="workspace folder", required=False, default=".") + parser.add_argument("--site_name", "-s", type=str, help="flare site name", required=True) + parser.add_argument("--agent_id", "-a", type=str, help="agent id", required=True) + + args = parser.parse_args() + + # 1. create the agent + agent = FlareAgentWithCellPipe( + root_url="grpc://server:8002", + site_name=args.site_name, + agent_id=args.agent_id, + workspace_dir=args.workspace, + secure_mode=True, + submit_result_timeout=2.0, + heartbeat_timeout=120.0, + ) + + # 2. start the agent + agent.start() + + # 3. processing tasks + while True: + print("getting task ...") + try: + task = agent.get_task() + except AgentClosed: + print("agent closed - exit") + break + + print(f"got task: {task}") + result = train(task.data) # perform train task + submitted = agent.submit_result(result) + print(f"result submitted: {submitted}") + + # 4. stop the agent + agent.stop() + + +def train(model): + print(f"training on {model}") + return model + + +if __name__ == "__main__": + main() diff --git a/docs/programming_guide/resources/fed_sag_round.png b/docs/resources/fed_sag_round.png similarity index 100% rename from docs/programming_guide/resources/fed_sag_round.png rename to docs/resources/fed_sag_round.png diff --git a/docs/resources/fl_diagram.png b/docs/resources/fl_diagram.png new file mode 100644 index 0000000000..cb5442732f Binary files /dev/null and b/docs/resources/fl_diagram.png differ diff --git a/docs/programming_guide/resources/init_weights_1_config_fed_server.json b/docs/resources/init_weights_1_config_fed_server.json similarity index 100% rename from docs/programming_guide/resources/init_weights_1_config_fed_server.json rename to docs/resources/init_weights_1_config_fed_server.json diff --git a/docs/resources/nvidia_logo.png b/docs/resources/nvidia_logo.png new file mode 100644 index 0000000000..578592cfc3 Binary files /dev/null and b/docs/resources/nvidia_logo.png differ diff --git a/docs/resources/task_execution_decision_chart.png b/docs/resources/task_execution_decision_chart.png new file mode 100644 index 0000000000..c40bae817e Binary files /dev/null and b/docs/resources/task_execution_decision_chart.png differ diff --git a/docs/user_guide.rst b/docs/user_guide.rst index 71c16314a2..0670440c9b 100644 --- a/docs/user_guide.rst +++ b/docs/user_guide.rst @@ -4,16 +4,12 @@ User Guide ########## -This user guide has information about various features in NVIDIA FLARE. +This user guide provides instructions on how to utilize various features in NVIDIA FLARE. For information about operating an FL system, see :ref:`Real-World Federated Learning `. -For more details on what you can do with apps with custom components and -the flexibility that the Controller and Worker APIs bring, see the :ref:`programming_guide`. - -In version 2.2, the commands for NVIDIA FLARE have been consolidated to be under the ``nvflare`` command for -better ease of use. This includes the FL Simulator, the POC command, ``provision``, and preflight check, all of -which are explained in more detail in their own sections linked below. +For a more in-depth exploration of the capabilities offered by apps with custom workflows and algorithms, +please refer to the :ref:`programming_guide`. .. toctree:: :maxdepth: 1 diff --git a/docs/user_guide/configurations/communication_configuration.rst b/docs/user_guide/configurations/communication_configuration.rst index 4cdd95f0ce..dbe1ad2a4b 100644 --- a/docs/user_guide/configurations/communication_configuration.rst +++ b/docs/user_guide/configurations/communication_configuration.rst @@ -126,6 +126,9 @@ This is done by setting use_aio_grpc to true: ``"use_aio_grpc": true`` +On the server side if you use the non-AIO gRPC driver, the default maximum number of workers is 100, meaning that there can be at most 100 concurrent connections to the server. +If this is not enough, you will need to use the AIO gRPC driver. + Ad-hoc Connections ================== diff --git a/docs/user_guide/nvflare_cli.rst b/docs/user_guide/nvflare_cli.rst index 360e41ed06..d806a2aedd 100644 --- a/docs/user_guide/nvflare_cli.rst +++ b/docs/user_guide/nvflare_cli.rst @@ -4,9 +4,10 @@ NVFlare CLI ########################### -The commands for NVIDIA FLARE have been consolidated to be under the ``nvflare`` command for -better ease of use. This includes the FL Simulator, the POC command, ``provision``, and preflight check, all of -which are explained in more detail in their own sections: +Various NVIDIA FLARE command line interfaces are available to enhance usability. +These include the FL Simulator, the POC command, the provision command, the job command, +the preflight check command, and the dashboard command. +Detailed explanations for each can be found in their respective sections, linked below. .. toctree:: :maxdepth: 1 diff --git a/docs/user_guide/nvflare_cli/fl_simulator.rst b/docs/user_guide/nvflare_cli/fl_simulator.rst index b0ffe8fa10..f52fd47078 100644 --- a/docs/user_guide/nvflare_cli/fl_simulator.rst +++ b/docs/user_guide/nvflare_cli/fl_simulator.rst @@ -49,7 +49,7 @@ Command examples Run a single NVFlare app ======================== -This command will run the same ``hello-numpy-sag`` app on the server and 8 clients using 1 thread. The client names will be site-1, site-2, ... , site-8: +This command will run the same ``hello-numpy-sag`` app on the server and 8 clients using 1 process. The client names will be site-1, site-2, ... , site-8: .. code-block:: python @@ -829,22 +829,29 @@ application run. status = run_simulator(args) sys.exit(status) -**************************** -Threads, Clients, and Events -**************************** +****************************** +Processes, Clients, and Events +****************************** -Specifying threads -================== -The simulator ``-t`` option provides the ability to specify how many threads to run the simulator with. +Specifying number of processes +============================== +The simulator ``-t`` option provides the ability to specify how many processes to run the simulator with. -When you run the simulator with ``-t 1``, there is only one client active and running at a time, and the clients will be running in -turn. This is to enable the simulation of large number of clients using a single machine with limited resources. +.. note:: + + The ``-t`` and ``--threads`` option for simulator was originally due to clients running in separate threads. + However each client now actually runs in a separate process. This distinction will not affect the user experience. + +- N = number of clients (``-n``) +- T = number of processes (``-t``) -Note that if you have fewer threads than the number of clients, ClientRunner/learner object will go thorugh setup and -teardown in every round. +When running the simulator with fewer processes than clients (T < N) +the simulator will need to swap-in/out the clients for the processes, resulting in some of the clients running sequentially as processes are available. +This also will cause the ClientRunner/learner objects to go through setup and teardown in every round. +Using T < N is only needed when trying to simulate of large number of clients using a single machine with limited resources. -With ``-t=num_client``, the simulator will run the number of clients in separate threads at the same time. Each -client will always be running in memory with no swap_in / swap_out, but it will require more resources available. +In most cases, run the simulator with the same number of processes as clients (T = N). The simulator will run the number of clients in separate processes at the same time. Each +client will always be running in memory with no swap-in/out, but it will require more resources available. For the dataset / tensorboard initialization, you could make use of EventType.SWAP_IN and EventType.SWAP_OUT in the application. diff --git a/docs/user_guide/nvflare_cli/job_cli.rst b/docs/user_guide/nvflare_cli/job_cli.rst index 78852af370..94f0395198 100644 --- a/docs/user_guide/nvflare_cli/job_cli.rst +++ b/docs/user_guide/nvflare_cli/job_cli.rst @@ -45,22 +45,42 @@ the job_templates. The output should be similar to the following: -.. code-block::shell +.. code-block:: none The following job templates are available: - ------------------------------------------------------------------------------------------------------------------------ - name Description Controller Type Client Category - ------------------------------------------------------------------------------------------------------------------------ - sag_cross_np scatter & gather and cross-site validation using numpy server client executor - sag_pt scatter & gather workflow using pytorch server client_api - sag_pt_ddp scatter & gather workflow using pytorch + ddp server client_api - sag_pt_deploy_map SAG workflow with pytorch, deploy_map, site-specific configs server client_api - sag_tf scatter & gather workflow using TensorFlow server client_api - stats_df FedStats: tabular data with pandas server stats executor - stats_image FedStats: image intensity histogram server stats executor - ------------------------------------------------------------------------------------------------------------------------ - + ---------------------------------------------------------------------------------------------------------------------- + name Description Controller Type Execution API Type + ---------------------------------------------------------------------------------------------------------------------- + cyclic_cc_pt client-controlled cyclic workflow with PyTorch ClientAPI tra client client_api + cyclic_pt server-controlled cyclic workflow with PyTorch ClientAPI tra server client_api + psi_csv private-set intersection for csv data server Executor + sag_cross_np scatter & gather and cross-site validation using numpy server client executor + sag_cse_pt scatter & gather workflow and cross-site evaluation with PyT server client_api + sag_gnn scatter & gather workflow for gnn learning server client_api + sag_nemo Scatter and Gather Workflow for NeMo server client_api + sag_np scatter & gather workflow using numpy server client_api + sag_np_cell_pipe scatter & gather workflow using numpy server client_api + sag_np_metrics scatter & gather workflow using numpy server client_api + sag_pt scatter & gather workflow using pytorch server client_api + sag_pt_deploy_map SAG workflow with pytorch, deploy_map, site-specific configs server client_api + sag_pt_executor scatter & gather workflow and cross-site evaluation with PyT server Executor + sag_pt_he scatter & gather workflow using pytorch and homomorphic encr server client_api + sag_pt_mlflow scatter & gather workflow using pytorch with MLflow tracking server client_api + sag_pt_model_learner scatter & gather workflow and cross-site evaluation with PyT server ModelLearner + sag_tf scatter & gather workflow using TensorFlow server client_api + sklearn_kmeans scikit-learn KMeans model server client_api + sklearn_linear scikit-learn linear model server client_api + sklearn_svm scikit-learn SVM model server client_api + stats_df FedStats: tabular data with pandas server stats executor + stats_image FedStats: image intensity histogram server stats executor + swarm_cse_pt Swarm Learning with Cross-Site Evaluation with PyTorch client client_api + swarm_cse_pt_model_l Swarm Learning with Cross-Site Evaluation with PyTorch Model client ModelLearner + vertical_xgb vertical federated xgboost server Executor + xgboost_tree xgboost horizontal tree-based collaboration model server client_api + ---------------------------------------------------------------------------------------------------------------------- + +View all the available templates at the :github_nvflare_link:`FLARE Job Template Registry `. Setting job_template path ------------------------- @@ -90,20 +110,18 @@ The options for usage are as follows: .. code-block:: - usage: nvflare job create [-h] [-j [JOB_FOLDER]] [-w [TEMPLATE]] [-s [SCRIPT]] [-sd [SCRIPT_DIR]] [-f [CONFIG_FILE ...]] [-debug] [-force] + usage: nvflare job create [-h] [-j [JOB_FOLDER]] [-w [TEMPLATE]] [-sd [SCRIPT_DIR]] [-f [CONFIG_FILE [CONFIG_FILE ...]]] [-debug] [-force] - options: + optional arguments: -h, --help show this help message and exit -j [JOB_FOLDER], --job_folder [JOB_FOLDER] job_folder path, default to ./current_job directory -w [TEMPLATE], --template [TEMPLATE] - template name or template folder. You can use list_templates to see available jobs from job templates, pick name such as 'sag_pt' as template name. Alternatively, you can use the path to the job template folder, such as - job_templates/sag_pt - -s [SCRIPT], --script [SCRIPT] - code script such as train.py + template name or template folder. You can use list_templates to see available jobs from job templates, pick name such as 'sag_pt' as template name. Alternatively, you can use the path to the job + template folder, such as job_templates/sag_pt -sd [SCRIPT_DIR], --script_dir [SCRIPT_DIR] script directory contains additional related files. All files or directories under this directory will be copied over to the custom directory. - -f [CONFIG_FILE ...], --config_file [CONFIG_FILE ...] + -f [CONFIG_FILE [CONFIG_FILE ...]], --config_file [CONFIG_FILE [CONFIG_FILE ...]] Training config file with corresponding optional key=value pairs. If key presents in the preceding config file, the value in the config file will be overwritten by the new value -debug, --debug debug is on -force, --force force create is on, if -force, overwrite existing configuration with newly created configurations diff --git a/docs/user_guide/nvflare_cli/poc_command.rst b/docs/user_guide/nvflare_cli/poc_command.rst index 9bc0d578bf..70927bc93a 100644 --- a/docs/user_guide/nvflare_cli/poc_command.rst +++ b/docs/user_guide/nvflare_cli/poc_command.rst @@ -1,11 +1,9 @@ .. _poc_command: ***************************************** -Command for Proof Of Concept (POC) Mode +Proof Of Concept (POC) Command ***************************************** -Introduction to the POC Command -=============================== The POC command allows users to try out the features of NVFlare in a proof of concept deployment on a single machine. diff --git a/examples/README.md b/examples/README.md index 7887865d53..9b9e0f25b9 100644 --- a/examples/README.md +++ b/examples/README.md @@ -72,14 +72,26 @@ Start a Jupyter Lab: When you open a notebook, select the kernel `nvflare_example` using the dropdown menu at the top right. ![Selecting a JupyterLab kernel](./jupyterlab_kernel.png) -## 1. Step-by-Step Examples -| Example | Dataset | Controller-Type | Client Category | Framework | Summary | +## 1. Hello World Examples +| Example | Framework | Summary | +|----------------------------------------------------------------------------------------------------------------------------------------|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------| +| [Notebook for Hello Examples](./hello-world/hello_world.ipynb) | - | Notebook for examples below. | +| [Hello Scatter and Gather](./hello-world/hello-numpy-sag/README.md) | Numpy | Example using [ScatterAndGather](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_common.workflows.scatter_and_gather.html) controller workflow. | +| [Hello Cross-Site Validation](./hello-world/hello-numpy-cross-val/README.md) | Numpy | Example using [CrossSiteModelEval](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_common.workflows.cross_site_model_eval.html) controller workflow, and example using previous results without training workflow. | +| [Hello Cyclic Weight Transfer](./hello-world/hello-cyclic/README.md) | PyTorch | Example using [CyclicController](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_common.workflows.cyclic_ctl.html) controller workflow to implement [Cyclic Weight Transfer](https://pubmed.ncbi.nlm.nih.gov/29617797/). | +| [Hello PyTorch](./hello-world/hello-pt/README.md) | PyTorch | Example using an image classifier using [FedAvg](https://arxiv.org/abs/1602.05629) and [PyTorch](https://pytorch.org/) as the deep learning training framework. | +| [Hello TensorFlow](./hello-world/hello-tf2/README.md) | TensorFlow2 | Example of using an image classifier using [FedAvg](https://arxiv.org/abs/1602.05629) and [TensorFlow](https://tensorflow.org/) as the deep learning training framework. | + +## 2. Step-by-Step Examples +| Example | Dataset | Controller-Type | Execution API Type | Framework | Summary | |---------|---------|-----------------|-----------------|-----------|---------| | [image_stats](./hello-world/step-by-step/cifar10/stats/image_stats.ipynb) | CIFAR10 | server | Executor | Pandas | Example for federated stats image histogram calculation. | | [sag](./hello-world/step-by-step/cifar10/sag/sag.ipynb) | CIFAR10 | server | Client API| PyTorch | Example for FedAvg with [ScatterAndGather](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_common.workflows.scatter_and_gather.html) controller workflow using the Client API. | -| [sag_with_deploy_map](./hello-world/step-by-step/cifar10/sag_with_deploy_map/sag_deploy_map.ipynb) | CIFAR10 | server | Client API | PyTorch | Example showcasing site-specific configurations and deploy_map. | -| [sag_executor](./hello-world/step-by-step/cifar10/sag_executor/sag_executor.ipynb) | CIFAR10 | server | Executor | PyTorch | Example with [ScatterAndGather](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_common.workflows.scatter_and_gather.html) using an Executor. | +| [sag_deploy_map](./hello-world/step-by-step/cifar10/sag_deploy_map/sag_deploy_map.ipynb) | CIFAR10 | server | Client API | PyTorch | Example showcasing site-specific configurations and deploy_map. | | [sag_model_learner](./hello-world/step-by-step/cifar10/sag_model_learner/sag_model_learner.ipynb) | CIFAR10 | server | ModelLearner | PyTorch | Example with [ScatterAndGather](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_common.workflows.scatter_and_gather.html) using a ModelLearner. | +| [sag_executor](./hello-world/step-by-step/cifar10/sag_executor/sag_executor.ipynb) | CIFAR10 | server | Executor | PyTorch | Example with [ScatterAndGather](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_common.workflows.scatter_and_gather.html) using an Executor. | +| [sag_mlflow](./hello-world/step-by-step/cifar10/sag_mlflow/sag_mlflow.ipynb) | CIFAR10 | server | Client API | PyTorch | MLflow experiment tracking logs with [ScatterAndGather](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_common.workflows.scatter_and_gather.html) using the Client API. | +| [sag_he](./hello-world/step-by-step/cifar10/sag_he/sag_he.ipynb) | CIFAR10 | server | Client API | PyTorch | Example with homomorphic encyption using Client API and POC -he mode. | | [cse](./hello-world/step-by-step/cifar10/cse/cse.ipynb) | CIFAR10 | server | Client API| PyTorch | Example using [CrossSiteModelEval](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_common.workflows.cross_site_model_eval.html) controller workflow. | | [cyclic](./hello-world/step-by-step/cifar10/cyclic/cyclic.ipynb) | CIFAR10 | server | Client API | PyTorch | Example for cyclic weight transfer using [CyclicController](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_common.workflows.cyclic_ctl.html) controller workflow. | | [cyclic_ccwf](./hello-world/step-by-step/cifar10/cyclic_ccwf/cyclic_ccwf.ipynb) | CIFAR10 | client| Client API | PyTorch | Example for client-controlled cyclic weight transfer using [CyclicClientController](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_common.ccwf.cyclic_client_ctl.html) controller workflow. | @@ -90,16 +102,6 @@ When you open a notebook, select the kernel `nvflare_example` using the dropdown | [sklearn_kmeans](./hello-world/step-by-step/higgs/sklearn-kmeans/sklearn_kmeans.ipynb) | HIGGS | server | Client API |sklearn | Example for federated k-Means clustering on tabular data. | | [xgboost](./hello-world/step-by-step/higgs/xgboost/xgboost_horizontal.ipynb) | HIGGS | server | Client API |XGBoost | Example for federated horizontal xgboost learning on tabular data with bagging collaboration. | -## 2. Hello World Examples -| Example | Framework | Summary | -|----------------------------------------------------------------------------------------------------------------------------------------|--------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------| -| [Notebook for Hello Examples](./hello-world/hello_world.ipynb) | - | Notebook for examples below. | -| [Hello Scatter and Gather](./hello-world/hello-numpy-sag/README.md) | Numpy | Example using [ScatterAndGather](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_common.workflows.scatter_and_gather.html) controller workflow. | -| [Hello Cross-Site Validation](./hello-world/hello-numpy-cross-val/README.md) | Numpy | Example using [CrossSiteModelEval](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_common.workflows.cross_site_model_eval.html) controller workflow. | -| [Hello Cyclic Weight Transfer](./hello-world/hello-cyclic/README.md) | PyTorch | Example using [CyclicController](https://nvflare.readthedocs.io/en/main/apidocs/nvflare.app_common.workflows.cyclic_ctl.html) controller workflow to implement [Cyclic Weight Transfer](https://pubmed.ncbi.nlm.nih.gov/29617797/). | -| [Hello PyTorch](./hello-world/hello-pt/README.md) | PyTorch | Example using an image classifier using [FedAvg](https://arxiv.org/abs/1602.05629) and [PyTorch](https://pytorch.org/) as the deep learning training framework. | -| [Hello TensorFlow](./hello-world/hello-tf2/README.md) | TensorFlow2 | Example of using an image classifier using [FedAvg](https://arxiv.org/abs/1602.05629) and [TensorFlow](https://tensorflow.org/) as the deep learning training framework. | - ## 3. Tutorial notebooks | Example | Summary | |----------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------| diff --git a/examples/advanced/brats18/requirements.txt b/examples/advanced/brats18/requirements.txt index b7dd1625cf..5757f5b0ea 100644 --- a/examples/advanced/brats18/requirements.txt +++ b/examples/advanced/brats18/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc torch torchvision tensorboard diff --git a/examples/advanced/cifar10/README.md b/examples/advanced/cifar10/README.md index 1b641fb953..4e75b03e8b 100644 --- a/examples/advanced/cifar10/README.md +++ b/examples/advanced/cifar10/README.md @@ -6,7 +6,7 @@ Please make sure you set up virtual environment and follows [example root readme This example includes instructions on running [FedAvg](https://arxiv.org/abs/1602.05629), [FedProx](https://arxiv.org/abs/1812.06127), [FedOpt](https://arxiv.org/abs/2003.00295), and [SCAFFOLD](https://arxiv.org/abs/1910.06378) algorithms using NVFlare's -[FL simulator](https://nvflare.readthedocs.io/en/latest/user_guide/fl_simulator.html). +[FL simulator](https://nvflare.readthedocs.io/en/latest/user_guide/nvflare_cli/fl_simulator.html). ### [Real-world Federated Learning with CIFAR-10](./cifar10-real-world/README.md) Real-world FL deployment requires secure provisioning and the admin API to submit jobs. diff --git a/examples/advanced/cifar10/cifar10-real-world/requirements.txt b/examples/advanced/cifar10/cifar10-real-world/requirements.txt index 9b6376874b..2fd21968dc 100644 --- a/examples/advanced/cifar10/cifar10-real-world/requirements.txt +++ b/examples/advanced/cifar10/cifar10-real-world/requirements.txt @@ -1,4 +1,4 @@ -nvflare[HE]>=2.3.0 +nvflare[HE]~=2.4.0rc torch torchvision tensorboard diff --git a/examples/advanced/cifar10/cifar10-sim/README.md b/examples/advanced/cifar10/cifar10-sim/README.md index 6feeaeb86e..230c7042e2 100644 --- a/examples/advanced/cifar10/cifar10-sim/README.md +++ b/examples/advanced/cifar10/cifar10-sim/README.md @@ -35,7 +35,7 @@ To speed up the following experiments, first download the [CIFAR-10](https://www ## 3. Run simulated FL experiments -We are using NVFlare's [FL simulator](https://nvflare.readthedocs.io/en/latest/user_guide/fl_simulator.html) to run the following experiments. +We are using NVFlare's [FL simulator](https://nvflare.readthedocs.io/en/latest/user_guide/nvflare_cli/fl_simulator.html) to run the following experiments. The output root of where to save the results is set in [./run_simulator.sh](./run_simulator.sh) as `RESULT_ROOT=/tmp/nvflare/sim_cifar10`. diff --git a/examples/advanced/cifar10/cifar10-sim/figs/plot_tensorboard_events.py b/examples/advanced/cifar10/cifar10-sim/figs/plot_tensorboard_events.py index 495dadf4c2..40284f3596 100644 --- a/examples/advanced/cifar10/cifar10-sim/figs/plot_tensorboard_events.py +++ b/examples/advanced/cifar10/cifar10-sim/figs/plot_tensorboard_events.py @@ -30,7 +30,7 @@ # 4.1 Central vs. FedAvg experiments = { - "cifar10_central": {"tag": "val_acc_local_model"}, + "cifar10_central": {"tag": "val_acc_local_model", "alpha": 0.0}, "cifar10_fedavg": {"tag": "val_acc_global_model", "alpha": 1.0}, } @@ -93,8 +93,10 @@ def main(): for config, exp in experiments.items(): config_name = config.split(" ")[0] alpha = exp.get("alpha", None) - if alpha: + if alpha is not None: config_name = config_name + f"*alpha{alpha}" + else: + raise ValueError(f"Expected an alpha value to be provided but got alpha={alpha}") eventfile = glob.glob( os.path.join(client_results_root, config_name, "**", "app_site-1", "events.*"), recursive=True ) @@ -116,7 +118,8 @@ def main(): try: xsite_data[k].append(xsite_results["site-1"][k]["val_accuracy"]) except Exception as e: - raise ValueError(f"No val_accuracy for {k} in {xsite_file}!") + xsite_data[k].append(None) + print(f"Warning: No val_accuracy for {k} in {xsite_file}!") print("Training TB data:") print(pd.DataFrame(data)) diff --git a/examples/advanced/cifar10/cifar10-sim/requirements.txt b/examples/advanced/cifar10/cifar10-sim/requirements.txt index 0804527963..3bbfea441b 100644 --- a/examples/advanced/cifar10/cifar10-sim/requirements.txt +++ b/examples/advanced/cifar10/cifar10-sim/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc torch torchvision tensorboard diff --git a/examples/advanced/cifar10/cifar10-sim/run_simulator.sh b/examples/advanced/cifar10/cifar10-sim/run_simulator.sh index 24a901a4ae..2fdd6eca45 100755 --- a/examples/advanced/cifar10/cifar10-sim/run_simulator.sh +++ b/examples/advanced/cifar10/cifar10-sim/run_simulator.sh @@ -8,12 +8,7 @@ n_clients=$4 # specify output workdir RESULT_ROOT=/tmp/nvflare/sim_cifar10 -if [ 1 -eq "$(echo "${alpha} > 0" | bc)" ] -then - out_workspace=${RESULT_ROOT}/${job}_alpha${alpha} -else - out_workspace=${RESULT_ROOT}/${job} -fi +out_workspace=${RESULT_ROOT}/${job}_alpha${alpha} # run FL simulator ./set_alpha.sh "${job}" "${alpha}" diff --git a/examples/advanced/custom_authentication/requirements.txt b/examples/advanced/custom_authentication/requirements.txt index d556c8d097..e4605852b5 100644 --- a/examples/advanced/custom_authentication/requirements.txt +++ b/examples/advanced/custom_authentication/requirements.txt @@ -1 +1 @@ -nvflare>=2.4.0 +nvflare~=2.4.0rc diff --git a/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_client.json b/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_client.conf similarity index 100% rename from examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_client.json rename to examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_client.conf diff --git a/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_server.json b/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_server.conf similarity index 100% rename from examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_server.json rename to examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-mlflow/app/config/config_fed_server.conf diff --git a/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-tb-mlflow/app/config/config_fed_client.json b/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-tb-mlflow/app/config/config_fed_client.conf similarity index 100% rename from examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-tb-mlflow/app/config/config_fed_client.json rename to examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-tb-mlflow/app/config/config_fed_client.conf diff --git a/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-tb-mlflow/app/config/config_fed_server.json b/examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-tb-mlflow/app/config/config_fed_server.conf similarity index 100% rename from examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-tb-mlflow/app/config/config_fed_server.json rename to examples/advanced/experiment-tracking/mlflow/jobs/hello-pt-tb-mlflow/app/config/config_fed_server.conf diff --git a/examples/advanced/experiment-tracking/mlflow/requirements.txt b/examples/advanced/experiment-tracking/mlflow/requirements.txt index 811b9aa432..04a1d06c48 100644 --- a/examples/advanced/experiment-tracking/mlflow/requirements.txt +++ b/examples/advanced/experiment-tracking/mlflow/requirements.txt @@ -1,4 +1,4 @@ -nvflare[PT]>=2.3.0 +nvflare~=2.4.0rc torch torchvision tensorboard diff --git a/examples/advanced/experiment-tracking/tensorboard/requirements.txt b/examples/advanced/experiment-tracking/tensorboard/requirements.txt index d96a108e51..3bbfea441b 100644 --- a/examples/advanced/experiment-tracking/tensorboard/requirements.txt +++ b/examples/advanced/experiment-tracking/tensorboard/requirements.txt @@ -1,2 +1,4 @@ -nvflare[PT]>=2.3.0 +nvflare~=2.4.0rc +torch +torchvision tensorboard diff --git a/examples/advanced/experiment-tracking/wandb/README.md b/examples/advanced/experiment-tracking/wandb/README.md index 27d06fa243..aaadc13b55 100644 --- a/examples/advanced/experiment-tracking/wandb/README.md +++ b/examples/advanced/experiment-tracking/wandb/README.md @@ -26,7 +26,9 @@ export PYTHONPATH=${PWD}/.. Import the W&B Python SDK and log in: ``` -wandb.login() +python3 +>>> import wandb +>>> wandb.login() ``` Provide your API key when prompted. diff --git a/examples/advanced/experiment-tracking/wandb/requirements.txt b/examples/advanced/experiment-tracking/wandb/requirements.txt index 7ea490208f..ad3f6241d2 100644 --- a/examples/advanced/experiment-tracking/wandb/requirements.txt +++ b/examples/advanced/experiment-tracking/wandb/requirements.txt @@ -1,3 +1,5 @@ -nvflare[PT]>=2.3.0 +nvflare~=2.4.0rc +torch +torchvision tensorboard wandb diff --git a/examples/advanced/federated-policies/requirements.txt b/examples/advanced/federated-policies/requirements.txt index 3fdbf10587..e4605852b5 100644 --- a/examples/advanced/federated-policies/requirements.txt +++ b/examples/advanced/federated-policies/requirements.txt @@ -1 +1 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc diff --git a/examples/advanced/federated-statistics/df_stats/requirements.txt b/examples/advanced/federated-statistics/df_stats/requirements.txt index dc5d8c6eaf..f897a1484a 100644 --- a/examples/advanced/federated-statistics/df_stats/requirements.txt +++ b/examples/advanced/federated-statistics/df_stats/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc numpy pandas matplotlib diff --git a/examples/advanced/federated-statistics/image_stats/requirements.txt b/examples/advanced/federated-statistics/image_stats/requirements.txt index 45e20cc1ee..9e0a46f617 100644 --- a/examples/advanced/federated-statistics/image_stats/requirements.txt +++ b/examples/advanced/federated-statistics/image_stats/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc numpy monai[itk] pandas diff --git a/examples/advanced/finance/requirements.txt b/examples/advanced/finance/requirements.txt index 5348abcbce..f8a60dc996 100644 --- a/examples/advanced/finance/requirements.txt +++ b/examples/advanced/finance/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc openmined.psi==1.1.1 pandas xgboost>=1.7.0 diff --git a/examples/advanced/gnn/README.md b/examples/advanced/gnn/README.md index 7983e528d7..f8c0c415ea 100644 --- a/examples/advanced/gnn/README.md +++ b/examples/advanced/gnn/README.md @@ -31,7 +31,7 @@ python3 -m pip install -r requirements.txt ``` To support functions of PyTorch Geometric necessary for this example, we need extra dependencies. Please refer to [installation guide](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) and install accordingly: ``` -pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cpu.html +python3 -m pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cpu.html ``` #### Job Template @@ -46,8 +46,8 @@ nvflare job list_templates We can see the "sag_gnn" template is available #### Protein Classification -The PPI dataset is directly available via torch_geometric library, we randomly split the dataset to 2 subsets, one for each client. -First, we run the local training on each client, as well as the whole dataset. +The PPI dataset is directly available via torch_geometric library, we randomly split the dataset to 2 subsets, one for each client (`--client_id 1` and `--client_id 2`). +First, we run the local training on each client, as well as the whole dataset with `--client_id 0`. ``` python3 code/graphsage_protein_local.py --client_id 0 python3 code/graphsage_protein_local.py --client_id 1 @@ -55,7 +55,7 @@ python3 code/graphsage_protein_local.py --client_id 2 ``` Then, we create NVFlare job based on GNN template. ``` -nvflare job create -force -j "./jobs/gnn_protein" -w "sag_gnn" -sd "code" \ +nvflare job create -force -j "/tmp/nvflare/jobs/gnn_protein" -w "sag_gnn" -sd "code" \ -f app_1/config_fed_client.conf app_script="graphsage_protein_fl.py" app_config="--client_id 1 --epochs 10" \ -f app_2/config_fed_client.conf app_script="graphsage_protein_fl.py" app_config="--client_id 2 --epochs 10" \ -f app_server/config_fed_server.conf num_rounds=7 key_metric="validation_f1" model_class_path="torch_geometric.nn.GraphSAGE" components[0].args.model.args.in_channels=50 components[0].args.model.args.hidden_channels=64 components[0].args.model.args.num_layers=2 components[0].args.model.args.out_channels=64 @@ -64,17 +64,17 @@ For client configs, we set client_ids for each client, and the number of local e For server configs, we set the number of rounds for federated training, the key metric for model selection, and the model class path with model hyperparameters. -With the produced job, we run the federated training on both clients via FedAvg using NVFlare Simulator. +With the produced job, we run the federated training on both clients via FedAvg using the NVFlare Simulator. ``` -nvflare simulator -w /tmp/nvflare/gnn/protein_fl_workspace -n 2 -t 2 ./jobs/gnn_protein +nvflare simulator -w /tmp/nvflare/gnn/protein_fl_workspace -n 2 -t 2 /tmp/nvflare/jobs/gnn_protein ``` #### Financial Transaction Classification -We first download the Elliptic++ dataset to `data` folder. In this example, we will use the following three files: +We first download the Elliptic++ dataset to `/tmp/nvflare/datasets/elliptic_pp` folder. In this example, we will use the following three files: - `txs_classes.csv`: transaction id and its class (licit or illicit) - `txs_edgelist.csv`: connections for transaction ids - `txs_features.csv`: transaction id and its features -Then, we run the local training on each client, as well as the whole dataset. +Then, we run the local training on each client, as well as the whole dataset. Again, `--client_id 0` uses all data. ``` python3 code/graphsage_finance_local.py --client_id 0 python3 code/graphsage_finance_local.py --client_id 1 @@ -82,14 +82,14 @@ python3 code/graphsage_finance_local.py --client_id 2 ``` Similarly, we create NVFlare job based on GNN template. ``` -nvflare job create -force -j "./jobs/gnn_finance" -w "sag_gnn" -sd "code" \ +nvflare job create -force -j "/tmp/nvflare/jobs/gnn_finance" -w "sag_gnn" -sd "code" \ -f app_1/config_fed_client.conf app_script="graphsage_finance_fl.py" app_config="--client_id 1 --epochs 10" \ -f app_2/config_fed_client.conf app_script="graphsage_finance_fl.py" app_config="--client_id 2 --epochs 10" \ -f app_server/config_fed_server.conf num_rounds=7 key_metric="validation_auc" model_class_path="pyg_sage.SAGE" components[0].args.model.args.in_channels=165 components[0].args.model.args.hidden_channels=256 components[0].args.model.args.num_layers=3 components[0].args.model.args.num_classes=2 ``` -And with the produced job, we run the federated training on both clients via FedAvg using NVFlare Simulator. +And with the produced job, we run the federated training on both clients via FedAvg using the NVFlare Simulator. ``` -nvflare simulator -w /tmp/nvflare/gnn/finance_fl_workspace -n 2 -t 2 ./jobs/gnn_finance +nvflare simulator -w /tmp/nvflare/gnn/finance_fl_workspace -n 2 -t 2 /tmp/nvflare/jobs/gnn_finance ``` ### Results diff --git a/examples/advanced/gnn/code/graphsage_finance_fl.py b/examples/advanced/gnn/code/graphsage_finance_fl.py index b2d8fa8a20..65ec991486 100644 --- a/examples/advanced/gnn/code/graphsage_finance_fl.py +++ b/examples/advanced/gnn/code/graphsage_finance_fl.py @@ -37,7 +37,7 @@ def main(): parser.add_argument( "--data_path", type=str, - default="./data", + default="/tmp/nvflare/datasets/elliptic_pp", ) parser.add_argument( "--epochs", diff --git a/examples/advanced/gnn/code/graphsage_finance_local.py b/examples/advanced/gnn/code/graphsage_finance_local.py index 351f9d4c1e..51404c3be2 100644 --- a/examples/advanced/gnn/code/graphsage_finance_local.py +++ b/examples/advanced/gnn/code/graphsage_finance_local.py @@ -34,7 +34,7 @@ def main(): parser.add_argument( "--data_path", type=str, - default="./data", + default="/tmp/nvflare/datasets/elliptic_pp", ) parser.add_argument( "--epochs", diff --git a/examples/advanced/gnn/gnn_examples.ipynb b/examples/advanced/gnn/gnn_examples.ipynb new file mode 100644 index 0000000000..54ab3ede98 --- /dev/null +++ b/examples/advanced/gnn/gnn_examples.ipynb @@ -0,0 +1,336 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cada310b-e776-4b9a-aabe-f111c31efcc2", + "metadata": { + "tags": [] + }, + "source": [ + "# Federated GNN on Graph Dataset using Inductive Learning" + ] + }, + { + "cell_type": "markdown", + "id": "0653cbf2-92f2-4a22-8317-69cfb0266e92", + "metadata": {}, + "source": [ + "## Introduction to GNN, Tasks, and federated GNN via Inductive Learning\n", + "### GNN\n", + "This example shows how to train a classification model using Graph Neural Network (GNN). GNNs show a promising future in research and industry, with potential applications in various domains, including social networks, e-commerce, recommendation systems, and more.\n", + "GNNs excel in learning, modeling, and leveraging complex relationships within graph-structured data. They combine local and global information, incorporate structural knowledge, adapt to diverse tasks, handle heterogeneous data, support transfer learning, scale for large graphs, offer interpretable insights, and achieve impressive performance. \n", + "\n", + "### Tasks\n", + "In this example, we provide two tasks:\n", + "1. **Protein Classification**:\n", + "The aim is to classify protein roles based on their cellular functions from gene ontology. The dataset we are using is PPI\n", + "([protein-protein interaction](http://snap.stanford.edu/graphsage/#code)) graphs, where each graph represents a specific human tissue. Protein-protein interaction (PPI) dataset is commonly used in graph-based machine-learning tasks, especially in the field of bioinformatics. This dataset represents interactions between proteins as graphs, where nodes represent proteins and edges represent interactions between them.\n", + "2. **Financial Transaction Classification**:\n", + "The aim is to classify whether a given transaction is licit or illicit. For this financial application, we use the [Elliptic++](https://github.com/git-disl/EllipticPlusPlus) dataset. It consists of 203k Bitcoin transactions and 822k wallet addresses to enable both the detection of fraudulent transactions and the detection of illicit addresses (actors) in the Bitcoin network by leveraging graph data. For more details, please refer to this [paper](https://arxiv.org/pdf/2306.06108.pdf).\n", + "\n", + "\n", + "### Federated GNN via Inductive Learning\n", + "Both tasks are for node classification. We used the inductive representation learning method [GraphSAGE](https://arxiv.org/pdf/1706.02216.pdf) based on [Pytorch Geometric](https://github.com/pyg-team/pytorch_geometric)'s examples. \n", + "[Pytorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/) is a library built upon PyTorch to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data.\n", + "\n", + "For protein classification task, we used it in an unsupervised manner, following [PyG's unsupervised PPI example](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_sage_unsup_ppi.py).\n", + "For financial transaction classification task, we used it in a supervised manner, directly using the node labels with supervised classification loss.\n", + "\n", + "Since the inductive learning mode is being used, the locally learnt model (a representation encoding / classification network) is irrelevant to the candidate graph, we are able to use the basic [FedAvg](https://arxiv.org/abs/1602.05629) as the federated learning algorithm. The workflow is Scatter and Gather (SAG).\n", + "\n", + "\n", + "Below we listed steps to run this example." + ] + }, + { + "cell_type": "markdown", + "id": "a5a0292c-78b6-4bde-96d6-699dae996173", + "metadata": {}, + "source": [ + "## 1. Setup NVFLARE\n", + "\n", + "Follow the [Getting_Started](https://nvflare.readthedocs.io/en/main/getting_started.html) to setup virtual environment and install NVFLARE\n", + "\n", + "We also provide a [Notebook](../../nvflare_setup.ipynb) for this setup process. \n", + "\n", + "Assume you have already setup the venv, lets first install required packages." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4130b15-09e6-456f-a3c7-87c8ee9e07f0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%pip install -r requirements.txt" + ] + }, + { + "cell_type": "markdown", + "id": "1d872d8a-9e44-49dd-94b1-7862b3815ffe", + "metadata": {}, + "source": [ + "To support functions of PyTorch Geometric necessary for this example, we need extra dependencies. Please refer to [installation guide](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) and install accordingly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f906a1c9-dce0-476c-be65-79ebd8ad5da9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cpu.html" + ] + }, + { + "cell_type": "markdown", + "id": "e93b1bf2-6157-4ab6-9766-819450304038", + "metadata": {}, + "source": [ + "## 2. Data preparation \n", + "This example uses two datasets: \n", + "- For Protein Classification, the PPI dataset is available from torch_geometric's dataset API. \n", + "- For Financial Transaction Classification, we first download the [Elliptic++](https://github.com/git-disl/EllipticPlusPlus) dataset to `/tmp/nvflare/datasets/elliptic_pp` folder. In this example, we will use the following three files:\n", + " - `txs_classes.csv`: transaction id and its class (licit or illicit)\n", + " - `txs_edgelist.csv`: connections for transaction ids \n", + " - `txs_features.csv`: transaction id and its features" + ] + }, + { + "cell_type": "markdown", + "id": "af257e69-2bb7-49b6-ac6c-f007b0e6618e", + "metadata": {}, + "source": [ + "## 3. Local Experiments\n", + "For comparison with federated learning results, we first perform local experiments on each client's data and the whole dataset. Here we simulate 2 clients with uniform data split (client_id = 0 means the whole dataset). The 6 experiments will take a while to finish. The default epoch number is set to 70. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdb7290a-48ff-4e80-be58-5e6b0e0f9379", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "! python3 code/graphsage_protein_local.py --client_id 0\n", + "! python3 code/graphsage_protein_local.py --client_id 1\n", + "! python3 code/graphsage_protein_local.py --client_id 2 " + ] + }, + { + "cell_type": "markdown", + "id": "9a2d55cf-4f7a-4030-8cba-b1619fdf1614", + "metadata": {}, + "source": [ + "And for finance experiment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4cf2c09-1f78-4d28-9b86-af9f9cf86479", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "! python3 code/graphsage_finance_local.py --client_id 0\n", + "! python3 code/graphsage_finance_local.py --client_id 1\n", + "! python3 code/graphsage_finance_local.py --client_id 2 " + ] + }, + { + "cell_type": "markdown", + "id": "d178c6dc-c180-4ca6-8dea-3b0fe147665b", + "metadata": {}, + "source": [ + "## 4. Prepare NVFlare job based on GNN template\n", + "We are using NVFlare's FL simulator to run the FL experiments. First, we create jobs using GNN template. We reuse the job templates from [sag_gnn](../../../job_templates/sag_gnn), let's set the job template path with the following command." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fd885e6b-ae4d-40aa-b89d-fe34217ad3da", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "! nvflare config -jt ../../../job_templates/" + ] + }, + { + "cell_type": "markdown", + "id": "f608a992-5096-4452-8775-b89987970a75", + "metadata": {}, + "source": [ + "Then we can check the available templates with the following command." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f8041c7-7fae-4c8a-8e07-1c6a6d59e541", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "! nvflare job list_templates" + ] + }, + { + "cell_type": "markdown", + "id": "5bad4f55-d582-4f37-a523-927dc015e564", + "metadata": {}, + "source": [ + "We shall see `sag_gnn` from the above command. We then create jobs using this template, and set local epochs to 10 with 7 rounds of FL to match local experiments' 70 epoch default training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7971a800-70fc-4213-96ed-c157801b5a11", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "! nvflare job create -force -j \"/tmp/nvflare/jobs/gnn_protein\" -w \"sag_gnn\" -sd \"code\" \\\n", + " -f app_1/config_fed_client.conf app_script=\"graphsage_protein_fl.py\" app_config=\"--client_id 1 --epochs 10\" \\\n", + " -f app_2/config_fed_client.conf app_script=\"graphsage_protein_fl.py\" app_config=\"--client_id 2 --epochs 10\" \\\n", + " -f app_server/config_fed_server.conf num_rounds=7 key_metric=\"validation_f1\" model_class_path=\"torch_geometric.nn.GraphSAGE\" components[0].args.model.args.in_channels=50 components[0].args.model.args.hidden_channels=64 components[0].args.model.args.num_layers=2 components[0].args.model.args.out_channels=64 " + ] + }, + { + "cell_type": "markdown", + "id": "675bff95-dcfa-4a47-9a05-460da16760ef", + "metadata": {}, + "source": [ + "And for finance experiment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6d0b643-31f0-4d52-ae3c-1fafcd404072", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "! nvflare job create -force -j \"/tmp/nvflare/jobs/gnn_finance\" -w \"sag_gnn\" -sd \"code\" \\\n", + " -f app_1/config_fed_client.conf app_script=\"graphsage_finance_fl.py\" app_config=\"--client_id 1 --epochs 10\" \\\n", + " -f app_2/config_fed_client.conf app_script=\"graphsage_finance_fl.py\" app_config=\"--client_id 2 --epochs 10\" \\\n", + " -f app_server/config_fed_server.conf num_rounds=7 key_metric=\"validation_auc\" model_class_path=\"pyg_sage.SAGE\" components[0].args.model.args.in_channels=165 components[0].args.model.args.hidden_channels=256 components[0].args.model.args.num_layers=3 components[0].args.model.args.num_classes=2 \n" + ] + }, + { + "cell_type": "markdown", + "id": "bd0713e2-e393-41c0-9da0-392535cf8a54", + "metadata": {}, + "source": [ + "## 5. Run simulated kmeans experiment\n", + "Now that we have the jobs ready, we run the experiment using Simulator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9bb6cab4-9c24-400a-bc3c-f1e4a6d5a346", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "! nvflare simulator -w /tmp/nvflare/gnn/protein_fl_workspace -n 2 -t 2 /tmp/nvflare/jobs/gnn_protein" + ] + }, + { + "cell_type": "markdown", + "id": "98c64648-1d09-42da-bd48-9a6ac48587af", + "metadata": {}, + "source": [ + "And for finance experiment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9f256a0-ae99-4a7e-8bc2-e7fc8de2e6f6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "! nvflare simulator -w /tmp/nvflare/gnn/finance_fl_workspace -n 2 -t 2 /tmp/nvflare/jobs/gnn_finance" + ] + }, + { + "cell_type": "markdown", + "id": "913e9ee2-e993-442d-a525-d2baf92af539", + "metadata": {}, + "source": [ + "## 6. Result visualization\n", + "Results from both local and federated experiments can be visualized in tensorboard." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6814434-4e6d-4460-b480-709cb3e77cc8", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%load_ext tensorboard\n", + "%tensorboard --logdir /tmp/nvflare/gnn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f6ae6cb-12df-4279-b6af-9c4d356e727e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "nvflare_example", + "language": "python", + "name": "nvflare_example" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/advanced/gnn/requirements.txt b/examples/advanced/gnn/requirements.txt index 129158b55d..c469afd535 100644 --- a/examples/advanced/gnn/requirements.txt +++ b/examples/advanced/gnn/requirements.txt @@ -1,4 +1,3 @@ -nvflare torch torch_geometric tensorboard diff --git a/examples/advanced/job-level-authorization/requirements.txt b/examples/advanced/job-level-authorization/requirements.txt index d556c8d097..e4605852b5 100644 --- a/examples/advanced/job-level-authorization/requirements.txt +++ b/examples/advanced/job-level-authorization/requirements.txt @@ -1 +1 @@ -nvflare>=2.4.0 +nvflare~=2.4.0rc diff --git a/examples/advanced/keycloak-site-authentication/requirements.txt b/examples/advanced/keycloak-site-authentication/requirements.txt index d556c8d097..e4605852b5 100644 --- a/examples/advanced/keycloak-site-authentication/requirements.txt +++ b/examples/advanced/keycloak-site-authentication/requirements.txt @@ -1 +1 @@ -nvflare>=2.4.0 +nvflare~=2.4.0rc diff --git a/examples/advanced/llm_hf/requirements.txt b/examples/advanced/llm_hf/requirements.txt index 9651b08324..171942c7af 100644 --- a/examples/advanced/llm_hf/requirements.txt +++ b/examples/advanced/llm_hf/requirements.txt @@ -1,4 +1,4 @@ -nvflare +nvflare~=2.4.0rc torch datasets tensorboard diff --git a/examples/advanced/nlp-ner/README.md b/examples/advanced/nlp-ner/README.md index 7ee52a3482..d110d944ba 100644 --- a/examples/advanced/nlp-ner/README.md +++ b/examples/advanced/nlp-ner/README.md @@ -52,7 +52,7 @@ Let's take a closer look at the word-label correspondence: As shown above, the task is to capture the keywords related to medical findings. ## Run automated experiments -We use the NVFlare [simulator](https://nvflare.readthedocs.io/en/latest/user_guide/fl_simulator.html) to run the FL training. +We use the NVFlare [simulator](https://nvflare.readthedocs.io/en/latest/user_guide/nvflare_cli/fl_simulator.html) to run the FL training. Set `PYTHONPATH` to include custom files of this example: ``` export PYTHONPATH=${PWD} diff --git a/examples/advanced/nlp-ner/requirements.txt b/examples/advanced/nlp-ner/requirements.txt index 0bcaa6d626..678ba0bcd7 100644 --- a/examples/advanced/nlp-ner/requirements.txt +++ b/examples/advanced/nlp-ner/requirements.txt @@ -1,4 +1,4 @@ -nvflare +nvflare~=2.4.0rc torch torchvision tensorboard diff --git a/examples/advanced/prostate/requirements.txt b/examples/advanced/prostate/requirements.txt index b7dd1625cf..5757f5b0ea 100644 --- a/examples/advanced/prostate/requirements.txt +++ b/examples/advanced/prostate/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc torch torchvision tensorboard diff --git a/examples/advanced/psi/user_email_match/requirements.txt b/examples/advanced/psi/user_email_match/requirements.txt index 8c42c63f15..23c6c47d5d 100644 --- a/examples/advanced/psi/user_email_match/requirements.txt +++ b/examples/advanced/psi/user_email_match/requirements.txt @@ -1,3 +1,3 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc openmined.psi==1.1.1 pandas diff --git a/examples/advanced/random_forest/requirements.txt b/examples/advanced/random_forest/requirements.txt index 02353a0589..96c88f1ec4 100644 --- a/examples/advanced/random_forest/requirements.txt +++ b/examples/advanced/random_forest/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc pandas xgboost scikit-learn diff --git a/examples/advanced/sklearn-kmeans/README.md b/examples/advanced/sklearn-kmeans/README.md index 5b826609bd..a3d7cdb158 100644 --- a/examples/advanced/sklearn-kmeans/README.md +++ b/examples/advanced/sklearn-kmeans/README.md @@ -117,7 +117,7 @@ Below is a sample config for site-1, saved to `./jobs/sklearn_kmeans_3_uniform/a ``` ## Run experiment with FL simulator -The [FL simulator](https://nvflare.readthedocs.io/en/latest/user_guide/fl_simulator.html) simulates FL experiments or debugging codes, +The [FL simulator](https://nvflare.readthedocs.io/en/latest/user_guide/nvflare_cli/fl_simulator.html) simulates FL experiments or debugging codes, not for real-world FL deployment. We can run the FL simulator with 3 clients under the uniform data split with ```commandline diff --git a/examples/advanced/sklearn-kmeans/requirements.txt b/examples/advanced/sklearn-kmeans/requirements.txt index aeafe651e1..22de3c503b 100644 --- a/examples/advanced/sklearn-kmeans/requirements.txt +++ b/examples/advanced/sklearn-kmeans/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc pandas scikit-learn joblib diff --git a/examples/advanced/sklearn-linear/README.md b/examples/advanced/sklearn-linear/README.md index 1284360101..54b9905fe9 100644 --- a/examples/advanced/sklearn-linear/README.md +++ b/examples/advanced/sklearn-linear/README.md @@ -101,7 +101,7 @@ Below is a sample config for site-1, saved to `./jobs/sklearn_linear_5_uniform/a ``` ## Run experiment with FL simulator -[FL simulator](https://nvflare.readthedocs.io/en/latest/user_guide/fl_simulator.html) is used to simulate FL experiments or debug codes, not for real FL deployment. +[FL simulator](https://nvflare.readthedocs.io/en/latest/user_guide/nvflare_cli/fl_simulator.html) is used to simulate FL experiments or debug codes, not for real FL deployment. We can run the FL simulator with five clients under the uniform data split with ```commandline bash run_experiment_simulator.sh diff --git a/examples/advanced/sklearn-linear/requirements.txt b/examples/advanced/sklearn-linear/requirements.txt index aeafe651e1..22de3c503b 100644 --- a/examples/advanced/sklearn-linear/requirements.txt +++ b/examples/advanced/sklearn-linear/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc pandas scikit-learn joblib diff --git a/examples/advanced/sklearn-svm/README.md b/examples/advanced/sklearn-svm/README.md index b2263ea90f..fc17102bd8 100644 --- a/examples/advanced/sklearn-svm/README.md +++ b/examples/advanced/sklearn-svm/README.md @@ -34,7 +34,7 @@ Under this setting, federated learning can be formulated in two steps: Unlike other iterative federated algorithms, federated SVM only involves these two training steps. Hence, in the server config, we have -```json +``` "num_rounds": 2 ``` The first round is the training round, performing local training and global aggregation. @@ -116,7 +116,7 @@ Below is a sample config for site-1, saved to `./jobs/sklearn_svm_3_uniform/app_ ``` ## Run experiment with FL simulator -[FL simulator](https://nvflare.readthedocs.io/en/latest/user_guide/fl_simulator.html) is used to simulate FL experiments or debug codes, not for real FL deployment. +[FL simulator](https://nvflare.readthedocs.io/en/latest/user_guide/nvflare_cli/fl_simulator.html) is used to simulate FL experiments or debug codes, not for real FL deployment. We can run the FL simulator with three clients under the uniform data split with ```commandline bash run_experiment_simulator.sh diff --git a/examples/advanced/sklearn-svm/requirements.txt b/examples/advanced/sklearn-svm/requirements.txt index aeafe651e1..22de3c503b 100644 --- a/examples/advanced/sklearn-svm/requirements.txt +++ b/examples/advanced/sklearn-svm/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc pandas scikit-learn joblib diff --git a/examples/advanced/swarm_learning/requirements.txt b/examples/advanced/swarm_learning/requirements.txt index 0804527963..3bbfea441b 100644 --- a/examples/advanced/swarm_learning/requirements.txt +++ b/examples/advanced/swarm_learning/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc torch torchvision tensorboard diff --git a/examples/advanced/vertical_federated_learning/cifar10-splitnn/README.md b/examples/advanced/vertical_federated_learning/cifar10-splitnn/README.md index 181d104baa..02caf835d4 100644 --- a/examples/advanced/vertical_federated_learning/cifar10-splitnn/README.md +++ b/examples/advanced/vertical_federated_learning/cifar10-splitnn/README.md @@ -1,7 +1,7 @@ # Split Learning with CIFAR-10 This example includes instructions on how to run [split learning](https://arxiv.org/abs/1810.06060) (SL) using the -CIFAR-10 dataset and the [FL simulator](https://nvflare.readthedocs.io/en/latest/user_guide/fl_simulator.html). +CIFAR-10 dataset and the [FL simulator](https://nvflare.readthedocs.io/en/latest/user_guide/nvflare_cli/fl_simulator.html). We assume one client holds the images, and the other client holds the labels to compute losses and accuracy metrics. Activations and corresponding gradients are being exchanged between the clients using NVFlare. diff --git a/examples/advanced/vertical_federated_learning/cifar10-splitnn/requirements.txt b/examples/advanced/vertical_federated_learning/cifar10-splitnn/requirements.txt index 58f9cf9de7..57c627000c 100644 --- a/examples/advanced/vertical_federated_learning/cifar10-splitnn/requirements.txt +++ b/examples/advanced/vertical_federated_learning/cifar10-splitnn/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.4.0 +nvflare~=2.4.0rc torch torchvision tensorboard diff --git a/examples/advanced/vertical_xgboost/requirements.txt b/examples/advanced/vertical_xgboost/requirements.txt index 6bf2c8cbe1..a9a1d31eda 100644 --- a/examples/advanced/vertical_xgboost/requirements.txt +++ b/examples/advanced/vertical_xgboost/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc openmined.psi==1.1.1 pandas tensorboard diff --git a/examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_client.json b/examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_client.json index c456968c25..1d687b3a37 100755 --- a/examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_client.json +++ b/examples/advanced/xgboost/histogram-based/jobs/base/app/config/config_fed_client.json @@ -1,5 +1,6 @@ { "format_version": 2, + "num_rounds": 100, "executors": [ { "tasks": [ @@ -10,14 +11,14 @@ "name": "FedXGBHistogramExecutor", "args": { "data_loader_id": "dataloader", - "num_rounds": 100, + "num_rounds": "{num_rounds}", "early_stopping_rounds": 2, "xgb_params": { "max_depth": 8, "eta": 0.1, "objective": "binary:logistic", "eval_metric": "auc", - "tree_method": "gpu_hist", + "tree_method": "hist", "nthread": 16 } } diff --git a/examples/advanced/xgboost/histogram-based/requirements.txt b/examples/advanced/xgboost/histogram-based/requirements.txt index fcdcad4892..8311f62b9f 100644 --- a/examples/advanced/xgboost/histogram-based/requirements.txt +++ b/examples/advanced/xgboost/histogram-based/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc pandas xgboost>=2.0.0 scikit-learn diff --git a/examples/advanced/xgboost/tree-based/jobs/bagging_base/app/config/config_fed_server.json b/examples/advanced/xgboost/tree-based/jobs/bagging_base/app/config/config_fed_server.json index 124a2296f9..f35526c721 100755 --- a/examples/advanced/xgboost/tree-based/jobs/bagging_base/app/config/config_fed_server.json +++ b/examples/advanced/xgboost/tree-based/jobs/bagging_base/app/config/config_fed_server.json @@ -1,5 +1,6 @@ { "format_version": 2, + "num_rounds": 101, "server": { "heart_beat_timeout": 600, @@ -34,7 +35,7 @@ "name": "ScatterAndGather", "args": { "min_clients": 5, - "num_rounds": 101, + "num_rounds": "{num_rounds}", "start_round": 0, "wait_time_after_min_received": 0, "aggregator_id": "aggregator", diff --git a/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_client.json b/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_client.json index 319c26de1a..6b25f996bb 100755 --- a/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_client.json +++ b/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_client.json @@ -13,7 +13,6 @@ "data_loader_id": "dataloader", "training_mode": "cyclic", "num_client_bagging": 1, - "lr_mode": "scaled", "local_model_path": "model.json", "global_model_path": "model_global.json", "learning_rate": 0.1, diff --git a/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_server.json b/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_server.json index 686042e1b5..3f331b862c 100755 --- a/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_server.json +++ b/examples/advanced/xgboost/tree-based/jobs/cyclic_base/app/config/config_fed_server.json @@ -1,5 +1,6 @@ { "format_version": 2, + "num_rounds": 20, "server": { "heart_beat_timeout": 600, @@ -29,7 +30,7 @@ "id": "cyclic_ctl", "name": "CyclicController", "args": { - "num_rounds": 20, + "num_rounds": "{num_rounds}", "task_assignment_timeout": 60, "persistor_id": "persistor", "shareable_generator_id": "shareable_generator", diff --git a/examples/advanced/xgboost/tree-based/requirements.txt b/examples/advanced/xgboost/tree-based/requirements.txt index 02353a0589..96c88f1ec4 100644 --- a/examples/advanced/xgboost/tree-based/requirements.txt +++ b/examples/advanced/xgboost/tree-based/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc pandas xgboost scikit-learn diff --git a/examples/advanced/xgboost/utils/prepare_job_config.py b/examples/advanced/xgboost/utils/prepare_job_config.py index f970faa016..28e5652a63 100644 --- a/examples/advanced/xgboost/utils/prepare_job_config.py +++ b/examples/advanced/xgboost/utils/prepare_job_config.py @@ -20,6 +20,8 @@ from nvflare.apis.fl_constant import JobConstants +SCRIPT_PATH = pathlib.Path(os.path.realpath(__file__)) +XGB_EXAMPLE_ROOT = SCRIPT_PATH.parent.parent.absolute() JOB_CONFIGS_ROOT = "jobs" MODE_ALGO_MAP = {"bagging": "tree-based", "cyclic": "tree-based", "histogram": "histogram-based"} @@ -84,7 +86,7 @@ def _get_src_job_dir(training_mode): "cyclic": "cyclic_base", "histogram": "base", } - return pathlib.Path(MODE_ALGO_MAP[training_mode]) / JOB_CONFIGS_ROOT / base_job_map[training_mode] + return XGB_EXAMPLE_ROOT / MODE_ALGO_MAP[training_mode] / JOB_CONFIGS_ROOT / base_job_map[training_mode] def _gen_deploy_map(num_sites: int, site_name_prefix: str) -> dict: @@ -133,6 +135,7 @@ def _update_client_config(config: dict, args, lr_scale, site_name: str): num_client_bagging = args.site_num config["executors"][0]["executor"]["args"]["num_client_bagging"] = num_client_bagging else: + config["num_rounds"] = args.round_num config["components"][0]["args"]["data_split_filename"] = data_split_name config["executors"][0]["executor"]["args"]["xgb_params"]["nthread"] = args.nthread config["executors"][0]["executor"]["args"]["xgb_params"]["tree_method"] = args.tree_method @@ -140,10 +143,10 @@ def _update_client_config(config: dict, args, lr_scale, site_name: str): def _update_server_config(config: dict, args): if args.training_mode == "bagging": - config["workflows"][0]["args"]["num_rounds"] = args.round_num + 1 + config["num_rounds"] = args.round_num + 1 config["workflows"][0]["args"]["min_clients"] = args.site_num elif args.training_mode == "cyclic": - config["workflows"][0]["args"]["num_rounds"] = int(args.round_num / args.site_num) + config["num_rounds"] = int(args.round_num / args.site_num) def _copy_custom_files(src_job_path, src_app_name, dst_job_path, dst_app_name): @@ -198,7 +201,7 @@ def main(): src_job_path = _get_src_job_dir(args.training_mode) # create a new job - dst_job_path = pathlib.Path(MODE_ALGO_MAP[args.training_mode]) / JOB_CONFIGS_ROOT / job_name + dst_job_path = XGB_EXAMPLE_ROOT / MODE_ALGO_MAP[args.training_mode] / JOB_CONFIGS_ROOT / job_name if not os.path.exists(dst_job_path): os.makedirs(dst_job_path) diff --git a/examples/hello-world/hello-ccwf/requirements.txt b/examples/hello-world/hello-ccwf/requirements.txt new file mode 100644 index 0000000000..e4605852b5 --- /dev/null +++ b/examples/hello-world/hello-ccwf/requirements.txt @@ -0,0 +1 @@ +nvflare~=2.4.0rc diff --git a/examples/hello-world/hello-cyclic/README.md b/examples/hello-world/hello-cyclic/README.md index 84d81eda1f..60c8c632ea 100644 --- a/examples/hello-world/hello-cyclic/README.md +++ b/examples/hello-world/hello-cyclic/README.md @@ -27,8 +27,8 @@ bash ./prepare_data.sh Use nvflare simulator to run the hello-examples: -``` -nvflare simulator -w /tmp/nvflare/ -n 2 -t 2 hello-cyclic/jobs/hello-cyclic +```bash +nvflare simulator -w /tmp/nvflare/ -n 2 -t 2 ./jobs/hello-cyclic ``` ### 3. Access the logs and results @@ -40,3 +40,27 @@ $ ls /tmp/nvflare/simulate_job/ app_server app_site-1 app_site-2 log.txt ``` + +### 4. Notes on running with GPUs + +For running with GPUs, we recommend using +[NVIDIA TensorFlow docker](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tensorflow) + +If you choose to run the example using GPUs, it is important to note that, +by default, TensorFlow will attempt to allocate all available GPU memory at the start. +In scenarios where multiple clients are involved, you have a couple of options to address this. + +One approach is to include specific flags to prevent TensorFlow from allocating all GPU memory. +For instance: + +```bash +TF_FORCE_GPU_ALLOW_GROWTH=true nvflare simulator -w /tmp/nvflare/ -n 2 -t 2 ./jobs/hello-cyclic +``` + +If you possess more GPUs than clients, +an alternative strategy is to run one client on each GPU. +This can be achieved as illustrated below: + +```bash +TF_FORCE_GPU_ALLOW_GROWTH=true nvflare simulator -w /tmp/nvflare/ -n 2 -gpu 0,1 ./jobs/hello-cyclic +``` diff --git a/examples/hello-world/hello-cyclic/requirements.txt b/examples/hello-world/hello-cyclic/requirements.txt index 0f8ce21d90..8f8b6bc27b 100644 --- a/examples/hello-world/hello-cyclic/requirements.txt +++ b/examples/hello-world/hello-cyclic/requirements.txt @@ -1,2 +1,2 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc tensorflow diff --git a/examples/hello-world/hello-numpy-cross-val/requirements.txt b/examples/hello-world/hello-numpy-cross-val/requirements.txt index 3fdbf10587..e4605852b5 100644 --- a/examples/hello-world/hello-numpy-cross-val/requirements.txt +++ b/examples/hello-world/hello-numpy-cross-val/requirements.txt @@ -1 +1 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc diff --git a/examples/hello-world/hello-numpy-sag/requirements.txt b/examples/hello-world/hello-numpy-sag/requirements.txt index 3fdbf10587..e4605852b5 100644 --- a/examples/hello-world/hello-numpy-sag/requirements.txt +++ b/examples/hello-world/hello-numpy-sag/requirements.txt @@ -1 +1 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc diff --git a/examples/hello-world/hello-pt/requirements.txt b/examples/hello-world/hello-pt/requirements.txt index 05e7ad671e..265102f82c 100644 --- a/examples/hello-world/hello-pt/requirements.txt +++ b/examples/hello-world/hello-pt/requirements.txt @@ -1 +1,3 @@ -nvflare[PT]>=2.3.0 +nvflare~=2.4.0rc +torch +torchvision diff --git a/examples/hello-world/hello-tf2/README.md b/examples/hello-world/hello-tf2/README.md index 5bdeb6a0e5..3bad827f25 100644 --- a/examples/hello-world/hello-tf2/README.md +++ b/examples/hello-world/hello-tf2/README.md @@ -26,10 +26,10 @@ Prepare the data first: bash ./prepare_data.sh ``` -Use nvflare simulator to run the hello-examples: (TF2 does not allow multiple processes to be running on a single GPU at the same time. Need to set the simulator threads to 1. "-gpu" option can be used to run multiple concurrent clients.) +Use nvflare simulator to run the hello-examples: -``` -nvflare simulator -w /tmp/nvflare/ -n 2 -t 1 hello-tf2/jobs/hello-tf2 +```bash +nvflare simulator -w /tmp/nvflare/ -n 2 -t 2 ./jobs/hello-tf2 ``` ### 3. Access the logs and results @@ -41,3 +41,27 @@ $ ls /tmp/nvflare/simulate_job/ app_server app_site-1 app_site-2 log.txt ``` + +### 4. Notes on running with GPUs + +For running with GPUs, we recommend using +[NVIDIA TensorFlow docker](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tensorflow) + +If you choose to run the example using GPUs, it is important to note that, +by default, TensorFlow will attempt to allocate all available GPU memory at the start. +In scenarios where multiple clients are involved, you have a couple of options to address this. + +One approach is to include specific flags to prevent TensorFlow from allocating all GPU memory. +For instance: + +```bash +TF_FORCE_GPU_ALLOW_GROWTH=true nvflare simulator -w /tmp/nvflare/ -n 2 -t 2 ./jobs/hello-tf2 +``` + +If you possess more GPUs than clients, +an alternative strategy is to run one client on each GPU. +This can be achieved as illustrated below: + +```bash +TF_FORCE_GPU_ALLOW_GROWTH=true nvflare simulator -w /tmp/nvflare/ -n 2 -gpu 0,1 ./jobs/hello-tf2 +``` diff --git a/examples/hello-world/hello-tf2/requirements.txt b/examples/hello-world/hello-tf2/requirements.txt index 0f8ce21d90..8f8b6bc27b 100644 --- a/examples/hello-world/hello-tf2/requirements.txt +++ b/examples/hello-world/hello-tf2/requirements.txt @@ -1,2 +1,2 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc tensorflow diff --git a/examples/hello-world/ml-to-fl/np/README.md b/examples/hello-world/ml-to-fl/np/README.md index c6dea03fc7..1de78e0e5a 100644 --- a/examples/hello-world/ml-to-fl/np/README.md +++ b/examples/hello-world/ml-to-fl/np/README.md @@ -1,4 +1,4 @@ -# Configurations of NVFlare Client API +# NVFlare Client API We will demonstrate how to send back model parameters or model parameters differences in different approaches in the following examples: @@ -18,6 +18,25 @@ We demonstrate how to launch training script once and have training script keeps 1. [Launch once for the whole job](#launch-once-for-the-whole-job) +## Software Requirements + +Please install the requirements first, it is suggested to install inside a virtual environment: + +```bash +pip install -r requirements.txt +``` + +Please also configure the job templates folder: + +```bash +nvflare config -jt ../../../../job_templates/ +nvflare job list_templates +``` + +## Minimum Hardware Requirements + +1 CPU + ## Send model parameters back to the NVFlare server @@ -29,8 +48,6 @@ To send back the whole model parameters, we need to make sure the "params_transf Let reuse the job templates from [sag_np](../../../../job_templates/sag_np/): ```bash -nvflare config -jt ../../../../job_templates/ -nvflare job list_templates nvflare job create -force -j ./jobs/np_param_full_transfer_full -w sag_np -sd ./code/ \ -f config_fed_client.conf app_script=train_full.py params_transfer_type=FULL launch_once=false ``` diff --git a/examples/hello-world/ml-to-fl/np/requirements.txt b/examples/hello-world/ml-to-fl/np/requirements.txt new file mode 100644 index 0000000000..e4605852b5 --- /dev/null +++ b/examples/hello-world/ml-to-fl/np/requirements.txt @@ -0,0 +1 @@ +nvflare~=2.4.0rc diff --git a/examples/hello-world/ml-to-fl/pt/README.md b/examples/hello-world/ml-to-fl/pt/README.md index 7789515df8..09c1a8b0db 100644 --- a/examples/hello-world/ml-to-fl/pt/README.md +++ b/examples/hello-world/ml-to-fl/pt/README.md @@ -1,5 +1,19 @@ # PyTorch Deep Learning to Federated Learning transition with NVFlare +We will demonstrate how to transform an existing DL code into an FL application step-by-step: + + 1. [Show a baseline training script](#the-baseline) + 2. [How to modify an existing training script using DL2FL Client API](#transform-cifar10-dl-training-code-to-fl-including-best-model-selection-using-client-api) + 3. [How to modify a structured script using DL2FL decorator](#the-decorator-use-case) + 4. [How to modify a PyTorch Lightning script using DL2FL Lightning Client API](#transform-cifar10-pytorch-lightning-training-code-to-fl-with-nvflare-client-lightning-integration-api) + +If you have multi GPU please refer to the following examples: + + 1. [How to modify a PyTorch DDP training script using DL2FL Client API](#transform-cifar10-pytorch--ddp-training-code-to-fl-using-client-api) + 2. [How to modify a PyTorch Lightning DDP training script using DL2FL Lightning Client API](#transform-cifar10-pytorch-lightning--ddp-training-code-to-fl-with-nvflare-client-lightning-integration-api) + +## Software Requirements + Please install the requirements first, it is suggested to install inside a virtual environment: ```bash @@ -13,17 +27,22 @@ nvflare config -jt ../../../../job_templates/ nvflare job list_templates ``` -We will demonstrate how to transform an existing DL code into an FL application step-by-step: +## Minimum Hardware Requirements - 1. [Show a baseline training script](#the-baseline) - 2. [How to modify an existing training script using DL2FL Client API](#transform-cifar10-dl-training-code-to-fl-including-best-model-selection-using-client-api) - 3. [How to modify a structured script using DL2FL decorator](#the-decorator-use-case) - 4. [How to modify a PyTorch Lightning script using DL2FL Lightning Client API](#transform-cifar10-pytorch-lightning-training-code-to-fl-with-nvflare-client-lightning-integration-api) +Each example has different requirements: -If you have multi GPU please refer to the following examples: +| Example name | minimum requirements | +| ------------ | -------------------- | +| [Show a baseline training script](#the-baseline) | 1 CPU or 1 GPU* | +| [How to modify an existing training script using DL2FL Client API](#transform-cifar10-dl-training-code-to-fl-including-best-model-selection-using-client-api) | 1 CPU or 1 GPU* | +| [How to modify a structured script using DL2FL decorator](#the-decorator-use-case) | 1 CPU or 1 GPU* | +| [How to modify a PyTorch Lightning script using DL2FL Lightning Client API](#transform-cifar10-pytorch-lightning-training-code-to-fl-with-nvflare-client-lightning-integration-api) | 1 CPU or 1 GPU* | +| [How to modify a PyTorch DDP training script using DL2FL Client API](#transform-cifar10-pytorch--ddp-training-code-to-fl-using-client-api) | 2 GPUs | +| [How to modify a PyTorch Lightning DDP training script using DL2FL Lightning Client API](#transform-cifar10-pytorch-lightning--ddp-training-code-to-fl-with-nvflare-client-lightning-integration-api) | 2 CPUs or 2 GPUs** | - 1. [How to modify a PyTorch DDP training script using DL2FL Client API](#transform-cifar10-pytorch--ddp-training-code-to-fl-using-client-api) - 2. [How to modify a PyTorch Lightning DDP training script using DL2FL Lightning Client API](#transform-cifar10-pytorch-lightning--ddp-training-code-to-fl-with-nvflare-client-lightning-integration-api) + +\* it depends on you use `device=cpu` or `device=cuda` +\*\* it depends on whether `torch.cuda.is_available()` is True or not ## The baseline @@ -200,8 +219,6 @@ nvflare simulator -n 2 -t 2 ./jobs/lightning -w lightning_workspace We follow the official [PyTorch documentation](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html#initialize-ddp-with-torch-distributed-run-torchrun) and write a [./code/cifar10_ddp_original.py](./code/cifar10_ddp_original.py). -Note that this example requires at least 2 GPUs on your machine. - Note that we wrap the evaluation logic into a method for better usability. It can be run using the torch distributed run: diff --git a/examples/hello-world/ml-to-fl/pt/code/cifar10_lightning_ddp_fl.py b/examples/hello-world/ml-to-fl/pt/code/cifar10_lightning_ddp_fl.py index 62a513f548..0398c9d3b0 100644 --- a/examples/hello-world/ml-to-fl/pt/code/cifar10_lightning_ddp_fl.py +++ b/examples/hello-world/ml-to-fl/pt/code/cifar10_lightning_ddp_fl.py @@ -72,7 +72,9 @@ def main(): model = LitNet() cifar10_dm = CIFAR10DataModule() - trainer = Trainer(max_epochs=1, strategy="ddp", devices=2 if torch.cuda.is_available() else None) + trainer = Trainer( + max_epochs=1, strategy="ddp", devices=2, accelerator="gpu" if torch.cuda.is_available() else "cpu" + ) # (2) patch the lightning trainer flare.patch(trainer) diff --git a/examples/hello-world/ml-to-fl/pt/requirements.txt b/examples/hello-world/ml-to-fl/pt/requirements.txt index 91244a6ec2..ea496a9976 100644 --- a/examples/hello-world/ml-to-fl/pt/requirements.txt +++ b/examples/hello-world/ml-to-fl/pt/requirements.txt @@ -1,3 +1,5 @@ -nvflare[PT]>=2.4.0 +nvflare~=2.4.0rc +torch +torchvision jsonargparse[signatures]>=4.17.0 pytorch_lightning diff --git a/examples/hello-world/ml-to-fl/tf/README.md b/examples/hello-world/ml-to-fl/tf/README.md index 80b879425b..ce3845dfe9 100644 --- a/examples/hello-world/ml-to-fl/tf/README.md +++ b/examples/hello-world/ml-to-fl/tf/README.md @@ -1,18 +1,37 @@ # TensorFlow Deep Learning to Federated Learning transition with NVFlare +We will demonstrate how to transform an existing DL code into an FL application step-by-step: + +1. [How to modify an existing training script using DL2FL Client API](#transform-cifar10-tensorflow-training-code-to-fl-with-nvflare-client-api) + +2. [How to modify an existing multi GPU training script using DL2FL Client API](#transform-cifar10-tensorflow-multi-gpu-training-code-to-fl-with-nvflare-client-api) + +## Software Requirements + Please install the requirements first, it is suggested to install inside a virtual environment: ```bash pip install -r requirements.txt ``` -Note that for running with GPUs, we recommend using [NVIDIA TensorFlow docker](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tensorflow) +Please also configure the job templates folder: -We will demonstrate how to transform an existing DL code into an FL application step-by-step: +```bash +nvflare config -jt ../../../../job_templates/ +nvflare job list_templates +``` -1. [How to modify an existing training script using DL2FL Client API](#transform-cifar10-tensorflow-training-code-to-fl-with-nvflare-client-api) +## Minimum Hardware Requirements -2. [How to modify an existing multi GPU training script using DL2FL Client API](#transform-cifar10-tensorflow-multi-gpu-training-code-to-fl-with-nvflare-client-api) +| Example name | minimum requirements | +| ------------ | -------------------- | +| [How to modify an existing training script using DL2FL Client API](#transform-cifar10-tensorflow-training-code-to-fl-with-nvflare-client-api) | 1 CPU or 1 GPU* | +| [How to modify an existing multi GPU training script using DL2FL Client API](#transform-cifar10-tensorflow-multi-gpu-training-code-to-fl-with-nvflare-client-api) | 2 CPUs or 2 GPUs* | + +\* depends on whether TF can found GPU or not + + +For running with GPUs, please check the [note](#notes-on-running-with-gpus) ## Transform CIFAR10 TensorFlow training code to FL with NVFLARE Client API @@ -46,7 +65,6 @@ Please refer to [JOB CLI tutorial](../../../tutorials/job_cli.ipynb) on how to g We choose the [tensorflow job template](../../../../job_templates/sag_tf/) and run the following command to create the job: ```bash -nvflare config -jt ../../../../job_templates nvflare job create -force -j ./jobs/tensorflow -w sag_tf -sd ./code/ -f config_fed_client.conf app_script=cifar10_tf_fl.py ``` @@ -82,7 +100,6 @@ Please refer to [JOB CLI tutorial](../../../tutorials/job_cli.ipynb) on how to g We choose the [tensorflow job template](../../../../job_templates/sag_tf/) and run the following command to create the job: ```bash -nvflare config -jt ../../../../job_templates nvflare job create -force -j ./jobs/tensorflow_multi_gpu -w sag_tf -sd ./code/ -f config_fed_client.conf app_script=cifar10_tf_multi_gpu_fl.py ``` @@ -90,7 +107,27 @@ Then we can run the job using the simulator: ```bash bash ./prepare_data.sh -TF_GPU_ALLOCATOR=cuda_malloc_async nvflare simulator -n 2 -t 2 ./jobs/tensorflow_multi_gpu -w tensorflow_multi_gpu_workspace +nvflare simulator -n 2 -t 2 ./jobs/tensorflow_multi_gpu -w tensorflow_multi_gpu_workspace +``` + +## Notes on running with GPUs + + +If you choose to run the example using GPUs, it is important to note that, +by default, TensorFlow will attempt to allocate all available GPU memory at the start. +In scenarios where multiple clients are involved, you have a couple of options to address this. + +One approach is to include specific flags to prevent TensorFlow from allocating all GPU memory. +For instance: + +```bash +TF_FORCE_GPU_ALLOW_GROWTH=true TF_GPU_ALLOCATOR=cuda_malloc_async nvflare simulator -n 2 -t 2 ./jobs/tensorflow_multi_gpu -w tensorflow_multi_gpu_workspace ``` -Note that the flag "TF_GPU_ALLOCATOR=cuda_malloc_async" is only needed if you are going to run more than one process in the same GPU. +If you possess more GPUs than clients, +an alternative strategy is to run one client on each GPU. +This can be achieved as illustrated below: + +```bash +nvflare simulator -n 2 -gpu 0,1 ./jobs/tensorflow_multi_gpu -w tensorflow_multi_gpu_workspace +``` diff --git a/examples/hello-world/ml-to-fl/tf/requirements.txt b/examples/hello-world/ml-to-fl/tf/requirements.txt index 573f902d09..8f8b6bc27b 100644 --- a/examples/hello-world/ml-to-fl/tf/requirements.txt +++ b/examples/hello-world/ml-to-fl/tf/requirements.txt @@ -1,2 +1,2 @@ -nvflare>=2.4.0 +nvflare~=2.4.0rc tensorflow diff --git a/examples/hello-world/step-by-step/README.md b/examples/hello-world/step-by-step/README.md index 4e563122b2..59a9c487b0 100644 --- a/examples/hello-world/step-by-step/README.md +++ b/examples/hello-world/step-by-step/README.md @@ -1,60 +1,22 @@ # Step-by-Step Examples -When given a machine learning problem, we probably wonder, where do we start to formulate the federated learning problem. +To run the notebooks in each example, please make sure you first set up a virtual environment and install "./requirements.txt" and JupyterLab following the [example root readme](../README.md). -* What does the data look like? -* How do we compare global statistics with the site's local data statistics? -* How to formulate the federated algorithms - * https://developer.download.nvidia.com/healthcare/clara/docs/federated_traditional_machine_learning_algorithms.pdf -* Given the formulation, how to convert the existing machine learning or deep learning code to Federated learning code. - * [ML to FL examples](https://github.com/NVIDIA/NVFlare/blob/main/examples/hello-world/ml-to-fl/README.md) -* For different types of federated learning workflows: Scatter and Gather, Cyclic Weight Transfer, Swarming learning, -Vertical learning, ..., what do we need to change ? -* Further how can apply the experiment log, so all sites' metrics and global metrics can be viewed -* in experiment tracking tools such as Weights & Biases, MLFLow, or simply Tensorboard - -In this "step-by-step" examples, we will dive these questions in two series of examples: - -## Multi-class classification with image data using CIFAR10 dataset - -The CIFAR10 dataset has the following 10 classes: ‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’. -The images in CIFAR-10 are of size 3x32x32, i.e. 3-channel color images of 32x32 pixels in size. - -![image](cifar10/data/cifar10.png) - -We will use the [pytorch](https://pytorch.org/) deep learning framework to illustrate how to formulate and convert the deep learning training -program to a federated learning training program. The example will include: - -* Federated Histogram analysis with Federated Statistics -* Scatter and Gather (SAG) workflow with NVFLARE Client APIs -* Cyclic Weight Transfer workflow with NVFLARE Client APIs -* Swarm Learning Workflow with NVFLARE Client APIs -* SAG with NVFLARE model learner APIs -* SAG with NVFLARE Executor APIs -* SAG with NVFLARE Client APIs + MLflow +* [cifar10](cifar10) - Multi-class classification with image data using CIFAR10 dataset +* [higgs](higgs) - Binary classification with tabular data using HIGGS dataset +These step-by-step example series are aimed to help users quickly get started and learn about FLARE. +For consistency, each example in the series uses the same dataset- CIFAR10 for image data and the HIGGS dataset for tabular data. +The examples will build upon previous ones to showcase different features, workflows, or APIs, allowing users to gain a comprehensive understanding of FLARE functionalities (Note: each example is self-contained, so going through them in order is not required, but recommended). See the README in each directory for more details about each series. -## Binary classification with tabular data using HIGGS dataset +## Common Questions -### HIGGS Dataset +Here are some common questions we aim to cover in these examples series when formulating a federated learning problem: -[HIGGS dataset](https://archive.ics.uci.edu/dataset/280/higgs) contains 11 million instances, each with 28 attributes, for binary classification to predict whether an event corresponds to the decayment of a Higgs boson or not. - -The first 21 features (columns 2-22) are kinematic properties measured by the particle detectors in the accelerator. -The data has been produced using Monte Carlo simulations. The first 21 features are kinematic properties measured by the particle detectors in the accelerator. The last 7 features are functions of the first 21 features; these are high-level features derived by physicists to help discriminate between the two classes. - -Please note that the [UCI's website](https://archive.ics.uci.edu/dataset/280/higgs) may experience occasional downtime. - -With the HIGGs Dataset, we like to demonstrate traditional machine learning techniques in federated learning. -These include: - -* Federated Statistics for tabular data -* Federated Logistic Regression -* Federated Kmeans -* Federated SVM -* Federated Horizontal XGBoost - -These examples demostrate: -* How to use the NVFlare Client APIs to convert the traditional machine learning code to federated learning code. Most of them contains local training scripts as baselines for comparison. -* How different machine learning methods can be applied to the same problem. Different behaviors and accuracies can be observed, as a reference for choosing the right method for the problem. -* How federated learning impacts different machine learning methods. Some methods are more sensitive to the federated learning process, and some are less. +* What does the data look like? +* How do we compare global statistics with the site's local data statistics? +* How to formulate the [federated algorithms](https://developer.download.nvidia.com/healthcare/clara/docs/federated_traditional_machine_learning_algorithms.pdf)? +* How do we convert the existing machine learning or deep learning code to federated learning code? [ML to FL examples](https://github.com/NVIDIA/NVFlare/blob/main/examples/hello-world/ml-to-fl/README.md) +* How do we use different types of federated learning workflows (e.g. Scatter and Gather, Cyclic Weight Transfer, Swarming learning, +Vertical learning) and what do we need to change? +* How can we capture the experiment log, so all sites' metrics and global metrics can be viewed in experiment tracking tools such as Weights & Biases MLfLow, or Tensorboard diff --git a/examples/hello-world/step-by-step/cifar10/README.md b/examples/hello-world/step-by-step/cifar10/README.md index 245542096c..d90d7471ec 100644 --- a/examples/hello-world/step-by-step/cifar10/README.md +++ b/examples/hello-world/step-by-step/cifar10/README.md @@ -1,5 +1,5 @@ -# Training a image classifier with CIFAR10 data +# Training an image classifier with CIFAR10 dataset We will use the original [Training a Classifer](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) example in pytorch as the code base. @@ -9,12 +9,16 @@ The images in CIFAR-10 are of size 3x32x32, i.e. 3-channel color images of 32x32 ![image](data/cifar10.png) -In the follow examples, we will show various Federated Learning workflows. +In the following examples, we will show various Federated Learning workflows and features: -* [Image intensity histogram caculation](stats) -* Scatter and Gather (SAG) workflow with NVFLARE Client APIs -* Cyclic Weight Transfer workflow with NVFLARE Client APIs -* Swarm Learning Workflow with NVFLARE Client APIs -* SAG with NVFLARE model learner APIs -* SAG with NVFLARE Executor APIs -* SAG with NVFLARE Client APIs + MLflow +* [stats](stats) - Federated statistics image intensity histogram calculation. +* [sag](sag) - Scatter and Gather (SAG) workflow with Client API +* [sag_deploy_map](sag_deploy_map) - SAG with deploy_map configuration for deployment of apps to different sites. +* [sag_model_learner](sag_model_learner) - SAG with Model Learner API +* [sag_executor](sag_executor) - SAG with Executor API +* [sag_mlflow](sag_mlflow) - SAG with MLflow experiment tracking logs. +* [sag_he](sag_he) - SAG with homomorphic encyption using POC -he mode. +* [cse](cse) - Cross-site evaluation with server-side controller. +* [cyclic](cyclic) - Cyclic Weight Transfer (cyclic) workflow with server-side controller. +* [cyclic_ccwf](cyclic_ccwf) - Client-controlled cyclic workflow with client-side controller. +* [swarm](swarm) - Swarm learning and client-controlled cross-site evaluation. diff --git a/examples/hello-world/step-by-step/cifar10/cse/cse.ipynb b/examples/hello-world/step-by-step/cifar10/cse/cse.ipynb index 176fc875f7..a25c7ffd32 100644 --- a/examples/hello-world/step-by-step/cifar10/cse/cse.ipynb +++ b/examples/hello-world/step-by-step/cifar10/cse/cse.ipynb @@ -180,9 +180,9 @@ "id": "48271064", "metadata": {}, "source": [ - "For additional resources, see other examples for SAG with CSE using the [ModelLearner](../sag_model_learner/sag_model_learner.ipynb), [Executor](../sag_executor/sag_executor.ipynb), and [Hello-Numpy](https://github.com/NVIDIA/NVFlare/tree/main/examples/hello-world/hello-numpy-cross-val).\n", + "For additional resources, see other examples for SAG with CSE using the [ModelLearner](../sag_model_learner/sag_model_learner.ipynb) and [Executor](../sag_executor/sag_executor.ipynb). [Hello-Numpy](https://github.com/NVIDIA/NVFlare/tree/main/examples/hello-world/hello-numpy-cross-val) also demonstrates how to run cross-site evaluation using the previous training results.\n", "\n", - "Also the ability to run Cross-site Evaluation without having to re-run training will be added in the near future." + "Next we will look at the [cyclic](../cyclic/cyclic.ipynb) example, which shows the cyclic workflow for the Cyclic Weight Transfer algorithm." ] }, { diff --git a/examples/hello-world/step-by-step/cifar10/cyclic/cyclic.ipynb b/examples/hello-world/step-by-step/cifar10/cyclic/cyclic.ipynb index 153a4b6468..5dbf03dd2c 100644 --- a/examples/hello-world/step-by-step/cifar10/cyclic/cyclic.ipynb +++ b/examples/hello-world/step-by-step/cifar10/cyclic/cyclic.ipynb @@ -140,7 +140,10 @@ "id": "48271064", "metadata": {}, "source": [ - "As an additional resource, also see the [hello-cyclic](../../../../hello-world/hello-cyclic/README.md) for a Tensorflow Executor implementation using the MNIST dataset." + "As an additional resource, also see the [hello-cyclic](../../../../hello-world/hello-cyclic/README.md) for a Tensorflow Executor implementation using the MNIST dataset.\n", + "\n", + "While this example focused on the server-controlled cyclic workflow, now we will introduce the idea of client-controlled workflows.\n", + "The next [cyclic_ccwf](../cyclic_ccwf/cyclic_ccwf.ipynb) example is a client-controlled version of the cyclic workflow." ] }, { diff --git a/examples/hello-world/step-by-step/cifar10/cyclic_ccwf/cyclic_ccwf.ipynb b/examples/hello-world/step-by-step/cifar10/cyclic_ccwf/cyclic_ccwf.ipynb index f90ce77d13..778943998e 100644 --- a/examples/hello-world/step-by-step/cifar10/cyclic_ccwf/cyclic_ccwf.ipynb +++ b/examples/hello-world/step-by-step/cifar10/cyclic_ccwf/cyclic_ccwf.ipynb @@ -145,7 +145,9 @@ "cell_type": "markdown", "id": "9bef3134", "metadata": {}, - "source": [] + "source": [ + "Lastly, we have the [swarm](../swarm/swarm.ipynb) example, which covers swarm learning and client-controlled cross-site evaluation workflows." + ] } ], "metadata": { diff --git a/examples/hello-world/step-by-step/cifar10/sag/fed_avg_one_round.png b/examples/hello-world/step-by-step/cifar10/sag/fed_avg_one_round.png new file mode 100644 index 0000000000..fea79a329b Binary files /dev/null and b/examples/hello-world/step-by-step/cifar10/sag/fed_avg_one_round.png differ diff --git a/examples/hello-world/step-by-step/cifar10/sag/mpi_gather.png b/examples/hello-world/step-by-step/cifar10/sag/mpi_gather.png new file mode 100644 index 0000000000..0c01ee9a6d Binary files /dev/null and b/examples/hello-world/step-by-step/cifar10/sag/mpi_gather.png differ diff --git a/examples/hello-world/step-by-step/cifar10/sag/mpi_scatter.png b/examples/hello-world/step-by-step/cifar10/sag/mpi_scatter.png new file mode 100644 index 0000000000..421507a339 Binary files /dev/null and b/examples/hello-world/step-by-step/cifar10/sag/mpi_scatter.png differ diff --git a/examples/hello-world/step-by-step/cifar10/sag/sag.ipynb b/examples/hello-world/step-by-step/cifar10/sag/sag.ipynb index 64466a781e..f29f41249e 100644 --- a/examples/hello-world/step-by-step/cifar10/sag/sag.ipynb +++ b/examples/hello-world/step-by-step/cifar10/sag/sag.ipynb @@ -10,19 +10,32 @@ "# FedAvg Algorithm with SAG (Scatter & Gather) workflow\n", "\n", "\n", - "In this example, we will demonstrate the SAG workflow with FedAvg using CIFAR10 dataset. \n", + "In this example, we will demonstrate the SAG workflow with FedAvg using the CIFAR10 dataset.\n", "\n", - "Both Job Lifecycle and training workflow are controlled on the **server side**, we will just use the existing available SAG controller availalbe in NVFLARE. \n", + "Both Job Lifecycle and training workflow are controlled on the server side; we will just use the existing available SAG controller available in NVFLARE.\n", + "\n", + "For client-side training code, we will leverage the new DL to FL Client API.\n", + "\n", + "First, let's look at the FedAvg Algorithm and SAG Workflow.\n", + "\n", + "\n", + "## Scatter and Gather (SAG)\n", + "\n", + "FLARE's Scatter and Gather workflow is similar to the Message Passing Interface (MPI)'s MPI Broadcast + MPI Gather. [MPI](https://en.wikipedia.org/wiki/Message_Passing_Interface) is a standardized and portable message-passing standard designed to function on parallel computing architectures. MPI consists of some [collective communication routines](https://mpitutorial.com/tutorials/mpi-broadcast-and-collective-communication/), such as MPI Broadcast, MPI Scatter, and MPI Gather.\n", + "\n", + "\"scatter\"\"gather\"\n", "\n", - "For client side training code, we will leverage new DL to FL **Client API**\n", "\n", - "First, Let's look at the FedAvg Algorithm and SAG Workflow. \n", "\n", "## FedAvg with SAG\n", + "We use [SAG workflow](https://nvflare.readthedocs.io/en/main/programming_guide/controllers/scatter_and_gather_workflow.html) to implement the FedAvg algorithm. You can see one round of training in such workflow.\n", + "\n", + "\"FedAvg\"\n", + "\n", "\n", "\"FedAvg\" \"Scatter\n", "\n", - "The Fed Avg aggregation is done on the server side, its weighted on the number of training steps on each client\n", + "The FedAvg aggregation is done on the server side, its weighted on the number of training steps on each client\n", " \n", "## Convert training code to federated learning training code\n", "\n", @@ -91,12 +104,11 @@ "source": [ "## Job Folder and Configurations\n", "\n", + " \n", + "Now we need to set up the configurations for the server and clients and construct the Job folder NVFLARE needs to run. We can do this using NVFLARE job CLI. You can study the [Job CLI tutorials](https://github.com/NVIDIA/NVFlare/blob/main/examples/tutorials/job_cli.ipynb) later with all the details. But for now, you can just use the following commands to find out the available job templates.\n", "\n", - "Now we need to setup the configurations for server and clients and constructure Job folder NVFLARE needed to run. We can do this using NVFLARE job CLI. You can study the [Job CLI tutorials](https://github.com/NVIDIA/NVFlare/blob/main/examples/tutorials/job_cli.ipynb) later with all the details. But for now, you can just use the following commands\n", - "\n", - "* Find out the available job templates\n", - "\n", - "We need to set the job templates directory, so the job cli commands can find the job templates. If have already set NVFLARE_HOME=``` ```then, you can skipt the folllowing step. \n" + "We need to set the job templates directory so the job CLI commands can find the job templates. If you have already set `NVFLARE_HOME` to ``, then you can skip the following step.\n", + "\n" ] }, { @@ -130,7 +142,7 @@ "source": [ "* Create job folder and initial configs\n", "\n", - "The template **'sag_pt'** seems to fit our needs: SAG with pytorch, using client API. Lets create a job folder with this template initially without specifying the code location, just see what's needs to be changed" + "The template **'sag_pt'** seems to fit our needs: SAG with PyTorch, using the client API. Let's create a job folder with this template initially without specifying the code location, just to see what needs to be changed.\n" ] }, { @@ -182,10 +194,9 @@ "id": "dbb79b5c-f97f-472f-91a5-5a5175fb9759", "metadata": {}, "source": [ - "* Create job folder with all the configs\n", + "* Create a job folder with all the configurations.\n", "\n", - "Let's change the num_rounds = 5, script = train.py, min_clients = 2 for meta.conf. We also like to change the arguments for train.py \n", - "dataset_path=CIFAR10_ROOT, batch_size=6, num_workers = 2. Here dataset_path is actually not changed, but we just want to show you could change. " + "Let's change the `num_rounds` to 5, `script` to `train.py`, and `min_clients` to 2 in `meta.conf`. We also want to change the arguments for `train.py`: `dataset_path=CIFAR10_ROOT`, `batch_size=6`, `num_workers=2`. Note that the `dataset_path` is not actually changed, but we just want to show you that it could be changed.\n" ] }, { @@ -251,15 +262,18 @@ "id": "b055bde7-432d-4e6b-9163-b5ab7ede7b73", "metadata": {}, "source": [ - "The job should be running in the simulator mode. We are done with the training. " + "The job should be running in the simulator mode. We are done with the training. \n", + "\n", + "The next 5 examples will use the same ScatterAndGather workflow, but will demonstrate different execution APIs and feature.\n", + "In the next example [sag_deploy_map](../sag_deploy_map/sag_deploy_map.ipynb), we will learn about the deploy_map configuration for deployment of apps to different sites." ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "nvflare_example", "language": "python", - "name": "python3" + "name": "nvflare_example" }, "language_info": { "codemirror_mode": { diff --git a/examples/hello-world/step-by-step/cifar10/sag_with_deploy_map/sag_deploy_map.ipynb b/examples/hello-world/step-by-step/cifar10/sag_deploy_map/sag_deploy_map.ipynb similarity index 97% rename from examples/hello-world/step-by-step/cifar10/sag_with_deploy_map/sag_deploy_map.ipynb rename to examples/hello-world/step-by-step/cifar10/sag_deploy_map/sag_deploy_map.ipynb index 91d4beef0d..b59dbad293 100644 --- a/examples/hello-world/step-by-step/cifar10/sag_with_deploy_map/sag_deploy_map.ipynb +++ b/examples/hello-world/step-by-step/cifar10/sag_deploy_map/sag_deploy_map.ipynb @@ -261,7 +261,10 @@ "id": "0af8036f-1f94-426d-8eb7-6e8b9be70a7e", "metadata": {}, "source": [ - "The job should be running in the simulator mode. We are done with the training. " + "The job should be running in the simulator mode. We are done with the training. \n", + "\n", + "In the next example [sag_model_learner](../sag_model_learner/sag_model_learner.ipynb), we will illustrate how to use the Model Learner API instead of the Client API,\n", + "and highlight why and when to use it." ] } ], diff --git a/examples/hello-world/step-by-step/cifar10/sag_executor/sag_executor.ipynb b/examples/hello-world/step-by-step/cifar10/sag_executor/sag_executor.ipynb index aed5f6ba74..e71dd87cbf 100644 --- a/examples/hello-world/step-by-step/cifar10/sag_executor/sag_executor.ipynb +++ b/examples/hello-world/step-by-step/cifar10/sag_executor/sag_executor.ipynb @@ -222,7 +222,12 @@ "id": "48271064", "metadata": {}, "source": [ - "For additional resources, take a look at the various other executors with different use cases in the app_common, app_opt, and examples folder." + "For additional resources, take a look at the various other executors with different use cases in the app_common, app_opt, and examples folder.\n", + "\n", + "In the previous examples we have finished covering each of Execution API types: the Client API, Model Learner, and Executor.\n", + "Now we will be using the Client API in future examples to highlight other features and workflows.\n", + "\n", + "Next we have the [sag_mlflow](../sag_mlflow/sag_mlflow.ipynb) example, which shows how to enable MLflow experiment tracking logs." ] }, { diff --git a/examples/hello-world/step-by-step/cifar10/sag_he/sag_he.ipynb b/examples/hello-world/step-by-step/cifar10/sag_he/sag_he.ipynb index 12936dd208..c80f7d37af 100644 --- a/examples/hello-world/step-by-step/cifar10/sag_he/sag_he.ipynb +++ b/examples/hello-world/step-by-step/cifar10/sag_he/sag_he.ipynb @@ -197,7 +197,10 @@ "id": "b19da336", "metadata": {}, "source": [ - "As an additional resource, see the [CIFAR10 Real World Example](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) for creating a secure workspace for HE using provisioning instead of POC mode." + "As an additional resource, see the [CIFAR10 Real World Example](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/cifar10/cifar10-real-world) for creating a secure workspace for HE using provisioning instead of POC mode.\n", + "\n", + "Now we will begin to take a look at other workflows besides ScatterAndGather.\n", + "First we have the [cse](../cse/cse.ipynb) example, which shows the server-controlled cross-site evaluation workflow." ] } ], diff --git a/examples/hello-world/step-by-step/cifar10/sag_mlflow/sag_mlflow.ipynb b/examples/hello-world/step-by-step/cifar10/sag_mlflow/sag_mlflow.ipynb index cb39afaf61..fa295c0e3b 100644 --- a/examples/hello-world/step-by-step/cifar10/sag_mlflow/sag_mlflow.ipynb +++ b/examples/hello-world/step-by-step/cifar10/sag_mlflow/sag_mlflow.ipynb @@ -183,12 +183,12 @@ ] }, { - "cell_type": "code", - "execution_count": null, - "id": "e69c9ed2-359a-4f97-820f-25e9323a4e92", + "cell_type": "markdown", + "id": "58037d1e", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "Next we will look at the [sag_he](../sag_he/sag_he.ipynb) example, which demonstrates how to enable homomorphic encryption using the POC -he mode." + ] } ], "metadata": { diff --git a/examples/hello-world/step-by-step/cifar10/sag_model_learner/sag_model_learner.ipynb b/examples/hello-world/step-by-step/cifar10/sag_model_learner/sag_model_learner.ipynb index 081ad00674..ae8aec8a6c 100644 --- a/examples/hello-world/step-by-step/cifar10/sag_model_learner/sag_model_learner.ipynb +++ b/examples/hello-world/step-by-step/cifar10/sag_model_learner/sag_model_learner.ipynb @@ -19,7 +19,7 @@ "\n", "Key Concepts:\n", "- Learning\n", - " - `FLModel` object defines structure to containe essential information about the learning task, such as `params`, `metrics`, `meta`, etc.\n", + " - `FLModel` object defines structure to contain essential information about the learning task, such as `params`, `metrics`, `meta`, etc.\n", " - learning logic implemented in `train()` and `validate` methods, which both receive and send an `FLModel` object\n", " - return requested model via `get_model()`\n", "- Lifecycle\n", @@ -204,7 +204,9 @@ "id": "48271064", "metadata": {}, "source": [ - "As an additional resource, also see the [CIFAR10 examples](../../../../advanced/cifar10/README.md) for a comprehensive implementation of a PyTorch ModelLearner." + "As an additional resource, also see the [CIFAR10 examples](../../../../advanced/cifar10/README.md) for a comprehensive implementation of a PyTorch ModelLearner.\n", + "\n", + "In the next example [sag_executor](../sag_executor/sag_executor.ipynb), we will illustrate how to use the Executor API for more specific use cases." ] }, { @@ -216,9 +218,9 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "nvflare_example", "language": "python", - "name": "python3" + "name": "nvflare_example" }, "language_info": { "codemirror_mode": { @@ -230,7 +232,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.8.16" } }, "nbformat": 4, diff --git a/examples/hello-world/step-by-step/cifar10/stats/image_stats.ipynb b/examples/hello-world/step-by-step/cifar10/stats/image_stats.ipynb index 0ba0d4e5fd..ede5f7b1e1 100644 --- a/examples/hello-world/step-by-step/cifar10/stats/image_stats.ipynb +++ b/examples/hello-world/step-by-step/cifar10/stats/image_stats.ipynb @@ -664,9 +664,8 @@ "\n", "If you would like to see another example of federated statistics calculations and configurations, please checkout [federated_statistics](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/federated-statistics) and [fed_stats with spleen_ct_segmentation](https://github.com/NVIDIA/NVFlare/tree/main/integration/monai/examples/spleen_ct_segmentation_sim)\n", "\n", - "Let's move on to the next example and see how can we train the image classifier using pytorch with CIFAR10 data.\n", - "\n", - "\n" + "Let's move on to the next examples and see how can we train the image classifier using pytorch with CIFAR10 data.\n", + "First we will look at the [sag](../sag/sag.ipynb) example, which illustrates how to use the ScatterAndGather workflow for FedAvg with the Client API.\n" ] } ], diff --git a/examples/hello-world/step-by-step/cifar10/stats/requirements.txt b/examples/hello-world/step-by-step/cifar10/stats/requirements.txt index 45e20cc1ee..9e0a46f617 100644 --- a/examples/hello-world/step-by-step/cifar10/stats/requirements.txt +++ b/examples/hello-world/step-by-step/cifar10/stats/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc numpy monai[itk] pandas diff --git a/examples/hello-world/step-by-step/cifar10/swarm/swarm.ipynb b/examples/hello-world/step-by-step/cifar10/swarm/swarm.ipynb index af5ba361b3..bc66e200f2 100644 --- a/examples/hello-world/step-by-step/cifar10/swarm/swarm.ipynb +++ b/examples/hello-world/step-by-step/cifar10/swarm/swarm.ipynb @@ -154,7 +154,10 @@ "id": "48271064", "metadata": {}, "source": [ - "As an additional resource, also see the [Swarm Learning Example](../../../../advanced/swarm_learning/README.md) which utilizes the CIFAR10 ModelLearner instead of the Client API." + "As an additional resource, also see the [Swarm Learning Example](../../../../advanced/swarm_learning/README.md) which utilizes the CIFAR10 ModelLearner instead of the Client API.\n", + "\n", + "Congratulations! You have completed the CIFAR10 step-by-step example series.\n", + "Next take a look at the [higgs](../../higgs/README.md) example series for how to use machine learning methods for federated learning on tabular data." ] }, { diff --git a/examples/hello-world/step-by-step/higgs/README.md b/examples/hello-world/step-by-step/higgs/README.md index 09baafba9b..9f344b43d3 100644 --- a/examples/hello-world/step-by-step/higgs/README.md +++ b/examples/hello-world/step-by-step/higgs/README.md @@ -1,19 +1,23 @@ -# Training traditional ML classifiers with HIGGS data +# Training traditional ML classifiers with HIGGS dataset -[HIGGS dataset](https://archive.ics.uci.edu/dataset/280/higgs) contains 11 million instances, each with 28 attributes, for binary classification to predict whether an event corresponds to the decayment of a Higgs boson or not. +The [HIGGS dataset](https://archive.ics.uci.edu/dataset/280/higgs) contains 11 million instances, each with 28 attributes, for binary classification to predict whether an event corresponds to the decayment of a Higgs boson or not. Follow the [prepare_data.ipynb](prepare_data.ipynb) notebook to download the HIGGS dataset and prepare the data splits. +(Please note that the [UCI's website](https://archive.ics.uci.edu/dataset/280/higgs) may experience occasional downtime) The first 21 features (columns 2-22) are kinematic properties measured by the particle detectors in the accelerator. The data has been produced using Monte Carlo simulations. The first 21 features are kinematic properties measured by the particle detectors in the accelerator. The last 7 features are functions of the first 21 features; these are high-level features derived by physicists to help discriminate between the two classes. -Please note that the [UCI's website](https://archive.ics.uci.edu/dataset/280/higgs) may experience occasional downtime. +Key Concepts: +* How to use the NVFlare Client APIs to convert the traditional machine learning code to federated learning code. Most of them contains local training scripts as baselines for comparison. +* How different machine learning methods can be applied to the same problem. Different behaviors and accuracies can be observed, as a reference for choosing the right method for the problem. +* How federated learning impacts different machine learning methods. Some methods are more sensitive to the federated learning process, and some are less. -With the HIGGs Dataset, in the following examples, we like to demonstrate traditional machine learning techniques in federated learning. -These include: +In the following examples, we will demonstrate traditional machine learning techniques with tabular data for federated learning: + +* [stats](stats) - Federated statistics for tabular histogram calculation. +* [sklearn-linear](sklearn-linear) - Federated linear model (logistic regression on binary classification) learning. +* [sklearn-svm](sklearn-svm) - Federated SVM model learning. +* [sklearn-kmeans](sklearn-kmeans) - Federated k-Means clustering. +* [xgboost](xgboost) - Federated horizontal xgboost learning with bagging collaboration. -* Federated Statistics for tabular data -* Federated Logistic Regression -* Federated Kmeans -* Federated SVM -* Federated Horizontal XGBoost diff --git a/examples/hello-world/step-by-step/higgs/sklearn-kmeans/sklearn_kmeans.ipynb b/examples/hello-world/step-by-step/higgs/sklearn-kmeans/sklearn_kmeans.ipynb index 961ab69c66..1989af4eaa 100644 --- a/examples/hello-world/step-by-step/higgs/sklearn-kmeans/sklearn_kmeans.ipynb +++ b/examples/hello-world/step-by-step/higgs/sklearn-kmeans/sklearn_kmeans.ipynb @@ -452,7 +452,10 @@ "HIGGS dataset is challenging for unsupervised clustering, as we can observe from the result. As shown by the local training with same number of iterations, the score is `model homogeneity_score: 0.0049`. As compared with the FL score of `0.0068`, FL in this case still provides some benefit from the collaborative learning.\n", "\n", "## We are done !\n", - "Congratulations! you have just completed the federated k-Means clustering for tabular data. " + "Congratulations! you have just completed the federated k-Means clustering for tabular data. \n", + "\n", + "Now we will move on from scikit-learn and take a look at how to use federated XGBoost.\n", + "In the next example [xgboost](../xgboost/xgboost_horizontal.ipynb), we will show a federated horizontal xgboost learning with bagging collaboration." ] }, { diff --git a/examples/hello-world/step-by-step/higgs/sklearn-linear/sklearn_linear.ipynb b/examples/hello-world/step-by-step/higgs/sklearn-linear/sklearn_linear.ipynb index 6e653a86bc..462a7448bc 100644 --- a/examples/hello-world/step-by-step/higgs/sklearn-linear/sklearn_linear.ipynb +++ b/examples/hello-world/step-by-step/higgs/sklearn-linear/sklearn_linear.ipynb @@ -454,12 +454,14 @@ "id": "ea7bbacc-b059-4f82-9785-2b22bf840ef9", "metadata": {}, "source": [ - "In this experiment, all three clients have relatively large amount data wiht homogeneous distribution, we would expect the three numbers align within reasonable variation range. \n", + "In this experiment, all three clients have relatively large amount data with homogeneous distribution, we would expect the three numbers align within reasonable variation range. \n", "\n", "The final result for iterative learning is `ending model AUC: 0.6352`, and one-shot learning is `local model AUC: 0.6355`, as compared with FL's `local model AUC: 0.6351`, the numbers do align.\n", "\n", "## We are done !\n", - "Congratulations! you have just completed the federated linear model for tabular data. " + "Congratulations! you have just completed the federated linear model for tabular data. \n", + "\n", + "In the next example [sklearn-svm](../sklearn-svm/sklearn_svm.ipynb), we will demonstrate training a federated SVM model." ] }, { diff --git a/examples/hello-world/step-by-step/higgs/sklearn-svm/sklearn_svm.ipynb b/examples/hello-world/step-by-step/higgs/sklearn-svm/sklearn_svm.ipynb index 29a85b8c44..850cc955c4 100644 --- a/examples/hello-world/step-by-step/higgs/sklearn-svm/sklearn_svm.ipynb +++ b/examples/hello-world/step-by-step/higgs/sklearn-svm/sklearn_svm.ipynb @@ -431,7 +431,9 @@ "The final result for local SVM learning is `model AUC: 0.6217`, as compared with FL's `model AUC: 0.6403`, this confirms our expectation.\n", "\n", "## We are done !\n", - "Congratulations! you have just completed the federated SVM model for tabular data. " + "Congratulations! you have just completed the federated SVM model for tabular data. \n", + "\n", + "In the next example [sklearn-kmeans](../sklearn-kmeans/sklearn_kmeans.ipynb), we will illustrate a federated k-Means clustering." ] }, { diff --git a/examples/hello-world/step-by-step/higgs/stats/tabular_stats.ipynb b/examples/hello-world/step-by-step/higgs/stats/tabular_stats.ipynb index 8946a902ae..455309941c 100644 --- a/examples/hello-world/step-by-step/higgs/stats/tabular_stats.ipynb +++ b/examples/hello-world/step-by-step/higgs/stats/tabular_stats.ipynb @@ -293,7 +293,7 @@ "source": [ "## Create Federated Statistics Job\n", "\n", - "We are going to use NVFLARE job cli to create job. For detailed instructions on Job CLI, please follow the [job cli tutorial](https://github.com/NVIDIA/NVFlare/blob/main/examples/tutorials/job_cli.ipynb)\n", + "We are going to use NVFLARE job cli to create a job. For detailed instructions on Job CLI, please follow the [job cli tutorial](https://github.com/NVIDIA/NVFlare/blob/main/examples/tutorials/job_cli.ipynb)\n", "\n", "Let's check the available job templates, we are going to use one of the existing job templates and modify it to fit our needs. The job template is nothing but server and client-side job configurations." ] @@ -607,7 +607,10 @@ "## We are done !\n", "Congratulations! you have just completed the federated stats calulation for tabular data. \n", "\n", - "If you would like to see a detailed discussion regarding privacy filtering, please checkout the example in [federated statistics](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/federated-statistics) examples." + "If you would like to see a detailed discussion regarding privacy filtering, please checkout the example in [federated statistics](https://github.com/NVIDIA/NVFlare/tree/main/examples/advanced/federated-statistics) examples.\n", + "\n", + "Let's move on to the next examples and see how can we use scikit-learn to train federated models on tabular data.\n", + "First we will look at the [sklearn-linear](../sklearn-linear/sklearn_linear.ipynb) example, which illustrates how to train a federated linear model (logistic regression on binary classification)." ] }, { diff --git a/examples/hello-world/step-by-step/higgs/xgboost/xgboost_horizontal.ipynb b/examples/hello-world/step-by-step/higgs/xgboost/xgboost_horizontal.ipynb index 6b1702d5ca..54b35705ab 100644 --- a/examples/hello-world/step-by-step/higgs/xgboost/xgboost_horizontal.ipynb +++ b/examples/hello-world/step-by-step/higgs/xgboost/xgboost_horizontal.ipynb @@ -481,7 +481,11 @@ "Both oneshot and iterative training schemes yield idential results of `local model AUC: 0.81928`. As compared with FL's `global model AUC: 0.82085`, we can notice FL gives some benefits, even under homogeneous data distribution across clients.\n", "\n", "## We are done !\n", - "Congratulations! you have just completed the federated xgboost model for tabular data. " + "Congratulations! you have just completed the federated xgboost model for tabular data. \n", + "\n", + "You have now completed the HIGGS step-by-step example series.\n", + "Next either take a look at the [cifar10](../../cifar10/README.md) example series for how to train an image classifier with PyTorch, or explore the\n", + "[examples/advanced](../../../../advanced/README.md) directory for more in-depth examples." ] }, { diff --git a/examples/hello-world/step-by-step/requirements.txt b/examples/hello-world/step-by-step/requirements.txt index e4bdfc07bf..3bbfea441b 100644 --- a/examples/hello-world/step-by-step/requirements.txt +++ b/examples/hello-world/step-by-step/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.2 +nvflare~=2.4.0rc torch torchvision tensorboard diff --git a/examples/nvflare_setup.ipynb b/examples/nvflare_setup.ipynb index eede563c67..c911037cca 100644 --- a/examples/nvflare_setup.ipynb +++ b/examples/nvflare_setup.ipynb @@ -23,7 +23,7 @@ "source": [ "### Install NVFLARE from PyPI\n", "```\n", - "pip install 'nvflare>=2.3.0'\n", + "pip install 'nvflare~=2.4.0'\n", "\n", "```\n", "We do not recommend running NVFlare CLI commands in jupyter notebook cells." diff --git a/examples/tutorials/flare_simulator.ipynb b/examples/tutorials/flare_simulator.ipynb index dd58e77901..89aa73df0f 100644 --- a/examples/tutorials/flare_simulator.ipynb +++ b/examples/tutorials/flare_simulator.ipynb @@ -7,7 +7,7 @@ "source": [ "## Intro to the FL Simulator\n", "\n", - "The [FL Simulator](https://nvflare.readthedocs.io/en/main/user_guide/fl_simulator.html) runs a local simulation of a running NVFLARE FL deployment. This allows researchers to test and debug an application without provisioning a real, distributed FL project. The FL Simulator runs a server and multiple clients in the same local process, with communication that mimics a real deployment. This allows researchers to more quickly build out new components and jobs that can be directly used in a production deployment.\n", + "The [FL Simulator](https://nvflare.readthedocs.io/en/latest/user_guide/nvflare_cli/fl_simulator.html) runs a local simulation of a running NVFLARE FL deployment. This allows researchers to test and debug an application without provisioning a real, distributed FL project. The FL Simulator runs a server and multiple clients in the same local process, with communication that mimics a real deployment. This allows researchers to more quickly build out new components and jobs that can be directly used in a production deployment.\n", "\n", "### Setup\n", "The NVFlare [Getting Started Guide](https://nvflare.readthedocs.io/en/main/getting_started.html) provides instructions for setting up FLARE on a local system or in a Docker image. We've also cloned the NVFlare GitHub in our top-level working directory." diff --git a/examples/tutorials/job_cli.ipynb b/examples/tutorials/job_cli.ipynb index d57d855f90..858c199069 100644 --- a/examples/tutorials/job_cli.ipynb +++ b/examples/tutorials/job_cli.ipynb @@ -15,7 +15,9 @@ "tags": [] }, "source": [ - "In this notebook, we will go through the different commands of the Job CLI to show the syntax and usage of each.\n" + "In this notebook, we will go through the different commands of the Job CLI to show the syntax and usage of each.\n", + "Refer to the [Job CLI Documentation](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/job_cli.html) for more details.\n", + "\n" ] }, { diff --git a/examples/tutorials/setup_poc.ipynb b/examples/tutorials/setup_poc.ipynb index 26b99210c1..d1a1aa0007 100644 --- a/examples/tutorials/setup_poc.ipynb +++ b/examples/tutorials/setup_poc.ipynb @@ -7,7 +7,7 @@ "source": [ "# Set Up NVFLARE in POC Mode\n", "\n", - "[POC mode](https://nvflare.readthedocs.io/en/main/user_guide/poc_command.html) allows users to test the features of a full FLARE deployment on a single machine, without the overhead of a true distributed deployment.\n", + "[POC mode](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/poc_command.html) allows users to test the features of a full FLARE deployment on a single machine, without the overhead of a true distributed deployment.\n", "\n", "Compared to the FL Simulator, where the job run is automated on a single system, POC mode allows you to establish and connect distinct server and client \"systems\" which can then be orchestrated using the FLARE Console. This can be useful in preparation for a distributed deployment.\n", "\n", @@ -628,23 +628,24 @@ "source": [ "### Support Homomorphic Encryption (HE)\n", "\n", - "To support HE, we need the provision process to generates Tenseal homomorphic encryption context for server and client and writes them to server and client\n", - "participant folders see [provision context](https://nvflare.readthedocs.io/en/main/programming_guide/provisioning_system.html#provision-context). This is achieved by Provision builder, in particular for HE, HEBuilder. Instead of manaully add HEBuilder to project.yml file, one can use ```-he``` in poc command\n", + "To support HE, we need the provision process to generate Tenseal homomorphic encryption context for the server and client and write them to the server and client participant folders. See [provision context](https://nvflare.readthedocs.io/en/main/programming_guide/provisioning_system.html#provision-context). This is achieved by the Provision builder, specifically for HE, HEBuilder. Instead of manually adding HEBuilder to the `project.yml` file, one can use `-he` in the POC command.\n", + "\n", + "For example, if we use the above command with HE, we can write as\n", "\n", - "For example, if we use above command with HE, we can write as\n", "\n", "```\n", "echo 'y' | nvflare poc prepare -c hospital_1 hospital_2 -d 'nvflare/nvflare' -he\n", "\n", "```\n", - "But before you run the command, you must have the correct dependency. NVFLARE uses Tenseal library as HE dependency. By default it is optional dependency. \n", - "To use HE, you could install it with \n", + "\n", "\n", "```\n", " pip install nvflare[HE]\n", + " \n", "```\n", - "> Note\n", - " * Tenseal is not avaiable in Mac OS\n" + "\n", + "\n", + "> note: Tenseal is not avaiable in Mac OS\n" ] }, { diff --git a/integration/monai/README.md b/integration/monai/README.md index e8ef9a2be2..f1803fba72 100644 --- a/integration/monai/README.md +++ b/integration/monai/README.md @@ -9,10 +9,6 @@ Add `ClientAlgoExecutor` class to allow using MONAI's `ClientAlgo` class in fede Allow the use of bundles from the MONAI [model zoo](https://github.com/Project-MONAI/model-zoo) or custom configurations with NVFlare. -### Non-goals: - -n/a - ## Background MONAI allows the definition of AI models using the "[bundle](https://docs.monai.io/en/latest/bundle.html)" concept. It allows for easy experimentation and sharing of models that have been developed using MONAI. diff --git a/integration/monai/examples/spleen_ct_segmentation_local/README.md b/integration/monai/examples/spleen_ct_segmentation_local/README.md index 7ba9892029..bd7d481316 100644 --- a/integration/monai/examples/spleen_ct_segmentation_local/README.md +++ b/integration/monai/examples/spleen_ct_segmentation_local/README.md @@ -21,8 +21,13 @@ And go to the folder containing this tutorial To execute the below commands, please open a terminal and go to the folder containing this tutorial. -We recommend following the instructions for setting up a [virtual environment](../../../../examples/README.md#set-up-a-virtual-environment), -and using it in [JupyterLab](../../../../examples/README.md#Set-up-JupyterLab-for-notebooks) for running the notebooks the MONAI integration examples. +Follow the [setup](../../README.md#requirements) to create a virtual environment with the MONAI-NVFlare integration installed to use in JupyterLab. + +Install the required packages in your virtual environment: + +``` +pip install -r ./requirements.txt +``` ### 1. Download the Spleen Bundle @@ -102,9 +107,9 @@ By default, POC will create startup kits at `/tmp/nvflare/poc`. ### 3.3 Start FL system in POC mode -Then, start the FL system with all provisioned clients by running +Then in another terminal start the FL system in POC mode with all provisioned clients by running: ``` -nvflare poc start +nvflare poc start -ex admin@nvidia.com ``` ### 4.1 (Optional) Secure FL workspace @@ -149,37 +154,26 @@ For details about resource management and consumption, please refer to the [docu > **Note:** Full FL training could take several hours for this task. > To speed up your experimentation, you can reduce the `num_rounds` value in `config_fed_server.json`, e.g. to 5 rounds. -### 5.1 FLARE-MONAI Integration Experiment tracking +### 5.1 FLARE-MONAI Integration Experiment Tracking with MLflow Experiment tracking for the FLARE-MONAI integration now uses `NVFlareStatsHandler` to provide a set of Ignite Event-handlers to support both iteration and epoch-level events for automatic metric streaming. -In this example, the `spleen_ct_segmentation_local` job is configured to automatically log metrics to MLflow through the FL server. - -The `config_fed_client.json` contains the `NVFlareStatsHandler`, `MetricsSender`, and `MetricRelay` (with their respective pipes) to send the metrics to the server side as federated events. -Then in `config_fed_server.json`, the `MLflowReceiver` is configured for the server to receive the results in "mlruns" or via the tracking uri if specified. - -View the results by running the following command at the `mlruns/` directory in the workspace: - -``` -mlflow ui --port 5000 -``` -> **_NOTE:_** The receiver on the server side can be easily configured to support other experiment tracking formats. -> In addition to the `MLflowReceiver`, the `WandBReceiver` and `TBAnalyticsReceiver` can also be used in `config_fed_server.json` for Tensorboard and > Weights & Biases experiment tracking streaming to the server. +In this example, the `spleen_ct_segmentation_local` job is configured to automatically log metrics to MLflow through the FL server. -### 5.2 MONAI Experiment tracking with MLflow +- The `config_fed_client.json` contains the `NVFlareStatsHandler`, `MetricsSender`, and `MetricRelay` (with their respective pipes) to send the metrics to the server side as federated events. +- Then in `config_fed_server.json`, the `MLflowReceiver` is configured for the server to write the results to the default MLflow tracking server URI. -The `spleen_ct_segmentation_loc_non_agg` job is the previous configuration that uses MONAI's experiment [tracking feature](https://github.com/Project-MONAI/tutorials/tree/main/experiment_management) -with clients logging to the MLflow tracking server without going through the FL server. -For `spleen_ct_segmentation_loc_non_agg`, an MLflow server is expected, so in a new terminal, start the mlflow server with: +With this configuration the MLflow tracking server must be started before running the job: ``` mlflow server ``` -You can access the MLflow dashboard in your browser using the default tracking uri `http://127.0.0.1:5000`. +> **_NOTE:_** The receiver on the server side can be easily configured to support other experiment tracking formats. + In addition to the `MLflowReceiver`, the `WandBReceiver` and `TBAnalyticsReceiver` can also be used in `config_fed_server.json` for Tensorboard and Weights & Biases experiment tracking streaming to the server. -Next, submit the job. +Next, we can submit the job. -### 5.3 Federated averaging +### 5.2 Federated averaging To run FedAvg using with the Job CLI, submit the job with: @@ -207,13 +201,9 @@ You should see the cross-site validation results at [DOWNLOAD_DIR]/[JOB_ID]/workspace/cross_site_val/cross_val_results.json ``` -Once the training started, you can the experiment curves for the local clients in the current run on the MLflow dashboard. - -![MLflow dashboard](./mlflow.png) - -### 5.4 Secure aggregation using homomorphic encryption +### 5.3 Secure aggregation using homomorphic encryption -Next we run FedAvg using homomorphic encryption (HE) for secure aggregation on the server. +Alternatively we can run FedAvg using homomorphic encryption (HE) for secure aggregation on the server. > **_NOTE:_** For HE, we need to use the securely provisioned workspace. > It will also take longer due to the additional encryption, decryption, encrypted aggregation, @@ -225,3 +215,13 @@ Then, submit the job to run FedAvg with HE: ``` nvflare job submit -j jobs/spleen_ct_segementation_he ``` + +### 5.4 MLflow experiment tracking results + +To view the results, you can access the MLflow dashboard in your browser using the default tracking uri `http://127.0.0.1:5000`. + +> **_NOTE:_** To write the results to the server workspace instead of using the MLflow server, users can remove the `tracking_uri` argument from the `MLflowReceiver` configuration and instead view the results by running `mlflow ui --port 5000` in the directory that contains the `mlruns/` directory in the server workspace. + +Once the training is started, you can see the experiment curves for the local clients in the current run on the MLflow dashboard. + +![MLflow dashboard](./mlflow.png) \ No newline at end of file diff --git a/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_loc_non_agg/app/config/config_fed_client.json b/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_loc_non_agg/app/config/config_fed_client.json deleted file mode 100644 index a8bd20e94f..0000000000 --- a/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_loc_non_agg/app/config/config_fed_client.json +++ /dev/null @@ -1,49 +0,0 @@ -{ - "format_version": 2, - - "executors": [ - { - "tasks": [ - "train", "submit_model", "validate" - ], - "executor": { - "id": "executor", - "path": "monai_nvflare.client_algo_executor.ClientAlgoExecutor", - "args": { - "client_algo_id": "client_algo", - "key_metric": "val_mean_dice" - } - } - } - ], - - "task_result_filters": [ - ], - "task_data_filters": [ - ], - - "tracking": "mlflow", - "experiment_name": "monai_nvflare", - "tracking_uri": "http://127.0.0.1:5000", - - "components": [ - { - "id": "client_algo", - "path": "monai.fl.client.MonaiAlgo", - "args": { - "bundle_root": "config/spleen_ct_segmentation", - "local_epochs": 10, - "train_kwargs": { - "tracking": "{tracking}", - "tracking_uri": "{tracking_uri}", - "experiment_name": "{experiment_name}" - }, - "eval_kwargs": { - "tracking": "{tracking}", - "tracking_uri": "{tracking_uri}", - "experiment_name": "{experiment_name}" - } - } - } - ] -} diff --git a/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_loc_non_agg/app/config/config_fed_server.json b/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_loc_non_agg/app/config/config_fed_server.json deleted file mode 100644 index 581e3a8c26..0000000000 --- a/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_loc_non_agg/app/config/config_fed_server.json +++ /dev/null @@ -1,74 +0,0 @@ -{ - "format_version": 2, - - "min_clients": 2, - "num_rounds": 100, - - "task_data_filters": [], - "task_result_filters": [], - "components": [ - { - "id": "persistor", - "path": "monai_nvflare.monai_bundle_persistor.MonaiBundlePersistor", - "args": { - "bundle_root": "config/spleen_ct_segmentation" - } - }, - { - "id": "shareable_generator", - "name": "FullModelShareableGenerator", - "args": {} - }, - { - "id": "aggregator", - "name": "InTimeAccumulateWeightedAggregator", - "args": { - "expected_data_kind": "WEIGHT_DIFF" - } - }, - { - "id": "model_selector", - "name": "IntimeModelSelector", - "args": {} - }, - { - "id": "model_locator", - "name": "PTFileModelLocator", - "args": { - "pt_persistor_id": "persistor" - } - }, - { - "id": "json_generator", - "name": "ValidationJsonGenerator", - "args": {} - } - ], - "workflows": [ - { - "id": "scatter_gather_ctl", - "name": "ScatterAndGather", - "args": { - "min_clients" : "{min_clients}", - "num_rounds" : "{num_rounds}", - "start_round": 0, - "wait_time_after_min_received": 10, - "aggregator_id": "aggregator", - "persistor_id": "persistor", - "shareable_generator_id": "shareable_generator", - "train_task_name": "train", - "train_timeout": 0 - } - }, - { - "id": "cross_site_model_eval", - "name": "CrossSiteModelEval", - "args": { - "model_locator_id": "model_locator", - "submit_model_timeout": 600, - "validation_timeout": 6000, - "cleanup_models": true - } - } - ] -} diff --git a/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_loc_non_agg/meta.json b/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_loc_non_agg/meta.json deleted file mode 100644 index 4947562644..0000000000 --- a/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_loc_non_agg/meta.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "spleen-bundle", - "resource_spec": {}, - "min_clients" : 2, - "deploy_map": { - "app": [ - "@ALL" - ] - } -} diff --git a/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_local/app/config/config_fed_client.json b/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_local/app/config/config_fed_client.conf similarity index 100% rename from integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_local/app/config/config_fed_client.json rename to integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_local/app/config/config_fed_client.conf diff --git a/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_local/app/config/config_fed_server.json b/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_local/app/config/config_fed_server.conf similarity index 98% rename from integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_local/app/config/config_fed_server.json rename to integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_local/app/config/config_fed_server.conf index dae4608749..43650a1a9d 100644 --- a/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_local/app/config/config_fed_server.json +++ b/integration/monai/examples/spleen_ct_segmentation_local/jobs/spleen_ct_segmentation_local/app/config/config_fed_server.conf @@ -47,6 +47,7 @@ "id": "mlflow_receiver_with_tracking_uri", "path": "nvflare.app_opt.tracking.mlflow.mlflow_receiver.MLflowReceiver", "args": { + "tracking_uri": "http://127.0.0.1:5000", "kwargs": { "experiment_name": "monai-spleen-experiment", "run_name": "monai-spleen-with-mlflow", diff --git a/integration/monai/examples/spleen_ct_segmentation_local/requirements.txt b/integration/monai/examples/spleen_ct_segmentation_local/requirements.txt index f1b948f162..716276ccd5 100644 --- a/integration/monai/examples/spleen_ct_segmentation_local/requirements.txt +++ b/integration/monai/examples/spleen_ct_segmentation_local/requirements.txt @@ -3,7 +3,7 @@ nibabel fire pytorch-ignite>=0.4.10 monai>=1.3.0 -nvflare>=2.3.0 +nvflare~=2.4.0rc monai_nvflare>=0.2.3 tensorboard mlflow diff --git a/integration/monai/examples/spleen_ct_segmentation_sim/README.md b/integration/monai/examples/spleen_ct_segmentation_sim/README.md index 5c951bde8f..e7b8dd03ed 100644 --- a/integration/monai/examples/spleen_ct_segmentation_sim/README.md +++ b/integration/monai/examples/spleen_ct_segmentation_sim/README.md @@ -9,8 +9,13 @@ For an example with real-world deployment settings, see [here](../spleen_ct_segm To execute the below commands, please open a terminal and go to the folder containing this tutorial. -We recommend following the instructions for setting up a [virtual environment](../../../../examples/README.md#set-up-a-virtual-environment), -and using it in [JupyterLab](../../../../examples/README.md#Set-up-JupyterLab-for-notebooks) for running the notebooks the MONAI integration examples. +Follow the [setup](../../README.md#requirements) to create a virtual environment with the MONAI-NVFlare integration installed to use in JupyterLab. + +Install the required packages in your virtual environment: + +``` +pip install -r ./requirements.txt +``` ### 1. Download the Spleen Bundle diff --git a/integration/monai/examples/spleen_ct_segmentation_sim/requirements.txt b/integration/monai/examples/spleen_ct_segmentation_sim/requirements.txt index ff0407a24e..1b0abf8525 100644 --- a/integration/monai/examples/spleen_ct_segmentation_sim/requirements.txt +++ b/integration/monai/examples/spleen_ct_segmentation_sim/requirements.txt @@ -3,7 +3,7 @@ nibabel fire pytorch-ignite>=0.4.10 monai>=1.3.0 -nvflare>=2.3.0 +nvflare~=2.4.0rc monai_nvflare>=0.2.3 tensorboard scikit-image diff --git a/integration/monai/monai_nvflare/nvflare_stats_handler.py b/integration/monai/monai_nvflare/nvflare_stats_handler.py index 07d5204bd5..2a7ab7d840 100644 --- a/integration/monai/monai_nvflare/nvflare_stats_handler.py +++ b/integration/monai/monai_nvflare/nvflare_stats_handler.py @@ -171,7 +171,6 @@ def _default_epoch_sender(self, engine: Engine) -> None: current_epoch = self.global_epoch_transform(engine.state.epoch) summary_dict = engine.state.metrics for name, value in summary_dict.items(): - print(f"\n\t{name}", type(value), value) self._send_stats(engine, name, value, AnalyticsDataType.SCALAR, current_epoch) if self.state_attributes is not None: diff --git a/integration/nemo/README.md b/integration/nemo/README.md index 94a1dd4e7e..4cb4ed604b 100644 --- a/integration/nemo/README.md +++ b/integration/nemo/README.md @@ -1,63 +1,16 @@ # NeMo Integration -## Objective -Execute [NVIDIA NeMo™](https://developer.nvidia.com/nemo) in federated environments. - -### Goals: - -Allow NeMo models to be trained and adapted with NVFlare. - -### Non-goals: - -n/a - -## Background -NVIDIA NeMo™ is an end-to-end cloud-native enterprise framework for developers to +[NVIDIA NeMo™](https://developer.nvidia.com/nemo) is an end-to-end cloud-native enterprise framework for developers to build, customize, and deploy generative AI models with billions of parameters. -## Description -NVFlare utilizes features from NeMo, such as prompt learning to run LLM tasks in federated environments. - -### Examples - -For an example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) with NeMo for prompt learning, -see [examples/prompt_learning](examples/prompt_learning/README.md) - -For an example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) with NeMo for supervised fine-tuning (SFT), -see [examples/supervised_fine_tuning](examples/supervised_fine_tuning/README.md) +Here, we show how NVFlare utilizes features from NeMo to run LLM tasks in federated environments with several [examples](./examples). ## Requirements -### Using docker -For simplicity, we recommend using NVIDIA's docker containers that include all the requirements for running NeMo models. -``` -docker pull nvcr.io/nvidia/nemo:23.02 -``` - -### Install NeMo-NVFlare package - - - -#### Mount the source code -For easy development with NeMo, install NVFlare and mount the code inside the folder. -``` -pip install nvflare>=2.3.0 -export PYTHONPATH=${PWD} -``` +### Using docker (Recommended) +For simplicity, we recommend using NVIDIA's [NeMo docker containers](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) that include all the requirements for running NeMo models. - +> Note: each example in this folder might require different container version. Please check their Readmes for details. ### Installation in a virtual environment @@ -68,4 +21,4 @@ and using it in [JupyterLab](../../examples/README.md#notebooks) for running the notebooks in the NeMo integration examples. Follow the NeMo installation steps [here](https://github.com/NVIDIA/NeMo#installation) -before installing the NeMo-NVFlare package. +before installing NVFlare and adding the source to the PYTHONPATH. diff --git a/integration/nemo/examples/README.md b/integration/nemo/examples/README.md index 4e7ed42f32..2e5bd21c2a 100644 --- a/integration/nemo/examples/README.md +++ b/integration/nemo/examples/README.md @@ -1,16 +1,16 @@ # Examples of NeMo-NVFlare Integration ### [Parameter-Efficient Fine-Tuning (PEFT) with NeMo](./peft/README.md) -In this example, we utilize NeMo's [PEFT](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/nlp/nemo_megatron/peft/landing_page.html) +In this example, we utilize NeMo's [PEFT](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/nlp/nemo_megatron/peft/landing_page.html) using NVFlare's new Client API (minimal code changes required to run a NeMo script in FL) methods to showcase how to adapt a large language model (LLM) to a downstream task, such as financial sentiment predictions. -### [Prompt learning with NeMo and NVFlare](./prompt_learning/README.md) -An example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) -with NeMo for [prompt learning](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/nlp/nemo_megatron/prompt_learning.html) -to adapt a large language model (LLM) to a downstream task. - -### [Supervised fine-tuning (SFT) with NeMo and NVFlare](./prompt_learning/README.md) +### [Supervised fine-tuning (SFT) with NeMo and NVFlare](./supervised_fine_tuning/README.md) An example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) with NeMo for [supervised fine-tuning (SFT)](https://github.com/NVIDIA/NeMo-Megatron-Launcher#5152-sft-training) to fine-tune all parameters of a large language model (LLM) on supervised data to teach the model how to follow user specified instructions. + +### [Prompt learning with NeMo and NVFlare](./prompt_learning/README.md) +An example of using [NVIDIA FLARE](https://nvflare.readthedocs.io/en/main/index.html) +with NeMo for [prompt learning](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/nlp/nemo_megatron/prompt_learning.html) using NVFlare's Learner API +to adapt a large language model (LLM) to a downstream task. diff --git a/integration/nemo/examples/peft/README.md b/integration/nemo/examples/peft/README.md index c5f1085cbb..01d116272d 100644 --- a/integration/nemo/examples/peft/README.md +++ b/integration/nemo/examples/peft/README.md @@ -10,17 +10,34 @@ that condition the model to produce the desired output for the downstream task. For more details, see the [PEFT script](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/tuning/megatron_gpt_peft_tuning.py) in NeMo, which we adapt using NVFlare's Lightning client API to run in a federated scenario. ## Dependencies -We assume you followed the instructions [here](../../README.md#requirements) -to install the NeMo, NVFlare, and the NeMo-NVFlare package. +The example was tested with the [NeMo 23.10 container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo). +In the following, we assume this example folder of the container is mounted to `/workspace` and all downloading, etc. operations are based on this root path. -The example was tested with the main branch of [NeMo](https://github.com/NVIDIA/NeMo). +> Note in the following, mount both the [current directory](./) and the [job_templates](../../../../job_templates) +> directory to locations inside the docker container. Please make sure you have cloned the full NVFlare repo. + +Start the docker container from **this directory** using +``` +# cd NVFlare/integration/nemo/examples/peft +DOCKER_IMAGE="nvcr.io/nvidia/nemo:23.10" +docker run --runtime=nvidia -it --rm --shm-size=16g -p 8888:8888 -p 6006:6006 --ulimit memlock=-1 --ulimit stack=67108864 \ +-v ${PWD}/../../../../job_templates:/job_templates -v ${PWD}:/workspace -w /workspace ${DOCKER_IMAGE} +``` + +For easy experimentation with NeMo, install NVFlare and mount the code inside the [nemo_nvflare](./nemo_nvflare) folder. +``` +pip install nvflare~=2.4.0rc7 +export PYTHONPATH=${PYTHONPATH}:/workspace +``` ## Examples ### 1. Federated PEFT using a 345 million parameter GPT model -This example requires a GPU with at least 24GB memory to run three clients in parallel on the same GPU. We use [JupyterLab](https://jupyterlab.readthedocs.io) for this example. To start JupyterLab, run ``` jupyter lab . ``` and open [peft.ipynb](./peft.ipynb). + +#### Hardware requirement +This example requires a GPU with at least 24GB memory to run three clients in parallel on the same GPU. diff --git a/integration/nemo/examples/peft/nemo_nvflare/__init__.py b/integration/nemo/examples/peft/nemo_nvflare/__init__.py new file mode 100644 index 0000000000..d6050992d1 --- /dev/null +++ b/integration/nemo/examples/peft/nemo_nvflare/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .peft_model import PEFTmodel diff --git a/integration/nemo/examples/peft/code/megatron_gpt_peft_fl_eval_config.yaml b/integration/nemo/examples/peft/nemo_nvflare/megatron_gpt_peft_fl_eval_config.yaml similarity index 100% rename from integration/nemo/examples/peft/code/megatron_gpt_peft_fl_eval_config.yaml rename to integration/nemo/examples/peft/nemo_nvflare/megatron_gpt_peft_fl_eval_config.yaml diff --git a/integration/nemo/examples/peft/code/megatron_gpt_peft_tuning.py b/integration/nemo/examples/peft/nemo_nvflare/megatron_gpt_peft_tuning.py similarity index 100% rename from integration/nemo/examples/peft/code/megatron_gpt_peft_tuning.py rename to integration/nemo/examples/peft/nemo_nvflare/megatron_gpt_peft_tuning.py diff --git a/integration/nemo/examples/peft/code/megatron_gpt_peft_tuning_config.yaml b/integration/nemo/examples/peft/nemo_nvflare/megatron_gpt_peft_tuning_config.yaml similarity index 100% rename from integration/nemo/examples/peft/code/megatron_gpt_peft_tuning_config.yaml rename to integration/nemo/examples/peft/nemo_nvflare/megatron_gpt_peft_tuning_config.yaml diff --git a/integration/nemo/nemo_nvflare/peft_model.py b/integration/nemo/examples/peft/nemo_nvflare/peft_model.py similarity index 100% rename from integration/nemo/nemo_nvflare/peft_model.py rename to integration/nemo/examples/peft/nemo_nvflare/peft_model.py diff --git a/integration/nemo/examples/peft/nemo_nvflare/utils.py b/integration/nemo/examples/peft/nemo_nvflare/utils.py new file mode 100644 index 0000000000..7ca186eae5 --- /dev/null +++ b/integration/nemo/examples/peft/nemo_nvflare/utils.py @@ -0,0 +1,34 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + + +def convert_global_to_ckpt(global_model_filepath: str, ckpt_path: str): + """Helper function to convert global models saved by NVFlare to NeMo ckpt format""" + + nvflare_ckpt = torch.load(global_model_filepath) + if "train_conf" in nvflare_ckpt: + print("Loaded NVFlare global checkpoint with train_conf", nvflare_ckpt["train_conf"]) + + assert ( + "model" in nvflare_ckpt + ), f"Expected global model to contain a 'model' key but it only had {list(nvflare_ckpt.keys())}" + global_weights = nvflare_ckpt["model"] + + torch.save({"state_dict": global_weights}, ckpt_path) + + print(f"Saved NeMo ckpt with {len(global_weights)} entries to {ckpt_path}") + diff --git a/integration/nemo/examples/peft/peft.ipynb b/integration/nemo/examples/peft/peft.ipynb index a86d28b4b5..ba31f585b9 100644 --- a/integration/nemo/examples/peft/peft.ipynb +++ b/integration/nemo/examples/peft/peft.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "0c534975", + "id": "5020dd81", "metadata": {}, "source": [ "# Parameter-Efficient Fine-Tuning (PEFT) with NeMo\n", @@ -19,17 +19,17 @@ }, { "cell_type": "markdown", - "id": "513eb148", + "id": "dc9769ef", "metadata": {}, "source": [ "## Dependencies\n", - "We assume you followed the instructions [here](../../README.md#requirements) \n", - "to install the NeMo framework and the NeMo-NVFlare package. " + "We assume you followed the instructions [here](./README.md) \n", + "to install the NeMo and NVFlare frameworks and mount the required codes." ] }, { "cell_type": "markdown", - "id": "bb97927a", + "id": "dab4c639", "metadata": {}, "source": [ "## Download the pre-trained LLM\n", @@ -39,7 +39,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c2f6c8b5", + "id": "20921eea", "metadata": {}, "outputs": [], "source": [ @@ -51,7 +51,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2036e09e", + "id": "aa852d07", "metadata": {}, "outputs": [], "source": [ @@ -66,7 +66,7 @@ }, { "cell_type": "markdown", - "id": "67f48638", + "id": "fa530a42", "metadata": {}, "source": [ "## Data preprocessing\n", @@ -74,12 +74,12 @@ "\n", "The Financial PhraseBank dataset contains the sentiments for financial news headlines from a retail investor's perspective. Further details about the dataset can be found in Malo et al.'s [\"Good Debt or Bad Debt: Detecting Semantic Orientations in Economic Texts\"](https://arxiv.org/abs/1307.5336).\n", "\n", - "We can configure the prompt template used by NeMo to solve this downstream task by setting `prompt_template: \"{sentence} sentiment: {label}\"` in [megatron_gpt_peft_tuning_config.yaml](./code/megatron_gpt_peft_tuning_config.yaml) accordingly." + "We can configure the prompt template used by NeMo to solve this downstream task by setting `prompt_template: \"{sentence} sentiment: {label}\"` in [megatron_gpt_peft_tuning_config.yaml](./nemo_nvflare/megatron_gpt_peft_tuning_config.yaml) accordingly." ] }, { "cell_type": "markdown", - "id": "29dd0470", + "id": "b5737e50", "metadata": {}, "source": [ "#### 1. Download the preprocessing scripts\n", @@ -89,7 +89,7 @@ { "cell_type": "code", "execution_count": null, - "id": "f37039ed", + "id": "b2c32fa5", "metadata": {}, "outputs": [], "source": [ @@ -102,7 +102,7 @@ }, { "cell_type": "markdown", - "id": "f1b1a07b", + "id": "13a2f952", "metadata": {}, "source": [ "#### 2. Download the Financial PhraseBank Dataset\n", @@ -114,7 +114,7 @@ }, { "cell_type": "markdown", - "id": "f335899e", + "id": "40199807", "metadata": {}, "source": [ "#### 3. Preprocess the dataset" @@ -123,7 +123,7 @@ { "cell_type": "code", "execution_count": null, - "id": "dc66ef42", + "id": "80f84586", "metadata": {}, "outputs": [], "source": [ @@ -132,7 +132,7 @@ }, { "cell_type": "markdown", - "id": "365a58c8", + "id": "d9f8fa9a", "metadata": {}, "source": [ "#### 4. Split the dataset to simulate clients\n", @@ -143,7 +143,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3f9214af", + "id": "a6725683", "metadata": {}, "outputs": [], "source": [ @@ -160,7 +160,7 @@ }, { "cell_type": "markdown", - "id": "cc85565d", + "id": "6c506c6b", "metadata": {}, "source": [ "Below are some examples of how the training data is distributed amount the three clients when using different values of `alpha`.\n", @@ -173,7 +173,7 @@ }, { "cell_type": "markdown", - "id": "3eea187a", + "id": "704ff05d", "metadata": {}, "source": [ "## Federated learning simulations\n", @@ -187,7 +187,7 @@ }, { "cell_type": "markdown", - "id": "aa23b7c7", + "id": "01fce4ae", "metadata": {}, "source": [ "#### 1. Convert NeMo PEFT script to FL\n", @@ -205,7 +205,7 @@ "\"Drawing\"\n", "\n", "\n", - "You can directly use all the PEFT methods implemented in the NeMo script, by changing the value of [peft_scheme](./code/megatron_gpt_peft_tuning_config.yaml) in the client configuration shown below accordingly:\n", + "You can directly use all the PEFT methods implemented in the NeMo script, by changing the value of [peft_scheme](./nemo_nvflare/megatron_gpt_peft_tuning_config.yaml) in the client configuration shown below accordingly:\n", "* p-tuning\n", "* adapter + p-tuning\n", "* adapter\n", @@ -221,7 +221,7 @@ }, { "cell_type": "markdown", - "id": "95b07067", + "id": "655a1f0a", "metadata": {}, "source": [ "#### 1. Local training\n", @@ -236,16 +236,16 @@ { "cell_type": "code", "execution_count": null, - "id": "b6e001c1", + "id": "51e4fb4d", "metadata": {}, "outputs": [], "source": [ - "!nvflare config -jt ../../../../job_templates" + "!nvflare config -jt /job_templates" ] }, { "cell_type": "markdown", - "id": "f3528af2", + "id": "2e515dc2", "metadata": {}, "source": [ "Then, create the job and configure it for simulating local training." @@ -254,8 +254,10 @@ { "cell_type": "code", "execution_count": null, - "id": "9905ebaa", - "metadata": {}, + "id": "404fe5fe", + "metadata": { + "scrolled": true + }, "outputs": [], "source": [ "import os\n", @@ -272,7 +274,7 @@ "num_rounds=1\n", "trainer_config=\"trainer.max_steps\\=1000 trainer.val_check_interval\\=100\"\n", "\n", - "!nvflare job create -force -j \"./jobs/peft_{peft_scheme}_local_345M\" -w \"sag_nemo\" -sd \"code\" \\\n", + "!nvflare job create -force -j \"./jobs/peft_{peft_scheme}_local_345M\" -w \"sag_nemo\" -sd \"nemo_nvflare\" \\\n", " -f app_1/config_fed_client.conf app_script={app_script} app_config=\"{peft_scheme_arg} model.restore_from_path\\={restore_from_path} {trainer_config} {val_files} {train_files_prefix}-1.jsonl\\]\" \\\n", " -f app_2/config_fed_client.conf app_script={app_script} app_config=\"{peft_scheme_arg} model.restore_from_path\\={restore_from_path} {trainer_config} {val_files} {train_files_prefix}-2.jsonl\\]\" \\\n", " -f app_3/config_fed_client.conf app_script={app_script} app_config=\"{peft_scheme_arg} model.restore_from_path\\={restore_from_path} {trainer_config} {val_files} {train_files_prefix}-3.jsonl\\]\" \\\n", @@ -281,7 +283,7 @@ }, { "cell_type": "markdown", - "id": "945b7d71", + "id": "df9ca0a5", "metadata": {}, "source": [ "Next, simulate each client training on their local dataset using the FL simulator. To do this, we only run 1 round of FL, with each client running 1000 steps on their local dataset." @@ -290,12 +292,16 @@ { "cell_type": "code", "execution_count": null, - "id": "09ef104c", + "id": "8d7f4970", "metadata": { "scrolled": true }, "outputs": [], "source": [ + "# required by NeMo models\n", + "import torch.multiprocessing as mp\n", + "mp.set_start_method(\"spawn\", force=True)\n", + "\n", "from nvflare import SimulatorRunner \n", "\n", "simulator = SimulatorRunner(\n", @@ -310,7 +316,7 @@ }, { "cell_type": "markdown", - "id": "bccf7bed", + "id": "2e56653f", "metadata": {}, "source": [ "#### 2. Federated training\n", @@ -323,7 +329,7 @@ { "cell_type": "code", "execution_count": null, - "id": "782af9c0", + "id": "ad3406a6", "metadata": { "scrolled": true }, @@ -333,7 +339,7 @@ "num_rounds=5\n", "trainer_config=\"trainer.max_steps\\=200 trainer.val_check_interval\\=100\"\n", "\n", - "!nvflare job create -force -j \"./jobs/peft_{peft_scheme}_fedavg_345M\" -w \"sag_nemo\" -sd \"code\" \\\n", + "!nvflare job create -force -j \"./jobs/peft_{peft_scheme}_fedavg_345M\" -w \"sag_nemo\" -sd \"nemo_nvflare\" \\\n", " -f app_1/config_fed_client.conf app_script={app_script} app_config=\"{peft_scheme_arg} model.restore_from_path\\={restore_from_path} {trainer_config} {val_files} {train_files_prefix}-1.jsonl\\]\" \\\n", " -f app_2/config_fed_client.conf app_script={app_script} app_config=\"{peft_scheme_arg} model.restore_from_path\\={restore_from_path} {trainer_config} {val_files} {train_files_prefix}-2.jsonl\\]\" \\\n", " -f app_3/config_fed_client.conf app_script={app_script} app_config=\"{peft_scheme_arg} model.restore_from_path\\={restore_from_path} {trainer_config} {val_files} {train_files_prefix}-3.jsonl\\]\" \\\n", @@ -342,7 +348,7 @@ }, { "cell_type": "markdown", - "id": "41088905", + "id": "5e591653", "metadata": {}, "source": [ "Next, simulate the federated training using FedAvg. " @@ -351,12 +357,16 @@ { "cell_type": "code", "execution_count": null, - "id": "00109b1e", + "id": "8559b79f", "metadata": { "scrolled": true }, "outputs": [], "source": [ + "# required by NeMo models\n", + "import torch.multiprocessing as mp\n", + "mp.set_start_method(\"spawn\", force=True)\n", + "\n", "from nvflare import SimulatorRunner \n", "\n", "simulator = SimulatorRunner(\n", @@ -371,25 +381,15 @@ }, { "cell_type": "markdown", - "id": "d3d8d656", - "metadata": {}, - "source": [ - "You can visualize the training process using TensorBoard" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7f6755b6", + "id": "3e20ca56", "metadata": {}, - "outputs": [], "source": [ - "!tensorboard --logdir /tmp/nvflare/nemo" + "You can visualize the training process using TensorBoard by running `tensorboard --logdir /tmp/nvflare/nemo` in a new terminal." ] }, { "cell_type": "markdown", - "id": "d0c35f89", + "id": "a8e5b7c0", "metadata": {}, "source": [ "## Results\n", @@ -405,7 +405,7 @@ }, { "cell_type": "markdown", - "id": "7174a47a", + "id": "65833f4b", "metadata": {}, "source": [ "## Inference\n", @@ -417,7 +417,7 @@ { "cell_type": "code", "execution_count": null, - "id": "72d1d6e9", + "id": "dcf08bc6", "metadata": {}, "outputs": [], "source": [ @@ -432,7 +432,7 @@ }, { "cell_type": "markdown", - "id": "afe4ed67", + "id": "7b3667c0", "metadata": {}, "source": [ "First, we need to convert the best global PEFT model into a NeMo ckpt." @@ -441,7 +441,7 @@ { "cell_type": "code", "execution_count": null, - "id": "54f93b59", + "id": "3d08150a", "metadata": {}, "outputs": [], "source": [ @@ -456,7 +456,7 @@ }, { "cell_type": "markdown", - "id": "6311edbd", + "id": "2963d08e", "metadata": {}, "source": [ "Next, we will load the global model." @@ -465,7 +465,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5b07ecbc", + "id": "9ffe513d", "metadata": {}, "outputs": [], "source": [ @@ -475,7 +475,7 @@ "from omegaconf import OmegaConf\n", "\n", "# Load model configuration inference of the global model\n", - "cfg = OmegaConf.load(\"code/megatron_gpt_peft_fl_eval_config.yaml\")\n", + "cfg = OmegaConf.load(\"nemo_nvflare/megatron_gpt_peft_fl_eval_config.yaml\")\n", "\n", "# Build trainer\n", "trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer()\n", @@ -499,7 +499,7 @@ }, { "cell_type": "markdown", - "id": "b6b00b36", + "id": "59fa62cb", "metadata": {}, "source": [ "Run the model" @@ -508,7 +508,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c03a073d", + "id": "acd89469", "metadata": {}, "outputs": [], "source": [ @@ -535,7 +535,7 @@ }, { "cell_type": "markdown", - "id": "b9d8fd7c", + "id": "d14026fc", "metadata": {}, "source": [ "The expected output of a well-trained model looks something like this. Note, the test sentences do not include ground truth labels.\n", @@ -555,7 +555,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3e7aaaa5", + "id": "db70a19a", "metadata": {}, "outputs": [], "source": [] diff --git a/integration/nemo/examples/prompt_learning/README.md b/integration/nemo/examples/prompt_learning/README.md index a19fc8744e..0e2125a92f 100644 --- a/integration/nemo/examples/prompt_learning/README.md +++ b/integration/nemo/examples/prompt_learning/README.md @@ -3,7 +3,7 @@ In this example, we utilize NeMo's [prompt learning](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/nlp/nemo_megatron/prompt_learning.html) feature to showcase how to adapt a large language model (LLM) to a downstream task such as financial sentiment predictions. -As the prompt learning technique shown in the example is p-tuning which adds a small prompt encoder network to the LLM +As the prompt learning technique shown in the example is p-tuning, which adds a small prompt encoder network to the LLM to produce virtual tokens that guide the model toward the desired output of the downstream task. @@ -13,14 +13,25 @@ In our federated implementation, the LLM parameters stay fixed. Prompt encoder p ## Dependencies -We assume you followed the instructions [here](../../README.md#requirements) -to install the NeMo, NVFlare, and the NeMo-NVFlare package. - The example was tested with the [NeMo 23.02 container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo). +In the following, we assume this example folder of the container is mounted to `/workspace` and all downloading, etc. operations are based on this root path. + +Start the docker container from **this directory** using +``` +# cd NVFlare/integration/nemo/examples/prompt_learning +DOCKER_IMAGE="nvcr.io/nvidia/nemo:23.02" +docker run --runtime=nvidia -it --rm --shm-size=16g -p 8888:8888 -p 6006:6006 --ulimit memlock=-1 --ulimit stack=67108864 \ +-v ${PWD}:/workspace -w /workspace ${DOCKER_IMAGE} +``` + +For easy experimentation with NeMo, install NVFlare and mount the code inside the [nemo_nvflare](./nemo_nvflare) folder. +``` +pip install nvflare~=2.4.0rc7 +export PYTHONPATH=${PYTHONPATH}:/workspace +``` ## Examples ### 1. Federated p-tuning using a 345 million parameter GPT model -This example requires a GPU with at least 16GB memory to run three clients in parallel on the same GPU. We use [JupyterLab](https://jupyterlab.readthedocs.io) for this example. To start JupyterLab, run ``` @@ -28,9 +39,14 @@ jupyter lab . ``` and open [prompt_learning.ipynb](./prompt_learning.ipynb). +#### Hardware requirement +This example requires a GPU with at least 16GB of memory to run three clients in parallel on the same GPU. + ### 2. Federated p-tuning using a 20 billion parameter GPT model -This example running a 20B GPT model requires more computational resources. -To run three clients in parallel, we require at least six GPUs with 64 GB memory or more each -(Ampere or later GPU architecture). +This example of running a 20B GPT model requires more computational resources. To run the example, follow the instructions in [prompt_learning_20B.md](prompt_learning_20B.md). + +#### Hardware requirement +To run three clients in parallel, we require at least six GPUs with 64 GB memory or more each +(Ampere or later GPU architecture). diff --git a/integration/nemo/nemo_nvflare/__init__.py b/integration/nemo/examples/prompt_learning/nemo_nvflare/__init__.py similarity index 84% rename from integration/nemo/nemo_nvflare/__init__.py rename to integration/nemo/examples/prompt_learning/nemo_nvflare/__init__.py index f109d45c1c..802119c693 100644 --- a/integration/nemo/nemo_nvflare/__init__.py +++ b/integration/nemo/examples/prompt_learning/nemo_nvflare/__init__.py @@ -13,12 +13,8 @@ # limitations under the License. from .config_sharer import ConfigSharer -from .config_sharer_sft import ConfigSharerSFT from .fed_megatron_gpt_prompt_learning_model import FedMegatronGPTPromptLearningModel from .learner_executor import NemoLearnerExecutor from .prompt_encoder import ServerPromptEncoder from .prompt_learner import PromptLearner -from .server_sft_model import ServerSFTModel -from .sft_learner import SFTLearner from .share_config import ShareConfig -from .share_config_sft import ShareConfigSFT diff --git a/integration/nemo/nemo_nvflare/config_sharer.py b/integration/nemo/examples/prompt_learning/nemo_nvflare/config_sharer.py similarity index 100% rename from integration/nemo/nemo_nvflare/config_sharer.py rename to integration/nemo/examples/prompt_learning/nemo_nvflare/config_sharer.py diff --git a/integration/nemo/nemo_nvflare/constants.py b/integration/nemo/examples/prompt_learning/nemo_nvflare/constants.py similarity index 100% rename from integration/nemo/nemo_nvflare/constants.py rename to integration/nemo/examples/prompt_learning/nemo_nvflare/constants.py diff --git a/integration/nemo/nemo_nvflare/fed_megatron_gpt_prompt_learning_model.py b/integration/nemo/examples/prompt_learning/nemo_nvflare/fed_megatron_gpt_prompt_learning_model.py similarity index 100% rename from integration/nemo/nemo_nvflare/fed_megatron_gpt_prompt_learning_model.py rename to integration/nemo/examples/prompt_learning/nemo_nvflare/fed_megatron_gpt_prompt_learning_model.py diff --git a/integration/nemo/nemo_nvflare/learner_executor.py b/integration/nemo/examples/prompt_learning/nemo_nvflare/learner_executor.py similarity index 100% rename from integration/nemo/nemo_nvflare/learner_executor.py rename to integration/nemo/examples/prompt_learning/nemo_nvflare/learner_executor.py diff --git a/integration/nemo/nemo_nvflare/prompt_encoder.py b/integration/nemo/examples/prompt_learning/nemo_nvflare/prompt_encoder.py similarity index 100% rename from integration/nemo/nemo_nvflare/prompt_encoder.py rename to integration/nemo/examples/prompt_learning/nemo_nvflare/prompt_encoder.py diff --git a/integration/nemo/nemo_nvflare/prompt_learner.py b/integration/nemo/examples/prompt_learning/nemo_nvflare/prompt_learner.py similarity index 100% rename from integration/nemo/nemo_nvflare/prompt_learner.py rename to integration/nemo/examples/prompt_learning/nemo_nvflare/prompt_learner.py diff --git a/integration/nemo/nemo_nvflare/share_config.py b/integration/nemo/examples/prompt_learning/nemo_nvflare/share_config.py similarity index 100% rename from integration/nemo/nemo_nvflare/share_config.py rename to integration/nemo/examples/prompt_learning/nemo_nvflare/share_config.py diff --git a/integration/nemo/nemo_nvflare/utils.py b/integration/nemo/examples/prompt_learning/nemo_nvflare/utils.py similarity index 100% rename from integration/nemo/nemo_nvflare/utils.py rename to integration/nemo/examples/prompt_learning/nemo_nvflare/utils.py diff --git a/integration/nemo/examples/prompt_learning/prompt_learning.ipynb b/integration/nemo/examples/prompt_learning/prompt_learning.ipynb index ef51377001..66f80a2dc9 100644 --- a/integration/nemo/examples/prompt_learning/prompt_learning.ipynb +++ b/integration/nemo/examples/prompt_learning/prompt_learning.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "b43584a8", + "id": "56e442d4", "metadata": {}, "source": [ "# Prompt Learning with NeMo\n", @@ -14,22 +14,28 @@ "The prompt learning technique shown in the example is [p-tuning](https://arxiv.org/abs/2103.10385), which adds a small prompt encoder network to the LLM\n", "to produce virtual token embeddings that guide the model toward the desired output of the downstream task.\n", "\n", - "For more details on how to change hyperparameters for prompt learning in NeMo, see this [tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/nlp/Multitask_Prompt_and_PTuning.ipynb) which is also the basis for this NVFlare tutorial." + "For more details on how to change hyperparameters for prompt learning in NeMo, see this [tutorial](https://github.com/NVIDIA/NeMo/blob/main/tutorials/nlp/Multitask_Prompt_and_PTuning.ipynb) which is also the basis for this NVFlare tutorial.\n", + "\n", + "\n", + "\n", + "In our federated implementation, the LLM parameters stay fixed. Prompt encoder parameters are trained/updated and averaged on the server.\n", + "\n", + "" ] }, { "cell_type": "markdown", - "id": "578585e4", + "id": "6dac11e2", "metadata": {}, "source": [ "## Dependencies\n", - "We assume you followed the instructions [here](../../README.md#requirements) \n", - "to install the NeMo framework and the NeMo-NVFlare package. " + "We assume you followed the instructions [here](./README.md) \n", + "to install the NeMo and NVFlare frameworks and mount the required codes." ] }, { "cell_type": "markdown", - "id": "199f1fe5", + "id": "47b1d4dc", "metadata": {}, "source": [ "## Download the pre-trained LLM\n", @@ -39,7 +45,7 @@ { "cell_type": "code", "execution_count": null, - "id": "4ac31bab", + "id": "581035ee", "metadata": {}, "outputs": [], "source": [ @@ -51,7 +57,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9a14ccb9", + "id": "154be2b0", "metadata": {}, "outputs": [], "source": [ @@ -66,7 +72,7 @@ }, { "cell_type": "markdown", - "id": "9b4b0a65", + "id": "9a420d6e", "metadata": {}, "source": [ "## Data preprocessing\n", @@ -77,7 +83,7 @@ }, { "cell_type": "markdown", - "id": "f4a845d4", + "id": "3b3d4155", "metadata": {}, "source": [ "#### 1. Download the preprocessing scripts\n", @@ -87,7 +93,7 @@ { "cell_type": "code", "execution_count": null, - "id": "3a00456f", + "id": "f5c33254", "metadata": {}, "outputs": [], "source": [ @@ -100,19 +106,31 @@ }, { "cell_type": "markdown", - "id": "353a28e0", + "id": "974248b8", "metadata": {}, "source": [ "#### 2. Download the Financial PhraseBank Dataset\n", "\n", "Download the `FinancialPhraseBank-v1.0.zip` dataset from [here](https://www.researchgate.net/profile/Pekka_Malo/publication/251231364_FinancialPhraseBank-v1.0/data/0c96051eee4fb1d56e000000/FinancialPhraseBank-v1.0.zip).\n", "\n", - "Then extract it under `./data`." + "Then extract it under `./data`. Note, after extraction, the data folder should have the following content\n", + "```\n", + "data\n", + "├── FinancialPhraseBank-v1.0\n", + "│   ├── License.txt\n", + "│   ├── README.txt\n", + "│   ├── Sentences_50Agree.txt\n", + "│   ├── Sentences_66Agree.txt\n", + "│   ├── Sentences_75Agree.txt\n", + "│   └── Sentences_AllAgree.txt\n", + "├── FinancialPhraseBank-v1.0.zip\n", + "└── split_financial_phrase_data.py\n", + "```" ] }, { "cell_type": "markdown", - "id": "12bb6682", + "id": "b1f8ad50", "metadata": {}, "source": [ "#### 3. Preprocess the dataset" @@ -121,7 +139,7 @@ { "cell_type": "code", "execution_count": null, - "id": "2ceb4180", + "id": "fbbf86af", "metadata": {}, "outputs": [], "source": [ @@ -130,7 +148,7 @@ }, { "cell_type": "markdown", - "id": "baa61a74", + "id": "29aaffee", "metadata": {}, "source": [ "#### 4. Split the dataset to simulate clients\n", @@ -140,7 +158,7 @@ { "cell_type": "code", "execution_count": null, - "id": "339884a1", + "id": "725115cc", "metadata": {}, "outputs": [], "source": [ @@ -149,7 +167,7 @@ }, { "cell_type": "markdown", - "id": "cbcab01b", + "id": "7a45c985", "metadata": {}, "source": [ "## Federated learning simulations\n", @@ -161,7 +179,7 @@ }, { "cell_type": "markdown", - "id": "4fbc7c4c", + "id": "8b56fc06", "metadata": {}, "source": [ "#### 1. Local P-Tuning\n", @@ -172,7 +190,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14896baa", + "id": "7dd7e496", "metadata": {}, "outputs": [], "source": [ @@ -181,7 +199,7 @@ }, { "cell_type": "markdown", - "id": "1ea16b74", + "id": "1529d090", "metadata": {}, "source": [ "Next, simulate each client p-tuning on their local dataset using the FL simulator. To do this, we only run 1 round of FL, with each client running 50 p-tuning epochs on their local dataset." @@ -190,7 +208,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5abf5055", + "id": "6a05bdd8", "metadata": { "scrolled": true }, @@ -210,7 +228,7 @@ }, { "cell_type": "markdown", - "id": "f0bb49cb", + "id": "6456327c", "metadata": {}, "source": [ "#### 2. Federated P-Tuning\n", @@ -221,7 +239,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5151467a", + "id": "c6ec1399", "metadata": {}, "outputs": [], "source": [ @@ -230,7 +248,7 @@ }, { "cell_type": "markdown", - "id": "eadb0a5c", + "id": "e5083061", "metadata": {}, "source": [ "Next, simulate the federated p-tuning using FedAvg. Here, each client p-tunes for one local epoch before sending their local model updates to the server for aggregation. This is repeated for 50 FL rounds." @@ -239,7 +257,7 @@ { "cell_type": "code", "execution_count": null, - "id": "eea2c83a", + "id": "38609850", "metadata": { "scrolled": true }, @@ -259,25 +277,15 @@ }, { "cell_type": "markdown", - "id": "a9276ce2", + "id": "9fc069d3", "metadata": {}, "source": [ - "You can visualize the training process using TensorBoard" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5c93483c", - "metadata": {}, - "outputs": [], - "source": [ - "!tensorboard --logdir /tmp/nvflare/nemo" + "You can visualize the training process using TensorBoard by running `tensorboard --logdir /tmp/nvflare/nemo` in a new terminal." ] }, { "cell_type": "markdown", - "id": "0e6763ca", + "id": "aad25a10", "metadata": {}, "source": [ "## Results\n", @@ -288,7 +296,7 @@ }, { "cell_type": "markdown", - "id": "639e95aa", + "id": "fbbac75c", "metadata": {}, "source": [ "## Inference\n", @@ -300,7 +308,7 @@ { "cell_type": "code", "execution_count": null, - "id": "38ae679d", + "id": "52bb91c3", "metadata": {}, "outputs": [], "source": [ @@ -315,7 +323,7 @@ }, { "cell_type": "markdown", - "id": "23ce4e16", + "id": "e5740fbf", "metadata": {}, "source": [ "Next, we will load the global model." @@ -324,7 +332,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c02f046e", + "id": "605d0d1c", "metadata": {}, "outputs": [], "source": [ @@ -369,7 +377,7 @@ }, { "cell_type": "markdown", - "id": "253cdc30", + "id": "7bf97036", "metadata": {}, "source": [ "Overwrite the prompt encoder with the best global model" @@ -378,7 +386,7 @@ { "cell_type": "code", "execution_count": null, - "id": "0f257854", + "id": "33f9771b", "metadata": {}, "outputs": [], "source": [ @@ -391,7 +399,7 @@ }, { "cell_type": "markdown", - "id": "57b954e7", + "id": "69c35011", "metadata": {}, "source": [ "Run the model" @@ -400,7 +408,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8781d8f2", + "id": "64402b65", "metadata": {}, "outputs": [], "source": [ @@ -414,7 +422,7 @@ }, { "cell_type": "markdown", - "id": "12613bbd", + "id": "f1ff31f1", "metadata": {}, "source": [ "The expected output predictions look something like this\n", @@ -434,7 +442,7 @@ { "cell_type": "code", "execution_count": null, - "id": "d69d4973", + "id": "98d73f49", "metadata": {}, "outputs": [], "source": [] @@ -456,7 +464,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/integration/nemo/examples/prompt_learning/prompt_learning_20B.md b/integration/nemo/examples/prompt_learning/prompt_learning_20B.md index 97bb03a2d2..1e2d54ba48 100644 --- a/integration/nemo/examples/prompt_learning/prompt_learning_20B.md +++ b/integration/nemo/examples/prompt_learning/prompt_learning_20B.md @@ -15,13 +15,18 @@ To run three clients in parallel, we require at least six GPUs with 64 GB memory (Ampere or later GPU architecture). The example was tested on 6xA100 GPUs with 80 GB each. -We assume you followed the instructions [here](../../README.md#requirements) -to install the NeMo framework and the NeMo-NVFlare package. +We assume you followed the instructions [here](./README.md) +to install the NeMo framework and mount the required code. The example was tested using the [NeMo Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo), available with `docker pull nvcr.io/nvidia/nemo:23.02`. For downloading the pre-trained model, we use [git lfs](https://git-lfs.com). +Install it in the container with +``` +apt update +apt install git-lfs +``` ## Download the pre-trained LLM In this example, we use a [Megatron-GPT 20B](https://huggingface.co/nvidia/nemo-megatron-gpt-20B), a transformer-based language model based on the GPT architecture. @@ -29,7 +34,8 @@ In this example, we use a [Megatron-GPT 20B](https://huggingface.co/nvidia/nemo- # download the model from HuggingFace using git lfs git clone https://huggingface.co/nvidia/nemo-megatron-gpt-20B ``` -After download, the checkpoint `nemo_gpt20B_bf16_tp4.nemo` should have a size of 38 GB. +> Note, this will take some time. After download, the checkpoint `nemo_gpt20B_bf16_tp4.nemo` should have a size of 38 GB. +> You can check the download status with `du -sh nemo-megatron-gpt-20B/nemo_gpt20B_bf16_tp4.nemo`. Next, in order to minimize the number of GPUs needed to simulate each client, we convert the downloaded checkpoint that was trained using tensor parallel of size 4, to tensor parallel of size 2. @@ -115,27 +121,30 @@ In a standard terminal, run ``` python3 create_configs.py --job_folder "jobs/gpt_p-tuning_local_20B" --num_clients 3 --devices 2 --aggregation_epochs 50 --num_rounds 1 ``` -Next, submit the federated p-tuning job using the admin prompt. -Replace `[PWD]` with the path to this directory. +Next, submit the federated p-tuning job in the terminal running the admin command prompt. + ``` -submit_job [PWD]/jobs/gpt_p-tuning_local_20B +submit_job /workspace/jobs/gpt_p-tuning_local_20B ``` #### 2. Federated P-Tuning We use the [FedAvg](https://arxiv.org/abs/1602.05629) algorithm to p-tune the model in a federated scenario. First, create and modify the configuration files again. This time, we increase the number of FL rounds and decrease the number of local epochs per round to match the federated scenario. -Here, each client p-tunes for one local epoch before sending their local model updates to the server for aggregation. This is repeated for 50 FL rounds. +Here, each client p-tunes for one local epoch before sending their local model updates to the server for aggregation. +This is repeated for 50 FL rounds. + +In a standard terminal, run ``` python3 create_configs.py --job_folder "jobs/gpt_p-tuning_fedavg_20B" --num_clients 3 --devices 2 --aggregation_epochs 1 --num_rounds 50 ``` -Next, simulate the federated p-tuning using FedAvg. +Next, simulate the federated p-tuning using FedAvg in the terminal running the admin command prompt. ``` -submit_job [PWD]/jobs/gpt_p-tuning_fedavg_20B +submit_job /workspace/jobs/gpt_p-tuning_fedavg_20B ``` You can visualize the training process using TensorBoard ``` -tensorboard --logdir /tmp/nvflare/nemo +tensorboard --logdir /tmp/nvflare/poc ``` ## Results diff --git a/integration/nemo/examples/prompt_learning/requirements.txt b/integration/nemo/examples/prompt_learning/requirements.txt index 3fdbf10587..e4605852b5 100644 --- a/integration/nemo/examples/prompt_learning/requirements.txt +++ b/integration/nemo/examples/prompt_learning/requirements.txt @@ -1 +1 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc diff --git a/integration/nemo/examples/supervised_fine_tuning/README.md b/integration/nemo/examples/supervised_fine_tuning/README.md index 9fbce4cc3c..4fb829ea7d 100644 --- a/integration/nemo/examples/supervised_fine_tuning/README.md +++ b/integration/nemo/examples/supervised_fine_tuning/README.md @@ -1,22 +1,40 @@ ## Supervised Fine-tuning (SFT) with NeMo In this example, we utilize NeMo's [supervised fine-tuning](https://github.com/NVIDIA/NeMo-Megatron-Launcher#515-instruction-following-via-supervised-finetuning--sft-) -feature to showcase how to fine-tune the whole model on supervised data for learning how to follow user specified instructions. +feature to showcase how to fine-tune the whole model on supervised data for learning how to follow user-specified instructions. Due to the large model size of the LLM, we use NVFlare's streaming feature to transfer the model in chunks. -## Dependencies -This example running a 1.3B GPT model requires considerable computational resources. For training 1.3B model, SFT needs ~24GB GPU memory using fp16 precision. Hence, to run three clients in parallel, we can compute the resource needed accordingly. - +## Hardware requirement The example for a 3-client 1.3B GPT model experiment can be performed on either three 32 GB V100 GPUs, or one 80 GB A100 GPU. -We assume you followed the instructions [here](../../README.md#requirements) -to install the NeMo, NVFlare, and the NeMo-NVFlare package. +## Dependencies +This example of running a 1.3B GPT model requires considerable computational resources. For training 1.3B model, SFT needs ~24GB GPU memory using fp16 precision. Hence, we can compute the resources needed accordingly to run three clients in parallel. The example was tested using the [NeMo Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo), -available with `docker pull nvcr.io/nvidia/nemo:23.02`. In the following, we assume the root folder of the container is mounted to `/workspace` and all downloading, etc. operations are based on this root path. +available with `docker pull nvcr.io/nvidia/nemo:23.06`. +In the following, we assume this example folder of the container is mounted to `/workspace` and all downloading, etc. operations are based on this root path. -For downloading the pre-trained model, we use [git lfs](https://git-lfs.com). +Start the docker container from **this directory** using +``` +# cd NVFlare/integration/nemo/examples/supervised_fine_tuning +DOCKER_IMAGE="nvcr.io/nvidia/nemo:23.06" +docker run --runtime=nvidia -it --rm --shm-size=16g -p 8888:8888 -p 6006:6006 --ulimit memlock=-1 --ulimit stack=67108864 \ +-v ${PWD}:/workspace -w /workspace ${DOCKER_IMAGE} +``` + +For easy experimentation with NeMo, install NVFlare and mount the code inside the [nemo_nvflare](./nemo_nvflare) folder. +``` +pip install nvflare~=2.4.0rc7 +export PYTHONPATH=${PYTHONPATH}:/workspace +``` + +To download the pre-trained model, we use [git lfs](https://git-lfs.com). +Install it in the container with +``` +apt update +apt install git-lfs +``` ## Download the pre-trained LLM In this example, we use [Megatron-GPT 1.3B](https://huggingface.co/nvidia/nemo-megatron-gpt-1.3B), a transformer-based language model based on the GPT architecture. @@ -34,9 +52,9 @@ For SFT task, we will use three datasets: - [databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) - [OpenAssistant Conversations](https://huggingface.co/datasets/OpenAssistant/oasst1) -These three datasets contain instruction-following data in different formats under different settings: oasst1 features a tree struture for full conversations, while the other two are instruction(w/ or w/o context)-response pairs. +These three datasets contain instruction-following data in different formats under different settings: oasst1 features a tree structure for full conversations, while the other two are instruction(w/ or w/o context)-response pairs. -In this example, we first preprocess them following the [NeMo SFT](https://github.com/NVIDIA/NeMo-Megatron-Launcher#5151-sft-data-formatting)'s instructions. The script converts the "Instruction", "Context" and "Response" fields (or their equivalents) into "Input" and "Output". The script also concatenates the "Instruction" and "Context" fields with a \n\n separator and randomizes the order in which they appear in the input to generate a new JSONL file. +In this example, we first preprocess them following the [NeMo SFT](https://github.com/NVIDIA/NeMo-Megatron-Launcher#5151-sft-data-formatting) instructions. The script converts the "Instruction", "Context" and "Response" fields (or their equivalents) into "Input" and "Output". The script also concatenates the "Instruction" and "Context" fields with a \n\n separator and randomizes the order in which they appear in the input to generate a new JSONL file. #### 1. Download the datasets We download the datasets from HuggingFace: @@ -62,7 +80,7 @@ python utils/preprocess_oasst1.py --training_file Data/oasst1/data/train-00000-o ``` #### 3. Combine for centralized training -We also generate a combined version for centralized training baseline: +We also generate a combined version for a centralized training baseline: ``` mkdir Data/Processed/combined python utils/combine_jsonl.py --file_list Data/Processed/alpaca/training.jsonl Data/Processed/dolly/training.jsonl Data/Processed/oasst1/training.jsonl --output_path Data/Processed/combined/training.jsonl @@ -110,7 +128,7 @@ nvflare simulator jobs/gpt_sft_1.3B_fedavg -w workspace_simulator_fedavg -n 3 -g ``` ### Use POC mode -Alternatively, we can also NVFlare's [POC mode](https://nvflare.readthedocs.io/en/main/getting_started.html#setting-up-poc) to simulate +Alternatively, we can also use NVFlare's [POC mode](https://nvflare.readthedocs.io/en/main/getting_started.html#setting-up-poc) to simulate #### 1. Local and Centralized SFT For single-site and centralized training experiments, we create the poc workspaces: @@ -127,7 +145,7 @@ nvflare poc start -p admin@nvidia.com ``` -Next, copy the jobs to temp workspace. +Next, copy the jobs to the temp workspace. ``` cp -r jobs/gpt_sft_1.3B_* /tmp/nvflare/poc/example_project/prod_00/admin\@nvidia.com/transfer/ ``` @@ -139,6 +157,11 @@ submit_job gpt_sft_1.3B_dolly submit_job gpt_sft_1.3B_oasst1 submit_job gpt_sft_1.3B_combined ``` +During training, we can visualize the training process using TensorBoard. +With FL simulator, use +``` +tensorboard --logdir /workspace +``` #### 2. Federated SFT We use the [FedAvg](https://arxiv.org/abs/1602.05629) algorithm to perform SFT on the model in a federated scenario with 3 clients, each uses one of the three datasets. @@ -157,7 +180,7 @@ nvflare poc start -p admin@nvidia.com ``` -Next, simulate the federated SFT using FedAvg, similarly to single-client experiments +Next, simulate the federated SFT using FedAvg, similarly to single-client experiments: ``` cp -r jobs/gpt_sft_1.3B_fedavg /tmp/nvflare/poc/example_project/prod_00/admin\@nvidia.com/transfer/ ``` @@ -166,11 +189,14 @@ and to submit the FedAvg job submit_job gpt_sft_1.3B_fedavg ``` -## Results -During training, we can visualize the training process using TensorBoard +During training, we can visualize the training process using TensorBoard. +With the POC mode, use ``` -tensorboard --logdir /tmp/nvflare/nemo +tensorboard --logdir /tmp/nvflare/poc ``` + +## Results + In this scenario, all experiments utilize the same validation set, allowing for a direct comparison across all models. Note that we ran FL for 5 rounds, and asked NeMo to record the validation losses every few steps during local training. The validation losses for all experiments are shown below. @@ -203,7 +229,7 @@ As shown, FedAvg is able to generate a model with the best overall performance. We use NeMo's [inference script](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_eval.py) for generation task with models after SFT. Below, we define some test examples to feed to the SFT model to see its predictions. -First, we ask the model to generate answer to an open question "Tell me an interesting fact about space travel." +First, we ask the model to generate an answer to an open question: "Tell me an interesting fact about space travel." ``` ALPACA: The first human to orbit the Earth was Neil Armstrong, who flew on the Apollo 11 mission in 1969.' DOLLY: The International Space Station is the largest floating structure in the universe. It is made of steel and is about the size of a small house. @@ -211,7 +237,7 @@ OASST: Sure! Here are a few interesting facts about space travel:\n\n1. Space tr COMBINED: The first human to set foot on the Moon was Neil Armstrong. FEDAVG: The first person to travel to space was Neil Armstrong, who set foot on the moon in 1969. ``` -Note that models mostly gives plausible answers, but ALPACA-finetuned model in fact gives misinformation, since it should be Yuri Gagarin who is the first human to orbit the Earth. +Note that models mostly give plausible answers, but the ALPACA-finetuned model, in fact, gives misinformation since it should be Yuri Gagarin who is the first human to orbit the Earth. On the other hand, the model trained on the combined dataset, as well as the FL model trained with FedAvg, are able to generate a more accurate answer. Next, we ask the model to answer a question according to a given context, one instance from [SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/). @@ -228,6 +254,6 @@ OASST: The Denver Broncos defeated the Carolina Panthers 24–10 to win the Supe COMBINED: The Denver Broncos' FEDAVG: The AFC champion Denver Broncos defeated the NFC champion Carolina Panthers 24–10 to win the Super Bowl.' ``` -As we can see, the key word "Denver Broncos" is correctly captured by all models. However, ALPACA and FedAvg answers are a bit redundant, and OASST answer is not directly "to the question". +As we can see, the keyword "Denver Broncos" is correctly captured by all models. However, ALPACA and FedAvg answers are a bit redundant, and OASST answer is not directly "to the question". -Based on the above results, we can see that the models trained on the combined dataset and in a federated fashion are able to generate more stable and accurate answers. +Based on the above results, we can see that the models trained on the combined dataset and in a federated fashion can generate more stable and accurate answers. diff --git a/integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/__init__.py b/integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/__init__.py new file mode 100644 index 0000000000..18e75a481b --- /dev/null +++ b/integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_sharer_sft import ConfigSharerSFT +from .learner_executor import NemoLearnerExecutor +from .server_sft_model import ServerSFTModel +from .sft_learner import SFTLearner +from .share_config_sft import ShareConfigSFT diff --git a/integration/nemo/nemo_nvflare/config_sharer_sft.py b/integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/config_sharer_sft.py similarity index 100% rename from integration/nemo/nemo_nvflare/config_sharer_sft.py rename to integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/config_sharer_sft.py diff --git a/integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/constants.py b/integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/constants.py new file mode 100644 index 0000000000..2c54b42039 --- /dev/null +++ b/integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/constants.py @@ -0,0 +1,23 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class NemoConstants(object): + TASK_SHARE_CONFIG = "share_config" + + +class NemoDataKind(object): + CONFIGS = "nemo_configs" + NEMO_CONFIG = "nemo_config" + TASK_TEMPLATES = "nemo_task_templates" diff --git a/integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/learner_executor.py b/integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/learner_executor.py new file mode 100644 index 0000000000..a8fccbb9a5 --- /dev/null +++ b/integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/learner_executor.py @@ -0,0 +1,80 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.dxo import from_shareable +from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.app_constant import AppConstants +from nvflare.app_common.executors.learner_executor import LearnerExecutor + +from .constants import NemoConstants, NemoDataKind + + +class NemoLearnerExecutor(LearnerExecutor): + def __init__( + self, + learner_id, + train_task=AppConstants.TASK_TRAIN, + submit_model_task=AppConstants.TASK_SUBMIT_MODEL, + validate_task=AppConstants.TASK_VALIDATION, + share_config_task=NemoConstants.TASK_SHARE_CONFIG, + ): + """Key component to run learner on clients. + + Args: + learner_id (str): id of the learner object + train_task (str, optional): task name for train. Defaults to AppConstants.TASK_TRAIN. + submit_model_task (str, optional): task name for submit model. Defaults to AppConstants.TASK_SUBMIT_MODEL. + validate_task (str, optional): task name for validation. Defaults to AppConstants.TASK_VALIDATION. + share_config_task (str, optional): share config task name. + """ + super().__init__( + learner_id=learner_id, + train_task=train_task, + submit_model_task=submit_model_task, + validate_task=validate_task, + ) + self.share_config_task = share_config_task + self.is_initialized = False + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + if not self.is_initialized: + self.is_initialized = True + self.initialize(fl_ctx) + + if task_name == self.share_config_task: + self.log_info(fl_ctx, f"Client trainer got task: {task_name}") + try: + return self._set_learner_configs(shareable, fl_ctx, abort_signal) + except Exception as e: + self.log_error(fl_ctx, f"Setting config failed with exception {e}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + else: + return super().execute(task_name=task_name, shareable=shareable, fl_ctx=fl_ctx, abort_signal=abort_signal) + + def _set_learner_configs(self, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + dxo = from_shareable(shareable) + + if dxo.data_kind != NemoDataKind.CONFIGS: + raise ValueError(f"Expected DXO data to be of kind NemoDataKind.CONFIGS but got {dxo.data_kind}") + + if not dxo.data: + raise ValueError("Received config data is empty!") + + self.learner.set_configs(configs=dxo.data) + self.log_info(fl_ctx, f"Received config with {len(dxo.data)} entries from server.") + + return make_reply(ReturnCode.OK) diff --git a/integration/nemo/nemo_nvflare/server_sft_model.py b/integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/server_sft_model.py similarity index 100% rename from integration/nemo/nemo_nvflare/server_sft_model.py rename to integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/server_sft_model.py diff --git a/integration/nemo/nemo_nvflare/sft_learner.py b/integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/sft_learner.py similarity index 100% rename from integration/nemo/nemo_nvflare/sft_learner.py rename to integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/sft_learner.py diff --git a/integration/nemo/nemo_nvflare/share_config_sft.py b/integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/share_config_sft.py similarity index 100% rename from integration/nemo/nemo_nvflare/share_config_sft.py rename to integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/share_config_sft.py diff --git a/integration/nemo/nemo_nvflare/utils_sft.py b/integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/utils_sft.py similarity index 100% rename from integration/nemo/nemo_nvflare/utils_sft.py rename to integration/nemo/examples/supervised_fine_tuning/nemo_nvflare/utils_sft.py diff --git a/integration/sample/README.md b/integration/sample/README.md index cbe0a1917c..3a5322e40e 100644 --- a/integration/sample/README.md +++ b/integration/sample/README.md @@ -29,4 +29,4 @@ Every project and implementation has some assumptions behind it, either about th ## Required NVFLARE version -pip3 install nvflare>=2.3.0 +pip3 install nvflare>=2.4.0 diff --git a/job_templates/cyclic_cc_pt/info.conf b/job_templates/cyclic_cc_pt/info.conf index f6906f1b33..6777ec617b 100644 --- a/job_templates/cyclic_cc_pt/info.conf +++ b/job_templates/cyclic_cc_pt/info.conf @@ -1,5 +1,5 @@ { description = "client-controlled cyclic workflow with PyTorch ClientAPI trainer" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "client" } \ No newline at end of file diff --git a/job_templates/cyclic_pt/info.conf b/job_templates/cyclic_pt/info.conf index ff5f710307..b5d297af4b 100644 --- a/job_templates/cyclic_pt/info.conf +++ b/job_templates/cyclic_pt/info.conf @@ -1,5 +1,5 @@ { description = "server-controlled cyclic workflow with PyTorch ClientAPI trainer" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/psi_csv/info.conf b/job_templates/psi_csv/info.conf index f6b6d49a8c..6894c5ba35 100644 --- a/job_templates/psi_csv/info.conf +++ b/job_templates/psi_csv/info.conf @@ -1,5 +1,5 @@ { description = "private-set intersection for csv data" - client_category = "Executor" + execution_api_type = "Executor" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/readme.md b/job_templates/readme.md index ba4e6fc61f..245ce3c198 100644 --- a/job_templates/readme.md +++ b/job_templates/readme.md @@ -13,8 +13,45 @@ Each job template contains the following informations * information card: info.md for display purpose * information config: used by program -# Configuration format +Refer to the [Job CLI Documentation](https://nvflare.readthedocs.io/en/main/user_guide/nvflare_cli/job_cli.html) for details on how to use the Job Templates with the Job CLI. + +## Configuration format Configurations are written in HOCON (human optimized object Notation). As a variant of JSON, .conf can also use json format. The pyhocon format allows for comments, and you can remove many of the double quotes as well as replace ":" with "=" to make the configurations look cleaner. -You can find details in [pyhoconb: HOCON Parser for python](https://github.com/chimpler/pyhocon). +You can find details in [pyhocon: HOCON Parser for python](https://github.com/chimpler/pyhocon). + +## List of Job Templates + +View all the available job templates with the following command: + +```nvflare job list_templates``` + +| Example | Controller-Type | Execution API Type | Description | +|---------|-----------------|-----------------|-------------| +| [cyclic_cc_pt](./cyclic_cc_pt) | client | client_api | client-controlled cyclic workflow with PyTorch ClientAPI trainer | +| [cyclic_pt](./cyclic_pt) | server | client_api | server-controlled cyclic workflow with PyTorch ClientAPI trainer | +| [psi_csv](./psi_csv) | server | Executor | private-set intersection for csv data | +| [sag_cross_np](./sag_cross_np) | server | client executor | scatter & gather and cross-site validation using numpy | +| [sag_cse_pt](./sag_cse_pt) | server | client_api | scatter & gather workflow and cross-site evaluation with PyTorch | +| [sag_gnn](./sag_gnn) | server | client_api | scatter & gather workflow for gnn learning | +| [sag_nemo](./sag_nemo) | server | client_api | Scatter and Gather Workflow for NeMo | +| [sag_np](./sag_np) | server | client_api | scatter & gather workflow using numpy | +| [sag_np_cell_pipe](./sag_np_cell_pipe) | server | client_api | scatter & gather workflow using numpy | +| [sag_np_metrics](./sag_np_metrics) | server | client_api | scatter & gather workflow using numpy | +| [sag_pt](./sag_pt) | server | client_api | scatter & gather workflow using pytorch | +| [sag_pt_deploy_map](./sag_pt_deploy_map) | server | client_api | SAG workflow with pytorch, deploy_map, site-specific configs | +| [sag_pt_executor](./sag_pt_executor) | server | Executor | scatter & gather workflow and cross-site evaluation with PyTorch | +| [sag_pt_he](./sag_pt_he) | server | client_api | scatter & gather workflow using pytorch and homomorphic encryption | +| [sag_pt_mlflow](./sag_pt_mlflow) | server | client_api | scatter & gather workflow using pytorch with MLflow tracking | +| [sag_pt_model_learner](./sag_pt_model_learner) | server | ModelLearner | scatter & gather workflow and cross-site evaluation with PyTorch | +| [sag_tf](./sag_tf) | server | client_api | scatter & gather workflow using TensorFlow | +| [sklearn_kmeans](./sklearn_kmeans) | server | client_api | scikit-learn KMeans model | +| [sklearn_linear](./sklearn_linear) | server | client_api | scikit-learn linear model | +| [sklearn_svm](./sklearn_svm) | server | client_api | scikit-learn SVM model | +| [stats_df](./stats_df) | server | stats executor | FedStats: tabular data with pandas | +| [stats_image](./stats_image) | server | stats executor | FedStats: image intensity histogram | +| [swarm_cse_pt](./swarm_cse_pt) | client | client_api | Swarm Learning with Cross-Site Evaluation with PyTorch | +| [swarm_cse_pt_model_learner](./swarm_cse_pt_model_learner) | client | ModelLearner | Swarm Learning with Cross-Site Evaluation with PyTorch ModelLearner | +| [vertical_xgb](./vertical_xgb) | server | Executor | vertical federated xgboost | +| [xgboost_tree](./xgboost_tree) | server | client_api | xgboost horizontal tree-based collaboration model | diff --git a/job_templates/sag_cross_np/info.conf b/job_templates/sag_cross_np/info.conf index 838c9bc165..f9be1549e4 100644 --- a/job_templates/sag_cross_np/info.conf +++ b/job_templates/sag_cross_np/info.conf @@ -1,5 +1,5 @@ { description = "scatter & gather and cross-site validation using numpy" - client_category = "client executor" + execution_api_type = "client executor" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sag_cse_pt/info.conf b/job_templates/sag_cse_pt/info.conf index 310b4423fb..258d56d9a1 100644 --- a/job_templates/sag_cse_pt/info.conf +++ b/job_templates/sag_cse_pt/info.conf @@ -1,5 +1,5 @@ { description = "scatter & gather workflow and cross-site evaluation with PyTorch" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sag_gnn/info.conf b/job_templates/sag_gnn/info.conf index 48882d4e4c..e4aaabd7c2 100644 --- a/job_templates/sag_gnn/info.conf +++ b/job_templates/sag_gnn/info.conf @@ -1,5 +1,5 @@ { description = "scatter & gather workflow for gnn learning" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sag_nemo/info.conf b/job_templates/sag_nemo/info.conf index ddd0fec10d..96b046b4b6 100644 --- a/job_templates/sag_nemo/info.conf +++ b/job_templates/sag_nemo/info.conf @@ -1,5 +1,5 @@ { description = "Scatter and Gather Workflow for NeMo" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sag_np/info.conf b/job_templates/sag_np/info.conf index e00a60749f..365c3cb62b 100644 --- a/job_templates/sag_np/info.conf +++ b/job_templates/sag_np/info.conf @@ -1,5 +1,5 @@ { description = "scatter & gather workflow using numpy" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sag_np_cell_pipe/info.conf b/job_templates/sag_np_cell_pipe/info.conf index e00a60749f..365c3cb62b 100644 --- a/job_templates/sag_np_cell_pipe/info.conf +++ b/job_templates/sag_np_cell_pipe/info.conf @@ -1,5 +1,5 @@ { description = "scatter & gather workflow using numpy" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sag_np_metrics/info.conf b/job_templates/sag_np_metrics/info.conf index e00a60749f..365c3cb62b 100644 --- a/job_templates/sag_np_metrics/info.conf +++ b/job_templates/sag_np_metrics/info.conf @@ -1,5 +1,5 @@ { description = "scatter & gather workflow using numpy" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sag_pt/info.conf b/job_templates/sag_pt/info.conf index 7e1015daf9..31ded18cf9 100644 --- a/job_templates/sag_pt/info.conf +++ b/job_templates/sag_pt/info.conf @@ -1,5 +1,5 @@ { description = "scatter & gather workflow using pytorch" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sag_pt_deploy_map/info.conf b/job_templates/sag_pt_deploy_map/info.conf index 6c387de8f7..df5ffdb4e3 100644 --- a/job_templates/sag_pt_deploy_map/info.conf +++ b/job_templates/sag_pt_deploy_map/info.conf @@ -1,5 +1,5 @@ { description = "SAG workflow with pytorch, deploy_map, site-specific configs" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sag_pt_executor/info.conf b/job_templates/sag_pt_executor/info.conf index e9a2aac332..914f92e5f9 100644 --- a/job_templates/sag_pt_executor/info.conf +++ b/job_templates/sag_pt_executor/info.conf @@ -1,5 +1,5 @@ { description = "scatter & gather workflow and cross-site evaluation with PyTorch Executor" - client_category = "Executor" + execution_api_type = "Executor" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sag_pt_he/info.conf b/job_templates/sag_pt_he/info.conf index 93c4edb7c9..1e60edc4c5 100644 --- a/job_templates/sag_pt_he/info.conf +++ b/job_templates/sag_pt_he/info.conf @@ -1,5 +1,5 @@ { description = "scatter & gather workflow using pytorch and homomorphic encryption" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sag_pt_mlflow/info.conf b/job_templates/sag_pt_mlflow/info.conf index 4e9272a588..e4a251ede0 100644 --- a/job_templates/sag_pt_mlflow/info.conf +++ b/job_templates/sag_pt_mlflow/info.conf @@ -1,5 +1,5 @@ { description = "scatter & gather workflow using pytorch with MLflow tracking" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } diff --git a/job_templates/sag_pt_model_learner/info.conf b/job_templates/sag_pt_model_learner/info.conf index c2b58b68dc..91df8ec626 100644 --- a/job_templates/sag_pt_model_learner/info.conf +++ b/job_templates/sag_pt_model_learner/info.conf @@ -1,5 +1,5 @@ { description = "scatter & gather workflow and cross-site evaluation with PyTorch ModelLearner" - client_category = "ModelLearner" + execution_api_type = "ModelLearner" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sag_tf/info.conf b/job_templates/sag_tf/info.conf index 27e4bba015..363950935b 100644 --- a/job_templates/sag_tf/info.conf +++ b/job_templates/sag_tf/info.conf @@ -1,5 +1,5 @@ { description = "scatter & gather workflow using TensorFlow" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sklearn_kmeans/info.conf b/job_templates/sklearn_kmeans/info.conf index c6d3bf68bc..9ca7516388 100644 --- a/job_templates/sklearn_kmeans/info.conf +++ b/job_templates/sklearn_kmeans/info.conf @@ -1,5 +1,5 @@ { description = "scikit-learn KMeans model" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sklearn_linear/info.conf b/job_templates/sklearn_linear/info.conf index 52ef59cb2d..13774438e9 100644 --- a/job_templates/sklearn_linear/info.conf +++ b/job_templates/sklearn_linear/info.conf @@ -1,5 +1,5 @@ { description = "scikit-learn linear model" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/sklearn_svm/info.conf b/job_templates/sklearn_svm/info.conf index d6678a6e4b..8dc7d5c6dc 100644 --- a/job_templates/sklearn_svm/info.conf +++ b/job_templates/sklearn_svm/info.conf @@ -1,5 +1,5 @@ { description = "scikit-learn SVM model" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/stats_df/info.conf b/job_templates/stats_df/info.conf index ba06c1e0b7..f45779ffd8 100644 --- a/job_templates/stats_df/info.conf +++ b/job_templates/stats_df/info.conf @@ -1,5 +1,5 @@ { description = "FedStats: tabular data with pandas" - client_category = "stats executor" + execution_api_type = "stats executor" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/stats_image/info.conf b/job_templates/stats_image/info.conf index 5e4d21eda8..9c1a2e4ade 100644 --- a/job_templates/stats_image/info.conf +++ b/job_templates/stats_image/info.conf @@ -1,5 +1,5 @@ { description = "FedStats: image intensity histogram" - client_category = "stats executor" + execution_api_type = "stats executor" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/swarm_cse_pt/info.conf b/job_templates/swarm_cse_pt/info.conf index 79c97c66f5..ad0bcd2cae 100644 --- a/job_templates/swarm_cse_pt/info.conf +++ b/job_templates/swarm_cse_pt/info.conf @@ -1,5 +1,5 @@ { description = "Swarm Learning with Cross-Site Evaluation with PyTorch" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "client" } diff --git a/job_templates/swarm_cse_pt_model_learner/info.conf b/job_templates/swarm_cse_pt_model_learner/info.conf index 35d79ca5d1..a690f8637c 100644 --- a/job_templates/swarm_cse_pt_model_learner/info.conf +++ b/job_templates/swarm_cse_pt_model_learner/info.conf @@ -1,5 +1,5 @@ { description = "Swarm Learning with Cross-Site Evaluation with PyTorch ModelLearner" - client_category = "ModelLearner" + execution_api_type = "ModelLearner" controller_type = "client" } diff --git a/job_templates/vertical_xgb/info.conf b/job_templates/vertical_xgb/info.conf index d391599686..99643d8607 100644 --- a/job_templates/vertical_xgb/info.conf +++ b/job_templates/vertical_xgb/info.conf @@ -1,5 +1,5 @@ { description = "vertical federated xgboost" - client_category = "Executor" + execution_api_type = "Executor" controller_type = "server" } \ No newline at end of file diff --git a/job_templates/xgboost_tree/info.conf b/job_templates/xgboost_tree/info.conf index c380f6dc39..9b3f3b0952 100644 --- a/job_templates/xgboost_tree/info.conf +++ b/job_templates/xgboost_tree/info.conf @@ -1,5 +1,5 @@ { description = "xgboost horizontal tree-based collaboration model" - client_category = "client_api" + execution_api_type = "client_api" controller_type = "server" } \ No newline at end of file diff --git a/nvflare/apis/analytix.py b/nvflare/apis/analytix.py index 68772b13a6..0cb0ab1b45 100644 --- a/nvflare/apis/analytix.py +++ b/nvflare/apis/analytix.py @@ -183,11 +183,7 @@ def convert_data_type( return sender_data_type if sender == LogWriterName.MLFLOW and receiver == LogWriterName.TORCH_TB: - if AnalyticsDataType.PARAMETER == sender_data_type: - return AnalyticsDataType.SCALAR - elif AnalyticsDataType.PARAMETERS == sender_data_type: - return AnalyticsDataType.SCALARS - elif AnalyticsDataType.METRIC == sender_data_type: + if AnalyticsDataType.METRIC == sender_data_type: return AnalyticsDataType.SCALAR elif AnalyticsDataType.METRICS == sender_data_type: return AnalyticsDataType.SCALARS diff --git a/nvflare/app_common/abstract/params_converter.py b/nvflare/app_common/abstract/params_converter.py index 6ec2e0e546..1ae611a836 100644 --- a/nvflare/app_common/abstract/params_converter.py +++ b/nvflare/app_common/abstract/params_converter.py @@ -16,12 +16,11 @@ from typing import Any, List from nvflare.apis.dxo import from_shareable -from nvflare.apis.filter import Filter from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable -class ParamsConverter(Filter, ABC): +class ParamsConverter(ABC): def __init__(self, supported_tasks: List[str] = None): self.supported_tasks = supported_tasks diff --git a/nvflare/app_common/executors/client_api_launcher_executor.py b/nvflare/app_common/executors/client_api_launcher_executor.py index 6fdc9b8372..776c5a0206 100644 --- a/nvflare/app_common/executors/client_api_launcher_executor.py +++ b/nvflare/app_common/executors/client_api_launcher_executor.py @@ -33,7 +33,7 @@ def __init__( external_execution_wait: float = 5.0, peer_read_timeout: Optional[float] = None, monitor_interval: float = 0.01, - read_interval: float = 0.001, + read_interval: float = 0.5, heartbeat_interval: float = 5.0, heartbeat_timeout: float = 30.0, workers: int = 4, @@ -43,8 +43,8 @@ def __init__( submit_model_task_name: str = "submit_model", from_nvflare_converter_id: Optional[str] = None, to_nvflare_converter_id: Optional[str] = None, - params_exchange_format: ExchangeFormat = ExchangeFormat.NUMPY, - params_transfer_type: TransferType = TransferType.FULL, + params_exchange_format: str = ExchangeFormat.NUMPY, + params_transfer_type: str = TransferType.FULL, config_file_name: str = CLIENT_API_CONFIG, ) -> None: """Initializes the ClientAPILauncherExecutor. @@ -70,8 +70,8 @@ def __init__( This ParamsConverter will be called when model is sent from nvflare controller side to executor side. to_nvflare_converter_id (Optional[str]): Identifier used to get the ParamsConverter from NVFlare components. This ParamsConverter will be called when model is sent from nvflare executor side to controller side. - params_exchange_format (ExchangeFormat): What format to exchange the parameters. - params_transfer_type (TransferType): How to transfer the parameters. FULL means the whole model parameters are sent. + params_exchange_format (str): What format to exchange the parameters. + params_transfer_type (str): How to transfer the parameters. FULL means the whole model parameters are sent. DIFF means that only the difference is sent. config_file_name (str): The config file name to write attributes into, the client api will read in this file. """ diff --git a/nvflare/app_common/executors/launcher_executor.py b/nvflare/app_common/executors/launcher_executor.py index a9fd1fab34..90cc6a3d0a 100644 --- a/nvflare/app_common/executors/launcher_executor.py +++ b/nvflare/app_common/executors/launcher_executor.py @@ -45,7 +45,7 @@ def __init__( external_execution_wait: float = 5.0, peer_read_timeout: Optional[float] = None, monitor_interval: float = 1.0, - read_interval: float = 0.1, + read_interval: float = 0.5, heartbeat_interval: float = 5.0, heartbeat_timeout: float = 30.0, workers: int = 1, diff --git a/nvflare/app_common/executors/task_exchanger.py b/nvflare/app_common/executors/task_exchanger.py index e7521d3380..77a7b19bb9 100644 --- a/nvflare/app_common/executors/task_exchanger.py +++ b/nvflare/app_common/executors/task_exchanger.py @@ -33,7 +33,7 @@ class TaskExchanger(Executor): def __init__( self, pipe_id: str, - read_interval: float = 0.1, + read_interval: float = 0.5, heartbeat_interval: float = 5.0, heartbeat_timeout: Optional[float] = 30.0, resend_interval: float = 2.0, @@ -46,16 +46,27 @@ def __init__( """Constructor of TaskExchanger. Args: - pipe_id: component id of pipe - read_interval: how often to read from pipe - heartbeat_interval: how often to send heartbeat to peer - heartbeat_timeout: max amount of time to allow missing heartbeats before treating peer as dead - resend_interval: how often to resend a message when failing to send - max_resends: max number of resends. None means no limit - peer_read_timeout: time to wait for peer to accept sent message - task_wait_time: how long to wait for a task to complete. None means waiting forever - result_poll_interval: how often to poll task result - pipe_channel_name: the channel name for sending task requests + pipe_id (str): component id of pipe. + read_interval (float): how often to read from pipe. + Defaults to 0.5. + heartbeat_interval (float): how often to send heartbeat to peer. + Defaults to 5.0. + heartbeat_timeout (float, optional): how long to wait for a + heartbeat from the peer before treating the peer as dead, + 0 means DO NOT check for heartbeat. Defaults to 30.0. + resend_interval (float): how often to resend a message if failing to send. + None means no resend. Note that if the pipe does not support resending, + then no resend. Defaults to 2.0. + max_resends (int, optional): max number of resend. None means no limit. + Defaults to None. + peer_read_timeout (float, optional): time to wait for peer to accept sent message. + Defaults to 5.0. + task_wait_time (float, optional): how long to wait for a task to complete. + None means waiting forever. Defaults to None. + result_poll_interval (float): how often to poll task result. + Defaults to 0.5. + pipe_channel_name: the channel name for sending task requests. + Defaults to "task". """ Executor.__init__(self) check_str("pipe_id", pipe_id) @@ -104,7 +115,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): self.pipe_handler.set_status_cb(self._pipe_status_cb) self.pipe.open(self.pipe_channel_name) self.pipe_handler.start() - elif event_type == EventType.END_RUN: + elif event_type == EventType.ABOUT_TO_END_RUN: self.log_info(fl_ctx, "Stopping pipe handler") if self.pipe_handler: self.pipe_handler.notify_end("end_of_job") diff --git a/nvflare/app_common/utils/fl_model_utils.py b/nvflare/app_common/utils/fl_model_utils.py index 204928fd13..2d84daa14f 100644 --- a/nvflare/app_common/utils/fl_model_utils.py +++ b/nvflare/app_common/utils/fl_model_utils.py @@ -97,7 +97,11 @@ def from_shareable(shareable: Shareable, fl_ctx: Optional[FLContext] = None) -> params = None meta = {} - try: + submit_model_name = shareable.get_header(AppConstants.SUBMIT_MODEL_NAME) + if submit_model_name: + # this only happens in cross-site eval right now + meta[MetaKey.SUBMIT_MODEL_NAME] = submit_model_name + else: dxo = from_shareable(shareable) meta = dict(dxo.meta) if dxo.data_kind == DataKind.METRICS: @@ -115,10 +119,6 @@ def from_shareable(shareable: Shareable, fl_ctx: Optional[FLContext] = None) -> if MetaKey.INITIAL_METRICS in meta: metrics = meta[MetaKey.INITIAL_METRICS] - except: - # this only happens in cross-site eval right now - submit_model_name = shareable.get_header(AppConstants.SUBMIT_MODEL_NAME) - meta[MetaKey.SUBMIT_MODEL_NAME] = submit_model_name current_round = shareable.get_header(AppConstants.CURRENT_ROUND, None) total_rounds = shareable.get_header(AppConstants.NUM_ROUNDS, None) diff --git a/nvflare/app_common/workflows/base_fedavg.py b/nvflare/app_common/workflows/base_fedavg.py index d031998e35..cd196d86b3 100644 --- a/nvflare/app_common/workflows/base_fedavg.py +++ b/nvflare/app_common/workflows/base_fedavg.py @@ -17,6 +17,7 @@ from nvflare.apis.fl_constant import FLMetaKey from nvflare.app_common.abstract.fl_model import FLModel +from nvflare.app_common.abstract.model import make_model_learnable from nvflare.app_common.aggregators.weighted_aggregation_helper import WeightedAggregationHelper from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.app_event_type import AppEventType @@ -142,5 +143,8 @@ def update_model(self, aggr_result): self.model = FLModelUtils.update_model(self.model, aggr_result) - self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self.model, private=True, sticky=True) + # persistor uses Learnable format to save model + ml = make_model_learnable(weights=self.model.params, meta_props=self.model.meta) + self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, ml, private=True, sticky=True) + self.event(AppEventType.AFTER_SHAREABLE_TO_LEARNABLE) diff --git a/nvflare/app_common/workflows/cyclic_ctl.py b/nvflare/app_common/workflows/cyclic_ctl.py index b274aa1f77..754e1b06b6 100644 --- a/nvflare/app_common/workflows/cyclic_ctl.py +++ b/nvflare/app_common/workflows/cyclic_ctl.py @@ -16,6 +16,7 @@ import random from nvflare.apis.client import Client +from nvflare.apis.fl_constant import ReturnCode from nvflare.apis.fl_context import FLContext from nvflare.apis.impl.controller import ClientTask, Controller, Task from nvflare.apis.shareable import Shareable @@ -145,6 +146,19 @@ def _get_relay_orders(self, fl_ctx: FLContext): return targets def _process_result(self, client_task: ClientTask, fl_ctx: FLContext): + result = client_task.result + rc = result.get_return_code() + client_name = client_task.client.name + + # Raise errors if ReturnCode is not OK. + if rc and rc != ReturnCode.OK: + self.system_panic( + f"Result from {client_name} is bad, error code: {rc}. " + f"{self.__class__.__name__} exiting at round {self._current_round}.", + fl_ctx=fl_ctx, + ) + return False + # submitted shareable is stored in client_task.result # we need to update task.data with that shareable so the next target # will get the updated shareable diff --git a/nvflare/app_common/workflows/model_controller.py b/nvflare/app_common/workflows/model_controller.py index 5320a432be..0b65f07539 100644 --- a/nvflare/app_common/workflows/model_controller.py +++ b/nvflare/app_common/workflows/model_controller.py @@ -138,7 +138,9 @@ def start_controller(self, fl_ctx: FLContext) -> None: else: self.model = FLModel(params_type=ParamsType.FULL, params={}) - self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, self.model, private=True, sticky=True) + # persistor uses Learnable format to save model + ml = make_model_learnable(weights=self.model.params, meta_props=self.model.meta) + self.fl_ctx.set_prop(AppConstants.GLOBAL_MODEL, ml, private=True, sticky=True) self.event(AppEventType.INITIAL_MODEL_LOADED) self.engine = self.fl_ctx.get_engine() @@ -231,7 +233,11 @@ def _process_result(self, client_task: ClientTask, fl_ctx: FLContext) -> None: result = client_task.result client_name = client_task.client.name + self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self._current_round, private=True, sticky=True) + + self.event(AppEventType.BEFORE_CONTRIBUTION_ACCEPT) self._accept_train_result(client_name=client_name, result=result, fl_ctx=fl_ctx) + self.event(AppEventType.AFTER_CONTRIBUTION_ACCEPT) # Turn result into FLModel result_model = FLModelUtils.from_shareable(result) @@ -270,7 +276,6 @@ def _accept_train_result(self, client_name: str, result: Shareable, fl_ctx: FLCo ) return - self.fl_ctx.set_prop(AppConstants.CURRENT_ROUND, self._current_round, private=True, sticky=True) self.fl_ctx.set_prop(AppConstants.TRAINING_RESULT, result, private=True, sticky=False) @abstractmethod @@ -307,6 +312,7 @@ def save_model(self): ) or self._current_round == self._num_rounds - 1: self.info("Start persist model on server.") self.event(AppEventType.BEFORE_LEARNABLE_PERSIST) + # persistor uses Learnable format to save model ml = make_model_learnable(weights=self.model.params, meta_props=self.model.meta) self.persistor.save(ml, self.fl_ctx) self.event(AppEventType.AFTER_LEARNABLE_PERSIST) diff --git a/nvflare/app_common/xgb/__init__.py b/nvflare/app_common/xgb/__init__.py new file mode 100644 index 0000000000..df104c37e9 --- /dev/null +++ b/nvflare/app_common/xgb/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.app_common.xgb.fed_controller import XGBFedController +from nvflare.app_common.xgb.fed_executor import FedXGBHistogramExecutor +from nvflare.app_common.xgb.mock.mock_controller import MockXGBController +from nvflare.app_common.xgb.mock.mock_executor import MockXGBExecutor diff --git a/nvflare/app_common/xgb/adaptors/__init__.py b/nvflare/app_common/xgb/adaptors/__init__.py new file mode 100644 index 0000000000..4fc50543f1 --- /dev/null +++ b/nvflare/app_common/xgb/adaptors/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/app_common/xgb/adaptors/adaptor.py b/nvflare/app_common/xgb/adaptors/adaptor.py new file mode 100644 index 0000000000..861e196de7 --- /dev/null +++ b/nvflare/app_common/xgb/adaptors/adaptor.py @@ -0,0 +1,431 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time +from abc import ABC, abstractmethod + +from nvflare.apis.fl_component import FLComponent +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.app_common.xgb.defs import Constant +from nvflare.app_common.xgb.runners.xgb_runner import XGBRunner +from nvflare.app_common.xgb.sender import Sender +from nvflare.fuel.utils.validation_utils import check_non_negative_int, check_object_type, check_positive_int + + +class XGBAdaptor(ABC, FLComponent): + """ + XGBAdaptors are used to integrate FLARE with XGBoost Target (Server or Client) in run time. + + For example, an XGB server could be run as a gRPC server process, or be run as part of the FLARE's FL server + process. Similarly, an XGB client could be run as a gRPC client process, or be run as part of the + FLARE's FL client process. + + Each type of XGB Target requires an appropriate adaptor to integrate it with FLARE's XGB Controller or Executor. + + The XGBAdaptor class defines commonly required methods for all adaptor implementations. + """ + + def __init__(self): + FLComponent.__init__(self) + self.abort_signal = None + self.xgb_runner = None + + def set_runner(self, runner: XGBRunner): + """Set the XGB Runner that will be used to run XGB processing logic. + Note that the adaptor is only responsible for starting the runner appropriately (in a thread or in a + separate process). + + Args: + runner: the runner to be set + + Returns: None + + """ + if not isinstance(runner, XGBRunner): + raise TypeError(f"runner must be XGBRunner but got {type(runner)}") + self.xgb_runner = runner + + def set_abort_signal(self, abort_signal: Signal): + """Called by XGB Controller/Executor to set the abort_signal. + + The abort_signal is assigned by FLARE's XGB Controller/Executor. It is used by the Controller/Executor + to tell the adaptor that the job has been aborted. + + Args: + abort_signal: the abort signal assigned by the caller. + + Returns: None + + """ + check_object_type("abort_signal", abort_signal, Signal) + self.abort_signal = abort_signal + + def initialize(self, fl_ctx: FLContext): + """Called by the Controller/Executor to initialize the adaptor. + + Args: + fl_ctx: the FL context + + Returns: None + + """ + pass + + @abstractmethod + def start(self, fl_ctx: FLContext): + """Called by XGB Controller/Executor to start the target. + If any error occurs when starting the target, this method should raise an exception. + + Args: + fl_ctx: the FL context. + + Returns: None + + """ + pass + + @abstractmethod + def stop(self, fl_ctx: FLContext): + """Called by XGB Controller/Executor to stop the target. + If any error occurs when stopping the target, this method should raise an exception. + + Args: + fl_ctx: the FL context. + + Returns: None + + """ + pass + + @abstractmethod + def configure(self, config: dict, fl_ctx: FLContext): + """Called by XGB Controller/Executor to configure the adaptor. + If any error occurs, this method should raise an exception. + + Args: + config: config data + fl_ctx: the FL context + + Returns: None + + """ + pass + + @abstractmethod + def _is_stopped(self) -> (bool, int): + """Called by the adaptor's monitor to know whether the target is stopped. + Note that this method is not called by XGB Controller/Executor. + + Returns: a tuple of: whether the target is stopped, and return code (if stopped) + + Note that a non-zero return code is considered abnormal completion of the target. + + """ + pass + + def _monitor(self, fl_ctx: FLContext, target_stopped_cb): + while True: + if self.abort_signal.triggered: + # asked to abort + self.stop(fl_ctx) + return + + stopped, rc = self._is_stopped() + if stopped: + # target already stopped - notify the caller + target_stopped_cb(rc, fl_ctx) + return + + time.sleep(0.1) + + def monitor_target(self, fl_ctx: FLContext, target_stopped_cb): + """Called by XGB Controller/Executor to monitor the health of the target. + + The monitor periodically checks the abort signal. Once set, it calls the adaptor's stop() method + to stop the running of the target. + + The monitor also periodically checks whether the target is already stopped (by calling the is_stopped + method). If the target is stopped, the monitor will call the specified target_stopped_cb. + + Args: + fl_ctx: FL context + target_stopped_cb: the callback function to be called when the target is stopped. + + Returns: None + + """ + if not callable(target_stopped_cb): + raise RuntimeError(f"target_stopped_cb must be callable but got {type(target_stopped_cb)}") + + # start the monitor in a separate daemon thread! + t = threading.Thread(target=self._monitor, args=(fl_ctx, target_stopped_cb), daemon=True) + t.start() + + +class XGBServerAdaptor(XGBAdaptor): + """ + XGBServerAdaptor specifies commonly required methods for server adaptor implementations. + """ + + def __init__(self): + XGBAdaptor.__init__(self) + self.world_size = None + + def configure(self, config: dict, fl_ctx: FLContext): + """Called by XGB Controller to configure the target. + + The world_size is a required config parameter. + + Args: + config: config data + fl_ctx: FL context + + Returns: None + + """ + ws = config.get(Constant.CONF_KEY_WORLD_SIZE) + if not ws: + raise RuntimeError("world_size is not configured") + + check_positive_int(Constant.CONF_KEY_WORLD_SIZE, ws) + self.world_size = ws + + @abstractmethod + def all_gather(self, rank: int, seq: int, send_buf: bytes, fl_ctx: FLContext) -> bytes: + """Called by the XGB Controller to perform Allgather operation, per XGBoost spec. + + Args: + rank: rank of the calling client + seq: sequence number of the request + send_buf: operation input data + fl_ctx: FL context + + Returns: operation result + + """ + pass + + @abstractmethod + def all_gather_v(self, rank: int, seq: int, send_buf: bytes, fl_ctx: FLContext) -> bytes: + """Called by the XGB Controller to perform AllgatherV operation, per XGBoost spec. + + Args: + rank: rank of the calling client + seq: sequence number of the request + send_buf: input data + fl_ctx: FL context + + Returns: operation result + + """ + pass + + @abstractmethod + def all_reduce( + self, + rank: int, + seq: int, + data_type: int, + reduce_op: int, + send_buf: bytes, + fl_ctx: FLContext, + ) -> bytes: + """Called by the XGB Controller to perform Allreduce operation, per XGBoost spec. + + Args: + rank: rank of the calling client + seq: sequence number of the request + data_type: data type of the input + reduce_op: reduce operation to be performed + send_buf: input data + fl_ctx: FL context + + Returns: operation result + + """ + pass + + @abstractmethod + def broadcast(self, rank: int, seq: int, root: int, send_buf: bytes, fl_ctx: FLContext) -> bytes: + """Called by the XGB Controller to perform Broadcast operation, per XGBoost spec. + + Args: + rank: rank of the calling client + seq: sequence number of the request + root: root rank of the broadcast + send_buf: input data + fl_ctx: FL context + + Returns: operation result + + """ + pass + + +class XGBClientAdaptor(XGBAdaptor): + """ + XGBClientAdaptor specifies commonly required methods for client adaptor implementations. + """ + + def __init__(self): + """Constructor of XGBClientAdaptor""" + XGBAdaptor.__init__(self) + self.engine = None + self.sender = None + self.stopped = False + self.rank = None + self.num_rounds = None + self.world_size = None + + def set_sender(self, sender: Sender): + """Set the sender to be used to send XGB operation requests to the server. + + Args: + sender: the sender to be set + + Returns: None + + """ + if not isinstance(sender, Sender): + raise TypeError(f"sender must be Sender but got {type(sender)}") + self.sender = sender + + def configure(self, config: dict, fl_ctx: FLContext): + """Called by XGB Executor to configure the target. + + The rank, world size, and number of rounds are required config parameters. + + Args: + config: config data + fl_ctx: FL context + + Returns: None + + """ + ws = config.get(Constant.CONF_KEY_WORLD_SIZE) + if not ws: + raise RuntimeError("world_size is not configured") + + check_positive_int(Constant.CONF_KEY_WORLD_SIZE, ws) + self.world_size = ws + + rank = config.get(Constant.CONF_KEY_RANK) + if rank is None: + raise RuntimeError("rank is not configured") + + check_non_negative_int(Constant.CONF_KEY_RANK, rank) + self.rank = rank + + num_rounds = config.get(Constant.CONF_KEY_NUM_ROUNDS) + if num_rounds is None: + raise RuntimeError("num_rounds is not configured") + + check_positive_int(Constant.CONF_KEY_NUM_ROUNDS, num_rounds) + self.num_rounds = num_rounds + + def _send_request(self, op: str, req: Shareable) -> bytes: + """Send XGB operation request to the FL server via FLARE message. + + Args: + op: the XGB operation + req: operation data + + Returns: operation result + + """ + reply = self.sender.send_to_server(op, req, self.abort_signal) + if isinstance(reply, Shareable): + rcv_buf = reply.get(Constant.PARAM_KEY_RCV_BUF) + if not isinstance(rcv_buf, bytes): + raise RuntimeError(f"invalid rcv_buf for {op=}: expect bytes but got {type(rcv_buf)}") + return rcv_buf + else: + raise RuntimeError(f"invalid reply for op {op}: expect Shareable but got {type(reply)}") + + def _send_all_gather(self, rank: int, seq: int, send_buf: bytes) -> bytes: + """This method is called by a concrete client adaptor to send Allgather operation to the server. + + Args: + rank: rank of the client + seq: sequence number of the request + send_buf: input data + + Returns: operation result + + """ + req = Shareable() + req[Constant.PARAM_KEY_RANK] = rank + req[Constant.PARAM_KEY_SEQ] = seq + req[Constant.PARAM_KEY_SEND_BUF] = send_buf + return self._send_request(Constant.OP_ALL_GATHER, req) + + def _send_all_gather_v(self, rank: int, seq: int, send_buf: bytes) -> bytes: + """This method is called by a concrete client adaptor to send AllgatherV operation to the server. + + Args: + rank: rank of the client + seq: sequence number of the request + send_buf: operation input + + Returns: operation result + + """ + req = Shareable() + req[Constant.PARAM_KEY_RANK] = rank + req[Constant.PARAM_KEY_SEQ] = seq + req[Constant.PARAM_KEY_SEND_BUF] = send_buf + return self._send_request(Constant.OP_ALL_GATHER_V, req) + + def _send_all_reduce(self, rank: int, seq: int, data_type: int, reduce_op: int, send_buf: bytes) -> bytes: + """This method is called by a concrete client adaptor to send Allreduce operation to the server. + + Args: + rank: rank of the client + seq: sequence number of the request + data_type: data type of the input + reduce_op: reduce operation to be performed + send_buf: operation input + + Returns: operation result + + """ + req = Shareable() + req[Constant.PARAM_KEY_RANK] = rank + req[Constant.PARAM_KEY_SEQ] = seq + req[Constant.PARAM_KEY_DATA_TYPE] = data_type + req[Constant.PARAM_KEY_REDUCE_OP] = reduce_op + req[Constant.PARAM_KEY_SEND_BUF] = send_buf + return self._send_request(Constant.OP_ALL_REDUCE, req) + + def _send_broadcast(self, rank: int, seq: int, root: int, send_buf: bytes) -> bytes: + """This method is called by a concrete client adaptor to send Broadcast operation to the server. + + Args: + rank: rank of the client + seq: sequence number of the request + root: root rank of the broadcast + send_buf: operation input + + Returns: operation result + + """ + req = Shareable() + req[Constant.PARAM_KEY_RANK] = rank + req[Constant.PARAM_KEY_SEQ] = seq + req[Constant.PARAM_KEY_ROOT] = root + req[Constant.PARAM_KEY_SEND_BUF] = send_buf + return self._send_request(Constant.OP_BROADCAST, req) diff --git a/nvflare/app_common/xgb/adaptors/grpc_client_adaptor.py b/nvflare/app_common/xgb/adaptors/grpc_client_adaptor.py new file mode 100644 index 0000000000..e2f82b7483 --- /dev/null +++ b/nvflare/app_common/xgb/adaptors/grpc_client_adaptor.py @@ -0,0 +1,253 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing +import threading + +import nvflare.app_common.xgb.proto.federated_pb2 as pb2 +from nvflare.apis.fl_constant import FLContextKey +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.xgb.adaptors.adaptor import XGBClientAdaptor +from nvflare.app_common.xgb.defs import Constant +from nvflare.app_common.xgb.grpc_server import GrpcServer +from nvflare.app_common.xgb.proto.federated_pb2_grpc import FederatedServicer +from nvflare.fuel.f3.drivers.net_utils import get_open_tcp_port +from nvflare.security.logging import secure_format_exception, secure_log_traceback + + +class _ClientStarter: + """This small class is used to start XGB client runner. It is used when running the runner in a thread + or in a separate process. + + """ + + def __init__(self, runner): + self.xgb_runner = runner + self.error = None + self.started = True + self.stopped = False + + def start(self, ctx: dict): + """Start the runner and wait for it to finish. + + Args: + ctx: + + Returns: + + """ + try: + self.xgb_runner.run(ctx) + self.stopped = True + except Exception as e: + secure_log_traceback() + self.error = f"Exception happens when running xgb train: {secure_format_exception(e)}" + self.started = False + + +class GrpcClientAdaptor(XGBClientAdaptor, FederatedServicer): + def __init__( + self, + int_server_grpc_options=None, + in_process=True, + ): + XGBClientAdaptor.__init__(self) + self.int_server_grpc_options = int_server_grpc_options + self.in_process = in_process + self.internal_xgb_server = None + self.stopped = False + self.internal_server_addr = None + self._training_stopped = False + self._client_name = None + self._app_dir = None + self._workspace = None + self._run_dir = None + self._process = None + self._starter = None + + def initialize(self, fl_ctx: FLContext): + self._client_name = fl_ctx.get_identity_name() + engine = fl_ctx.get_engine() + ws = engine.get_workspace() + self._app_dir = ws.get_app_dir(fl_ctx.get_job_id()) + self._workspace = fl_ctx.get_prop(FLContextKey.WORKSPACE_OBJECT) + run_number = fl_ctx.get_prop(FLContextKey.CURRENT_RUN) + self._run_dir = self._workspace.get_run_dir(run_number) + + def _start_client(self, server_addr: str): + """Start the XGB client runner in a separate thread or separate process based on config. + Note that when starting runner in a separate process, we must not call a method defined in this + class since the self object contains a sender that contains a Core Cell which cannot be sent to + the new process. Instead, we use a small _ClientStarter object to run the process. + + Args: + server_addr: the internal gRPC server address that the XGB client will connect to + + Returns: None + + """ + ctx = { + Constant.RUNNER_CTX_WORLD_SIZE: self.world_size, + Constant.RUNNER_CTX_CLIENT_NAME: self._client_name, + Constant.RUNNER_CTX_SERVER_ADDR: server_addr, + Constant.RUNNER_CTX_RANK: self.rank, + Constant.RUNNER_CTX_NUM_ROUNDS: self.num_rounds, + Constant.RUNNER_CTX_MODEL_DIR: self._run_dir, + Constant.RUNNER_CTX_TB_DIR: self._app_dir, + } + starter = _ClientStarter(self.xgb_runner) + if self.in_process: + self.logger.info("starting XGB client in another thread") + t = threading.Thread( + target=starter.start, + args=(ctx,), + daemon=True, + name="xgb_client_thread_runner", + ) + t.start() + if not starter.started: + self.logger.error(f"cannot start XGB client: {starter.error}") + raise RuntimeError(starter.error) + self._starter = starter + else: + # start as a separate local process + self.logger.info("starting XGB client in another process") + self._process = multiprocessing.Process( + target=starter.start, + args=(ctx,), + daemon=True, + name="xgb_client_process_runner", + ) + self._process.start() + + def _stop_client(self): + self._training_stopped = True + if self.in_process: + if self.xgb_runner: + self.xgb_runner.stop() + else: + if self._process: + self._process.kill() + + def _is_stopped(self) -> (bool, int): + if self.in_process: + if self._starter: + if self._starter.stopped: + return True, 0 + + if self._training_stopped: + return True, 0 + + if self.xgb_runner: + return self.xgb_runner.is_stopped() + else: + return True, 0 + else: + if self._process: + assert isinstance(self._process, multiprocessing.Process) + ec = self._process.exitcode + if ec is None: + return False, 0 + else: + return True, ec + else: + return True, 0 + + def start(self, fl_ctx: FLContext): + if self.rank is None: + raise RuntimeError("cannot start - my rank is not set") + + if not self.num_rounds: + raise RuntimeError("cannot start - num_rounds is not set") + + # dynamically determine address on localhost + port = get_open_tcp_port(resources={}) + if not port: + raise RuntimeError("failed to get a port for XGB server") + self.internal_server_addr = f"127.0.0.1:{port}" + self.logger.info(f"Start internal server at {self.internal_server_addr}") + self.internal_xgb_server = GrpcServer(self.internal_server_addr, 10, self.int_server_grpc_options, self) + self.internal_xgb_server.start(no_blocking=True) + self.logger.info(f"Started internal server at {self.internal_server_addr}") + self._start_client(self.internal_server_addr) + self.logger.info("Started external XGB Client") + + def stop(self, fl_ctx: FLContext): + if self.stopped: + return + + self.stopped = True + self._stop_client() + + if self.internal_xgb_server: + self.logger.info("Stop internal XGB Server") + self.internal_xgb_server.shutdown() + + def _abort(self, reason: str): + # stop the gRPC XGB client (the target) + self.abort_signal.trigger(True) + + # abort the FL client + with self.engine.new_context() as fl_ctx: + self.system_panic(reason, fl_ctx) + + def Allgather(self, request: pb2.AllgatherRequest, context): + try: + rcv_buf = self._send_all_gather( + rank=request.rank, + seq=request.sequence_number, + send_buf=request.send_buffer, + ) + return pb2.AllgatherReply(receive_buffer=rcv_buf) + except Exception as ex: + self._abort(reason=f"send_all_gather exception: {secure_format_exception(ex)}") + return None + + def AllgatherV(self, request: pb2.AllgatherVRequest, context): + try: + rcv_buf = self._send_all_gather_v( + rank=request.rank, + seq=request.sequence_number, + send_buf=request.send_buffer, + ) + return pb2.AllgatherVReply(receive_buffer=rcv_buf) + except Exception as ex: + self._abort(reason=f"send_all_gather_v exception: {secure_format_exception(ex)}") + return None + + def Allreduce(self, request: pb2.AllreduceRequest, context): + try: + rcv_buf = self._send_all_reduce( + rank=request.rank, + seq=request.sequence_number, + data_type=request.data_type, + reduce_op=request.reduce_operation, + send_buf=request.send_buffer, + ) + return pb2.AllreduceReply(receive_buffer=rcv_buf) + except Exception as ex: + self._abort(reason=f"send_all_reduce exception: {secure_format_exception(ex)}") + return None + + def Broadcast(self, request: pb2.BroadcastRequest, context): + try: + rcv_buf = self._send_broadcast( + rank=request.rank, + seq=request.sequence_number, + root=request.root, + send_buf=request.send_buffer, + ) + return pb2.BroadcastReply(receive_buffer=rcv_buf) + except Exception as ex: + self._abort(reason=f"send_broadcast exception: {secure_format_exception(ex)}") + return None diff --git a/nvflare/app_common/xgb/adaptors/grpc_server_adaptor.py b/nvflare/app_common/xgb/adaptors/grpc_server_adaptor.py new file mode 100644 index 0000000000..540c164745 --- /dev/null +++ b/nvflare/app_common/xgb/adaptors/grpc_server_adaptor.py @@ -0,0 +1,158 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing +import threading + +import nvflare.app_common.xgb.proto.federated_pb2 as pb2 +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.xgb.adaptors.adaptor import XGBServerAdaptor +from nvflare.app_common.xgb.defs import Constant +from nvflare.app_common.xgb.grpc_client import GrpcClient +from nvflare.fuel.f3.drivers.net_utils import get_open_tcp_port +from nvflare.security.logging import secure_format_exception + + +class GrpcServerAdaptor(XGBServerAdaptor): + def __init__( + self, + int_client_grpc_options=None, + xgb_server_ready_timeout=Constant.XGB_SERVER_READY_TIMEOUT, + in_process=True, + ): + XGBServerAdaptor.__init__(self) + self.int_client_grpc_options = int_client_grpc_options + self.xgb_server_ready_timeout = xgb_server_ready_timeout + self.in_process = in_process + self.internal_xgb_client = None + self._process = None + self._server_stopped = False + + def _try_start_server(self, addr: str, port: int, world_size: int): + ctx = { + Constant.RUNNER_CTX_SERVER_ADDR: addr, + Constant.RUNNER_CTX_WORLD_SIZE: world_size, + Constant.RUNNER_CTX_PORT: port, + } + try: + self.xgb_runner.run(ctx) + except Exception as ex: + self.logger.error(f"Exception running xgb_runner {ctx=}: {secure_format_exception(ex)}") + raise ex + + def _start_server(self, addr: str, port: int, world_size: int): + if self.in_process: + self.logger.info("starting XGB server in another thread") + t = threading.Thread( + name="xgb_server_thread", target=self._try_start_server, args=(addr, port, world_size), daemon=True + ) + t.start() + else: + self.logger.info("starting XGB server in another process") + self._process = multiprocessing.Process( + name="xgb_server_process", target=self._try_start_server, args=(addr, port, world_size), daemon=True + ) + self._process.start() + + def _stop_server(self): + self._server_stopped = True + if self.in_process: + if self.xgb_runner: + self.xgb_runner.stop() + else: + if self._process: + self._process.kill() + self._process = None + + def _is_stopped(self) -> (bool, int): + if self._server_stopped: + return True, 0 + + if self.in_process: + if self.xgb_runner: + return self.xgb_runner.is_stopped() + else: + return True, 0 + else: + if self._process: + assert isinstance(self._process, multiprocessing.Process) + ec = self._process.exitcode + if ec is None: + return False, 0 + else: + return True, ec + else: + return True, 0 + + def start(self, fl_ctx: FLContext): + # we dynamically create server address on localhost + port = get_open_tcp_port(resources={}) + if not port: + raise RuntimeError("failed to get a port for XGB server") + + server_addr = f"127.0.0.1:{port}" + self._start_server(server_addr, port, self.world_size) + + # start XGB client + self.internal_xgb_client = GrpcClient(server_addr, self.int_client_grpc_options) + self.internal_xgb_client.start(ready_timeout=self.xgb_server_ready_timeout) + + def stop(self, fl_ctx: FLContext): + client = self.internal_xgb_client + self.internal_xgb_client = None + if client: + self.log_info(fl_ctx, "Stopping internal XGB client") + client.stop() + self._stop_server() + + def all_gather(self, rank: int, seq: int, send_buf: bytes, fl_ctx: FLContext) -> bytes: + result = self.internal_xgb_client.send_allgather(seq_num=seq, rank=rank, data=send_buf) + if isinstance(result, pb2.AllgatherReply): + return result.receive_buffer + else: + raise RuntimeError(f"bad result from XGB server: expect AllgatherReply but got {type(result)}") + + def all_gather_v(self, rank: int, seq: int, send_buf: bytes, fl_ctx: FLContext) -> bytes: + result = self.internal_xgb_client.send_allgatherv(seq_num=seq, rank=rank, data=send_buf) + if isinstance(result, pb2.AllgatherVReply): + return result.receive_buffer + else: + raise RuntimeError(f"bad result from XGB server: expect AllgatherVReply but got {type(result)}") + + def all_reduce( + self, + rank: int, + seq: int, + data_type: int, + reduce_op: int, + send_buf: bytes, + fl_ctx: FLContext, + ) -> bytes: + result = self.internal_xgb_client.send_allreduce( + seq_num=seq, + rank=rank, + data=send_buf, + data_type=data_type, + reduce_op=reduce_op, + ) + if isinstance(result, pb2.AllreduceReply): + return result.receive_buffer + else: + raise RuntimeError(f"bad result from XGB server: expect AllreduceReply but got {type(result)}") + + def broadcast(self, rank: int, seq: int, root: int, send_buf: bytes, fl_ctx: FLContext) -> bytes: + result = self.internal_xgb_client.send_broadcast(seq_num=seq, rank=rank, data=send_buf, root=root) + if isinstance(result, pb2.BroadcastReply): + return result.receive_buffer + else: + raise RuntimeError(f"bad result from XGB server: expect BroadcastReply but got {type(result)}") diff --git a/nvflare/app_common/xgb/controller.py b/nvflare/app_common/xgb/controller.py new file mode 100644 index 0000000000..83bd6f90a1 --- /dev/null +++ b/nvflare/app_common/xgb/controller.py @@ -0,0 +1,609 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +import time + +from nvflare.apis.client import Client +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import ClientTask, Controller, Task +from nvflare.apis.shareable import ReturnCode, Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.xgb.adaptors.adaptor import XGBServerAdaptor +from nvflare.fuel.utils.validation_utils import check_number_range, check_object_type, check_positive_number, check_str +from nvflare.security.logging import secure_format_exception + +from .defs import Constant + + +class ClientStatus: + """ + Objects of this class keep processing status of each FL client during job execution. + """ + + def __init__(self): + # Set when the client's config reply is received and the reply return code is OK. + # If the client failed to reply or the return code is not OK, this value is not set. + self.configured_time = None + + # Set when the client's start reply is received and the reply return code is OK. + # If the client failed to reply or the return code is not OK, this value is not set. + self.started_time = None + + # operation of the last XGB request from this client + self.last_op = None + + # time of the last XGB op request from this client + self.last_op_time = time.time() + + # whether the XGB process is done on this client + self.xgb_done = False + + +class XGBController(Controller): + def __init__( + self, + adaptor_component_id: str, + num_rounds: int, + configure_task_name=Constant.CONFIG_TASK_NAME, + configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT, + start_task_name=Constant.START_TASK_NAME, + start_task_timeout=Constant.START_TASK_TIMEOUT, + job_status_check_interval: float = Constant.JOB_STATUS_CHECK_INTERVAL, + max_client_op_interval: float = Constant.MAX_CLIENT_OP_INTERVAL, + progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, + client_ranks=None, + ): + """ + Constructor + + Args: + adaptor_component_id - the component ID of server target adaptor + num_rounds - number of rounds + configure_task_name - name of the config task + configure_task_timeout - time to wait for clients’ responses to the config task before timeout. + start_task_name - name of the start task + start_task_timeout - time to wait for clients’ responses to the start task before timeout. + job_status_check_interval - how often to check client statuses of the job + max_client_op_interval - max amount of time allowed between XGB ops from a client + progress_timeout- the maximum amount of time allowed for the workflow to not make any progress. + In other words, at least one participating client must have made progress during this time. + Otherwise, the workflow will be considered to be in trouble and the job will be aborted. + client_ranks: client rank assignments. + If specified, must be a dict of client_name => rank. + If not specified, client ranks will be randomly assigned. + No matter how assigned, ranks must be consecutive integers, starting from 0. + """ + Controller.__init__(self) + self.adaptor_component_id = adaptor_component_id + self.num_rounds = num_rounds + self.configure_task_name = configure_task_name + self.start_task_name = start_task_name + self.start_task_timeout = start_task_timeout + self.configure_task_timeout = configure_task_timeout + self.max_client_op_interval = max_client_op_interval + self.progress_timeout = progress_timeout + self.job_status_check_interval = job_status_check_interval + self.client_ranks = client_ranks # client rank assignments + + self.adaptor = None + self.participating_clients = None + self.status_lock = threading.Lock() + self.client_statuses = {} # client name => ClientStatus + self.abort_signal = None + + check_str("adaptor_component_id", adaptor_component_id) + check_number_range("configure_task_timeout", configure_task_timeout, min_value=1) + check_number_range("start_task_timeout", start_task_timeout, min_value=1) + check_positive_number("job_status_check_interval", job_status_check_interval) + check_positive_number("num_rounds", num_rounds) + check_number_range("max_client_op_interval", max_client_op_interval, min_value=10.0) + check_number_range("progress_timeout", progress_timeout, min_value=5.0) + if client_ranks: + check_object_type("client_ranks", client_ranks, dict) + + # set up operation handlers + self.op_table = { + Constant.OP_ALL_GATHER: self._process_all_gather, + Constant.OP_ALL_GATHER_V: self._process_all_gather_v, + Constant.OP_ALL_REDUCE: self._process_all_reduce, + Constant.OP_BROADCAST: self._process_broadcast, + } + + def get_adaptor(self, fl_ctx: FLContext): + engine = fl_ctx.get_engine() + return engine.get_component(self.adaptor_component_id) + + def start_controller(self, fl_ctx: FLContext): + all_clients = self._engine.get_clients() + self.participating_clients = [t.name for t in all_clients] + + for c in self.participating_clients: + self.client_statuses[c] = ClientStatus() + + adaptor = self.get_adaptor(fl_ctx) + if not adaptor: + self.system_panic(f"cannot get component for {self.adaptor_component_id}", fl_ctx) + return None + + if not isinstance(adaptor, XGBServerAdaptor): + self.system_panic( + f"invalid component '{self.adaptor_component_id}': expect XGBServerBridge but got {type(adaptor)}", + fl_ctx, + ) + return None + + adaptor.initialize(fl_ctx) + self.adaptor = adaptor + + engine = fl_ctx.get_engine() + engine.register_aux_message_handler( + topic=Constant.TOPIC_XGB_REQUEST, + message_handle_func=self._process_xgb_request, + ) + engine.register_aux_message_handler( + topic=Constant.TOPIC_CLIENT_DONE, + message_handle_func=self._process_client_done, + ) + + def _trigger_stop(self, fl_ctx: FLContext, error=None): + # first trigger the abort_signal to tell all components (mainly the controller's control_flow and adaptor) + # that check this signal to abort. + if self.abort_signal: + self.abort_signal.trigger(value=True) + + # if there is error, call system_panic to terminate the job with proper status. + # if no error, the job will end normally. + if error: + self.system_panic(reason=error, fl_ctx=fl_ctx) + + def _is_stopped(self): + # check whether the abort signal is triggered + return self.abort_signal and self.abort_signal.triggered + + def _update_client_status(self, fl_ctx: FLContext, op=None, client_done=False): + """Update the status of the requesting client. + + Args: + fl_ctx: FL context + op: the XGB operation requested + client_done: whether the client is done + + Returns: None + + """ + with self.status_lock: + peer_ctx = fl_ctx.get_peer_context() + if not peer_ctx: + self.log_error(fl_ctx, "missing peer_ctx from fl_ctx") + return + if not isinstance(peer_ctx, FLContext): + self.log_error(fl_ctx, f"expect peer_ctx to be FLContext but got {type(peer_ctx)}") + return + client_name = peer_ctx.get_identity_name() + if not client_name: + self.log_error(fl_ctx, "missing identity from peer_ctx") + return + status = self.client_statuses.get(client_name) + if not status: + self.log_error(fl_ctx, f"no status record for client {client_name}") + assert isinstance(status, ClientStatus) + if op: + status.last_op = op + if client_done: + status.xgb_done = client_done + status.last_op_time = time.time() + + def _process_client_done(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: + """Process the ClientDone report for a client + + Args: + topic: topic of the message + request: request to be processed + fl_ctx: the FL context + + Returns: reply to the client + + """ + exit_code = request.get(Constant.MSG_KEY_EXIT_CODE) + + # TBD: should we check the exit_code and determine job status? + # Problem is that even if the exit_code is not 0, we can't say the job failed. + if exit_code == 0: + self.log_info(fl_ctx, f"XGB client is done with exit code {exit_code}") + else: + self.log_warning(fl_ctx, f"XGB client is done with exit code {exit_code}") + + self._update_client_status(fl_ctx, client_done=True) + return make_reply(ReturnCode.OK) + + def _process_all_gather(self, request: Shareable, fl_ctx: FLContext) -> Shareable: + """This is the op handler for Allgather. + + Args: + request: the request containing op params + fl_ctx: FL context + + Returns: a Shareable containing operation result + + """ + rank = request.get(Constant.PARAM_KEY_RANK) + seq = request.get(Constant.PARAM_KEY_SEQ) + send_buf = request.get(Constant.PARAM_KEY_SEND_BUF) + rcv_buf = self.adaptor.all_gather(rank, seq, send_buf, fl_ctx) + reply = Shareable() + reply[Constant.PARAM_KEY_RCV_BUF] = rcv_buf + return reply + + def _process_all_gather_v(self, request: Shareable, fl_ctx: FLContext) -> Shareable: + """This is the op handler for AllgatherV. + + Args: + request: the request containing op params + fl_ctx: FL context + + Returns: a Shareable containing operation result + + """ + rank = request.get(Constant.PARAM_KEY_RANK) + seq = request.get(Constant.PARAM_KEY_SEQ) + send_buf = request.get(Constant.PARAM_KEY_SEND_BUF) + rcv_buf = self.adaptor.all_gather_v(rank, seq, send_buf, fl_ctx) + reply = Shareable() + reply[Constant.PARAM_KEY_RCV_BUF] = rcv_buf + return reply + + def _process_all_reduce(self, request: Shareable, fl_ctx: FLContext) -> Shareable: + """This is the op handler for Allreduce. + + Args: + request: the request containing op params + fl_ctx: FL context + + Returns: a Shareable containing operation result + + """ + rank = request.get(Constant.PARAM_KEY_RANK) + seq = request.get(Constant.PARAM_KEY_SEQ) + send_buf = request.get(Constant.PARAM_KEY_SEND_BUF) + data_type = request.get(Constant.PARAM_KEY_DATA_TYPE) + reduce_op = request.get(Constant.PARAM_KEY_REDUCE_OP) + assert isinstance(self.adaptor, XGBServerAdaptor) + rcv_buf = self.adaptor.all_reduce(rank, seq, data_type, reduce_op, send_buf, fl_ctx) + reply = Shareable() + reply[Constant.PARAM_KEY_RCV_BUF] = rcv_buf + return reply + + def _process_broadcast(self, request: Shareable, fl_ctx: FLContext) -> Shareable: + """This is the op handler for Broadcast. + + Args: + request: the request containing op params + fl_ctx: FL context + + Returns: a Shareable containing operation result + + """ + rank = request.get(Constant.PARAM_KEY_RANK) + seq = request.get(Constant.PARAM_KEY_SEQ) + send_buf = request.get(Constant.PARAM_KEY_SEND_BUF) + root = request.get(Constant.PARAM_KEY_ROOT) + assert isinstance(self.adaptor, XGBServerAdaptor) + rcv_buf = self.adaptor.broadcast(rank, seq, root, send_buf, fl_ctx) + reply = Shareable() + reply[Constant.PARAM_KEY_RCV_BUF] = rcv_buf + return reply + + def _process_xgb_request(self, topic: str, request: Shareable, fl_ctx: FLContext) -> Shareable: + op = request.get_header(Constant.MSG_KEY_XGB_OP) + if self._is_stopped(): + self.log_error(fl_ctx, f"dropped XGB request '{op}' since server is already stopped") + return make_reply(ReturnCode.SERVICE_UNAVAILABLE) + + # since XGB protocol is very strict, we'll stop the control flow when any error occurs + bad_req_error = "bad XGB request" + process_error = "XGB request process error" + if not op: + self.log_error(fl_ctx, "missing op from XGB request") + self._trigger_stop(fl_ctx, bad_req_error) + return make_reply(ReturnCode.BAD_REQUEST_DATA) + + # find and call the op handlers + process_f = self.op_table.get(op) + if process_f is None: + self.log_error(fl_ctx, f"invalid op '{op}' from XGB request") + self._trigger_stop(fl_ctx, bad_req_error) + return make_reply(ReturnCode.BAD_REQUEST_DATA) + + self._update_client_status(fl_ctx, op=op) + + if not callable(process_f): + # impossible but we must declare process_f to be callable; otherwise PyCharm will complain about + # process_f(request, fl_ctx). + raise RuntimeError(f"op handler for {op} is not callable") + try: + reply = process_f(request, fl_ctx) + except Exception as ex: + self.log_exception(fl_ctx, f"exception processing {op}: {secure_format_exception(ex)}") + self._trigger_stop(fl_ctx, process_error) + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + self.log_info(fl_ctx, f"received reply for '{op}'") + reply.set_header(Constant.MSG_KEY_XGB_OP, op) + return reply + + def _configure_clients(self, abort_signal: Signal, fl_ctx: FLContext): + self.log_info(fl_ctx, f"Configuring clients {self.participating_clients}") + + shareable = Shareable() + + # compute client ranks + if not self.client_ranks: + # dynamically assign ranks, starting from 0 + # Assumption: all clients are used + clients = self.participating_clients + + # Sort by client name so rank is consistent + clients.sort() + self.client_ranks = {clients[i]: i for i in range(0, len(clients))} + else: + # validate ranks - ranks must be unique consecutive integers, starting from 0. + num_clients = len(self.participating_clients) + assigned_ranks = {} # rank => client + if len(self.client_ranks) != num_clients: + # either missing client or duplicate client + self.system_panic( + f"expecting rank assignments for {self.participating_clients} but got {self.client_ranks}", fl_ctx + ) + return False + + # all clients must have ranks + for c in self.participating_clients: + if c not in self.client_ranks: + self.system_panic(f"missing rank assignment for client '{c}'", fl_ctx) + return False + + # check each client's rank + for c, r in self.client_ranks.items(): + if not isinstance(r, int): + self.system_panic(f"bad rank assignment {r} for client '{c}': expect int but got {type(r)}", fl_ctx) + return False + + if r < 0 or r >= num_clients: + self.system_panic(f"bad rank assignment {r} for client '{c}': must be 0 to {num_clients-1}", fl_ctx) + return False + + assigned_client = assigned_ranks.get(r) + if assigned_client: + self.system_panic(f"rank {r} is assigned to both client '{c}' and '{assigned_client}'", fl_ctx) + return False + + assigned_ranks[r] = c + + shareable[Constant.CONF_KEY_CLIENT_RANKS] = self.client_ranks + shareable[Constant.CONF_KEY_NUM_ROUNDS] = self.num_rounds + + task = Task( + name=self.configure_task_name, + data=shareable, + timeout=self.configure_task_timeout, + result_received_cb=self._process_configure_reply, + ) + + self.log_info(fl_ctx, f"sending task {self.configure_task_name} to clients {self.participating_clients}") + start_time = time.time() + self.broadcast_and_wait( + task=task, + targets=self.participating_clients, + min_responses=len(self.participating_clients), + fl_ctx=fl_ctx, + abort_signal=abort_signal, + ) + + time_taken = time.time() - start_time + self.log_info(fl_ctx, f"client configuration took {time_taken} seconds") + + failed_clients = [] + for c, cs in self.client_statuses.items(): + assert isinstance(cs, ClientStatus) + if not cs.configured_time: + failed_clients.append(c) + + # if any client failed to configure, terminate the job + if failed_clients: + self.system_panic(f"failed to configure clients {failed_clients}", fl_ctx) + return False + + self.log_info(fl_ctx, f"successfully configured clients {self.participating_clients}") + return True + + def _start_clients(self, abort_signal: Signal, fl_ctx: FLContext): + self.log_info(fl_ctx, f"Starting clients {self.participating_clients}") + + task = Task( + name=self.start_task_name, + data=Shareable(), + timeout=self.start_task_timeout, + result_received_cb=self._process_start_reply, + ) + + self.log_info(fl_ctx, f"sending task {self.start_task_name} to clients {self.participating_clients}") + start_time = time.time() + self.broadcast_and_wait( + task=task, + targets=self.participating_clients, + min_responses=len(self.participating_clients), + fl_ctx=fl_ctx, + abort_signal=abort_signal, + ) + + time_taken = time.time() - start_time + self.log_info(fl_ctx, f"client starting took {time_taken} seconds") + + failed_clients = [] + for c, cs in self.client_statuses.items(): + assert isinstance(cs, ClientStatus) + if not cs.started_time: + failed_clients.append(c) + + # if any client failed to start, terminate the job + if failed_clients: + self.system_panic(f"failed to start clients {failed_clients}", fl_ctx) + return False + + self.log_info(fl_ctx, f"successfully started clients {self.participating_clients}") + return True + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + """ + This is the control flow of the XGB Controller. To ensure smooth XGB execution: + - ensure that all clients are online and ready to go before starting server + - ensure that server is started and ready to take requests before asking clients to start operation + - monitor the health of the clients + - if anything goes wrong, terminate the job + + Args: + abort_signal: abort signal that is used to notify components to abort + fl_ctx: FL context + + Returns: None + + """ + self.abort_signal = abort_signal + + # the adaptor uses the same abort signal! + self.adaptor.set_abort_signal(abort_signal) + + # wait for every client to become online and properly configured + self.log_info(fl_ctx, f"Waiting for clients to be ready: {self.participating_clients}") + + # configure all clients + if not self._configure_clients(abort_signal, fl_ctx): + self.system_panic("failed to configure all clients", fl_ctx) + return + + # start the server adaptor + try: + self.adaptor.configure({Constant.CONF_KEY_WORLD_SIZE: len(self.participating_clients)}, fl_ctx) + self.adaptor.start(fl_ctx) + except Exception as ex: + error = f"failed to start bridge: {secure_format_exception(ex)}" + self.log_error(fl_ctx, error) + self.system_panic(error, fl_ctx) + return + + self.adaptor.monitor_target(fl_ctx, self._xgb_server_stopped) + + # start all clients + if not self._start_clients(abort_signal, fl_ctx): + self.system_panic("failed to start all clients", fl_ctx) + return + + # monitor client health + # we periodically check job status until all clients are done or the system is stopped + self.log_info(fl_ctx, "Waiting for clients to finish ...") + while not self._is_stopped(): + done = self._check_job_status(fl_ctx) + if done: + break + time.sleep(self.job_status_check_interval) + + def _xgb_server_stopped(self, rc, fl_ctx: FLContext): + # This CB is called when XGB server target is stopped + error = None + if rc != 0: + self.log_error(fl_ctx, f"XGB Server stopped abnormally with code {rc}") + error = "XGB server abnormal stop" + + # the XGB server could stop at any moment, we trigger the abort_signal in case it is checked by any + # other components + self._trigger_stop(fl_ctx, error) + + def _process_configure_reply(self, client_task: ClientTask, fl_ctx: FLContext): + result = client_task.result + client_name = client_task.client.name + + rc = result.get_return_code() + if rc == ReturnCode.OK: + self.log_info(fl_ctx, f"successfully configured client {client_name}") + cs = self.client_statuses.get(client_name) + if cs: + assert isinstance(cs, ClientStatus) + cs.configured_time = time.time() + else: + self.log_error(fl_ctx, f"client {client_task.client.name} failed to configure: {rc}") + + def _process_start_reply(self, client_task: ClientTask, fl_ctx: FLContext): + result = client_task.result + client_name = client_task.client.name + + rc = result.get_return_code() + if rc == ReturnCode.OK: + self.log_info(fl_ctx, f"successfully started client {client_name}") + cs = self.client_statuses.get(client_name) + if cs: + assert isinstance(cs, ClientStatus) + cs.started_time = time.time() + else: + self.log_error(fl_ctx, f"client {client_name} failed to start") + + def _check_job_status(self, fl_ctx: FLContext) -> bool: + """Check job status and determine whether the job is done. + + Args: + fl_ctx: FL context + + Returns: whether the job is considered done. + + """ + now = time.time() + + # overall_last_progress_time is the latest time that any client made progress. + overall_last_progress_time = 0.0 + clients_done = 0 + for client_name, cs in self.client_statuses.items(): + assert isinstance(cs, ClientStatus) + + if cs.xgb_done: + self.log_info(fl_ctx, f"client {client_name} is Done") + clients_done += 1 + elif now - cs.last_op_time > self.max_client_op_interval: + self.system_panic( + f"client {client_name} didn't have any activity for {self.max_client_op_interval} seconds", + fl_ctx, + ) + return True + + if overall_last_progress_time < cs.last_op_time: + overall_last_progress_time = cs.last_op_time + + if clients_done == len(self.client_statuses): + # all clients are done - the job is considered done + return True + elif time.time() - overall_last_progress_time > self.progress_timeout: + # there has been no progress from any client for too long. + # this could be because the clients got stuck. + # consider the job done and abort the job. + self.system_panic(f"the job has no progress for {self.progress_timeout} seconds", fl_ctx) + return True + return False + + def process_result_of_unknown_task( + self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext + ): + self.log_warning(fl_ctx, f"ignored unknown task {task_name} from client {client.name}") + + def stop_controller(self, fl_ctx: FLContext): + if self.adaptor: + self.log_info(fl_ctx, "Stopping server bridge") + self.adaptor.stop(fl_ctx) diff --git a/nvflare/app_common/xgb/data_loader.py b/nvflare/app_common/xgb/data_loader.py new file mode 100644 index 0000000000..2fa8855c99 --- /dev/null +++ b/nvflare/app_common/xgb/data_loader.py @@ -0,0 +1,30 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod +from typing import Tuple + +import xgboost as xgb + + +class XGBDataLoader(ABC): + @abstractmethod + def load_data(self, client_id: str) -> Tuple[xgb.core.DMatrix, xgb.core.DMatrix]: + """Loads data for xgboost. + + Returns: + A tuple of train_data, validation_data + """ + pass diff --git a/nvflare/app_common/xgb/defs.py b/nvflare/app_common/xgb/defs.py new file mode 100644 index 0000000000..c7997951f6 --- /dev/null +++ b/nvflare/app_common/xgb/defs.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.fuel.f3.drivers.net_utils import MAX_FRAME_SIZE + + +class Constant: + + # task name defaults + CONFIG_TASK_NAME = "config" + START_TASK_NAME = "start" + + # keys of adaptor config parameters + CONF_KEY_CLIENT_RANKS = "client_ranks" + CONF_KEY_RANK = "rank" + CONF_KEY_WORLD_SIZE = "world_size" + CONF_KEY_NUM_ROUNDS = "num_rounds" + + # default component config values + CONFIG_TASK_TIMEOUT = 10 + START_TASK_TIMEOUT = 10 + XGB_SERVER_READY_TIMEOUT = 10.0 + + TASK_CHECK_INTERVAL = 0.5 + JOB_STATUS_CHECK_INTERVAL = 2.0 + MAX_CLIENT_OP_INTERVAL = 90.0 + WORKFLOW_PROGRESS_TIMEOUT = 3600.0 + + # message topics + TOPIC_XGB_REQUEST = "xgb.request" + TOPIC_XGB_REQUEST_CHECK = "xgb.req_check" + TOPIC_CLIENT_DONE = "xgb.client_done" + + # keys for Shareable between client and server + MSG_KEY_EXIT_CODE = "exit_code" + MSG_KEY_XGB_OP = "xgb.op" + MSG_KEY_XGB_REQ_ID = "xgb.req_id" + MSG_KEY_XGB_REQ_TRY_NUM = "xgb.req_try_num" + MSG_KEY_XGB_REQ_RECEIVED = "xgb.req_received" + + # XGB operation names + OP_ALL_GATHER = "all_gather" + OP_ALL_GATHER_V = "all_gather_v" + OP_ALL_REDUCE = "all_reduce" + OP_BROADCAST = "broadcast" + + # XGB operation codes + OPCODE_NONE = 0 + OPCODE_ALL_GATHER = 1 + OPCODE_ALL_GATHER_V = 2 + OPCODE_ALL_REDUCE = 3 + OPCODE_BROADCAST = 4 + OPCODE_DONE = 99 + + # XGB operation error codes + ERR_OP_MISMATCH = -1 + ERR_INVALID_RANK = -2 + ERR_NO_CLIENT_FOR_RANK = -3 + ERR_TARGET_ERROR = -4 + + # XGB operation parameter keys + PARAM_KEY_RANK = "xgb.rank" + PARAM_KEY_SEQ = "xgb.seq" + PARAM_KEY_SEND_BUF = "xgb.send_buf" + PARAM_KEY_DATA_TYPE = "xgb.data_type" + PARAM_KEY_REDUCE_OP = "xgb.reduce_op" + PARAM_KEY_ROOT = "xgb.root" + PARAM_KEY_RCV_BUF = "xgb.rcv_buf" + + RUNNER_CTX_SERVER_ADDR = "server_addr" + RUNNER_CTX_PORT = "port" + RUNNER_CTX_CLIENT_NAME = "client_name" + RUNNER_CTX_NUM_ROUNDS = "num_rounds" + RUNNER_CTX_WORLD_SIZE = "world_size" + RUNNER_CTX_RANK = "rank" + RUNNER_CTX_DATA_LOADER = "data_loader" + RUNNER_CTX_TB_DIR = "tb_dir" + RUNNER_CTX_MODEL_DIR = "model_dir" + + +GRPC_DEFAULT_OPTIONS = [ + ("grpc.max_send_message_length", MAX_FRAME_SIZE), + ("grpc.max_receive_message_length", MAX_FRAME_SIZE), +] diff --git a/nvflare/app_common/xgb/executor.py b/nvflare/app_common/xgb/executor.py new file mode 100644 index 0000000000..b647b3b308 --- /dev/null +++ b/nvflare/app_common/xgb/executor.py @@ -0,0 +1,171 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.xgb.adaptors.adaptor import XGBClientAdaptor +from nvflare.fuel.f3.cellnet.fqcn import FQCN +from nvflare.security.logging import secure_format_exception + +from .defs import Constant +from .sender import Sender + + +class XGBExecutor(Executor): + def __init__( + self, + adaptor_component_id: str, + configure_task_name=Constant.CONFIG_TASK_NAME, + start_task_name=Constant.START_TASK_NAME, + req_timeout=10.0, + ): + """Constructor + + Args: + adaptor_component_id: the component ID of client target adaptor + configure_task_name: name of the config task + start_task_name: name of the start task + """ + Executor.__init__(self) + self.adaptor_component_id = adaptor_component_id + self.req_timeout = req_timeout + self.configure_task_name = configure_task_name + self.start_task_name = start_task_name + self.adaptor = None + + # create the abort signal to be used for signaling the adaptor + self.abort_signal = Signal() + + def get_adaptor(self, fl_ctx: FLContext): + """Get adaptor to be used by this executor. + This is the default implementation that gets the adaptor based on configured adaptor_component_id. + A subclass of XGBExecutor may get adaptor in a different way. + + Args: + fl_ctx: the FL context + + Returns: a XGBClientAdaptor object + + """ + engine = fl_ctx.get_engine() + return engine.get_component(self.adaptor_component_id) + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + adaptor = self.get_adaptor(fl_ctx) + if not adaptor: + self.system_panic(f"cannot get component for {self.adaptor_component_id}", fl_ctx) + return + + if not isinstance(adaptor, XGBClientAdaptor): + self.system_panic( + f"invalid component '{self.adaptor_component_id}': expect XGBClientAdaptor but got {type(adaptor)}", + fl_ctx, + ) + return + + adaptor.set_abort_signal(self.abort_signal) + engine = fl_ctx.get_engine() + adaptor.set_sender(Sender(engine, self.req_timeout)) + adaptor.initialize(fl_ctx) + self.adaptor = adaptor + elif event_type == EventType.END_RUN: + self.abort_signal.trigger(True) + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + if task_name == self.configure_task_name: + # there are two important config params for the client: + # the rank assigned to the client; + # number of rounds for training. + ranks = shareable.get(Constant.CONF_KEY_CLIENT_RANKS) + if not ranks: + self.log_error(fl_ctx, f"missing {Constant.CONF_KEY_CLIENT_RANKS} from config") + return make_reply(ReturnCode.BAD_TASK_DATA) + + if not isinstance(ranks, dict): + self.log_error(fl_ctx, f"expect config data to be dict but got {ranks}") + return make_reply(ReturnCode.BAD_TASK_DATA) + + me = fl_ctx.get_identity_name() + my_rank = ranks.get(me) + if my_rank is None: + self.log_error(fl_ctx, f"missing rank for me ({me}) in config data") + return make_reply(ReturnCode.BAD_TASK_DATA) + + self.log_info(fl_ctx, f"got my rank: {my_rank}") + + num_rounds = shareable.get(Constant.CONF_KEY_NUM_ROUNDS) + if not num_rounds: + self.log_error(fl_ctx, f"missing {Constant.CONF_KEY_NUM_ROUNDS} from config") + return make_reply(ReturnCode.BAD_TASK_DATA) + + world_size = len(ranks) + + # configure the XGB client target via the adaptor + self.adaptor.configure( + { + Constant.CONF_KEY_RANK: my_rank, + Constant.CONF_KEY_NUM_ROUNDS: num_rounds, + Constant.CONF_KEY_WORLD_SIZE: world_size, + }, + fl_ctx, + ) + return make_reply(ReturnCode.OK) + elif task_name == self.start_task_name: + # start adaptor + try: + self.adaptor.start(fl_ctx) + except Exception as ex: + self.log_exception(fl_ctx, f"failed to start adaptor: {secure_format_exception(ex)}") + return make_reply(ReturnCode.EXECUTION_EXCEPTION) + + # start to monitor the XGB target via the adaptor + self.adaptor.monitor_target(fl_ctx, self._notify_client_done) + return make_reply(ReturnCode.OK) + else: + self.log_error(fl_ctx, f"ignored unsupported {task_name}") + return make_reply(ReturnCode.TASK_UNSUPPORTED) + + def _notify_client_done(self, rc, fl_ctx: FLContext): + """This is called when the XGB client target is done. + We send a message to the FL server telling it that this client is done. + + Args: + rc: the return code from the XGB client target + fl_ctx: FL context + + Returns: None + + """ + if rc != 0: + self.log_error(fl_ctx, f"XGB Client stopped with RC {rc}") + else: + self.log_info(fl_ctx, "XGB Client Stopped") + + # tell server that this client is done + engine = fl_ctx.get_engine() + req = Shareable() + req[Constant.MSG_KEY_EXIT_CODE] = rc + engine.send_aux_request( + targets=[FQCN.ROOT_SERVER], + topic=Constant.TOPIC_CLIENT_DONE, + request=req, + timeout=0, # fire and forget + fl_ctx=fl_ctx, + optional=True, + ) diff --git a/nvflare/app_common/xgb/fed_controller.py b/nvflare/app_common/xgb/fed_controller.py new file mode 100644 index 0000000000..959fb6dcce --- /dev/null +++ b/nvflare/app_common/xgb/fed_controller.py @@ -0,0 +1,62 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.xgb.adaptors.grpc_server_adaptor import GrpcServerAdaptor +from nvflare.app_common.xgb.runners.xgb_server_runner import XGBServerRunner + +from .controller import XGBController +from .defs import Constant + + +class XGBFedController(XGBController): + def __init__( + self, + num_rounds: int, + configure_task_name=Constant.CONFIG_TASK_NAME, + configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT, + start_task_name=Constant.START_TASK_NAME, + start_task_timeout=Constant.START_TASK_TIMEOUT, + job_status_check_interval: float = Constant.JOB_STATUS_CHECK_INTERVAL, + max_client_op_interval: float = Constant.MAX_CLIENT_OP_INTERVAL, + progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, + client_ranks=None, + int_client_grpc_options=None, + in_process=True, + ): + XGBController.__init__( + self, + adaptor_component_id="", + num_rounds=num_rounds, + configure_task_name=configure_task_name, + configure_task_timeout=configure_task_timeout, + start_task_name=start_task_name, + start_task_timeout=start_task_timeout, + job_status_check_interval=job_status_check_interval, + max_client_op_interval=max_client_op_interval, + progress_timeout=progress_timeout, + client_ranks=client_ranks, + ) + self.int_client_grpc_options = int_client_grpc_options + self.in_process = in_process + + def get_adaptor(self, fl_ctx: FLContext): + runner = XGBServerRunner() + runner.initialize(fl_ctx) + adaptor = GrpcServerAdaptor( + int_client_grpc_options=self.int_client_grpc_options, + in_process=self.in_process, + ) + adaptor.set_runner(runner) + return adaptor diff --git a/nvflare/app_common/xgb/fed_executor.py b/nvflare/app_common/xgb/fed_executor.py new file mode 100644 index 0000000000..f64ecf5f32 --- /dev/null +++ b/nvflare/app_common/xgb/fed_executor.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.xgb.adaptors.grpc_client_adaptor import GrpcClientAdaptor +from nvflare.app_common.xgb.runners.xgb_client_runner import XGBClientRunner + +from .executor import XGBExecutor + + +class FedXGBHistogramExecutor(XGBExecutor): + def __init__( + self, + early_stopping_rounds, + xgb_params: dict, + data_loader_id: str, + verbose_eval=False, + use_gpus=False, + int_server_grpc_options=None, + req_timeout=10.0, + model_file_name="model.json", + in_process=True, + ): + XGBExecutor.__init__( + self, + adaptor_component_id="", + req_timeout=req_timeout, + ) + self.early_stopping_rounds = early_stopping_rounds + self.xgb_params = xgb_params + self.data_loader_id = data_loader_id + self.verbose_eval = verbose_eval + self.use_gpus = use_gpus + self.int_server_grpc_options = int_server_grpc_options + self.model_file_name = model_file_name + self.in_process = in_process + + def get_adaptor(self, fl_ctx: FLContext): + runner = XGBClientRunner( + data_loader_id=self.data_loader_id, + early_stopping_rounds=self.early_stopping_rounds, + xgb_params=self.xgb_params, + verbose_eval=self.verbose_eval, + use_gpus=self.use_gpus, + model_file_name=self.model_file_name, + ) + runner.initialize(fl_ctx) + adaptor = GrpcClientAdaptor( + int_server_grpc_options=self.int_server_grpc_options, + in_process=self.in_process, + ) + adaptor.set_runner(runner) + return adaptor diff --git a/nvflare/app_common/xgb/grpc_client.py b/nvflare/app_common/xgb/grpc_client.py new file mode 100644 index 0000000000..7b0135e9e6 --- /dev/null +++ b/nvflare/app_common/xgb/grpc_client.py @@ -0,0 +1,178 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import grpc + +import nvflare.app_common.xgb.proto.federated_pb2 as pb2 +from nvflare.app_common.xgb.defs import GRPC_DEFAULT_OPTIONS +from nvflare.app_common.xgb.proto.federated_pb2_grpc import FederatedStub +from nvflare.fuel.utils.obj_utils import get_logger + + +class GrpcClient: + """This class implements a gRPC XGB Client that is capable of sending XGB operations to a gRPC XGB Server.""" + + def __init__(self, server_addr, grpc_options=None): + """Constructor + + Args: + server_addr: address of the gRPC server to connect to + grpc_options: gRPC options for the gRPC client + """ + if not grpc_options: + grpc_options = GRPC_DEFAULT_OPTIONS + + self.stub = None + self.channel = None + self.server_addr = server_addr + self.grpc_options = grpc_options + self.started = False + self.logger = get_logger(self) + + def start(self, ready_timeout=10): + """Start the gRPC client and wait for the server to be ready. + + Args: + ready_timeout: how long to wait for the server to be ready + + Returns: None + + """ + if self.started: + return + + self.started = True + + # TBD: need to support secure channel as well + self.channel = grpc.insecure_channel(self.server_addr, options=self.grpc_options) + self.stub = FederatedStub(self.channel) + + # wait for channel ready + try: + grpc.channel_ready_future(self.channel).result(timeout=ready_timeout) + except grpc.FutureTimeoutError: + raise RuntimeError(f"cannot connect to server after {ready_timeout} seconds") + + def send_allgather(self, seq_num, rank, data: bytes): + """Send Allgather request to gRPC server + + Args: + seq_num: sequence number + rank: rank of the client + data: the send_buf data + + Returns: an AllgatherReply object; or None if processing error is encountered + + """ + req = pb2.AllgatherRequest( + sequence_number=seq_num, + rank=rank, + send_buffer=data, + ) + + result = self.stub.Allgather(req) + if not isinstance(result, pb2.AllgatherReply): + self.logger.error(f"expect reply to be pb2.AllgatherReply but got {type(result)}") + return None + return result + + def send_allgatherv(self, seq_num, rank, data: bytes): + """Send AllgatherV request to gRPC server + + Args: + seq_num: sequence number + rank: rank of the client + data: the send_buf data + + Returns: an AllgatherVReply object; or None if processing error is encountered + + """ + req = pb2.AllgatherVRequest( + sequence_number=seq_num, + rank=rank, + send_buffer=data, + ) + + result = self.stub.AllgatherV(req) + if not isinstance(result, pb2.AllgatherVReply): + self.logger.error(f"expect reply to be pb2.AllgatherVReply but got {type(result)}") + return None + return result + + def send_allreduce(self, seq_num, rank, data: bytes, data_type, reduce_op): + """Send Allreduce request to gRPC server + + Args: + seq_num: sequence number + rank: rank of the client + data: the send_buf data + data_type: data type of the input + reduce_op: reduce op to be performed + + Returns: an AllreduceReply object; or None if processing error is encountered + + """ + req = pb2.AllreduceRequest( + sequence_number=seq_num, + rank=rank, + send_buffer=data, + data_type=data_type, + reduce_operation=reduce_op, + ) + + result = self.stub.Allreduce(req) + if not isinstance(result, pb2.AllreduceReply): + self.logger.error(f"expect reply to be pb2.AllreduceReply but got {type(result)}") + return None + return result + + def send_broadcast(self, seq_num, rank, data: bytes, root): + """Send Broadcast request to gRPC server + + Args: + seq_num: sequence number + rank: rank of the client + data: the send_buf data + root: rank of the root + + Returns: a BroadcastReply object; or None if processing error is encountered + + """ + req = pb2.BroadcastRequest( + sequence_number=seq_num, + rank=rank, + send_buffer=data, + root=root, + ) + + result = self.stub.Broadcast(req) + if not isinstance(result, pb2.BroadcastReply): + self.logger.error(f"expect reply to be pb2.BroadcastReply but got {type(result)}") + return None + return result + + def stop(self): + """Stop the gRPC client + + Returns: None + + """ + ch = self.channel + self.channel = None # set to None in case another thread also tries to close. + if ch: + try: + ch.close() + except: + # ignore errors when closing the channel + pass diff --git a/nvflare/app_common/xgb/grpc_server.py b/nvflare/app_common/xgb/grpc_server.py new file mode 100644 index 0000000000..6ace57aecc --- /dev/null +++ b/nvflare/app_common/xgb/grpc_server.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import concurrent.futures as futures + +import grpc + +from nvflare.app_common.xgb.defs import GRPC_DEFAULT_OPTIONS +from nvflare.app_common.xgb.proto.federated_pb2_grpc import FederatedServicer, add_FederatedServicer_to_server +from nvflare.fuel.utils.obj_utils import get_logger +from nvflare.fuel.utils.validation_utils import check_object_type, check_positive_int +from nvflare.security.logging import secure_format_exception + + +class GrpcServer: + """This class implements a gRPC XGB Server that is capable of processing XGB operations.""" + + def __init__(self, addr, max_workers: int, grpc_options, servicer): + """Constructor + + Args: + addr: the listening address of the server + max_workers: max number of workers + grpc_options: gRPC options + servicer: the servicer that is capable of processing XGB requests + """ + if not grpc_options: + grpc_options = GRPC_DEFAULT_OPTIONS + + check_object_type("servicer", servicer, FederatedServicer) + check_positive_int("max_workers", max_workers) + self.grpc_server = grpc.server(futures.ThreadPoolExecutor(max_workers=max_workers), options=grpc_options) + add_FederatedServicer_to_server(servicer, self.grpc_server) + self.logger = get_logger(self) + + try: + # TBD: will be enhanced to support secure port + self.grpc_server.add_insecure_port(addr) + self.logger.info(f"XGBServer: added insecure port at {addr}") + except Exception as ex: + self.logger.error(f"cannot listen on {addr}: {secure_format_exception(ex)}") + + def start(self, no_blocking=False): + """Called to start the server + + Args: + no_blocking: whether blocking the current thread and wait for server termination + + Returns: None + + """ + self.logger.info("starting gRPC Server") + self.grpc_server.start() + if no_blocking: + # don't wait for server termination + return + else: + self.grpc_server.wait_for_termination() + self.logger.info("gRPC XGB server terminated") + + def shutdown(self): + """Shut down the gRPC server gracefully. + + Returns: + + """ + self.logger.info("shutting down gRPC XGB server") + server = self.grpc_server + self.grpc_server = None # in case another thread calls shutdown at the same time + if server: + server.stop(grace=0.5) diff --git a/nvflare/app_common/xgb/mock/__init__.py b/nvflare/app_common/xgb/mock/__init__.py new file mode 100644 index 0000000000..4fc50543f1 --- /dev/null +++ b/nvflare/app_common/xgb/mock/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/app_common/xgb/mock/aggr_servicer.py b/nvflare/app_common/xgb/mock/aggr_servicer.py new file mode 100644 index 0000000000..142fa36a10 --- /dev/null +++ b/nvflare/app_common/xgb/mock/aggr_servicer.py @@ -0,0 +1,122 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading + +import nvflare.app_common.xgb.proto.federated_pb2 as pb2 +from nvflare.app_common.xgb.proto.federated_pb2_grpc import FederatedServicer +from nvflare.fuel.utils.obj_utils import get_logger + + +class ReqWaiter: + def __init__(self, exp_num_clients: int, exp_seq: int, exp_op): + self.exp_num_clients = exp_num_clients + self.exp_seq = exp_seq + self.exp_op = exp_op + self.reqs = {} + self.result = {} + self.waiter = threading.Event() + + def add_request(self, op: str, rank, seq, req): + if seq != self.exp_seq: + raise RuntimeError(f"expecting seq {self.exp_seq} from {rank=} but got {seq}") + + if op != self.exp_op: + raise RuntimeError(f"expecting op {self.exp_op} from {rank=} but got {op}") + + if rank in self.reqs: + raise RuntimeError(f"duplicate request from {op=} {rank=} {seq=}") + + self.reqs[rank] = req + + if isinstance(req, pb2.AllgatherRequest): + reply = pb2.AllgatherReply(receive_buffer=req.send_buffer) + elif isinstance(req, pb2.AllgatherVRequest): + reply = pb2.AllgatherVReply(receive_buffer=req.send_buffer) + elif isinstance(req, pb2.AllreduceRequest): + reply = pb2.AllreduceReply(receive_buffer=req.send_buffer) + elif isinstance(req, pb2.BroadcastRequest): + reply = pb2.BroadcastReply(receive_buffer=req.send_buffer) + else: + raise RuntimeError(f"unknown request type {type(req)}") + self.result[rank] = reply + if len(self.reqs) == self.exp_num_clients: + self.waiter.set() + + def wait(self, timeout): + return self.waiter.wait(timeout) + + +class AggrServicer(FederatedServicer): + def __init__(self, num_clients, aggr_timeout=10.0): + self.logger = get_logger(self) + self.num_clients = num_clients + self.aggr_timeout = aggr_timeout + self.req_lock = threading.Lock() + self.req_waiter = None + + def _wait_for_result(self, op, rank, seq, request): + with self.req_lock: + if not self.req_waiter: + self.logger.info(f"setting new waiter: {seq=} {op=}") + self.req_waiter = ReqWaiter( + exp_num_clients=self.num_clients, + exp_seq=seq, + exp_op=op, + ) + self.req_waiter.add_request(op, rank, seq, request) + if not self.req_waiter.wait(self.aggr_timeout): + self.logger.error(f"results not received from all ranks after {self.aggr_timeout} seconds") + self.logger.info(f"for {rank=}: results remaining: {self.req_waiter.result.keys()}") + with self.req_lock: + result = self.req_waiter.result.pop(rank, None) + if len(self.req_waiter.result) == 0: + self.logger.info("all results are retrieved - reset req_waiter to None") + self.req_waiter = None + return result + + def Allgather(self, request: pb2.AllgatherRequest, context): + seq = request.sequence_number + rank = request.rank + data = request.send_buffer + op = "Allgather" + self.logger.info(f"got {op}: {seq=} {rank=} data_size={len(data)}") + return self._wait_for_result(op, rank, seq, request) + + def AllgatherV(self, request: pb2.AllgatherVRequest, context): + seq = request.sequence_number + rank = request.rank + data = request.send_buffer + op = "AllgatherV" + self.logger.info(f"got {op}: {seq=} {rank=} data_size={len(data)}") + return self._wait_for_result(op, rank, seq, request) + + def Allreduce(self, request: pb2.AllreduceRequest, context): + seq = request.sequence_number + rank = request.rank + data = request.send_buffer + reduce_op = request.reduce_operation + data_type = request.data_type + op = "Allreduce" + self.logger.info(f"got {op}: {seq=} {rank=} {reduce_op=} {data_type=} data_size={len(data)}") + return self._wait_for_result(op, rank, seq, request) + + def Broadcast(self, request: pb2.BroadcastRequest, context): + seq = request.sequence_number + rank = request.rank + data = request.send_buffer + root = request.root + op = "Broadcast" + self.logger.info(f"got {op}: {seq=} {rank=} {root=} data_size={len(data)}") + return self._wait_for_result(op, rank, seq, request) diff --git a/nvflare/app_common/xgb/mock/mock_client_runner.py b/nvflare/app_common/xgb/mock/mock_client_runner.py new file mode 100644 index 0000000000..7e16dd6f17 --- /dev/null +++ b/nvflare/app_common/xgb/mock/mock_client_runner.py @@ -0,0 +1,123 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import time + +import nvflare.app_common.xgb.proto.federated_pb2 as pb2 +from nvflare.apis.fl_component import FLComponent +from nvflare.app_common.xgb.defs import Constant +from nvflare.app_common.xgb.grpc_client import GrpcClient +from nvflare.app_common.xgb.runners.xgb_runner import XGBRunner + + +class MockClientRunner(XGBRunner, FLComponent): + def __init__(self): + FLComponent.__init__(self) + self.training_stopped = False + self.asked_to_stop = False + + def run(self, ctx: dict): + server_addr = ctx.get(Constant.RUNNER_CTX_SERVER_ADDR) + rank = ctx.get(Constant.RUNNER_CTX_RANK) + num_rounds = ctx.get(Constant.RUNNER_CTX_NUM_ROUNDS) + + client = GrpcClient(server_addr=server_addr) + client.start() + + rank = rank + seq = 0 + total_time = 0 + total_reqs = 0 + for i in range(num_rounds): + if self.asked_to_stop: + self.logger.info("training aborted") + self.training_stopped = True + return + + self.logger.info(f"Test round {i}") + data = os.urandom(1000000) + + self.logger.info("sending allgather") + start = time.time() + result = client.send_allgather(seq_num=seq + 1, rank=rank, data=data) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.AllgatherReply): + self.logger.error(f"expect reply to be pb2.AllgatherReply but got {type(result)}") + elif result.receive_buffer != data: + self.logger.error("allgather result does not match request") + else: + self.logger.info("OK: allgather result matches request!") + + self.logger.info("sending allgatherV") + start = time.time() + result = client.send_allgatherv(seq_num=seq + 2, rank=rank, data=data) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.AllgatherVReply): + self.logger.error(f"expect reply to be pb2.AllgatherVReply but got {type(result)}") + elif result.receive_buffer != data: + self.logger.error("allgatherV result does not match request") + else: + self.logger.info("OK: allgatherV result matches request!") + + self.logger.info("sending allreduce") + start = time.time() + result = client.send_allreduce( + seq_num=seq + 3, + rank=rank, + data=data, + reduce_op=2, + data_type=2, + ) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.AllreduceReply): + self.logger.error(f"expect reply to be pb2.AllreduceReply but got {type(result)}") + elif result.receive_buffer != data: + self.logger.error("allreduce result does not match request") + else: + self.logger.info("OK: allreduce result matches request!") + print("OK: allreduce result matches request!") + + self.logger.info("sending broadcast") + start = time.time() + result = client.send_broadcast( + seq_num=seq + 4, + rank=rank, + data=data, + root=3, + ) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.BroadcastReply): + self.logger.error(f"expect reply to be pb2.BroadcastReply but got {type(result)}") + elif result.receive_buffer != data: + self.logger.error("ERROR: broadcast result does not match request") + else: + self.logger.info("OK: broadcast result matches request!") + + seq += 4 + time.sleep(1.0) + + time_per_req = total_time / total_reqs + self.logger.info(f"DONE: {total_reqs=} {total_time=} {time_per_req=}") + print(f"DONE: {total_reqs=} {total_time=} {time_per_req=}") + self.training_stopped = True + + def stop(self): + self.asked_to_stop = True + + def is_stopped(self) -> (bool, int): + return self.training_stopped, 0 diff --git a/nvflare/app_common/xgb/mock/mock_controller.py b/nvflare/app_common/xgb/mock/mock_controller.py new file mode 100644 index 0000000000..904ba66f83 --- /dev/null +++ b/nvflare/app_common/xgb/mock/mock_controller.py @@ -0,0 +1,61 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.xgb.adaptors.grpc_server_adaptor import GrpcServerAdaptor +from nvflare.app_common.xgb.controller import XGBController +from nvflare.app_common.xgb.defs import Constant +from nvflare.app_common.xgb.mock.mock_server_runner import MockServerRunner + + +class MockXGBController(XGBController): + def __init__( + self, + num_rounds: int, + configure_task_name=Constant.CONFIG_TASK_NAME, + configure_task_timeout=Constant.CONFIG_TASK_TIMEOUT, + start_task_name=Constant.START_TASK_NAME, + start_task_timeout=Constant.START_TASK_TIMEOUT, + job_status_check_interval: float = Constant.JOB_STATUS_CHECK_INTERVAL, + max_client_op_interval: float = Constant.MAX_CLIENT_OP_INTERVAL, + progress_timeout: float = Constant.WORKFLOW_PROGRESS_TIMEOUT, + client_ranks=None, + int_client_grpc_options=None, + in_process=True, + ): + XGBController.__init__( + self, + adaptor_component_id="", + num_rounds=num_rounds, + configure_task_name=configure_task_name, + configure_task_timeout=configure_task_timeout, + start_task_name=start_task_name, + start_task_timeout=start_task_timeout, + job_status_check_interval=job_status_check_interval, + max_client_op_interval=max_client_op_interval, + progress_timeout=progress_timeout, + client_ranks=client_ranks, + ) + self.int_client_grpc_options = int_client_grpc_options + self.in_process = in_process + + def get_adaptor(self, fl_ctx: FLContext): + runner = MockServerRunner() + runner.initialize(fl_ctx) + adaptor = GrpcServerAdaptor( + int_client_grpc_options=self.int_client_grpc_options, + in_process=self.in_process, + ) + adaptor.set_runner(runner) + return adaptor diff --git a/nvflare/app_common/xgb/mock/mock_executor.py b/nvflare/app_common/xgb/mock/mock_executor.py new file mode 100644 index 0000000000..d384ea2c63 --- /dev/null +++ b/nvflare/app_common/xgb/mock/mock_executor.py @@ -0,0 +1,43 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.xgb.adaptors.grpc_client_adaptor import GrpcClientAdaptor +from nvflare.app_common.xgb.executor import XGBExecutor +from nvflare.app_common.xgb.mock.mock_client_runner import MockClientRunner + + +class MockXGBExecutor(XGBExecutor): + def __init__( + self, + int_server_grpc_options=None, + req_timeout=10.0, + in_process=True, + ): + XGBExecutor.__init__( + self, + adaptor_component_id="", + req_timeout=req_timeout, + ) + self.int_server_grpc_options = int_server_grpc_options + self.in_process = in_process + + def get_adaptor(self, fl_ctx: FLContext): + runner = MockClientRunner() + runner.initialize(fl_ctx) + adaptor = GrpcClientAdaptor( + int_server_grpc_options=self.int_server_grpc_options, + in_process=self.in_process, + ) + adaptor.set_runner(runner) + return adaptor diff --git a/nvflare/app_common/xgb/mock/mock_server_runner.py b/nvflare/app_common/xgb/mock/mock_server_runner.py new file mode 100644 index 0000000000..8539b3c290 --- /dev/null +++ b/nvflare/app_common/xgb/mock/mock_server_runner.py @@ -0,0 +1,46 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.app_common.xgb.defs import Constant +from nvflare.app_common.xgb.grpc_server import GrpcServer +from nvflare.app_common.xgb.mock.aggr_servicer import AggrServicer +from nvflare.app_common.xgb.runners.xgb_runner import XGBRunner + + +class MockServerRunner(XGBRunner): + def __init__(self, server_max_workers=10): + self.server_max_workers = server_max_workers + self._stopped = False + self._server = None + + def run(self, ctx: dict): + world_size = ctx.get(Constant.RUNNER_CTX_WORLD_SIZE) + addr = ctx.get(Constant.RUNNER_CTX_SERVER_ADDR) + + self._server = GrpcServer( + addr, + max_workers=self.server_max_workers, + grpc_options=None, + servicer=AggrServicer(num_clients=world_size), + ) + self._server.start(no_blocking=False) + + def stop(self): + s = self._server + self._server = None + if s: + s.shutdown() + self._stopped = True + + def is_stopped(self) -> (bool, int): + return self._stopped, 0 diff --git a/nvflare/app_common/xgb/mock/run_client.py b/nvflare/app_common/xgb/mock/run_client.py new file mode 100644 index 0000000000..bbd6551322 --- /dev/null +++ b/nvflare/app_common/xgb/mock/run_client.py @@ -0,0 +1,108 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import time + +import nvflare.app_common.xgb.proto.federated_pb2 as pb2 +from nvflare.app_common.xgb.grpc_client import GrpcClient + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--addr", "-a", type=str, help="server address", required=True) + parser.add_argument("--rank", "-r", type=int, help="client rank", required=True) + parser.add_argument("--num_rounds", "-n", type=int, help="number of rounds", required=True) + + args = parser.parse_args() + client = GrpcClient(server_addr=args.addr) + client.start() + + rank = args.rank + seq = 0 + total_time = 0 + total_reqs = 0 + for i in range(args.num_rounds): + print(f"Test round {i}") + data = os.urandom(1000000) + + print("sending allgather") + start = time.time() + result = client.send_allgather(seq_num=seq + 1, rank=rank, data=data) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.AllgatherReply): + print(f"expect reply to be pb2.AllgatherReply but got {type(result)}") + elif result.receive_buffer != data: + print("ERROR: allgather result does not match request") + else: + print("OK: allgather result matches request!") + + print("sending allgatherV") + start = time.time() + result = client.send_allgatherv(seq_num=seq + 2, rank=rank, data=data) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.AllgatherVReply): + print(f"expect reply to be pb2.AllgatherVReply but got {type(result)}") + elif result.receive_buffer != data: + print("ERROR: allgatherV result does not match request") + else: + print("OK: allgatherV result matches request!") + + print("sending allreduce") + start = time.time() + result = client.send_allreduce( + seq_num=seq + 3, + rank=rank, + data=data, + reduce_op=2, + data_type=2, + ) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.AllreduceReply): + print(f"expect reply to be pb2.AllreduceReply but got {type(result)}") + elif result.receive_buffer != data: + print("ERROR: allreduce result does not match request") + else: + print("OK: allreduce result matches request!") + + print("sending broadcast") + start = time.time() + result = client.send_broadcast( + seq_num=seq + 4, + rank=rank, + data=data, + root=3, + ) + total_reqs += 1 + total_time += time.time() - start + if not isinstance(result, pb2.BroadcastReply): + print(f"expect reply to be pb2.BroadcastReply but got {type(result)}") + elif result.receive_buffer != data: + print("ERROR: broadcast result does not match request") + else: + print("OK: broadcast result matches request!") + + seq += 4 + time.sleep(1.0) + + time_per_req = total_time / total_reqs + print(f"DONE: {total_reqs=} {total_time=} {time_per_req=}") + + +if __name__ == "__main__": + main() diff --git a/nvflare/app_common/xgb/mock/run_server.py b/nvflare/app_common/xgb/mock/run_server.py new file mode 100644 index 0000000000..6ac4ae02e3 --- /dev/null +++ b/nvflare/app_common/xgb/mock/run_server.py @@ -0,0 +1,43 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging + +from nvflare.app_common.xgb.grpc_server import GrpcServer +from nvflare.app_common.xgb.mock.aggr_servicer import AggrServicer + + +def main(): + logging.basicConfig() + logging.getLogger().setLevel(logging.INFO) + + parser = argparse.ArgumentParser() + parser.add_argument("--addr", "-a", type=str, help="server address", required=True) + parser.add_argument("--num_clients", "-c", type=int, help="number of clients", required=True) + parser.add_argument("--max_workers", "-w", type=int, help="max number of workers", required=False, default=20) + + args = parser.parse_args() + print(f"starting server at {args.addr} max_workers={args.max_workers}") + server = GrpcServer( + args.addr, + max_workers=args.max_workers, + grpc_options=None, + servicer=AggrServicer(num_clients=args.num_clients), + ) + server.start() + + +if __name__ == "__main__": + main() diff --git a/nvflare/app_common/xgb/proto/__init__.py b/nvflare/app_common/xgb/proto/__init__.py new file mode 100644 index 0000000000..4fc50543f1 --- /dev/null +++ b/nvflare/app_common/xgb/proto/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/app_common/xgb/proto/federated.proto b/nvflare/app_common/xgb/proto/federated.proto new file mode 100644 index 0000000000..f412204813 --- /dev/null +++ b/nvflare/app_common/xgb/proto/federated.proto @@ -0,0 +1,85 @@ +/*! + * Copyright 2022-2023 XGBoost contributors + */ +syntax = "proto3"; + +package xgboost.collective.federated; + +service Federated { + rpc Allgather(AllgatherRequest) returns (AllgatherReply) {} + rpc AllgatherV(AllgatherVRequest) returns (AllgatherVReply) {} + rpc Allreduce(AllreduceRequest) returns (AllreduceReply) {} + rpc Broadcast(BroadcastRequest) returns (BroadcastReply) {} +} + +enum DataType { + HALF = 0; + FLOAT = 1; + DOUBLE = 2; + LONG_DOUBLE = 3; + INT8 = 4; + INT16 = 5; + INT32 = 6; + INT64 = 7; + UINT8 = 8; + UINT16 = 9; + UINT32 = 10; + UINT64 = 11; +} + +enum ReduceOperation { + MAX = 0; + MIN = 1; + SUM = 2; + BITWISE_AND = 3; + BITWISE_OR = 4; + BITWISE_XOR = 5; +} + +message AllgatherRequest { + // An incrementing counter that is unique to each round to operations. + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; +} + +message AllgatherReply { + bytes receive_buffer = 1; +} + +message AllgatherVRequest { + // An incrementing counter that is unique to each round to operations. + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; +} + +message AllgatherVReply { + bytes receive_buffer = 1; +} + +message AllreduceRequest { + // An incrementing counter that is unique to each round to operations. + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; + DataType data_type = 4; + ReduceOperation reduce_operation = 5; +} + +message AllreduceReply { + bytes receive_buffer = 1; +} + +message BroadcastRequest { + // An incrementing counter that is unique to each round to operations. + uint64 sequence_number = 1; + int32 rank = 2; + bytes send_buffer = 3; + // The root rank to broadcast from. + int32 root = 4; +} + +message BroadcastReply { + bytes receive_buffer = 1; +} \ No newline at end of file diff --git a/nvflare/app_common/xgb/proto/federated_pb2.py b/nvflare/app_common/xgb/proto/federated_pb2.py new file mode 100644 index 0000000000..ba80c1e5d6 --- /dev/null +++ b/nvflare/app_common/xgb/proto/federated_pb2.py @@ -0,0 +1,59 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: federated.proto +"""Generated protocol buffer code.""" +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x66\x65\x64\x65rated.proto\x12\x1cxgboost.collective.federated\"N\n\x10\x41llgatherRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\"(\n\x0e\x41llgatherReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"O\n\x11\x41llgatherVRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\")\n\x0f\x41llgatherVReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\xd2\x01\n\x10\x41llreduceRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x39\n\tdata_type\x18\x04 \x01(\x0e\x32&.xgboost.collective.federated.DataType\x12G\n\x10reduce_operation\x18\x05 \x01(\x0e\x32-.xgboost.collective.federated.ReduceOperation\"(\n\x0e\x41llreduceReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c\"\\\n\x10\x42roadcastRequest\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\x0c\n\x04rank\x18\x02 \x01(\x05\x12\x13\n\x0bsend_buffer\x18\x03 \x01(\x0c\x12\x0c\n\x04root\x18\x04 \x01(\x05\"(\n\x0e\x42roadcastReply\x12\x16\n\x0ereceive_buffer\x18\x01 \x01(\x0c*\x96\x01\n\x08\x44\x61taType\x12\x08\n\x04HALF\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06\x44OUBLE\x10\x02\x12\x0f\n\x0bLONG_DOUBLE\x10\x03\x12\x08\n\x04INT8\x10\x04\x12\t\n\x05INT16\x10\x05\x12\t\n\x05INT32\x10\x06\x12\t\n\x05INT64\x10\x07\x12\t\n\x05UINT8\x10\x08\x12\n\n\x06UINT16\x10\t\x12\n\n\x06UINT32\x10\n\x12\n\n\x06UINT64\x10\x0b*^\n\x0fReduceOperation\x12\x07\n\x03MAX\x10\x00\x12\x07\n\x03MIN\x10\x01\x12\x07\n\x03SUM\x10\x02\x12\x0f\n\x0b\x42ITWISE_AND\x10\x03\x12\x0e\n\nBITWISE_OR\x10\x04\x12\x0f\n\x0b\x42ITWISE_XOR\x10\x05\x32\xc2\x03\n\tFederated\x12k\n\tAllgather\x12..xgboost.collective.federated.AllgatherRequest\x1a,.xgboost.collective.federated.AllgatherReply\"\x00\x12n\n\nAllgatherV\x12/.xgboost.collective.federated.AllgatherVRequest\x1a-.xgboost.collective.federated.AllgatherVReply\"\x00\x12k\n\tAllreduce\x12..xgboost.collective.federated.AllreduceRequest\x1a,.xgboost.collective.federated.AllreduceReply\"\x00\x12k\n\tBroadcast\x12..xgboost.collective.federated.BroadcastRequest\x1a,.xgboost.collective.federated.BroadcastReply\"\x00\x62\x06proto3') + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'federated_pb2', globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _DATATYPE._serialized_start=687 + _DATATYPE._serialized_end=837 + _REDUCEOPERATION._serialized_start=839 + _REDUCEOPERATION._serialized_end=933 + _ALLGATHERREQUEST._serialized_start=49 + _ALLGATHERREQUEST._serialized_end=127 + _ALLGATHERREPLY._serialized_start=129 + _ALLGATHERREPLY._serialized_end=169 + _ALLGATHERVREQUEST._serialized_start=171 + _ALLGATHERVREQUEST._serialized_end=250 + _ALLGATHERVREPLY._serialized_start=252 + _ALLGATHERVREPLY._serialized_end=293 + _ALLREDUCEREQUEST._serialized_start=296 + _ALLREDUCEREQUEST._serialized_end=506 + _ALLREDUCEREPLY._serialized_start=508 + _ALLREDUCEREPLY._serialized_end=548 + _BROADCASTREQUEST._serialized_start=550 + _BROADCASTREQUEST._serialized_end=642 + _BROADCASTREPLY._serialized_start=644 + _BROADCASTREPLY._serialized_end=684 + _FEDERATED._serialized_start=936 + _FEDERATED._serialized_end=1386 +# @@protoc_insertion_point(module_scope) diff --git a/nvflare/app_common/xgb/proto/federated_pb2.pyi b/nvflare/app_common/xgb/proto/federated_pb2.pyi new file mode 100644 index 0000000000..8e2a7e740e --- /dev/null +++ b/nvflare/app_common/xgb/proto/federated_pb2.pyi @@ -0,0 +1,100 @@ +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union + +BITWISE_AND: ReduceOperation +BITWISE_OR: ReduceOperation +BITWISE_XOR: ReduceOperation +DESCRIPTOR: _descriptor.FileDescriptor +DOUBLE: DataType +FLOAT: DataType +HALF: DataType +INT16: DataType +INT32: DataType +INT64: DataType +INT8: DataType +LONG_DOUBLE: DataType +MAX: ReduceOperation +MIN: ReduceOperation +SUM: ReduceOperation +UINT16: DataType +UINT32: DataType +UINT64: DataType +UINT8: DataType + +class AllgatherReply(_message.Message): + __slots__ = ["receive_buffer"] + RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] + receive_buffer: bytes + def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... + +class AllgatherRequest(_message.Message): + __slots__ = ["rank", "send_buffer", "sequence_number"] + RANK_FIELD_NUMBER: _ClassVar[int] + SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] + SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] + rank: int + send_buffer: bytes + sequence_number: int + def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ... + +class AllgatherVReply(_message.Message): + __slots__ = ["receive_buffer"] + RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] + receive_buffer: bytes + def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... + +class AllgatherVRequest(_message.Message): + __slots__ = ["rank", "send_buffer", "sequence_number"] + RANK_FIELD_NUMBER: _ClassVar[int] + SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] + SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] + rank: int + send_buffer: bytes + sequence_number: int + def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ...) -> None: ... + +class AllreduceReply(_message.Message): + __slots__ = ["receive_buffer"] + RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] + receive_buffer: bytes + def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... + +class AllreduceRequest(_message.Message): + __slots__ = ["data_type", "rank", "reduce_operation", "send_buffer", "sequence_number"] + DATA_TYPE_FIELD_NUMBER: _ClassVar[int] + RANK_FIELD_NUMBER: _ClassVar[int] + REDUCE_OPERATION_FIELD_NUMBER: _ClassVar[int] + SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] + SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] + data_type: DataType + rank: int + reduce_operation: ReduceOperation + send_buffer: bytes + sequence_number: int + def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ..., data_type: _Optional[_Union[DataType, str]] = ..., reduce_operation: _Optional[_Union[ReduceOperation, str]] = ...) -> None: ... + +class BroadcastReply(_message.Message): + __slots__ = ["receive_buffer"] + RECEIVE_BUFFER_FIELD_NUMBER: _ClassVar[int] + receive_buffer: bytes + def __init__(self, receive_buffer: _Optional[bytes] = ...) -> None: ... + +class BroadcastRequest(_message.Message): + __slots__ = ["rank", "root", "send_buffer", "sequence_number"] + RANK_FIELD_NUMBER: _ClassVar[int] + ROOT_FIELD_NUMBER: _ClassVar[int] + SEND_BUFFER_FIELD_NUMBER: _ClassVar[int] + SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] + rank: int + root: int + send_buffer: bytes + sequence_number: int + def __init__(self, sequence_number: _Optional[int] = ..., rank: _Optional[int] = ..., send_buffer: _Optional[bytes] = ..., root: _Optional[int] = ...) -> None: ... + +class DataType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] + +class ReduceOperation(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] diff --git a/nvflare/app_common/xgb/proto/federated_pb2_grpc.py b/nvflare/app_common/xgb/proto/federated_pb2_grpc.py new file mode 100644 index 0000000000..206d8474da --- /dev/null +++ b/nvflare/app_common/xgb/proto/federated_pb2_grpc.py @@ -0,0 +1,179 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +import nvflare.app_common.xgb.proto.federated_pb2 as federated__pb2 + + +class FederatedStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Allgather = channel.unary_unary( + '/xgboost.collective.federated.Federated/Allgather', + request_serializer=federated__pb2.AllgatherRequest.SerializeToString, + response_deserializer=federated__pb2.AllgatherReply.FromString, + ) + self.AllgatherV = channel.unary_unary( + '/xgboost.collective.federated.Federated/AllgatherV', + request_serializer=federated__pb2.AllgatherVRequest.SerializeToString, + response_deserializer=federated__pb2.AllgatherVReply.FromString, + ) + self.Allreduce = channel.unary_unary( + '/xgboost.collective.federated.Federated/Allreduce', + request_serializer=federated__pb2.AllreduceRequest.SerializeToString, + response_deserializer=federated__pb2.AllreduceReply.FromString, + ) + self.Broadcast = channel.unary_unary( + '/xgboost.collective.federated.Federated/Broadcast', + request_serializer=federated__pb2.BroadcastRequest.SerializeToString, + response_deserializer=federated__pb2.BroadcastReply.FromString, + ) + + +class FederatedServicer(object): + """Missing associated documentation comment in .proto file.""" + + def Allgather(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def AllgatherV(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Allreduce(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Broadcast(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_FederatedServicer_to_server(servicer, server): + rpc_method_handlers = { + 'Allgather': grpc.unary_unary_rpc_method_handler( + servicer.Allgather, + request_deserializer=federated__pb2.AllgatherRequest.FromString, + response_serializer=federated__pb2.AllgatherReply.SerializeToString, + ), + 'AllgatherV': grpc.unary_unary_rpc_method_handler( + servicer.AllgatherV, + request_deserializer=federated__pb2.AllgatherVRequest.FromString, + response_serializer=federated__pb2.AllgatherVReply.SerializeToString, + ), + 'Allreduce': grpc.unary_unary_rpc_method_handler( + servicer.Allreduce, + request_deserializer=federated__pb2.AllreduceRequest.FromString, + response_serializer=federated__pb2.AllreduceReply.SerializeToString, + ), + 'Broadcast': grpc.unary_unary_rpc_method_handler( + servicer.Broadcast, + request_deserializer=federated__pb2.BroadcastRequest.FromString, + response_serializer=federated__pb2.BroadcastReply.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'xgboost.collective.federated.Federated', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + + # This class is part of an EXPERIMENTAL API. +class Federated(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Allgather(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/Allgather', + federated__pb2.AllgatherRequest.SerializeToString, + federated__pb2.AllgatherReply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def AllgatherV(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/AllgatherV', + federated__pb2.AllgatherVRequest.SerializeToString, + federated__pb2.AllgatherVReply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Allreduce(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/Allreduce', + federated__pb2.AllreduceRequest.SerializeToString, + federated__pb2.AllreduceReply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) + + @staticmethod + def Broadcast(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary(request, target, '/xgboost.collective.federated.Federated/Broadcast', + federated__pb2.BroadcastRequest.SerializeToString, + federated__pb2.BroadcastReply.FromString, + options, channel_credentials, + insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/nvflare/app_common/xgb/proto/gen_proto.sh b/nvflare/app_common/xgb/proto/gen_proto.sh new file mode 100644 index 0000000000..10afcf5b3b --- /dev/null +++ b/nvflare/app_common/xgb/proto/gen_proto.sh @@ -0,0 +1 @@ +python -m grpc_tools.protoc -I. --python_out=. --pyi_out=. --grpc_python_out=. federated.proto diff --git a/nvflare/app_common/xgb/runners/__init__.py b/nvflare/app_common/xgb/runners/__init__.py new file mode 100644 index 0000000000..4fc50543f1 --- /dev/null +++ b/nvflare/app_common/xgb/runners/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/app_common/xgb/runners/xgb_client_runner.py b/nvflare/app_common/xgb/runners/xgb_client_runner.py new file mode 100644 index 0000000000..7771590813 --- /dev/null +++ b/nvflare/app_common/xgb/runners/xgb_client_runner.py @@ -0,0 +1,141 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import xgboost as xgb +from xgboost import callback + +from nvflare.apis.fl_component import FLComponent +from nvflare.apis.fl_context import FLContext +from nvflare.app_common.xgb.data_loader import XGBDataLoader +from nvflare.app_common.xgb.defs import Constant +from nvflare.app_common.xgb.runners.xgb_runner import XGBRunner +from nvflare.app_common.xgb.tb import TensorBoardCallback +from nvflare.app_common.xgb.xgb_params import XGBoostParams +from nvflare.fuel.utils.import_utils import optional_import +from nvflare.fuel.utils.obj_utils import get_logger + + +class XGBClientRunner(XGBRunner, FLComponent): + def __init__( + self, + data_loader_id: str, + early_stopping_rounds: int, + xgb_params: dict, + verbose_eval, + use_gpus, + model_file_name, + ): + FLComponent.__init__(self) + self.early_stopping_rounds = early_stopping_rounds + self.xgb_params = xgb_params + self.verbose_eval = verbose_eval + self.use_gpus = use_gpus + self.model_file_name = model_file_name + self.data_loader_id = data_loader_id + self.logger = get_logger(self) + + self._client_name = None + self._rank = None + self._world_size = None + self._num_rounds = None + self._server_addr = None + self._data_loader = None + self._tb_dir = None + self._model_dir = None + self._stopped = False + + def initialize(self, fl_ctx: FLContext): + engine = fl_ctx.get_engine() + self._data_loader = engine.get_component(self.data_loader_id) + if not isinstance(self._data_loader, XGBDataLoader): + self.system_panic(f"data_loader should be type XGBDataLoader but got {type(self._data_loader)}", fl_ctx) + + def _xgb_train(self, params: XGBoostParams, train_data, val_data) -> xgb.core.Booster: + """XGBoost training logic. + + Args: + params (XGBoostParams): xgboost parameters. + + Returns: + A xgboost booster. + """ + # Specify validations set to watch performance + watchlist = [(val_data, "eval"), (train_data, "train")] + + callbacks = [callback.EvaluationMonitor(rank=self._rank)] + tensorboard, flag = optional_import(module="torch.utils.tensorboard") + if flag and self._tb_dir: + callbacks.append(TensorBoardCallback(self._tb_dir, tensorboard)) + + # Run training, all the features in training API is available. + bst = xgb.train( + params.xgb_params, + train_data, + params.num_rounds, + evals=watchlist, + early_stopping_rounds=params.early_stopping_rounds, + verbose_eval=params.verbose_eval, + callbacks=callbacks, + ) + return bst + + def run(self, ctx: dict): + self._client_name = ctx.get(Constant.RUNNER_CTX_CLIENT_NAME) + self._rank = ctx.get(Constant.RUNNER_CTX_RANK) + self._world_size = ctx.get(Constant.RUNNER_CTX_WORLD_SIZE) + self._num_rounds = ctx.get(Constant.RUNNER_CTX_NUM_ROUNDS) + self._server_addr = ctx.get(Constant.RUNNER_CTX_SERVER_ADDR) + self._data_loader = ctx.get(Constant.RUNNER_CTX_DATA_LOADER) + self._tb_dir = ctx.get(Constant.RUNNER_CTX_TB_DIR) + self._model_dir = ctx.get(Constant.RUNNER_CTX_MODEL_DIR) + + if self.use_gpus: + # mapping each rank to a GPU (can set to cuda:0 if simulating with only one gpu) + self.logger.info(f"Training with GPU {self._rank}") + self.xgb_params["device"] = f"cuda:{self._rank}" + + self.logger.info(f"Using xgb params: {self.xgb_params}") + params = XGBoostParams( + xgb_params=self.xgb_params, + num_rounds=self._num_rounds, + early_stopping_rounds=self.early_stopping_rounds, + verbose_eval=self.verbose_eval, + ) + + self.logger.info(f"server address is {self._server_addr}") + communicator_env = { + "xgboost_communicator": "federated", + "federated_server_address": f"{self._server_addr}", + "federated_world_size": self._world_size, + "federated_rank": self._rank, + } + with xgb.collective.CommunicatorContext(**communicator_env): + # Load the data. Dmatrix must be created with column split mode in CommunicatorContext for vertical FL + train_data, val_data = self._data_loader.load_data(self._client_name) + + bst = self._xgb_train(params, train_data, val_data) + + # Save the model. + bst.save_model(os.path.join(self._model_dir, self.model_file_name)) + xgb.collective.communicator_print("Finished training\n") + + self._stopped = True + + def stop(self): + # currently no way to stop the runner + pass + + def is_stopped(self) -> (bool, int): + return self._stopped, 0 diff --git a/nvflare/app_common/xgb/runners/xgb_runner.py b/nvflare/app_common/xgb/runners/xgb_runner.py new file mode 100644 index 0000000000..0dae41ada6 --- /dev/null +++ b/nvflare/app_common/xgb/runners/xgb_runner.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from abc import ABC, abstractmethod + +from nvflare.apis.fl_context import FLContext + + +class XGBRunner(ABC): + + """An XGBRunner implements XGB (server or client) processing logic.""" + + def initialize(self, fl_ctx: FLContext): + """Called by Controller/Executor to initialize the runner. + This happens when the job is about to start. + + Args: + fl_ctx: FL context + + Returns: None + + """ + pass + + @abstractmethod + def run(self, ctx: dict): + """Called to start the execution of XGB processing logic. + + Args: + ctx: the contextual info to help the runner execution + + Returns: None + + """ + pass + + @abstractmethod + def stop(self): + """Called to stop the runner. + + Returns: + + """ + pass + + @abstractmethod + def is_stopped(self) -> (bool, int): + """Called to check whether the runner is already stopped. + + Returns: whether the runner is stopped. If stopped, the exit code. + + """ + pass diff --git a/nvflare/app_common/xgb/runners/xgb_server_runner.py b/nvflare/app_common/xgb/runners/xgb_server_runner.py new file mode 100644 index 0000000000..92a4f81c35 --- /dev/null +++ b/nvflare/app_common/xgb/runners/xgb_server_runner.py @@ -0,0 +1,42 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import xgboost.federated as xgb_federated + +from nvflare.app_common.xgb.defs import Constant +from nvflare.app_common.xgb.runners.xgb_runner import XGBRunner + + +class XGBServerRunner(XGBRunner): + def __init__(self): + self._port = None + self._world_size = None + self._stopped = False + + def run(self, ctx: dict): + self._port = ctx.get(Constant.RUNNER_CTX_PORT) + self._world_size = ctx.get(Constant.RUNNER_CTX_WORLD_SIZE) + + xgb_federated.run_federated_server( + port=self._port, + world_size=self._world_size, + ) + self._stopped = True + + def stop(self): + # no way to start currently + pass + + def is_stopped(self) -> (bool, int): + return self._stopped, 0 diff --git a/nvflare/app_common/xgb/sender.py b/nvflare/app_common/xgb/sender.py new file mode 100644 index 0000000000..7177fbb214 --- /dev/null +++ b/nvflare/app_common/xgb/sender.py @@ -0,0 +1,89 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.apis.shareable import ReturnCode, Shareable +from nvflare.apis.signal import Signal +from nvflare.fuel.f3.cellnet.fqcn import FQCN +from nvflare.fuel.utils.obj_utils import get_logger + +from .defs import Constant + + +class Sender: + """ + A Sender is used to send XGB requests from the client to the server and wait for reply. + TBD: currently the sender simply sends the request with an aux message. It will be enhanced to be more + reliable in dealing with unstable network. + """ + + def __init__(self, engine, timeout): + """Constructor + + Args: + engine: the client engine that can send aux messages + timeout: the timeout for XGB requests + """ + self.engine = engine + self.timeout = timeout + self.logger = get_logger(self) + + def _extract_result(self, reply, expected_op): + if not reply: + return None + if not isinstance(reply, dict): + self.logger.error(f"expect reply to be a dict but got {type(reply)}") + return None + result = reply.get(FQCN.ROOT_SERVER) + if not result: + self.logger.error(f"no reply from {FQCN.ROOT_SERVER} for request {expected_op}") + return None + if not isinstance(result, Shareable): + self.logger.error(f"expect result to be a Shareable but got {type(result)}") + return None + rc = result.get_return_code() + if rc != ReturnCode.OK: + self.logger.error(f"server failed to process request: {rc=}") + return None + reply_op = result.get_header(Constant.MSG_KEY_XGB_OP) + if reply_op != expected_op: + self.logger.error(f"received op {reply_op} != expected op {expected_op}") + return None + return result + + def send_to_server(self, op: str, req: Shareable, abort_signal: Signal): + """Send an XGB request to the server. + + Args: + op: the XGB operation code + req: the XGB request + abort_signal: used for checking whether the job is aborted. + + Returns: reply from the server + + Note: when this method is enhanced to be more reliable, we'll keep resending until either the request is + sent successfully or the job is aborted. + + """ + req.set_header(Constant.MSG_KEY_XGB_OP, op) + + server_name = FQCN.ROOT_SERVER + with self.engine.new_context() as fl_ctx: + reply = self.engine.send_aux_request( + targets=[server_name], + topic=Constant.TOPIC_XGB_REQUEST, + request=req, + timeout=self.timeout, + fl_ctx=fl_ctx, + ) + return self._extract_result(reply, op) diff --git a/nvflare/app_common/xgb/tb.py b/nvflare/app_common/xgb/tb.py new file mode 100644 index 0000000000..0719d5b57d --- /dev/null +++ b/nvflare/app_common/xgb/tb.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import xgboost.callback + + +class TensorBoardCallback(xgboost.callback.TrainingCallback): + def __init__(self, app_dir: str, tensorboard): + xgboost.callback.TrainingCallback.__init__(self) + self.train_writer = tensorboard.SummaryWriter(log_dir=os.path.join(app_dir, "train-auc/")) + self.val_writer = tensorboard.SummaryWriter(log_dir=os.path.join(app_dir, "val-auc/")) + + def after_iteration(self, model, epoch: int, evals_log: xgboost.callback.TrainingCallback.EvalsLog): + if not evals_log: + return False + + for data, metric in evals_log.items(): + for metric_name, log in metric.items(): + score = log[-1][0] if isinstance(log[-1], tuple) else log[-1] + if data == "train": + self.train_writer.add_scalar(metric_name, score, epoch) + else: + self.val_writer.add_scalar(metric_name, score, epoch) + return False diff --git a/nvflare/app_common/xgb/xgb_params.py b/nvflare/app_common/xgb/xgb_params.py new file mode 100644 index 0000000000..bf5d4f9b81 --- /dev/null +++ b/nvflare/app_common/xgb/xgb_params.py @@ -0,0 +1,29 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class XGBoostParams: + def __init__(self, xgb_params: dict, num_rounds=10, early_stopping_rounds=2, verbose_eval=False): + """Container for all XGBoost parameters. + + Args: + xgb_params: This dict is passed to `xgboost.train()` as the first argument `params`. + It contains all the Booster parameters. + Please refer to XGBoost documentation for details: + https://xgboost.readthedocs.io/en/stable/python/python_api.html#module-xgboost.training + """ + self.num_rounds = num_rounds + self.early_stopping_rounds = early_stopping_rounds + self.verbose_eval = verbose_eval + self.xgb_params: dict = xgb_params if xgb_params else {} diff --git a/nvflare/app_opt/lightning/api.py b/nvflare/app_opt/lightning/api.py index d34025e913..9035b2fe7c 100644 --- a/nvflare/app_opt/lightning/api.py +++ b/nvflare/app_opt/lightning/api.py @@ -38,13 +38,42 @@ def patch(trainer: pl.Trainer, restore_state: bool = True, load_state_dict_strict: bool = True): - """Patch the lightning trainer for usage with NVFlare. + """Patches the PyTorch Lightning Trainer for usage with NVFlare. Args: trainer: the PyTorch Lightning trainer. - restore_state: whether to restore optimizer and learning rate scheduler states. Defaults to `True`. - load_state_dict_strict: exposes `strict` argument of `torch.nn.Module.load_state_dict()` used load the received model. Defaults to `True`. + restore_state: whether to restore optimizer and learning rate scheduler states. + Defaults to `True`. + load_state_dict_strict: exposes `strict` argument of `torch.nn.Module.load_state_dict()` + used to load the received model. Defaults to `True`. See https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict for details. + + Example: + + Normal usage: + + .. code-block:: python + + trainer = Trainer(max_epochs=1) + flare.patch(trainer) + + + Advanced usage: + + If users want to pass additional information to FLARE server side via the lightning API, + they will need to set the information inside the attributes called ``__fl_meta__`` in their LightningModule. + + .. code-block:: python + + class LitNet(LightningModule): + def __init__(self): + super().__init__() + self.save_hyperparameters() + self.model = Net() + self.train_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES) + self.valid_acc = Accuracy(task="multiclass", num_classes=NUM_CLASSES) + self.__fl_meta__ = {"CUSTOM_VAR": "VALUE_OF_THE_VAR"} + """ fl_callback = FLCallback(rank=trainer.global_rank, load_state_dict_strict=load_state_dict_strict) callbacks = trainer.callbacks @@ -67,7 +96,8 @@ def __init__(self, rank: int = 0, load_state_dict_strict: bool = True): Args: rank: global rank of the PyTorch Lightning trainer. - load_state_dict_strict: exposes `strict` argument of `torch.nn.Module.load_state_dict()` used load the received model. Defaults to `True`. + load_state_dict_strict: exposes `strict` argument of `torch.nn.Module.load_state_dict()` + used to load the received model. Defaults to `True`. See https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict for details. """ super(FLCallback, self).__init__() diff --git a/nvflare/app_opt/tracking/tb/tb_receiver.py b/nvflare/app_opt/tracking/tb/tb_receiver.py index d30b2b94c0..585087dadc 100644 --- a/nvflare/app_opt/tracking/tb/tb_receiver.py +++ b/nvflare/app_opt/tracking/tb/tb_receiver.py @@ -28,13 +28,22 @@ AnalyticsDataType.TEXT: "add_text", AnalyticsDataType.IMAGE: "add_image", AnalyticsDataType.SCALARS: "add_scalars", - AnalyticsDataType.PARAMETER: "add_scalar", - AnalyticsDataType.PARAMETERS: "add_scalars", AnalyticsDataType.METRIC: "add_scalar", AnalyticsDataType.METRICS: "add_scalars", } +def _create_new_data(key, value, sender): + if isinstance(value, (int, float)): + data_type = AnalyticsDataType.SCALAR + elif isinstance(value, str): + data_type = AnalyticsDataType.TEXT + else: + return None + + return AnalyticsData(key=key, value=value, data_type=data_type, sender=sender) + + class TBAnalyticsReceiver(AnalyticsReceiver): def __init__(self, tb_folder="tb_events", events: Optional[List[str]] = None): """Receives analytics data to save to TensorBoard. @@ -71,6 +80,27 @@ def initialize(self, fl_ctx: FLContext): os.makedirs(root_log_dir, exist_ok=True) self.root_log_dir = root_log_dir + def _convert_to_records(self, analytic_data: AnalyticsData, fl_ctx: FLContext) -> List[AnalyticsData]: + # break dict of stuff to smaller items to support + # AnalyticsDataType.PARAMETER and AnalyticsDataType.PARAMETERS + records = [] + + if analytic_data.data_type in (AnalyticsDataType.PARAMETER, AnalyticsDataType.PARAMETERS): + for k, v in ( + analytic_data.value.items() + if analytic_data.data_type == AnalyticsDataType.PARAMETERS + else [(analytic_data.tag, analytic_data.value)] + ): + new_data = _create_new_data(k, v, analytic_data.sender) + if new_data is None: + self.log_warning(fl_ctx, f"Entry {k} of type {type(v)} is not supported.", fire_event=False) + else: + records.append(new_data) + else: + records.append(analytic_data) + + return records + def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin): dxo = from_shareable(shareable) analytic_data = AnalyticsData.from_dxo(dxo) @@ -86,19 +116,22 @@ def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin): # do different things depending on the type in dxo self.log_debug( fl_ctx, - f"save data {analytic_data} from {record_origin}", + f"try to save data {analytic_data} from {record_origin}", fire_event=False, ) - func_name = FUNCTION_MAPPING.get(analytic_data.data_type, None) - if func_name is None: - self.log_error(fl_ctx, f"The data_type {analytic_data.data_type} is not supported.", fire_event=False) - return - - func = getattr(writer, func_name) - if analytic_data.step: - func(analytic_data.tag, analytic_data.value, analytic_data.step) - else: - func(analytic_data.tag, analytic_data.value) + data_records = self._convert_to_records(analytic_data, fl_ctx) + + for data_record in data_records: + func_name = FUNCTION_MAPPING.get(data_record.data_type, None) + if func_name is None: + self.log_warning(fl_ctx, f"The data_type {data_record.data_type} is not supported.", fire_event=False) + return + + func = getattr(writer, func_name) + if data_record.step: + func(data_record.tag, data_record.value, data_record.step) + else: + func(data_record.tag, data_record.value) def finalize(self, fl_ctx: FLContext): for writer in self.writers_table.values(): diff --git a/nvflare/cli.py b/nvflare/cli.py index 4319ed9608..46491cb130 100644 --- a/nvflare/cli.py +++ b/nvflare/cli.py @@ -48,7 +48,7 @@ def check_python_version(): if sys.version_info >= (3, 11): raise RuntimeError("Python versions 3.11 and above are not yet supported. Please use Python 3.8, 3.9 or 3.10.") if sys.version_info < (3, 8): - raise RuntimeError("Python versions 3.6 and below are not supported. Please use Python 3.8, 3.9 or 3.10") + raise RuntimeError("Python versions 3.7 and below are not supported. Please use Python 3.8, 3.9 or 3.10") def def_provision_parser(sub_cmd): @@ -116,6 +116,9 @@ def def_config_parser(sub_cmd): def handle_config_cmd(args): config_file_path, nvflare_config = get_hidden_config() + if not args.job_templates_dir or not os.path.isdir(args.job_templates_dir): + raise ValueError(f"job_templates_dir='{args.job_templates_dir}', it is not a directory") + nvflare_config = create_startup_kit_config(nvflare_config, args.startup_kit_dir) nvflare_config = create_poc_workspace_config(nvflare_config, args.poc_workspace_dir) nvflare_config = create_job_template_config(nvflare_config, args.job_templates_dir) @@ -143,7 +146,8 @@ def parse_args(prog_name: str): if argv: msg = f"{prog_name} {cmd}: unrecognized arguments: {' '.join(argv)}\n" print(f"\nerror: {msg}") - sub_cmd_parser.print_help() + if sub_cmd_parser: + sub_cmd_parser.print_help() _parser.exit(2, "\n") return _parser, _parser.parse_args(), sub_cmd_parsers diff --git a/nvflare/client/api.py b/nvflare/client/api.py index 6fd5b945aa..f9aaf886bb 100644 --- a/nvflare/client/api.py +++ b/nvflare/client/api.py @@ -14,7 +14,7 @@ import importlib import os -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple from nvflare.apis.analytix import AnalyticsDataType from nvflare.apis.utils.analytix_utils import create_analytic_dxo @@ -32,13 +32,11 @@ PROCESS_MODEL_REGISTRY = None -def _create_client_config(config: Union[str, Dict]) -> ClientConfig: +def _create_client_config(config: str) -> ClientConfig: if isinstance(config, str): client_config = from_file(config_file=config) - elif isinstance(config, dict): - client_config = ClientConfig(config=config) else: - raise ValueError("config should be either a string or dictionary.") + raise ValueError("config should be a string.") return client_config @@ -63,15 +61,24 @@ def _register_tensor_decomposer(): def init( - config: Union[str, Dict] = f"config/{CLIENT_API_CONFIG}", rank: Optional[str] = None, -): +) -> None: """Initializes NVFlare Client API environment. Args: - config (str or dict): configuration file or config dictionary. rank (str): local rank of the process. It is only useful when the training script has multiple worker processes. (for example multi GPU) + + Returns: + None + + Example: + + .. code-block:: python + + nvflare.client.init() + + """ global PROCESS_MODEL_REGISTRY # Declare PROCESS_MODEL_REGISTRY as global @@ -82,7 +89,7 @@ def init( print("Warning: called init() more than once. The subsequence calls are ignored") return - client_config = _create_client_config(config=config) + client_config = _create_client_config(config=f"config/{CLIENT_API_CONFIG}") flare_agent = None try: @@ -114,6 +121,7 @@ def init( def get_model_registry() -> ModelRegistry: + """Gets the ModelRegistry.""" if PROCESS_MODEL_REGISTRY is None: raise RuntimeError("needs to call init method first") return PROCESS_MODEL_REGISTRY @@ -124,6 +132,13 @@ def receive(timeout: Optional[float] = None) -> Optional[FLModel]: Returns: An FLModel received. + + Example: + + .. code-block:: python + + nvflare.client.receive() + """ model_registry = get_model_registry() return model_registry.get_model(timeout) @@ -135,6 +150,13 @@ def send(fl_model: FLModel, clear_registry: bool = True) -> None: Args: fl_model (FLModel): Sends a FLModel object. clear_registry (bool): To clear the registry or not. + + Example: + + .. code-block:: python + + nvflare.client.send(fl_model=FLModel(...)) + """ model_registry = get_model_registry() model_registry.submit_model(model=fl_model) @@ -143,7 +165,15 @@ def send(fl_model: FLModel, clear_registry: bool = True) -> None: def clear(): - """Clears the model registry.""" + """Clears the model registry. + + Example: + + .. code-block:: python + + nvflare.client.clear() + + """ model_registry = get_model_registry() model_registry.clear() @@ -154,29 +184,89 @@ def system_info() -> Dict: System information will be available after a valid FLModel is received. It does not retrieve information actively. + Note: + system information includes job id and site name. + Returns: A dict of system information. + + Example: + + .. code-block:: python + + sys_info = nvflare.client.system_info() + """ model_registry = get_model_registry() return model_registry.get_sys_info() def get_config() -> Dict: + """Gets the ClientConfig dictionary. + + Returns: + A dict of the configuration used in Client API. + + Example: + + .. code-block:: python + + config = nvflare.client.get_config() + + """ model_registry = get_model_registry() return model_registry.config.config def get_job_id() -> str: + """Gets job id. + + Returns: + The current job id. + + Example: + + .. code-block:: python + + job_id = nvflare.client.get_job_id() + + """ sys_info = system_info() return sys_info.get(ConfigKey.JOB_ID, "") def get_site_name() -> str: + """Gets site name. + + Returns: + The site name of this client. + + Example: + + .. code-block:: python + + site_name = nvflare.client.get_site_name() + + """ sys_info = system_info() return sys_info.get(ConfigKey.SITE_NAME, "") def is_running() -> bool: + """Returns whether the NVFlare system is up and running. + + Returns: + True, if the system is up and running. False, otherwise. + + Example: + + .. code-block:: python + + while nvflare.client.is_running(): + # receive model, perform task, send model, etc. + ... + + """ try: receive() return True @@ -185,6 +275,20 @@ def is_running() -> bool: def is_train() -> bool: + """Returns whether the current task is a training task. + + Returns: + True, if the current task is a training task. False, otherwise. + + Example: + + .. code-block:: python + + if nvflare.client.is_train(): + # perform train task on received model + ... + + """ model_registry = get_model_registry() if model_registry.rank != "0": raise RuntimeError("only rank 0 can call is_train!") @@ -192,6 +296,20 @@ def is_train() -> bool: def is_evaluate() -> bool: + """Returns whether the current task is an evaluate task. + + Returns: + True, if the current task is an evaluate task. False, otherwise. + + Example: + + .. code-block:: python + + if nvflare.client.is_evaluate(): + # perform evaluate task on received model + ... + + """ model_registry = get_model_registry() if model_registry.rank != "0": raise RuntimeError("only rank 0 can call is_evaluate!") @@ -199,17 +317,58 @@ def is_evaluate() -> bool: def is_submit_model() -> bool: + """Returns whether the current task is a submit_model task. + + Returns: + True, if the current task is a submit_model. False, otherwise. + + Example: + + .. code-block:: python + + if nvflare.client.is_submit_model(): + # perform submit_model task to obtain the best local model + ... + + """ model_registry = get_model_registry() if model_registry.rank != "0": raise RuntimeError("only rank 0 can call is_submit_model!") return model_registry.task_name == model_registry.config.get_submit_model_task() -def log(key: str, value: Any, data_type: AnalyticsDataType, **kwargs): +def log(key: str, value: Any, data_type: AnalyticsDataType, **kwargs) -> bool: + """Logs a key value pair. + + We suggest users use the high-level APIs in nvflare/client/tracking.py + + Args: + key (str): key string. + value (Any): value to log. + data_type (AnalyticsDataType): the data type of the "value". + kwargs: additional arguments to be included. + + Returns: + whether the key value pair is logged successfully + + Example: + + .. code-block:: python + + log( + key=tag, + value=scalar, + data_type=AnalyticsDataType.SCALAR, + global_step=global_step, + writer=LogWriterName.TORCH_TB, + **kwargs, + ) + + """ model_registry = get_model_registry() if model_registry.rank != "0": raise RuntimeError("only rank 0 can call log!") flare_agent = model_registry.flare_agent dxo = create_analytic_dxo(tag=key, value=value, data_type=data_type, **kwargs) - flare_agent.log(dxo) + return flare_agent.log(dxo) diff --git a/nvflare/client/config.py b/nvflare/client/config.py index 2e7b9e3e17..e85d3ab837 100644 --- a/nvflare/client/config.py +++ b/nvflare/client/config.py @@ -14,7 +14,6 @@ import json import os -from enum import Enum from typing import Dict, Optional from nvflare.fuel.utils.config_factory import ConfigFactory @@ -26,7 +25,7 @@ class ExchangeFormat: NUMPY = "numpy" -class TransferType(str, Enum): +class TransferType: FULL = "FULL" DIFF = "DIFF" @@ -34,7 +33,6 @@ class TransferType(str, Enum): class ConfigKey: EXCHANGE_FORMAT = "exchange_format" TRANSFER_TYPE = "transfer_type" - GLOBAL_EVAL = "global_eval" TRAIN_WITH_EVAL = "train_with_eval" TRAIN_TASK_NAME = "train_task_name" EVAL_TASK_NAME = "eval_task_name" @@ -50,47 +48,72 @@ class ConfigKey: class ClientConfig: - """Config class used in nvflare.client module. + """Config class used in `nvflare.client` module. + + Note: + The config has the following keys: + + .. code-block:: + + EXCHANGE_FORMAT: Format to exchange, pytorch, raw, or numpy + TRANSFER_TYPE: Either FULL or DIFF (means difference) + TRAIN_WITH_EVAL: Whether train task needs to also do evaluation + TRAIN_TASK_NAME: Name of the train task + EVAL_TASK_NAME: Name of the evaluate task + SUBMIT_MODEL_TASK_NAME: Name of the submit_model task + PIPE_CHANNEL_NAME: Channel name of the pipe + PIPE: pipe section + CLASS_NAME: Class name + ARG: Arguments + SITE_NAME: Site name + JOB_ID: Job id + TASK_EXCHANGE: TASK_EXCHANGE section + METRICS_EXCHANGE: METRICS_EXCHANGE section Example: - { - "METRICS_EXCHANGE": { - "pipe_channel_name": "metric", - "pipe": { - "CLASS_NAME": "nvflare.fuel.utils.pipe.cell_pipe.CellPipe", - "ARG": { - "mode": "ACTIVE", - "site_name": "site-1", - "token": "simulate_job", - "root_url": "tcp://0:51893", - "secure_mode": false, - "workspace_dir": "xxx" + The content of config looks like: + + .. code-block:: json + + { + "METRICS_EXCHANGE": { + "pipe_channel_name": "metric", + "pipe": { + "CLASS_NAME": "nvflare.fuel.utils.pipe.cell_pipe.CellPipe", + "ARG": { + "mode": "ACTIVE", + "site_name": "site-1", + "token": "simulate_job", + "root_url": "tcp://0:51893", + "secure_mode": false, + "workspace_dir": "xxx" + } + } + }, + "SITE_NAME": "site-1", + "JOB_ID": "simulate_job", + "TASK_EXCHANGE": { + "train_with_eval": true, + "exchange_format": "numpy", + "transfer_type": "DIFF", + "train_task_name": "train", + "eval_task_name": "evaluate", + "submit_model_task_name": "submit_model", + "pipe_channel_name": "task", + "pipe": { + "CLASS_NAME": "nvflare.fuel.utils.pipe.cell_pipe.CellPipe", + "ARG": { + "mode": "ACTIVE", + "site_name": "site-1", + "token": "simulate_job", + "root_url": "tcp://0:51893", + "secure_mode": false, + "workspace_dir": "xxx" + } + } } } - }, - "SITE_NAME": "site-1", - "JOB_ID": "simulate_job", - "TASK_EXCHANGE": { - "train_with_eval": true, - "exchange_format": "numpy", - "transfer_type": "DIFF", - "train_task_name": "train", - "eval_task_name": "evaluate", - "submit_model_task_name": "submit_model", - "pipe_channel_name": "task", - "pipe": { - "CLASS_NAME": "nvflare.fuel.utils.pipe.cell_pipe.CellPipe", - "ARG": { - "mode": "ACTIVE", - "site_name": "site-1", - "token": "simulate_job", - "root_url": "tcp://0:51893", - "secure_mode": false, - "workspace_dir": "xxx" - } - } - } - } + """ def __init__(self, config: Optional[Dict] = None): @@ -110,10 +133,10 @@ def get_pipe_args(self, section: str) -> dict: def get_pipe_class(self, section: str) -> str: return self.config[section][ConfigKey.PIPE][ConfigKey.CLASS_NAME] - def get_exchange_format(self) -> ExchangeFormat: + def get_exchange_format(self) -> str: return self.config[ConfigKey.TASK_EXCHANGE][ConfigKey.EXCHANGE_FORMAT] - def get_transfer_type(self): + def get_transfer_type(self) -> str: return self.config.get(ConfigKey.TASK_EXCHANGE, {}).get(ConfigKey.TRANSFER_TYPE, "FULL") def get_train_task(self): diff --git a/nvflare/client/decorator.py b/nvflare/client/decorator.py index 87c6762fbe..a70e0bba18 100644 --- a/nvflare/client/decorator.py +++ b/nvflare/client/decorator.py @@ -30,6 +30,23 @@ def train( _func=None, **root_kwargs, ): + """A decorator to wraps the training logic. + + Note: + FLARE will pass the model received from the server side to the first argument of the decorated method. + The return value of the decorated training method needs to be an FLModel object. + + Usage: + + .. code-block:: python + + @nvflare.client.train + def my_train(input_model=None, device="cuda:0"): + ... + return new_model + + """ + def decorator(train_fn): @functools.wraps(train_fn) def wrapper(*args, **kwargs): @@ -65,6 +82,25 @@ def evaluate( _func=None, **root_kwargs, ): + """A decorator to wraps the evaluate logic. + + Note: + FLARE will pass the model received from the server side to the first argument of the decorated method. + The return value of the decorated method needs to be a float number metric. + The decorated method needs to be run BEFORE the training method, + so the metrics will be sent along with the trained output model. + + Usage: + + .. code-block:: python + + @nvflare.client.evaluate + def my_eval(input_model, device="cuda:0"): + ... + return metrics + + """ + def decorator(eval_fn): @functools.wraps(eval_fn) def wrapper(*args, **kwargs): diff --git a/nvflare/client/flare_agent.py b/nvflare/client/flare_agent.py index a9f9c7c890..763449d824 100644 --- a/nvflare/client/flare_agent.py +++ b/nvflare/client/flare_agent.py @@ -16,7 +16,7 @@ import threading import time import traceback -from typing import Optional +from typing import Any, Optional from nvflare.apis.dxo import DXO, MetaKey, from_shareable from nvflare.apis.fl_constant import FLContextKey @@ -71,17 +71,34 @@ def __init__( max_resends=None, submit_result_timeout=30.0, metric_pipe=None, - task_channel_name=PipeChannelName.TASK, - metric_channel_name=PipeChannelName.METRIC, + task_channel_name: str = PipeChannelName.TASK, + metric_channel_name: str = PipeChannelName.METRIC, close_pipe: bool = True, close_metric_pipe: bool = True, ): - """Constructor of Flare Agent. The agent is responsible for communicating with the Flare Client Job cell (CJ) + """Constructor of Flare Agent. + + The agent is responsible for communicating with the Flare Client Job cell (CJ) to get task and to submit task result. Args: - pipe: pipe for communication - submit_result_timeout: when submitting task result, how long to wait for response from the CJ + pipe (Pipe): pipe for task communication. + read_interval (float): how often to read from the pipe. Defaults to 0.1. + heartbeat_interval (float): how often to send a heartbeat to the peer. Defaults to 5.0. + heartbeat_timeout (float): how long to wait for a heartbeat from the peer before treating the peer as dead, + 0 means DO NOT check for heartbeat. Defaults to 30.0. + resend_interval (float): how often to resend a message if failing to send. None means no resend. + Note that if the pipe does not support resending, then no resend. Defaults to 2.0. + max_resends (int, optional): max number of resend. None means no limit. Defaults to None. + submit_result_timeout (float): when submitting task result, + how long to wait for response from the CJ. Defaults to 30.0. + metric_pipe (Pipe, optional): pipe for metric communication. Defaults to None. + task_channel_name (str): channel name for task. Defaults to ``task``. + metric_channel_name (str): channel name for metric. Defaults to ``metric``. + close_pipe (bool): whether to close the task pipe when stopped. Defaults to True. + Usually for ``FilePipe`` we set to False, for ``CellPipe`` we set to True. + close_metric_pipe (bool): whether to close the metric pipe when stopped. Defaults to True. + Usually for ``FilePipe`` we set to False, for ``CellPipe`` we set to True. """ flare_decomposers.register() common_decomposers.register() @@ -119,7 +136,9 @@ def __init__( self._close_metric_pipe = close_metric_pipe def start(self): - """Start the agent. This method must be called to enable CJ/Agent communication. + """Start the agent. + + This method must be called to enable CJ/Agent communication. Returns: None @@ -141,7 +160,9 @@ def _status_cb(self, msg: Message, pipe_handler: PipeHandler, channel): pipe_handler.stop(self._close_pipe) def stop(self): - """Stop the agent. After this is called, there will be no more communications between CJ and agent. + """Stop the agent. + + After this is called, there will be no more communications between CJ and agent. Returns: None @@ -152,13 +173,17 @@ def stop(self): if self.metric_pipe_handler: self.metric_pipe_handler.stop(self._close_metric_pipe) - def shareable_to_task_data(self, shareable: Shareable): + def shareable_to_task_data(self, shareable: Shareable) -> Any: """Convert the Shareable object received from the TaskExchanger to an app-friendly format. + Subclass can override this method to convert to its own app-friendly task data. By default, we convert to DXO object. Args: shareable: the Shareable object received from the TaskExchanger. + + Returns: + task data. """ try: dxo = from_shareable(shareable) @@ -175,7 +200,7 @@ def shareable_to_task_data(self, shareable: Shareable): self.logger.error(f"failed to extract DXO from shareable object: {ex}") raise ex - def get_task(self, timeout: Optional[float] = None): + def get_task(self, timeout: Optional[float] = None) -> Optional[Task]: """Get a task from FLARE. This is a blocking call. Args: @@ -184,6 +209,7 @@ def get_task(self, timeout: Optional[float] = None): Returns: None if no task is available before timeout; or a Task object if task is available. + Raises: AgentClosed exception if the agent has been closed before timeout. CallStateError exception if the call has not been made properly. @@ -229,6 +255,7 @@ def get_task(self, timeout: Optional[float] = None): def submit_result(self, result, rc=RC.OK) -> bool: """Submit the result of the current task. + This is a blocking call. The agent will try to send the result to flare site until it is successfully sent or the task is aborted or the agent is closed. @@ -236,8 +263,11 @@ def submit_result(self, result, rc=RC.OK) -> bool: result: result to be submitted rc: return code - Returns: whether the result is submitted successfully - Raises: the CallStateError exception if the submit_result call is not made properly. + Returns: + whether the result is submitted successfully + + Raises: + the CallStateError exception if the submit_result call is not made properly. Notes: the application must only make this call after the received task is processed. The call can only be made a single time regardless whether the submission is successful. @@ -261,14 +291,15 @@ def submit_result(self, result, rc=RC.OK) -> bool: return result - def task_result_to_shareable(self, result, rc) -> Shareable: + def task_result_to_shareable(self, result: Any, rc) -> Shareable: """Convert the result object to Shareable object before sending back to the TaskExchanger. + Subclass can override this method to convert its app-friendly result type to Shareable. By default, we expect the result to be DXO object. Args: - result: the result object to be converted to Shareable. If None, an empty Shareable object will be - created with the rc only. + result: the result object to be converted to Shareable. + If None, an empty Shareable object will be created with the rc only. rc: the return code. Returns: @@ -289,7 +320,15 @@ def _do_submit_result(self, current_task: _TaskContext, result, rc): reply = Message.new_reply(topic=current_task.task_name, req_msg_id=current_task.msg_id, data=result) return self.pipe_handler.send_to_peer(reply, self.submit_result_timeout) - def log(self, record: DXO): + def log(self, record: DXO) -> bool: + """Logs a metric record. + + Args: + record (DXO): A metric record. + + Returns: + whether the metric record is submitted successfully + """ if not self.metric_pipe_handler: raise RuntimeError("metric pipe is not available") @@ -313,6 +352,24 @@ def __init__( submit_result_timeout=30.0, has_metrics=False, ): + """Constructor of Flare Agent with Cell Pipe. This is a convenient class. + + Args: + agent_id (str): unique id to guarantee the uniqueness of cell's FQCN. + site_name (str): name of the FLARE site + root_url (str): the root url of the cellnet that the pipe's cell will join + secure_mode (bool): whether connection to the root is secure (TLS) + workspace_dir (str): the directory that contains startup for joining the cellnet. Required only in secure mode + read_interval (float): how often to read from the pipe. + heartbeat_interval (float): how often to send a heartbeat to the peer. + heartbeat_timeout (float): how long to wait for a heartbeat from the peer before treating the peer as gone, + 0 means DO NOT check for heartbeat. + resend_interval (float): how often to resend a message if failing to send. None means no resend. + Note that if the pipe does not support resending, then no resend. + max_resends (int, optional): max number of resend. None means no limit. + submit_result_timeout (float): when submitting task result, how long to wait for response from the CJ. + has_metrics (bool): has metric pipe or not. + """ pipe = CellPipe( mode=Mode.ACTIVE, token=agent_id, diff --git a/nvflare/client/lightning/__init__.py b/nvflare/client/lightning/__init__.py index a3f1d5acbb..395e6728ab 100644 --- a/nvflare/client/lightning/__init__.py +++ b/nvflare/client/lightning/__init__.py @@ -12,6 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""PyTorch Lightning API integration module for simplified imports. + +Usage: + from nvflare.client.lightning import patch + +For detailed information on usage and the API, refer to: + :mod:`nvflare.app_opt.lightning.api` + +""" + from nvflare.fuel.utils.import_utils import optional_import pytorch_lightning, ok = optional_import(module="pytorch_lightning") diff --git a/nvflare/client/model_registry.py b/nvflare/client/model_registry.py index da3869bcbb..92c24eb1da 100644 --- a/nvflare/client/model_registry.py +++ b/nvflare/client/model_registry.py @@ -23,7 +23,7 @@ class ModelRegistry(TaskRegistry): - """This class is used to remember attributes that need to share for a user code. + """This class is used to remember attributes that need to be shared for a user code. For example, after "global_evaluate" we should remember the "metrics" value. And set that into the model that we want to submit after "train". @@ -39,6 +39,17 @@ def __init__(self, config: ClientConfig, rank: Optional[str] = None, flare_agent self.metrics = None def get_model(self, timeout: Optional[float] = None) -> Optional[FLModel]: + """Gets a model from FLARE client. + + This method gets the task from FLARE client, and extract the `task.data` out. + + Args: + timeout (float, optional): If specified, this call is blocked only for the specified amount of time. + If not specified, this call is blocked forever until a task has been received or agent has been closed. + + Returns: + None if flare agent is None; or an FLModel object if a task is available within timeout. + """ task = self.get_task(timeout) if task is not None and task.data is not None: if not isinstance(task.data, FLModel): @@ -47,6 +58,11 @@ def get_model(self, timeout: Optional[float] = None) -> Optional[FLModel]: return None def submit_model(self, model: FLModel) -> None: + """Submits a model to FLARE client. + + Args: + model (FLModel): Trained local model to be submitted. + """ if not self.flare_agent: return None if self.config.get_transfer_type() == "DIFF": @@ -74,5 +90,6 @@ def submit_model(self, model: FLModel) -> None: self.submit_task(model) def clear(self): + """Clears the model registry cache.""" super().clear() self.metrics = None diff --git a/nvflare/client/task_registry.py b/nvflare/client/task_registry.py index 7e4dac2dbd..88f556d63c 100644 --- a/nvflare/client/task_registry.py +++ b/nvflare/client/task_registry.py @@ -20,7 +20,7 @@ class TaskRegistry: - """This class is used to remember attributes that need to share for a user code.""" + """This class is used to remember attributes that need to be shared for a user code.""" def __init__(self, config: ClientConfig, rank: Optional[str] = None, flare_agent: Optional[FlareAgent] = None): self.flare_agent = flare_agent @@ -41,6 +41,9 @@ def _receive(self, timeout: Optional[float] = None): task = self.flare_agent.get_task(timeout) + if task is None: + raise RuntimeError(f"no received task within timeout: {timeout}") + if task.data is None: raise RuntimeError("no received task.data") @@ -48,24 +51,58 @@ def _receive(self, timeout: Optional[float] = None): self.task_name = task.task_name self.cache_loaded = True - def set_task_name(self, task_name: str): + def set_task_name(self, task_name: str) -> None: + """Sets the current task name. + + This method is only used in multiprocess scenario in the lightning API. + For non-rank 0 processes, they are not getting tasks from the FLARE side, + thus they rely on the rank 0 process to tell them the current task name + and will use this method to set it. + + Args: + task_name (str): current task name + """ self.task_name = task_name def get_task(self, timeout: Optional[float] = None) -> Optional[Task]: + """Gets the cached received task. + + Args: + timeout (float, optional): If specified, this call is blocked only for the specified amount of time. + If not specified, this call is blocked forever until a task has been received or agent has been closed. + + Returns: + None if flare agent is None; or a Task object if task is available within timeout. + """ if not self.cache_loaded: self._receive(timeout) return self.received_task def get_sys_info(self) -> Dict: + """Gets NVFlare system information. + + Returns: + A dict of system information. + """ return self.sys_info - def submit_task(self, data: Any, return_code: str = RC.OK) -> None: + def submit_task(self, data: Any, return_code: str = RC.OK) -> bool: + """Submits result of the current task. + + Args: + data: task result + return_code (str): return code of the task execution + + Returns: + whether the result is submitted successfully + """ if not self.flare_agent or not self.task_name or self.received_task is None: - return None + return False - self.flare_agent.submit_result(result=data, rc=return_code) + return self.flare_agent.submit_result(result=data, rc=return_code) - def clear(self): + def clear(self) -> None: + """Clears the cached received task.""" self.received_task = None self.cache_loaded = False diff --git a/nvflare/client/tracking.py b/nvflare/client/tracking.py index 2a8cef83a7..c04b797037 100644 --- a/nvflare/client/tracking.py +++ b/nvflare/client/tracking.py @@ -21,7 +21,12 @@ class SummaryWriter: - """Mimics Tensorboard apis.""" + """SummaryWriter mimics the usage of Tensorboard's SummaryWriter. + + Users can replace the import of Tensorboard's SummaryWriter with FLARE's SummaryWriter. + They would then use SummaryWriter the same as before. + SummaryWriter will send log records to the FLARE system. + """ def add_scalar(self, tag: str, scalar: float, global_step: Optional[int] = None, **kwargs): """Sends a scalar. @@ -61,6 +66,13 @@ def add_scalars(self, tag: str, scalars: dict, global_step: Optional[int] = None class WandBWriter: + """WandBWriter mimics the usage of weights and biases. + + Users can replace the import of wandb with FLARE's WandBWriter. + They would then use WandBWriter the same as they would use wandb. + WandBWriter will send log records to the FLARE system. + """ + def log(self, metrics: Dict[str, float], step: Optional[int] = None): """Log multiple metrics for the current run. @@ -69,7 +81,7 @@ def log(self, metrics: Dict[str, float], step: Optional[int] = None): step (int, optional): A single integer step at which to log the specified Metrics. """ log( - tag="metrics", + key="metrics", value=metrics, data_type=AnalyticsDataType.METRICS, global_step=step, @@ -78,11 +90,11 @@ def log(self, metrics: Dict[str, float], step: Optional[int] = None): class MLflowWriter: - """MLflowWriter mimics the usage of mlflow. + """MLflowWriter mimics the usage of MLflow. - Users can replace the import of mlflow with MLflowWriter. They would then use - MLflowWriter the same as they would use mlflow. MLflowWriter will send log records to - the receiver. + Users can replace the import of MLflow with FLARE's MLflowWriter. + They would then use MLflowWriter the same as they would use MLflow. + MLflowWriter will send log records to the FLARE system. """ def log_param(self, key: str, value: any) -> None: diff --git a/nvflare/dashboard/application/blob.py b/nvflare/dashboard/application/blob.py index 389b5bb4c6..174e2135bb 100644 --- a/nvflare/dashboard/application/blob.py +++ b/nvflare/dashboard/application/blob.py @@ -25,24 +25,17 @@ lighter_folder = os.path.dirname(utils.__file__) template = utils.load_yaml(os.path.join(lighter_folder, "impl", "master_template.yml")) - - -def get_csp_template(csp, participant, template): - return template[f"{csp}_start_{participant}_sh"] +supported_csps = ["aws", "azure"] +for csp in supported_csps: + csp_template_file = os.path.join(lighter_folder, "impl", f"{csp}_template.yml") + if os.path.exists(csp_template_file): + template.update(utils.load_yaml(csp_template_file)) def get_csp_start_script_name(csp): return f"{csp}_start.sh" -def _write(file_full_path, content, mode, exe=False): - mode = mode + "w" - with open(file_full_path, mode) as f: - f.write(content) - if exe: - os.chmod(file_full_path, 0o755) - - def gen_overseer(key): project = Project.query.first() entity = Entity(project.overseer) @@ -54,21 +47,19 @@ def gen_overseer(key): dest_dir = os.path.join(overseer_dir, "startup") os.mkdir(overseer_dir) os.mkdir(dest_dir) - _write( + utils._write( os.path.join(dest_dir, "start.sh"), template["start_ovsr_sh"], "t", exe=True, ) - _write( + utils._write( os.path.join(dest_dir, "gunicorn.conf.py"), utils.sh_replace(template["gunicorn_conf_py"], {"port": "8443"}), "t", exe=False, ) - _write(os.path.join(dest_dir, "overseer.crt"), cert_pair.ser_cert, "b", exe=False) - _write(os.path.join(dest_dir, "overseer.key"), cert_pair.ser_pri_key, "b", exe=False) - _write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False) + utils._write_pki(type="overseer", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert) run_args = ["zip", "-rq", "-P", key, "tmp.zip", "."] subprocess.run(run_args, cwd=tmp_dir) fileobj = io.BytesIO() @@ -121,6 +112,8 @@ def gen_server(key, first_server=True): "ha_mode": "true" if project.ha_mode else "false", "docker_image": project.app_location.split(" ")[-1] if project.app_location else "nvflare/nvflare", "org_name": "", + "type": "server", + "cln_uid": "", } tplt = tplt_utils.Template(template) with tempfile.TemporaryDirectory() as tmp_dir: @@ -128,82 +121,33 @@ def gen_server(key, first_server=True): dest_dir = os.path.join(server_dir, "startup") os.mkdir(server_dir) os.mkdir(dest_dir) - _write(os.path.join(dest_dir, "fed_server.json"), json.dumps(config, indent=2), "t") - _write( - os.path.join(dest_dir, "docker.sh"), - utils.sh_replace(template["docker_svr_sh"], replacement_dict), - "t", - exe=True, - ) - _write( - os.path.join(dest_dir, "start.sh"), - utils.sh_replace(template["start_svr_sh"], replacement_dict), - "t", - exe=True, - ) - _write( - os.path.join(dest_dir, "sub_start.sh"), - utils.sh_replace(template["sub_start_svr_sh"], replacement_dict), - "t", - exe=True, - ) - _write( - os.path.join(dest_dir, "stop_fl.sh"), - template["stop_fl_sh"], - "t", - exe=True, + utils._write_common( + type="server", + dest_dir=dest_dir, + template=template, + tplt=tplt, + replacement_dict=replacement_dict, + config=config, ) - _write(os.path.join(dest_dir, "server.crt"), cert_pair.ser_cert, "b", exe=False) - _write(os.path.join(dest_dir, "server.key"), cert_pair.ser_pri_key, "b", exe=False) - _write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False) + utils._write_pki(type="server", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert) if not project.ha_mode: - _write( - os.path.join(dest_dir, get_csp_start_script_name("azure")), - utils.sh_replace( - tplt.get_cloud_script_header() + get_csp_template("azure", "svr", template), - {"server_name": entity.name, "ORG": ""}, - ), - "t", - exe=True, - ) - _write( - os.path.join(dest_dir, get_csp_start_script_name("aws")), - utils.sh_replace( - tplt.get_cloud_script_header() + get_csp_template("aws", "svr", template), - {"server_name": entity.name, "ORG": ""}, - ), - "t", - exe=True, - ) + for csp in supported_csps: + utils._write( + os.path.join(dest_dir, get_csp_start_script_name(csp)), + tplt.get_start_sh(csp=csp, type="server", entity=entity), + "t", + exe=True, + ) signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key)) json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt")) # local folder creation dest_dir = os.path.join(server_dir, "local") os.mkdir(dest_dir) - _write( - os.path.join(dest_dir, "log.config.default"), - template["log_config"], - "t", - ) - _write( - os.path.join(dest_dir, "resources.json.default"), - template["local_server_resources"], - "t", - ) - _write( - os.path.join(dest_dir, "privacy.json.sample"), - template["sample_privacy"], - "t", - ) - _write( - os.path.join(dest_dir, "authorization.json.default"), - template["default_authz"], - "t", - ) + utils._write_local(type="server", dest_dir=dest_dir, template=template) # workspace folder file - _write( + utils._write( os.path.join(server_dir, "readme.txt"), template["readme_fs"], "t", @@ -233,6 +177,8 @@ def gen_client(key, id): "config_folder": "config", "docker_image": project.app_location.split(" ")[-1] if project.app_location else "nvflare/nvflare", "org_name": entity.org, + "type": "client", + "cln_uid": f"uid={entity.name}", } if project.ha_mode: overseer_agent = {"path": "nvflare.ha.overseer_agent.HttpOverseerAgent"} @@ -254,85 +200,34 @@ def gen_client(key, id): os.mkdir(client_dir) os.mkdir(dest_dir) - _write(os.path.join(dest_dir, "fed_client.json"), json.dumps(config, indent=2), "t") - _write( - os.path.join(dest_dir, "docker.sh"), - utils.sh_replace(template["docker_cln_sh"], replacement_dict), - "t", - exe=True, - ) - _write( - os.path.join(dest_dir, "start.sh"), - template["start_cln_sh"], - "t", - exe=True, - ) - _write( - os.path.join(dest_dir, "sub_start.sh"), - utils.sh_replace(template["sub_start_cln_sh"], replacement_dict), - "t", - exe=True, - ) - _write( - os.path.join(dest_dir, "stop_fl.sh"), - template["stop_fl_sh"], - "t", - exe=True, - ) - _write(os.path.join(dest_dir, "client.crt"), cert_pair.ser_cert, "b", exe=False) - _write(os.path.join(dest_dir, "client.key"), cert_pair.ser_pri_key, "b", exe=False) - _write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False) - _write( - os.path.join(dest_dir, get_csp_start_script_name("azure")), - utils.sh_replace( - tplt.get_cloud_script_header() + get_csp_template("azure", "cln", template), - {"SITE": entity.name, "ORG": entity.org}, - ), - "t", - exe=True, - ) - _write( - os.path.join(dest_dir, get_csp_start_script_name("aws")), - utils.sh_replace( - tplt.get_cloud_script_header() + get_csp_template("aws", "cln", template), - {"SITE": entity.name, "ORG": entity.org}, - ), - "t", - exe=True, + utils._write_pki(type="client", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert) + utils._write_common( + type="client", + dest_dir=dest_dir, + template=template, + tplt=tplt, + replacement_dict=replacement_dict, + config=config, ) + + for csp in supported_csps: + utils._write( + os.path.join(dest_dir, get_csp_start_script_name(csp)), + tplt.get_start_sh(csp=csp, type="client", entity=entity), + "t", + exe=True, + ) + signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key)) json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt")) # local folder creation dest_dir = os.path.join(client_dir, "local") os.mkdir(dest_dir) - _write( - os.path.join(dest_dir, "log.config.default"), - template["log_config"], - "t", - ) - resources = json.loads(template["local_client_resources"]) - for component in resources["components"]: - if "nvflare.app_common.resource_managers.gpu_resource_manager.GPUResourceManager" == component["path"]: - component["args"] = json.loads(client.capacity.capacity) - break - _write( - os.path.join(dest_dir, "resources.json.default"), - json.dumps(resources, indent=2), - "t", - ) - _write( - os.path.join(dest_dir, "privacy.json.sample"), - template["sample_privacy"], - "t", - ) - _write( - os.path.join(dest_dir, "authorization.json.default"), - template["default_authz"], - "t", - ) + utils._write_local(type="client", dest_dir=dest_dir, template=template, capacity=client.capacity.capacity) + # workspace folder file - _write( + utils._write( os.path.join(client_dir, "readme.txt"), template["readme_fc"], "t", @@ -378,16 +273,14 @@ def gen_user(key, id): os.mkdir(user_dir) os.mkdir(dest_dir) - _write(os.path.join(dest_dir, "fed_admin.json"), json.dumps(config, indent=2), "t") - _write( + utils._write(os.path.join(dest_dir, "fed_admin.json"), json.dumps(config, indent=2), "t") + utils._write( os.path.join(dest_dir, "fl_admin.sh"), utils.sh_replace(template["fl_admin_sh"], replacement_dict), "t", exe=True, ) - _write(os.path.join(dest_dir, "client.crt"), cert_pair.ser_cert, "b", exe=False) - _write(os.path.join(dest_dir, "client.key"), cert_pair.ser_pri_key, "b", exe=False) - _write(os.path.join(dest_dir, "rootCA.pem"), project.root_cert, "b", exe=False) + utils._write_pki(type="client", dest_dir=dest_dir, cert_pair=cert_pair, root_cert=project.root_cert) signatures = utils.sign_all(dest_dir, deserialize_ca_key(project.root_key)) json.dump(signatures, open(os.path.join(dest_dir, "signature.json"), "wt")) @@ -396,12 +289,12 @@ def gen_user(key, id): os.mkdir(dest_dir) # workspace folder file - _write( + utils._write( os.path.join(user_dir, "readme.txt"), template["readme_am"], "t", ) - _write( + utils._write( os.path.join(user_dir, "system_info.ipynb"), utils.sh_replace(template["adm_notebook"], replacement_dict), "t", diff --git a/nvflare/dashboard/cli.py b/nvflare/dashboard/cli.py index a31409b545..58e45dbe10 100644 --- a/nvflare/dashboard/cli.py +++ b/nvflare/dashboard/cli.py @@ -21,7 +21,6 @@ import docker import nvflare from nvflare.apis.utils.format_check import name_check -from nvflare.dashboard.application.blob import _write from nvflare.lighter import tplt_utils, utils supported_csp = ("azure", "aws") @@ -146,7 +145,7 @@ def cloud(args): dsb_start = template[f"{csp}_start_dsb_sh"] version = nvflare.__version__ replacement_dict = {"NVFLARE": f"nvflare=={version}", "START_OPT": f"-i {args.image}" if args.image else ""} - _write( + utils._write( dest, utils.sh_replace(tplt.get_cloud_script_header() + dsb_start, replacement_dict), "t", diff --git a/nvflare/fuel/data_event/__init__.py b/nvflare/fuel/data_event/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/nvflare/fuel/data_event/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nvflare/fuel/data_event/data_bus.py b/nvflare/fuel/data_event/data_bus.py new file mode 100644 index 0000000000..ef1fd4a4be --- /dev/null +++ b/nvflare/fuel/data_event/data_bus.py @@ -0,0 +1,106 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Callable, List + +from nvflare.fuel.data_event.pub_sub import EventPubSub + + +class DataBus(EventPubSub): + """ + Singleton class for a simple data bus implementation. + + This class allows components to subscribe to topics, publish messages to topics, + and store/retrieve messages associated with specific keys and topics. + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls) -> "DataBus": + """ + Create a new instance of the DataBus class. + This method ensures that only one instance of the class is created (singleton pattern). + The databus + + + """ + with cls._lock: + if not cls._instance: + cls._instance = super(DataBus, cls).__new__(cls) + cls._instance.subscribers = {} + cls._instance.data_store = {} + return cls._instance + + def subscribe(self, topics: List[str], callback: Callable[[str, Any, "DataBus"], None]) -> None: + """ + Subscribe a callback function to one or more topics. + + Args: + topics (List[str]): A list of topics to subscribe to. + callback (Callable): The callback function to be called when messages are published to the subscribed topics. + """ + + if not topics: + raise ValueError("topics must non-empty") + + for topic in topics: + if topic.isspace(): + raise ValueError(f"topics {topics}contains white space topic") + + with self._lock: + if topic not in self.subscribers: + self.subscribers[topic] = [] + self.subscribers[topic].append(callback) + + def publish(self, topics: List[str], datum: Any) -> None: + """ + Publish a data to one or more topics, notifying all subscribed callbacks. + + Args: + topics (List[str]): A list of topics to publish the data to. + datum (Any): The data to be published to the specified topics. + """ + if topics: + for topic in topics: + if topic in self.subscribers: + with self._lock: + executor = ThreadPoolExecutor(max_workers=len(self.subscribers[topic])) + for callback in self.subscribers[topic]: + executor.submit(callback, topic, datum, self) + executor.shutdown() + + def put_data(self, key: Any, datum: Any) -> None: + """ + Store a data associated with a key and topic. + + Args: + key (Any): The key to associate with the stored message. + datum (Any): The message to be stored. + """ + with self._lock: + self.data_store[key] = datum + + def get_data(self, key: Any) -> Any: + """ + Retrieve a stored data associated with a key and topic. + + Args: + key (Any): The key associated with the stored message. + + Returns: + Any: The stored datum if found, or None if not found. + """ + return self.data_store.get(key) diff --git a/nvflare/fuel/data_event/event_manager.py b/nvflare/fuel/data_event/event_manager.py new file mode 100644 index 0000000000..6421f8bf4e --- /dev/null +++ b/nvflare/fuel/data_event/event_manager.py @@ -0,0 +1,45 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +from nvflare.fuel.data_event.data_bus import DataBus + + +class EventManager: + """ + Class for managing events by interacting with a DataBus. + + Args: + data_bus (DataBus): An instance of the DataBus class used for event communication. + """ + + def __init__(self, data_bus: "DataBus"): + """ + Initialize the EventManager with a DataBus instance. + + Args: + data_bus (DataBus): An instance of the DataBus class used for event communication. + """ + self.data_bus = data_bus + + def fire_event(self, event_name: str, event_data: Optional[Any] = None) -> None: + """ + Fire an event by publishing it to the DataBus. + + Args: + event_name (str): The name of the event to be fired. + event_data (Any, optional): Additional data associated with the event (default is None). + """ + self.data_bus.publish([event_name], event_data) diff --git a/nvflare/fuel/data_event/pub_sub.py b/nvflare/fuel/data_event/pub_sub.py new file mode 100644 index 0000000000..63583c8b13 --- /dev/null +++ b/nvflare/fuel/data_event/pub_sub.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, List + + +class EventPubSub: + def subscribe(self, topics: List[str], callback: Callable[[str, Any, "DataBus"], None]) -> None: + """ + Subscribe a callback function to one or more topics. + + Args: + topics (List[str]): A list of topics to subscribe to. + callback (Callable): The callback function to be called when messages are published to the subscribed topics. + """ + + def publish(self, topics: List[str], datum: Any) -> None: + """ + Publish a message to one or more topics, notifying all subscribed callbacks. + + Args: + topics (List[str]): A list of topics to publish the message to. + datum (Any): The message to be published to the specified topics. + """ diff --git a/nvflare/fuel/f3/cellnet/core_cell.py b/nvflare/fuel/f3/cellnet/core_cell.py index 57aeb7a9eb..9ca89a85e6 100644 --- a/nvflare/fuel/f3/cellnet/core_cell.py +++ b/nvflare/fuel/f3/cellnet/core_cell.py @@ -1159,7 +1159,18 @@ def _send_target_messages( if ep: reachable_targets[t] = ep else: - self.log_error(f"cannot send to '{t}': {err}", tm.message) + msg = Message(headers=copy.copy(tm.message.headers), payload=tm.message.payload) + msg.add_headers( + { + MessageHeaderKey.CHANNEL: tm.channel, + MessageHeaderKey.TOPIC: tm.topic, + MessageHeaderKey.FROM_CELL: self.my_info.fqcn, + MessageHeaderKey.TO_CELL: t, + MessageHeaderKey.ORIGIN: self.my_info.fqcn, + MessageHeaderKey.DESTINATION: t, + } + ) + self.log_error(f"cannot send to '{t}': {err}", msg) send_errs[t] = err for t, ep in reachable_targets.items(): @@ -1170,12 +1181,12 @@ def _send_target_messages( { MessageHeaderKey.CHANNEL: tm.channel, MessageHeaderKey.TOPIC: tm.topic, - MessageHeaderKey.ORIGIN: self.my_info.fqcn, MessageHeaderKey.FROM_CELL: self.my_info.fqcn, + MessageHeaderKey.TO_CELL: ep.name, + MessageHeaderKey.ORIGIN: self.my_info.fqcn, + MessageHeaderKey.DESTINATION: t, MessageHeaderKey.MSG_TYPE: MessageType.REQ, MessageHeaderKey.ROUTE: [(self.my_info.fqcn, time.time())], - MessageHeaderKey.DESTINATION: t, - MessageHeaderKey.TO_CELL: ep.name, } ) @@ -1506,11 +1517,12 @@ def send_reply(self, reply: Message, to_cell: str, for_req_ids: List[str], secur reply.add_headers( { MessageHeaderKey.FROM_CELL: self.my_info.fqcn, + MessageHeaderKey.TO_CELL: to_cell, MessageHeaderKey.ORIGIN: self.my_info.fqcn, - MessageHeaderKey.ROUTE: [(self.my_info.fqcn, time.time())], MessageHeaderKey.DESTINATION: to_cell, MessageHeaderKey.REQ_ID: for_req_ids, MessageHeaderKey.MSG_TYPE: MessageType.REPLY, + MessageHeaderKey.ROUTE: [(self.my_info.fqcn, time.time())], MessageHeaderKey.SECURE: secure, MessageHeaderKey.OPTIONAL: optional, } @@ -1926,9 +1938,9 @@ def _process_received_msg(self, endpoint: Endpoint, connection: Connection, mess MessageHeaderKey.CHANNEL: channel, MessageHeaderKey.TOPIC: topic, MessageHeaderKey.FROM_CELL: self.my_info.fqcn, + MessageHeaderKey.TO_CELL: endpoint.name, MessageHeaderKey.ORIGIN: self.my_info.fqcn, MessageHeaderKey.DESTINATION: origin, - MessageHeaderKey.TO_CELL: endpoint.name, MessageHeaderKey.REQ_ID: req_id, MessageHeaderKey.MSG_TYPE: MessageType.REPLY, MessageHeaderKey.ROUTE: [(self.my_info.fqcn, time.time())], diff --git a/nvflare/fuel/f3/drivers/aio_http_driver.py b/nvflare/fuel/f3/drivers/aio_http_driver.py index f8e0f95d5a..61c953a867 100644 --- a/nvflare/fuel/f3/drivers/aio_http_driver.py +++ b/nvflare/fuel/f3/drivers/aio_http_driver.py @@ -85,7 +85,8 @@ async def _async_send_frame(self, frame: BytesAlike): # This is to yield control. See bug: https://github.com/aaugustin/websockets/issues/865 await asyncio.sleep(0) except Exception as ex: - log.error(f"Error sending frame for connection {self}: {secure_format_exception(ex)}") + log.error(f"Error sending frame for connection {self}, closing: {secure_format_exception(ex)}") + self.close() class AioHttpDriver(BaseDriver): @@ -184,5 +185,11 @@ async def _handler(self, websocket): async def _read_loop(conn: WsConnection): while not conn.closing: # Reading from websocket and call receiver CB - frame = await conn.websocket.recv() - conn.process_frame(frame) + try: + frame = await conn.websocket.recv() + conn.process_frame(frame) + except ConnectionClosedOK as ex: + raise ex + except Exception as ex: + log.error(f"Exception {type(ex)} on connection {conn}: {ex}") + raise ex diff --git a/nvflare/fuel/f3/streaming/byte_receiver.py b/nvflare/fuel/f3/streaming/byte_receiver.py index 96aa9f2ff5..21244ab4aa 100644 --- a/nvflare/fuel/f3/streaming/byte_receiver.py +++ b/nvflare/fuel/f3/streaming/byte_receiver.py @@ -39,7 +39,7 @@ MAX_OUT_SEQ_CHUNKS = 16 # 1/4 of the window size ACK_INTERVAL = 1024 * 1024 * 4 -READ_TIMEOUT = 60 +READ_TIMEOUT = 300 COUNTER_NAME_RECEIVED = "received" diff --git a/nvflare/fuel/f3/streaming/byte_streamer.py b/nvflare/fuel/f3/streaming/byte_streamer.py index e06fbed2b9..437a292c14 100644 --- a/nvflare/fuel/f3/streaming/byte_streamer.py +++ b/nvflare/fuel/f3/streaming/byte_streamer.py @@ -38,7 +38,7 @@ STREAM_CHUNK_SIZE = 1024 * 1024 STREAM_WINDOW_SIZE = 16 * STREAM_CHUNK_SIZE -STREAM_ACK_WAIT = 10 +STREAM_ACK_WAIT = 60 STREAM_TYPE_BYTE = "byte" STREAM_TYPE_BLOB = "blob" diff --git a/nvflare/fuel/hci/client/api.py b/nvflare/fuel/hci/client/api.py index 0597ee7a2b..492b11a2f0 100644 --- a/nvflare/fuel/hci/client/api.py +++ b/nvflare/fuel/hci/client/api.py @@ -46,7 +46,7 @@ _CMD_TYPE_SERVER = 2 MAX_AUTO_LOGIN_TRIES = 300 -AUTO_LOGIN_INTERVAL = 1.0 +AUTO_LOGIN_INTERVAL = 1.5 class ResultKey(object): @@ -313,7 +313,7 @@ def __init__( session_timeout_interval=None, session_status_check_interval=None, auto_login_delay: int = 5, - auto_login_max_tries: int = 5, + auto_login_max_tries: int = 15, event_handlers=None, ): """API to keep certs, keys and connection information and to execute admin commands through do_command. diff --git a/nvflare/fuel/utils/pipe/cell_pipe.py b/nvflare/fuel/utils/pipe/cell_pipe.py index d9a75dd51c..ab95ec437e 100644 --- a/nvflare/fuel/utils/pipe/cell_pipe.py +++ b/nvflare/fuel/utils/pipe/cell_pipe.py @@ -157,18 +157,18 @@ def __init__( site_name: str, token: str, root_url: str = "", - secure_mode=True, + secure_mode: bool = True, workspace_dir: str = "", ): """The constructor of the CellPipe. Args: mode: passive or active mode - site_name: name of the FLARE site - token: unique id to guarantee the uniqueness of cell's FQCN. - root_url: the root url of the cellnet that the pipe's cell will join - secure_mode: whether connection to the root is secure (TLS) - workspace_dir: the directory that contains startup for joining the cellnet. Required only in secure_mode + site_name (str): name of the FLARE site + token (str): unique id to guarantee the uniqueness of cell's FQCN. + root_url (str): the root url of the cellnet that the pipe's cell will join + secure_mode (bool): whether connection to the root is secure (TLS) + workspace_dir (str): the directory that contains startup for joining the cellnet. Required only in secure_mode """ super().__init__(mode) self.logger = logging.getLogger(self.__class__.__name__) @@ -211,6 +211,16 @@ def set_cell_cb(self, channel_name: str): self.logger.info(f"registered CellPipe request CB for {self.channel}") def send(self, msg: Message, timeout=None) -> bool: + """Sends the specified message to the peer. + + Args: + msg: the message to be sent + timeout: if specified, number of secs to wait for the peer to read the message. + If not specified, wait indefinitely. + + Returns: + Whether the message is read by the peer. + """ with self.pipe_lock: if self.closed: raise BrokenPipeError("pipe closed") diff --git a/nvflare/fuel/utils/pipe/file_pipe.py b/nvflare/fuel/utils/pipe/file_pipe.py index 3d02df3dc9..d5d7384876 100644 --- a/nvflare/fuel/utils/pipe/file_pipe.py +++ b/nvflare/fuel/utils/pipe/file_pipe.py @@ -135,7 +135,7 @@ def clear(self): self._clear_dir(self.y_path) self._clear_dir(self.t_path) - def _monitor_file(self, file_path: str, timeout) -> bool: + def _monitor_file(self, file_path: str, timeout=None) -> bool: """Monitors the file until it's read-and-removed by peer, or timed out. If timeout, remove the file. @@ -147,8 +147,6 @@ def _monitor_file(self, file_path: str, timeout) -> bool: Returns: whether the file has been read and removed """ - if not timeout: - return False start = time.time() while True: if not self.pipe_path: @@ -156,7 +154,7 @@ def _monitor_file(self, file_path: str, timeout) -> bool: if not os.path.exists(file_path): return True - if time.time() - start > timeout: + if timeout and time.time() - start > timeout: # timed out - try to delete the file try: os.remove(file_path) @@ -247,13 +245,15 @@ def y_get(self, timeout=None): return self._get_from_dir(self.y_path, timeout) def send(self, msg: Message, timeout=None) -> bool: - """ + """Sends the specified message to the peer. Args: - msg: - timeout: + msg: the message to be sent + timeout: if specified, number of secs to wait for the peer to read the message. + If not specified, wait indefinitely. - Returns: whether the message is read by peer (if timeout is specified) + Returns: + Whether the message is read by the peer. """ if not self.pipe_path: diff --git a/nvflare/fuel/utils/pipe/pipe.py b/nvflare/fuel/utils/pipe/pipe.py index 5993896e7d..c4aeb81b3e 100644 --- a/nvflare/fuel/utils/pipe/pipe.py +++ b/nvflare/fuel/utils/pipe/pipe.py @@ -99,14 +99,15 @@ def clear(self): @abstractmethod def send(self, msg: Message, timeout=None) -> bool: - """Send the specified message to the peer. + """Sends the specified message to the peer. Args: msg: the message to be sent timeout: if specified, number of secs to wait for the peer to read the message. + If not specified, wait indefinitely. - Returns: whether the message is read by the peer. - If timeout is not specified, always return False. + Returns: + Whether the message is read by the peer. """ pass @@ -117,8 +118,10 @@ def receive(self, timeout=None) -> Union[None, Message]: Args: timeout: how long (number of seconds) to try + If not specified, return right away. - Returns: the message received; or None if no message + Returns: + the message received; or None if no message """ pass diff --git a/nvflare/fuel/utils/pipe/pipe_handler.py b/nvflare/fuel/utils/pipe/pipe_handler.py index 6efc9c67f2..4826c0bfa4 100644 --- a/nvflare/fuel/utils/pipe/pipe_handler.py +++ b/nvflare/fuel/utils/pipe/pipe_handler.py @@ -64,19 +64,18 @@ def __init__( max_resends=None, default_request_timeout=5.0, ): - """ - Constructor of the PipeHandler. + """Constructor of the PipeHandler. Args: - pipe: the pipe to be monitored - read_interval: how often to read from the pipe - heartbeat_interval: how often to send a heartbeat to the peer - heartbeat_timeout: how long to wait for a heartbeat from the peer before treating the peer as gone, + pipe (Pipe): the pipe to be monitored. + read_interval (float): how often to read from the pipe. + heartbeat_interval (float): how often to send a heartbeat to the peer. + heartbeat_timeout (float): how long to wait for a heartbeat from the peer before treating the peer as gone, 0 means DO NOT check for heartbeat. - resend_interval: how often to resend a message if failing to send. None means no resend. + resend_interval (float): how often to resend a message if failing to send. None means no resend. Note that if the pipe does not support resending, then no resend. - max_resends: max number of resends. None means no limit. - default_request_timeout: default timeout for request if timeout not specified + max_resends (int, optional): max number of resends. None means no limit. + default_request_timeout (float): default timeout for request if timeout not specified. """ check_positive_number("read_interval", read_interval) check_positive_number("heartbeat_interval", heartbeat_interval) @@ -108,6 +107,9 @@ def __init__( self.peer_is_up_or_dead = threading.Event() self._pause = False self._last_heartbeat_received_time = None + self._check_interval = 0.01 + self.heartbeat_sender = threading.Thread(target=self._heartbeat) + self.heartbeat_sender.daemon = True def set_status_cb(self, cb, *args, **kwargs): """Set CB for status handling. When the peer status is changed (ABORT, END, GONE), this CB is called. @@ -209,6 +211,9 @@ def start(self): if not self.reader.is_alive(): self.reader.start() + if not self.heartbeat_sender.is_alive(): + self.heartbeat_sender.start() + def stop(self, close_pipe=True): """Stops the handler and optionally close the monitored pipe. @@ -232,7 +237,7 @@ def send_to_peer(self, msg: Message, timeout=None, abort_signal: Signal = None) Args: msg: message to be sent timeout: how long to wait for the peer to read the data. - If not specified, return False immediately. + If not specified, will use ``self.default_request_timeout``. abort_signal: Returns: @@ -286,15 +291,13 @@ def _read(self): def _try_read(self): self._last_heartbeat_received_time = time.time() - last_heartbeat_sent_time = 0.0 while not self.asked_to_stop: - now = time.time() - if self._pause: time.sleep(self.read_interval) continue msg = self.pipe.receive() + now = time.time() if msg: self._last_heartbeat_received_time = now @@ -319,13 +322,23 @@ def _try_read(self): ) break + time.sleep(self.read_interval) + self.reader = None + + def _heartbeat(self): + last_heartbeat_sent_time = 0.0 + while not self.asked_to_stop: + if self._pause: + time.sleep(self._check_interval) + continue + now = time.time() + # send heartbeat to the peer if now - last_heartbeat_sent_time > self.heartbeat_interval: self.send_to_peer(self._make_event_message(Topic.HEARTBEAT, "")) last_heartbeat_sent_time = now - time.sleep(self.read_interval) - self.reader = None + time.sleep(self._check_interval) def get_next(self) -> Optional[Message]: """Gets the next message from the message queue. diff --git a/nvflare/lighter/dummy_project.yml b/nvflare/lighter/dummy_project.yml index 51d8cc6379..57311da4ae 100644 --- a/nvflare/lighter/dummy_project.yml +++ b/nvflare/lighter/dummy_project.yml @@ -12,9 +12,10 @@ participants: - name: site-1 type: client org: nvidia - # listening_host will enable creating one pair of cert/private key for this client - # so it can behave like a server for client api. The value must be a hostname that - # client api can reach via network. + # Specifying listening_host will enable the creation of one pair of + # certificate/private key for this client, allowing the client to function + # as a server for 3rd-party integration. + # The value must be a hostname that the external trainer can reach via the network. # listening_host: site-1-lh - name: site-2 type: client @@ -28,7 +29,10 @@ participants: builders: - path: nvflare.lighter.impl.workspace.WorkspaceBuilder args: - template_file: master_template.yml + template_file: + - master_template.yml + - aws_template.yml + - azure_template.yml - path: nvflare.lighter.impl.template.TemplateBuilder - path: nvflare.lighter.impl.static_file.StaticFileBuilder args: diff --git a/nvflare/lighter/ha_project.yml b/nvflare/lighter/ha_project.yml index 2a5fecad28..7216dcc762 100644 --- a/nvflare/lighter/ha_project.yml +++ b/nvflare/lighter/ha_project.yml @@ -40,7 +40,10 @@ participants: builders: - path: nvflare.lighter.impl.workspace.WorkspaceBuilder args: - template_file: master_template.yml + template_file: + - master_template.yml + - aws_template.yml + - azure_template.yml - path: nvflare.lighter.impl.template.TemplateBuilder - path: nvflare.lighter.impl.docker.DockerBuilder args: diff --git a/nvflare/lighter/impl/aws_template.yml b/nvflare/lighter/impl/aws_template.yml new file mode 100644 index 0000000000..8ba14d6f2d --- /dev/null +++ b/nvflare/lighter/impl/aws_template.yml @@ -0,0 +1,261 @@ +aws_start_sh: | + VM_NAME=nvflare_{~~type~~} + SECURITY_GROUP=nvflare_{~~type~~}_sg_$RANDOM + DEST_FOLDER=/var/tmp/cloud + KEY_PAIR=NVFlare{~~type~~}KeyPair + KEY_FILE=${KEY_PAIR}.pem + + echo "This script requires aws (AWS CLI), sshpass, dig and jq. Now checking if they are installed." + + check_binary aws "Please see https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html on how to install it on your system." + check_binary sshpass "Please install it first." + check_binary dig "Please install it first." + check_binary jq "Please install it first." + + if [ -z ${image_name+x} ] + then + container=false + else + container=true + fi + + if [ $container = true ] + then + AMI_IMAGE=ami-06b8d5099f3a8d79d + EC2_TYPE=t2.xlarge + REGION=us-west-2 + else + AMI_IMAGE=ami-04bad3c587fe60d89 + EC2_TYPE=t2.small + REGION=us-west-2 + fi + + if [ -z ${config_file+x} ] + then + useDefault=true + else + useDefault=false + . $config_file + report_status "$?" "Loading config file" + fi + + + if [ $useDefault = true ] + then + while true + do + prompt AMI_IMAGE "Cloud AMI image, press ENTER to accept default ${AMI_IMAGE}: " + prompt EC2_TYPE "Cloud EC2 type, press ENTER to accept default ${EC2_TYPE}: " + prompt REGIION "Cloud EC2 region, press ENTER to accept default ${REGION}: " + prompt ans "region = ${REGION}, ami image = ${AMI_IMAGE}, EC2 type = ${EC2_TYPE}, OK? (Y/n) " + if [[ $ans = "" ]] || [[ $ans =~ ^(y|Y)$ ]] + then + break + fi + done + fi + + if [ $container = false ] + then + echo "If the {~~type~~} requires additional dependencies, please copy the requirements.txt to ${DIR}." + prompt ans "Press ENTER when it's done or no additional dependencies. " + fi + + cd $DIR/.. + # Generate key pair + + echo "Generating key pair for VM" + + aws ec2 delete-key-pair --key-name $KEY_PAIR > /dev/null 2>&1 + rm -rf $KEY_FILE + aws ec2 create-key-pair --key-name $KEY_PAIR --query 'KeyMaterial' --output text > $KEY_FILE + report_status "$?" "creating key pair" + chmod 400 $KEY_FILE + + # Generate Security Group + # Try not reusing existing security group because we have to modify it for our own need. + sg_id=$(aws ec2 create-security-group --group-name $SECURITY_GROUP --description "NVFlare security group" | jq -r .GroupId) + report_status "$?" "creating security group" + my_public_ip=$(dig +short myip.opendns.com @resolver1.opendns.com) + if [ "$?" -eq 0 ] && [[ "$my_public_ip" =~ ^(([1-9]?[0-9]|1[0-9][0-9]|2([0-4][0-9]|5[0-5]))\.){3}([1-9]?[0-9]|1[0-9][0-9]|2([0-4][0-9]|5[0-5]))$ ]] + then + aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 22 --cidr ${my_public_ip}/32 > /tmp/sec_grp.log + else + echo "getting my public IP failed, please manually configure the inbound rule to limit SSH access" + aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 22 --cidr 0.0.0.0/0 > /tmp/sec_grp.log + fi + {~~inbound_rule~~} + report_status "$?" "creating security group rules" + + # Start provisioning + + echo "Creating VM at region $REGION, may take a few minutes." + + aws ec2 run-instances --region $REGION --image-id $AMI_IMAGE --count 1 --instance-type $EC2_TYPE --key-name $KEY_PAIR --security-group-ids $sg_id > vm_create.json + report_status "$?" "creating VM" + instance_id=$(jq -r .Instances[0].InstanceId vm_create.json) + + aws ec2 wait instance-status-ok --instance-ids $instance_id + aws ec2 describe-instances --instance-ids $instance_id > vm_result.json + + IP_ADDRESS=$(jq -r .Reservations[0].Instances[0].PublicIpAddress vm_result.json) + + echo "VM created with IP address: ${IP_ADDRESS}" + + echo "Copying files to $VM_NAME" + DEST_SITE=ubuntu@${IP_ADDRESS} + DEST=${DEST_SITE}:${DEST_FOLDER} + echo "Destination folder is ${DEST}" + scp -q -i $KEY_FILE -r -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null $PWD $DEST + report_status "$?" "copying startup kits to VM" + + if [ $container = true ] + then + echo "Launching container with docker option ${DOCKER_OPTION}." + ssh -f -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ + "docker run -d -v ${DEST_FOLDER}:${DEST_FOLDER} --network host ${DOCKER_OPTION} ${image_name} \ + /bin/bash -c \"python -u -m nvflare.private.fed.app.{~~type~~}.{~~type~~}_train -m ${DEST_FOLDER} \ + -s fed_{~~type~~}.json --set {~~cln_uid~~} secure_train=true config_folder=config org={~~ORG~~} \" " > /tmp/nvflare.log 2>&1 + report_status "$?" "launching container" + else + echo "Installing packages in $VM_NAME, may take a few minutes." + ssh -f -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ + "pwd && wget -q https://bootstrap.pypa.io/get-pip.py && \ + python3 get-pip.py && python3 -m pip install nvflare && \ + touch ${DEST_FOLDER}/startup/requirements.txt && \ + python3 -m pip install -r ${DEST_FOLDER}/startup/requirements.txt && \ + nohup ${DEST_FOLDER}/startup/start.sh && sleep 20 && \ + exit" > /tmp/nvflare.log 2>&1 + report_status "$?" "installing packages" + fi + + echo "System was provisioned" + echo "To terminate the EC2 instance, run the following command." + echo "aws ec2 terminate-instances --instance-ids ${instance_id}" + echo "Other resources provisioned" + echo "security group: ${SECURITY_GROUP}" + echo "key pair: ${KEY_PAIR}" + +aws_start_dsb_sh: | + VM_NAME=nvflare_dashboard + AMI_IMAGE=ami-04bad3c587fe60d89 + EC2_TYPE=t2.small + SECURITY_GROUP=nvflare_dashboard_sg_$RANDOM + REGION=us-west-2 + ADMIN_USERNAME=ubuntu + DEST_FOLDER=/home/${ADMIN_USERNAME} + KEY_PAIR=NVFlareDashboardKeyPair + KEY_FILE=${KEY_PAIR}.pem + + echo "This script requires aws (AWS CLI), sshpass, dig and jq. Now checking if they are installed." + + check_binary aws "Please see https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html on how to install it on your system." + check_binary sshpass "Please install it first." + check_binary dig "Please install it first." + check_binary jq "Please install it first." + + echo "One initial user will be created when starting dashboard." + echo "Please enter the email address for this user." + read email + credential="${email}:$RANDOM" + + # Generate key pair + + echo "Generating key pair for VM" + + aws ec2 delete-key-pair --key-name $KEY_PAIR > /dev/null 2>&1 + rm -rf $KEY_FILE + aws ec2 create-key-pair --key-name $KEY_PAIR --query 'KeyMaterial' --output text > $KEY_FILE + report_status "$?" "creating key pair" + chmod 400 $KEY_FILE + + # Generate Security Group + + sg_id=$(aws ec2 create-security-group --group-name $SECURITY_GROUP --description "NVFlare security group" | jq -r .GroupId) + report_status "$?" "creating security group" + echo "Security group id: ${sg_id}" + my_public_ip=$(dig +short myip.opendns.com @resolver1.opendns.com) + if [ "$?" -eq 0 ] && [[ "$my_public_ip" =~ ^(([1-9]?[0-9]|1[0-9][0-9]|2([0-4][0-9]|5[0-5]))\.){3}([1-9]?[0-9]|1[0-9][0-9]|2([0-4][0-9]|5[0-5]))$ ]] + then + aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 22 --cidr ${my_public_ip}/32 > /tmp/sec_grp.log + else + echo "getting my public IP failed, please manually configure the inbound rule to limit SSH access" + aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 22 --cidr 0.0.0.0/0 > /tmp/sec_grp.log + fi + aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 443 --cidr 0.0.0.0/0 >> /tmp/sec_grp.log + report_status "$?" "creating security group rules" + + # Start provisioning + + echo "Creating VM at region $REGION, may take a few minutes." + + aws ec2 run-instances --region $REGION --image-id $AMI_IMAGE --count 1 --instance-type $EC2_TYPE --key-name $KEY_PAIR --security-group-ids $sg_id > vm_create.json + report_status "$?" "creating VM" + instance_id=$(jq -r .Instances[0].InstanceId vm_create.json) + + aws ec2 wait instance-status-ok --instance-ids $instance_id + aws ec2 describe-instances --instance-ids $instance_id > vm_result.json + + IP_ADDRESS=$(jq -r .Reservations[0].Instances[0].PublicIpAddress vm_result.json) + + echo "VM created with IP address: ${IP_ADDRESS}" + + echo "Installing docker engine in $VM_NAME, may take a few minutes." + DEST_SITE=${ADMIN_USERNAME}@${IP_ADDRESS} + scripts=$(cat << 'EOF' + sudo apt-get update && \ + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y ca-certificates curl gnupg lsb-release && \ + sudo mkdir -p /etc/apt/keyrings && \ + curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg && \ + echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ + $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null && \ + sudo apt-get update && \ + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y docker-ce docker-ce-cli containerd.io + EOF + ) + ssh -t -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} "$scripts" > /tmp/docker_engine.log + report_status "$?" "installing docker engine" + ssh -t -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} "sudo usermod -aG docker $ADMIN_USERNAME && exit" >> /tmp/docker_engine.log + report_status "$?" "installing docker engine" + + echo "Installing nvflare in $VM_NAME, may take a few minutes." + ssh -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ + "export PATH=/home/ubuntu/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/snap/bin && \ + wget -q https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && \ + python3 -m pip install {~~NVFLARE~~} && \ + mkdir -p ./cert && \ + exit" > /tmp/nvflare.json + report_status "$?" "installing nvflare" + + echo "Checking if certificate (web.crt) and private key (web.key) are available" + if [[ -f "web.crt" && -f "web.key" ]]; then + CERT_FOLDER=${DEST_SITE}:${DEST_FOLDER}/cert + echo "Cert folder is ${CERT_FOLDER}" + scp -i $KEY_FILE -r -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null web.{crt,key} $CERT_FOLDER + report_status "$?" "copying cert/key to VM ${CERT_FOLDER} folder" + secure=true + else + echo "No web.crt and web.key found" + secure=false + fi + + echo "Starting dashboard" + ssh -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ + "export PATH=/home/ubuntu/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/snap/bin && \ + python3 -m nvflare.dashboard.cli --start -f ${DEST_FOLDER} --cred ${credential} {~~START_OPT~~}" > /tmp/dashboard.json + + echo "Dashboard url is running at IP address ${IP_ADDRESS}, listening to port 443." + if [ "$secure" = true ] + then + echo "URL is https://${IP_ADDRESS}" + else + echo "URL is http://${IP_ADDRESS}:443" + fi + echo "Note: you may need to configure DNS server with your DNS hostname and the above IP address." + echo "Project admin credential (username:password) is ${credential} ." + echo "To terminate the EC2 instance, run the following command." + echo "aws ec2 terminate-instances --instance-ids ${instance_id}" + echo "Other resources provisioned" + echo "security group: ${SECURITY_GROUP}" + echo "key pair: ${KEY_PAIR}" diff --git a/nvflare/lighter/impl/azure_template.yml b/nvflare/lighter/impl/azure_template.yml new file mode 100644 index 0000000000..9c42a10cf3 --- /dev/null +++ b/nvflare/lighter/impl/azure_template.yml @@ -0,0 +1,517 @@ +azure_start_svr_header_sh: | + RESOURCE_GROUP=nvflare_rg + VM_NAME=nvflare_server + VM_IMAGE=Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest + VM_SIZE=Standard_B2ms + NSG_NAME=nvflare_nsgs + ADMIN_USERNAME=nvflare + PASSWORD="NVFl@r3_P@88"$RANDOM"w0rd" + DEST_FOLDER=/var/tmp/cloud + NIC_NAME=${VM_NAME}VMNic + SERVER_NAME={~~server_name~~} + FL_PORT=8002 + ADMIN_PORT=8003 + + echo "This script requires az (Azure CLI), sshpass and jq. Now checking if they are installed." + + check_binary az "Please see https://learn.microsoft.com/en-us/cli/azure/install-azure-cli on how to install it on your system." + check_binary sshpass "Please install it first." + check_binary jq "Please install it first." + + self_dns=true + if [[ "$SERVER_NAME" = *".cloudapp.azure.com"* ]] + then + DNS_TAG=$(echo $SERVER_NAME | cut -d "." -f 1) + DERIVED_LOCATION=$(echo $SERVER_NAME | cut -d "." -f 2) + LOCATION=$DERIVED_LOCATION + self_dns=false + else + echo "Warning: ${SERVER_NAME} does not end with .cloudapp.azure.com." + echo "The cloud launch process will not create the domain name for you." + echo "Please use your own DNS to set the information." + LOCATION=westus2 + fi + + if [ -z ${image_name+x} ] + then + container=false + else + container=true + fi + + if [ $container = true ] + then + VM_IMAGE=Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest + VM_SIZE=Standard_D8s_v3 + else + VM_IMAGE=Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest + VM_SIZE=Standard_B2ms + fi + + if [ -z ${config_file+x} ] + then + useDefault=true + else + useDefault=false + . $config_file + report_status "$?" "Loading config file" + if [ $self_dns = false ] && [ $DERIVED_LOCATION != $LOCATION ] + then + echo "Server name implies LOCATION=${DERIVED_LOCATION} but the config file specifies LOCATION=${LOCATION}. Unable to continue provisioning." + exit 1 + fi + fi + + if [ $useDefault = true ] + then + while true + do + prompt VM_IMAGE "Cloud VM image, press ENTER to accept default ${VM_IMAGE}: " + prompt VM_SIZE "Cloud VM size, press ENTER to accept default ${VM_SIZE}: " + if [ $self_dns = true ] + then + prompt LOCATION "Cloud location, press ENTER to accept default ${LOCATION}: " + prompt ans "VM image = ${VM_IMAGE}, VM size = ${VM_SIZE}, location = ${LOCATION}, OK? (Y/n) " + else + prompt ans "VM image = ${VM_IMAGE}, VM size = ${VM_SIZE}, OK? (Y/n) " + fi + if [[ $ans = "" ]] || [[ $ans =~ ^(y|Y)$ ]]; then break; fi + done + fi + + if [ $container = false ] + then + echo "If the client requires additional dependencies, please copy the requirements.txt to ${DIR}." + prompt ans "Press ENTER when it's done or no additional dependencies. " + fi + + az login --use-device-code -o none + report_status "$?" "login" + + # Start provisioning + + if [ $(az group exists -n $RESOURCE_GROUP) == 'false' ] + then + echo "Creating Resource Group $RESOURCE_GROUP at Location $LOCATION" + az group create --output none --name $RESOURCE_GROUP --location $LOCATION + report_status "$?" "creating resource group" + elif [ $useDefault = true ] + then + report_status "1" "Only one NVFL server VM and its resource group is allowed. $RESOURCE_GROUP exists and thus creating duplicate resource group" + else + echo "Users require to reuse Resource Group $RESOURCE_GROUP. This script will modify the group and may not work always." + fi + + echo "Creating Virtual Machine, will take a few minutes" + if [ $self_dns = true ] + then + az vm create \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --location $LOCATION \ + --name $VM_NAME \ + --image $VM_IMAGE \ + --size $VM_SIZE \ + --admin-username $ADMIN_USERNAME \ + --admin-password $PASSWORD \ + --authentication-type password \ + --public-ip-address nvflare_server_ip \ + --public-ip-address-allocation static \ + --public-ip-sku Standard > /tmp/vm.json + else + az vm create \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --location $LOCATION \ + --name $VM_NAME \ + --image $VM_IMAGE \ + --size $VM_SIZE \ + --admin-username $ADMIN_USERNAME \ + --admin-password $PASSWORD \ + --authentication-type password \ + --public-ip-address nvflare_server_ip \ + --public-ip-address-allocation static \ + --public-ip-sku Standard \ + --public-ip-address-dns-name $DNS_TAG > /tmp/vm.json + fi + report_status "$?" "creating virtual machine" + + IP_ADDRESS=$(jq -r .publicIpAddress /tmp/vm.json) + echo "Setting up network related configuration" + az network nsg create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --location $LOCATION \ + --name $NSG_NAME + report_status "$?" "creating network security group" + + az network nsg rule create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name SSH \ + --nsg-name $NSG_NAME \ + --priority 1000 \ + --protocol Tcp \ + --destination-port-ranges 22 + report_status "$?" "creating network security group rule for SSH" + + az network nsg rule create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name FL_PORT \ + --nsg-name $NSG_NAME \ + --priority 1001 \ + --protocol Tcp \ + --destination-port-ranges $FL_PORT + report_status "$?" "creating network security group rule for FL port" + + az network nsg rule create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name ADMIN_PORT \ + --nsg-name $NSG_NAME \ + --priority 1002 \ + --protocol Tcp \ + --destination-port-ranges $ADMIN_PORT + report_status "$?" "creating network security group rule for Admin port" + +azure_start_cln_header_sh: | + RESOURCE_GROUP=nvflare_client_rg_${RANDOM}_${RANDOM} + VM_NAME=nvflare_client + VM_IMAGE=Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest + VM_SIZE=Standard_B2ms + NSG_NAME=nvflare_nsgc + ADMIN_USERNAME=nvflare + PASSWORD="NVFl@r3_P@88"$RANDOM"w0rd" + DEST_FOLDER=/var/tmp/cloud + LOCATION=westus2 + NIC_NAME=${VM_NAME}VMNic + echo "This script requires az (Azure CLI), sshpass and jq. Now checking if they are installed." + + check_binary az "Please see https://learn.microsoft.com/en-us/cli/azure/install-azure-cli on how to install it on your system." + check_binary sshpass "Please install it first." + check_binary jq "Please install it first." + + + if [ -z ${image_name+x} ] + then + container=false + else + container=true + fi + + if [ $container = true ] + then + VM_IMAGE=Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest + VM_SIZE=Standard_D8s_v3 + else + VM_IMAGE=Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest + VM_SIZE=Standard_B2ms + fi + if [ -z ${config_file+x} ] + then + useDefault=true + else + useDefault=false + . $config_file + report_status "$?" "Loading config file" + fi + + if [ $useDefault = true ] + then + while true + do + prompt LOCATION "Cloud location, press ENTER to accept default ${LOCATION}: " + prompt VM_IMAGE "Cloud VM image, press ENTER to accept default ${VM_IMAGE}: " + prompt VM_SIZE "Cloud VM size, press ENTER to accept default ${VM_SIZE}: " + prompt ans "location = ${LOCATION}, VM image = ${VM_IMAGE}, VM size = ${VM_SIZE}, OK? (Y/n) " + if [[ $ans = "" ]] || [[ $ans =~ ^(y|Y)$ ]]; then break; fi + done + fi + + if [ $container = false ] + then + echo "If the client requires additional dependencies, please copy the requirements.txt to ${DIR}." + prompt ans "Press ENTER when it's done or no additional dependencies. " + fi + + az login --use-device-code -o none + report_status "$?" "login" + + # Start provisioning + + if [ $(az group exists -n $RESOURCE_GROUP) == 'false' ] + then + echo "Creating Resource Group $RESOURCE_GROUP at Location $LOCATION" + az group create --output none --name $RESOURCE_GROUP --location $LOCATION + report_status "$?" "creating resource group" + else + echo "Resource Group $RESOURCE_GROUP exists, will reuse it." + fi + + echo "Creating Virtual Machine, will take a few minutes" + az vm create \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --location $LOCATION \ + --name $VM_NAME \ + --image $VM_IMAGE \ + --size $VM_SIZE \ + --admin-username $ADMIN_USERNAME \ + --admin-password $PASSWORD \ + --authentication-type password \ + --public-ip-sku Standard > /tmp/vm.json + report_status "$?" "creating virtual machine" + + IP_ADDRESS=$(jq -r .publicIpAddress /tmp/vm.json) + + echo "Setting up network related configuration" + + az network nsg create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --location $LOCATION \ + --name $NSG_NAME + report_status "$?" "creating network security group" + + az network nsg rule create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name SSH \ + --nsg-name $NSG_NAME \ + --priority 1000 \ + --protocol Tcp \ + --destination-port-ranges 22 + report_status "$?" "creating network security group rule for SSH" + +azure_start_common_sh: | + az network nic update \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name $NIC_NAME \ + --network-security-group $NSG_NAME + report_status "$?" "updating network interface card" + + echo "Copying files to $VM_NAME" + DEST=$ADMIN_USERNAME@${IP_ADDRESS}:$DEST_FOLDER + echo "Destination folder is ${DEST}" + cd $DIR/.. && sshpass -p $PASSWORD scp -r -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null $PWD $DEST + report_status "$?" "copying startup kits to VM" + + if [ $container = true ] + then + echo "Installing and lauching container in $VM_NAME, may take a few minutes." + scripts=$(cat << 'EOF' + sudo apt-get update && \ + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y ca-certificates curl gnupg lsb-release && \ + sudo mkdir -p /etc/apt/keyrings && \ + curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg && \ + echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ + $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null && \ + sudo apt-get update && \ + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y docker-ce docker-ce-cli containerd.io + EOF + ) + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "$scripts" > /tmp/docker_engine.json + report_status "$?" "installing docker engine" + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "sudo usermod -aG docker $ADMIN_USERNAME" >> /tmp/docker_engine.json + report_status "$?" "Setting user group" + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "docker run -d -v ${DEST_FOLDER}:${DEST_FOLDER} {~~docker_network~~} ${image_name} /bin/bash -c \"python -u -m nvflare.private.fed.app.{~~type~~}.{~~type~~}_train -m ${DEST_FOLDER} -s fed_{~~type~~}.json --set {~~cln_uid~~} secure_train=true config_folder=config org={~~ORG~~} \" " > /tmp/vm_create.json 2>&1 + report_status "$?" "launching container" + else + echo "Installing packages in $VM_NAME, may take a few minutes." + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "echo ${DEST_FOLDER} && wget -q https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python3 -m pip install --ignore-installed nvflare && touch ${DEST_FOLDER}/startup/requirements.txt && python3 -m pip install -r ${DEST_FOLDER}/startup/requirements.txt && ${DEST_FOLDER}/startup/start.sh && sleep 20 && cat ${DEST_FOLDER}/log.txt" > /tmp/vm_create.json + report_status "$?" "installing packages" + fi + echo "System was provisioned" + echo "To delete the resource group (also delete the VM), run the following command" + echo "az group delete -n ${RESOURCE_GROUP}" + echo "To login to the VM with SSH, use ${ADMIN_USERNAME} : ${PASSWORD}" > vm_credential.txt + +azure_start_dsb_sh: | + RESOURCE_GROUP=nvflare_dashboard_rg_${RANDOM}_${RANDOM} + VM_NAME=nvflare_dashboard + VM_IMAGE=Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest + VM_SIZE=Standard_B2ms + NSG_NAME=nvflare_nsgc + ADMIN_USERNAME=nvflare + PASSWORD="NVFl@r3_P@88"$RANDOM"w0rd" + DEST_FOLDER=/var/tmp/cloud + LOCATION=westus2 + NIC_NAME=${VM_NAME}VMNic + + echo "This script requires az (Azure CLI), sshpass and jq. Now checking if they are installed." + + check_binary az "Please see https://learn.microsoft.com/en-us/cli/azure/install-azure-cli on how to install it on your system." + check_binary sshpass "Please install it first." + check_binary jq "Please install it first." + + echo "One initial user will be created when starting dashboard." + echo "Please enter the email address for this user." + read email + credential="${email}:$RANDOM" + + az login --use-device-code -o none + report_status "$?" "login" + + # Start provisioning + if [ $(az group exists -n $RESOURCE_GROUP) == 'false' ] + then + echo "Creating Resource Group $RESOURCE_GROUP at Location $LOCATION" + az group create --output none --name $RESOURCE_GROUP --location $LOCATION + report_status "$?" "creating resource group" + else + echo "Resource Group $RESOURCE_GROUP exists, will reuse it." + fi + + echo "Creating Virtual Machine, will take a few minutes" + az vm create \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --location $LOCATION \ + --name $VM_NAME \ + --image $VM_IMAGE \ + --size $VM_SIZE \ + --admin-username $ADMIN_USERNAME \ + --admin-password $PASSWORD \ + --authentication-type password \ + --public-ip-sku Standard > /tmp/vm.json + report_status "$?" "creating virtual machine" + + IP_ADDRESS=$(jq -r .publicIpAddress /tmp/vm.json) + report_status "$?" "extracting ip address" + + echo "Setting up network related configuration" + az network nsg create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --location $LOCATION \ + --name $NSG_NAME + report_status "$?" "creating network security group" + + az network nsg rule create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name SSH \ + --nsg-name $NSG_NAME \ + --priority 1000 \ + --protocol Tcp \ + --destination-port-ranges 22 + report_status "$?" "creating network security group rule for SSH" + + az network nsg rule create \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name HTTPS \ + --nsg-name $NSG_NAME \ + --priority 1001 \ + --protocol Tcp \ + --destination-port-ranges 443 + report_status "$?" "creating network security group rule for HTTPS" + + az network nic update \ + --output none \ + --resource-group $RESOURCE_GROUP \ + --name $NIC_NAME \ + --network-security-group $NSG_NAME + report_status "$?" "updating network interface card" + + echo "Installing docker engine in $VM_NAME, may take a few minutes." + scripts=$(cat << 'EOF' + sudo apt-get update && \ + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y ca-certificates curl gnupg lsb-release && \ + sudo mkdir -p /etc/apt/keyrings && \ + curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg && \ + echo \ + "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ + $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null && \ + sudo apt-get update && \ + sudo DEBIAN_FRONTEND=noninteractive apt-get install -y docker-ce docker-ce-cli containerd.io + EOF + ) + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "$scripts" > /tmp/docker_engine.json + report_status "$?" "installing docker engine" + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "sudo usermod -aG docker $ADMIN_USERNAME" >> /tmp/docker_engine.json + report_status "$?" "installing docker engine" + + DEST_FOLDER=/home/${ADMIN_USERNAME} + echo "Installing nvflare in $VM_NAME, may take a few minutes." + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "echo ${DEST_FOLDER} && wget -q https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python3 -m pip install --ignore-installed {~~NVFLARE~~} && mkdir -p ${DEST_FOLDER}/cert && chown -R ${ADMIN_USERNAME} ${DEST_FOLDER}" > /tmp/nvflare.json + report_status "$?" "installing nvflare" + + echo "Checking if certificate (web.crt) and private key (web.key) are available" + if [[ -f "web.crt" && -f "web.key" ]]; then + DEST=$ADMIN_USERNAME@$IP_ADDRESS:${DEST_FOLDER}/cert + echo "Destination folder is ${DEST}" + sshpass -p $PASSWORD scp -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null web.{crt,key} $DEST + report_status "$?" "copying cert/key to VM ${DEST} folder" + secure=true + else + echo "No web.crt and web.key found" + secure=false + fi + + echo "Starting dashboard" + az vm run-command invoke \ + --output json \ + --resource-group $RESOURCE_GROUP \ + --command-id RunShellScript \ + --name $VM_NAME \ + --scripts \ + "cd ${DEST_FOLDER} && python3 -m nvflare.dashboard.cli --start -f ${DEST_FOLDER} --cred ${credential} {~~START_OPT~~}" > /tmp/dashboard.json + + # credential=$(jq -r .value[0].message /tmp/dashboard.json | grep "Project admin") + # echo "The VM was created with user: ${ADMIN_USERNAME} and password: ${PASSWORD}" + if [ "$secure" = true ] + then + echo "URL is https://${IP_ADDRESS}" + else + echo "URL is http://${IP_ADDRESS}:443" + fi + echo "Note: you may need to configure DNS server with your DNS hostname and the above IP address." + echo "Project admin credential (username:password) is ${credential} ." + echo "To stop the dashboard, run az group delete -n ${RESOURCE_GROUP}" + echo "To login to the VM with SSH, use ${ADMIN_USERNAME} : ${PASSWORD}" > vm_credential.txt diff --git a/nvflare/lighter/impl/master_template.yml b/nvflare/lighter/impl/master_template.yml index 5d816858c6..24342030ba 100644 --- a/nvflare/lighter/impl/master_template.yml +++ b/nvflare/lighter/impl/master_template.yml @@ -417,7 +417,7 @@ stop_fl_sh: | ;; esac -sub_start_cln_sh: | +sub_start_sh: | #!/usr/bin/env bash DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" echo "WORKSPACE set to $DIR/.." @@ -440,7 +440,7 @@ sub_start_cln_sh: | exit fi lst=$SECONDS - ((python3 -u -m nvflare.private.fed.app.client.client_train -m $DIR/.. -s fed_client.json --set secure_train=true uid={~~client_name~~} org={~~org_name~~} config_folder={~~config_folder~~} 2>&1 & echo $! >&3 ) 3>$DIR/../pid.fl ) + ((python3 -u -m nvflare.private.fed.app.{~~type~~}.{~~type~~}_train -m $DIR/.. -s fed_{~~type~~}.json --set secure_train=true {~~cln_uid~~} org={~~org_name~~} config_folder={~~config_folder~~} 2>&1 & echo $! >&3 ) 3>$DIR/../pid.fl ) pid=`cat $DIR/../pid.fl` echo "new pid ${pid}" } @@ -506,93 +506,6 @@ sub_start_cln_sh: | rm -f $DIR/../pid.fl $DIR/../shutdown.fl $DIR/../restart.fl $DIR/../daemon_pid.fl -sub_start_svr_sh: | - #!/usr/bin/env bash - DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" - echo "WORKSPACE set to $DIR/.." - mkdir -p $DIR/../transfer - - SECONDS=0 - lst=-400 - restart_count=0 - start_fl() { - if [[ $(( $SECONDS - $lst )) -lt 300 ]]; then - ((restart_count++)) - else - restart_count=0 - fi - if [[ $(($SECONDS - $lst )) -lt 300 && $restart_count -ge 5 ]]; then - echo "System is in trouble and unable to start the task!!!!!" - rm -f $DIR/../pid.fl $DIR/../shutdown.fl $DIR/../restart.fl $DIR/../daemon_pid.fl - exit - fi - lst=$SECONDS - ((python3 -u -m nvflare.private.fed.app.server.server_train -m $DIR/.. -s fed_server.json --set secure_train=true org={~~org_name~~} config_folder={~~config_folder~~} 2>&1 & echo $! >&3 ) 3>$DIR/../pid.fl ) - pid=`cat $DIR/../pid.fl` - echo "new pid ${pid}" - } - - stop_fl() { - if [[ ! -f "$DIR/../pid.fl" ]]; then - echo "No pid.fl. No need to kill process." - return - fi - pid=`cat $DIR/../pid.fl` - sleep 5 - kill -0 ${pid} 2> /dev/null 1>&2 - if [[ $? -ne 0 ]]; then - echo "Process already terminated" - return - fi - kill -9 $pid - rm -f $DIR/../pid.fl $DIR/../shutdown.fl $DIR/../restart.fl - } - - if [[ -f "$DIR/../daemon_pid.fl" ]]; then - dpid=`cat $DIR/../daemon_pid.fl` - kill -0 ${dpid} 2> /dev/null 1>&2 - if [[ $? -eq 0 ]]; then - echo "There seems to be one instance, pid=$dpid, running." - echo "If you are sure it's not the case, please kill process $dpid and then remove daemon_pid.fl in $DIR/.." - exit - fi - rm -f $DIR/../daemon_pid.fl - fi - - echo $BASHPID > $DIR/../daemon_pid.fl - - while true - do - sleep 5 - if [[ ! -f "$DIR/../pid.fl" ]]; then - echo "start fl because of no pid.fl" - start_fl - continue - fi - pid=`cat $DIR/../pid.fl` - kill -0 ${pid} 2> /dev/null 1>&2 - if [[ $? -ne 0 ]]; then - if [[ -f "$DIR/../shutdown.fl" ]]; then - echo "Gracefully shutdown." - break - fi - echo "start fl because process of ${pid} does not exist" - start_fl - continue - fi - if [[ -f "$DIR/../shutdown.fl" ]]; then - echo "About to shutdown." - stop_fl - break - fi - if [[ -f "$DIR/../restart.fl" ]]; then - echo "About to restart." - stop_fl - fi - done - - rm -f $DIR/../pid.fl $DIR/../shutdown.fl $DIR/../restart.fl $DIR/../daemon_pid.fl - docker_cln_sh: | #!/usr/bin/env bash DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" @@ -910,589 +823,6 @@ cloud_script_header: | shift done -azure_start_svr_sh: | - RESOURCE_GROUP=nvflare_rg - VM_NAME=nvflare_server - VM_IMAGE=Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest - VM_SIZE=Standard_B2ms - NSG_NAME=nvflare_nsgs - ADMIN_USERNAME=nvflare - PASSWORD="NVFl@r3_P@88"$RANDOM"w0rd" - DEST_FOLDER=/var/tmp/cloud - NIC_NAME=${VM_NAME}VMNic - SERVER_NAME={~~server_name~~} - FL_PORT=8002 - ADMIN_PORT=8003 - - echo "This script requires az (Azure CLI), sshpass and jq. Now checking if they are installed." - - check_binary az "Please see https://learn.microsoft.com/en-us/cli/azure/install-azure-cli on how to install it on your system." - check_binary sshpass "Please install it first." - check_binary jq "Please install it first." - - self_dns=true - if [[ "$SERVER_NAME" = *".cloudapp.azure.com"* ]] - then - DNS_TAG=$(echo $SERVER_NAME | cut -d "." -f 1) - DERIVED_LOCATION=$(echo $SERVER_NAME | cut -d "." -f 2) - LOCATION=$DERIVED_LOCATION - self_dns=false - else - echo "Warning: ${SERVER_NAME} does not end with .cloudapp.azure.com." - echo "The cloud launch process will not create the domain name for you." - echo "Please use your own DNS to set the information." - LOCATION=westus2 - fi - - if [ -z ${image_name+x} ] - then - container=false - else - container=true - fi - - if [ $container = true ] - then - VM_IMAGE=Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest - VM_SIZE=Standard_D8s_v3 - else - VM_IMAGE=Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest - VM_SIZE=Standard_B2ms - fi - - if [ -z ${config_file+x} ] - then - useDefault=true - else - useDefault=false - . $config_file - report_status "$?" "Loading config file" - if [ $self_dns = false ] && [ $DERIVED_LOCATION != $LOCATION ] - then - echo "Server name implies LOCATION=${DERIVED_LOCATION} but the config file specifies LOCATION=${LOCATION}. Unable to continue provisioning." - exit 1 - fi - fi - - if [ $useDefault = true ] - then - while true - do - prompt VM_IMAGE "Cloud VM image, press ENTER to accept default ${VM_IMAGE}: " - prompt VM_SIZE "Cloud VM size, press ENTER to accept default ${VM_SIZE}: " - if [ $self_dns = true ] - then - prompt LOCATION "Cloud location, press ENTER to accept default ${LOCATION}: " - prompt ans "VM image = ${VM_IMAGE}, VM size = ${VM_SIZE}, location = ${LOCATION}, OK? (Y/n) " - else - prompt ans "VM image = ${VM_IMAGE}, VM size = ${VM_SIZE}, OK? (Y/n) " - fi - if [[ $ans = "" ]] || [[ $ans =~ ^(y|Y)$ ]]; then break; fi - done - fi - - if [ $container = false ] - then - echo "If the client requires additional dependencies, please copy the requirements.txt to ${DIR}." - prompt ans "Press ENTER when it's done or no additional dependencies. " - fi - - az login --use-device-code -o none - report_status "$?" "login" - - # Start provisioning - - if [ $(az group exists -n $RESOURCE_GROUP) == 'false' ] - then - echo "Creating Resource Group $RESOURCE_GROUP at Location $LOCATION" - az group create --output none --name $RESOURCE_GROUP --location $LOCATION - report_status "$?" "creating resource group" - elif [ $useDefault = true ] - then - report_status "1" "Only one NVFL server VM and its resource group is allowed. $RESOURCE_GROUP exists and thus creating duplicate resource group" - else - echo "Users require to reuse Resource Group $RESOURCE_GROUP. This script will modify the group and may not work always." - fi - - echo "Creating Virtual Machine, will take a few minutes" - if [ $self_dns = true ] - then - az vm create \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --location $LOCATION \ - --name $VM_NAME \ - --image $VM_IMAGE \ - --size $VM_SIZE \ - --admin-username $ADMIN_USERNAME \ - --admin-password $PASSWORD \ - --authentication-type password \ - --public-ip-address nvflare_server_ip \ - --public-ip-address-allocation static \ - --public-ip-sku Standard > /tmp/vm.json - else - az vm create \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --location $LOCATION \ - --name $VM_NAME \ - --image $VM_IMAGE \ - --size $VM_SIZE \ - --admin-username $ADMIN_USERNAME \ - --admin-password $PASSWORD \ - --authentication-type password \ - --public-ip-address nvflare_server_ip \ - --public-ip-address-allocation static \ - --public-ip-sku Standard \ - --public-ip-address-dns-name $DNS_TAG > /tmp/vm.json - fi - report_status "$?" "creating virtual machine" - - IP_ADDRESS=$(jq -r .publicIpAddress /tmp/vm.json) - echo "Setting up network related configuration" - az network nsg create \ - --output none \ - --resource-group $RESOURCE_GROUP \ - --location $LOCATION \ - --name $NSG_NAME - report_status "$?" "creating network security group" - - az network nsg rule create \ - --output none \ - --resource-group $RESOURCE_GROUP \ - --name SSH \ - --nsg-name $NSG_NAME \ - --priority 1000 \ - --protocol Tcp \ - --destination-port-ranges 22 - report_status "$?" "creating network security group rule for SSH" - - az network nsg rule create \ - --output none \ - --resource-group $RESOURCE_GROUP \ - --name FL_PORT \ - --nsg-name $NSG_NAME \ - --priority 1001 \ - --protocol Tcp \ - --destination-port-ranges $FL_PORT - report_status "$?" "creating network security group rule for FL port" - - az network nsg rule create \ - --output none \ - --resource-group $RESOURCE_GROUP \ - --name ADMIN_PORT \ - --nsg-name $NSG_NAME \ - --priority 1002 \ - --protocol Tcp \ - --destination-port-ranges $ADMIN_PORT - report_status "$?" "creating network security group rule for Admin port" - - az network nic update \ - --output none \ - --resource-group $RESOURCE_GROUP \ - --name $NIC_NAME \ - --network-security-group $NSG_NAME - report_status "$?" "updating network interface card" - - echo "Copying files to $VM_NAME" - DEST=$ADMIN_USERNAME@${IP_ADDRESS}:$DEST_FOLDER - echo "Destination folder is ${DEST}" - cd $DIR/.. && sshpass -p $PASSWORD scp -r -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null $PWD $DEST - report_status "$?" "copying startup kits to VM" - - if [ $container = true ] - then - echo "Installing and lauching container in $VM_NAME, may take a few minutes." - scripts=$(cat << 'EOF' - sudo apt-get update && \ - sudo DEBIAN_FRONTEND=noninteractive apt-get install -y ca-certificates curl gnupg lsb-release && \ - sudo mkdir -p /etc/apt/keyrings && \ - curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg && \ - echo \ - "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ - $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null && \ - sudo apt-get update && \ - sudo DEBIAN_FRONTEND=noninteractive apt-get install -y docker-ce docker-ce-cli containerd.io - EOF - ) - az vm run-command invoke \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --command-id RunShellScript \ - --name $VM_NAME \ - --scripts \ - "$scripts" > /tmp/docker_engine.json - az vm run-command invoke \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --command-id RunShellScript \ - --name $VM_NAME \ - --scripts \ - "sudo usermod -aG docker $ADMIN_USERNAME" >> /tmp/docker_engine.json - report_status "$?" "installing docker engine" - az vm run-command invoke \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --command-id RunShellScript \ - --name $VM_NAME \ - --scripts \ - "docker run -d -v ${DEST_FOLDER}:${DEST_FOLDER} --network host ${DOCKER_OPTION} ${image_name} /bin/bash -c \"python -u -m nvflare.private.fed.app.server.server_train -m ${DEST_FOLDER} -s fed_server.json --set secure_train=true config_folder=config org={~~ORG~~} \" " > /tmp/vm_create.json 2>&1 - report_status "$?" "launching container" - else - echo "Installing packages in $VM_NAME, may take a few minutes." - az vm run-command invoke \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --command-id RunShellScript \ - --name $VM_NAME \ - --scripts \ - "echo ${DEST_FOLDER} && wget -q https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python3 -m pip install --ignore-installed nvflare && touch ${DEST_FOLDER}/startup/requirements.txt && python3 -m pip install -r ${DEST_FOLDER}/startup/requirements.txt && ${DEST_FOLDER}/startup/start.sh && sleep 20 && cat ${DEST_FOLDER}/log.txt" > /tmp/vm_create.json - report_status "$?" "installing packages" - fi - echo "System was provisioned" - echo "To delete the resource group (also delete the VM), run the following command" - echo "az group delete -n ${RESOURCE_GROUP}" - echo "To login to the VM with SSH, use ${ADMIN_USERNAME} : ${PASSWORD}" > vm_credential.txt - -azure_start_cln_sh: | - RESOURCE_GROUP=nvflare_client_rg_${RANDOM}_${RANDOM} - VM_NAME=nvflare_client - VM_IMAGE=Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest - VM_SIZE=Standard_B2ms - NSG_NAME=nvflare_nsgc - ADMIN_USERNAME=nvflare - PASSWORD="NVFl@r3_P@88"$RANDOM"w0rd" - DEST_FOLDER=/var/tmp/cloud - LOCATION=westus2 - NIC_NAME=${VM_NAME}VMNic - echo "This script requires az (Azure CLI), sshpass and jq. Now checking if they are installed." - - check_binary az "Please see https://learn.microsoft.com/en-us/cli/azure/install-azure-cli on how to install it on your system." - check_binary sshpass "Please install it first." - check_binary jq "Please install it first." - - - if [ -z ${image_name+x} ] - then - container=false - else - container=true - fi - - if [ $container = true ] - then - VM_IMAGE=Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest - VM_SIZE=Standard_D8s_v3 - else - VM_IMAGE=Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest - VM_SIZE=Standard_B2ms - fi - if [ -z ${config_file+x} ] - then - useDefault=true - else - useDefault=false - . $config_file - report_status "$?" "Loading config file" - fi - - if [ $useDefault = true ] - then - while true - do - prompt LOCATION "Cloud location, press ENTER to accept default ${LOCATION}: " - prompt VM_IMAGE "Cloud VM image, press ENTER to accept default ${VM_IMAGE}: " - prompt VM_SIZE "Cloud VM size, press ENTER to accept default ${VM_SIZE}: " - prompt ans "location = ${LOCATION}, VM image = ${VM_IMAGE}, VM size = ${VM_SIZE}, OK? (Y/n) " - if [[ $ans = "" ]] || [[ $ans =~ ^(y|Y)$ ]]; then break; fi - done - fi - - if [ $container = false ] - then - echo "If the client requires additional dependencies, please copy the requirements.txt to ${DIR}." - prompt ans "Press ENTER when it's done or no additional dependencies. " - fi - - az login --use-device-code -o none - report_status "$?" "login" - - # Start provisioning - - if [ $(az group exists -n $RESOURCE_GROUP) == 'false' ] - then - echo "Creating Resource Group $RESOURCE_GROUP at Location $LOCATION" - az group create --output none --name $RESOURCE_GROUP --location $LOCATION - report_status "$?" "creating resource group" - else - echo "Resource Group $RESOURCE_GROUP exists, will reuse it." - fi - - echo "Creating Virtual Machine, will take a few minutes" - az vm create \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --location $LOCATION \ - --name $VM_NAME \ - --image $VM_IMAGE \ - --size $VM_SIZE \ - --admin-username $ADMIN_USERNAME \ - --admin-password $PASSWORD \ - --authentication-type password \ - --public-ip-sku Standard > /tmp/vm.json - report_status "$?" "creating virtual machine" - - IP_ADDRESS=$(jq -r .publicIpAddress /tmp/vm.json) - - echo "Setting up network related configuration" - - az network nsg create \ - --output none \ - --resource-group $RESOURCE_GROUP \ - --location $LOCATION \ - --name $NSG_NAME - report_status "$?" "creating network security group" - - az network nsg rule create \ - --output none \ - --resource-group $RESOURCE_GROUP \ - --name SSH \ - --nsg-name $NSG_NAME \ - --priority 1000 \ - --protocol Tcp \ - --destination-port-ranges 22 - report_status "$?" "creating network security group rule for SSH" - - az network nic update \ - --output none \ - --resource-group $RESOURCE_GROUP \ - --name $NIC_NAME \ - --network-security-group $NSG_NAME - report_status "$?" "updating network interface card" - - echo "Copying files to $VM_NAME" - DEST=$ADMIN_USERNAME@$IP_ADDRESS:$DEST_FOLDER - echo "Destination folder is ${DEST}" - cd $DIR/.. && sshpass -p $PASSWORD scp -r -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null $PWD $DEST - report_status "$?" "copying startup kits to VM" - - if [ $container = true ] - then - echo "Installing and lauching container in $VM_NAME, may take a few minutes." - scripts=$(cat <<- 'EOF' - sudo apt-get update && \ - sudo DEBIAN_FRONTEND=noninteractive apt-get install -y ca-certificates curl gnupg lsb-release && \ - sudo mkdir -p /etc/apt/keyrings && \ - curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg && \ - echo \ - "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ - $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null && \ - sudo apt-get update && \ - sudo DEBIAN_FRONTEND=noninteractive apt-get install -y docker-ce docker-ce-cli containerd.io - EOF - ) - az vm run-command invoke \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --command-id RunShellScript \ - --name $VM_NAME \ - --scripts \ - "$scripts" > /tmp/docker_engine.json - report_status "$?" "installing docker engine" - az vm run-command invoke \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --command-id RunShellScript \ - --name $VM_NAME \ - --scripts \ - "sudo usermod -aG docker $ADMIN_USERNAME" >> /tmp/docker_engine.json - az vm run-command invoke \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --command-id RunShellScript \ - --name $VM_NAME \ - --scripts \ - "docker run -d -v ${DEST_FOLDER}:${DEST_FOLDER} ${image_name} /bin/bash -c \"python -u -m nvflare.private.fed.app.client.client_train -m ${DEST_FOLDER} -s fed_client.json --set uid={~~SITE~~} secure_train=true config_folder=config org={~~ORG~~} \" " > /tmp/vm_create.json 2>&1 - report_status "$?" "launching container" - else - echo "Installing packages in $VM_NAME, may take a few minutes." - az vm run-command invoke \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --command-id RunShellScript \ - --name $VM_NAME \ - --scripts \ - "echo ${DEST_FOLDER} && wget -q https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python3 -m pip install --ignore-installed nvflare && touch ${DEST_FOLDER}/startup/requirements.txt && python3 -m pip install -r ${DEST_FOLDER}/startup/requirements.txt && ${DEST_FOLDER}/startup/start.sh && sleep 20 && cat ${DEST_FOLDER}/log.txt" > /tmp/vm_create.json - report_status "$?" "installing packages" - fi - echo "System was provisioned" - echo "To delete the resource group (also delete the VM), run the following command" - echo "az group delete -n ${RESOURCE_GROUP}" - echo "To login to the VM with SSH, use ${ADMIN_USERNAME} : ${PASSWORD}" > vm_credential.txt - -azure_start_dsb_sh: | - RESOURCE_GROUP=nvflare_dashboard_rg_${RANDOM}_${RANDOM} - VM_NAME=nvflare_dashboard - VM_IMAGE=Canonical:0001-com-ubuntu-server-focal:20_04-lts-gen2:latest - VM_SIZE=Standard_B2ms - NSG_NAME=nvflare_nsgc - ADMIN_USERNAME=nvflare - PASSWORD="NVFl@r3_P@88"$RANDOM"w0rd" - DEST_FOLDER=/var/tmp/cloud - LOCATION=westus2 - NIC_NAME=${VM_NAME}VMNic - - echo "This script requires az (Azure CLI), sshpass and jq. Now checking if they are installed." - - check_binary az "Please see https://learn.microsoft.com/en-us/cli/azure/install-azure-cli on how to install it on your system." - check_binary sshpass "Please install it first." - check_binary jq "Please install it first." - - echo "One initial user will be created when starting dashboard." - echo "Please enter the email address for this user." - read email - credential="${email}:$RANDOM" - - az login --use-device-code -o none - report_status "$?" "login" - - # Start provisioning - if [ $(az group exists -n $RESOURCE_GROUP) == 'false' ] - then - echo "Creating Resource Group $RESOURCE_GROUP at Location $LOCATION" - az group create --output none --name $RESOURCE_GROUP --location $LOCATION - report_status "$?" "creating resource group" - else - echo "Resource Group $RESOURCE_GROUP exists, will reuse it." - fi - - echo "Creating Virtual Machine, will take a few minutes" - az vm create \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --location $LOCATION \ - --name $VM_NAME \ - --image $VM_IMAGE \ - --size $VM_SIZE \ - --admin-username $ADMIN_USERNAME \ - --admin-password $PASSWORD \ - --authentication-type password \ - --public-ip-sku Standard > /tmp/vm.json - report_status "$?" "creating virtual machine" - - IP_ADDRESS=$(jq -r .publicIpAddress /tmp/vm.json) - report_status "$?" "extracting ip address" - - echo "Setting up network related configuration" - az network nsg create \ - --output none \ - --resource-group $RESOURCE_GROUP \ - --location $LOCATION \ - --name $NSG_NAME - report_status "$?" "creating network security group" - - az network nsg rule create \ - --output none \ - --resource-group $RESOURCE_GROUP \ - --name SSH \ - --nsg-name $NSG_NAME \ - --priority 1000 \ - --protocol Tcp \ - --destination-port-ranges 22 - report_status "$?" "creating network security group rule for SSH" - - az network nsg rule create \ - --output none \ - --resource-group $RESOURCE_GROUP \ - --name HTTPS \ - --nsg-name $NSG_NAME \ - --priority 1001 \ - --protocol Tcp \ - --destination-port-ranges 443 - report_status "$?" "creating network security group rule for HTTPS" - - az network nic update \ - --output none \ - --resource-group $RESOURCE_GROUP \ - --name $NIC_NAME \ - --network-security-group $NSG_NAME - report_status "$?" "updating network interface card" - - echo "Installing docker engine in $VM_NAME, may take a few minutes." - scripts=$(cat << 'EOF' - sudo apt-get update && \ - sudo DEBIAN_FRONTEND=noninteractive apt-get install -y ca-certificates curl gnupg lsb-release && \ - sudo mkdir -p /etc/apt/keyrings && \ - curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg && \ - echo \ - "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ - $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null && \ - sudo apt-get update && \ - sudo DEBIAN_FRONTEND=noninteractive apt-get install -y docker-ce docker-ce-cli containerd.io - EOF - ) - az vm run-command invoke \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --command-id RunShellScript \ - --name $VM_NAME \ - --scripts \ - "$scripts" > /tmp/docker_engine.json - report_status "$?" "installing docker engine" - az vm run-command invoke \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --command-id RunShellScript \ - --name $VM_NAME \ - --scripts \ - "sudo usermod -aG docker $ADMIN_USERNAME" >> /tmp/docker_engine.json - report_status "$?" "installing docker engine" - - DEST_FOLDER=/home/${ADMIN_USERNAME} - echo "Installing nvflare in $VM_NAME, may take a few minutes." - az vm run-command invoke \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --command-id RunShellScript \ - --name $VM_NAME \ - --scripts \ - "echo ${DEST_FOLDER} && wget -q https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python3 -m pip install --ignore-installed {~~NVFLARE~~} && mkdir -p ${DEST_FOLDER}/cert && chown -R ${ADMIN_USERNAME} ${DEST_FOLDER}" > /tmp/nvflare.json - report_status "$?" "installing nvflare" - - echo "Checking if certificate (web.crt) and private key (web.key) are available" - if [[ -f "web.crt" && -f "web.key" ]]; then - DEST=$ADMIN_USERNAME@$IP_ADDRESS:${DEST_FOLDER}/cert - echo "Destination folder is ${DEST}" - sshpass -p $PASSWORD scp -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null web.{crt,key} $DEST - report_status "$?" "copying cert/key to VM ${DEST} folder" - secure=true - else - echo "No web.crt and web.key found" - secure=false - fi - - echo "Starting dashboard" - az vm run-command invoke \ - --output json \ - --resource-group $RESOURCE_GROUP \ - --command-id RunShellScript \ - --name $VM_NAME \ - --scripts \ - "cd ${DEST_FOLDER} && python3 -m nvflare.dashboard.cli --start -f ${DEST_FOLDER} --cred ${credential} {~~START_OPT~~}" > /tmp/dashboard.json - - # credential=$(jq -r .value[0].message /tmp/dashboard.json | grep "Project admin") - # echo "The VM was created with user: ${ADMIN_USERNAME} and password: ${PASSWORD}" - if [ "$secure" = true ] - then - echo "URL is https://${IP_ADDRESS}" - else - echo "URL is http://${IP_ADDRESS}:443" - fi - echo "Note: you may need to configure DNS server with your DNS hostname and the above IP address." - echo "Project admin credential (username:password) is ${credential} ." - echo "To stop the dashboard, run az group delete -n ${RESOURCE_GROUP}" - echo "To login to the VM with SSH, use ${ADMIN_USERNAME} : ${PASSWORD}" > vm_credential.txt - adm_notebook: | { "cells": [ @@ -1611,402 +941,3 @@ adm_notebook: | "nbformat_minor": 5 } -aws_start_svr_sh: | - VM_NAME=nvflare_server - SECURITY_GROUP=nvflare_server_sg - DEST_FOLDER=/var/tmp/cloud - KEY_PAIR=NVFlareServerKeyPair - KEY_FILE=${KEY_PAIR}.pem - - echo "This script requires aws (AWS CLI), sshpass, dig and jq. Now checking if they are installed." - - check_binary aws "Please see https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html on how to install it on your system." - check_binary sshpass "Please install it first." - check_binary dig "Please install it first." - check_binary jq "Please install it first." - - if [ -z ${image_name+x} ] - then - container=false - else - container=true - fi - - if [ $container = true ] - then - AMI_IMAGE=ami-06b8d5099f3a8d79d - EC2_TYPE=t2.xlarge - REGION=us-west-2 - else - AMI_IMAGE=ami-04bad3c587fe60d89 - EC2_TYPE=t2.small - REGION=us-west-2 - fi - - if [ -z ${config_file+x} ] - then - useDefault=true - else - useDefault=false - . $config_file - report_status "$?" "Loading config file" - fi - - - if [ $useDefault = true ] - then - while true - do - prompt AMI_IMAGE "Cloud AMI image, press ENTER to accept default ${AMI_IMAGE}: " - prompt EC2_TYPE "Cloud EC2 type, press ENTER to accept default ${EC2_TYPE}: " - prompt REGIION "Cloud EC2 region, press ENTER to accept default ${REGION}: " - prompt ans "region = ${REGION}, ami image = ${AMI_IMAGE}, EC2 type = ${EC2_TYPE}, OK? (Y/n) " - if [[ $ans = "" ]] || [[ $ans =~ ^(y|Y)$ ]] - then - break - fi - done - fi - - if [ $container = false ] - then - echo "If the server requires additional dependencies, please copy the requirements.txt to ${DIR}." - prompt ans "Press ENTER when it's done or no additional dependencies. " - fi - - cd $DIR/.. - # Generate key pair - - echo "Generating key pair for VM" - - aws ec2 delete-key-pair --key-name $KEY_PAIR > /dev/null 2>&1 - rm -rf $KEY_FILE - aws ec2 create-key-pair --key-name $KEY_PAIR --query 'KeyMaterial' --output text > $KEY_FILE - report_status "$?" "creating key pair" - chmod 400 $KEY_FILE - - # Generate Security Group - - sg_result=$(aws ec2 create-security-group --group-name $SECURITY_GROUP --description "NVFlare security group") - report_status "$?" "Only one NVFL server VM and its security group is allowed. $SECURITY_GROUP exists and thus creating duplicate security group" - sg_id=$(echo $sg_result | jq -r .GroupId) - my_public_ip=$(dig +short myip.opendns.com @resolver1.opendns.com) - if [ "$?" -eq 0 ] && [[ "$my_public_ip" =~ ^(([1-9]?[0-9]|1[0-9][0-9]|2([0-4][0-9]|5[0-5]))\.){3}([1-9]?[0-9]|1[0-9][0-9]|2([0-4][0-9]|5[0-5]))$ ]] - then - aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 22 --cidr ${my_public_ip}/32 > /tmp/sec_grp.log - else - echo "getting my public IP failed, please manually configure the inbound rule to limit SSH access" - aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 22 --cidr 0.0.0.0/0 > /tmp/sec_grp.log - fi - aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 8002-8003 --cidr 0.0.0.0/0 >> /tmp/sec_grp.log - report_status "$?" "creating security group rules" - - # Start provisioning - - echo "Creating VM at region $REGION, may take a few minutes." - - aws ec2 run-instances --region $REGION --image-id $AMI_IMAGE --count 1 --instance-type $EC2_TYPE --key-name $KEY_PAIR --security-group-ids $sg_id > vm_create.json - report_status "$?" "creating VM" - instance_id=$(jq -r .Instances[0].InstanceId vm_create.json) - - aws ec2 wait instance-status-ok --instance-ids $instance_id - aws ec2 describe-instances --instance-ids $instance_id > vm_result.json - - IP_ADDRESS=$(jq -r .Reservations[0].Instances[0].PublicIpAddress vm_result.json) - - echo "VM created with IP address: ${IP_ADDRESS}" - - echo "Copying files to $VM_NAME" - DEST_SITE=ubuntu@${IP_ADDRESS} - DEST=${DEST_SITE}:${DEST_FOLDER} - echo "Destination folder is ${DEST}" - scp -q -i $KEY_FILE -r -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null $PWD $DEST - report_status "$?" "copying startup kits to VM" - - if [ $container = true ] - then - echo "Launching container with docker option ${DOCKER_OPTION}." - ssh -f -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ - "docker run -d -v ${DEST_FOLDER}:${DEST_FOLDER} --network host ${DOCKER_OPTION} ${image_name} \ - /bin/bash -c \"python -u -m nvflare.private.fed.app.server.server_train -m ${DEST_FOLDER} \ - -s fed_server.json --set secure_train=true config_folder=config org={~~ORG~~} \" " > /tmp/nvflare.log 2>&1 - report_status "$?" "launching container" - else - ssh -f -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ - "pwd && wget -q https://bootstrap.pypa.io/get-pip.py && \ - python3 get-pip.py && python3 -m pip install nvflare && \ - touch ${DEST_FOLDER}/startup/requirements.txt && \ - python3 -m pip install -r ${DEST_FOLDER}/startup/requirements.txt && \ - nohup ${DEST_FOLDER}/startup/start.sh && sleep 20 && \ - exit" > /tmp/nvflare.log 2>&1 - report_status "$?" "installing packages" - fi - - echo "System was provisioned" - echo "To terminate the EC2 instance, run the following command." - echo "aws ec2 terminate-instances --instance-ids ${instance_id}" - echo "Other resources provisioned" - echo "security group: ${SECURITY_GROUP}" - echo "key pair: ${KEY_PAIR}" - -aws_start_cln_sh: | - VM_NAME=nvflare_client - SECURITY_GROUP=nvflare_client_sg_$RANDOM - DEST_FOLDER=/var/tmp/cloud - KEY_PAIR=NVFlareClientKeyPair - KEY_FILE=${KEY_PAIR}.pem - - echo "This script requires aws (AWS CLI), sshpass, dig and jq. Now checking if they are installed." - - check_binary aws "Please see https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html on how to install it on your system." - check_binary sshpass "Please install it first." - check_binary dig "Please install it first." - check_binary jq "Please install it first." - - if [ -z ${image_name+x} ] - then - container=false - else - container=true - fi - - if [ $container = true ] - then - AMI_IMAGE=ami-06b8d5099f3a8d79d - EC2_TYPE=t2.xlarge - REGION=us-west-2 - else - AMI_IMAGE=ami-04bad3c587fe60d89 - EC2_TYPE=t2.small - REGION=us-west-2 - fi - - if [ -z ${config_file+x} ] - then - useDefault=true - else - useDefault=false - . $config_file - report_status "$?" "Loading config file" - fi - - if [ $useDefault = true ] - then - while true - do - prompt AMI_IMAGE "Cloud AMI image, press ENTER to accept default ${AMI_IMAGE}: " - prompt EC2_TYPE "Cloud EC2 type, press ENTER to accept default ${EC2_TYPE}: " - prompt REGIION "Cloud EC2 region, press ENTER to accept default ${REGION}: " - prompt ans "region = ${REGION}, ami image = ${AMI_IMAGE}, EC2 type = ${EC2_TYPE}, OK? (Y/n) " - if [[ $ans = "" ]] || [[ $ans =~ ^(y|Y)$ ]] - then - break - fi - done - fi - - if [ $container = false ] - then - echo "If the client requires additional dependencies, please copy the requirements.txt to ${DIR}." - prompt ans "Press ENTER when it's done or no additional dependencies. " - fi - - cd $DIR/.. - # Generate key pair - - echo "Generating key pair for VM" - - aws ec2 delete-key-pair --key-name $KEY_PAIR > /dev/null 2>&1 - rm -rf $KEY_FILE - aws ec2 create-key-pair --key-name $KEY_PAIR --query 'KeyMaterial' --output text > $KEY_FILE - report_status "$?" "creating key pair" - chmod 400 $KEY_FILE - - # Generate Security Group - # Try not reusing existing security group because we have to modify it for our own need. - sg_id=$(aws ec2 create-security-group --group-name $SECURITY_GROUP --description "NVFlare security group" | jq -r .GroupId) - report_status "$?" "creating security group" - my_public_ip=$(dig +short myip.opendns.com @resolver1.opendns.com) - if [ "$?" -eq 0 ] && [[ "$my_public_ip" =~ ^(([1-9]?[0-9]|1[0-9][0-9]|2([0-4][0-9]|5[0-5]))\.){3}([1-9]?[0-9]|1[0-9][0-9]|2([0-4][0-9]|5[0-5]))$ ]] - then - aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 22 --cidr ${my_public_ip}/32 > /tmp/sec_grp.log - else - echo "getting my public IP failed, please manually configure the inbound rule to limit SSH access" - aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 22 --cidr 0.0.0.0/0 > /tmp/sec_grp.log - fi - report_status "$?" "creating security group rules" - - # Start provisioning - - echo "Creating VM at region $REGION, may take a few minutes." - - aws ec2 run-instances --region $REGION --image-id $AMI_IMAGE --count 1 --instance-type $EC2_TYPE --key-name $KEY_PAIR --security-group-ids $sg_id > vm_create.json - report_status "$?" "creating VM" - instance_id=$(jq -r .Instances[0].InstanceId vm_create.json) - - aws ec2 wait instance-status-ok --instance-ids $instance_id - aws ec2 describe-instances --instance-ids $instance_id > vm_result.json - - IP_ADDRESS=$(jq -r .Reservations[0].Instances[0].PublicIpAddress vm_result.json) - - echo "VM created with IP address: ${IP_ADDRESS}" - - echo "Copying files to $VM_NAME" - DEST_SITE=ubuntu@${IP_ADDRESS} - DEST=${DEST_SITE}:${DEST_FOLDER} - echo "Destination folder is ${DEST}" - scp -q -i $KEY_FILE -r -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null $PWD $DEST - report_status "$?" "copying startup kits to VM" - - if [ $container = true ] - then - echo "Launching container with docker option ${DOCKER_OPTION}." - ssh -f -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ - "docker run -d -v ${DEST_FOLDER}:${DEST_FOLDER} --network host ${DOCKER_OPTION} ${image_name} \ - /bin/bash -c \"python -u -m nvflare.private.fed.app.client.client_train -m ${DEST_FOLDER} \ - -s fed_client.json --set uid={~~SITE~~} secure_train=true config_folder=config org={~~ORG~~} \" " > /tmp/nvflare.log 2>&1 - report_status "$?" "launching container" - else - echo "Installing packages in $VM_NAME, may take a few minutes." - ssh -f -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ - "pwd && wget -q https://bootstrap.pypa.io/get-pip.py && \ - python3 get-pip.py && python3 -m pip install nvflare && \ - touch ${DEST_FOLDER}/startup/requirements.txt && \ - python3 -m pip install -r ${DEST_FOLDER}/startup/requirements.txt && \ - nohup ${DEST_FOLDER}/startup/start.sh && sleep 20 && \ - exit" > /tmp/nvflare.log 2>&1 - - report_status "$?" "installing packages" - fi - - echo "System was provisioned" - echo "To terminate the EC2 instance, run the following command." - echo "aws ec2 terminate-instances --instance-ids ${instance_id}" - echo "Other resources provisioned" - echo "security group: ${SECURITY_GROUP}" - echo "key pair: ${KEY_PAIR}" - - -aws_start_dsb_sh: | - VM_NAME=nvflare_dashboard - AMI_IMAGE=ami-04bad3c587fe60d89 - EC2_TYPE=t2.small - SECURITY_GROUP=nvflare_dashboard_sg_$RANDOM - REGION=us-west-2 - ADMIN_USERNAME=ubuntu - DEST_FOLDER=/home/${ADMIN_USERNAME} - KEY_PAIR=NVFlareDashboardKeyPair - KEY_FILE=${KEY_PAIR}.pem - - echo "This script requires aws (AWS CLI), sshpass, dig and jq. Now checking if they are installed." - - check_binary aws "Please see https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html on how to install it on your system." - check_binary sshpass "Please install it first." - check_binary dig "Please install it first." - check_binary jq "Please install it first." - - echo "One initial user will be created when starting dashboard." - echo "Please enter the email address for this user." - read email - credential="${email}:$RANDOM" - - # Generate key pair - - echo "Generating key pair for VM" - - aws ec2 delete-key-pair --key-name $KEY_PAIR > /dev/null 2>&1 - rm -rf $KEY_FILE - aws ec2 create-key-pair --key-name $KEY_PAIR --query 'KeyMaterial' --output text > $KEY_FILE - report_status "$?" "creating key pair" - chmod 400 $KEY_FILE - - # Generate Security Group - - sg_id=$(aws ec2 create-security-group --group-name $SECURITY_GROUP --description "NVFlare security group" | jq -r .GroupId) - report_status "$?" "creating security group" - echo "Security group id: ${sg_id}" - my_public_ip=$(dig +short myip.opendns.com @resolver1.opendns.com) - if [ "$?" -eq 0 ] && [[ "$my_public_ip" =~ ^(([1-9]?[0-9]|1[0-9][0-9]|2([0-4][0-9]|5[0-5]))\.){3}([1-9]?[0-9]|1[0-9][0-9]|2([0-4][0-9]|5[0-5]))$ ]] - then - aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 22 --cidr ${my_public_ip}/32 > /tmp/sec_grp.log - else - echo "getting my public IP failed, please manually configure the inbound rule to limit SSH access" - aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 22 --cidr 0.0.0.0/0 > /tmp/sec_grp.log - fi - aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 443 --cidr 0.0.0.0/0 >> /tmp/sec_grp.log - report_status "$?" "creating security group rules" - - # Start provisioning - - echo "Creating VM at region $REGION, may take a few minutes." - - aws ec2 run-instances --region $REGION --image-id $AMI_IMAGE --count 1 --instance-type $EC2_TYPE --key-name $KEY_PAIR --security-group-ids $sg_id > vm_create.json - report_status "$?" "creating VM" - instance_id=$(jq -r .Instances[0].InstanceId vm_create.json) - - aws ec2 wait instance-status-ok --instance-ids $instance_id - aws ec2 describe-instances --instance-ids $instance_id > vm_result.json - - IP_ADDRESS=$(jq -r .Reservations[0].Instances[0].PublicIpAddress vm_result.json) - - echo "VM created with IP address: ${IP_ADDRESS}" - - echo "Installing docker engine in $VM_NAME, may take a few minutes." - DEST_SITE=${ADMIN_USERNAME}@${IP_ADDRESS} - scripts=$(cat << 'EOF' - sudo apt-get update && \ - sudo DEBIAN_FRONTEND=noninteractive apt-get install -y ca-certificates curl gnupg lsb-release && \ - sudo mkdir -p /etc/apt/keyrings && \ - curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg && \ - echo \ - "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \ - $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null && \ - sudo apt-get update && \ - sudo DEBIAN_FRONTEND=noninteractive apt-get install -y docker-ce docker-ce-cli containerd.io - EOF - ) - ssh -t -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} "$scripts" > /tmp/docker_engine.log - report_status "$?" "installing docker engine" - ssh -t -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} "sudo usermod -aG docker $ADMIN_USERNAME && exit" >> /tmp/docker_engine.log - report_status "$?" "installing docker engine" - - echo "Installing nvflare in $VM_NAME, may take a few minutes." - ssh -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ - "export PATH=/home/ubuntu/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/snap/bin && \ - wget -q https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && \ - python3 -m pip install {~~NVFLARE~~} && \ - mkdir -p ./cert && \ - exit" > /tmp/nvflare.json - report_status "$?" "installing nvflare" - - echo "Checking if certificate (web.crt) and private key (web.key) are available" - if [[ -f "web.crt" && -f "web.key" ]]; then - CERT_FOLDER=${DEST_SITE}:${DEST_FOLDER}/cert - echo "Cert folder is ${CERT_FOLDER}" - scp -i $KEY_FILE -r -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null web.{crt,key} $CERT_FOLDER - report_status "$?" "copying cert/key to VM ${CERT_FOLDER} folder" - secure=true - else - echo "No web.crt and web.key found" - secure=false - fi - - echo "Starting dashboard" - ssh -i $KEY_FILE -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null ${DEST_SITE} \ - "export PATH=/home/ubuntu/.local/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/snap/bin && \ - python3 -m nvflare.dashboard.cli --start -f ${DEST_FOLDER} --cred ${credential} {~~START_OPT~~}" > /tmp/dashboard.json - - echo "Dashboard url is running at IP address ${IP_ADDRESS}, listening to port 443." - if [ "$secure" = true ] - then - echo "URL is https://${IP_ADDRESS}" - else - echo "URL is http://${IP_ADDRESS}:443" - fi - echo "Note: you may need to configure DNS server with your DNS hostname and the above IP address." - echo "Project admin credential (username:password) is ${credential} ." - echo "To terminate the EC2 instance, run the following command." - echo "aws ec2 terminate-instances --instance-ids ${instance_id}" - echo "Other resources provisioned" - echo "security group: ${SECURITY_GROUP}" - echo "key pair: ${KEY_PAIR}" diff --git a/nvflare/lighter/impl/static_file.py b/nvflare/lighter/impl/static_file.py index 21ef4c8f04..a024c84a43 100644 --- a/nvflare/lighter/impl/static_file.py +++ b/nvflare/lighter/impl/static_file.py @@ -18,8 +18,8 @@ import yaml +from nvflare.lighter import utils from nvflare.lighter.spec import Builder -from nvflare.lighter.utils import sh_replace class StaticFileBuilder(Builder): @@ -61,13 +61,6 @@ def __init__( self.snapshot_persistor = snapshot_persistor self.components = components - def _write(self, file_full_path, content, mode, exe=False): - mode = mode + "w" - with open(file_full_path, mode) as f: - f.write(content) - if exe: - os.chmod(file_full_path, 0o755) - def get_server_name(self, server): return server.name @@ -76,7 +69,7 @@ def get_overseer_name(self, overseer): def _build_overseer(self, overseer, ctx): dest_dir = self.get_kit_dir(overseer, ctx) - self._write( + utils._write( os.path.join(dest_dir, "start.sh"), self.template["start_svr_sh"], "t", @@ -95,7 +88,7 @@ def _build_overseer(self, overseer, ctx): privilege_dict[role].append(admin.subject) else: privilege_dict[role] = [admin.subject] - self._write( + utils._write( os.path.join(dest_dir, "privilege.yml"), yaml.dump(privilege_dict, Dumper=yaml.Dumper), "t", @@ -103,19 +96,19 @@ def _build_overseer(self, overseer, ctx): ) if self.docker_image: - self._write( + utils._write( os.path.join(dest_dir, "docker.sh"), - sh_replace(self.template["docker_svr_sh"], replacement_dict), + utils.sh_replace(self.template["docker_svr_sh"], replacement_dict), "t", exe=True, ) - self._write( + utils._write( os.path.join(dest_dir, "gunicorn.conf.py"), - sh_replace(self.template["gunicorn_conf_py"], replacement_dict), + utils.sh_replace(self.template["gunicorn_conf_py"], replacement_dict), "t", exe=False, ) - self._write( + utils._write( os.path.join(dest_dir, "start.sh"), self.template["start_ovsr_sh"], "t", @@ -140,11 +133,6 @@ def _build_server(self, server, ctx): server_0["service"]["scheme"] = self.scheme server_0["admin_host"] = self.get_server_name(server) server_0["admin_port"] = admin_port - # if self.download_job_url: - # server_0["download_job_url"] = self.download_job_url - # config["enable_byoc"] = server.enable_byoc - # if self.app_validator: - # config["app_validator"] = {"path": self.app_validator} if self.overseer_agent: overseer_agent = copy.deepcopy(self.overseer_agent) if overseer_agent.get("overseer_exists", True): @@ -158,46 +146,36 @@ def _build_server(self, server, ctx): } overseer_agent.pop("overseer_exists", None) config["overseer_agent"] = overseer_agent - # if self.snapshot_persistor: - # config["snapshot_persistor"] = self.snapshot_persistor - # components = server.props.get("components", []) - # config["components"] = list() - # for comp in components: - # temp_dict = {"id": comp} - # temp_dict.update(components[comp]) - # config["components"].append(temp_dict) - # provisioned_client_list = list() - # for client in self.project.get_participants_by_type("client", first_only=False): - # provisioned_client_list.append(client.name) - # config["provisioned_client_list"] = provisioned_client_list - self._write(os.path.join(dest_dir, "fed_server.json"), json.dumps(config, indent=2), "t") + utils._write(os.path.join(dest_dir, "fed_server.json"), json.dumps(config, indent=2), "t") replacement_dict = { "admin_port": admin_port, "fed_learn_port": fed_learn_port, "config_folder": self.config_folder, "docker_image": self.docker_image, "org_name": server.org, + "type": "server", + "cln_uid": "", } if self.docker_image: - self._write( + utils._write( os.path.join(dest_dir, "docker.sh"), - sh_replace(self.template["docker_svr_sh"], replacement_dict), + utils.sh_replace(self.template["docker_svr_sh"], replacement_dict), "t", exe=True, ) - self._write( + utils._write( os.path.join(dest_dir, "start.sh"), self.template["start_svr_sh"], "t", exe=True, ) - self._write( + utils._write( os.path.join(dest_dir, "sub_start.sh"), - sh_replace(self.template["sub_start_svr_sh"], replacement_dict), + utils.sh_replace(self.template["sub_start_sh"], replacement_dict), "t", exe=True, ) - self._write( + utils._write( os.path.join(dest_dir, "stop_fl.sh"), self.template["stop_fl_sh"], "t", @@ -205,29 +183,29 @@ def _build_server(self, server, ctx): ) # local folder creation dest_dir = self.get_local_dir(server, ctx) - self._write( + utils._write( os.path.join(dest_dir, "log.config.default"), self.template["log_config"], "t", ) - self._write( + utils._write( os.path.join(dest_dir, "resources.json.default"), self.template["local_server_resources"], "t", ) - self._write( + utils._write( os.path.join(dest_dir, "privacy.json.sample"), self.template["sample_privacy"], "t", ) - self._write( + utils._write( os.path.join(dest_dir, "authorization.json.default"), self.template["default_authz"], "t", ) # workspace folder file - self._write( + utils._write( os.path.join(self.get_ws_dir(server, ctx), "readme.txt"), self.template["readme_fs"], "t", @@ -247,6 +225,8 @@ def _build_client(self, client, ctx): "config_folder": self.config_folder, "docker_image": self.docker_image, "org_name": client.org, + "type": "client", + "cln_uid": f"uid={client.subject}", } if self.overseer_agent: overseer_agent = copy.deepcopy(self.overseer_agent) @@ -266,27 +246,27 @@ def _build_client(self, client, ctx): # temp_dict.update(components[comp]) # config["components"].append(temp_dict) - self._write(os.path.join(dest_dir, "fed_client.json"), json.dumps(config, indent=2), "t") + utils._write(os.path.join(dest_dir, "fed_client.json"), json.dumps(config, indent=2), "t") if self.docker_image: - self._write( + utils._write( os.path.join(dest_dir, "docker.sh"), - sh_replace(self.template["docker_cln_sh"], replacement_dict), + utils.sh_replace(self.template["docker_cln_sh"], replacement_dict), "t", exe=True, ) - self._write( + utils._write( os.path.join(dest_dir, "start.sh"), self.template["start_cln_sh"], "t", exe=True, ) - self._write( + utils._write( os.path.join(dest_dir, "sub_start.sh"), - sh_replace(self.template["sub_start_cln_sh"], replacement_dict), + utils.sh_replace(self.template["sub_start_sh"], replacement_dict), "t", exe=True, ) - self._write( + utils._write( os.path.join(dest_dir, "stop_fl.sh"), self.template["stop_fl_sh"], "t", @@ -294,29 +274,29 @@ def _build_client(self, client, ctx): ) # local folder creation dest_dir = self.get_local_dir(client, ctx) - self._write( + utils._write( os.path.join(dest_dir, "log.config.default"), self.template["log_config"], "t", ) - self._write( + utils._write( os.path.join(dest_dir, "resources.json.default"), self.template["local_client_resources"], "t", ) - self._write( + utils._write( os.path.join(dest_dir, "privacy.json.sample"), self.template["sample_privacy"], "t", ) - self._write( + utils._write( os.path.join(dest_dir, "authorization.json.default"), self.template["default_authz"], "t", ) # workspace folder file - self._write( + utils._write( os.path.join(self.get_ws_dir(client, ctx), "readme.txt"), self.template["readme_fc"], "t", @@ -335,21 +315,21 @@ def _build_admin(self, admin, ctx): config = self.prepare_admin_config(admin, ctx) - self._write(os.path.join(dest_dir, "fed_admin.json"), json.dumps(config, indent=2), "t") + utils._write(os.path.join(dest_dir, "fed_admin.json"), json.dumps(config, indent=2), "t") if self.docker_image: - self._write( + utils._write( os.path.join(dest_dir, "docker.sh"), - sh_replace(self.template["docker_adm_sh"], replacement_dict), + utils.sh_replace(self.template["docker_adm_sh"], replacement_dict), "t", exe=True, ) - self._write( + utils._write( os.path.join(dest_dir, "fl_admin.sh"), - sh_replace(self.template["fl_admin_sh"], replacement_dict), + utils.sh_replace(self.template["fl_admin_sh"], replacement_dict), "t", exe=True, ) - self._write( + utils._write( os.path.join(dest_dir, "readme.txt"), self.template["readme_am"], "t", diff --git a/nvflare/lighter/impl/template.py b/nvflare/lighter/impl/template.py index a85277ad11..e3a19e8261 100644 --- a/nvflare/lighter/impl/template.py +++ b/nvflare/lighter/impl/template.py @@ -26,6 +26,8 @@ class TemplateBuilder(Builder): def initialize(self, ctx): resource_dir = self.get_resources_dir(ctx) - template_file = ctx.get("template_file") - template = load_yaml(os.path.join(resource_dir, template_file)) + template_files = ctx.get("template_files") + template = dict() + for tplt_file in template_files: + template.update(load_yaml(os.path.join(resource_dir, tplt_file))) ctx["template"] = template diff --git a/nvflare/lighter/impl/workspace.py b/nvflare/lighter/impl/workspace.py index 4926b20629..6b203227df 100644 --- a/nvflare/lighter/impl/workspace.py +++ b/nvflare/lighter/impl/workspace.py @@ -43,9 +43,9 @@ def __init__(self, template_file): wip/ <--- this is only used during runtime, and will be removed when the provision command exits Args: - template_file: name of template file containing scripts and configs to put into startup folders + template_file: name(s) of template file(s) containing scripts and configs to put into startup folders """ - self.template_file = template_file + self.template_files = template_file def _make_dir(self, dirs): for dir in dirs: @@ -61,10 +61,15 @@ def initialize(self, ctx): if stage > last: last = stage ctx["last_prod_stage"] = last - template_file_full_path = os.path.join(self.get_resources_dir(ctx), self.template_file) - file_path = pathlib.Path(__file__).parent.absolute() - shutil.copyfile(os.path.join(file_path, self.template_file), template_file_full_path) - ctx["template_file"] = self.template_file + if not isinstance(self.template_files, list): + self.template_files = [self.template_files] + tplt_file_list = [] + for tplt_file in self.template_files: + tplt_file_full_path = os.path.join(self.get_resources_dir(ctx), tplt_file) + file_path = pathlib.Path(__file__).parent.absolute() + shutil.copyfile(os.path.join(file_path, tplt_file), tplt_file_full_path) + tplt_file_list.append(tplt_file) + ctx["template_files"] = tplt_file_list def build(self, project: Project, ctx: dict): dirs = [self.get_kit_dir(p, ctx) for p in project.participants] diff --git a/nvflare/lighter/tool_consts.py b/nvflare/lighter/tool_consts.py new file mode 100644 index 0000000000..2ba1970a9f --- /dev/null +++ b/nvflare/lighter/tool_consts.py @@ -0,0 +1,17 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +NVFLARE_PREFIX = ".__nvfl_" +NVFLARE_SIG_FILE = ".__nvfl_sig.json" +NVFLARE_SUBMITTER_CRT_FILE = ".__nvfl_submitter.crt" diff --git a/nvflare/lighter/tplt_utils.py b/nvflare/lighter/tplt_utils.py index e0ebf5aec9..590052a5cd 100644 --- a/nvflare/lighter/tplt_utils.py +++ b/nvflare/lighter/tplt_utils.py @@ -13,9 +13,81 @@ # limitations under the License. +from . import utils + + class Template: def __init__(self, template): self.template = template + self.supported_csps = ("azure", "aws") def get_cloud_script_header(self): return self.template.get("cloud_script_header") + + def get_azure_server_start_sh(self, entity): + tmp = self.get_cloud_script_header() + self.get_azure_start_svr_header_sh() + self.get_azure_start_common_sh() + script = utils.sh_replace( + tmp, + { + "type": "server", + "docker_network": "--network host", + "cln_uid": "", + "server_name": entity.name, + "ORG": "", + }, + ) + return script + + def get_aws_server_start_sh(self, entity): + tmp = self.get_cloud_script_header() + self.template.get("aws_start_sh") + script = utils.sh_replace( + tmp, + { + "type": "server", + "inbound_rule": "aws ec2 authorize-security-group-ingress --group-id $sg_id --protocol tcp --port 8002-8003 --cidr 0.0.0.0/0 >> /tmp/sec_grp.log", + "cln_uid": "", + "server_name": entity.name, + "ORG": "", + }, + ) + return script + + def get_azure_client_start_sh(self, entity): + tmp = self.get_cloud_script_header() + self.get_azure_start_cln_header_sh() + self.get_azure_start_common_sh() + script = utils.sh_replace( + tmp, + {"type": "client", "docker_network": "", "cln_uid": f"uid={entity.name}", "ORG": entity.org}, + ) + return script + + def get_aws_client_start_sh(self, entity): + tmp = self.get_cloud_script_header() + self.template.get("aws_start_sh") + script = utils.sh_replace( + tmp, {"type": "client", "inbound_rule": "", "cln_uid": f"uid={entity.name}", "ORG": entity.org} + ) + return script + + def get_azure_start_svr_header_sh(self): + return self.template.get("azure_start_svr_header_sh") + + def get_azure_start_cln_header_sh(self): + return self.template.get("azure_start_cln_header_sh") + + def get_azure_start_common_sh(self): + return self.template.get("azure_start_common_sh") + + def get_sub_start_sh(self): + return self.template.get("sub_start_sh") + + def get_azure_svr_sh(self): + return self.get_cloud_script_header() + self.get_azure_start_svr_header_sh() + self.get_azure_start_common_sh() + + def get_azure_cln_sh(self): + return self.get_cloud_script_header() + self.get_azure_start_cln_header_sh() + self.get_azure_start_common_sh() + + def get_start_sh(self, csp, type, entity): + try: + func = getattr(self, f"get_{csp}_{type}_start_sh") + return func(entity) + except AttributeError: + return "" diff --git a/nvflare/lighter/utils.py b/nvflare/lighter/utils.py index fa202b480a..2bb1b2b753 100644 --- a/nvflare/lighter/utils.py +++ b/nvflare/lighter/utils.py @@ -24,6 +24,7 @@ from cryptography.hazmat.primitives.asymmetric import padding from nvflare.lighter.impl.cert import load_crt +from nvflare.lighter.tool_consts import NVFLARE_SIG_FILE, NVFLARE_SUBMITTER_CRT_FILE def generate_password(passlen=16): @@ -56,7 +57,7 @@ def sign_folders(folder, signing_pri_key, crt_path, max_depth=9999): depth = depth + 1 signatures = dict() for file in files: - if file == ".__nvfl_sig.json" or file == ".__nvfl_submitter.crt": + if file == NVFLARE_SIG_FILE or file == NVFLARE_SUBMITTER_CRT_FILE: continue signature = signing_pri_key.sign( data=open(os.path.join(root, file), "rb").read(), @@ -78,8 +79,8 @@ def sign_folders(folder, signing_pri_key, crt_path, max_depth=9999): ) signatures[folder] = b64encode(signature).decode("utf-8") - json.dump(signatures, open(os.path.join(root, ".__nvfl_sig.json"), "wt")) - shutil.copyfile(crt_path, os.path.join(root, ".__nvfl_submitter.crt")) + json.dump(signatures, open(os.path.join(root, NVFLARE_SIG_FILE), "wt")) + shutil.copyfile(crt_path, os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE)) if depth >= max_depth: break @@ -90,8 +91,8 @@ def verify_folder_signature(src_folder, root_ca_path): root_ca_public_key = root_ca_cert.public_key() for root, folders, files in os.walk(src_folder): try: - signatures = json.load(open(os.path.join(root, ".__nvfl_sig.json"), "rt")) - cert = load_crt(os.path.join(root, ".__nvfl_submitter.crt")) + signatures = json.load(open(os.path.join(root, NVFLARE_SIG_FILE), "rt")) + cert = load_crt(os.path.join(root, NVFLARE_SUBMITTER_CRT_FILE)) public_key = cert.public_key() except: continue # TODO: shall return False @@ -101,7 +102,7 @@ def verify_folder_signature(src_folder, root_ca_path): for k in signatures: signatures[k] = b64decode(signatures[k].encode("utf-8")) for file in files: - if file == ".__nvfl_sig.json" or file == ".__nvfl_submitter.crt": + if file == NVFLARE_SIG_FILE or file == NVFLARE_SUBMITTER_CRT_FILE: continue signature = signatures.get(file) if signature: @@ -224,3 +225,75 @@ def update_storage_locations( json_object = json.dumps(resources, indent=4) with open(target_resource, "w") as outfile: outfile.write(json_object) + + +def _write(file_full_path, content, mode, exe=False): + mode = mode + "w" + with open(file_full_path, mode) as f: + f.write(content) + if exe: + os.chmod(file_full_path, 0o755) + + +def _write_common(type, dest_dir, template, tplt, replacement_dict, config): + mapping = {"server": "svr", "client": "cln"} + _write(os.path.join(dest_dir, f"fed_{type}.json"), json.dumps(config, indent=2), "t") + _write( + os.path.join(dest_dir, "docker.sh"), + sh_replace(template[f"docker_{mapping[type]}_sh"], replacement_dict), + "t", + exe=True, + ) + _write( + os.path.join(dest_dir, "start.sh"), + sh_replace(template[f"start_{mapping[type]}_sh"], replacement_dict), + "t", + exe=True, + ) + _write( + os.path.join(dest_dir, "sub_start.sh"), + sh_replace(tplt.get_sub_start_sh(), replacement_dict), + "t", + exe=True, + ) + _write( + os.path.join(dest_dir, "stop_fl.sh"), + template["stop_fl_sh"], + "t", + exe=True, + ) + + +def _write_local(type, dest_dir, template, capacity=""): + _write( + os.path.join(dest_dir, "log.config.default"), + template["log_config"], + "t", + ) + _write( + os.path.join(dest_dir, "privacy.json.sample"), + template["sample_privacy"], + "t", + ) + _write( + os.path.join(dest_dir, "authorization.json.default"), + template["default_authz"], + "t", + ) + resources = json.loads(template["local_client_resources"]) + if type == "client": + for component in resources["components"]: + if "nvflare.app_common.resource_managers.gpu_resource_manager.GPUResourceManager" == component["path"]: + component["args"] = json.loads(capacity) + break + _write( + os.path.join(dest_dir, "resources.json.default"), + json.dumps(resources, indent=2), + "t", + ) + + +def _write_pki(type, dest_dir, cert_pair, root_cert): + _write(os.path.join(dest_dir, f"{type}.crt"), cert_pair.ser_cert, "b", exe=False) + _write(os.path.join(dest_dir, f"{type}.key"), cert_pair.ser_pri_key, "b", exe=False) + _write(os.path.join(dest_dir, "rootCA.pem"), root_cert, "b", exe=False) diff --git a/nvflare/private/fed/app/simulator/simulator_runner.py b/nvflare/private/fed/app/simulator/simulator_runner.py index 7e2d1cfd98..b6a84c4516 100644 --- a/nvflare/private/fed/app/simulator/simulator_runner.py +++ b/nvflare/private/fed/app/simulator/simulator_runner.py @@ -94,6 +94,15 @@ def __init__( self.clients_created = 0 + running_dir = os.getcwd() + if self.workspace is None: + self.workspace = "simulator_workspace" + self.logger.warn( + f"Simulator workspace is not provided. Set it to the default location:" + f" {os.path.join(running_dir, self.workspace)}" + ) + self.workspace = os.path.join(running_dir, self.workspace) + def _generate_args( self, job_folder: str, workspace: str, clients=None, n_clients=None, threads=None, gpu=None, max_clients=100 ): @@ -110,15 +119,6 @@ def _generate_args( return args def setup(self): - running_dir = os.getcwd() - if self.workspace is None: - self.workspace = "simulator_workspace" - self.logger.warn( - f"Simulator workspace is not provided. Set it to the default location:" - f" {os.path.join(running_dir, self.workspace)}" - ) - self.workspace = os.path.join(running_dir, self.workspace) - self.args = self._generate_args( self.job_folder, self.workspace, self.clients, self.n_clients, self.threads, self.gpu, self.max_clients ) @@ -331,9 +331,6 @@ def create_client(self, client_name): client_name, self.args ) self.federated_clients.append(client) - app_root = os.path.join(self.simulator_root, "app_" + client_name) - app_custom_folder = os.path.join(app_root, "custom") - sys.path.append(app_custom_folder) def _set_client_status(self): for client in self.federated_clients: @@ -348,7 +345,7 @@ def run(self): try: manager = Manager() return_dict = manager.dict() - process = Process(target=self.run_processs, args=(return_dict,)) + process = Process(target=self.run_process, args=(return_dict,)) process.start() process.join() run_status = self._get_return_code(return_dict, process, self.workspace) @@ -380,7 +377,7 @@ def _get_return_code(self, return_dict, process, workspace): self.logger.info(f"return_code from process.exitcode: {return_code}") return return_code - def run_processs(self, return_dict): + def run_process(self, return_dict): # run_status = self.simulator_run_main() try: run_status = mpm.run( diff --git a/nvflare/private/fed/client/client_executor.py b/nvflare/private/fed/client/client_executor.py index eae4b8abb0..c775f7b8b9 100644 --- a/nvflare/private/fed/client/client_executor.py +++ b/nvflare/private/fed/client/client_executor.py @@ -21,11 +21,12 @@ import time from abc import ABC, abstractmethod -from nvflare.apis.fl_constant import AdminCommandNames, RunProcessKey +from nvflare.apis.fl_constant import AdminCommandNames, RunProcessKey, SystemConfigs from nvflare.apis.resource_manager_spec import ResourceManagerSpec from nvflare.fuel.common.exit_codes import PROCESS_EXIT_REASON, ProcessExitCode from nvflare.fuel.f3.cellnet.core_cell import FQCN from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode +from nvflare.fuel.utils.config_service import ConfigService from nvflare.private.defs import CellChannel, CellChannelTopic, JobFailureMsgKey, new_cell_message from nvflare.private.fed.utils.fed_utils import get_return_code from nvflare.security.logging import secure_format_exception, secure_log_traceback @@ -138,6 +139,10 @@ def __init__(self, client, startup): self.run_processes = {} self.lock = threading.Lock() + self.job_query_timeout = ConfigService.get_float_var( + name="job_query_timeout", conf=SystemConfigs.APPLICATION_CONF, default=5.0 + ) + def start_app( self, client, @@ -216,10 +221,9 @@ def start_app( thread.start() def notify_job_status(self, job_id, job_status): - with self.lock: - run_process = self.run_processes.get(job_id) - if run_process: - run_process[RunProcessKey.STATUS] = job_status + run_process = self.run_processes.get(job_id) + if run_process: + run_process[RunProcessKey.STATUS] = job_status def check_status(self, job_id): """Checks the status of the running client. @@ -231,9 +235,8 @@ def check_status(self, job_id): A client status message """ try: - with self.lock: - process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED) - return get_status_message(process_status) + process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED) + return get_status_message(process_status) except Exception as e: self.logger.error(f"check_status execution exception: {secure_format_exception(e)}.") secure_log_traceback() @@ -249,23 +252,23 @@ def get_run_info(self, job_id): A dict of run information. """ try: - with self.lock: - data = {} - fqcn = FQCN.join([self.client.client_name, job_id]) - request = new_cell_message({}, data) - return_data = self.client.cell.send_request( - target=fqcn, - channel=CellChannel.CLIENT_COMMAND, - topic=AdminCommandNames.SHOW_STATS, - request=request, - optional=True, - ) - return_code = return_data.get_header(MessageHeaderKey.RETURN_CODE) - if return_code == ReturnCode.OK: - run_info = return_data.payload - return run_info - else: - return {} + data = {} + fqcn = FQCN.join([self.client.client_name, job_id]) + request = new_cell_message({}, data) + return_data = self.client.cell.send_request( + target=fqcn, + channel=CellChannel.CLIENT_COMMAND, + topic=AdminCommandNames.SHOW_STATS, + request=request, + optional=True, + timeout=self.job_query_timeout, + ) + return_code = return_data.get_header(MessageHeaderKey.RETURN_CODE) + if return_code == ReturnCode.OK: + run_info = return_data.payload + return run_info + else: + return {} except Exception as e: self.logger.error(f"get_run_info execution exception: {secure_format_exception(e)}.") secure_log_traceback() @@ -281,23 +284,23 @@ def get_errors(self, job_id): A dict of error information. """ try: - with self.lock: - data = {"command": AdminCommandNames.SHOW_ERRORS, "data": {}} - fqcn = FQCN.join([self.client.client_name, job_id]) - request = new_cell_message({}, data) - return_data = self.client.cell.send_request( - target=fqcn, - channel=CellChannel.CLIENT_COMMAND, - topic=AdminCommandNames.SHOW_ERRORS, - request=request, - optional=True, - ) - return_code = return_data.get_header(MessageHeaderKey.RETURN_CODE) - if return_code == ReturnCode.OK: - errors_info = return_data.payload - return errors_info - else: - return None + data = {"command": AdminCommandNames.SHOW_ERRORS, "data": {}} + fqcn = FQCN.join([self.client.client_name, job_id]) + request = new_cell_message({}, data) + return_data = self.client.cell.send_request( + target=fqcn, + channel=CellChannel.CLIENT_COMMAND, + topic=AdminCommandNames.SHOW_ERRORS, + request=request, + optional=True, + timeout=self.job_query_timeout, + ) + return_code = return_data.get_header(MessageHeaderKey.RETURN_CODE) + if return_code == ReturnCode.OK: + errors_info = return_data.payload + return errors_info + else: + return None except Exception as e: self.logger.error(f"get_errors execution exception: {secure_format_exception(e)}.") secure_log_traceback() @@ -310,17 +313,16 @@ def reset_errors(self, job_id): job_id: the job_id """ try: - with self.lock: - data = {"command": AdminCommandNames.RESET_ERRORS, "data": {}} - fqcn = FQCN.join([self.client.client_name, job_id]) - request = new_cell_message({}, data) - self.client.cell.fire_and_forget( - targets=fqcn, - channel=CellChannel.CLIENT_COMMAND, - topic=AdminCommandNames.RESET_ERRORS, - message=request, - optional=True, - ) + data = {"command": AdminCommandNames.RESET_ERRORS, "data": {}} + fqcn = FQCN.join([self.client.client_name, job_id]) + request = new_cell_message({}, data) + self.client.cell.fire_and_forget( + targets=fqcn, + channel=CellChannel.CLIENT_COMMAND, + topic=AdminCommandNames.RESET_ERRORS, + message=request, + optional=True, + ) except Exception as e: self.logger.error(f"reset_errors execution exception: {secure_format_exception(e)}.") @@ -332,41 +334,41 @@ def abort_app(self, job_id): Args: job_id: the job_id """ - with self.lock: - # When the HeartBeat cleanup process try to abort the worker process, the job maybe already terminated, - # Use retry to avoid print out the error stack trace. - retry = 1 - while retry >= 0: - process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED) - if process_status == ClientStatus.STARTED: - try: + # When the HeartBeat cleanup process try to abort the worker process, the job maybe already terminated, + # Use retry to avoid print out the error stack trace. + retry = 1 + while retry >= 0: + process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED) + if process_status == ClientStatus.STARTED: + try: + with self.lock: child_process = self.run_processes[job_id][RunProcessKey.CHILD_PROCESS] - data = {} - fqcn = FQCN.join([self.client.client_name, job_id]) - request = new_cell_message({}, data) - self.client.cell.fire_and_forget( - targets=fqcn, - channel=CellChannel.CLIENT_COMMAND, - topic=AdminCommandNames.ABORT, - message=request, - optional=True, - ) - self.logger.debug("abort sent to worker") - t = threading.Thread(target=self._terminate_process, args=[child_process, job_id]) - t.start() - t.join() - break - except Exception as e: - if retry == 0: - self.logger.error( - f"abort_worker_process execution exception: {secure_format_exception(e)} for run: {job_id}." - ) - secure_log_traceback() - retry -= 1 - time.sleep(5.0) - else: - self.logger.info(f"Client worker process for run: {job_id} was already terminated.") + data = {} + fqcn = FQCN.join([self.client.client_name, job_id]) + request = new_cell_message({}, data) + self.client.cell.fire_and_forget( + targets=fqcn, + channel=CellChannel.CLIENT_COMMAND, + topic=AdminCommandNames.ABORT, + message=request, + optional=True, + ) + self.logger.debug("abort sent to worker") + t = threading.Thread(target=self._terminate_process, args=[child_process, job_id]) + t.start() + t.join() break + except Exception as e: + if retry == 0: + self.logger.error( + f"abort_worker_process execution exception: {secure_format_exception(e)} for run: {job_id}." + ) + secure_log_traceback() + retry -= 1 + time.sleep(5.0) + else: + self.logger.info(f"Client worker process for run: {job_id} was already terminated.") + break self.logger.info("Client worker process is terminated.") @@ -405,25 +407,23 @@ def abort_task(self, job_id): Args: job_id: the job_id """ - with self.lock: - process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED) - if process_status == ClientStatus.STARTED: - data = {"command": AdminCommandNames.ABORT_TASK, "data": {}} - fqcn = FQCN.join([self.client.client_name, job_id]) - request = new_cell_message({}, data) - self.client.cell.fire_and_forget( - targets=fqcn, - channel=CellChannel.CLIENT_COMMAND, - topic=AdminCommandNames.ABORT_TASK, - message=request, - optional=True, - ) - self.logger.debug("abort_task sent") + process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.NOT_STARTED) + if process_status == ClientStatus.STARTED: + data = {"command": AdminCommandNames.ABORT_TASK, "data": {}} + fqcn = FQCN.join([self.client.client_name, job_id]) + request = new_cell_message({}, data) + self.client.cell.fire_and_forget( + targets=fqcn, + channel=CellChannel.CLIENT_COMMAND, + topic=AdminCommandNames.ABORT_TASK, + message=request, + optional=True, + ) + self.logger.debug("abort_task sent") def _wait_child_process_finish(self, client, job_id, allocated_resource, token, resource_manager, workspace): self.logger.info(f"run ({job_id}): waiting for child worker process to finish.") - with self.lock: - child_process = self.run_processes.get(job_id, {}).get(RunProcessKey.CHILD_PROCESS) + child_process = self.run_processes.get(job_id, {}).get(RunProcessKey.CHILD_PROCESS) if child_process: child_process.wait() @@ -452,13 +452,13 @@ def _wait_child_process_finish(self, client, job_id, allocated_resource, token, resource_manager.free_resources( resources=allocated_resource, token=token, fl_ctx=client.engine.new_context() ) - self.run_processes.pop(job_id, None) + with self.lock: + self.run_processes.pop(job_id, None) self.logger.debug(f"run ({job_id}): child worker resources freed.") def get_status(self, job_id): - with self.lock: - process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.STOPPED) - return process_status + process_status = self.run_processes.get(job_id, {}).get(RunProcessKey.STATUS, ClientStatus.STOPPED) + return process_status def get_run_processes_keys(self): with self.lock: diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index 6abd8068f9..b6c01f8d43 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -93,7 +93,7 @@ def __init__( cell=cell, client_register_interval=client_args.get("client_register_interval", 2.0), timeout=client_args.get("communication_timeout", 30.0), - maint_msg_timeout=client_args.get("maint_msg_timeout", 5.0), + maint_msg_timeout=client_args.get("maint_msg_timeout", 30.0), ) self.secure_train = secure_train diff --git a/nvflare/tool/job/config/configer.py b/nvflare/tool/job/config/configer.py index 0221ac0b3a..e50dafc254 100644 --- a/nvflare/tool/job/config/configer.py +++ b/nvflare/tool/job/config/configer.py @@ -18,6 +18,7 @@ from pyhocon import ConfigFactory, ConfigTree from nvflare.fuel.utils.config import ConfigFormat +from nvflare.lighter.tool_consts import NVFLARE_PREFIX from nvflare.tool.job.config.config_indexer import KeyIndex, build_reverse_order_index from nvflare.tool.job.job_client_const import ( APP_CONFIG_DIR, @@ -436,7 +437,7 @@ def build_config_file_indices(job_folder: str, app_names: List[str]) -> Dict[str for root, dirs, files in os.walk(custom_dir): for f in files: for ext in config_extensions: - if f.endswith(ext): + if f.endswith(ext) and not f.startswith(NVFLARE_PREFIX): file = os.path.join(root, f) config_files = app_config_files.get(app_name, []) config_files.append(file) diff --git a/nvflare/tool/job/job_cli.py b/nvflare/tool/job/job_cli.py index 9b3522c83e..ceab31fe58 100644 --- a/nvflare/tool/job/job_cli.py +++ b/nvflare/tool/job/job_cli.py @@ -41,13 +41,13 @@ JOB_CONFIG_FILE_NAME, JOB_CONFIG_VAR_NAME, JOB_CONFIG_VAR_VALUE, - JOB_INFO_CLIENT_TYPE, - JOB_INFO_CLIENT_TYPE_KEY, JOB_INFO_CONF, JOB_INFO_CONTROLLER_TYPE, JOB_INFO_CONTROLLER_TYPE_KEY, JOB_INFO_DESC, JOB_INFO_DESC_KEY, + JOB_INFO_EXECUTION_API_TYPE, + JOB_INFO_EXECUTION_API_TYPE_KEY, JOB_INFO_KEYS, JOB_INFO_MD, JOB_META_BASE_NAME, @@ -318,13 +318,13 @@ def display_available_templates(template_index_conf): print("-" * total_length) name_fix_length = 20 description_fix_length = 60 - controller_type_fix_length = 20 - client_category_fix_length = 20 + controller_type_fix_length = 17 + execution_api_type_fix_length = 23 name = fix_length_format("name", name_fix_length) description = fix_length_format(JOB_INFO_DESC, description_fix_length) - client_category = fix_length_format(JOB_INFO_CLIENT_TYPE, client_category_fix_length) + execution_api_type = fix_length_format(JOB_INFO_EXECUTION_API_TYPE, execution_api_type_fix_length) controller_type = fix_length_format(JOB_INFO_CONTROLLER_TYPE, controller_type_fix_length) - print(" " * left_margin, name, description, controller_type, client_category) + print(" " * left_margin, name, description, controller_type, execution_api_type) print("-" * total_length) for file_path in sorted(template_registry.keys()): name = os.path.basename(file_path) @@ -333,9 +333,11 @@ def display_available_templates(template_index_conf): template_info = template_registry.get(name) name = fix_length_format(name, name_fix_length) description = fix_length_format(template_info.get(JOB_INFO_DESC_KEY), description_fix_length) - client_category = fix_length_format(template_info.get(JOB_INFO_CLIENT_TYPE_KEY), client_category_fix_length) + execution_api_type = fix_length_format( + template_info.get(JOB_INFO_EXECUTION_API_TYPE_KEY), execution_api_type_fix_length + ) controller_type = fix_length_format(template_info.get(JOB_INFO_CONTROLLER_TYPE_KEY), controller_type_fix_length) - print(" " * left_margin, name, description, controller_type, client_category) + print(" " * left_margin, name, description, controller_type, execution_api_type) print("-" * total_length) diff --git a/nvflare/tool/job/job_client_const.py b/nvflare/tool/job/job_client_const.py index 8a26e64efc..8c37df5673 100644 --- a/nvflare/tool/job/job_client_const.py +++ b/nvflare/tool/job/job_client_const.py @@ -16,8 +16,8 @@ JOB_INFO_DESC = "Description" JOB_INFO_CONTROLLER_TYPE_KEY = "controller_type" JOB_INFO_CONTROLLER_TYPE = "Controller Type" -JOB_INFO_CLIENT_TYPE_KEY = "client_category" -JOB_INFO_CLIENT_TYPE = "Client Category" +JOB_INFO_EXECUTION_API_TYPE_KEY = "execution_api_type" +JOB_INFO_EXECUTION_API_TYPE = "Execution API Type" JOB_TEMPLATES = "job_templates" JOB_TEMPLATE = "job_template" @@ -25,7 +25,7 @@ JOB_INFO_CONF = "info.conf" JOB_INFO_MD = "info.md" -JOB_INFO_KEYS = [JOB_INFO_DESC_KEY, JOB_INFO_CONTROLLER_TYPE_KEY, JOB_INFO_CLIENT_TYPE_KEY] +JOB_INFO_KEYS = [JOB_INFO_DESC_KEY, JOB_INFO_CONTROLLER_TYPE_KEY, JOB_INFO_EXECUTION_API_TYPE_KEY] CONFIG_FILE_BASE_NAME_WO_EXTS = ["config_fed_client", "config_fed_server", "meta"] APP_CONFIG_FILE_BASE_NAMES = ["config_fed_client", "config_fed_server"] diff --git a/nvflare/tool/poc/poc_commands.py b/nvflare/tool/poc/poc_commands.py index 7809e64cd8..d1fcc5b274 100644 --- a/nvflare/tool/poc/poc_commands.py +++ b/nvflare/tool/poc/poc_commands.py @@ -164,19 +164,19 @@ def _prepare_jobs_dir(jobs_dir: str, workspace: str, config_packages: Optional[T dst = os.path.join(console_dir, transfer) if not is_dir_empty(dst): print(" ") - answer = input(f"Examples at {dst} is already exists, replace with new one ? (y/N) ") + answer = input(f"job directory at {dst} is already exists, replace with new one ? (y/N) ") if answer.strip().upper() == "Y": if os.path.islink(dst): os.unlink(dst) if os.path.isdir(dst): shutil.rmtree(dst, ignore_errors=True) - print(f"link examples from {src} to {dst}") + print(f"link job directory from {src} to {dst}") os.symlink(src, dst) else: if os.path.isdir(dst): shutil.rmtree(dst, ignore_errors=True) - print(f"link examples from {src} to {dst}") + print(f"link job directory from {src} to {dst}") os.symlink(src, dst) diff --git a/research/auto-fed-rl/requirements.txt b/research/auto-fed-rl/requirements.txt index 0804527963..3bbfea441b 100644 --- a/research/auto-fed-rl/requirements.txt +++ b/research/auto-fed-rl/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc torch torchvision tensorboard diff --git a/research/fed-ce/requirements.txt b/research/fed-ce/requirements.txt index b7dd1625cf..5757f5b0ea 100644 --- a/research/fed-ce/requirements.txt +++ b/research/fed-ce/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc torch torchvision tensorboard diff --git a/research/fed-sm/requirements.txt b/research/fed-sm/requirements.txt index 56ed083280..71feadf756 100644 --- a/research/fed-sm/requirements.txt +++ b/research/fed-sm/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc websockets torch torchvision diff --git a/research/one-shot-vfl/requirements.txt b/research/one-shot-vfl/requirements.txt index 2f293a7a7c..0e0eac05ba 100644 --- a/research/one-shot-vfl/requirements.txt +++ b/research/one-shot-vfl/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc torch torchvision tensorboard diff --git a/research/quantifying-data-leakage/requirements.txt b/research/quantifying-data-leakage/requirements.txt index 060b0dd44c..b362ba016b 100644 --- a/research/quantifying-data-leakage/requirements.txt +++ b/research/quantifying-data-leakage/requirements.txt @@ -1,4 +1,4 @@ -nvflare>=2.3.0 +nvflare~=2.4.0rc pytorch-ignite>=0.4.10 torchvision monai>=1.0.1 diff --git a/runtest.sh b/runtest.sh index e7644cf736..b7df3b6be1 100755 --- a/runtest.sh +++ b/runtest.sh @@ -92,7 +92,7 @@ function check_license() { folders_to_check_license="nvflare examples tests integration research" echo "checking license header in folder: $folders_to_check_license" (grep -r --include "*.py" --exclude-dir "*protos*" -L \ - "\(# Copyright (c) \(2021\|2022\|2023\), NVIDIA CORPORATION. All rights reserved.\)\|\(This file is released into the public domain.\)" \ + "\(# Copyright (c) \(2021\|2022\|2023\|2024\), NVIDIA CORPORATION. All rights reserved.\)\|\(This file is released into the public domain.\)" \ ${folders_to_check_license} || true) > no_license.lst if [ -s no_license.lst ]; then # The file is not-empty. diff --git a/setup.cfg b/setup.cfg index 23868dbe05..7b1b6e834f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -46,6 +46,8 @@ PT = SKLEARN = scikit-learn TRACKING = + mlflow + wandb tensorboard CONFIG = omegaconf @@ -55,6 +57,8 @@ app_opt = %(PT)s %(SKLEARN)s %(TRACKING)s + pytorch_lightning + xgboost app_opt_mac = %(PT)s %(SKLEARN)s diff --git a/tests/integration_test/data/test_configs/standalone_job/hello_numpy_examples.yml b/tests/integration_test/data/test_configs/standalone_job/hello_numpy_examples.yml index e578f70420..f3a87da6d7 100644 --- a/tests/integration_test/data/test_configs/standalone_job/hello_numpy_examples.yml +++ b/tests/integration_test/data/test_configs/standalone_job/hello_numpy_examples.yml @@ -43,3 +43,19 @@ tests: args: { server_model_names: ["server"] } - path: tests.integration_test.src.validators.NumpySAGResultValidator args: { expected_result: [ [ 4, 5, 6 ], [ 7, 8, 9 ], [ 10, 11, 12 ] ] } + - test_name: "run hello-ccwf" + # TODO: add a result validator for the "models" saved on client site (ccwf) + event_sequence: + - "trigger": + "type": "server_log" + "data": "Server started" + "actions": [ "submit_job hello-ccwf/jobs/swarm_cse_numpy" ] + "result": + "type": "job_submit_success" + - "trigger": + "type": "run_state" + "data": { "run_finished": True } + "actions": [ "ensure_current_job_done" ] + "result": + "type": "run_state" + "data": { "run_finished": True } diff --git a/tests/integration_test/data/test_configs/standalone_job/hello_tf_examples.yml b/tests/integration_test/data/test_configs/standalone_job/hello_tf_examples.yml index 7f104c77b0..39b1a59c3b 100644 --- a/tests/integration_test/data/test_configs/standalone_job/hello_tf_examples.yml +++ b/tests/integration_test/data/test_configs/standalone_job/hello_tf_examples.yml @@ -22,6 +22,10 @@ tests: "data": { "run_finished": True } validators: - path: tests.integration_test.src.validators.TFModelValidator + setup: + - python -c "import tensorflow as tf; tf.keras.datasets.mnist.load_data()" + teardown: + - rm ~/.keras/datasets/mnist.npz - test_name: "run hello-tf2" event_sequence: - "trigger": @@ -39,3 +43,7 @@ tests: "data": { "run_finished": True } validators: - path: tests.integration_test.src.validators.TFModelValidator + setup: + - python -c "import tensorflow as tf; tf.keras.datasets.mnist.load_data()" + teardown: + - rm ~/.keras/datasets/mnist.npz diff --git a/tests/integration_test/data/test_configs/standalone_job/internal_tf.yml b/tests/integration_test/data/test_configs/standalone_job/internal_tf.yml index 38bfb1641e..59daaf8dbe 100644 --- a/tests/integration_test/data/test_configs/standalone_job/internal_tf.yml +++ b/tests/integration_test/data/test_configs/standalone_job/internal_tf.yml @@ -22,3 +22,7 @@ tests: "data": { "run_finished": True } validators: - path: tests.integration_test.src.validators.TFModelValidator + setup: + - python -c "import tensorflow as tf; tf.keras.datasets.mnist.load_data()" + teardown: + - rm ~/.keras/datasets/mnist.npz diff --git a/tests/integration_test/data/test_configs/standalone_job/xgb_histogram_examples.yml b/tests/integration_test/data/test_configs/standalone_job/xgb_histogram_examples.yml new file mode 100644 index 0000000000..38f3cd71ec --- /dev/null +++ b/tests/integration_test/data/test_configs/standalone_job/xgb_histogram_examples.yml @@ -0,0 +1,44 @@ +n_servers: 1 +n_clients: 2 +additional_python_paths: +- ../../examples/advanced/xgboost +cleanup: true +jobs_root_dir: ../../examples/advanced/xgboost/histogram-based/jobs + + +tests: +- test_name: Test a simplified copy of job higgs_2_histogram_uniform_split_uniform_lr + for xgboost histogram-based example. + event_sequence: + - actions: + - submit_job higgs_2_histogram_uniform_split_uniform_lr_copy + result: + type: job_submit_success + trigger: + data: Server started + type: server_log + - actions: + - ensure_current_job_done + result: + data: + run_finished: true + type: run_state + trigger: + data: + run_finished: true + type: run_state + setup: + - cp ../../examples/advanced/xgboost/histogram-based/requirements.txt + ../../examples/advanced/xgboost/histogram-based/temp_requirements.txt + - sed -i '/nvflare\|jupyter\|notebook/d' ../../examples/advanced/xgboost/histogram-based/temp_requirements.txt + - pip install -r ../../examples/advanced/xgboost/histogram-based/temp_requirements.txt + - bash ../../examples/advanced/xgboost/histogram-based/prepare_data.sh + - python3 ../../examples/advanced/xgboost/utils/prepare_job_config.py --site_num 2 --training_mode histogram + --split_method uniform --lr_mode uniform --nthread 16 --tree_method hist + - python3 convert_to_test_job.py + --job ../../examples/advanced/xgboost/histogram-based/jobs/higgs_2_histogram_uniform_split_uniform_lr + --post _copy + - rm -f ../../examples/advanced/xgboost/histogram-based/temp_requirements.txt + teardown: + - rm -rf ../../examples/advanced/xgboost/histogram-based/jobs/higgs_2_histogram_uniform_split_uniform_lr + - rm -rf ../../examples/advanced/xgboost/histogram-based/jobs/higgs_2_histogram_uniform_split_uniform_lr_copy diff --git a/tests/integration_test/data/test_configs/standalone_job/xgb_tree_examples.yml b/tests/integration_test/data/test_configs/standalone_job/xgb_tree_examples.yml new file mode 100644 index 0000000000..145b6aecf0 --- /dev/null +++ b/tests/integration_test/data/test_configs/standalone_job/xgb_tree_examples.yml @@ -0,0 +1,75 @@ +n_servers: 1 +n_clients: 5 +additional_python_paths: +- ../../examples/advanced/xgboost +cleanup: true +jobs_root_dir: ../../examples/advanced/xgboost/tree-based/jobs + + +tests: +- test_name: Test a simplified copy of job higgs_5_cyclic_uniform_split_uniform_lr + for xgboost tree-based example. + event_sequence: + - actions: + - submit_job higgs_5_cyclic_uniform_split_uniform_lr_copy + result: + type: job_submit_success + trigger: + data: Server started + type: server_log + - actions: + - ensure_current_job_done + result: + data: + run_finished: true + type: run_state + trigger: + data: + run_finished: true + type: run_state + setup: + - cp ../../examples/advanced/xgboost/tree-based/requirements.txt + ../../examples/advanced/xgboost/tree-based/temp_requirements.txt + - sed -i '/nvflare\|jupyter\|notebook/d' ../../examples/advanced/xgboost/tree-based/temp_requirements.txt + - pip install -r ../../examples/advanced/xgboost/tree-based/temp_requirements.txt + - bash ../../examples/advanced/xgboost/tree-based/prepare_data.sh + - python3 ../../examples/advanced/xgboost/utils/prepare_job_config.py --site_num 5 --training_mode cyclic + --split_method uniform --lr_mode uniform --nthread 16 --tree_method hist + - python3 convert_to_test_job.py + --job ../../examples/advanced/xgboost/tree-based/jobs/higgs_5_cyclic_uniform_split_uniform_lr + --post _copy + - rm -f ../../examples/advanced/xgboost/tree-based/temp_requirements.txt + teardown: + - rm -rf ../../examples/advanced/xgboost/tree-based/jobs/higgs_5_cyclic_uniform_split_uniform_lr + - rm -rf ../../examples/advanced/xgboost/tree-based/jobs/higgs_5_cyclic_uniform_split_uniform_lr_copy + +- test_name: Test a simplified copy of job higgs_5_bagging_uniform_split_uniform_lr + for xgboost tree-based example. + event_sequence: + - actions: + - submit_job higgs_5_bagging_uniform_split_uniform_lr_copy + result: + type: job_submit_success + trigger: + data: Server started + type: server_log + - actions: + - ensure_current_job_done + result: + data: + run_finished: true + type: run_state + trigger: + data: + run_finished: true + type: run_state + setup: + - python3 ../../examples/advanced/xgboost/utils/prepare_job_config.py --site_num 5 --training_mode bagging + --split_method uniform --lr_mode uniform --nthread 16 --tree_method hist + - python3 convert_to_test_job.py + --job ../../examples/advanced/xgboost/tree-based/jobs/higgs_5_bagging_uniform_split_uniform_lr + --post _copy + - rm -f ../../examples/advanced/xgboost/tree-based/temp_requirements.txt + teardown: + - rm -rf ../../examples/advanced/xgboost/tree-based/jobs/higgs_5_bagging_uniform_split_uniform_lr + - rm -rf ../../examples/advanced/xgboost/tree-based/jobs/higgs_5_bagging_uniform_split_uniform_lr_copy diff --git a/tests/integration_test/src/example.py b/tests/integration_test/src/example.py index f46d01b53c..58dbeeea35 100644 --- a/tests/integration_test/src/example.py +++ b/tests/integration_test/src/example.py @@ -17,7 +17,7 @@ class Example: - """This class represents a standardized example structure in NVFlare.""" + """This class represents a standardized example folder structure in NVFlare.""" def __init__( self, @@ -27,9 +27,37 @@ def __init__( additional_python_path: Optional[str] = None, prepare_data_script: Optional[str] = None, ): + """Constructor of Example. + + A standardized example folder looks like the following: + + .. code-block + + ./[example_root] + ./[jobs_folder_in_example] + ./job_name1 + ./job_name2 + ./job_name3 + ./[requirements] + ./[prepare_data_script] + + For example: + + .. code-block + + ./cifar10-sim + ./jobs + ./cifar10_central + ./cifar10_fedavg + ./cifar10_fedopt + ... + ./requirements.txt + ./prepare_data.sh + + """ self.root = os.path.abspath(root) if not os.path.exists(self.root): - raise FileNotFoundError("Example root directory does not exist.") + raise FileNotFoundError("Example's root directory does not exist.") self.name = os.path.basename(self.root) diff --git a/tests/integration_test/src/utils.py b/tests/integration_test/src/utils.py index ac0d09291e..0a945607c9 100644 --- a/tests/integration_test/src/utils.py +++ b/tests/integration_test/src/utils.py @@ -292,6 +292,7 @@ def _replace_config_fed_client(client_json_path: str): with open(client_json_path, "r+") as f: config_fed_client = json.load(f) config_fed_client["TRAIN_SPLIT_ROOT"] = "/tmp/nvflare/test_data" + config_fed_client["num_rounds"] = 2 config_fed_client["AGGREGATION_EPOCHS"] = 1 f.seek(0) json.dump(config_fed_client, f, indent=4) @@ -318,67 +319,84 @@ def simplify_job(job_folder_path: str, postfix: str = POSTFIX): def generate_test_config_yaml_for_example( example: Example, project_yaml: str = PROJECT_YAML, - postfix: str = POSTFIX, + job_postfix: str = POSTFIX, ) -> List[str]: - """Generates test configuration yaml for NVFlare example. + """Generates test configurations for an NVFlare example folder. Args: - example: A well-formatted NVFlare example. - project_yaml: Project yaml file for the testing of this example. - postfix: Postfix for the newly generated job. + example (Example): A well-formatted NVFlare example folder. + project_yaml (str): Project yaml file for the testing of this example. + job_postfix (str): Postfix for the newly generated job. """ - output_yamls = [] os.makedirs(OUTPUT_YAML_DIR, exist_ok=True) for job in os.listdir(example.jobs_root_dir): - output_yaml = os.path.join(OUTPUT_YAML_DIR, f"{example.name}_{job}.yml") - job_dir = os.path.join(example.jobs_root_dir, job) - requirements_file = os.path.join(example.root, example.requirements_file) - new_requirements_file = os.path.join(example.root, "temp_requirements.txt") - exclude_requirements = "\\|".join(REQUIREMENTS_TO_EXCLUDE) - - setup = [ - f"cp {requirements_file} {new_requirements_file}", - f"sed -i '/{exclude_requirements}/d' {new_requirements_file}", - f"pip install -r {new_requirements_file}", - ] - if example.prepare_data_script is not None: - setup.append(f"bash {example.prepare_data_script}") - setup.append(f"python convert_to_test_job.py --job {job_dir} --post {postfix}") - setup.append(f"rm -f {new_requirements_file}") - - config = { - "ha": True, - "jobs_root_dir": example.jobs_root_dir, - "cleanup": True, - "project_yaml": project_yaml, - "additional_python_paths": example.additional_python_paths, - "tests": [ - { - "test_name": f"Test a simplified copy of job {job} for example {example.name}.", - "event_sequence": [ - { - "trigger": {"type": "server_log", "data": "Server started"}, - "actions": [f"submit_job {job}{postfix}"], - "result": {"type": "job_submit_success"}, - }, - { - "trigger": {"type": "run_state", "data": {"run_finished": True}}, - "actions": ["ensure_current_job_done"], - "result": {"type": "run_state", "data": {"run_finished": True}}, - }, - ], - "setup": setup, - "teardown": [f"rm -rf {job_dir}{postfix}"], - } - ], - } - with open(output_yaml, "w") as yaml_file: - yaml.dump(config, yaml_file, default_flow_style=False) + output_yaml = _generate_test_config_for_one_job(example, job, project_yaml, job_postfix) output_yamls.append(output_yaml) return output_yamls +def _generate_test_config_for_one_job( + example: Example, + job: str, + project_yaml: str = PROJECT_YAML, + postfix: str = POSTFIX, +) -> str: + """Generates test configuration yaml for an NVFlare example. + + Args: + example (Example): A well-formatted NVFlare example. + job (str): name of the job. + project_yaml (str): Project yaml file for the testing of this example. + postfix (str): Postfix for the newly generated job. + """ + output_yaml = os.path.join(OUTPUT_YAML_DIR, f"{example.name}_{job}.yml") + job_dir = os.path.join(example.jobs_root_dir, job) + requirements_file = os.path.join(example.root, example.requirements_file) + new_requirements_file = os.path.join(example.root, "temp_requirements.txt") + exclude_requirements = "\\|".join(REQUIREMENTS_TO_EXCLUDE) + + setup = [ + f"cp {requirements_file} {new_requirements_file}", + f"sed -i '/{exclude_requirements}/d' {new_requirements_file}", + f"pip install -r {new_requirements_file}", + ] + if example.prepare_data_script is not None: + setup.append(f"bash {example.prepare_data_script}") + setup.append(f"python convert_to_test_job.py --job {job_dir} --post {postfix}") + setup.append(f"rm -f {new_requirements_file}") + + config = { + "ha": True, + "jobs_root_dir": example.jobs_root_dir, + "cleanup": True, + "project_yaml": project_yaml, + "additional_python_paths": example.additional_python_paths, + "tests": [ + { + "test_name": f"Test a simplified copy of job {job} for example {example.name}.", + "event_sequence": [ + { + "trigger": {"type": "server_log", "data": "Server started"}, + "actions": [f"submit_job {job}{postfix}"], + "result": {"type": "job_submit_success"}, + }, + { + "trigger": {"type": "run_state", "data": {"run_finished": True}}, + "actions": ["ensure_current_job_done"], + "result": {"type": "run_state", "data": {"run_finished": True}}, + }, + ], + "setup": setup, + "teardown": [f"rm -rf {job_dir}{postfix}"], + } + ], + } + with open(output_yaml, "w") as yaml_file: + yaml.dump(config, yaml_file, default_flow_style=False) + return output_yaml + + def _read_admin_json_file(admin_json_file) -> dict: if not os.path.exists(admin_json_file): raise RuntimeError("Missing admin json file.") diff --git a/tests/integration_test/src/validators/np_sag_result_validator.py b/tests/integration_test/src/validators/np_sag_result_validator.py index 0e8690abb1..bbf6c94d63 100644 --- a/tests/integration_test/src/validators/np_sag_result_validator.py +++ b/tests/integration_test/src/validators/np_sag_result_validator.py @@ -20,9 +20,10 @@ class NumpySAGResultValidator(FinishJobResultValidator): - def __init__(self, expected_result): + def __init__(self, expected_result, model_name: str = "server.npy"): super().__init__() self.expected_result = np.array(expected_result) + self.model_name = model_name def validate_finished_results(self, job_result, client_props) -> bool: server_run_dir = job_result["workspace_root"] @@ -32,7 +33,7 @@ def validate_finished_results(self, job_result, client_props) -> bool: self.logger.error(f"models dir {models_dir} doesn't exist.") return False - model_path = os.path.join(models_dir, "server.npy") + model_path = os.path.join(models_dir, self.model_name) if not os.path.isfile(model_path): self.logger.error(f"model_path {model_path} doesn't exist.") return False diff --git a/tests/integration_test/test_configs.yml b/tests/integration_test/test_configs.yml index a179a6ce79..d75f59a784 100644 --- a/tests/integration_test/test_configs.yml +++ b/tests/integration_test/test_configs.yml @@ -29,3 +29,6 @@ test_configs: - ./data/test_configs/standalone_job/cifar_examples.yml stats: - ./data/test_configs/standalone_job/image_stats.yml + xgboost: + - ./data/test_configs/standalone_job/xgb_histogram_examples.yml + - ./data/test_configs/standalone_job/xgb_tree_examples.yml diff --git a/tests/unit_test/fuel/data_event/__init__.py b/tests/unit_test/fuel/data_event/__init__.py new file mode 100644 index 0000000000..d9155f923f --- /dev/null +++ b/tests/unit_test/fuel/data_event/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/unit_test/fuel/data_event/data_bus_test.py b/tests/unit_test/fuel/data_event/data_bus_test.py new file mode 100644 index 0000000000..4979688d13 --- /dev/null +++ b/tests/unit_test/fuel/data_event/data_bus_test.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from nvflare.fuel.data_event.data_bus import DataBus +from nvflare.fuel.data_event.event_manager import EventManager + + +class TestMessageBus(unittest.TestCase): + def setUp(self): + self.data_bus = DataBus() + self.event_manager = EventManager(self.data_bus) + + def test_subscribe_and_publish(self): + result = {"count": 0} + + def callback_function(topic, datum, data_bus): + result["count"] += 1 + + self.data_bus.subscribe(["test_topic"], callback_function) + self.data_bus.publish(["test_topic"], "Test Message 1") + self.data_bus.publish(["test_topic"], "Test Message 2") + + self.assertEqual(result["count"], 2) + + def test_singleton_message_bus(self): + data_bus1 = DataBus() + data_bus1.put_data("user_1", "Hello from User 1!") + user_1_message = data_bus1.get_data("user_1") + self.assertEqual(user_1_message, "Hello from User 1!") + + message_bus2 = DataBus() + user_1_message = message_bus2.get_data("user_1") + self.assertEqual(user_1_message, "Hello from User 1!") + + def test_send_message_and_receive_messages(self): + self.data_bus.put_data("user_1", "Hello from User 1!") + self.data_bus.put_data("user_2", "Greetings from User 2!") + + user_1_message = self.data_bus.get_data("user_1") + user_2_message = self.data_bus.get_data("user_2") + + self.assertEqual(user_1_message, "Hello from User 1!") + self.assertEqual(user_2_message, "Greetings from User 2!") + + self.data_bus.put_data("user_1", "2nd greetings from User 1!") + user_1_message = self.data_bus.get_data("user_1") + self.assertEqual(user_1_message, "2nd greetings from User 1!") + + def test_send_message_and_receive_messages_abnormal(self): + user_3_message = self.data_bus.get_data("user_3") + self.assertEqual(user_3_message, None) + + def test_fire_event(self): + + result = { + "test_event": {"event_received": False}, + "dev_event": {"event_received": False}, + "prod_event": {"event_received": False}, + } + + def event_handler(topic, data, data_bus): + result[topic]["event_received"] = True + if data_bus.get_data("hi") == "hello": + self.data_bus.put_data("hi", "hello-world") + + self.data_bus.put_data("hi", "hello") + + self.data_bus.subscribe(["test_event", "dev_event", "prod_event"], event_handler) + self.event_manager.fire_event("test_event", {"key": "value"}) + self.event_manager.fire_event("dev_event", {"key": "value"}) + + self.assertTrue(result["test_event"]["event_received"]) + self.assertTrue(result["dev_event"]["event_received"]) + self.assertFalse(result["prod_event"]["event_received"])