Skip to content

Commit

Permalink
run training with in process executor
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Apr 11, 2024
1 parent 9b59d97 commit 2e1afae
Show file tree
Hide file tree
Showing 3 changed files with 355 additions and 0 deletions.
136 changes: 136 additions & 0 deletions examples/hello-world/python_jobs/client-api-pt/cifar10_fl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from net import Net

# (1) import nvflare client API
import nvflare.client as flare

# (optional) metrics
from nvflare.client.tracking import SummaryWriter

# (optional) set a fix place so we don't need to download everytime
DATASET_PATH = "/tmp/nvflare/data"
# If available, we use GPU to speed things up.
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


def main():
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4
epochs = 2

trainset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=DATASET_PATH, train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

net = Net()

# (2) initializes NVFlare client API
flare.init()

summary_writer = SummaryWriter()
while flare.is_running():
# (3) receives FLModel from NVFlare
input_model = flare.receive()
print(f"current_round={input_model.current_round}")

# (4) loads model from NVFlare
net.load_state_dict(input_model.params)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

# (optional) use GPU to speed things up
net.to(DEVICE)
# (optional) calculate total steps
steps = epochs * len(trainloader)
for epoch in range(epochs): # loop over the dataset multiple times

running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
# (optional) use GPU to speed things up
inputs, labels = data[0].to(DEVICE), data[1].to(DEVICE)

# zero the parameter gradients
optimizer.zero_grad()

# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
global_step = input_model.current_round * steps + epoch * len(trainloader) + i

summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step)
running_loss = 0.0

print("Finished Training")

PATH = "./cifar_net.pth"
torch.save(net.state_dict(), PATH)

# (5) wraps evaluation logic into a method to re-use for
# evaluation on both trained and received model
def evaluate(input_weights):
net = Net()
net.load_state_dict(input_weights)
# (optional) use GPU to speed things up
net.to(DEVICE)

correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in testloader:
# (optional) use GPU to speed things up
images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
# calculate outputs by running images through the network
outputs = net(images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f"Accuracy of the network on the 10000 test images: {100 * correct // total} %")
return 100 * correct // total

# (6) evaluate on received model for model selection
accuracy = evaluate(input_model.params)
# (7) construct trained FL model
output_model = flare.FLModel(
params=net.cpu().state_dict(),
metrics={"accuracy": accuracy},
meta={"NUM_STEPS_CURRENT_ROUND": steps},
)
# (8) send model back to NVFlare
flare.send(output_model)


if __name__ == "__main__":
main()
181 changes: 181 additions & 0 deletions examples/hello-world/python_jobs/client-api-pt/client_api_pt_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import uuid
from typing import Union, List

from nvflare.job_config.fed_app_config import ClientAppConfig, ServerAppConfig, FedAppConfig
from nvflare.job_config.fed_job_config import FedJobConfig
from nvflare.app_common.widgets.validation_json_generator import ValidationJsonGenerator
from nvflare.app_common.workflows.cross_site_model_eval import CrossSiteModelEval
from nvflare.app_opt.pt import PTFileModelPersistor
from nvflare.app_opt.pt.client_api_launcher_executor import PTClientAPILauncherExecutor
from nvflare.app_common.launchers.subprocess_launcher import SubprocessLauncher
from nvflare.app_common.widgets.metric_relay import MetricRelay
from nvflare.app_common.widgets.external_configurator import ExternalConfigurator
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.file_model_locator import PTFileModelLocator
from nvflare.fuel.utils.constants import Mode
from nvflare.fuel.utils.pipe.file_pipe import FilePipe
from net import Net
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_opt.tracking.tb.tb_receiver import TBAnalyticsReceiver
from nvflare.app_opt.pt.in_process_client_api_executor import PTInProcessClientAPIExecutor


class FedApp:
def __init__(self):
self.app = None # Union[ClientAppConfig, ServerAppConfig]

def get_app_config(self):
return self.app


class FedJob:
def __init__(self, name="client-api-pt", workspace="/tmp/nvflare/simulator_workspace") -> None:
self.job_name = name
self.workspace = workspace

self.job: FedJobConfig = FedJobConfig(job_name=self.job_name, min_clients=1)

