Skip to content

Commit

Permalink
add to() routine
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Apr 5, 2024
1 parent 947be32 commit d2d1cbc
Showing 1 changed file with 32 additions and 29 deletions.
61 changes: 32 additions & 29 deletions examples/hello-world/jobs/client-api-pt/client_api_pt_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import uuid
from typing import Union

from nvflare.app_common.job.fed_app_config import ClientAppConfig, ServerAppConfig, FedAppConfig
from nvflare.app_common.job.fed_job_config import FedJobConfig
Expand All @@ -30,42 +31,33 @@
from net import Net
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver
from nvflare.app_common.job.fed_app_config import ClientAppConfig, ServerAppConfig

# TODO: FedApp -> server, FedApp -> client


class FedJob:
def __init__(self, job_name="client-api-pt", workspace="/tmp/nvflare/simulator_workspace") -> None:
self.job_name = job_name
def __init__(self, name="client-api-pt", workspace="/tmp/nvflare/simulator_workspace") -> None:
self.job_name = name
self.job_id = str(uuid.uuid4())
self.workspace = workspace
self.root_url = ""

self.job = self.define_job()

def define_job(self) -> FedJobConfig:
# job = FedJobConfig(job_name="hello-pt", min_clients=2, mandatory_clients="site-1")
job: FedJobConfig = FedJobConfig(job_name=self.job_name, min_clients=2)

# TODO: implement in .to() call
server_app = self._create_server_app(min_clients=2, num_rounds=2)
app = FedAppConfig(server_app=server_app, client_app=None)
job.add_fed_app("server", app)
job.set_site_app("server", "server")

client_app = self._create_client_app(site_name="site-1", app_script="cifar10_fl.py")
app = FedAppConfig(server_app=None, client_app=client_app)
job.add_fed_app("app1", app)
job.set_site_app("site-1", "app1")

client_app = self._create_client_app(site_name="site-2", app_script="cifar10_fl.py")
app = FedAppConfig(server_app=None, client_app=client_app)
job.add_fed_app("app2", app)
job.set_site_app("site-2", "app2")

return job

def _create_client_app(self, site_name, app_script, app_config=""):
#self.job = self.define_job()
self.job: FedJobConfig = FedJobConfig(job_name=self.job_name, min_clients=1)

def to(self, app: Union[ClientAppConfig, ServerAppConfig], target):
if isinstance(app, ClientAppConfig):
app_config = FedAppConfig(server_app=None, client_app=app)
elif isinstance(app, ServerAppConfig):
app_config = FedAppConfig(server_app=app, client_app=None)
else:
raise ValueError(f"App needs to be of type `ClientAppConfig` or `ServerAppConfig` but was type {type(app)}")
app_name = f"app_{(uuid.uuid4())}"
self.job.add_fed_app(app_name, app_config)
self.job.set_site_app(target, app_name)

def create_client_app(self, site_name, app_script, app_config=""):
client_app = ClientAppConfig()
executor = PTClientAPILauncherExecutor(
launcher_id="launcher",
Expand Down Expand Up @@ -109,7 +101,7 @@ def _create_client_app(self, site_name, app_script, app_config=""):

return client_app

def _create_server_app(self, min_clients, num_rounds, model_class_path="net.Net"):
def create_server_app(self, min_clients, num_rounds, model_class_path="net.Net"):
server_app = ServerAppConfig()
controller = FedAvg(
min_clients=min_clients,
Expand Down Expand Up @@ -157,7 +149,18 @@ def simulator_run(self, job_root):


if __name__ == "__main__":
job = FedJob()
n_clients = 2

job = FedJob(name="cifar10_fedavg")

server_app = job.create_server_app(min_clients=2, num_rounds=2)
job.to(server_app, "server")

for i in range(n_clients):
client_app = job.create_client_app(site_name=f"site-{i}", app_script="cifar10_fl.py") # TODO: don't require site_name here
job.to(client_app, f"site-{i}")

#job.export_job("/tmp/nvflare/jobs")
job.simulator_run("/tmp/nvflare/jobs")


0 comments on commit d2d1cbc

Please sign in to comment.