Skip to content

Commit

Permalink
address comments; enable selection of evaluation client
Browse files Browse the repository at this point in the history
  • Loading branch information
holgerroth committed Apr 10, 2024
1 parent 06f2f6f commit b9e1f13
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 38 deletions.
13 changes: 5 additions & 8 deletions research/fed-bpt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,7 @@ with the addition of installing NVFlare for running federated learning and some
```commandline
conda create --name fedbpt python=3.8
conda activate fedbpt
pip install transformers==4.38.2
pip install fastNLP==0.6.0
pip install datasets
pip install cma
pip install scikit-learn
pip install tensorboard
pip install cvxopt
pip install nvflare==2.4.1rc
pip install -r requirements.txt
```

## 2. Run a federated learning experiment
Expand Down Expand Up @@ -51,16 +44,20 @@ nvflare job create -force -j "./jobs/fedbpt" -w "fedbpt" -sd "./src" \
--local_popsize 5 \
--perturb 1 \
--model_name roberta-large \
--eval_clients site-1 \
--llama_causal 1" \
-f app/config/config_fed_server.conf min_clients=${N_CLIENTS} num_rounds=200 seed=${SEED}
```
By default, we only evaluate the global model on client `site-1` as in our setting, the global test set is shared by clients.

Start the FL simulator with `N_CLIENTS` clients in parallel.
The following setting requires a GPU with at least 24 GB memory and enough system memory to run the clients in parallel (we recommend at least 40 GB).
For a system with less resources, you can set -t to be a lower number and simulate the clients running sequentially.
```commandline
OUT_DIR="/tmp/nvflare/fedbpt"
nvflare simulator ./jobs/fedbpt -n ${N_CLIENTS} -t ${N_CLIENTS} -w ${OUT_DIR}
```
If you have more GPUs available on your system, you can use the `--gpu` argument of the simulator to run clients on different GPUs in parallel.

## 3. Example results
The training results showing the global testing accuracy over 200 rounds is shown below.
Expand Down
10 changes: 4 additions & 6 deletions research/fed-bpt/job_templates/fedbpt/config_fed_client.conf
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,20 @@
# This executor needs Pipe component
pipe_id = "pipe"

# Timeout in seconds for waiting for a heartbeat from the training script. Defaults to 30 seconds.
# Timeout in seconds for waiting for a heartbeat from the training script.
# Please refer to the class docstring for all available arguments
heartbeat_timeout = 120

# format of the exchange parameters
params_exchange_format = "numpy"

# if the transfer_type is FULL, then it will be sent directly
# if the transfer_type is DIFF, then we will calculate the
# difference VS received parameters and send the difference
# For GlobalES, the transfer_type is FULL.
params_transfer_type = "FULL"

# if train_with_evaluation is true, the executor will expect
# the custom code need to send back both the trained parameters and the evaluation metric
# otherwise only trained parameters are expected
train_with_evaluation = true
train_with_evaluation = false

}
}
Expand Down Expand Up @@ -117,7 +115,7 @@
}
},
{
# we use this component so the client api `flare.init()` can get required information
# We serialize CMAEvolutionStrategy object directly. This requires registering custom decomposers.
id = "register_decomposer"
path = "decomposer_widget.RegisterDecomposer"
args {}
Expand Down
7 changes: 2 additions & 5 deletions research/fed-bpt/job_templates/fedbpt/config_fed_server.conf
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@
# task result filter: if filters are provided, the filter will filter the result flow out of client to server.
task_result_filters = []

# This assumes that there will be a "net.py" file with class name "Net".
# If your model code is not in "net.py" and class name is not "Net", please modify here
# model_class_path = "net.Net"

# workflows: Array of workflows the control the Federated Learning workflow lifecycle.
# One can specify multiple workflows. The NVFLARE will run them in the order specified.
workflows = [
Expand Down Expand Up @@ -40,12 +36,13 @@
# List of components used in the server side workflow.
components = [
{
# Receive streamed tensorboard metrics
id = "receiver"
path = "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver"
args.events = ["fed.analytix_log_stats"]
},
{
# we use this component so the client api `flare.init()` can get required information
# We serialize CMAEvolutionStrategy object directly. This requires registering custom decomposers.
id = "register_decomposer"
path = "decomposer_widget.RegisterDecomposer"
args {}
Expand Down
8 changes: 8 additions & 0 deletions research/fed-bpt/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
transformers==4.38.2
fastNLP==0.6.0
datasets
cma
scikit-learn
tensorboard
cvxopt
nvflare~=2.4.1rc
48 changes: 29 additions & 19 deletions research/fed-bpt/src/fedbpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@
parser.add_argument("--cat_or_add", default="add", type=str)
parser.add_argument("--parallel", action="store_true", help="Whether to allow parallel evaluation")
# fl args
parser.add_argument(
"--eval_clients",
default="site-1",
type=str,
help="Provide the name of client that should evaluate the global model. Can be comma-spearated list, e.g., `site-1,site-2`",
)
parser.add_argument("--num_users", default=10, type=int)
parser.add_argument("--iid", default=1, type=int)
parser.add_argument("--local_popsize", default=20, type=int)
Expand Down Expand Up @@ -120,6 +126,7 @@
bound = args.bound
sigma = args.sigma
alpha = args.alpha
eval_clients = args.eval_clients.split(",")

if args.local_popsize > 0:
args.local_popsize = args.local_popsize
Expand Down Expand Up @@ -227,25 +234,28 @@
)
local_sigma_current = global_es.sigma

print("Global es evaluate on test data...")
global_api_setting["best_prompt"] = local_es.mean
model_forward_api.load_client_record(global_api_setting)
global_test_acc = model_forward_api.eval(prompt_embedding=local_es.mean, test_data=test_data)
print("Global test acc: {}".format(round(global_test_acc, 4)))
print("Global prompt norm: {}".format(np.linalg.norm(local_es.mean)))
writer.add_scalar("global_test_acc", global_test_acc, current_round)

if args.norm_prompt and np.linalg.norm(local_es.mean) < args.prompt_norm_threshold_upper:
args.prompt_norm_threshold += 1
model_forward_api.args = args
print("Set prompt_norm_threshold as {}".format(args.prompt_norm_threshold))
if args.save_prompt:
if global_test_acc > best_test_acc:
best_test_acc = global_test_acc
torch.save(
model_forward_api.model.prompt_embedding.cpu().detach(),
"results/llama/sst2/larger_global_pop_new_sigma_pert/fl_prompt.pt",
)
if flare.get_site_name() in eval_clients:
print("Global es evaluate on test data...")
global_api_setting["best_prompt"] = local_es.mean
model_forward_api.load_client_record(global_api_setting)
global_test_acc = model_forward_api.eval(prompt_embedding=local_es.mean, test_data=test_data)
print("Global test acc: {}".format(round(global_test_acc, 4)))
print("Global prompt norm: {}".format(np.linalg.norm(local_es.mean)))
writer.add_scalar("global_test_acc", global_test_acc, current_round)

if args.norm_prompt and np.linalg.norm(local_es.mean) < args.prompt_norm_threshold_upper:
args.prompt_norm_threshold += 1
model_forward_api.args = args
print("Set prompt_norm_threshold as {}".format(args.prompt_norm_threshold))
if args.save_prompt:
if global_test_acc > best_test_acc:
best_test_acc = global_test_acc
torch.save(
model_forward_api.model.prompt_embedding.cpu().detach(),
"results/llama/sst2/larger_global_pop_new_sigma_pert/fl_prompt.pt",
)
else:
global_test_acc = None

client_sigmas = {}

Expand Down

0 comments on commit b9e1f13

Please sign in to comment.