Skip to content

Commit

Permalink
add three-site results
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Nov 4, 2024
1 parent a5fbd07 commit 5010db0
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 39 deletions.
42 changes: 27 additions & 15 deletions examples/advanced/llm_hf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,6 @@ Centralized trainings, as the baseline for comparison with other results, are do
python3 ./utils/hf_sft_peft.py --output_path ./workspace/llama-3.2-1b-dolly-cen_sft --mode 0
python3 ./utils/hf_sft_peft.py --output_path ./workspace/llama-3.2-1b-dolly-cen_peft --mode 1
```
### Pre: Launch Modes
Before we start adapting the local training script to federated application, we first need to understand the launch modes of NVFlare client API.
In our [client settings](../../../job_templates/sag_pt/config_fed_client.conf), we have two launch modes by switching the `--launch_once` flag:
* If launch_once is true, the SubprocessLauncher will launch an external process once for the whole job
* If launch_once is false, the SubprocessLauncher will launch an external process everytime it receives a task from server
So if it is false, the SubprocessLauncher will create new processes every round.
If it is true, the SubprocessLauncher will reuse the same process for all rounds.

Turning `launch_once` to `false` can be useful in some scenarios like quick prototyping, but for the application of LLM where setup stage can take significant resources, we would want to only setup once. Hence, the below steps are for `launch_once = true` scenario.

See [Client API](../../hello-world/ml-to-fl/pt/README.md) for more details.

### Adaptation Step 1: iterative training
To adapt the centralized training script to federated application, under `launch_once = true` setting, we first need to "break" the single call to `trainer.train()` into iterative calls, one for each round of training.
Expand Down Expand Up @@ -98,8 +87,8 @@ The major code modifications are for receiving and returning the global model (r
### Federated Training Results
We run the federated training on a single client using NVFlare Simulator via [JobAPI](../job_api/README.md).
```
python3 sft_job.py --data_path ${PWD}/dataset/dolly --workspace_dir ${PWD}/workspace/hf_sft --job_dir ${PWD}/workspace/jobs/hf_sft --train_mode 0
python3 sft_job.py --data_path ${PWD}/dataset/dolly --workspace_dir ${PWD}/workspace/hf_peft --job_dir ${PWD}/workspace/jobs/hf_peft --train_mode 1
python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft --job_dir ${PWD}/workspace/jobs/hf_sft --train_mode 0
python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_peft --job_dir ${PWD}/workspace/jobs/hf_peft --train_mode 1
```
The SFT curves are shown below, black for centralized results, magenta for FL training. With some training randomness, the two PEFT training loss curves align with each other.
![sft](./figs/fl_sft.png)
Expand All @@ -110,8 +99,8 @@ Similar patterns can be observed from the PEFT curves, purple for centralized re
## Model Precision Conversion for Communication
In the above example, we used float32 for communication. To reduce the message size, we can use model precision conversion for communication. Model conversion is enabled by NVFlare's [filter mechanism](https://nvflare.readthedocs.io/en/main/programming_guide/filters.html). We can use the following command to run the federated training with model precision conversion:
```
python3 sft_job_compress.py --data_path ${PWD}/dataset/dolly --workspace_dir ${PWD}/workspace/hf_sft_compress --job_dir ${PWD}/workspace/jobs/hf_sft_compress --train_mode 0
python3 sft_job_compress.py --data_path ${PWD}/dataset/dolly --workspace_dir ${PWD}/workspace/hf_peft_compress --job_dir ${PWD}/workspace/jobs/hf_peft_compress --train_mode 1
python3 sft_job_compress.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_compress --job_dir ${PWD}/workspace/jobs/hf_sft_compress --train_mode 0
python3 sft_job_compress.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_peft_compress --job_dir ${PWD}/workspace/jobs/hf_peft_compress --train_mode 1
```
The SFT curves are shown below, black for centralized results, yellow for FL training with compression. We can see it achieves similar alignment with centralized result.
![sft](./figs/fl_sft_comp.png)
Expand All @@ -126,3 +115,26 @@ For message reduce, since we convert float32 to float16, the message size is red
```shell
Compressed all 147 params Before compression: 5993930752 bytes After compression: 2996965376 bytes
```

## Federated Training with Multiple Clients
With the above example, we can easily extend the federated training to multiple clients. We can use the following command to run the federated training with multiple clients:
```
python3 sft_job.py --client_ids dolly alpaca oasst1 --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_multi --job_dir ${PWD}/workspace/jobs/hf_sft_multi --train_mode 0 --threads 1
python3 sft_job_compress.py --client_ids dolly alpaca oasst1 --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_multi_compress --job_dir ${PWD}/workspace/jobs/hf_sft_multi_compress --train_mode 0 --threads 1
```

For comparison, we run the other two sites in centralized training mode:
```
python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/alpaca/training.jsonl --data_path_valid ./dataset/alpaca/validation.jsonl --output_path ./workspace/llama-3.2-1b-alpaca-cen_sft --mode 0
python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/oasst1/training.jsonl --data_path_valid ./dataset/oasst1/validation.jsonl --output_path ./workspace/llama-3.2-1b-oasst1-cen_peft --mode 0
```

The training loss curves are shown below:
Dolly:
![sft](./figs/fl_sft_dolly.png)
Alpaca:
![sft](./figs/fl_sft_alpaca.png)
Oasst1:
![sft](./figs/fl_sft_oasst1.png)

As shown, federated training with multiple clients (lines with three sections) can achieve comparable or better results w.r.t. training loss to individual site's centralized trainings, demonstrating the effectiveness of NVFlare for LLM tuning tasks.
Binary file added examples/advanced/llm_hf/figs/fl_sft_alpaca.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/advanced/llm_hf/figs/fl_sft_dolly.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/advanced/llm_hf/figs/fl_sft_oasst1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
42 changes: 31 additions & 11 deletions examples/advanced/llm_hf/sft_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,24 @@
def main():
args = define_parser()
train_script = "src/hf_sft_peft_fl.py"
num_clients = args.num_clients
client_ids = args.client_ids
num_clients = len(client_ids)

if args.threads:
num_threads = args.threads
else:
num_threads = num_clients

if num_threads < num_clients:
print("The number of threads smaller than the number of clients, runner clean-up will be performed.")
clean_up = 1
else:
clean_up = 0

num_rounds = args.num_rounds
workspace_dir = args.workspace_dir
job_dir = args.job_dir
model_name_or_path = args.model_name_or_path
data_path_train = os.path.join(args.data_path, "training.jsonl")
data_path_valid = os.path.join(args.data_path, "validation.jsonl")
train_mode = args.train_mode

# Create the FedJob
Expand Down Expand Up @@ -61,10 +72,13 @@ def main():

# Send ScriptRunner to all clients
for i in range(num_clients):
site_name = f"site-{i}"
client_id = client_ids[i]
site_name = f"site-{client_id}"
data_path_train = os.path.join(args.data_path, client_id, "training.jsonl")
data_path_valid = os.path.join(args.data_path, client_id, "validation.jsonl")
runner = ScriptRunner(
script=train_script,
script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --mode {train_mode}",
script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --mode {train_mode} --clean_up {clean_up}",
)
job.to(runner, site_name, tasks=["train"])

Expand All @@ -74,16 +88,18 @@ def main():

# Run the job
print("workspace_dir=", workspace_dir)
job.simulator_run(workspace_dir)
print("num_threads=", num_threads)
job.simulator_run(workspace_dir, threads=num_threads)


def define_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num_clients",
type=int,
default=1,
help="Number of clients, default to 1",
"--client_ids",
nargs="+",
type=str,
default="",
help="Clinet IDs, used to get the data path for each client",
)
parser.add_argument(
"--num_rounds",
Expand Down Expand Up @@ -121,7 +137,11 @@ def define_parser():
default=0,
help="training mode, 0: SFT, 1: PEFT, default to 0",
)

parser.add_argument(
"--threads",
type=int,
help="number of threads to use for FL simulation, default to the number of clients",
)
return parser.parse_args()


Expand Down
42 changes: 31 additions & 11 deletions examples/advanced/llm_hf/sft_job_compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,24 @@
def main():
args = define_parser()
train_script = "src/hf_sft_peft_fl.py"
num_clients = args.num_clients
client_ids = args.client_ids
num_clients = len(client_ids)

if args.threads:
num_threads = args.threads
else:
num_threads = num_clients

if args.threads < num_clients:
print("The number of threads smaller than the number of clients, runner clean-up will be performed.")
clean_up = 1
else:
clean_up = 0

num_rounds = args.num_rounds
workspace_dir = args.workspace_dir
job_dir = args.job_dir
model_name_or_path = args.model_name_or_path
data_path_train = os.path.join(args.data_path, "training.jsonl")
data_path_valid = os.path.join(args.data_path, "validation.jsonl")
train_mode = args.train_mode

# Create the FedJob
Expand Down Expand Up @@ -68,10 +79,13 @@ def main():

# Send ScriptRunner to all clients
for i in range(num_clients):
site_name = f"site-{i}"
client_id = client_ids[i]
site_name = f"site-{client_id}"
data_path_train = os.path.join(args.data_path, client_id, "training.jsonl")
data_path_valid = os.path.join(args.data_path, client_id, "validation.jsonl")
runner = ScriptRunner(
script=train_script,
script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --mode {train_mode}",
script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --mode {train_mode} --clean_up {clean_up}",
)
job.to(runner, site_name, tasks=["train"])
job.to(compressor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
Expand All @@ -83,16 +97,18 @@ def main():

# Run the job
print("workspace_dir=", workspace_dir)
job.simulator_run(workspace_dir)
print("num_threads=", num_threads)
job.simulator_run(workspace_dir, threads=num_threads)


def define_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--num_clients",
type=int,
default=1,
help="Number of clients, default to 1",
"--client_ids",
nargs="+",
type=str,
default="",
help="Clinet IDs, used to get the data path for each client",
)
parser.add_argument(
"--num_rounds",
Expand Down Expand Up @@ -130,7 +146,11 @@ def define_parser():
default=0,
help="training mode, 0: SFT, 1: PEFT, default to 0",
)

parser.add_argument(
"--threads",
type=int,
help="number of threads to use for FL simulation, default to the number of clients",
)
return parser.parse_args()


Expand Down
12 changes: 10 additions & 2 deletions examples/advanced/llm_hf/src/hf_sft_peft_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def main():
default="./workspace_federated/llama-3.2-1b-dolly-sft",
)
parser.add_argument("--mode", type=int, default=0)
parser.add_argument("--local_epoch", type=int, default=1)
parser.add_argument("--clean_up", type=int, default=0)
args = parser.parse_args()

# Dataset
Expand Down Expand Up @@ -114,7 +116,7 @@ def main():
# Training arguments
train_args = SFTConfig(
output_dir=args.output_path,
num_train_epochs=1,
num_train_epochs=args.local_epoch,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gra_accu_steps,
gradient_checkpointing=False,
Expand Down Expand Up @@ -192,7 +194,13 @@ def evaluate(input_weights, mode):
# Disable safetensor for now
trainer.model.save_pretrained(resume_from_checkpoint_folder, safe_serialization=False)
# increment num_train_epochs so that the trainer will continue training
trainer.args.num_train_epochs += 1
if args.clean_up:
# runner got cleaned up, set num_train_epochs with curr_round
trainer.args.num_train_epochs = (curr_round + 1) * args.local_epoch
else:
# runner still alive, increment num_train_epochs with local_epoch
trainer.args.num_train_epochs += args.local_epoch
print(f"Increment num_train_epochs to {trainer.args.num_train_epochs}")
# continue training
trainer.train(resume_from_checkpoint=True)

Expand Down

0 comments on commit 5010db0

Please sign in to comment.