Skip to content

Commit

Permalink
add controller launcher with flower
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Mar 1, 2024
1 parent 1ddcdcb commit a2bd72b
Show file tree
Hide file tree
Showing 10 changed files with 636 additions and 0 deletions.
2 changes: 2 additions & 0 deletions examples/advanced/flower/fedprox/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Flower launcher example with FedProx

242 changes: 242 additions & 0 deletions examples/advanced/flower/fedprox/flower_fedprox.ipynb

Large diffs are not rendered by default.

97 changes: 97 additions & 0 deletions examples/advanced/flower/fedprox/flwr_scripts/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from collections import OrderedDict
import warnings

import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

# #############################################################################
# Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
# #############################################################################

warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)

def train(net, trainloader, epochs):
"""Train the model on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for _ in range(epochs):
for images, labels in trainloader:
optimizer.zero_grad()
criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
optimizer.step()

def test(net, testloader):
"""Validate the model on the test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
with torch.no_grad():
for images, labels in testloader:
outputs = net(images.to(DEVICE))
loss += criterion(outputs, labels.to(DEVICE)).item()
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
return loss / len(testloader.dataset), correct / total

def load_data():
"""Load CIFAR-10 (training and test set)."""
trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = CIFAR10("./data", train=True, download=True, transform=trf)
testset = CIFAR10("./data", train=False, download=True, transform=trf)
return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)

# #############################################################################
# Federating the pipeline with Flower
# #############################################################################

# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
trainloader, testloader = load_data()

# Define Flower client
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)

def fit(self, parameters, config):
self.set_parameters(parameters)
train(net, trainloader, epochs=1)
return self.get_parameters(config={}), len(trainloader.dataset), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(net, testloader)
return float(loss), len(testloader.dataset), {"accuracy": float(accuracy)}

# Start Flower client
fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient(), insecure=True)
7 changes: 7 additions & 0 deletions examples/advanced/flower/fedprox/flwr_scripts/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import flwr as fl

# Start Flower server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=3),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
{
format_version = 2
app_script = "client.py"
app_config = ""
executors = [
{
tasks = [
"train"
]
executor {
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor"
args {
launcher_id = "launcher"
pipe_id = "pipe"
heartbeat_timeout = 60
params_exchange_format = "pytorch"
params_transfer_type = "DIFF"
train_with_evaluation = true
}
}
}
]
task_data_filters = []
task_result_filters = []
components = [
{
id = "launcher"
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"
args {
script = "python3 custom/{app_script} {app_config} "
launch_once = true
}
}
{
id = "pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
}
{
id = "metrics_pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
}
{
id = "metric_relay"
path = "nvflare.app_common.widgets.metric_relay.MetricRelay"
args {
pipe_id = "metrics_pipe"
event_type = "fed.analytix_log_stats"
read_interval = 0.1
}
}
{
id = "config_preparer"
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator"
args {
component_ids = [
"metric_relay"
]
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
{
format_version = 2
task_data_filters = []
task_result_filters = []
app_script = "server.py"
app_config = ""
workflows = [
{
id = "controller_launcher"
path = "controller_launcher.ControllerLauncher"
args {
launcher_id = "launcher"
}
}
]
components = [
{
id = "launcher"
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher"
args {
script = "python3 custom/{app_script} {app_config} "
launch_once = true
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from collections import OrderedDict
import warnings

import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

# #############################################################################
# Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
# #############################################################################

warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)

def train(net, trainloader, epochs):
"""Train the model on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for _ in range(epochs):
for images, labels in trainloader:
print("train...")
optimizer.zero_grad()
criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
optimizer.step()

def test(net, testloader):
"""Validate the model on the test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
with torch.no_grad():
for images, labels in testloader:
outputs = net(images.to(DEVICE))
loss += criterion(outputs, labels.to(DEVICE)).item()
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
return loss / len(testloader.dataset), correct / total

def load_data():
"""Load CIFAR-10 (training and test set)."""
trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = CIFAR10("./data", train=True, download=True, transform=trf)
testset = CIFAR10("./data", train=False, download=True, transform=trf)
return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)

# #############################################################################
# Federating the pipeline with Flower
# #############################################################################

# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
trainloader, testloader = load_data()

# Define Flower client
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)

def fit(self, parameters, config):
self.set_parameters(parameters)
train(net, trainloader, epochs=1)
return self.get_parameters(config={}), len(trainloader.dataset), {}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(net, testloader)
return float(loss), len(testloader.dataset), {"accuracy": float(accuracy)}

# Start Flower client
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=FlowerClient(), insecure=True) # "127.0.0.1:8080"
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

from nvflare.app_common.workflows.model_controller import ModelController
from nvflare.app_common.abstract.launcher import Launcher, LauncherRunStatus
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.fuel.utils.validation_utils import check_object_type


class ControllerLauncher(ModelController):
"""The base controller for FedAvg Workflow. *Note*: This class is based on the experimental `ModelController`.
Implements [FederatedAveraging](https://arxiv.org/abs/1602.05629).
The model persistor (persistor_id) is used to load the initial global model which is sent to a list of clients.
Each client sends it's updated weights after local training which is aggregated.
Next, the global model is updated.
The model_persistor also saves the model after training.
Provides the default implementations for the follow routines:
- def sample_clients(self, min_clients)
- def aggregate(self, results: List[FLModel], aggregate_fn=None) -> FLModel
- def update_model(self, aggr_result)
The `run` routine needs to be implemented by the derived class:
- def run(self)
"""

def __init__(self, launcher_id):
super().__init__()
self._launcher_id = launcher_id

def _init_launcher(self, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
launcher: Launcher = engine.get_component(self._launcher_id)
if launcher is None:
raise RuntimeError(f"Launcher can not be found using {self._launcher_id}")
check_object_type(self._launcher_id, launcher, Launcher)
self.launcher = launcher

def run(self):
self.info("Start Controller Launcher.")

self._init_launcher(self.fl_ctx)

#self.launcher.launch_task("train", shareable=Shareable(), fl_ctx=self.fl_ctx, abort_signal=self.abort_signal)
self.launcher.initialize(fl_ctx=self.fl_ctx)

while True:
time.sleep(10)
print(f"Running task ... [{self.launcher._script}]")

self.info("Stop Controller Launcher.")

Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import flwr as fl

# Start Flower server
#fl.server.start_server(
# server_address="0.0.0.0:8080",
# config=fl.server.ServerConfig(num_rounds=3),
#)
print("Run Server code...")
Loading

0 comments on commit a2bd72b

Please sign in to comment.