diff --git a/examples/advanced/llm_hf/README.md b/examples/advanced/llm_hf/README.md index 0c82090905..6bb723d5d7 100644 --- a/examples/advanced/llm_hf/README.md +++ b/examples/advanced/llm_hf/README.md @@ -1,18 +1,21 @@ -# LLM Tuning via HuggingFace SFT Trainer +# LLM Tuning via HuggingFace SFT/PEFT APIs This example shows how to use [NVIDIA FLARE](https://nvidia.github.io/NVFlare) for Large Language Models (LLMs) tuning tasks. It illustrates how to adapt a local training script with [HuggingFace](https://huggingface.co/) trainer to NVFlare. ## Introduction This example illustrates both supervised fine-tuning (SFT) and parameter-efficient fine-tuning (PEFT) using the [SFT Trainer](https://huggingface.co/docs/trl/sft_trainer) from [HuggingFace](https://huggingface.co/) with [PEFT library](https://github.com/huggingface/peft). -We used the [Llama-2-7b-hf model](https://huggingface.co/meta-llama/Llama-2-7b-hf) to showcase the functionality of federated SFT and PEFT, allowing HuggingFace models to be trained and adapted with NVFlare. +We used the [Llama-3.2-1B model](https://huggingface.co/meta-llama/Llama-3.2-1B) to showcase the functionality of federated SFT and PEFT, allowing HuggingFace models to be trained and adapted with NVFlare. All other models from HuggingFace can be easily adapted following the same steps. For PEFT, we used LoRA method, other PEFT methods (e.g. p-tuning, prompt-tuning) can be easily adapted as well by modifying the configs following [PEFT](https://github.com/huggingface/peft) examples. -Mainly on two fronts: -- Adapt local HuggingFace training scripts to federated application -- Handling large model weights (~26 GB for Llama-2-7b-hf model), this is supported by NVFlare infrastructure, and does not need any code change. +We would like to showcase two key points in this example: +- Adapt local HuggingFace training scripts, both SFT and PEFT, to federated application +- Handling large model weights (~6 GB for Llama-3.2-1B model with float32 precision for communication), which is beyond protobuf's 2 GB hard limit. It is supported by NVFlare infrastructure via streaming, and does not need any code change. + +We conducted these experiments on a single 48GB RTX 6000 Ada GPU. + +To use Llama-3.2-1B model, please request access to the model here https://huggingface.co/meta-llama/Llama-3.2-1B and login with an access token using huggingface-cli. -We conducted these experiments on two 80GB A100 GPUs, PEFT only needs 1 GPU, while SFT needs both GPUs. Less computation resources will be needed if smaller models are used, simply replace Llama-2-7b-hf with other options from HuggingFace. ## Setup Please make sure you set up virtual environment following [example root readme](../../README.md). @@ -20,52 +23,38 @@ Install additional requirements (if you already have a specific version of nvfla ``` python3 -m pip install -r requirements.txt ``` +Git LFS is also necessary for downloads, please follow the steps in this [link](https://github.com/git-lfs/git-lfs/blob/main/INSTALLING.md). -## Model and Data Preparation -We first download the model and save it to the `model` folder, note that approved access to the model is required -``` -mkdir model -cd model -git clone https://huggingface.co/meta-llama/Llama-2-7b-hf -cd .. -``` - -We then download and preprocess (to be consistent with our [NeMo example](../../../integration/nemo/examples/supervised_fine_tuning), we follow the same preprocessing steps) the dataset [databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) for this example. +## Data Preparation +We download and preprocess (consistent with our [NeMo example](../../../integration/nemo/examples/supervised_fine_tuning/README.md), we follow the same preprocessing steps). ``` mkdir dataset cd dataset +git clone https://huggingface.co/datasets/tatsu-lab/alpaca git clone https://huggingface.co/datasets/databricks/databricks-dolly-15k +git clone https://huggingface.co/datasets/OpenAssistant/oasst1 cd .. mkdir dataset/dolly python ./utils/preprocess_dolly.py --training_file dataset/databricks-dolly-15k/databricks-dolly-15k.jsonl --output_dir dataset/dolly +python ./utils/preprocess_alpaca.py --training_file dataset/alpaca/data/train-00000-of-00001-a09b74b3ef9c3b56.parquet --output_dir dataset/alpaca +python ./utils/preprocess_oasst1.py --training_file dataset/oasst1/data/train-00000-of-00001-b42a775f407cee45.parquet --validation_file dataset/oasst1/data/validation-00000-of-00001-134b8fd0c89408b6.parquet --output_dir dataset/oasst1 ``` -## Centralized Training +## Adaptation of Centralized Training Script to Federated +To illustrate the adaptation process, we use a single dataset [databricks-dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k). ### One-call training -Centralized trainings, as the baseline for comparison with FL results, are done with the following command: +Centralized trainings, as the baseline for comparison with other results, are done with the following command: ``` -python3 ./utils/hf_sft_peft.py --output_path ./workspace_centralized/llama2-7b-dolly-sft --mode 0 -python3 ./utils/hf_sft_peft.py --output_path ./workspace_centralized/llama2-7b-dolly-peft --mode 1 +python3 ./utils/hf_sft_peft.py --output_path ./workspace/llama-3.2-1b-dolly-cen_sft --train_mode SFT +python3 ./utils/hf_sft_peft.py --output_path ./workspace/llama-3.2-1b-dolly-cen_peft --train_mode PEFT ``` -### Pre: Launch Modes -Before we start adapting the local training script to federated application, we first need to understand the launch modes of NVFlare client API. -In our [client settings](../../../job_templates/sag_pt/config_fed_client.conf), we have two launch modes by switching the `--launch_once` flag: -* If launch_once is true, the SubprocessLauncher will launch an external process once for the whole job -* If launch_once is false, the SubprocessLauncher will launch an external process everytime it receives a task from server -So if it is false, the SubprocessLauncher will create new processes every round. -If it is true, the SubprocessLauncher will reuse the same process for all rounds. - -Turning `launch_once` to `false` can be useful in some scenarios like quick prototyping, but for the application of LLM where setup stage can take significant resources, we would want to only setup once. Hence, the below steps are for `launch_once = true` scenario. - -See [Client API](../../hello-world/ml-to-fl/pt/README.md) for more details. ### Adaptation Step 1: iterative training -To adapt the centralized training script to federated application, under `launch_once = true` setting, we first need to "break" the single call to `trainer.train()` into iterative calls, one for each round of training. +To adapt the centralized training script to federated application, we first need to "break" the single call to `trainer.train()` into iterative calls, one for each round of training. For this purpose, we provided `utils/hf_sft_peft_iter.py` as an example, which is a modified version of `utils/hf_sft_peft.py`. Their differences are highlighted below: - - + Note that the `trainer.train()` call is replaced by a `for` loop, and the three training epochs becomes three rounds, one epoch per round. @@ -79,18 +68,16 @@ If the intended model weights (serving as the starting point for each round, the To run iterative training, we use the following command: ``` -python3 ./utils/hf_sft_peft_iter.py --output_path /workspace_centralized/llama2-7b-dolly-sft-iter --mode 0 -python3 ./utils/hf_sft_peft_iter.py --output_path /workspace_centralized/llama2-7b-dolly-peft-iter --mode 1 +python3 ./utils/hf_sft_peft_iter.py --output_path ./workspace/llama-3.2-1b-dolly-cen_sft-iter --train_mode SFT +python3 ./utils/hf_sft_peft_iter.py --output_path ./workspace/llama-3.2-1b-dolly-cen_peft-iter --train_mode PEFT ``` -The PEFT curves are shown below, blue for single call, black for iterative. We can see the "zig-zag" pattern in the iterative training loss curve. +The SFT curves are shown below, black for single call, blue for iterative. We can see the "zig-zag" pattern in the iterative training loss curve. + +Similar patterns can be observed from the PEFT curves, purple for single call, green for iterative.  -Similar patterns can be observed from the SFT curves - - - ### Adaptation Step 2: federated with NVFlare Once we have the iterative training script ready with "starting model" loading capability, it can be easily adapted to a NVFlare trainer by using [Client API](../../hello-world/ml-to-fl/pt/README.md). @@ -99,55 +86,53 @@ The major code modifications are for receiving and returning the global model (r   -## Job for NVFlare FL Training -With the local training script ready, we can go ahead to generate the NVFlare job configs by reusing the job templates from [sag_pt](../../../job_templates/sag_pt/). - -Let's set the job template path with the following command. -```bash -nvflare config -jt ../../../job_templates/ -``` -Then we can check the available templates with the following command. -```bash -nvflare job list_templates +### Federated Training Results +We run the federated training on a single client using NVFlare Simulator via [JobAPI](https://nvflare.readthedocs.io/en/main/programming_guide/fed_job_api.html). ``` -We can see the "sag_pt" template is available, with which we further generate job configs for SFT and PEFT as: -``` -nvflare job create -force -j "./jobs/hf_peft" -w "sag_pt" -sd "code" \ - -f meta.conf min_clients=1 \ - -f config_fed_client.conf app_script="hf_sft_peft_fl.py" app_config="--model_path ${PWD}/model/Llama-2-7b-hf --data_path_train ${PWD}/dataset/dolly/training.jsonl --data_path_valid ${PWD}/dataset/dolly/validation.jsonl --output_path llama2-7b-dolly-peft --mode 1" \ - -f config_fed_server.conf model_class_path="hf_peft_model.CausalLMPEFTModel" components[0].args.model.args.model_path="${PWD}/model/Llama-2-7b-hf" min_clients=1 num_rounds=3 key_metric="eval_loss" negate_key_metric=True -``` -and -``` -nvflare job create -force -j "./jobs/hf_sft" -w "sag_pt" -sd "code" \ - -f meta.conf min_clients=1 \ - -f config_fed_client.conf app_script="hf_sft_peft_fl.py" app_config="--model_path ${PWD}/model/Llama-2-7b-hf --data_path_train ${PWD}/dataset/dolly/training.jsonl --data_path_valid ${PWD}/dataset/dolly/validation.jsonl --output_path llama2-7b-dolly-sft --mode 0" \ - -f config_fed_server.conf model_class_path="hf_sft_model.CausalLMModel" components[0].args.model.args.model_path="${PWD}/model/Llama-2-7b-hf" min_clients=1 num_rounds=3 key_metric="eval_loss" negate_key_metric=True +python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft --job_dir ${PWD}/workspace/jobs/hf_sft --train_mode SFT +python3 sft_job.py --client_ids dolly --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_peft --job_dir ${PWD}/workspace/jobs/hf_peft --train_mode PEFT ``` +The SFT curves are shown below, black for centralized results, magenta for FL training. With some training randomness, the two PEFT training loss curves align with each other. + -For both client and server configs, we only set the necessary parameters for the SFT and PEFT tasks, and leave the rest to the default values. +Similar patterns can be observed from the PEFT curves, purple for centralized results, orange for FL training. Alignment better than SFT can be observed. + -## Federated Training -With the produced job, we run the federated training on a single client using NVFlare Simulator. -``` -nvflare simulator -w ./workspace_fl/hf_peft -n 1 -t 1 ./jobs/hf_peft +## Federated Training with Multiple Clients +With the above example, we can easily extend the federated training to multiple clients. We can use the following command to run the federated training with multiple clients: ``` -and +python3 sft_job.py --client_ids dolly alpaca oasst1 --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_sft_multi --job_dir ${PWD}/workspace/jobs/hf_sft_multi --train_mode SFT --threads 1 ``` -nvflare simulator -w ./workspace_fl/hf_sft -n 1 -t 1 ./jobs/hf_sft -``` -## Results -In this example, our purpose is to showcase the adaptation process and FL functionality. Hence, we used 1-client setting, with which the training results should relatively align with centralized training. - -The PEFT curves are shown below, blue for centralized results from `./utils/hf_sft_peft.py`, black for FL training. +For comparison, we run the other two sites in centralized training mode: +``` +python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/alpaca/training.jsonl --data_path_valid ./dataset/alpaca/validation.jsonl --output_path ./workspace/llama-3.2-1b-alpaca-cen_sft --train_mode SFT +python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/oasst1/training.jsonl --data_path_valid ./dataset/oasst1/validation.jsonl --output_path ./workspace/llama-3.2-1b-oasst1-cen_sft --train_mode SFT +``` -We can see with some training randomness, the two PEFT training loss curves align with each other. +The training loss curves are shown below: - +Dolly: + +Alpaca: + +Oasst1: + -Similar patterns can be observed from the SFT curves +As shown, federated training with multiple clients (lines with three sections) can achieve comparable or better results w.r.t. training loss to individual site's centralized trainings (continuous curves), demonstrating the effectiveness of federated learning. - +Similarly for PEFT, we can run the following command: +``` +python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/alpaca/training.jsonl --data_path_valid ./dataset/alpaca/validation.jsonl --output_path ./workspace/llama-3.2-1b-alpaca-cen_peft --train_mode PEFT +python3 ./utils/hf_sft_peft.py --data_path_train ./dataset/oasst1/training.jsonl --data_path_valid ./dataset/oasst1/validation.jsonl --output_path ./workspace/llama-3.2-1b-oasst1-cen_peft --train_mode PEFT +python3 sft_job.py --client_ids dolly alpaca oasst1 --data_path ${PWD}/dataset --workspace_dir ${PWD}/workspace/hf_peft_multi --job_dir ${PWD}/workspace/jobs/hf_peft_multi --train_mode PEFT --threads 1 +``` +The training loss curves are shown below: +Dolly: + +Alpaca: + +Oasst1: + \ No newline at end of file diff --git a/examples/advanced/llm_hf/figs/cen_peft.png b/examples/advanced/llm_hf/figs/cen_peft.png index bb0ba57d9e..9a9fa56b42 100644 Binary files a/examples/advanced/llm_hf/figs/cen_peft.png and b/examples/advanced/llm_hf/figs/cen_peft.png differ diff --git a/examples/advanced/llm_hf/figs/cen_sft.png b/examples/advanced/llm_hf/figs/cen_sft.png index b0a4c47bfb..91133af9f1 100644 Binary files a/examples/advanced/llm_hf/figs/cen_sft.png and b/examples/advanced/llm_hf/figs/cen_sft.png differ diff --git a/examples/advanced/llm_hf/figs/diff.png b/examples/advanced/llm_hf/figs/diff.png new file mode 100644 index 0000000000..e5d1b5f980 Binary files /dev/null and b/examples/advanced/llm_hf/figs/diff.png differ diff --git a/examples/advanced/llm_hf/figs/diff_1.png b/examples/advanced/llm_hf/figs/diff_1.png deleted file mode 100644 index 098568294b..0000000000 Binary files a/examples/advanced/llm_hf/figs/diff_1.png and /dev/null differ diff --git a/examples/advanced/llm_hf/figs/diff_2.png b/examples/advanced/llm_hf/figs/diff_2.png deleted file mode 100644 index 6d9eaffc4a..0000000000 Binary files a/examples/advanced/llm_hf/figs/diff_2.png and /dev/null differ diff --git a/examples/advanced/llm_hf/figs/diff_fl_1.png b/examples/advanced/llm_hf/figs/diff_fl_1.png index f8190e9649..4de6721e17 100644 Binary files a/examples/advanced/llm_hf/figs/diff_fl_1.png and b/examples/advanced/llm_hf/figs/diff_fl_1.png differ diff --git a/examples/advanced/llm_hf/figs/diff_fl_2.png b/examples/advanced/llm_hf/figs/diff_fl_2.png index c5be32e863..45499b307c 100644 Binary files a/examples/advanced/llm_hf/figs/diff_fl_2.png and b/examples/advanced/llm_hf/figs/diff_fl_2.png differ diff --git a/examples/advanced/llm_hf/figs/fl_peft.png b/examples/advanced/llm_hf/figs/fl_peft.png index 2b8a1c7a89..bb701413cb 100644 Binary files a/examples/advanced/llm_hf/figs/fl_peft.png and b/examples/advanced/llm_hf/figs/fl_peft.png differ diff --git a/examples/advanced/llm_hf/figs/fl_peft_comp.png b/examples/advanced/llm_hf/figs/fl_peft_comp.png new file mode 100644 index 0000000000..8a19ed917f Binary files /dev/null and b/examples/advanced/llm_hf/figs/fl_peft_comp.png differ diff --git a/examples/advanced/llm_hf/figs/fl_sft.png b/examples/advanced/llm_hf/figs/fl_sft.png index a7c401f932..0ab67cc051 100644 Binary files a/examples/advanced/llm_hf/figs/fl_sft.png and b/examples/advanced/llm_hf/figs/fl_sft.png differ diff --git a/examples/advanced/llm_hf/figs/fl_sft_alpaca.png b/examples/advanced/llm_hf/figs/fl_sft_alpaca.png new file mode 100644 index 0000000000..401a72f2c4 Binary files /dev/null and b/examples/advanced/llm_hf/figs/fl_sft_alpaca.png differ diff --git a/examples/advanced/llm_hf/figs/fl_sft_comp.png b/examples/advanced/llm_hf/figs/fl_sft_comp.png new file mode 100644 index 0000000000..e5874a39fc Binary files /dev/null and b/examples/advanced/llm_hf/figs/fl_sft_comp.png differ diff --git a/examples/advanced/llm_hf/figs/fl_sft_dolly.png b/examples/advanced/llm_hf/figs/fl_sft_dolly.png new file mode 100644 index 0000000000..053b733c13 Binary files /dev/null and b/examples/advanced/llm_hf/figs/fl_sft_dolly.png differ diff --git a/examples/advanced/llm_hf/figs/fl_sft_oasst1.png b/examples/advanced/llm_hf/figs/fl_sft_oasst1.png new file mode 100644 index 0000000000..10bc870b8b Binary files /dev/null and b/examples/advanced/llm_hf/figs/fl_sft_oasst1.png differ diff --git a/examples/advanced/llm_hf/figs/peft_alpaca.png b/examples/advanced/llm_hf/figs/peft_alpaca.png new file mode 100644 index 0000000000..c597cb4e96 Binary files /dev/null and b/examples/advanced/llm_hf/figs/peft_alpaca.png differ diff --git a/examples/advanced/llm_hf/figs/peft_dolly.png b/examples/advanced/llm_hf/figs/peft_dolly.png new file mode 100644 index 0000000000..811e500403 Binary files /dev/null and b/examples/advanced/llm_hf/figs/peft_dolly.png differ diff --git a/examples/advanced/llm_hf/figs/peft_oasst1.png b/examples/advanced/llm_hf/figs/peft_oasst1.png new file mode 100644 index 0000000000..f09994c3bc Binary files /dev/null and b/examples/advanced/llm_hf/figs/peft_oasst1.png differ diff --git a/examples/advanced/llm_hf/requirements.txt b/examples/advanced/llm_hf/requirements.txt index 27e42b17b1..b5ef99c4c0 100644 --- a/examples/advanced/llm_hf/requirements.txt +++ b/examples/advanced/llm_hf/requirements.txt @@ -1,8 +1,8 @@ -nvflare~=2.5.0rc +nvflare torch datasets tensorboard transformers peft trl -flash-attn +bitsandbytes \ No newline at end of file diff --git a/examples/advanced/llm_hf/sft_job.py b/examples/advanced/llm_hf/sft_job.py new file mode 100644 index 0000000000..c850392897 --- /dev/null +++ b/examples/advanced/llm_hf/sft_job.py @@ -0,0 +1,151 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from nvflare import FedJob +from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector +from nvflare.app_common.workflows.fedavg import FedAvg +from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor +from nvflare.job_config.script_runner import ScriptRunner + + +def main(): + args = define_parser() + train_script = "src/hf_sft_peft_fl.py" + client_ids = args.client_ids + num_clients = len(client_ids) + + if args.threads: + num_threads = args.threads + else: + num_threads = num_clients + + if num_threads < num_clients: + print("The number of threads smaller than the number of clients, runner clean-up will be performed.") + clean_up = 1 + else: + clean_up = 0 + + num_rounds = args.num_rounds + workspace_dir = args.workspace_dir + job_dir = args.job_dir + model_name_or_path = args.model_name_or_path + train_mode = args.train_mode + + # Create the FedJob + if train_mode.lower() == "sft": + job = FedJob(name="llm_hf_sft", min_clients=num_clients) + output_path = "sft" + elif train_mode.lower() == "peft": + job = FedJob(name="llm_hf_peft", min_clients=num_clients) + output_path = "peft" + else: + raise ValueError(f"Invalid train_mode: {train_mode}, only SFT and PEFT are supported.") + + # Define the FedAvg controller workflow and send to server + controller = FedAvg( + num_clients=num_clients, + num_rounds=num_rounds, + ) + job.to(controller, "server") + + # Define the model persistor and send to server + # First send the model to the server + job.to("src/hf_sft_model.py", "server") + # Then send the model persistor to the server + model_args = {"path": "src.hf_sft_model.CausalLMModel", "args": {"model_name_or_path": model_name_or_path}} + job.to(PTFileModelPersistor(model=model_args), "server", id="persistor") + + # Add model selection widget and send to server + job.to(IntimeModelSelector(key_metric="eval_loss", negate_key_metric=True), "server", id="model_selector") + + # Send ScriptRunner to all clients + for i in range(num_clients): + client_id = client_ids[i] + site_name = f"site-{client_id}" + data_path_train = os.path.join(args.data_path, client_id, "training.jsonl") + data_path_valid = os.path.join(args.data_path, client_id, "validation.jsonl") + runner = ScriptRunner( + script=train_script, + script_args=f"--model_name_or_path {model_name_or_path} --data_path_train {data_path_train} --data_path_valid {data_path_valid} --output_path {output_path} --train_mode {train_mode} --clean_up {clean_up}", + ) + job.to(runner, site_name, tasks=["train"]) + + # Export the job + print("job_dir=", job_dir) + job.export_job(job_dir) + + # Run the job + print("workspace_dir=", workspace_dir) + print("num_threads=", num_threads) + job.simulator_run(workspace_dir, threads=num_threads) + + +def define_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--client_ids", + nargs="+", + type=str, + default="", + help="Clinet IDs, used to get the data path for each client", + ) + parser.add_argument( + "--num_rounds", + type=int, + default=3, + help="Number of rounds, default to 5", + ) + parser.add_argument( + "--workspace_dir", + type=str, + default="/tmp/nvflare/jobs/llm_hf/workdir", + help="work directory, default to '/tmp/nvflare/jobs/llm_hf/workdir'", + ) + parser.add_argument( + "--job_dir", + type=str, + default="/tmp/nvflare/jobs/llm_hf/jobdir", + help="directory for job export, default to '/tmp/nvflare/jobs/llm_hf/jobdir'", + ) + parser.add_argument( + "--model_name_or_path", + type=str, + default="meta-llama/llama-3.2-1b", + help="model name or path", + ) + parser.add_argument( + "--data_path", + type=str, + default="", + help="root directory for training and validation data", + ) + parser.add_argument( + "--train_mode", + type=str, + default="SFT", + help="training mode, SFT or PEFT, default to SFT", + ) + parser.add_argument( + "--threads", + type=int, + help="number of threads to use for FL simulation, default to the number of clients", + ) + return parser.parse_args() + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/llm_hf/code/hf_peft_model.py b/examples/advanced/llm_hf/src/hf_peft_model.py similarity index 94% rename from examples/advanced/llm_hf/code/hf_peft_model.py rename to examples/advanced/llm_hf/src/hf_peft_model.py index d3773455a5..b2545864c7 100755 --- a/examples/advanced/llm_hf/code/hf_peft_model.py +++ b/examples/advanced/llm_hf/src/hf_peft_model.py @@ -18,7 +18,7 @@ class CausalLMPEFTModel(torch.nn.Module): - def __init__(self, model_path): + def __init__(self, model_name_or_path): super(CausalLMPEFTModel, self).__init__() # PEFT configs peft_config = LoraConfig( @@ -29,7 +29,7 @@ def __init__(self, model_path): task_type="CAUSAL_LM", ) full_model = AutoModelForCausalLM.from_pretrained( - model_path, + model_name_or_path, ) self.model = get_peft_model(full_model, peft_config) diff --git a/examples/advanced/llm_hf/code/hf_sft_model.py b/examples/advanced/llm_hf/src/hf_sft_model.py similarity index 92% rename from examples/advanced/llm_hf/code/hf_sft_model.py rename to examples/advanced/llm_hf/src/hf_sft_model.py index f8428b0db7..fd84c3f06a 100755 --- a/examples/advanced/llm_hf/code/hf_sft_model.py +++ b/examples/advanced/llm_hf/src/hf_sft_model.py @@ -17,10 +17,10 @@ class CausalLMModel(torch.nn.Module): - def __init__(self, model_path): + def __init__(self, model_name_or_path): super(CausalLMModel, self).__init__() self.model = AutoModelForCausalLM.from_pretrained( - model_path, + model_name_or_path, ) def forward(self, input_id): diff --git a/examples/advanced/llm_hf/code/hf_sft_peft_fl.py b/examples/advanced/llm_hf/src/hf_sft_peft_fl.py similarity index 70% rename from examples/advanced/llm_hf/code/hf_sft_peft_fl.py rename to examples/advanced/llm_hf/src/hf_sft_peft_fl.py index 76050ee48a..ef0093140f 100755 --- a/examples/advanced/llm_hf/code/hf_sft_peft_fl.py +++ b/examples/advanced/llm_hf/src/hf_sft_peft_fl.py @@ -16,15 +16,21 @@ import copy import os +# Add deterministic seed for reproducibility illustration +import random + import datasets +import numpy as np import torch from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict, utils -from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, trainer_utils -from trl import SFTTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer, trainer_utils +from trl import SFTConfig, SFTTrainer import nvflare.client as flare -use_flash_attention = True +torch.manual_seed(0) +random.seed(0) +np.random.seed(0) def format_instruction(example): @@ -38,9 +44,9 @@ def format_instruction(example): def main(): parser = argparse.ArgumentParser() parser.add_argument( - "--model_path", + "--model_name_or_path", type=str, - default="./model/Llama-2-7b-hf", + default="meta-llama/llama-3.2-1b", ) parser.add_argument( "--data_path_train", @@ -55,9 +61,16 @@ def main(): parser.add_argument( "--output_path", type=str, - default="llama2-7b-dolly-sft", + default="./workspace_federated/llama-3.2-1b-dolly-sft", + ) + parser.add_argument( + "--train_mode", + type=str, + default="SFT", + help="training mode, SFT or PEFT, default to SFT", ) - parser.add_argument("--mode", type=int, default=0) + parser.add_argument("--local_epoch", type=int, default=1) + parser.add_argument("--clean_up", type=int, default=0) args = parser.parse_args() # Dataset @@ -72,8 +85,30 @@ def main(): print(f"logging_steps: {logging_steps}") # Model configs - model_path = args.model_path - if args.mode: + model_name_or_path = args.model_name_or_path + peft_config = None + + # Load model + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) + model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map="auto", + use_cache=False, + torch_dtype=torch.bfloat16, + ) + torch.set_default_dtype(default_dtype) + + # Train mode + if args.train_mode.lower() == "sft": + train_mode = 0 + elif args.train_mode.lower() == "peft": + train_mode = 1 + else: + raise ValueError(f"Invalid train_mode: {args.train_mode}, only SFT and PEFT are supported.") + + # PEFT specific + if train_mode: # PEFT configs peft_config = LoraConfig( lora_alpha=16, @@ -82,56 +117,41 @@ def main(): bias="none", task_type="CAUSAL_LM", ) - # Load model - model = AutoModelForCausalLM.from_pretrained( - model_path, - use_cache=False, - use_flash_attention_2=use_flash_attention, - device_map="auto", - ) model = get_peft_model(model, peft_config) - else: - peft_config = None - model = AutoModelForCausalLM.from_pretrained( - model_path, - use_flash_attention_2=use_flash_attention, - use_cache=False, - device_map="auto", - ) - model.config.pretraining_tp = 1 + # Set tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # Training arguments - train_args = TrainingArguments( + train_args = SFTConfig( output_dir=args.output_path, - num_train_epochs=1, + num_train_epochs=args.local_epoch, per_device_train_batch_size=batch_size, gradient_accumulation_steps=gra_accu_steps, gradient_checkpointing=False, optim="paged_adamw_32bit", logging_steps=logging_steps, save_strategy="epoch", - learning_rate=2e-4, + learning_rate=5e-4, bf16=True, - tf32=True, max_grad_norm=0.3, warmup_ratio=0.03, lr_scheduler_type="constant", disable_tqdm=True, + max_seq_length=1024, + # safetensors has some issues in saving lm_head.weight, disable it for now + save_safetensors=False, ) # Trainer - max_seq_length = 1024 trainer = SFTTrainer( model=model, train_dataset=dataset_train, eval_dataset=dataset_valid, peft_config=peft_config, - max_seq_length=max_seq_length, tokenizer=tokenizer, packing=False, formatting_func=format_instruction, @@ -141,6 +161,8 @@ def main(): # initializes NVFlare client API flare.init() + # Train federated rounds + # start with global model at the beginning of each round while flare.is_running(): # receives FLModel from NVFlare input_model = flare.receive() @@ -156,7 +178,7 @@ def main(): # evaluation on both trained and received model def evaluate(input_weights, mode): # Special load func for PEFT - if mode: + if train_mode: set_peft_model_state_dict(trainer.model, input_weights) else: trainer.model.load_state_dict(input_weights) @@ -165,10 +187,10 @@ def evaluate(input_weights, mode): return metric_score # evaluate on received global model - eval_loss = evaluate(global_model, args.mode) + eval_loss = evaluate(global_model, train_mode) eval_loss = float(eval_loss["eval_loss"]) - # loads global model + # Load global model and previous training states # Since we perform iterative training by using "resume" functionality # we need to replace the resume weights with global weights every round if curr_round == 0: @@ -177,20 +199,27 @@ def evaluate(input_weights, mode): else: # replace local resume weights with global weights resume_from_checkpoint_folder = trainer_utils.get_last_checkpoint(trainer.args.output_dir) - if args.mode: + if train_mode: # PEFT model small, directly save via torch.save resume_model_file_path = os.path.join(resume_from_checkpoint_folder, utils.WEIGHTS_NAME) torch.save(global_model, resume_model_file_path) else: # SFT model can be large, save via HF API - trainer.model.save_pretrained(resume_from_checkpoint_folder) + # Disable safetensor for now + trainer.model.save_pretrained(resume_from_checkpoint_folder, safe_serialization=False) # increment num_train_epochs so that the trainer will continue training - trainer.args.num_train_epochs += 1 + if args.clean_up: + # runner got cleaned up, set num_train_epochs with curr_round + trainer.args.num_train_epochs = (curr_round + 1) * args.local_epoch + else: + # runner still alive, increment num_train_epochs with local_epoch + trainer.args.num_train_epochs += args.local_epoch + print(f"Increment num_train_epochs to {trainer.args.num_train_epochs}") # continue training trainer.train(resume_from_checkpoint=True) # compose output model to send back to server - if args.mode: + if train_mode: # PEFT, load PEFT part from trainer model out_param = get_peft_model_state_dict(trainer.model) else: @@ -198,10 +227,13 @@ def evaluate(input_weights, mode): out_param = trainer.model.state_dict() # update the key name sent to global model - if not args.mode: + if not train_mode: for key in list(out_param.keys()): out_param["model." + key] = out_param.pop(key).cpu() + # cast out_param to float32 preparing for communication + out_param = {k: v.to(torch.float32) for k, v in out_param.items()} + # construct trained FL model output_model = flare.FLModel( params=out_param, diff --git a/examples/advanced/llm_hf/utils/hf_sft_peft.py b/examples/advanced/llm_hf/utils/hf_sft_peft.py index 2d07a05715..ae1d429281 100755 --- a/examples/advanced/llm_hf/utils/hf_sft_peft.py +++ b/examples/advanced/llm_hf/utils/hf_sft_peft.py @@ -14,12 +14,19 @@ import argparse +# Add deterministic seed for reproducibility illustration +import random + import datasets +import numpy as np +import torch from peft import LoraConfig, get_peft_model -from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments -from trl import SFTTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import SFTConfig, SFTTrainer -use_flash_attention = True +torch.manual_seed(0) +random.seed(0) +np.random.seed(0) def format_instruction(example): @@ -33,9 +40,9 @@ def format_instruction(example): def main(): parser = argparse.ArgumentParser() parser.add_argument( - "--model_path", + "--model_name_or_path", type=str, - default="./model/Llama-2-7b-hf", + default="meta-llama/llama-3.2-1b", ) parser.add_argument( "--data_path_train", @@ -50,9 +57,14 @@ def main(): parser.add_argument( "--output_path", type=str, - default="./workspace_centralized/llama2-7b-dolly-sft", + default="./workspace_centralized/llama-3.2-1b-dolly-sft", + ) + parser.add_argument( + "--train_mode", + type=str, + default="SFT", + help="training mode, SFT or PEFT, default to SFT", ) - parser.add_argument("--mode", type=int, default=0) args = parser.parse_args() # Dataset @@ -67,8 +79,30 @@ def main(): print(f"logging_steps: {logging_steps}") # Model configs - model_path = args.model_path - if args.mode: + model_name_or_path = args.model_name_or_path + peft_config = None + + # Load model + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) + model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map="auto", + use_cache=False, + torch_dtype=torch.bfloat16, + ) + torch.set_default_dtype(default_dtype) + + # Train mode + if args.train_mode.lower() == "sft": + train_mode = 0 + elif args.train_mode.lower() == "peft": + train_mode = 1 + else: + raise ValueError(f"Invalid train_mode: {args.train_mode}, only SFT and PEFT are supported.") + + # PEFT specific + if train_mode: # PEFT configs peft_config = LoraConfig( lora_alpha=16, @@ -77,31 +111,16 @@ def main(): bias="none", task_type="CAUSAL_LM", ) - # Load model - model = AutoModelForCausalLM.from_pretrained( - model_path, - use_cache=False, - use_flash_attention_2=use_flash_attention, - device_map="auto", - ) model = get_peft_model(model, peft_config) - else: - peft_config = None - model = AutoModelForCausalLM.from_pretrained( - model_path, - use_flash_attention_2=use_flash_attention, - use_cache=False, - device_map="auto", - ) - model.config.pretraining_tp = 1 + # Set tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # Training arguments - train_args = TrainingArguments( + train_args = SFTConfig( output_dir=args.output_path, num_train_epochs=3, per_device_train_batch_size=batch_size, @@ -110,23 +129,21 @@ def main(): optim="paged_adamw_32bit", logging_steps=logging_steps, save_strategy="epoch", - learning_rate=2e-4, + learning_rate=5e-4, bf16=True, - tf32=True, max_grad_norm=0.3, warmup_ratio=0.03, lr_scheduler_type="constant", disable_tqdm=True, + max_seq_length=1024, ) # Trainer - max_seq_length = 1024 trainer = SFTTrainer( model=model, train_dataset=dataset_train, eval_dataset=dataset_valid, peft_config=peft_config, - max_seq_length=max_seq_length, tokenizer=tokenizer, packing=False, formatting_func=format_instruction, diff --git a/examples/advanced/llm_hf/utils/hf_sft_peft_iter.py b/examples/advanced/llm_hf/utils/hf_sft_peft_iter.py index 3a0f5b3af7..065280f3f2 100755 --- a/examples/advanced/llm_hf/utils/hf_sft_peft_iter.py +++ b/examples/advanced/llm_hf/utils/hf_sft_peft_iter.py @@ -15,13 +15,19 @@ import argparse import os +# Add deterministic seed for reproducibility illustration +import random + import datasets +import numpy as np import torch from peft import LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict, utils -from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, trainer_utils -from trl import SFTTrainer +from transformers import AutoModelForCausalLM, AutoTokenizer, trainer_utils +from trl import SFTConfig, SFTTrainer -use_flash_attention = True +torch.manual_seed(0) +random.seed(0) +np.random.seed(0) def format_instruction(example): @@ -35,9 +41,9 @@ def format_instruction(example): def main(): parser = argparse.ArgumentParser() parser.add_argument( - "--model_path", + "--model_name_or_path", type=str, - default="./model/Llama-2-7b-hf", + default="meta-llama/llama-3.2-1b", ) parser.add_argument( "--data_path_train", @@ -52,9 +58,14 @@ def main(): parser.add_argument( "--output_path", type=str, - default="./workspace_centralized/llama2-7b-dolly-sft-iter", + default="./workspace_centralized/llama-3.2-1b-dolly-sft-iter", + ) + parser.add_argument( + "--train_mode", + type=str, + default="SFT", + help="training mode, SFT or PEFT, default to SFT", ) - parser.add_argument("--mode", type=int, default=0) args = parser.parse_args() # Dataset @@ -69,8 +80,30 @@ def main(): print(f"logging_steps: {logging_steps}") # Model configs - model_path = args.model_path - if args.mode: + model_name_or_path = args.model_name_or_path + peft_config = None + + # Load model + default_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) + model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map="auto", + use_cache=False, + torch_dtype=torch.bfloat16, + ) + torch.set_default_dtype(default_dtype) + + # Train mode + if args.train_mode.lower() == "sft": + train_mode = 0 + elif args.train_mode.lower() == "peft": + train_mode = 1 + else: + raise ValueError(f"Invalid train_mode: {args.train_mode}, only SFT and PEFT are supported.") + + # PEFT specific + if train_mode: # PEFT configs peft_config = LoraConfig( lora_alpha=16, @@ -79,39 +112,16 @@ def main(): bias="none", task_type="CAUSAL_LM", ) - # Load model - model = AutoModelForCausalLM.from_pretrained( - model_path, - use_cache=False, - use_flash_attention_2=use_flash_attention, - device_map="auto", - ) model = get_peft_model(model, peft_config) - else: - peft_config = None - model = AutoModelForCausalLM.from_pretrained( - model_path, - use_flash_attention_2=use_flash_attention, - use_cache=False, - device_map="auto", - ) - model.config.pretraining_tp = 1 + # Set tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" - # Save base model state_dict, which will be used as the starting - # weights for each round - to show the weights are loaded correctly - if args.mode: - params = get_peft_model_state_dict(model) - else: - params = model.state_dict() - torch.save(params, "model_dict_base.pt") - # Training arguments - train_args = TrainingArguments( + train_args = SFTConfig( output_dir=args.output_path, num_train_epochs=1, per_device_train_batch_size=batch_size, @@ -120,37 +130,46 @@ def main(): optim="paged_adamw_32bit", logging_steps=logging_steps, save_strategy="epoch", - learning_rate=2e-4, + learning_rate=5e-4, bf16=True, - tf32=True, max_grad_norm=0.3, warmup_ratio=0.03, lr_scheduler_type="constant", disable_tqdm=True, + max_seq_length=1024, + # safetensors has some issues in saving lm_head.weight, disable it for now + save_safetensors=False, ) # Trainer - max_seq_length = 1024 trainer = SFTTrainer( model=model, train_dataset=dataset_train, eval_dataset=dataset_valid, peft_config=peft_config, - max_seq_length=max_seq_length, tokenizer=tokenizer, packing=False, formatting_func=format_instruction, args=train_args, ) + # Save base model state_dict, which will be used as the starting + # weights for each round - to show the weights are loaded correctly + initial_model_path = os.path.join(args.output_path, "model_dict_base.pt") + if train_mode: + params = get_peft_model_state_dict(model) + else: + params = model.state_dict() + torch.save(params, initial_model_path) + # Train iteratively by using "resume" functionality # and replace the resume weights every round for curr_round in range(3): print(f"current_round={curr_round}") - # Evaluate - state_dict_replace = torch.load("model_dict_base.pt", map_location="cpu") - if args.mode: + # Load and Evaluate model file + state_dict_replace = torch.load(initial_model_path, map_location="cpu", weights_only=True) + if train_mode: set_peft_model_state_dict(trainer.model, state_dict_replace) else: trainer.model.load_state_dict(state_dict_replace) @@ -163,13 +182,14 @@ def main(): else: # replace local resume weights with global weights resume_from_checkpoint_folder = trainer_utils.get_last_checkpoint(trainer.args.output_dir) - if args.mode: + if train_mode: # PEFT model small, directly save via torch.save resume_model_file_path = os.path.join(resume_from_checkpoint_folder, utils.WEIGHTS_NAME) torch.save(state_dict_replace, resume_model_file_path) else: # SFT model can be large, save via HF API - trainer.model.save_pretrained(resume_from_checkpoint_folder) + # Disable safetensor for now + trainer.model.save_pretrained(resume_from_checkpoint_folder, safe_serialization=False) # increment num_train_epochs so that the trainer will continue training trainer.args.num_train_epochs += 1 # continue training diff --git a/examples/advanced/llm_hf/utils/preprocess_alpaca.py b/examples/advanced/llm_hf/utils/preprocess_alpaca.py new file mode 100755 index 0000000000..c13c5a7e61 --- /dev/null +++ b/examples/advanced/llm_hf/utils/preprocess_alpaca.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os + +import numpy as np +import pandas as pd +import pyarrow.parquet as pq + + +def data_args(): + parser = argparse.ArgumentParser(description="Preprocess data to train and validation files in jsonl format") + parser.add_argument("--training_file", type=str, required=True, help="Path to training set") + parser.add_argument("--validation_file", type=str, help="Path to validation set, if given, append to training data") + parser.add_argument("--validation_ratio", type=float, default=0.1, help="Ratio of validation set, defult to 10%") + parser.add_argument("--testing_ratio", type=float, default=0.1, help="Ratio of testing set, defult to 10%") + parser.add_argument("--output_dir", type=str, required=True, help="Path to output folder") + args = parser.parse_args() + return args + + +def split_to_jsonl(data, output_dir, validation_ratio, testing_ratio): + print("Preprocessing data to NeMo_SFT jsonl format...") + output_path_tra = os.path.join(output_dir, "training.jsonl") + output_path_val = os.path.join(output_dir, "validation.jsonl") + output_path_tst = os.path.join(output_dir, "testing.jsonl") + + data_ct = len(data) + val_threshold = int(data_ct * validation_ratio) + test_threshold = int(data_ct * testing_ratio) + + with open(output_path_val, "w") as g, open(output_path_tst, "w") as h, open(output_path_tra, "w") as i: + for index, item in data.iterrows(): + input = item["input"].strip() + if input != "": + # Randomize input and instruction order. + input_first = np.random.randint(0, 2) == 0 + if input_first: + instruction = item["instruction"].strip() + assert instruction != "" + input = f"{input}\n\n{instruction}" + output = item["output"] + else: + instruction = item["instruction"].strip() + assert instruction != "" + input = f"{instruction}\n\n{input}" + output = item["output"] + else: + input = item["instruction"] + output = item["output"] + # write to jsonl file according to index + if index < val_threshold: + h.write(json.dumps({"input": input, "output": output}) + "\n") + elif index < val_threshold + test_threshold: + g.write(json.dumps({"input": input, "output": output}) + "\n") + else: + i.write(json.dumps({"input": input, "output": output}) + "\n") + print(f"{index + 1} out of {data_ct} Data was successfully preprocessed and saved.") + + +def main(): + args = data_args() + # load training data + path_to_train = args.training_file + ds = pq.read_table(path_to_train) + train = ds.to_pandas() + # load validation data if provided and append to training data + if args.validation_file: + path_to_val = args.validation_file + ds = pq.read_table(path_to_val) + val = ds.to_pandas() + train = pd.concat([train, val]) + # randomize the order of the data + data_full = train.sample(frac=1, random_state=0).reset_index(drop=True) + # split data into training, validation and testing + val_ratio = args.validation_ratio + test_ratio = args.testing_ratio + output_dir = args.output_dir + split_to_jsonl(data_full, output_dir, val_ratio, test_ratio) + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/llm_hf/utils/preprocess_oasst1.py b/examples/advanced/llm_hf/utils/preprocess_oasst1.py new file mode 100755 index 0000000000..de4de63040 --- /dev/null +++ b/examples/advanced/llm_hf/utils/preprocess_oasst1.py @@ -0,0 +1,101 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os + +import pandas as pd +import pyarrow.parquet as pq + + +def data_args(): + parser = argparse.ArgumentParser(description="Preprocess data to train and validation files in jsonl format") + parser.add_argument("--training_file", type=str, required=True, help="Path to training set") + parser.add_argument("--validation_file", type=str, help="Path to validation set, if given, append to training data") + parser.add_argument("--validation_ratio", type=float, default=0.1, help="Ratio of validation set, defult to 10%") + parser.add_argument("--testing_ratio", type=float, default=0.1, help="Ratio of testing set, defult to 10%") + parser.add_argument("--output_dir", type=str, required=True, help="Path to output folder") + args = parser.parse_args() + return args + + +def get_data_for_sft(data): + data_assistant = data[(data.role == "assistant") & (data["rank"] == 0.0)].copy() + data_prompter = data[(data.role == "prompter")].copy() + data_prompter = data_prompter.set_index("message_id") + data_assistant["output"] = data_assistant["text"].values + + inputs = [] + parent_ids = [] + for index, item in data_assistant.iterrows(): + input = data_prompter.loc[item.parent_id] + inputs.append(input.text) + parent_ids.append(input.parent_id) + data_assistant["instruction"] = inputs + data_assistant["parent_id"] = parent_ids + data_assistant = data_assistant[data_assistant.lang == "en"] + data_assistant = data_assistant[["instruction", "output"]] + return data_assistant + + +def split_to_jsonl(data, output_dir, validation_ratio, testing_ratio): + print("Preprocessing data to NeMo_SFT jsonl format...") + output_path_tra = os.path.join(output_dir, "training.jsonl") + output_path_val = os.path.join(output_dir, "validation.jsonl") + output_path_tst = os.path.join(output_dir, "testing.jsonl") + + data_ct = len(data) + val_threshold = int(data_ct * validation_ratio) + test_threshold = int(data_ct * testing_ratio) + + with open(output_path_val, "w") as g, open(output_path_tst, "w") as h, open(output_path_tra, "w") as i: + for index, item in data.iterrows(): + input = item["instruction"] + output = item["output"] + # write to jsonl file according to index + if index < val_threshold: + h.write(json.dumps({"input": input, "output": output}) + "\n") + elif index < val_threshold + test_threshold: + g.write(json.dumps({"input": input, "output": output}) + "\n") + else: + i.write(json.dumps({"input": input, "output": output}) + "\n") + print(f"{index + 1} out of {data_ct} Data was successfully preprocessed and saved.") + + +def main(): + args = data_args() + # load training data + path_to_train = args.training_file + ds = pq.read_table(path_to_train) + data = ds.to_pandas() + train = get_data_for_sft(data) + # load validation data if provided and append to training data + if args.validation_file: + path_to_val = args.validation_file + ds = pq.read_table(path_to_val) + data = ds.to_pandas() + val = get_data_for_sft(data) + train = pd.concat([train, val]) + # randomize the order of the data + data_full = train.sample(frac=1, random_state=0).reset_index(drop=True) + # split data into training, validation and testing + val_ratio = args.validation_ratio + test_ratio = args.testing_ratio + output_dir = args.output_dir + split_to_jsonl(data_full, output_dir, val_ratio, test_ratio) + + +if __name__ == "__main__": + main()