Skip to content

Commit

Permalink
add FedApp
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Apr 5, 2024
1 parent 6d9a56e commit ad945ec
Showing 1 changed file with 120 additions and 100 deletions.
220 changes: 120 additions & 100 deletions examples/hello-world/jobs/client-api-pt/client_api_pt_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import uuid
from typing import Union
from typing import Union, List

from nvflare.app_common.job.fed_app_config import ClientAppConfig, ServerAppConfig, FedAppConfig
from nvflare.app_common.job.fed_job_config import FedJobConfig
Expand All @@ -33,24 +33,34 @@
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver
from nvflare.app_common.job.fed_app_config import ClientAppConfig, ServerAppConfig

# TODO: FedApp -> server, FedApp -> client

class FedApp:
def __init__(self):
self.app = None # Union[ClientAppConfig, ServerAppConfig]

def get_app_config(self):
return self.app


class FedJob:
def __init__(self, name="client-api-pt", workspace="/tmp/nvflare/simulator_workspace") -> None:
self.job_name = name
self.workspace = workspace
self.root_url = ""

self.job: FedJobConfig = FedJobConfig(job_name=self.job_name, min_clients=1)

def to(self, app: Union[ClientAppConfig, ServerAppConfig], target):
if isinstance(app, ClientAppConfig):
app_config = FedAppConfig(server_app=None, client_app=app)
elif isinstance(app, ServerAppConfig):
app_config = FedAppConfig(server_app=app, client_app=None)
def to(self, app: FedApp, target):
if not isinstance(app, FedApp):
raise ValueError(f"App needs to be of type `FedApp` but was type {type(app)}")

client_server_config = app.get_app_config()
if isinstance(client_server_config, ClientAppConfig):
app_config = FedAppConfig(server_app=None, client_app=client_server_config)
elif isinstance(client_server_config, ServerAppConfig):
app_config = FedAppConfig(server_app=client_server_config, client_app=None)
else:
raise ValueError(f"App needs to be of type `ClientAppConfig` or `ServerAppConfig` but was type {type(app)}")
raise ValueError(f"App needs to be of type `ClientAppConfig` or `ServerAppConfig` but was type {type(client_server_config)}")

app_name = f"app_{(uuid.uuid4())}"
self.job.add_fed_app(app_name, app_config)
self.job.set_site_app(target, app_name)
Expand All @@ -62,107 +72,117 @@ def simulator_run(self, job_root):
self.job.simulator_run(job_root, self.workspace, threads=2)


def create_client_app(app_script, app_config=""):
_site_name = f"site_{uuid.uuid4()}"
_job_id = f"{uuid.uuid4()}"

client_app = ClientAppConfig()
executor = PTClientAPILauncherExecutor(
launcher_id="launcher",
pipe_id="pipe",
heartbeat_timeout=60,
params_exchange_format="pytorch",
params_transfer_type="DIFF",
train_with_evaluation=True,
)
client_app.add_executor(["train"], executor)

component = SubprocessLauncher(script=f"python3 custom/{app_script} {app_config}", launch_once=True)
client_app.add_component("launcher", component)

# TODO: Use CellPipe, create CellPipe objects as part of components that require it. Automatically set root_url in CellPipe
component = FilePipe(
mode=Mode.PASSIVE,
root_path=f"/tmp/nvflare/_file_pipe/{_job_id}/{_site_name}"
)
client_app.add_component("pipe", component)

component = FilePipe(
mode=Mode.PASSIVE,
root_path=f"/tmp/nvflare/_file_pipe/{_job_id}/{_site_name}"
)
client_app.add_component("metrics_pipe", component)

component = MetricRelay(
pipe_id="metrics_pipe",
event_type="fed.analytix_log_stats",
read_interval=0.1
)
client_app.add_component("metric_relay", component)

component = ExternalConfigurator(
component_ids=["metric_relay"]
)
client_app.add_component("config_preparer", component)

client_app.add_component("net", Net()) # TODO: find another way to register files that need to be included in custom folder

return client_app


def create_server_app(min_clients, num_rounds, model_class_path="net.Net"):
server_app = ServerAppConfig()
controller = FedAvg(
min_clients=min_clients,
num_rounds=num_rounds,
persistor_id="persistor"
)
server_app.add_workflow("fedavg_ctl", controller)

