diff --git a/examples/hello-world/python_jobs/client-api-pt/cifar10_fl.py b/examples/hello-world/python_jobs/client-api-pt/cifar10_fl.py new file mode 100644 index 0000000000..5bd28adc14 --- /dev/null +++ b/examples/hello-world/python_jobs/client-api-pt/cifar10_fl.py @@ -0,0 +1,136 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.optim as optim +import torchvision +import torchvision.transforms as transforms +from net import Net + +# (1) import nvflare client API +import nvflare.client as flare + +# (optional) metrics +from nvflare.client.tracking import SummaryWriter + +# (optional) set a fix place so we don't need to download everytime +DATASET_PATH = "/tmp/nvflare/data" +# If available, we use GPU to speed things up. +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +def main(): + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) + + batch_size = 4 + epochs = 2 + + trainset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform) + trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) + + testset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, download=True, transform=transform) + testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) + + net = Net() + + # (2) initializes NVFlare client API + flare.init() + + summary_writer = SummaryWriter() + while flare.is_running(): + # (3) receives FLModel from NVFlare + input_model = flare.receive() + print(f"current_round={input_model.current_round}") + + # (4) loads model from NVFlare + net.load_state_dict(input_model.params) + + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) + + # (optional) use GPU to speed things up + net.to(DEVICE) + # (optional) calculate total steps + steps = epochs * len(trainloader) + for epoch in range(epochs): # loop over the dataset multiple times + + running_loss = 0.0 + for i, data in enumerate(trainloader, 0): + # get the inputs; data is a list of [inputs, labels] + # (optional) use GPU to speed things up + inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = net(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # print statistics + running_loss += loss.item() + if i % 2000 == 1999: # print every 2000 mini-batches + print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") + global_step = input_model.current_round * steps + epoch * len(trainloader) + i + + summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step) + running_loss = 0.0 + + print("Finished Training") + + PATH = "./cifar_net.pth" + torch.save(net.state_dict(), PATH) + + # (5) wraps evaluation logic into a method to re-use for + # evaluation on both trained and received model + def evaluate(input_weights): + net = Net() + net.load_state_dict(input_weights) + # (optional) use GPU to speed things up + net.to(DEVICE) + + correct = 0 + total = 0 + # since we're not training, we don't need to calculate the gradients for our outputs + with torch.no_grad(): + for data in testloader: + # (optional) use GPU to speed things up + images, labels = data[0].to(DEVICE), data[1].to(DEVICE) + # calculate outputs by running images through the network + outputs = net(images) + # the class with the highest energy is what we choose as prediction + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + + print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") + return 100 * correct // total + + # (6) evaluate on received model for model selection + accuracy = evaluate(input_model.params) + # (7) construct trained FL model + output_model = flare.FLModel( + params=net.cpu().state_dict(), + metrics={"accuracy": accuracy}, + meta={"NUM_STEPS_CURRENT_ROUND": steps}, + ) + # (8) send model back to NVFlare + flare.send(output_model) + + +if __name__ == "__main__": + main() diff --git a/examples/hello-world/python_jobs/client-api-pt/client_api_pt_job.py b/examples/hello-world/python_jobs/client-api-pt/client_api_pt_job.py new file mode 100644 index 0000000000..5bfa284a13 --- /dev/null +++ b/examples/hello-world/python_jobs/client-api-pt/client_api_pt_job.py @@ -0,0 +1,181 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from typing import Union, List + +from nvflare.job_config.fed_app_config import ClientAppConfig, ServerAppConfig, FedAppConfig +from nvflare.job_config.fed_job_config import FedJobConfig +from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator +from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval +from nvflare.app_opt.pt import PTFileModelPersistor +from nvflare.app_opt.pt.client_api_launcher_executor import PTClientAPILauncherExecutor +from nvflare.app_common.launchers.subprocess_launcher import SubprocessLauncher +from nvflare.app_common.widgets.metric_relay import MetricRelay +from nvflare.app_common.widgets.external_configurator import ExternalConfigurator +from nvflare.app_common.workflows.fedavg import FedAvg +from nvflare.app_opt.pt.file_model_locator import PTFileModelLocator +from nvflare.fuel.utils.constants import Mode +from nvflare.fuel.utils.pipe.file_pipe import FilePipe +from net import Net +from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector +from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver +from nvflare.app_opt.pt.in_process_client_api_executor import PTInProcessClientAPIExecutor + + +class FedApp: + def __init__(self): + self.app = None # Union[ClientAppConfig, ServerAppConfig] + + def get_app_config(self): + return self.app + + +class FedJob: + def __init__(self, name="client-api-pt", workspace="/tmp/nvflare/simulator_workspace") -> None: + self.job_name = name + self.workspace = workspace + + self.job: FedJobConfig = FedJobConfig(job_name=self.job_name, min_clients=1) + + def to(self, app: FedApp, target): + if not isinstance(app, FedApp): + raise ValueError(f"App needs to be of type `FedApp` but was type {type(app)}") + + client_server_config = app.get_app_config() + if isinstance(client_server_config, ClientAppConfig): + app_config = FedAppConfig(server_app=None, client_app=client_server_config) + app_name = f"app_client_{(uuid.uuid4())}" + elif isinstance(client_server_config, ServerAppConfig): + app_config = FedAppConfig(server_app=client_server_config, client_app=None) + app_name = f"app_server_{(uuid.uuid4())}" + else: + raise ValueError(f"App needs to be of type `ClientAppConfig` or `ServerAppConfig` but was type {type(client_server_config)}") + + self.job.add_fed_app(app_name, app_config) + self.job.set_site_app(target, app_name) + + def export_job(self, job_root): + self.job.generate_job_config(job_root) + + def simulator_run(self, job_root): + self.job.simulator_run(job_root, self.workspace, threads=2) + + +class ExecutorApp(FedApp): + def __init__(self, executors: List): + super().__init__() + self.executors = executors + #self._site_name = f"site_{uuid.uuid4()}" + #self._job_id = f"{uuid.uuid4()}" + self._create_client_app() + + def _create_client_app(self): + self.app = ClientAppConfig() + + for _executor in self.executors: + self.app.add_executor(["train"], _executor) + + component = FilePipe( + mode="PASSIVE", #TODO: enable passing Mode.PASSIVE + root_path="{WORKSPACE}/{JOB_ID}/{SITE_NAME}" # TODO: this creates empty subfolder structure + ) + self.app.add_component("metrics_pipe", component) + + component = MetricRelay( + pipe_id="metrics_pipe", + event_type="fed.analytix_log_stats", + read_interval=0.1 + ) + self.app.add_component("metric_relay", component) + + component = ExternalConfigurator( + component_ids=["metric_relay"] + ) + self.app.add_component("config_preparer", component) + + #self.app.add_component("net", Net()) # TODO: find another way to register files that need to be included in custom folder + self.app.add_ext_script("cifar10_fl.py") + + +class ControllerApp(FedApp): + def __init__(self, controllers: List, model_class_path=Net()): + super().__init__() + self.controllers = controllers + self.model_class_path = model_class_path + self._create_server_app() + + def _create_server_app(self): + self.app: ServerAppConfig = ServerAppConfig() + + for i, _controller in enumerate(self.controllers): + self.app.add_workflow(f"controller_{i}", _controller) + + # TODO: make optional or list of controllers? # add as new example + #_controller = CrossSiteModelEval(model_locator_id="model_locator") + #self.app.add_workflow("cross_site_validate", _controller) + + component = PTFileModelPersistor( + model=Net() + ) + + self.app.add_component("persistor", component) + + component = PTFileModelLocator( + pt_persistor_id="persistor" + ) + self.app.add_component("model_locator", component) + + component = ValidationJsonGenerator() + self.app.add_component("json_generator", component) + + component = IntimeModelSelector( + key_metric="accuracy" + ) + self.app.add_component("model_selector", component) + + component = TBAnalyticsReceiver( + events=["fed.analytix_log_stats"] + ) + self.app.add_component("receiver", component) + + #self.app.add_component("net", Net()) # TODO: have another way to register needed scripts + + +if __name__ == "__main__": + n_clients = 2 + num_rounds = 2 + + job = FedJob(name="cifar10_fedavg") + + controller = FedAvg( + min_clients=n_clients, + num_rounds=num_rounds, + persistor_id="persistor" # TODO: why is it not using default + ) + server_app = ControllerApp(controllers=[controller]) + job.to(server_app, "server") + + for i in range(n_clients): + executor = PTInProcessClientAPIExecutor( + task_fn_path="cifar10_fl.main", + #task_fn_args={"batch_size": 1} + ) + client_app = ExecutorApp(executors=[executor]) + job.to(client_app, f"site-{i}") + + job.export_job("/tmp/nvflare/jobs/job_config") + job.simulator_run("/tmp/nvflare/jobs/workdir") + + diff --git a/examples/hello-world/python_jobs/client-api-pt/net.py b/examples/hello-world/python_jobs/client-api-pt/net.py new file mode 100644 index 0000000000..1e86a406f0 --- /dev/null +++ b/examples/hello-world/python_jobs/client-api-pt/net.py @@ -0,0 +1,38 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + #from cifar10_fl import main # TODO: remove + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x diff --git a/nvflare/app_opt/pt/in_process_client_api_executor.py b/nvflare/app_opt/pt/in_process_client_api_executor.py index 3998567985..165fcee77e 100644 --- a/nvflare/app_opt/pt/in_process_client_api_executor.py +++ b/nvflare/app_opt/pt/in_process_client_api_executor.py @@ -25,7 +25,7 @@ class PTInProcessClientAPIExecutor(InProcessClientAPIExecutor): def __init__( self, task_fn_path: str, - task_fn_args: Dict = {}, + task_fn_args: Dict = None, task_wait_time: Optional[float] = None, result_pull_interval: float = 0.5, log_pull_interval: Optional[float] = None,