-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev 11 minimal application code #15
base: main
Are you sure you want to change the base?
Changes from 9 commits
ddd4de2
4edbd8f
19c4b85
5ba73a0
9eeec25
828564c
b40f967
e7936d4
aac78c7
8d64834
d8a66d4
2388f04
e53b307
2eecbb0
36a492e
d372f86
518bd9d
6439ada
9cecb48
f68995d
9b81f7e
7c12bf8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,8 +6,26 @@ | |
# coverage report -m | ||
# cd .. | ||
|
||
# run unit tests of ODELIA swarm learning and report coverage | ||
export MPLCONFIGDIR=/tmp | ||
cd tests/unit_tests/controller | ||
python3 -m coverage run --source=/workspace/controller/controller -m unittest discover | ||
coverage report -m | ||
rm .coverage | ||
|
||
# run simulation mode for minimal example | ||
cd /workspace | ||
nvflare simulator -w /tmp/minimal_training_test -n 2 -t 2 application/jobs/minimal_training_test -c simulated_node_0,simulated_node_1 | ||
|
||
# run proof-of-concept mode for minimal example | ||
cd /workspace | ||
nvflare poc prepare -c poc_client_0 poc_client_1 | ||
nvflare poc prepare-jobs-dir -j application/test_jobs/ | ||
nvflare poc start -ex [email protected] | ||
sleep 15 | ||
echo "Will submit job now after sleeping 15 seconds to allow the background process to complete" | ||
nvflare job submit -j application/jobs/minimal_training_test | ||
sleep 60 | ||
echo "Will shut down now after sleeping 60 seconds to allow the background process to complete" | ||
sleep 2 | ||
nvflare poc stop |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
{ | ||
format_version = 2 | ||
app_script = "main.py" | ||
app_config = "" | ||
executors = [ | ||
{ | ||
tasks = [ | ||
"train" | ||
] | ||
executor { | ||
path = "nvflare.app_opt.pt.client_api_launcher_executor.PTClientAPILauncherExecutor" | ||
args { | ||
launcher_id = "launcher" | ||
pipe_id = "pipe" | ||
heartbeat_timeout = 600 | ||
params_exchange_format = "pytorch" | ||
params_transfer_type = "DIFF" | ||
train_with_evaluation = true | ||
} | ||
} | ||
} | ||
{ | ||
# All tasks prefixed with swarm_ are routed to SwarmClientController | ||
tasks = ["swarm_*"] | ||
executor { | ||
# client-side controller for training and logic and aggregation management | ||
path = "controller.SwarmClientController" | ||
args { | ||
# train task must be implemented by Executor | ||
learn_task_name = "train" | ||
# how long to wait for current learn task before timing out the gathering | ||
learn_task_timeout = 600 | ||
# ids must map to corresponding components | ||
persistor_id = "persistor" | ||
aggregator_id = "aggregator" | ||
shareable_generator_id = "shareable_generator" | ||
min_responses_required = 3 | ||
wait_time_after_min_resps_received = 300 | ||
} | ||
} | ||
} | ||
] | ||
task_data_filters = [] | ||
task_result_filters = [] | ||
components = [ | ||
{ | ||
id = "launcher" | ||
path = "nvflare.app_common.launchers.subprocess_launcher.SubprocessLauncher" | ||
args { | ||
script = "python3 custom/{app_script} {app_config} " | ||
launch_once = true | ||
} | ||
} | ||
{ | ||
id = "aggregator" | ||
path = "nvflare.app_common.aggregators.intime_accumulate_model_aggregator.InTimeAccumulateWeightedAggregator" | ||
args { | ||
expected_data_kind = "WEIGHT_DIFF" | ||
} | ||
} | ||
{ | ||
id = "pipe" | ||
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" | ||
args { | ||
mode = "PASSIVE" | ||
site_name = "{SITE_NAME}" | ||
token = "{JOB_ID}" | ||
root_url = "{ROOT_URL}" | ||
secure_mode = "{SECURE_MODE}" | ||
workspace_dir = "{WORKSPACE}" | ||
} | ||
} | ||
{ | ||
id = "persistor" | ||
path = "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor" | ||
args { | ||
model { | ||
path = "models.mini_model.MiniCNNForTesting" | ||
args { | ||
in_ch = 1 | ||
out_ch = 1 | ||
} | ||
} | ||
} | ||
} | ||
{ | ||
id = "shareable_generator" | ||
path = "nvflare.app_common.ccwf.comps.simple_model_shareable_generator.SimpleModelShareableGenerator" | ||
args {} | ||
} | ||
{ | ||
id = "metrics_pipe" | ||
path = "nvflare.fuel.utils.pipe.cell_pipe.CellPipe" | ||
args { | ||
mode = "PASSIVE" | ||
site_name = "{SITE_NAME}" | ||
token = "{JOB_ID}" | ||
root_url = "{ROOT_URL}" | ||
secure_mode = "{SECURE_MODE}" | ||
workspace_dir = "{WORKSPACE}" | ||
} | ||
} | ||
{ | ||
id = "metric_relay" | ||
path = "nvflare.app_common.widgets.metric_relay.MetricRelay" | ||
args { | ||
pipe_id = "metrics_pipe" | ||
event_type = "fed.analytix_log_stats" | ||
read_interval = 0.1 | ||
} | ||
} | ||
{ | ||
id = "model_selector" | ||
path = "nvflare.app_common.widgets.intime_model_selector.IntimeModelSelector" | ||
args { | ||
key_metric = "accuracy" | ||
} | ||
} | ||
{ | ||
id = "config_preparer" | ||
path = "nvflare.app_common.widgets.external_configurator.ExternalConfigurator" | ||
args { | ||
component_ids = [ | ||
"metric_relay" | ||
] | ||
} | ||
} | ||
] | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
format_version = 2 | ||
task_data_filters = [] | ||
task_result_filters = [] | ||
components = [ | ||
{ | ||
# write validation results to json file | ||
id = "json_generator" | ||
path = "nvflare.app_common.widgets.validation_json_generator.ValidationJsonGenerator" | ||
args {} | ||
} | ||
] | ||
workflows = [ | ||
{ | ||
# server-side controller to manage job life cycle | ||
id = "swarm_controller" | ||
path = "controller.SwarmServerController" | ||
args { | ||
# can also set aggregation clients and train clients, see class for all available args | ||
num_rounds = 1 | ||
} | ||
} | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .datamodule import DataModule | ||
|
||
__all__ = ['DataModule'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import pytorch_lightning as pl | ||
import torch | ||
from torch.utils.data.dataloader import DataLoader | ||
import torch.multiprocessing as mp | ||
from torch.utils.data.sampler import WeightedRandomSampler, RandomSampler | ||
|
||
|
||
class DataModule(pl.LightningDataModule): | ||
""" | ||
LightningDataModule for handling dataset loading and batching. | ||
|
||
Attributes: | ||
ds_train (object): Training dataset. | ||
ds_val (object): Validation dataset. | ||
ds_test (object): Test dataset. | ||
batch_size (int): Batch size for dataloaders. | ||
num_workers (int): Number of workers for data loading. | ||
seed (int): Random seed for reproducibility. | ||
pin_memory (bool): If True, pin memory for faster data transfer to GPU. | ||
weights (list): Weights for the weighted random sampler. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
ds_train: object = None, | ||
ds_val: object = None, | ||
ds_test: object = None, | ||
batch_size: int = 1, | ||
num_workers: int = mp.cpu_count(), | ||
seed: int = 0, | ||
pin_memory: bool = False, | ||
weights: list = None | ||
): | ||
""" | ||
Initializes the DataModule with datasets and parameters. | ||
|
||
Args: | ||
ds_train (object, optional): Training dataset. Defaults to None. | ||
ds_val (object, optional): Validation dataset. Defaults to None. | ||
ds_test (object, optional): Test dataset. Defaults to None. | ||
batch_size (int, optional): Batch size. Defaults to 1. | ||
num_workers (int, optional): Number of workers. Defaults to mp.cpu_count(). | ||
seed (int, optional): Random seed. Defaults to 0. | ||
pin_memory (bool, optional): Pin memory. Defaults to False. | ||
weights (list, optional): Weights for sampling. Defaults to None. | ||
""" | ||
super().__init__() | ||
self.hyperparameters = {**locals()} | ||
self.hyperparameters.pop('__class__') | ||
self.hyperparameters.pop('self') | ||
|
||
self.ds_train = ds_train | ||
self.ds_val = ds_val | ||
self.ds_test = ds_test | ||
|
||
self.batch_size = batch_size | ||
self.num_workers = num_workers | ||
self.seed = seed | ||
self.pin_memory = pin_memory | ||
self.weights = weights | ||
|
||
def train_dataloader(self) -> DataLoader: | ||
""" | ||
Returns the training dataloader. | ||
|
||
Returns: | ||
DataLoader: DataLoader for the training dataset. | ||
|
||
Raises: | ||
AssertionError: If the training dataset is not initialized. | ||
""" | ||
generator = torch.Generator() | ||
generator.manual_seed(self.seed) | ||
|
||
if self.ds_train is not None: | ||
if self.weights is not None: | ||
sampler = WeightedRandomSampler(self.weights, len(self.weights), generator=generator) | ||
else: | ||
sampler = RandomSampler(self.ds_train, replacement=False, generator=generator) | ||
return DataLoader( | ||
self.ds_train, batch_size=self.batch_size, num_workers=self.num_workers, | ||
sampler=sampler, generator=generator, drop_last=True, pin_memory=self.pin_memory | ||
) | ||
|
||
raise AssertionError("A training set was not initialized.") | ||
|
||
def val_dataloader(self) -> DataLoader: | ||
""" | ||
Returns the validation dataloader. | ||
|
||
Returns: | ||
DataLoader: DataLoader for the validation dataset. | ||
|
||
Raises: | ||
AssertionError: If the validation dataset is not initialized. | ||
""" | ||
generator = torch.Generator() | ||
generator.manual_seed(self.seed) | ||
if self.ds_val is not None: | ||
return DataLoader( | ||
self.ds_val, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, | ||
generator=generator, drop_last=False, pin_memory=self.pin_memory | ||
) | ||
|
||
raise AssertionError("A validation set was not initialized.") | ||
|
||
def test_dataloader(self) -> DataLoader: | ||
""" | ||
Returns the test dataloader. | ||
|
||
Returns: | ||
DataLoader: DataLoader for the test dataset. | ||
|
||
Raises: | ||
AssertionError: If the test dataset is not initialized. | ||
""" | ||
generator = torch.Generator() | ||
generator.manual_seed(self.seed) | ||
if self.ds_test is not None: | ||
return DataLoader( | ||
self.ds_test, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, | ||
generator=generator, drop_last=False, pin_memory=self.pin_memory | ||
) | ||
|
||
raise AssertionError("A test dataset was not initialized.") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .mini_dataset_for_testing import MiniDatasetForTesting | ||
|
||
__all__ = [MiniDatasetForTesting] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import numpy as np | ||
import uuid | ||
import torch | ||
import torch.utils.data as data | ||
|
||
class MiniDatasetForTesting(data.Dataset): | ||
def __init__(self): | ||
num_entries = 10 | ||
self.data = [{'uid': str(uuid.uuid4()), 'source': self.dummy_image(index), 'target': index % 2} for index in range(num_entries)] | ||
|
||
@staticmethod | ||
def dummy_image(index): | ||
shape = (1, 18, 18) | ||
dtype = np.float16 | ||
if index % 2 == 0: | ||
array = np.zeros(shape, dtype=dtype) | ||
array[0, 0, index] = 1 | ||
else: | ||
array = np.ones(shape, dtype=dtype) | ||
array[0, 0, index] = 0 | ||
return torch.from_numpy(array) | ||
|
||
def __getitem__(self, index): | ||
return self.data[index] | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
def get_labels(self): | ||
return[i['target'] for i in self.data] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import os | ||
from datetime import datetime | ||
|
||
|
||
def load_environment_variables(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can delete this block There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I'll change it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, it is needed. But there is no reason to have a separate file |
||
"""Load environment variables and return them as a dictionary.""" | ||
return { | ||
'scratch_dir': os.getenv('SCRATCH_DIR', '/scratch/'), | ||
'max_epochs': int(os.getenv('MAX_EPOCHS', 100)), | ||
'min_peers': int(os.getenv('MIN_PEERS', 2)), | ||
'max_peers': int(os.getenv('MAX_PEERS', 7)), | ||
'use_adaptive_sync': os.getenv('USE_ADAPTIVE_SYNC', 'False').lower() == 'true', | ||
'sync_frequency': int(os.getenv('SYNC_FREQUENCY', 1024)), | ||
'prediction_flag': os.getenv('PREDICT_FLAG', 'ext') | ||
} | ||
|
||
def create_run_directory(scratch_dir): | ||
current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S") | ||
# make dir if not exist | ||
if not os.path.exists(scratch_dir): | ||
os.makedirs(scratch_dir) | ||
return os.path.join(scratch_dir, f"{current_time}_minimal_training") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we change the name of the script to [config_swarm_client.conf] ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'll change it.
As I copied this from a different job folder, we should also adapt it there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unable. The expected file names are hard-coded in NVFlare.