controller = CrossSiteModelEval(model_locator_id="model_locator")
server_app.add_workflow("cross_site_validate", controller)

component = PTFileModelPersistor(
model={"path": model_class_path}
)

server_app.add_component("persistor", component)

component = PTFileModelLocator(
pt_persistor_id="persistor"
)
server_app.add_component("model_locator", component)

component = ValidationJsonGenerator()
server_app.add_component("json_generator", component)

component = IntimeModelSelector(
key_metric="accuracy"
)
server_app.add_component("model_selector", component)

component = TBAnalyticsReceiver(
events=["fed.analytix_log_stats"]
)
server_app.add_component("receiver", component)

server_app.add_component("net", Net()) # TODO: have another way to register needed scripts

return server_app

# TODO: add another FedApp class layer here
class ClientFedApp(FedApp):
def __init__(self, app_script, app_config=""):
super().__init__()

self._site_name = f"site_{uuid.uuid4()}"
self._job_id = f"{uuid.uuid4()}"
self._create_client_app(app_script, app_config)

def _create_client_app(self, app_script, app_config):
self.app = ClientAppConfig()
executor = PTClientAPILauncherExecutor(
launcher_id="launcher",
pipe_id="pipe",
heartbeat_timeout=60,
params_exchange_format="pytorch",
params_transfer_type="DIFF",
train_with_evaluation=True,
)
self.app.add_executor(["train"], executor)

component = SubprocessLauncher(script=f"python3 custom/{app_script} {app_config}", launch_once=True)
self.app.add_component("launcher", component)

# TODO: Use CellPipe, create CellPipe objects as part of components that require it. Automatically set root_url in CellPipe
component = FilePipe(
mode=Mode.PASSIVE,
root_path=f"/tmp/nvflare/_file_pipe/{self._job_id}/{self._site_name}"
)
self.app.add_component("pipe", component)

component = FilePipe(
mode=Mode.PASSIVE,
root_path=f"/tmp/nvflare/_file_pipe/{self._job_id}-1/{self._site_name}"
)
self.app.add_component("metrics_pipe", component)

component = MetricRelay(
pipe_id="metrics_pipe",
event_type="fed.analytix_log_stats",
read_interval=0.1
)
self.app.add_component("metric_relay", component)

component = ExternalConfigurator(
component_ids=["metric_relay"]
)
self.app.add_component("config_preparer", component)

self.app.add_component("net", Net()) # TODO: find another way to register files that need to be included in custom folder


class ServerFedApp(FedApp):
def __init__(self, controllers: List, model_class_path="net.Net"):
super().__init__()
self.controllers = controllers
self.model_class_path = model_class_path
self._create_server_app()

def _create_server_app(self):
self.app: ServerAppConfig = ServerAppConfig()

for i, _controller in enumerate(self.controllers):
self.app.add_workflow(f"controller_{i}", _controller)

# TODO: make optional or list of controllers?
_controller = CrossSiteModelEval(model_locator_id="model_locator")
self.app.add_workflow("cross_site_validate", _controller)

component = PTFileModelPersistor(
model={"path": self.model_class_path}
)

self.app.add_component("persistor", component)

component = PTFileModelLocator(
pt_persistor_id="persistor"
)
self.app.add_component("model_locator", component)

component = ValidationJsonGenerator()
self.app.add_component("json_generator", component)

component = IntimeModelSelector(
key_metric="accuracy"
)
self.app.add_component("model_selector", component)

component = TBAnalyticsReceiver(
events=["fed.analytix_log_stats"]
)
self.app.add_component("receiver", component)

self.app.add_component("net", Net()) # TODO: have another way to register needed scripts


if __name__ == "__main__":
n_clients = 2
num_rounds = 2

job = FedJob(name="cifar10_fedavg")

server_app = create_server_app(min_clients=2, num_rounds=2)
controller = FedAvg(
min_clients=n_clients,
num_rounds=num_rounds,
persistor_id="persistor" # TODO: why is it not using default
)
server_app = ServerFedApp(controllers=[controller])
job.to(server_app, "server")

for i in range(n_clients):
client_app = create_client_app(app_script="cifar10_fl.py") # TODO: don't require site_name here
client_app = ClientFedApp(app_script="cifar10_fl.py")
job.to(client_app, f"site-{i}")

#job.export_job("/tmp/nvflare/jobs")
Expand Down

0 comments on commit ad945ec

Please sign in to comment.