Skip to content

Commit

Permalink
update readme and scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
ZiyueXu77 committed Nov 6, 2024
1 parent 5ef5e90 commit 61ad146
Show file tree
Hide file tree
Showing 14 changed files with 90 additions and 32 deletions.
43 changes: 33 additions & 10 deletions examples/advanced/llm_hf/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# LLM Tuning via HuggingFace SFT Trainer
# LLM Tuning via HuggingFace SFT/PEFT APIs
This example shows how to use [NVIDIA FLARE](https://nvidia.github.io/NVFlare) for Large Language Models (LLMs) tuning tasks. It illustrates how to adapt a local training script with [HuggingFace](https://huggingface.co/) trainer to NVFlare.

## Introduction
Expand All @@ -14,12 +14,20 @@ We would like to showcase two key points in this example:

We conducted these experiments on a single 48GB RTX 6000 Ada GPU.

To use Llama-3.2-1B model, please request access to the model here https://huggingface.co/meta-llama/Llama-3.2-1B and login with an access token using huggingface-cli.


## Setup
Please make sure you set up virtual environment following [example root readme](../../README.md).
Install additional requirements (if you already have a specific version of nvflare installed in your environment, you may want to remove nvflare in the requirements to avoid reinstalling nvflare):
```
python3 -m pip install -r requirements.txt
```
flash-attn cannot be installed together with others, and needs to be installed alone after the above step:
```
python3 -m pip install flash-attn --no-build-isolation
```
Git LFS is also necessary for downloads, please follow the steps in this [link](https://github.com/git-lfs/git-lfs/blob/main/INSTALLING.md).

## Data Preparation
We download and preprocess (consistent with our [NeMo example](../../../integration/nemo/examples/supervised_fine_tuning/README.md), we follow the same preprocessing steps).
Expand All @@ -41,17 +49,16 @@ To illustrate the adaptation process, we use a single dataset [databricks-dolly-
### One-call training
Centralized trainings, as the baseline for comparison with other results, are done with the following command:
```
python3 ./utils/hf_sft_peft.py --output_path ./workspace/llama-3.2-1b-dolly-cen_sft --mode 0
python3 ./utils/hf_sft_peft.py --output_path ./workspace/llama-3.2-1b-dolly-cen_peft --mode 1
python3 ./utils/hf_sft_peft.py --output_path ./workspace/llama-3.2-1b-dolly-cen_sft --train_mode SFT
python3 ./utils/hf_sft_peft.py --output_path ./workspace/llama-3.2-1b-dolly-cen_peft --train_mode PEFT
```

### Adaptation Step 1: iterative training
To adapt the centralized training script to federated application, under `launch_once = true` setting, we first need to "break" the single call to `trainer.train()` into iterative calls, one for each round of training.
To adapt the centralized training script to federated application, we first need to "break" the single call to `trainer.train()` into iterative calls, one for each round of training.
For this purpose, we provided `utils/hf_sft_peft_iter.py` as an example, which is a modified version of `utils/hf_sft_peft.py`.
Their differences are highlighted below:

![diff](./figs/diff_1.png)
![diff](./figs/diff_2.png)
![diff](./figs/diff.png)

Note that the `trainer.train()` call is replaced by a `for` loop, and the three training epochs becomes three rounds, one epoch per round.

Expand All @@ -65,8 +72,8 @@ If the intended model weights (serving as the starting point for each round, the

To run iterative training, we use the following command:
```
python3 ./utils/hf_sft_peft_iter.py --output_path ./workspace/llama-3.2-1b-dolly-cen_sft-iter --mode 0
python3 ./utils/hf_sft_peft_iter.py --output_path ./workspace/llama-3.2-1b-dolly-cen_peft-iter --mode 1
python3 ./utils/hf_sft_peft_iter.py --output_path ./workspace/llama-3.2-1b-dolly-cen_sft-iter --train_mode SFT
python3 ./utils/hf_sft_peft_iter.py --output_path ./workspace/llama-3.2-1b-dolly-cen_peft-iter --train_mode PEFT
```

The SFT curves are shown below, black for single call, blue for iterative. We can see the "zig-zag" pattern in the iterative training loss curve.
Expand Down Expand Up @@ -103,8 +110,8 @@ python3 sft_job.py --client_ids dolly alpaca oasst1 --data_path ${PWD}/dataset -

For comparison, we run the other two sites in centralized training mode:
```
python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/alpaca/training.jsonl --data_path_valid ./dataset/alpaca/validation.jsonl --output_path ./workspace/llama-3.2-1b-alpaca-cen_sft --mode 0
python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/oasst1/training.jsonl --data_path_valid ./dataset/oasst1/validation.jsonl --output_path ./workspace/llama-3.2-1b-oasst1-cen_sft --mode 0
python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/alpaca/training.jsonl --data_path_valid ./dataset/alpaca/validation.jsonl --output_path ./workspace/llama-3.2-1b-alpaca-cen_sft --train_mode SFT
python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/oasst1/training.jsonl --data_path_valid ./dataset/oasst1/validation.jsonl --output_path ./workspace/llama-3.2-1b-oasst1-cen_sft --train_mode SFT
```

The training loss curves are shown below:
Expand All @@ -117,3 +124,19 @@ Oasst1:
![sft](./figs/fl_sft_oasst1.png)

As shown, federated training with multiple clients (lines with three sections) can achieve comparable or better results w.r.t. training loss to individual site's centralized trainings (continuous curves), demonstrating the effectiveness of federated learning.

Similarly for PEFT, we can run the following command:
```
python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/alpaca/training.jsonl --data_path_valid ./dataset/alpaca/validation.jsonl --output_path ./workspace/llama-3.2-1b-alpaca-cen_peft --train_mode PEFT
python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/oasst1/training.jsonl --data_path_valid ./dataset/oasst1/validation.jsonl --output_path ./workspace/llama-3.2-1b-oasst1-cen_peft --train_mode PEFT
python3 sft_job.py --client_ids dolly alpaca oasst1 --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_peft_multi --job_dir ${PWD}/workspace/jobs/hf_peft_multi --train_mode PEFT --threads 1
```

The training loss curves are shown below:

Dolly:
![peft](./figs/peft_dolly.png)
Alpaca:
![peft](./figs/peft_alpaca.png)
Oasst1:
![peft](./figs/peft_oasst1.png)
Binary file added examples/advanced/llm_hf/figs/diff.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed examples/advanced/llm_hf/figs/diff_1.png
Binary file not shown.
Binary file removed examples/advanced/llm_hf/figs/diff_2.png
Binary file not shown.
Binary file modified examples/advanced/llm_hf/figs/diff_fl_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/advanced/llm_hf/figs/diff_fl_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/advanced/llm_hf/figs/peft_alpaca.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/advanced/llm_hf/figs/peft_dolly.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/advanced/llm_hf/figs/peft_oasst1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 1 addition & 2 deletions examples/advanced/llm_hf/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,4 @@ tensorboard
transformers
peft
trl
flash-attn
bitsandbytes
bitsandbytes
4 changes: 1 addition & 3 deletions examples/advanced/llm_hf/sft_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,9 @@ def main():
if train_mode.lower() == "sft":
job = FedJob(name="llm_hf_sft", min_clients=num_clients)
output_path = "sft"
mode = 0
elif train_mode.lower() == "peft":
job = FedJob(name="llm_hf_peft", min_clients=num_clients)
output_path = "peft"
mode = 1
else:
raise ValueError(f"Invalid train_mode: {train_mode}, only SFT and PEFT are supported.")

Expand Down Expand Up @@ -82,7 +80,7 @@ def main():
data_path_valid = os.path.join(args.data_path, client_id, "validation.jsonl")
runner = ScriptRunner(
script=train_script,
script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --mode {mode} --clean_up {clean_up}",
script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --clean_up {clean_up}",
)
job.to(runner, site_name, tasks=["train"])

Expand Down
30 changes: 22 additions & 8 deletions examples/advanced/llm_hf/src/hf_sft_peft_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ def main():
type=str,
default="./workspace_federated/llama-3.2-1b-dolly-sft",
)
parser.add_argument("--mode", type=int, default=0)
parser.add_argument(
"--train_mode",
type=str,
default="SFT",
help="training mode, SFT or PEFT, default to SFT",
)
parser.add_argument("--local_epoch", type=int, default=1)
parser.add_argument("--clean_up", type=int, default=0)
args = parser.parse_args()
Expand All @@ -88,15 +93,22 @@ def main():
torch.set_default_dtype(torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
attn_implementation="flash_attention_2",
device_map="auto",
use_cache=False,
torch_dtype=torch.bfloat16,
)
torch.set_default_dtype(default_dtype)

# Train mode
if args.train_mode.lower() == "sft":
train_mode = 0
elif args.train_mode.lower() == "peft":
train_mode = 1
else:
raise ValueError(f"Invalid train_mode: {args.train_mode}, only SFT and PEFT are supported.")

# PEFT specific
if args.mode:
if train_mode:
# PEFT configs
peft_config = LoraConfig(
lora_alpha=16,
Expand Down Expand Up @@ -149,6 +161,8 @@ def main():
# initializes NVFlare client API
flare.init()

# Train federated rounds
# start with global model at the beginning of each round
while flare.is_running():
# receives FLModel from NVFlare
input_model = flare.receive()
Expand All @@ -164,7 +178,7 @@ def main():
# evaluation on both trained and received model
def evaluate(input_weights, mode):
# Special load func for PEFT
if mode:
if train_mode:
set_peft_model_state_dict(trainer.model, input_weights)
else:
trainer.model.load_state_dict(input_weights)
Expand All @@ -173,7 +187,7 @@ def evaluate(input_weights, mode):
return metric_score

# evaluate on received global model
eval_loss = evaluate(global_model, args.mode)
eval_loss = evaluate(global_model, train_mode)
eval_loss = float(eval_loss["eval_loss"])

# Load global model and previous training states
Expand All @@ -185,7 +199,7 @@ def evaluate(input_weights, mode):
else:
# replace local resume weights with global weights
resume_from_checkpoint_folder = trainer_utils.get_last_checkpoint(trainer.args.output_dir)
if args.mode:
if train_mode:
# PEFT model small, directly save via torch.save
resume_model_file_path = os.path.join(resume_from_checkpoint_folder, utils.WEIGHTS_NAME)
torch.save(global_model, resume_model_file_path)
Expand All @@ -205,15 +219,15 @@ def evaluate(input_weights, mode):
trainer.train(resume_from_checkpoint=True)

# compose output model to send back to server
if args.mode:
if train_mode:
# PEFT, load PEFT part from trainer model
out_param = get_peft_model_state_dict(trainer.model)
else:
# SFT, load whole model state_dict
out_param = trainer.model.state_dict()

# update the key name sent to global model
if not args.mode:
if not train_mode:
for key in list(out_param.keys()):
out_param["model." + key] = out_param.pop(key).cpu()

Expand Down
18 changes: 15 additions & 3 deletions examples/advanced/llm_hf/utils/hf_sft_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ def main():
type=str,
default="./workspace_centralized/llama-3.2-1b-dolly-sft",
)
parser.add_argument("--mode", type=int, default=0)
parser.add_argument(
"--train_mode",
type=str,
default="SFT",
help="training mode, SFT or PEFT, default to SFT",
)
args = parser.parse_args()

# Dataset
Expand All @@ -82,15 +87,22 @@ def main():
torch.set_default_dtype(torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
attn_implementation="flash_attention_2",
device_map="auto",
use_cache=False,
torch_dtype=torch.bfloat16,
)
torch.set_default_dtype(default_dtype)

# Train mode
if args.train_mode.lower() == "sft":
train_mode = 0
elif args.train_mode.lower() == "peft":
train_mode = 1
else:
raise ValueError(f"Invalid train_mode: {args.train_mode}, only SFT and PEFT are supported.")

# PEFT specific
if args.mode:
if train_mode:
# PEFT configs
peft_config = LoraConfig(
lora_alpha=16,
Expand Down
24 changes: 18 additions & 6 deletions examples/advanced/llm_hf/utils/hf_sft_peft_iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ def main():
type=str,
default="./workspace_centralized/llama-3.2-1b-dolly-sft-iter",
)
parser.add_argument("--mode", type=int, default=0)
parser.add_argument(
"--train_mode",
type=str,
default="SFT",
help="training mode, SFT or PEFT, default to SFT",
)
args = parser.parse_args()

# Dataset
Expand All @@ -83,15 +88,22 @@ def main():
torch.set_default_dtype(torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
attn_implementation="flash_attention_2",
device_map="auto",
use_cache=False,
torch_dtype=torch.bfloat16,
)
torch.set_default_dtype(default_dtype)

# Train mode
if args.train_mode.lower() == "sft":
train_mode = 0
elif args.train_mode.lower() == "peft":
train_mode = 1
else:
raise ValueError(f"Invalid train_mode: {args.train_mode}, only SFT and PEFT are supported.")

# PEFT specific
if args.mode:
if train_mode:
# PEFT configs
peft_config = LoraConfig(
lora_alpha=16,
Expand Down Expand Up @@ -144,7 +156,7 @@ def main():
# Save base model state_dict, which will be used as the starting
# weights for each round - to show the weights are loaded correctly
initial_model_path = os.path.join(args.output_path, "model_dict_base.pt")
if args.mode:
if train_mode:
params = get_peft_model_state_dict(model)
else:
params = model.state_dict()
Expand All @@ -157,7 +169,7 @@ def main():

# Load and Evaluate model file
state_dict_replace = torch.load(initial_model_path, map_location="cpu", weights_only=True)
if args.mode:
if train_mode:
set_peft_model_state_dict(trainer.model, state_dict_replace)
else:
trainer.model.load_state_dict(state_dict_replace)
Expand All @@ -170,7 +182,7 @@ def main():
else:
# replace local resume weights with global weights
resume_from_checkpoint_folder = trainer_utils.get_last_checkpoint(trainer.args.output_dir)
if args.mode:
if train_mode:
# PEFT model small, directly save via torch.save
resume_model_file_path = os.path.join(resume_from_checkpoint_folder, utils.WEIGHTS_NAME)
torch.save(state_dict_replace, resume_model_file_path)
Expand Down

0 comments on commit 61ad146

Please sign in to comment.