forked from NVIDIA/NVFlare
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
36b59c7
commit ae54607
Showing
3 changed files
with
149 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from code.net import Net | ||
|
||
from nvflare import FedJob2 | ||
from nvflare.app_common.workflows.fedavg import FedAvg | ||
from nvflare.app_opt.pt.in_process_client_api_executor import PTInProcessClientAPIExecutor | ||
|
||
|
||
if __name__ == "__main__": | ||
n_clients = 2 | ||
num_rounds = 2 | ||
train_script = "code/cifar10_fl.py" | ||
|
||
job = FedJob2(name="cifar10_fedavg", init_model=Net(), external_scripts=[train_script]) # TODO: use load/save model in FedAvg | ||
|
||
controller = FedAvg( | ||
min_clients=n_clients, | ||
num_rounds=num_rounds, | ||
) | ||
job.to(controller, "server") | ||
|
||
for i in range(n_clients): | ||
executor = PTInProcessClientAPIExecutor( | ||
task_script_path=train_script, | ||
task_script_args="" # --batch_size 32 --data_path f'/tmp/data/site-{i}' | ||
) | ||
job.to(executor, f"site-{i}", gpu=0) | ||
|
||
job.export_job("/tmp/nvflare/jobs/job_config") | ||
job.simulator_run("/tmp/nvflare/jobs/workdir") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import uuid | ||
from typing import List, Any | ||
|
||
from nvflare.app_common.widgets.external_configurator import ExternalConfigurator | ||
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector | ||
from nvflare.app_common.widgets.metric_relay import MetricRelay | ||
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator | ||
from nvflare.app_opt.pt import PTFileModelPersistor | ||
from nvflare.app_opt.pt.file_model_locator import PTFileModelLocator | ||
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver | ||
from nvflare.fuel.utils.constants import Mode | ||
from nvflare.fuel.utils.pipe.file_pipe import FilePipe | ||
from nvflare.job_config.fed_app_config import ClientAppConfig, FedAppConfig, ServerAppConfig | ||
from nvflare.job_config.fed_job_config import FedJobConfig | ||
from nvflare.apis.filter import Filter | ||
from .fed_job import FedApp, ControllerApp, ExecutorApp | ||
from nvflare.apis.impl.controller import Controller | ||
from nvflare.apis.executor import Executor | ||
|
||
|
||
class FedJob2: | ||
def __init__(self, name="fed_job", min_clients=1, mandatory_clients=None, init_model=None, external_scripts: List = None) -> None: | ||
self.job_name = name | ||
self.gpus = [] | ||
self.clients = [] | ||
self.job: FedJobConfig = FedJobConfig(job_name=self.job_name, min_clients=min_clients, mandatory_clients=mandatory_clients) | ||
self.init_model = init_model | ||
self.external_scripts = external_scripts | ||
self._deploy_map = {} | ||
|
||
def to(self, obj: Any, target: str, tasks: List[str] = None, gpu: int = None): | ||
if isinstance(obj, Controller): | ||
if target not in self._deploy_map: | ||
self._deploy_map[target] = ControllerApp(init_model=self.init_model) | ||
self._deploy_map[target].add_controller(obj) | ||
elif isinstance(obj, Executor): | ||
if target not in self._deploy_map: | ||
self._deploy_map[target] = ExecutorApp(external_scripts=self.external_scripts) | ||
self.clients.append(target) | ||
if gpu is not None: | ||
self.gpus.append(str(gpu)) | ||
self._deploy_map[target].add_executor(obj, tasks=tasks) | ||
else: | ||
if target not in self._deploy_map: | ||
raise ValueError(f"{target} doesn't have a `Controller` or `Executor`. Deploy one first before adding components!") | ||
self._deploy_map[target].add_component(obj) | ||
|
||
def _deploy(self, app: FedApp, target: str): | ||
if not isinstance(app, FedApp): | ||
raise ValueError(f"App needs to be of type `FedApp` but was type {type(app)}") | ||
|
||
client_server_config = app.get_app_config() | ||
if isinstance(client_server_config, ClientAppConfig): | ||
app_config = FedAppConfig(server_app=None, client_app=client_server_config) | ||
app_name = f"app_client_{uuid.uuid4()}" | ||
elif isinstance(client_server_config, ServerAppConfig): | ||
app_config = FedAppConfig(server_app=client_server_config, client_app=None) | ||
app_name = f"app_server_{uuid.uuid4()}" | ||
else: | ||
raise ValueError( | ||
f"App needs to be of type `ClientAppConfig` or `ServerAppConfig` but was type {type(client_server_config)}" | ||
) | ||
|
||
self.job.add_fed_app(app_name, app_config) | ||
self.job.set_site_app(target, app_name) | ||
|
||
def _run_deploy(self): | ||
for target in self._deploy_map: | ||
self._deploy(self._deploy_map[target], target) | ||
|
||
def export_job(self, job_root): | ||
self._run_deploy() | ||
self.job.generate_job_config(job_root) | ||
|
||
def simulator_run(self, workspace, threads: int = None): | ||
self._run_deploy() | ||
|
||
n_clients = len(self.clients) | ||
if threads is None: | ||
threads = n_clients | ||
|
||
job_root = os.path.join(workspace, "job") | ||
self.job.simulator_run( | ||
job_root, | ||
workspace, | ||
clients=",".join(self.clients), | ||
n_clients=n_clients, | ||
threads=threads, | ||
gpu=",".join(self.gpus), | ||
) |