Skip to content

Commit

Permalink
multi-site deployment
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Mar 11, 2024
1 parent d5e4b13 commit 3a440ac
Show file tree
Hide file tree
Showing 11 changed files with 556 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import nvflare.client as flare
from nvflare.client.tracking import SummaryWriter
writer = SummaryWriter()

TRAIN_STEP = 0
TEST_STEP = 0

Expand Down Expand Up @@ -137,7 +138,7 @@ def set_parameters(self, parameters):

def fit(self, parameters, config):
self.set_parameters(parameters)
train(net, trainloader, epochs=1)
train(net, trainloader, epochs=5)
return self.get_parameters(config={}), len(trainloader.dataset), {}

def evaluate(self, parameters, config):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
{
# version of the configuration
format_version = 2

# Client Computing Executors.
executors = [
{
# tasks the executors are defined to handle
tasks = ["train"]

# This particular executor
executor {

# Executor name : ClientAPILauncherExecutor
# This is an executor for Client API. The underline data exchange is using Pipe.
path = "nvflare.app_common.executors.client_api_launcher_executor.ClientAPILauncherExecutor"

args {

# This executor take an component named "launcher"
launcher_id = "launcher"

# This executor needs Pipe component
pipe_id = "pipe"

# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds.
# Please refer to the class docstring for all available arguments
heartbeat_timeout = 60

# format of the exchange parameters
params_exchange_format = "numpy"

# if the transfer_type is FULL, then it will be sent directly
# if the transfer_type is DIFF, then we will calculate the
# difference VS received parameters and send the difference
params_transfer_type = "FULL"

# if train_with_evaluation is true, the executor will expect
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = true

}
}
}
],

# this defined an array of task data filters. If provided, it will control the data from server controller to client executor
task_data_filters = []

# this defined an array of task result filters. If provided, it will control the result from client executor to server controller
task_result_filters = []

components = [
{
# component id is "launcher"
id = "launcher"

# the class path of this component
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"

args {
# the launcher will invoke the script
script = "python3 custom/client.py --node-id 1"
# if launch_once is true, the SubprocessLauncher will launch once for the whole job
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server
launch_once = true
}
},
{
id = "pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
},
{
id = "metrics_pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
},
{
id = "metric_relay"
path = "nvflare.app_common.widgets.metric_relay.MetricRelay"
args {
pipe_id = "metrics_pipe"
event_type = "fed.analytix_log_stats"
# how fast should it read from the peer
read_interval = 0.1
}
},
{
# we use this component so the client api `flare.init()` can get required information
id = "client_api_config_preparer"
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator"
args {
component_ids = ["metric_relay"]
}
}
]
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import argparse
import warnings
from collections import OrderedDict

import flwr as fl
from flwr_datasets import FederatedDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm

import nvflare.client as flare
from nvflare.client.tracking import SummaryWriter
writer = SummaryWriter()

TRAIN_STEP = 0
TEST_STEP = 0

# #############################################################################
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
# #############################################################################

warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)

def train(net, trainloader, epochs):
"""Train the model on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

global TRAIN_STEP
for _ in range(epochs):
avg_loss = 0.0
for batch in tqdm(trainloader, "Training"):
images = batch["img"]
labels = batch["label"]
optimizer.zero_grad()
loss = criterion(net(images.to(DEVICE)), labels.to(DEVICE))
loss.backward()
optimizer.step()
avg_loss += loss.item()
writer.add_scalar("train_loss", avg_loss/len(trainloader), TRAIN_STEP)
TRAIN_STEP += 1


def test(net, testloader):
"""Validate the model on the test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
global TEST_STEP
with torch.no_grad():
for batch in tqdm(testloader, "Testing"):
images = batch["img"].to(DEVICE)
labels = batch["label"].to(DEVICE)
outputs = net(images)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
writer.add_scalar("test_loss", loss, TEST_STEP)
writer.add_scalar("test_accuracy", accuracy, TEST_STEP)
TEST_STEP += 1
return loss, accuracy


def load_data(node_id):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3})
partition = fds.load_partition(node_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2)
pytorch_transforms = Compose(
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch

partition_train_test = partition_train_test.with_transform(apply_transforms)
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
testloader = DataLoader(partition_train_test["test"], batch_size=32)
return trainloader, testloader


# #############################################################################
# 2. Federation of the pipeline with Flower
# #############################################################################

# Get node id
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--node-id",
choices=[0, 1, 2],
required=True,
type=int,
help="Partition of the dataset divided into 3 iid partitions created artificially.",
)
node_id = parser.parse_args().node_id

# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
trainloader, testloader = load_data(node_id=node_id)


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)

def fit(self, parameters, config):
self.set_parameters(parameters)
train(net, trainloader, epochs=5)
return self.get_parameters(config={}), len(trainloader.dataset), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(net, testloader)
return loss, len(testloader.dataset), {"accuracy": accuracy}


# initializes NVFlare interface
flare.init()

# get system information
sys_info = flare.system_info()
print(f"Flare system info is: {sys_info}")

# Start Flower client
fl.client.start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
Loading

0 comments on commit 3a440ac

Please sign in to comment.