Skip to content

Commit

Permalink
refactor Controller/ExcecutorApps
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Apr 15, 2024
1 parent ac5d579 commit 36b59c7
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 36 deletions.
8 changes: 5 additions & 3 deletions examples/hello-world/python_jobs/pt/client_api_pt_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,17 @@
min_clients=n_clients,
num_rounds=num_rounds,
)
server_app = ControllerApp(controllers=[controller], persistor_model=Net())
server_app = ControllerApp(init_model=Net()) # TODO: use load/save model in FedAvg
server_app.add_controller(controller)
job.to(server_app, "server")

for i in range(n_clients):
executor = PTInProcessClientAPIExecutor(
task_script_path=train_script,
task_script_args=""
task_script_args="" # --batch_size 32 --data_path f'/tmp/data/site-{i}'
)
client_app = ExecutorApp(executors=[executor], external_scripts=[train_script])
client_app = ExecutorApp(external_scripts=[train_script])
client_app.add_executor(executor, tasks=["train"])
job.to(client_app, f"site-{i}", gpu=0)

job.export_job("/tmp/nvflare/jobs/job_config")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,17 @@
min_clients=n_clients,
num_rounds=num_rounds,
)
server_app = ControllerApp(controllers=[controller], persistor_model=Net())
server_app = ControllerApp(init_model=Net())
server_app.add_controller(controller)
job.to(server_app, "server")

for i in range(n_clients):
executor = PTInProcessClientAPIExecutor(
task_script_path=train_script,
task_script_args=""
)
client_app = ExecutorApp(executors=[executor], external_scripts=[train_script]) # TODO: add_executor?
client_app = ExecutorApp(external_scripts=[train_script])
client_app.add_executor(executor=executor, tasks=["train"])

# add privacy filter. # TODO: is there a better way to handle task names?
client_app.add_task_result_filter(["train"], PercentilePrivacy(percentile=10, gamma=0.01))
Expand Down
11 changes: 9 additions & 2 deletions examples/hello-world/python_jobs/pt/model_learner_xsite_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@
alpha=alpha,
)

server_app = ControllerApp(controllers=[ctrl1, ctrl2], persistor_model=ModerateCNN(), extra_components=[data_splitter])
#server_app = ControllerApp(controllers=[ctrl1, ctrl2], persistor_model=ModerateCNN(), extra_components=[data_splitter])
server_app = ControllerApp()
server_app.add_controller(ctrl1)
server_app.add_controller(ctrl2)
server_app.add_component(data_splitter)
job.to(server_app, "server")

for i in range(n_clients):
Expand All @@ -60,7 +64,10 @@
executor = ModelLearnerExecutor(
learner_id=learner # TODO: change more places that use id to directly accept objects
)
client_app = ExecutorApp(executors=[executor], tasks=[["train", "submit_model", "validate"]])
#client_app = ExecutorApp(executors=[executor], tasks=[["train", "submit_model", "validate"]])
client_app = ExecutorApp()
client_app.add_executor(executor=executor, task=["train", "submit_model", "validate"], result_filter=[""], data_filter=[""])

job.to(client_app, f"site-{i+1}", gpu=0) # data splitter assumes client names start from 1

job.export_job("/tmp/nvflare/jobs/job_config")
Expand Down
51 changes: 22 additions & 29 deletions nvflare/fed_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
class FedApp:
def __init__(self):
self.app = None # Union[ClientAppConfig, ServerAppConfig]
self._component_count = 0

def get_app_config(self):
return self.app
Expand All @@ -43,6 +44,10 @@ def add_task_result_filter(self, tasks: List[str], task_filter: Filter):
def add_task_data_filter(self, tasks: List[str], task_filter: Filter):
self.app.add_task_data_filter(tasks, task_filter)

def add_component(self, component):
self.app.add_component(f"component_{self._component_count}", component)
self._component_count += 1


class FedJob:
def __init__(self, name="fed_job", min_clients=1, mandatory_clients=None) -> None:
Expand All @@ -58,13 +63,13 @@ def to(self, app: FedApp, target: str, gpu: int = None):
client_server_config = app.get_app_config()
if isinstance(client_server_config, ClientAppConfig):
app_config = FedAppConfig(server_app=None, client_app=client_server_config)
app_name = f"app_client_{(uuid.uuid4())}"
app_name = f"app_client_{uuid.uuid4()}"
self.clients.append(target)
if gpu is not None:
self.gpus.append(str(gpu))
elif isinstance(client_server_config, ServerAppConfig):
app_config = FedAppConfig(server_app=client_server_config, client_app=None)
app_name = f"app_server_{(uuid.uuid4())}"
app_name = f"app_server_{uuid.uuid4()}"
else:
raise ValueError(
f"App needs to be of type `ClientAppConfig` or `ServerAppConfig` but was type {type(client_server_config)}"
Expand Down Expand Up @@ -93,23 +98,19 @@ def simulator_run(self, workspace, threads: int = None):


class ExecutorApp(FedApp):
def __init__(self, executors: List, tasks: List[List] = None, external_scripts: List = None, extra_components: List = None):
def __init__(self, external_scripts: List = None):
super().__init__()
self.executors = executors
self.tasks = tasks
self.external_scripts = external_scripts
self.extra_components = extra_components
self._create_client_app()

def add_executor(self, executor, tasks=None):
if tasks is None:
tasks = ["train"]
self.app.add_executor(tasks, executor)

def _create_client_app(self):
self.app = ClientAppConfig()

if self.tasks is None:
self.tasks = [["train"]] * len(self.executors)

for _task, _executor in zip(self.tasks, self.executors):
self.app.add_executor(_task, _executor)

component = FilePipe( # TODO: support CellPipe, causes type error for passing secure_mode = "{SECURE_MODE}"
mode=Mode.PASSIVE,
root_path="{WORKSPACE}/{JOB_ID}/{SITE_NAME}", # TODO: this creates empty subfolder structure
Expand All @@ -122,31 +123,27 @@ def _create_client_app(self):
component = ExternalConfigurator(component_ids=["metric_relay"])
self.app.add_component("config_preparer", component)

if self.extra_components is not None:
for i, _component in enumerate(self.extra_components):
self.app.add_component(f"extra_component_{i}", _component)

if self.external_scripts is not None:
for _script in self.external_scripts:
self.app.add_ext_script(_script)


class ControllerApp(FedApp):
def __init__(self, controllers: List, persistor_model=None, extra_components: List = None):
def __init__(self, init_model=None):
super().__init__()
self.controllers = controllers
self.persistor_model = persistor_model
self.extra_components = extra_components
self.init_model = init_model
self._create_server_app()
self._controller_count = 0

def add_controller(self, controller):
self.app.add_workflow(f"controller_{self._controller_count}", controller)
self._controller_count += 1

def _create_server_app(self):
self.app: ServerAppConfig = ServerAppConfig()

for i, _controller in enumerate(self.controllers):
self.app.add_workflow(f"controller_{i}", _controller)

if self.persistor_model is not None:
component = PTFileModelPersistor(model=self.persistor_model)
if self.init_model is not None:
component = PTFileModelPersistor(model=self.init_model)
self.app.add_component("persistor", component)

component = PTFileModelLocator(pt_persistor_id="persistor")
Expand All @@ -160,7 +157,3 @@ def _create_server_app(self):

component = TBAnalyticsReceiver(events=["fed.analytix_log_stats"])
self.app.add_component("receiver", component)

if self.extra_components is not None:
for i, _component in enumerate(self.extra_components):
self.app.add_component(f"extra_component_{i}", _component)

0 comments on commit 36b59c7

Please sign in to comment.