Skip to content

Commit

Permalink
some redesign
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Apr 29, 2024
1 parent 14bd5e7 commit f157618
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions examples/hello-world/python_jobs/pt/client_api_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,23 @@
num_rounds = 2
train_script = "code/cifar10_fl.py"

job = FedJob(name="cifar10_fedavg", init_model=Net(), external_scripts=[train_script]) # TODO: use load/save model in FedAvg
job = FedJob(name="cifar10_fedavg") #, init_model=Net(), external_scripts=[train_script]) # TODO: use load/save model in FedAvg

controller = FedAvg(
min_clients=n_clients,
num_rounds=num_rounds,
)
job.to(controller, "server")

# Define the initial server model
job.to(Net(), "server") # TODO: default to PTFileModelPersistor -> persistor

for i in range(n_clients):
executor = ScriptExecutor(
task_script_path=train_script,
task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}"
)
job.to(executor, f"site-{i}", gpu=0)
job.to(executor, f"site-{i}", gpu=0) # TODO: get script path from executor

job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir")

0 comments on commit f157618

Please sign in to comment.