def to(self, app: FedApp, target):
if not isinstance(app, FedApp):
raise ValueError(f"App needs to be of type `FedApp` but was type {type(app)}")

client_server_config = app.get_app_config()
if isinstance(client_server_config, ClientAppConfig):
app_config = FedAppConfig(server_app=None, client_app=client_server_config)
app_name = f"app_client_{(uuid.uuid4())}"
elif isinstance(client_server_config, ServerAppConfig):
app_config = FedAppConfig(server_app=client_server_config, client_app=None)
app_name = f"app_server_{(uuid.uuid4())}"
else:
raise ValueError(f"App needs to be of type `ClientAppConfig` or `ServerAppConfig` but was type {type(client_server_config)}")

self.job.add_fed_app(app_name, app_config)
self.job.set_site_app(target, app_name)

def export_job(self, job_root):
self.job.generate_job_config(job_root)

def simulator_run(self, job_root):
self.job.simulator_run(job_root, self.workspace, threads=2)


class ExecutorApp(FedApp):
def __init__(self, executors: List):
super().__init__()
self.executors = executors
#self._site_name = f"site_{uuid.uuid4()}"
#self._job_id = f"{uuid.uuid4()}"
self._create_client_app()

def _create_client_app(self):
self.app = ClientAppConfig()

for _executor in self.executors:
self.app.add_executor(["train"], _executor)

component = FilePipe(
mode="PASSIVE", #TODO: enable passing Mode.PASSIVE
root_path="{WORKSPACE}/{JOB_ID}/{SITE_NAME}" # TODO: this creates empty subfolder structure
)
self.app.add_component("metrics_pipe", component)

component = MetricRelay(
pipe_id="metrics_pipe",
event_type="fed.analytix_log_stats",
read_interval=0.1
)
self.app.add_component("metric_relay", component)

component = ExternalConfigurator(
component_ids=["metric_relay"]
)
self.app.add_component("config_preparer", component)

#self.app.add_component("net", Net()) # TODO: find another way to register files that need to be included in custom folder
self.app.add_ext_script("cifar10_fl.py")


class ControllerApp(FedApp):
def __init__(self, controllers: List, model_class_path=Net()):
super().__init__()
self.controllers = controllers
self.model_class_path = model_class_path
self._create_server_app()

def _create_server_app(self):
self.app: ServerAppConfig = ServerAppConfig()

for i, _controller in enumerate(self.controllers):
self.app.add_workflow(f"controller_{i}", _controller)

# TODO: make optional or list of controllers? # add as new example
#_controller = CrossSiteModelEval(model_locator_id="model_locator")
#self.app.add_workflow("cross_site_validate", _controller)

component = PTFileModelPersistor(
model=Net()
)

self.app.add_component("persistor", component)

component = PTFileModelLocator(
pt_persistor_id="persistor"
)
self.app.add_component("model_locator", component)

component = ValidationJsonGenerator()
self.app.add_component("json_generator", component)

component = IntimeModelSelector(
key_metric="accuracy"
)
self.app.add_component("model_selector", component)

component = TBAnalyticsReceiver(
events=["fed.analytix_log_stats"]
)
self.app.add_component("receiver", component)

#self.app.add_component("net", Net()) # TODO: have another way to register needed scripts


if __name__ == "__main__":
n_clients = 2
num_rounds = 2

job = FedJob(name="cifar10_fedavg")

controller = FedAvg(
min_clients=n_clients,
num_rounds=num_rounds,
persistor_id="persistor" # TODO: why is it not using default
)
server_app = ControllerApp(controllers=[controller])
job.to(server_app, "server")

for i in range(n_clients):
executor = PTInProcessClientAPIExecutor(
task_fn_path="cifar10_fl.main",
#task_fn_args={"batch_size": 1}
)
client_app = ExecutorApp(executors=[executor])
job.to(client_app, f"site-{i}")

job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir")


38 changes: 38 additions & 0 deletions examples/hello-world/python_jobs/client-api-pt/net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
#from cifar10_fl import main # TODO: remove
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

0 comments on commit 2e1afae

Please sign in to comment.