Skip to content

Commit

Permalink
switch to 2.4.1.rc1
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed May 1, 2024
1 parent 36cf887 commit 23b8228
Show file tree
Hide file tree
Showing 3 changed files with 591 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,15 @@ def main():
print(f"current_round={input_model.current_round}")

# (4) loads model from NVFlare and sends it to GPU
trainer.network.load_state_dict(input_model.params)
trainer.network.load_state_dict(input_model.params, strict=False) # TODO: enable strict
trainer.network.to(DEVICE)

trainer.run()

# (5) wraps evaluation logic into a method to re-use for
# evaluation on both trained and received model
def evaluate(input_weights):
model.load_state_dict(input_weights)
model.load_state_dict(input_weights, strict=False) # TODO: enable strict

# Check the prediction on the test dataset
dataset_dir = Path(root_dir, "MedNIST")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
format_version = 2

# This is the application script which will be invoked. Client can replace this script with user's own training script.
app_script = "monai_mednist_train.py"
app_script = "cifar10.py"

# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
# Additional arguments needed by the training code. For example, in lightning, these can be --trainer.batch_size=xxx.
app_config = ""
app_config = "monai_mednist_train.py"

# Client Computing Executors.
executors = [
Expand All @@ -18,33 +17,32 @@
# This particular executor
executor {

path = "nvflare.app_opt.pt.in_process_client_api_executor.PTInProcessClientAPIExecutor"
# This is an executor for Client API. The underline data exchange is using Pipe.
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor"

args {
task_script_path = "{app_script}"
task_script_args = "{app_config}"

# if the transfer_type is FULL, then it will be sent directly
# if the transfer_type is DIFF, then we will calculate the
# difference VS received parameters and send the difference
params_transfer_type = "DIFF"

# if train_with_evaluation is true, the executor will expect
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = true

# time interval in seconds. Time interval to wait before check if the local task has submitted the result
# if the local task takes long time, you can increase this interval to larger number
# uncomment to overwrite the default, default is 0.5 seconds
result_pull_interval = 0.5

# time interval in seconds. Time interval to wait before check if the trainig code has log metric (such as
# Tensorboard log, MLFlow log or Weights & Biases logs. The result will be streanmed to the server side
# then to the corresponding tracking system
# if the log is not needed, you can set this to a larger number
# uncomment to overwrite the default, default is None, which disable the log streaming feature.
log_pull_interval = 0.1
# launcher_id is used to locate the Launcher object in "components"
launcher_id = "launcher"

# pipe_id is used to locate the Pipe object in "components"
pipe_id = "pipe"

# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds.
# Please refer to the class docstring for all available arguments
heartbeat_timeout = 60

# format of the exchange parameters
params_exchange_format = "pytorch"

# if the transfer_type is FULL, then it will be sent directly
# if the transfer_type is DIFF, then we will calculate the
# difference VS received parameters and send the difference
params_transfer_type = "DIFF"

# if train_with_evaluation is true, the executor will expect
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = true
}
}
}
Expand All @@ -56,12 +54,63 @@
# this defined an array of task result filters. If provided, it will control the result from client executor to server controller
task_result_filters = []

# define this component that will help relay local metrics log to FL server.
components = [
{
"id": "event_to_fed",
"name": "ConvertToFedEvent",
"args": {"events_to_convert": ["analytix_log_stats"], "fed_event_prefix": "fed."}
{
# component id is "launcher"
id = "launcher"

# the class path of this component
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"

args {
# the launcher will invoke the script
script = "python3 custom/{app_script} {app_config} "
# if launch_once is true, the SubprocessLauncher will launch once for the whole job
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server
launch_once = true
}
]
}
}
{
id = "pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
}
{
id = "metrics_pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
},
{
id = "metric_relay"
path = "nvflare.app_common.widgets.metric_relay.MetricRelay"
args {
pipe_id = "metrics_pipe"
event_type = "fed.analytix_log_stats"
# how fast should it read from the peer
read_interval = 0.1
}
},
{
# we use this component so the client api `flare.init()` can get required information
id = "config_preparer"
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator"
args {
component_ids = ["metric_relay"]
}
}
]
}
Loading

0 comments on commit 23b8228

Please sign in to comment.