-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
00e1d40
commit dc4a17d
Showing
159 changed files
with
38,523 additions
and
0 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
{ | ||
"dataset": "fed_isic2019", | ||
"results_file": "results_isic42.csv", | ||
"strategies": { | ||
"Cyclic": { | ||
"deterministic_cycle": false, | ||
"learning_rate": 0.0031622776601683, | ||
"optimizer_class": "torch.optim.SGD" | ||
}, | ||
"FedAdagrad": { | ||
"learning_rate": 0.01, | ||
"optimizer_class": "torch.optim.SGD", | ||
"server_learning_rate": 0.0316227766016837 | ||
}, | ||
"FedAdam": { | ||
"learning_rate": 0.01, | ||
"optimizer_class": "torch.optim.SGD", | ||
"server_learning_rate": 0.0031622776601683 | ||
}, | ||
"FedAvg": { | ||
"learning_rate": 0.01, | ||
"optimizer_class": "torch.optim.SGD" | ||
}, | ||
"FedProx": { | ||
"learning_rate": 0.01, | ||
"mu": 0.001, | ||
"optimizer_class": "torch.optim.SGD" | ||
}, | ||
"FedYogi": { | ||
"learning_rate": 0.01, | ||
"optimizer_class": "torch.optim.SGD", | ||
"server_learning_rate": 0.0031622776601683 | ||
}, | ||
"Scaffold": { | ||
"learning_rate": 0.01, | ||
"optimizer_class": "torch.optim.SGD", | ||
"server_learning_rate": 1.0 | ||
}, | ||
"FedAvgFineTuning": { | ||
"learning_rate": 0.01, | ||
"optimizer_class": "torch.optim.SGD", | ||
"num_fine_tuning_steps": 3 | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import argparse | ||
|
||
from MedicalDiagnosis.utils import create_config, write_value_in_config | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--path", type=str, help="The path where the dataset is located", required=True | ||
) | ||
parser.add_argument( | ||
"--dataset-name", | ||
type=str, | ||
help="The name of the dataset you downloaded", | ||
required=True, | ||
) | ||
parser.add_argument( | ||
"--debug", | ||
action="store_true", | ||
help="whether or not to update the config fro debug mode or the real one.", | ||
) | ||
args = parser.parse_args() | ||
dict, config_file = create_config(args.path, args.debug, args.dataset_name) | ||
write_value_in_config(config_file, "dataset_path", args.path) | ||
write_value_in_config(config_file, "download_complete", True) | ||
write_value_in_config(config_file, "preprocessing_complete", True) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
from torch.nn.modules.loss import _Loss | ||
from torch.utils.data import DataLoader, Dataset | ||
from torchvision import models | ||
|
||
|
||
class FedDummyDataset(Dataset): | ||
def __init__( | ||
self, | ||
center=0, | ||
train=True, | ||
pooled=False, | ||
X_dtype=torch.float32, | ||
y_dtype=torch.float32, | ||
debug=False, | ||
): | ||
super().__init__() | ||
self.X_dtype = X_dtype | ||
self.y_dtype = y_dtype | ||
self.size = (center + 1) * 10 * 42 | ||
self.centers = center | ||
|
||
def __len__(self): | ||
return self.size | ||
|
||
def __getitem__(self, idx): | ||
return ( | ||
torch.rand(3, 224, 224).to(self.X_dtype), | ||
torch.randint(0, 2, (1,)).to(self.y_dtype), | ||
) | ||
|
||
|
||
class BaselineLoss(_Loss): | ||
def __init__(self, reduction="mean"): | ||
super(BaselineLoss, self).__init__(reduction=reduction) | ||
self.bce = torch.nn.BCEWithLogitsLoss() | ||
|
||
def forward(self, input: torch.Tensor, target: torch.Tensor): | ||
return self.bce(input, target) | ||
|
||
|
||
class Baseline(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.architecture = models.mobilenet_v2(pretrained=False) | ||
nftrs = [m for m in self.architecture.classifier.modules()][-1].in_features | ||
self.architecture.classifier = torch.nn.Linear(nftrs, 1) | ||
|
||
def forward(self, X): | ||
return self.architecture(X) | ||
|
||
|
||
if __name__ == "__main__": | ||
m = Baseline() | ||
lo = BaselineLoss() | ||
dl = DataLoader( | ||
FedDummyDataset(center=1, train=True), batch_size=32, shuffle=True, num_workers=0 | ||
) | ||
it = iter(dl) | ||
X, y = next(it) | ||
opt = torch.optim.SGD(m.parameters(), lr=1.0) | ||
y_pred = m(X) | ||
ls = lo(y_pred, y) | ||
ls.backward() | ||
opt.step() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import sys | ||
sys.path.insert(0,"/home/lixiang/FLamby-main") | ||
|
||
import torch | ||
from MedicalDiagnosis.utils import evaluate_model_on_tests | ||
|
||
print(torch.Tensor([1,2]).cuda()) | ||
# 2 lines of code to change to switch to another dataset | ||
from MedicalDiagnosis.datasets.fed_isic2019 import ( | ||
BATCH_SIZE, | ||
LR, | ||
NUM_EPOCHS_POOLED, | ||
Baseline, | ||
BaselineLoss, | ||
metric, | ||
NUM_CLIENTS, | ||
get_nb_max_rounds | ||
) | ||
from MedicalDiagnosis.datasets.fed_isic2019 import FedIsic2019 as FedDataset | ||
|
||
# 1st line of code to change to switch to another strategy | ||
from MedicalDiagnosis.strategies.fed_prox import FedProx as strat | ||
|
||
# We loop on all the clients of the distributed dataset and instantiate associated data loaders | ||
train_dataloaders = [ | ||
torch.utils.data.DataLoader( | ||
FedDataset(center = i, train = True, pooled = False), | ||
batch_size = BATCH_SIZE, | ||
shuffle = True, | ||
num_workers = 0 | ||
) | ||
for i in range(NUM_CLIENTS) | ||
] | ||
full_dataset = FedDataset(train = False, pooled = True) | ||
valid_size = int(0.25 * len(full_dataset)) | ||
test_size = len(full_dataset) -valid_size | ||
valid_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [valid_size, test_size]) | ||
print(len(valid_dataset), len(test_dataset)) | ||
test_dataloaders = [ | ||
torch.utils.data.DataLoader( | ||
test_dataset, | ||
batch_size = BATCH_SIZE, | ||
shuffle = False, | ||
num_workers = 0, | ||
) | ||
] | ||
valid_dataloaders = [ | ||
torch.utils.data.DataLoader( | ||
valid_dataset, | ||
batch_size = BATCH_SIZE, | ||
shuffle = False, | ||
num_workers = 0, | ||
) | ||
] | ||
lossfunc = BaselineLoss() | ||
m = Baseline() | ||
|
||
# Federated Learning loop | ||
# 2nd line of code to change to switch to another strategy (feed the FL strategy the right HPs) | ||
args = { | ||
"training_dataloaders": train_dataloaders, | ||
"valid_dataloaders": valid_dataloaders, | ||
"test_dataloaders": test_dataloaders, | ||
"model": m, | ||
"loss": lossfunc, | ||
"optimizer_class": torch.optim.SGD, | ||
"learning_rate": 0.01, | ||
"num_updates": 100, | ||
# This helper function returns the number of rounds necessary to perform approximately as many | ||
# epochs on each local dataset as with the pooled training | ||
"nrounds": 25, | ||
} | ||
s = strat(**args) | ||
seeds = [20,21,22,23,24] | ||
for seed in seeds: | ||
m = s.run(seed)[0] | ||
|
||
# Evaluation | ||
# We only instantiate one test set in this particular case: the pooled one | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import sys | ||
sys.path.insert(0,"/home/lixiang/FLamby-main") | ||
|
||
import torch | ||
|
||
from MedicalDiagnosis.utils import evaluate_model_on_tests | ||
print(torch.Tensor([1,2]).cuda()) | ||
# 2 lines of code to change to switch to another dataset | ||
from MedicalDiagnosis.datasets.fed_isic2019 import ( | ||
BATCH_SIZE, | ||
LR, | ||
NUM_EPOCHS_POOLED, | ||
Baseline, | ||
BaselineLoss, | ||
metric, | ||
NUM_CLIENTS, | ||
get_nb_max_rounds | ||
) | ||
from MedicalDiagnosis.datasets.fed_isic2019 import FedIsic2019 as FedDataset | ||
|
||
# 1st line of code to change to switch to another strategy | ||
from MedicalDiagnosis.strategies.fed_lsv import FedLSV as strat | ||
|
||
# We loop on all the clients of the distributed dataset and instantiate associated data loaders | ||
train_dataloaders = [ | ||
torch.utils.data.DataLoader( | ||
FedDataset(center = i, train = True, pooled = False), | ||
batch_size = BATCH_SIZE, | ||
shuffle = True, | ||
num_workers = 0 | ||
) | ||
for i in range(NUM_CLIENTS) | ||
] | ||
full_dataset = FedDataset(train = False, pooled = True) | ||
valid_size = int(0.25 * len(full_dataset)) | ||
test_size = len(full_dataset) - valid_size | ||
print(valid_size, test_size, len(full_dataset)) | ||
print(sum([9930,3163,2691,1807,655,351])) | ||
valid_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [valid_size, test_size]) | ||
print(len(valid_dataset), len(test_dataset)) | ||
test_dataloaders = [ | ||
torch.utils.data.DataLoader( | ||
test_dataset, | ||
batch_size = BATCH_SIZE, | ||
shuffle = False, | ||
num_workers = 0, | ||
) | ||
] | ||
valid_dataloaders = [ | ||
torch.utils.data.DataLoader( | ||
valid_dataset, | ||
batch_size = BATCH_SIZE, | ||
shuffle = False, | ||
num_workers = 0, | ||
) | ||
] | ||
|
||
lossfunc = BaselineLoss() | ||
m = Baseline() | ||
|
||
# Federated Learning loop | ||
# 2nd line of code to change to switch to another strategy (feed the FL strategy the right HPs) | ||
args = { | ||
"training_dataloaders": train_dataloaders, | ||
"valid_dataloaders": valid_dataloaders, | ||
"test_dataloaders": test_dataloaders, | ||
"model": m, | ||
"loss": lossfunc, | ||
"optimizer_class": torch.optim.SGD, | ||
"learning_rate": 0.01, | ||
"num_updates": 100, | ||
# This helper function returns the number of rounds necessary to perform approximately as many | ||
# epochs on each local dataset as with the pooled training | ||
"nrounds": 40, | ||
} | ||
s = strat(**args) | ||
seeds = [20,21,22,23,24] | ||
for seed in seeds: | ||
m = s.run(seed)[0] | ||
|
||
# Evaluation | ||
# We only instantiate one test set in this particular case: the pooled one | ||
|
Oops, something went wrong.