Skip to content

Commit

Permalink
add cyclic workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Apr 19, 2024
1 parent 58d9583 commit abb1144
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 15 deletions.
54 changes: 54 additions & 0 deletions examples/hello-world/python_jobs/pt/client_api_pt_cyclic_cc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from code.net import Net

from nvflare import FedJob
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare import ScriptExecutor
from nvflare.app_common.ccwf import CyclicServerController, CyclicClientController
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_common.ccwf.comps.simple_model_shareable_generator import SimpleModelShareableGenerator


if __name__ == "__main__":
n_clients = 2
num_rounds = 3
train_script = "code/cifar10_fl.py"

job = FedJob(name="cifar10_fedavg", init_model=Net(), external_scripts=[train_script]) # TODO: use load/save model in FedAvg

controller = CyclicServerController(
num_rounds=num_rounds,
max_status_report_interval=300
)
job.to(controller, "server")

for i in range(n_clients):
executor = ScriptExecutor(
task_script_path=train_script,
task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}"
)
job.to(executor, f"site-{i}", gpu=0)

# Add client-side controller for cyclic workflow
executor = CyclicClientController()
job.to(executor, f"site-{i}", tasks=["cyclic_*"])

# In swarm learning, each client uses a model persistor and shareable_generator
job.to(PTFileModelPersistor(model=Net()), f"site-{i}", id="persistor")
job.to(SimpleModelShareableGenerator(), f"site-{i}", id="shareable_generator")

job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir")
14 changes: 0 additions & 14 deletions examples/hello-world/python_jobs/pt/client_api_pt_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_common.aggregators.intime_accumulate_model_aggregator import InTimeAccumulateWeightedAggregator
from nvflare.app_common.ccwf.comps.simple_model_shareable_generator import SimpleModelShareableGenerator
from nvflare.app_common.launchers.subprocess_launcher import SubprocessLauncher
from nvflare.app_opt.pt.client_api_launcher_executor import PTClientAPILauncherExecutor
from nvflare.fuel.utils.constants import Mode
from nvflare.fuel.utils.pipe.file_pipe import FilePipe

Expand All @@ -49,18 +47,6 @@
)
job.to(executor, f"site-{i}", gpu=0, tasks=["train", "validate", "submit_model"])

# executor = PTClientAPILauncherExecutor(
# launcher_id="launcher",
# pipe_id="pipe"
# )
# job.to(executor, f"site-{i}", gpu=0, tasks=["train", "validate", "submit_model"])
#
# launcher = SubprocessLauncher(
# script=train_script,
# launch_once=True
# )
# job.to(launcher, f"site-{i}", id="launcher")

pipe = FilePipe(
mode=Mode.PASSIVE,
root_path="{WORKSPACE}/{JOB_ID}/{SITE_NAME}",
Expand Down
2 changes: 1 addition & 1 deletion nvflare/fed_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def to(self, obj: Any, target: str, tasks: List[str] = None, gpu: int = None, fi
if target not in self._deploy_map:
self._deploy_map[target] = ExecutorApp(external_scripts=self._external_scripts)
self.clients.append(target)
if gpu is not None:
if gpu is not None: # TODO: make sure GPUs are not added several times per client. Use dict as well?
self.gpus.append(str(gpu))
self._deploy_map[target].add_executor(obj, tasks=tasks)
else:
Expand Down

0 comments on commit abb1144

Please sign in to comment.