forked from NVIDIA/NVFlare
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d5e4b13
commit 3a440ac
Showing
11 changed files
with
556 additions
and
4 deletions.
There are no files selected for viewing
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
114 changes: 114 additions & 0 deletions
114
...advanced/flower/cifar10/jobs/flwr_cifar10_tb_streaming/app2/config/config_fed_client.conf
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
{ | ||
# version of the configuration | ||
format_version = 2 | ||
|
||
# Client Computing Executors. | ||
executors = [ | ||
{ | ||
# tasks the executors are defined to handle | ||
tasks = ["train"] | ||
|
||
# This particular executor | ||
executor { | ||
|
||
# Executor name : ClientAPILauncherExecutor | ||
# This is an executor for Client API. The underline data exchange is using Pipe. | ||
path = "nvflare.app_common.executors.client_api_launcher_executor.ClientAPILauncherExecutor" | ||
|
||
args { | ||
|
||
# This executor take an component named "launcher" | ||
launcher_id = "launcher" | ||
|
||
# This executor needs Pipe component | ||
pipe_id = "pipe" | ||
|
||
# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds. | ||
# Please refer to the class docstring for all available arguments | ||
heartbeat_timeout = 60 | ||
|
||
# format of the exchange parameters | ||
params_exchange_format = "numpy" | ||
|
||
# if the transfer_type is FULL, then it will be sent directly | ||
# if the transfer_type is DIFF, then we will calculate the | ||
# difference VS received parameters and send the difference | ||
params_transfer_type = "FULL" | ||
|
||
# if train_with_evaluation is true, the executor will expect | ||
# the custom code need to send back both the trained parameters and the evaluation metric | ||
# otherwise only trained parameters are expected | ||
train_with_evaluation = true | ||
|
||
} | ||
} | ||
} | ||
], | ||
|
||
# this defined an array of task data filters. If provided, it will control the data from server controller to client executor | ||
task_data_filters = [] | ||
|
||
# this defined an array of task result filters. If provided, it will control the result from client executor to server controller | ||
task_result_filters = [] | ||
|
||
components = [ | ||
{ | ||
# component id is "launcher" | ||
id = "launcher" | ||
|
||
# the class path of this component | ||
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" | ||
|
||
args { | ||
# the launcher will invoke the script | ||
script = "python3 custom/client.py --node-id 1" | ||
# if launch_once is true, the SubprocessLauncher will launch once for the whole job | ||
# if launch_once is false, the SubprocessLauncher will launch a process for each task it receives from server | ||
launch_once = true | ||
} | ||
}, | ||
{ | ||
id = "pipe" | ||
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" | ||
args { | ||
mode = "PASSIVE" | ||
site_name = "{SITE_NAME}" | ||
token = "{JOB_ID}" | ||
root_url = "{ROOT_URL}" | ||
secure_mode = "{SECURE_MODE}" | ||
workspace_dir = "{WORKSPACE}" | ||
} | ||
}, | ||
{ | ||
id = "metrics_pipe" | ||
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" | ||
args { | ||
mode = "PASSIVE" | ||
site_name = "{SITE_NAME}" | ||
token = "{JOB_ID}" | ||
root_url = "{ROOT_URL}" | ||
secure_mode = "{SECURE_MODE}" | ||
workspace_dir = "{WORKSPACE}" | ||
} | ||
}, | ||
{ | ||
id = "metric_relay" | ||
path = "nvflare.app_common.widgets.metric_relay.MetricRelay" | ||
args { | ||
pipe_id = "metrics_pipe" | ||
event_type = "fed.analytix_log_stats" | ||
# how fast should it read from the peer | ||
read_interval = 0.1 | ||
} | ||
}, | ||
{ | ||
# we use this component so the client api `flare.init()` can get required information | ||
id = "client_api_config_preparer" | ||
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" | ||
args { | ||
component_ids = ["metric_relay"] | ||
} | ||
} | ||
] | ||
} | ||
|
161 changes: 161 additions & 0 deletions
161
examples/advanced/flower/cifar10/jobs/flwr_cifar10_tb_streaming/app2/custom/client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
import argparse | ||
import warnings | ||
from collections import OrderedDict | ||
|
||
import flwr as fl | ||
from flwr_datasets import FederatedDataset | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.utils.data import DataLoader | ||
from torchvision.transforms import Compose, Normalize, ToTensor | ||
from tqdm import tqdm | ||
|
||
import nvflare.client as flare | ||
from nvflare.client.tracking import SummaryWriter | ||
writer = SummaryWriter() | ||
|
||
TRAIN_STEP = 0 | ||
TEST_STEP = 0 | ||
|
||
# ############################################################################# | ||
# 1. Regular PyTorch pipeline: nn.Module, train, test, and DataLoader | ||
# ############################################################################# | ||
|
||
warnings.filterwarnings("ignore", category=UserWarning) | ||
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
class Net(nn.Module): | ||
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" | ||
|
||
def __init__(self) -> None: | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(3, 6, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(6, 16, 5) | ||
self.fc1 = nn.Linear(16 * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, 10) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(-1, 16 * 5 * 5) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
return self.fc3(x) | ||
|
||
def train(net, trainloader, epochs): | ||
"""Train the model on the training set.""" | ||
criterion = torch.nn.CrossEntropyLoss() | ||
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) | ||
|
||
global TRAIN_STEP | ||
for _ in range(epochs): | ||
avg_loss = 0.0 | ||
for batch in tqdm(trainloader, "Training"): | ||
images = batch["img"] | ||
labels = batch["label"] | ||
optimizer.zero_grad() | ||
loss = criterion(net(images.to(DEVICE)), labels.to(DEVICE)) | ||
loss.backward() | ||
optimizer.step() | ||
avg_loss += loss.item() | ||
writer.add_scalar("train_loss", avg_loss/len(trainloader), TRAIN_STEP) | ||
TRAIN_STEP += 1 | ||
|
||
|
||
def test(net, testloader): | ||
"""Validate the model on the test set.""" | ||
criterion = torch.nn.CrossEntropyLoss() | ||
correct, loss = 0, 0.0 | ||
global TEST_STEP | ||
with torch.no_grad(): | ||
for batch in tqdm(testloader, "Testing"): | ||
images = batch["img"].to(DEVICE) | ||
labels = batch["label"].to(DEVICE) | ||
outputs = net(images) | ||
loss += criterion(outputs, labels).item() | ||
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item() | ||
accuracy = correct / len(testloader.dataset) | ||
writer.add_scalar("test_loss", loss, TEST_STEP) | ||
writer.add_scalar("test_accuracy", accuracy, TEST_STEP) | ||
TEST_STEP += 1 | ||
return loss, accuracy | ||
|
||
|
||
def load_data(node_id): | ||
"""Load partition CIFAR10 data.""" | ||
fds = FederatedDataset(dataset="cifar10", partitioners={"train": 3}) | ||
partition = fds.load_partition(node_id) | ||
# Divide data on each node: 80% train, 20% test | ||
partition_train_test = partition.train_test_split(test_size=0.2) | ||
pytorch_transforms = Compose( | ||
[ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] | ||
) | ||
|
||
def apply_transforms(batch): | ||
"""Apply transforms to the partition from FederatedDataset.""" | ||
batch["img"] = [pytorch_transforms(img) for img in batch["img"]] | ||
return batch | ||
|
||
partition_train_test = partition_train_test.with_transform(apply_transforms) | ||
trainloader = DataLoader(partition_train_test["train"], batch_size=32, shuffle=True) | ||
testloader = DataLoader(partition_train_test["test"], batch_size=32) | ||
return trainloader, testloader | ||
|
||
|
||
# ############################################################################# | ||
# 2. Federation of the pipeline with Flower | ||
# ############################################################################# | ||
|
||
# Get node id | ||
parser = argparse.ArgumentParser(description="Flower") | ||
parser.add_argument( | ||
"--node-id", | ||
choices=[0, 1, 2], | ||
required=True, | ||
type=int, | ||
help="Partition of the dataset divided into 3 iid partitions created artificially.", | ||
) | ||
node_id = parser.parse_args().node_id | ||
|
||
# Load model and data (simple CNN, CIFAR-10) | ||
net = Net().to(DEVICE) | ||
trainloader, testloader = load_data(node_id=node_id) | ||
|
||
|
||
# Define Flower client | ||
class FlowerClient(fl.client.NumPyClient): | ||
def get_parameters(self, config): | ||
return [val.cpu().numpy() for _, val in net.state_dict().items()] | ||
|
||
def set_parameters(self, parameters): | ||
params_dict = zip(net.state_dict().keys(), parameters) | ||
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) | ||
net.load_state_dict(state_dict, strict=True) | ||
|
||
def fit(self, parameters, config): | ||
self.set_parameters(parameters) | ||
train(net, trainloader, epochs=5) | ||
return self.get_parameters(config={}), len(trainloader.dataset), {} | ||
|
||
def evaluate(self, parameters, config): | ||
self.set_parameters(parameters) | ||
loss, accuracy = test(net, testloader) | ||
return loss, len(testloader.dataset), {"accuracy": accuracy} | ||
|
||
|
||
# initializes NVFlare interface | ||
flare.init() | ||
|
||
# get system information | ||
sys_info = flare.system_info() | ||
print(f"Flare system info is: {sys_info}") | ||
|
||
# Start Flower client | ||
fl.client.start_client( | ||
server_address="127.0.0.1:8080", | ||
client=FlowerClient().to_client(), | ||
) |
Oops, something went wrong.