-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclient.py
90 lines (80 loc) · 3.86 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import flwr as fl
import argparse
from flower_helpers import (train, test, create_model,
get_weights, load_stored_tff, load_data)
from config import (MODEL_NAME, NUM_CLASSES, PRE_TRAINED,
SERVER_ADDRESS, DOUBLE_TRAIN, NUM_CLIENTS,
NUM_ROUNDS, LEARNING_RATE, BATCH_SIZE,
EPOCHS, TFF_DATA_DIR, NON_IID, TEST_SIZE,
TRAIN_SIZE,VAL_PORTION)
class FlowerClient(fl.client.NumPyClient):
def __init__(self, model_config, trainloader, valloader):
self.model_config = model_config
self.parameters = None
self.trainloader = trainloader
self.valloader = valloader
def get_parameters(self):
return self.parameters
def set_parameters(self, parameters):
self.parameters = parameters
def fit(self, parameters, config):
self.set_parameters(parameters)
new_parameters, data_size, metrics = train(self.model_config,
config['local_epochs'],
config['learning_rate'],
parameters,
self.trainloader)
# print('fit metrics:', metrics)
return new_parameters, data_size, metrics
def evaluate(self, parameters, config):
self.set_parameters(parameters)
data_size, metrics = test(self.model_config,
parameters,
self.valloader)
# changing the name of the metric to avoid confusion
metrics['val_accuracy'] = metrics.pop('accuracy')
metrics['val_loss'] = metrics.pop('loss')
return metrics['val_loss'], data_size, metrics
if __name__ == '__main__':
print('loading data...')
if NON_IID:
# non-iid dataset from tff (train and test are already split)
trainloaders, valloaders, testloader = load_stored_tff(TFF_DATA_DIR,
BATCH_SIZE,
DOUBLE_TRAIN)
else:
# iid dataset from huggingface
trainloaders, valloaders, testloader = load_data(MODEL_NAME, TEST_SIZE,
TRAIN_SIZE, VAL_PORTION,
BATCH_SIZE, NUM_CLIENTS,
NUM_CLASSES)
net = create_model(MODEL_NAME, NUM_CLASSES, PRE_TRAINED)
init_weights = get_weights(net)
MODEL_CONFIG = net.config
del net
print('--'*20)
print('num clients:', NUM_CLIENTS)
print('num rounds:', NUM_ROUNDS)
print('--'*20)
print('client training set size:', [len(t.dataset) for t in trainloaders])
print('client validation set size:', [len(v.dataset) for v in valloaders])
print('test set size:', len(testloader.dataset))
print('--'*20)
print('model name:', MODEL_NAME)
print('num classes:', NUM_CLASSES)
print('pre-trained:', PRE_TRAINED)
print('learning rate:', LEARNING_RATE)
print('batch size:', BATCH_SIZE)
print('epochs:', EPOCHS)
print('--'*20)
# Start Flower client
client_no = int(input('Enter client number 0-10 (0-5 if double train size): '))
while client_no < 0 or client_no > 10 or (DOUBLE_TRAIN and client_no > 5):
print('Invalid client number!')
client_no = int(input('Enter client number 0-10 (0-5 if double train size): '))
print(f'Starting Flower client#{client_no}...')
fl.client.start_numpy_client(server_address=SERVER_ADDRESS,
client=FlowerClient(
MODEL_CONFIG,
trainloaders[client_no],
valloaders[client_no]))