Skip to content

Commit

Permalink
update swarm learning
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Apr 24, 2024
1 parent 0a80c4f commit dd787e0
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 7 deletions.
8 changes: 1 addition & 7 deletions examples/hello-world/python_jobs/pt/client_api_pt_swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
if __name__ == "__main__":
n_clients = 2
num_rounds = 3
train_script = "code/cifar10_fl.py"
train_script = "code/train_eval_submit.py"

job = FedJob(name="cifar10_swarm", init_model=Net(), external_scripts=[train_script]) # TODO: use load/save model in FedAvg

Expand All @@ -47,12 +47,6 @@
)
job.to(executor, f"site-{i}", gpu=0, tasks=["train", "validate", "submit_model"])

pipe = FilePipe(
mode=Mode.PASSIVE,
root_path="{WORKSPACE}/{JOB_ID}/{SITE_NAME}",
)
job.to(pipe, f"site-{i}", id="pipe")

client_controller = SwarmClientController()
job.to(client_controller, f"site-{i}", tasks=["swarm_*"])

Expand Down
188 changes: 188 additions & 0 deletions examples/hello-world/python_jobs/pt/code/train_eval_submit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from net import Net

# (1) import nvflare client API
import nvflare.client as flare
from nvflare.app_common.app_constant import ModelName

# (optional) set a fix place so we don't need to download everytime
CIFAR10_ROOT = "/tmp/nvflare/data/cifar10"
# (optional) We change to use GPU to speed things up.
# if you want to use CPU, change DEVICE="cpu"
DEVICE = "cuda:0"


def define_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_path", type=str, default=CIFAR10_ROOT, nargs="?")
parser.add_argument("--batch_size", type=int, default=4, nargs="?")
parser.add_argument("--num_workers", type=int, default=1, nargs="?")
parser.add_argument("--local_epochs", type=int, default=2, nargs="?")
parser.add_argument("--model_path", type=str, default=f"{CIFAR10_ROOT}/cifar_net.pth", nargs="?")
return parser.parse_args()


def main():
# define local parameters
args = define_parser()

dataset_path = args.dataset_path
batch_size = args.batch_size
num_workers = args.num_workers
local_epochs = args.local_epochs
model_path = args.model_path

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root=dataset_path, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
testset = torchvision.datasets.CIFAR10(root=dataset_path, train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

net = Net()
best_accuracy = 0.0

# wraps evaluation logic into a method to re-use for
# evaluation on both trained and received model
def evaluate(input_weights):
net = Net()
net.load_state_dict(input_weights)
# (optional) use GPU to speed things up
net.to(DEVICE)

correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in testloader:
# (optional) use GPU to speed things up
images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
# calculate outputs by running images through the network
outputs = net(images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

return 100 * correct // total

# (2) initialize NVFlare client API
flare.init()

# (3) run continously when launch_once=true
while flare.is_running():

# (4) receive FLModel from NVFlare
input_model = flare.receive()
client_id = flare.get_site_name()

# Based on different "task" we will do different things
# for "train" task (flare.is_train()) we use the received model to do training and/or evaluation
# and send back updated model and/or evaluation metrics, if the "train_with_evaluation" is specified as True
# in the config_fed_client we will need to do evaluation and include the evaluation metrics
# for "evaluate" task (flare.is_evaluate()) we use the received model to do evaluation
# and send back the evaluation metrics
# for "submit_model" task (flare.is_submit_model()) we just need to send back the local model
# (5) performing train task on received model
if flare.is_train():
print(f"({client_id}) current_round={input_model.current_round}, total_rounds={input_model.total_rounds}")

# (5.1) loads model from NVFlare
net.load_state_dict(input_model.params)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# (optional) use GPU to speed things up
net.to(DEVICE)
# (optional) calculate total steps
steps = local_epochs * len(trainloader)
for epoch in range(local_epochs): # loop over the dataset multiple times

running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
# (optional) use GPU to speed things up
inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)

# zero the parameter gradients
optimizer.zero_grad()

# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f"({client_id}) [{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
running_loss = 0.0

print(f"({client_id}) Finished Training")

# (5.2) evaluation on local trained model to save best model
local_accuracy = evaluate(net.state_dict())
print(f"({client_id}) Evaluating local trained model. Accuracy on the 10000 test images: {local_accuracy}")
if local_accuracy > best_accuracy:
best_accuracy = local_accuracy
torch.save(net.state_dict(), model_path)

# (5.3) evaluate on received model for model selection
accuracy = evaluate(input_model.params)
print(
f"({client_id}) Evaluating received model for model selection. Accuracy on the 10000 test images: {accuracy}"
)

# (5.4) construct trained FL model
output_model = flare.FLModel(
params=net.cpu().state_dict(),
metrics={"accuracy": accuracy},
meta={"NUM_STEPS_CURRENT_ROUND": steps},
)

# (5.5) send model back to NVFlare
flare.send(output_model)

# (6) performing evaluate task on received model
elif flare.is_evaluate():
accuracy = evaluate(input_model.params)
flare.send(flare.FLModel(metrics={"accuracy": accuracy}))

# (7) performing submit_model task to obtain best local model
elif flare.is_submit_model():
model_name = input_model.meta["submit_model_name"]
if model_name == ModelName.BEST_MODEL:
try:
weights = torch.load(model_path)
net = Net()
net.load_state_dict(weights)
flare.send(flare.FLModel(params=net.cpu().state_dict()))
except Exception as e:
raise ValueError("Unable to load best model") from e
else:
raise ValueError(f"Unknown model_type: {model_name}")


if __name__ == "__main__":
main()

0 comments on commit dd787e0

Please sign in to comment.