Skip to content

Commit

Permalink
Update LLM_HF example (#3054)
Browse files Browse the repository at this point in the history
* update llm_hf example to comply with latest API change and add compression filter

* add three-site results

* minor rewording

* move filter to next PR

* updates to comments

* bug correction

* bug correction
  • Loading branch information
ZiyueXu77 authored Nov 7, 2024
1 parent dc6598e commit 3f7b12b
Show file tree
Hide file tree
Showing 27 changed files with 603 additions and 201 deletions.
145 changes: 65 additions & 80 deletions examples/advanced/llm_hf/README.md
Original file line number Diff line number Diff line change
@@ -1,71 +1,60 @@
# LLM Tuning via HuggingFace SFT Trainer
# LLM Tuning via HuggingFace SFT/PEFT APIs
This example shows how to use [NVIDIA FLARE](https://nvidia.github.io/NVFlare) for Large Language Models (LLMs) tuning tasks. It illustrates how to adapt a local training script with [HuggingFace](https://huggingface.co/) trainer to NVFlare.

## Introduction
This example illustrates both supervised fine-tuning (SFT) and parameter-efficient fine-tuning (PEFT) using the [SFT Trainer](https://huggingface.co/docs/trl/sft_trainer) from [HuggingFace](https://huggingface.co/) with [PEFT library](https://github.com/huggingface/peft).

We used the [Llama-2-7b-hf model](https://huggingface.co/meta-llama/Llama-2-7b-hf) to showcase the functionality of federated SFT and PEFT, allowing HuggingFace models to be trained and adapted with NVFlare.
We used the [Llama-3.2-1B model](https://huggingface.co/meta-llama/Llama-3.2-1B) to showcase the functionality of federated SFT and PEFT, allowing HuggingFace models to be trained and adapted with NVFlare. All other models from HuggingFace can be easily adapted following the same steps.

For PEFT, we used LoRA method, other PEFT methods (e.g. p-tuning, prompt-tuning) can be easily adapted as well by modifying the configs following [PEFT](https://github.com/huggingface/peft) examples.

Mainly on two fronts:
- Adapt local HuggingFace training scripts to federated application
- Handling large model weights (~26 GB for Llama-2-7b-hf model), this is supported by NVFlare infrastructure, and does not need any code change.
We would like to showcase two key points in this example:
- Adapt local HuggingFace training scripts, both SFT and PEFT, to federated application
- Handling large model weights (~6 GB for Llama-3.2-1B model with float32 precision for communication), which is beyond protobuf's 2 GB hard limit. It is supported by NVFlare infrastructure via streaming, and does not need any code change.

We conducted these experiments on a single 48GB RTX 6000 Ada GPU.

To use Llama-3.2-1B model, please request access to the model here https://huggingface.co/meta-llama/Llama-3.2-1B and login with an access token using huggingface-cli.

We conducted these experiments on two 80GB A100 GPUs, PEFT only needs 1 GPU, while SFT needs both GPUs. Less computation resources will be needed if smaller models are used, simply replace Llama-2-7b-hf with other options from HuggingFace.

## Setup
Please make sure you set up virtual environment following [example root readme](../../README.md).
Install additional requirements (if you already have a specific version of nvflare installed in your environment, you may want to remove nvflare in the requirements to avoid reinstalling nvflare):
```
python3 -m pip install -r requirements.txt
```
Git LFS is also necessary for downloads, please follow the steps in this [link](https://github.com/git-lfs/git-lfs/blob/main/INSTALLING.md).

## Model and Data Preparation
We first download the model and save it to the `model` folder, note that approved access to the model is required
```
mkdir model
cd model
git clone https://huggingface.co/meta-llama/Llama-2-7b-hf
cd ..
```

We then download and preprocess (to be consistent with our [NeMo example](../../../integration/nemo/examples/supervised_fine_tuning), we follow the same preprocessing steps) the dataset [databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) for this example.
## Data Preparation
We download and preprocess (consistent with our [NeMo example](../../../integration/nemo/examples/supervised_fine_tuning/README.md), we follow the same preprocessing steps).
```
mkdir dataset
cd dataset
git clone https://huggingface.co/datasets/tatsu-lab/alpaca
git clone https://huggingface.co/datasets/databricks/databricks-dolly-15k
git clone https://huggingface.co/datasets/OpenAssistant/oasst1
cd ..
mkdir dataset/dolly
python ./utils/preprocess_dolly.py --training_file dataset/databricks-dolly-15k/databricks-dolly-15k.jsonl --output_dir dataset/dolly
python ./utils/preprocess_alpaca.py --training_file dataset/alpaca/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet --output_dir dataset/alpaca
python ./utils/preprocess_oasst1.py --training_file dataset/oasst1/data/train-00000-of-00001-b42a775f407cee45.parquet --validation_file dataset/oasst1/data/validation-00000-of-00001-134b8fd0c89408b6.parquet --output_dir dataset/oasst1
```

## Centralized Training
## Adaptation of Centralized Training Script to Federated
To illustrate the adaptation process, we use a single dataset [databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k).
### One-call training
Centralized trainings, as the baseline for comparison with FL results, are done with the following command:
Centralized trainings, as the baseline for comparison with other results, are done with the following command:
```
python3 ./utils/hf_sft_peft.py --output_path ./workspace_centralized/llama2-7b-dolly-sft --mode 0
python3 ./utils/hf_sft_peft.py --output_path ./workspace_centralized/llama2-7b-dolly-peft --mode 1
python3 ./utils/hf_sft_peft.py --output_path ./workspace/llama-3.2-1b-dolly-cen_sft --train_mode SFT
python3 ./utils/hf_sft_peft.py --output_path ./workspace/llama-3.2-1b-dolly-cen_peft --train_mode PEFT
```
### Pre: Launch Modes
Before we start adapting the local training script to federated application, we first need to understand the launch modes of NVFlare client API.
In our [client settings](../../../job_templates/sag_pt/config_fed_client.conf), we have two launch modes by switching the `--launch_once` flag:
* If launch_once is true, the SubprocessLauncher will launch an external process once for the whole job
* If launch_once is false, the SubprocessLauncher will launch an external process everytime it receives a task from server
So if it is false, the SubprocessLauncher will create new processes every round.
If it is true, the SubprocessLauncher will reuse the same process for all rounds.

Turning `launch_once` to `false` can be useful in some scenarios like quick prototyping, but for the application of LLM where setup stage can take significant resources, we would want to only setup once. Hence, the below steps are for `launch_once = true` scenario.

See [Client API](../../hello-world/ml-to-fl/pt/README.md) for more details.

### Adaptation Step 1: iterative training
To adapt the centralized training script to federated application, under `launch_once = true` setting, we first need to "break" the single call to `trainer.train()` into iterative calls, one for each round of training.
To adapt the centralized training script to federated application, we first need to "break" the single call to `trainer.train()` into iterative calls, one for each round of training.
For this purpose, we provided `utils/hf_sft_peft_iter.py` as an example, which is a modified version of `utils/hf_sft_peft.py`.
Their differences are highlighted below:

![diff](./figs/diff_1.png)
![diff](./figs/diff_2.png)
![diff](./figs/diff.png)

Note that the `trainer.train()` call is replaced by a `for` loop, and the three training epochs becomes three rounds, one epoch per round.

Expand All @@ -79,18 +68,16 @@ If the intended model weights (serving as the starting point for each round, the

To run iterative training, we use the following command:
```
python3 ./utils/hf_sft_peft_iter.py --output_path /workspace_centralized/llama2-7b-dolly-sft-iter --mode 0
python3 ./utils/hf_sft_peft_iter.py --output_path /workspace_centralized/llama2-7b-dolly-peft-iter --mode 1
python3 ./utils/hf_sft_peft_iter.py --output_path ./workspace/llama-3.2-1b-dolly-cen_sft-iter --train_mode SFT
python3 ./utils/hf_sft_peft_iter.py --output_path ./workspace/llama-3.2-1b-dolly-cen_peft-iter --train_mode PEFT
```

The PEFT curves are shown below, blue for single call, black for iterative. We can see the "zig-zag" pattern in the iterative training loss curve.
The SFT curves are shown below, black for single call, blue for iterative. We can see the "zig-zag" pattern in the iterative training loss curve.
![sft](./figs/cen_sft.png)

Similar patterns can be observed from the PEFT curves, purple for single call, green for iterative.
![peft](./figs/cen_peft.png)

Similar patterns can be observed from the SFT curves

![sft](./figs/cen_sft.png)

### Adaptation Step 2: federated with NVFlare
Once we have the iterative training script ready with "starting model" loading capability, it can be easily adapted to a NVFlare trainer by using [Client API](../../hello-world/ml-to-fl/pt/README.md).

Expand All @@ -99,55 +86,53 @@ The major code modifications are for receiving and returning the global model (r
![diff](./figs/diff_fl_1.png)
![diff](./figs/diff_fl_2.png)

## Job for NVFlare FL Training
With the local training script ready, we can go ahead to generate the NVFlare job configs by reusing the job templates from [sag_pt](../../../job_templates/sag_pt/).

Let's set the job template path with the following command.
```bash
nvflare config -jt ../../../job_templates/
```
Then we can check the available templates with the following command.
```bash
nvflare job list_templates
### Federated Training Results
We run the federated training on a single client using NVFlare Simulator via [JobAPI](https://nvflare.readthedocs.io/en/main/programming_guide/fed_job_api.html).
```
We can see the "sag_pt" template is available, with which we further generate job configs for SFT and PEFT as:
```
nvflare job create -force -j "./jobs/hf_peft" -w "sag_pt" -sd "code" \
-f meta.conf min_clients=1 \
-f config_fed_client.conf app_script="hf_sft_peft_fl.py" app_config="--model_path ${PWD}/model/Llama-2-7b-hf --data_path_train ${PWD}/dataset/dolly/training.jsonl --data_path_valid ${PWD}/dataset/dolly/validation.jsonl --output_path llama2-7b-dolly-peft --mode 1" \
-f config_fed_server.conf model_class_path="hf_peft_model.CausalLMPEFTModel" components[0].args.model.args.model_path="${PWD}/model/Llama-2-7b-hf" min_clients=1 num_rounds=3 key_metric="eval_loss" negate_key_metric=True
```
and
```
nvflare job create -force -j "./jobs/hf_sft" -w "sag_pt" -sd "code" \
-f meta.conf min_clients=1 \
-f config_fed_client.conf app_script="hf_sft_peft_fl.py" app_config="--model_path ${PWD}/model/Llama-2-7b-hf --data_path_train ${PWD}/dataset/dolly/training.jsonl --data_path_valid ${PWD}/dataset/dolly/validation.jsonl --output_path llama2-7b-dolly-sft --mode 0" \
-f config_fed_server.conf model_class_path="hf_sft_model.CausalLMModel" components[0].args.model.args.model_path="${PWD}/model/Llama-2-7b-hf" min_clients=1 num_rounds=3 key_metric="eval_loss" negate_key_metric=True
python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft --job_dir ${PWD}/workspace/jobs/hf_sft --train_mode SFT
python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_peft --job_dir ${PWD}/workspace/jobs/hf_peft --train_mode PEFT
```
The SFT curves are shown below, black for centralized results, magenta for FL training. With some training randomness, the two PEFT training loss curves align with each other.
![sft](./figs/fl_sft.png)

For both client and server configs, we only set the necessary parameters for the SFT and PEFT tasks, and leave the rest to the default values.
Similar patterns can be observed from the PEFT curves, purple for centralized results, orange for FL training. Alignment better than SFT can be observed.
![peft](./figs/fl_peft.png)

## Federated Training
With the produced job, we run the federated training on a single client using NVFlare Simulator.
```
nvflare simulator -w ./workspace_fl/hf_peft -n 1 -t 1 ./jobs/hf_peft
## Federated Training with Multiple Clients
With the above example, we can easily extend the federated training to multiple clients. We can use the following command to run the federated training with multiple clients:
```
and
python3 sft_job.py --client_ids dolly alpaca oasst1 --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_multi --job_dir ${PWD}/workspace/jobs/hf_sft_multi --train_mode SFT --threads 1
```
nvflare simulator -w ./workspace_fl/hf_sft -n 1 -t 1 ./jobs/hf_sft
```

## Results
In this example, our purpose is to showcase the adaptation process and FL functionality. Hence, we used 1-client setting, with which the training results should relatively align with centralized training.

The PEFT curves are shown below, blue for centralized results from `./utils/hf_sft_peft.py`, black for FL training.
For comparison, we run the other two sites in centralized training mode:
```
python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/alpaca/training.jsonl --data_path_valid ./dataset/alpaca/validation.jsonl --output_path ./workspace/llama-3.2-1b-alpaca-cen_sft --train_mode SFT
python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/oasst1/training.jsonl --data_path_valid ./dataset/oasst1/validation.jsonl --output_path ./workspace/llama-3.2-1b-oasst1-cen_sft --train_mode SFT
```

We can see with some training randomness, the two PEFT training loss curves align with each other.
The training loss curves are shown below:

![peft](./figs/fl_peft.png)
Dolly:
![sft](./figs/fl_sft_dolly.png)
Alpaca:
![sft](./figs/fl_sft_alpaca.png)
Oasst1:
![sft](./figs/fl_sft_oasst1.png)

Similar patterns can be observed from the SFT curves
As shown, federated training with multiple clients (lines with three sections) can achieve comparable or better results w.r.t. training loss to individual site's centralized trainings (continuous curves), demonstrating the effectiveness of federated learning.

![sft](./figs/fl_sft.png)
Similarly for PEFT, we can run the following command:
```
python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/alpaca/training.jsonl --data_path_valid ./dataset/alpaca/validation.jsonl --output_path ./workspace/llama-3.2-1b-alpaca-cen_peft --train_mode PEFT
python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/oasst1/training.jsonl --data_path_valid ./dataset/oasst1/validation.jsonl --output_path ./workspace/llama-3.2-1b-oasst1-cen_peft --train_mode PEFT
python3 sft_job.py --client_ids dolly alpaca oasst1 --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_peft_multi --job_dir ${PWD}/workspace/jobs/hf_peft_multi --train_mode PEFT --threads 1
```

The training loss curves are shown below:

Dolly:
![peft](./figs/peft_dolly.png)
Alpaca:
![peft](./figs/peft_alpaca.png)
Oasst1:
![peft](./figs/peft_oasst1.png)
Binary file modified examples/advanced/llm_hf/figs/cen_peft.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/advanced/llm_hf/figs/cen_sft.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/advanced/llm_hf/figs/diff.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed examples/advanced/llm_hf/figs/diff_1.png
Binary file not shown.
Binary file removed examples/advanced/llm_hf/figs/diff_2.png
Binary file not shown.
Binary file modified examples/advanced/llm_hf/figs/diff_fl_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/advanced/llm_hf/figs/diff_fl_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/advanced/llm_hf/figs/fl_peft.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/advanced/llm_hf/figs/fl_peft_comp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/advanced/llm_hf/figs/fl_sft.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/advanced/llm_hf/figs/fl_sft_alpaca.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/advanced/llm_hf/figs/fl_sft_comp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/advanced/llm_hf/figs/fl_sft_dolly.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/advanced/llm_hf/figs/fl_sft_oasst1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/advanced/llm_hf/figs/peft_alpaca.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/advanced/llm_hf/figs/peft_dolly.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/advanced/llm_hf/figs/peft_oasst1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions examples/advanced/llm_hf/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
nvflare~=2.5.0rc
nvflare
torch
datasets
tensorboard
transformers
peft
trl
flash-attn
bitsandbytes
151 changes: 151 additions & 0 deletions examples/advanced/llm_hf/sft_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

from nvflare import FedJob
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.job_config.script_runner import ScriptRunner


def main():
args = define_parser()
train_script = "src/hf_sft_peft_fl.py"
client_ids = args.client_ids
num_clients = len(client_ids)

if args.threads:
num_threads = args.threads
else:
num_threads = num_clients

if num_threads < num_clients:
print("The number of threads smaller than the number of clients, runner clean-up will be performed.")
clean_up = 1
else:
clean_up = 0

num_rounds = args.num_rounds
workspace_dir = args.workspace_dir
job_dir = args.job_dir
model_name_or_path = args.model_name_or_path
train_mode = args.train_mode

# Create the FedJob
if train_mode.lower() == "sft":
job = FedJob(name="llm_hf_sft", min_clients=num_clients)
output_path = "sft"
elif train_mode.lower() == "peft":
job = FedJob(name="llm_hf_peft", min_clients=num_clients)
output_path = "peft"
else:
raise ValueError(f"Invalid train_mode: {train_mode}, only SFT and PEFT are supported.")

# Define the FedAvg controller workflow and send to server
controller = FedAvg(
num_clients=num_clients,
num_rounds=num_rounds,
)
job.to(controller, "server")

# Define the model persistor and send to server
# First send the model to the server
job.to("src/hf_sft_model.py", "server")
# Then send the model persistor to the server
model_args = {"path": "src.hf_sft_model.CausalLMModel", "args": {"model_name_or_path": model_name_or_path}}
job.to(PTFileModelPersistor(model=model_args), "server", id="persistor")

# Add model selection widget and send to server
job.to(IntimeModelSelector(key_metric="eval_loss", negate_key_metric=True), "server", id="model_selector")

# Send ScriptRunner to all clients
for i in range(num_clients):
client_id = client_ids[i]
site_name = f"site-{client_id}"
data_path_train = os.path.join(args.data_path, client_id, "training.jsonl")
data_path_valid = os.path.join(args.data_path, client_id, "validation.jsonl")
runner = ScriptRunner(
script=train_script,
script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --clean_up {clean_up}",
)
job.to(runner, site_name, tasks=["train"])

# Export the job
print("job_dir=", job_dir)
job.export_job(job_dir)

# Run the job
print("workspace_dir=", workspace_dir)
print("num_threads=", num_threads)
job.simulator_run(workspace_dir, threads=num_threads)


def define_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--client_ids",
nargs="+",
type=str,
default="",
help="Clinet IDs, used to get the data path for each client",
)
parser.add_argument(
"--num_rounds",
type=int,
default=3,
help="Number of rounds, default to 5",
)
parser.add_argument(
"--workspace_dir",
type=str,
default="/tmp/nvflare/jobs/llm_hf/workdir",
help="work directory, default to '/tmp/nvflare/jobs/llm_hf/workdir'",
)
parser.add_argument(
"--job_dir",
type=str,
default="/tmp/nvflare/jobs/llm_hf/jobdir",
help="directory for job export, default to '/tmp/nvflare/jobs/llm_hf/jobdir'",
)
parser.add_argument(
"--model_name_or_path",
type=str,
default="meta-llama/llama-3.2-1b",
help="model name or path",
)
parser.add_argument(
"--data_path",
type=str,
default="",
help="root directory for training and validation data",
)
parser.add_argument(
"--train_mode",
type=str,
default="SFT",
help="training mode, SFT or PEFT, default to SFT",
)
parser.add_argument(
"--threads",
type=int,
help="number of threads to use for FL simulation, default to the number of clients",
)
return parser.parse_args()


if __name__ == "__main__":
main()
Loading

0 comments on commit 3f7b12b

Please sign in to comment.