forked from NVIDIA/NVFlare
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7cfc38d
commit ba47f3c
Showing
3 changed files
with
351 additions
and
0 deletions.
There are no files selected for viewing
136 changes: 136 additions & 0 deletions
136
examples/hello-world/python_jobs/client-api-pt/cifar10_fl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
from net import Net | ||
|
||
# (1) import nvflare client API | ||
import nvflare.client as flare | ||
|
||
# (optional) metrics | ||
from nvflare.client.tracking import SummaryWriter | ||
|
||
# (optional) set a fix place so we don't need to download everytime | ||
DATASET_PATH = "/tmp/nvflare/data" | ||
# If available, we use GPU to speed things up. | ||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
|
||
def main(): | ||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | ||
|
||
batch_size = 4 | ||
epochs = 2 | ||
|
||
trainset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform) | ||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) | ||
|
||
testset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, download=True, transform=transform) | ||
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) | ||
|
||
net = Net() | ||
|
||
# (2) initializes NVFlare client API | ||
flare.init() | ||
|
||
summary_writer = SummaryWriter() | ||
while flare.is_running(): | ||
# (3) receives FLModel from NVFlare | ||
input_model = flare.receive() | ||
print(f"current_round={input_model.current_round}") | ||
|
||
# (4) loads model from NVFlare | ||
net.load_state_dict(input_model.params) | ||
|
||
criterion = nn.CrossEntropyLoss() | ||
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) | ||
|
||
# (optional) use GPU to speed things up | ||
net.to(DEVICE) | ||
# (optional) calculate total steps | ||
steps = epochs * len(trainloader) | ||
for epoch in range(epochs): # loop over the dataset multiple times | ||
|
||
running_loss = 0.0 | ||
for i, data in enumerate(trainloader, 0): | ||
# get the inputs; data is a list of [inputs, labels] | ||
# (optional) use GPU to speed things up | ||
inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE) | ||
|
||
# zero the parameter gradients | ||
optimizer.zero_grad() | ||
|
||
# forward + backward + optimize | ||
outputs = net(inputs) | ||
loss = criterion(outputs, labels) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# print statistics | ||
running_loss += loss.item() | ||
if i % 2000 == 1999: # print every 2000 mini-batches | ||
print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}") | ||
global_step = input_model.current_round * steps + epoch * len(trainloader) + i | ||
|
||
summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step) | ||
running_loss = 0.0 | ||
|
||
print("Finished Training") | ||
|
||
PATH = "./cifar_net.pth" | ||
torch.save(net.state_dict(), PATH) | ||
|
||
# (5) wraps evaluation logic into a method to re-use for | ||
# evaluation on both trained and received model | ||
def evaluate(input_weights): | ||
net = Net() | ||
net.load_state_dict(input_weights) | ||
# (optional) use GPU to speed things up | ||
net.to(DEVICE) | ||
|
||
correct = 0 | ||
total = 0 | ||
# since we're not training, we don't need to calculate the gradients for our outputs | ||
with torch.no_grad(): | ||
for data in testloader: | ||
# (optional) use GPU to speed things up | ||
images, labels = data[0].to(DEVICE), data[1].to(DEVICE) | ||
# calculate outputs by running images through the network | ||
outputs = net(images) | ||
# the class with the highest energy is what we choose as prediction | ||
_, predicted = torch.max(outputs.data, 1) | ||
total += labels.size(0) | ||
correct += (predicted == labels).sum().item() | ||
|
||
print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %") | ||
return 100 * correct // total | ||
|
||
# (6) evaluate on received model for model selection | ||
accuracy = evaluate(input_model.params) | ||
# (7) construct trained FL model | ||
output_model = flare.FLModel( | ||
params=net.cpu().state_dict(), | ||
metrics={"accuracy": accuracy}, | ||
meta={"NUM_STEPS_CURRENT_ROUND": steps}, | ||
) | ||
# (8) send model back to NVFlare | ||
flare.send(output_model) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
177 changes: 177 additions & 0 deletions
177
examples/hello-world/python_jobs/client-api-pt/client_api_pt_job.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import uuid | ||
from typing import Union, List | ||
|
||
from nvflare.job_config.fed_app_config import ClientAppConfig, ServerAppConfig, FedAppConfig | ||
from nvflare.job_config.fed_job_config import FedJobConfig | ||
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator | ||
from nvflare.app_opt.pt import PTFileModelPersistor | ||
from nvflare.app_common.widgets.metric_relay import MetricRelay | ||
from nvflare.app_common.widgets.external_configurator import ExternalConfigurator | ||
from nvflare.app_common.workflows.fedavg import FedAvg | ||
from nvflare.app_opt.pt.file_model_locator import PTFileModelLocator | ||
from nvflare.fuel.utils.constants import Mode | ||
from nvflare.fuel.utils.pipe.file_pipe import FilePipe | ||
from net import Net | ||
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector | ||
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver | ||
from nvflare.app_opt.pt.in_process_client_api_executor import PTInProcessClientAPIExecutor | ||
|
||
|
||
class FedApp: | ||
def __init__(self): | ||
self.app = None # Union[ClientAppConfig, ServerAppConfig] | ||
|
||
def get_app_config(self): | ||
return self.app | ||
|
||
|
||
class FedJob: | ||
def __init__(self, name="client-api-pt", workspace="/tmp/nvflare/simulator_workspace") -> None: | ||
self.job_name = name | ||
self.workspace = workspace | ||
|
||
self.job: FedJobConfig = FedJobConfig(job_name=self.job_name, min_clients=1) | ||
|
||
def to(self, app: FedApp, target): | ||
if not isinstance(app, FedApp): | ||
raise ValueError(f"App needs to be of type `FedApp` but was type {type(app)}") | ||
|
||
client_server_config = app.get_app_config() | ||
if isinstance(client_server_config, ClientAppConfig): | ||
app_config = FedAppConfig(server_app=None, client_app=client_server_config) | ||
app_name = f"app_client_{(uuid.uuid4())}" | ||
elif isinstance(client_server_config, ServerAppConfig): | ||
app_config = FedAppConfig(server_app=client_server_config, client_app=None) | ||
app_name = f"app_server_{(uuid.uuid4())}" | ||
else: | ||
raise ValueError(f"App needs to be of type `ClientAppConfig` or `ServerAppConfig` but was type {type(client_server_config)}") | ||
|
||
self.job.add_fed_app(app_name, app_config) | ||
self.job.set_site_app(target, app_name) | ||
|
||
def export_job(self, job_root): | ||
self.job.generate_job_config(job_root) | ||
|
||
def simulator_run(self, job_root): | ||
self.job.simulator_run(job_root, self.workspace, threads=2) | ||
|
||
|
||
class ExecutorApp(FedApp): | ||
def __init__(self, executors: List): | ||
super().__init__() | ||
self.executors = executors | ||
#self._site_name = f"site_{uuid.uuid4()}" | ||
#self._job_id = f"{uuid.uuid4()}" | ||
self._create_client_app() | ||
|
||
def _create_client_app(self): | ||
self.app = ClientAppConfig() | ||
|
||
for _executor in self.executors: | ||
self.app.add_executor(["train"], _executor) | ||
|
||
component = FilePipe( | ||
mode="PASSIVE", #TODO: enable passing Mode.PASSIVE | ||
root_path="{WORKSPACE}/{JOB_ID}/{SITE_NAME}" # TODO: this creates empty subfolder structure | ||
) | ||
self.app.add_component("metrics_pipe", component) | ||
|
||
component = MetricRelay( | ||
pipe_id="metrics_pipe", | ||
event_type="fed.analytix_log_stats", | ||
read_interval=0.1 | ||
) | ||
self.app.add_component("metric_relay", component) | ||
|
||
component = ExternalConfigurator( | ||
component_ids=["metric_relay"] | ||
) | ||
self.app.add_component("config_preparer", component) | ||
|
||
#self.app.add_component("net", Net()) # TODO: find another way to register files that need to be included in custom folder | ||
self.app.add_ext_script("cifar10_fl.py") | ||
|
||
|
||
class ControllerApp(FedApp): | ||
def __init__(self, controllers: List, model_class_path=Net()): | ||
super().__init__() | ||
self.controllers = controllers | ||
self.model_class_path = model_class_path | ||
self._create_server_app() | ||
|
||
def _create_server_app(self): | ||
self.app: ServerAppConfig = ServerAppConfig() | ||
|
||
for i, _controller in enumerate(self.controllers): | ||
self.app.add_workflow(f"controller_{i}", _controller) | ||
|
||
# TODO: make optional or list of controllers? # add as new example | ||
#_controller = CrossSiteModelEval(model_locator_id="model_locator") | ||
#self.app.add_workflow("cross_site_validate", _controller) | ||
|
||
component = PTFileModelPersistor( | ||
model=Net() | ||
) | ||
|
||
self.app.add_component("persistor", component) | ||
|
||
component = PTFileModelLocator( | ||
pt_persistor_id="persistor" | ||
) | ||
self.app.add_component("model_locator", component) | ||
|
||
component = ValidationJsonGenerator() | ||
self.app.add_component("json_generator", component) | ||
|
||
component = IntimeModelSelector( | ||
key_metric="accuracy" | ||
) | ||
self.app.add_component("model_selector", component) | ||
|
||
component = TBAnalyticsReceiver( | ||
events=["fed.analytix_log_stats"] | ||
) | ||
self.app.add_component("receiver", component) | ||
|
||
#self.app.add_component("net", Net()) # TODO: have another way to register needed scripts | ||
|
||
|
||
if __name__ == "__main__": | ||
n_clients = 2 | ||
num_rounds = 2 | ||
|
||
job = FedJob(name="cifar10_fedavg") | ||
|
||
controller = FedAvg( | ||
min_clients=n_clients, | ||
num_rounds=num_rounds, | ||
) | ||
server_app = ControllerApp(controllers=[controller]) | ||
job.to(server_app, "server") | ||
|
||
for i in range(n_clients): | ||
executor = PTInProcessClientAPIExecutor( | ||
task_script_path="cifar10_fl.py", | ||
task_script_args="" # --batch_size 32 --data_path f"/tmp/data/site-{i}" | ||
) | ||
client_app = ExecutorApp(executors=[executor]) | ||
job.to(client_app, f"site-{i}") # gpu=0 | ||
|
||
job.export_job("/tmp/nvflare/jobs/job_config") | ||
job.simulator_run("/tmp/nvflare/jobs/workdir") | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x): | ||
#from cifar10_fl import main # TODO: remove | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = torch.flatten(x, 1) # flatten all dimensions except batch | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x |