Skip to content

Commit

Permalink
add client executor launcher; upgrade flwr scripts to 1.7.0 versions
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Mar 4, 2024
1 parent 6434435 commit a2bb01a
Show file tree
Hide file tree
Showing 6 changed files with 476 additions and 264 deletions.
379 changes: 237 additions & 142 deletions examples/advanced/flower/fedprox/flower_fedprox.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,9 @@
"train"
]
executor {
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor"
path = "executor_launcher.ExecutorLauncher"
args {
launcher_id = "launcher"
pipe_id = "pipe"
heartbeat_timeout = 60
params_exchange_format = "pytorch"
params_transfer_type = "DIFF"
train_with_evaluation = true
}
}
}
Expand All @@ -31,47 +26,5 @@
launch_once = true
}
}
{
id = "pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
}
{
id = "metrics_pipe"
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe"
args {
mode = "PASSIVE"
site_name = "{SITE_NAME}"
token = "{JOB_ID}"
root_url = "{ROOT_URL}"
secure_mode = "{SECURE_MODE}"
workspace_dir = "{WORKSPACE}"
}
}
{
id = "metric_relay"
path = "nvflare.app_common.widgets.metric_relay.MetricRelay"
args {
pipe_id = "metrics_pipe"
event_type = "fed.analytix_log_stats"
read_interval = 0.1
}
}
{
id = "config_preparer"
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator"
args {
component_ids = [
"metric_relay"
]
}
}
]
}
173 changes: 108 additions & 65 deletions examples/advanced/flower/fedprox/jobs/flwr_cifar10/app/custom/client.py
Original file line number Diff line number Diff line change
@@ -1,98 +1,141 @@
from collections import OrderedDict
import argparse
import warnings
from collections import OrderedDict

import flwr as fl
from flwr_datasets import FederatedDataset
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm


# #############################################################################
# Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader
# #############################################################################

warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)


def train(net, trainloader, epochs):
"""Train the model on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for _ in range(epochs):
for images, labels in trainloader:
print("train...")
optimizer.zero_grad()
criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
optimizer.step()
"""Train the model on the training set."""
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for _ in range(epochs):
for batch in tqdm(trainloader, "Training"):
images = batch["img"]
labels = batch["label"]
optimizer.zero_grad()
criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
optimizer.step()


def test(net, testloader):
"""Validate the model on the test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, total, loss = 0, 0, 0.0
with torch.no_grad():
for images, labels in testloader:
outputs = net(images.to(DEVICE))
loss += criterion(outputs, labels.to(DEVICE)).item()
total += labels.size(0)
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
return loss / len(testloader.dataset), correct / total

def load_data():
"""Load CIFAR-10 (training and test set)."""
trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = CIFAR10("./data", train=True, download=True, transform=trf)
testset = CIFAR10("./data", train=False, download=True, transform=trf)
return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)
"""Validate the model on the test set."""
criterion = torch.nn.CrossEntropyLoss()
correct, loss = 0, 0.0
with torch.no_grad():
for batch in tqdm(testloader, "Testing"):
images = batch["img"].to(DEVICE)
labels = batch["label"].to(DEVICE)
outputs = net(images)
loss += criterion(outputs, labels).item()
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
accuracy = correct / len(testloader.dataset)
return loss, accuracy


def load_data(node_id):
"""Load partition CIFAR10 data."""
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3})
partition = fds.load_partition(node_id)
# Divide data on each node: 80% train, 20% test
partition_train_test = partition.train_test_split(test_size=0.2)
pytorch_transforms = Compose(
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch

partition_train_test = partition_train_test.with_transform(apply_transforms)
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True)
testloader = DataLoader(partition_train_test["test"], batch_size=32)
return trainloader, testloader


# #############################################################################
# Federating the pipeline with Flower
# 2. Federation of the pipeline with Flower
# #############################################################################

# Get node id
#parser = argparse.ArgumentParser(description="Flower")
#parser.add_argument(
# "--node-id",
# choices=[0, 1, 2],
# required=True,
# type=int,
# help="Partition of the dataset divided into 3 iid partitions created artificially.",
#)
#node_id = parser.parse_args().node_id
node_id = np.random.randint(0,3)
print(f"START FLOWER CLIENT [node_id={node_id}]")

# Load model and data (simple CNN, CIFAR-10)
net = Net().to(DEVICE)
trainloader, testloader = load_data()
trainloader, testloader = load_data(node_id=node_id)


# Define Flower client
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)

def set_parameters(self, parameters):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
net.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config):
self.set_parameters(parameters)
train(net, trainloader, epochs=1)
return self.get_parameters(config={}), len(trainloader.dataset), {}

def fit(self, parameters, config):
self.set_parameters(parameters)
train(net, trainloader, epochs=1)
return self.get_parameters(config={}), len(trainloader.dataset), {}
def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(net, testloader)
return loss, len(testloader.dataset), {"accuracy": accuracy}

def evaluate(self, parameters, config):
self.set_parameters(parameters)
loss, accuracy = test(net, testloader)
return float(loss), len(testloader.dataset), {"accuracy": float(accuracy)}

# Start Flower client
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=FlowerClient(), insecure=True) # "127.0.0.1:8080"
fl.client.start_client(
server_address="127.0.0.1:8080",
client=FlowerClient().to_client(),
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nvflare.app_common.abstract.launcher import Launcher, LauncherRunStatus
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.app_constant import AppConstants, ValidateType
from nvflare.fuel.utils.validation_utils import check_object_type


Expand All @@ -40,9 +41,14 @@ class ControllerLauncher(ModelController):
- def run(self)
"""

def __init__(self, launcher_id):
def __init__(self,
launcher_id,
task_name=AppConstants.TASK_TRAIN
):
super().__init__()
self._launcher_id = launcher_id
self._task_name = task_name
self.is_initialized = False

def _init_launcher(self, fl_ctx: FLContext):
engine = fl_ctx.get_engine()
Expand All @@ -51,18 +57,27 @@ def _init_launcher(self, fl_ctx: FLContext):
raise RuntimeError(f"Launcher can not be found using {self._launcher_id}")
check_object_type(self._launcher_id, launcher, Launcher)
self.launcher = launcher
self.is_initialized = True

def run(self):
self.info("Start Controller Launcher.")

self._init_launcher(self.fl_ctx)
if not self.is_initialized:
self._init_launcher(self.fl_ctx)

#self.launcher.launch_task("train", shareable=Shareable(), fl_ctx=self.fl_ctx, abort_signal=self.abort_signal)
self.launcher.initialize(fl_ctx=self.fl_ctx)

while True:
time.sleep(10)
print(f"Running task ... [{self.launcher._script}]")

time.sleep(10.0)
run_status = self.launcher.check_run_status(task_name=self._task_name, fl_ctx=self.fl_ctx)
if run_status == LauncherRunStatus.RUNNING:
print(f"Running ... [{self.launcher._script}]")
elif run_status == LauncherRunStatus.COMPLETE_SUCCESS:
print("run success")
break
else:
print(f"run failed or not start: {run_status}")
break
self.launcher.finalize(fl_ctx=self.fl_ctx)
self.info("Stop Controller Launcher.")

Loading

0 comments on commit a2bb01a

Please sign in to comment.