Skip to content

Commit

Permalink
hide ControllerApp/ExecutorApp
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Apr 16, 2024
1 parent 36b59c7 commit ae54607
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 0 deletions.
43 changes: 43 additions & 0 deletions examples/hello-world/python_jobs/pt/client_api_pt_job2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from code.net import Net

from nvflare import FedJob2
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.in_process_client_api_executor import PTInProcessClientAPIExecutor


if __name__ == "__main__":
n_clients = 2
num_rounds = 2
train_script = "code/cifar10_fl.py"

job = FedJob2(name="cifar10_fedavg", init_model=Net(), external_scripts=[train_script]) # TODO: use load/save model in FedAvg

controller = FedAvg(
min_clients=n_clients,
num_rounds=num_rounds,
)
job.to(controller, "server")

for i in range(n_clients):
executor = PTInProcessClientAPIExecutor(
task_script_path=train_script,
task_script_args="" # --batch_size 32 --data_path f'/tmp/data/site-{i}'
)
job.to(executor, f"site-{i}", gpu=0)

job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir")
1 change: 1 addition & 0 deletions nvflare/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@

from nvflare.private.fed.app.simulator.simulator_runner import SimulatorRunner as SimulatorRunner
from nvflare.fed_job import FedJob, ControllerApp, ExecutorApp
from nvflare.fed_job2 import FedJob2
105 changes: 105 additions & 0 deletions nvflare/fed_job2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import uuid
from typing import List, Any

from nvflare.app_common.widgets.external_configurator import ExternalConfigurator
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.widgets.metric_relay import MetricRelay
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator
from nvflare.app_opt.pt import PTFileModelPersistor
from nvflare.app_opt.pt.file_model_locator import PTFileModelLocator
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver
from nvflare.fuel.utils.constants import Mode
from nvflare.fuel.utils.pipe.file_pipe import FilePipe
from nvflare.job_config.fed_app_config import ClientAppConfig, FedAppConfig, ServerAppConfig
from nvflare.job_config.fed_job_config import FedJobConfig
from nvflare.apis.filter import Filter
from .fed_job import FedApp, ControllerApp, ExecutorApp
from nvflare.apis.impl.controller import Controller
from nvflare.apis.executor import Executor


class FedJob2:
def __init__(self, name="fed_job", min_clients=1, mandatory_clients=None, init_model=None, external_scripts: List = None) -> None:
self.job_name = name
self.gpus = []
self.clients = []
self.job: FedJobConfig = FedJobConfig(job_name=self.job_name, min_clients=min_clients, mandatory_clients=mandatory_clients)
self.init_model = init_model
self.external_scripts = external_scripts
self._deploy_map = {}

def to(self, obj: Any, target: str, tasks: List[str] = None, gpu: int = None):
if isinstance(obj, Controller):
if target not in self._deploy_map:
self._deploy_map[target] = ControllerApp(init_model=self.init_model)
self._deploy_map[target].add_controller(obj)
elif isinstance(obj, Executor):
if target not in self._deploy_map:
self._deploy_map[target] = ExecutorApp(external_scripts=self.external_scripts)
self.clients.append(target)
if gpu is not None:
self.gpus.append(str(gpu))
self._deploy_map[target].add_executor(obj, tasks=tasks)
else:
if target not in self._deploy_map:
raise ValueError(f"{target} doesn't have a `Controller` or `Executor`. Deploy one first before adding components!")
self._deploy_map[target].add_component(obj)

def _deploy(self, app: FedApp, target: str):
if not isinstance(app, FedApp):
raise ValueError(f"App needs to be of type `FedApp` but was type {type(app)}")

client_server_config = app.get_app_config()
if isinstance(client_server_config, ClientAppConfig):
app_config = FedAppConfig(server_app=None, client_app=client_server_config)
app_name = f"app_client_{uuid.uuid4()}"
elif isinstance(client_server_config, ServerAppConfig):
app_config = FedAppConfig(server_app=client_server_config, client_app=None)
app_name = f"app_server_{uuid.uuid4()}"
else:
raise ValueError(
f"App needs to be of type `ClientAppConfig` or `ServerAppConfig` but was type {type(client_server_config)}"
)

self.job.add_fed_app(app_name, app_config)
self.job.set_site_app(target, app_name)

def _run_deploy(self):
for target in self._deploy_map:
self._deploy(self._deploy_map[target], target)

def export_job(self, job_root):
self._run_deploy()
self.job.generate_job_config(job_root)

def simulator_run(self, workspace, threads: int = None):
self._run_deploy()

n_clients = len(self.clients)
if threads is None:
threads = n_clients

job_root = os.path.join(workspace, "job")
self.job.simulator_run(
job_root,
workspace,
clients=",".join(self.clients),
n_clients=n_clients,
threads=threads,
gpu=",".join(self.gpus),
)

0 comments on commit ae54607

Please sign in to comment.