From a2aa0f0b09671eaf81a945eb5e4913165fee92fa Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 18 Mar 2024 12:20:54 +0100 Subject: [PATCH] FEAT: Add CLIs in TRL ! (#1419) * CLI V1 * v1 CLI * add rich enhancmeents * revert unindented change * some comments * cleaner CLI * fix * fix * remove print callback * move to cli instead of trl_cli * revert unneeded changes * fix test * Update trl/commands/sft.py Co-authored-by: Leandro von Werra * remove redundant strings * fix import issue * fix other issues * add packing * add config parser * some refactor * cleaner * add example config yaml file * small refactor * change a bit the logic * fix issues here and there * add CLI in docs * move to examples/sft * remove redundant licenses * make it work on dpo * set to None * switch to accelerate and fix many things * add docs * more docs * added tests * doc clarification * more docs * fix CI for windows and python 3.8 * fix * attempt to fix CI * fix? * test * fix * tweak? * fix * test * another test * fix * test * fix * fix * fix * skip tests for windows * test @lvwerra approach * make dev * revert unneeded changes * fix sft dpo * optimize a bit * address final comments * update docs * final comment --------- Co-authored-by: Leandro von Werra --- CONTRIBUTING.md | 2 +- MANIFEST.in | 2 +- Makefile | 5 + commands/run_dpo.sh | 2 + docs/source/_toctree.yml | 2 + docs/source/clis.mdx | 87 ++++++++++++++ example_config.yaml | 20 ++++ examples/scripts/dpo.py | 139 ++++++++++------------ examples/scripts/sft.py | 88 +++++++++----- setup.cfg | 2 +- setup.py | 73 +++++++----- tests/test_cli.py | 34 ++++++ trl/__init__.py | 161 ++++++++++++++++++------- trl/commands/__init__.py | 34 ++++++ trl/commands/cli.py | 65 +++++++++++ trl/commands/cli_utils.py | 227 ++++++++++++++++++++++++++++++++++++ trl/environment/__init__.py | 13 ++- trl/extras/__init__.py | 16 ++- trl/import_utils.py | 87 ++++++++++++-- trl/models/__init__.py | 59 +++++++--- trl/models/utils.py | 8 ++ trl/trainer/__init__.py | 112 +++++++++++++----- trl/trainer/sft_trainer.py | 7 ++ trl/trainer/utils.py | 73 ++++++++++++ 24 files changed, 1085 insertions(+), 233 deletions(-) create mode 100644 docs/source/clis.mdx create mode 100644 example_config.yaml create mode 100644 tests/test_cli.py create mode 100644 trl/commands/__init__.py create mode 100644 trl/commands/cli.py create mode 100644 trl/commands/cli_utils.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index de1731bccb..2e11e59b3a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -5,7 +5,7 @@ Before you start contributing make sure you installed all the dev tools: ```bash -pip install -e ".[dev]" +make dev ``` ## Did you find a bug? diff --git a/MANIFEST.in b/MANIFEST.in index 5c0e7ced19..f0d7acb4da 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,4 +2,4 @@ include settings.ini include LICENSE include CONTRIBUTING.md include README.md -recursive-exclude * __pycache__ +recursive-exclude * __pycache__ \ No newline at end of file diff --git a/Makefile b/Makefile index a80e739d1e..dfdc96df47 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,11 @@ check_dirs := examples tests trl ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs COMMAND_FILES_PATH = `pwd`/commands + +dev: + [ -L "$(pwd)/trl/commands/scripts" ] && unlink "$(pwd)/trl/commands/scripts" || true + pip install -e ".[dev]" + test: python -m pytest -n auto --dist=loadfile -s -v ./tests/ diff --git a/commands/run_dpo.sh b/commands/run_dpo.sh index 1d6b6db7e9..632496f359 100644 --- a/commands/run_dpo.sh +++ b/commands/run_dpo.sh @@ -3,6 +3,7 @@ # but defaults to QLoRA + PEFT OUTPUT_DIR="test_dpo/" MODEL_NAME="HuggingFaceM4/tiny-random-LlamaForCausalLM" +DATASET_NAME="trl-internal-testing/Anthropic-hh-rlhf-processed" MAX_STEPS=5 BATCH_SIZE=2 SEQ_LEN=128 @@ -36,6 +37,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \ --mixed_precision 'fp16' \ `pwd`/examples/scripts/dpo.py \ --model_name_or_path $MODEL_NAME \ + --dataset_name $DATASET_NAME \ --output_dir $OUTPUT_DIR \ --max_steps $MAX_STEPS \ --per_device_train_batch_size $BATCH_SIZE \ diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index f5f9654c42..2cbfb66914 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -5,6 +5,8 @@ title: Quickstart - local: installation title: Installation + - local: clis + title: Get started with Command Line Interfaces (CLIs) - local: how_to_train title: PPO Training FAQ - local: use_model diff --git a/docs/source/clis.mdx b/docs/source/clis.mdx new file mode 100644 index 0000000000..86a77e6eae --- /dev/null +++ b/docs/source/clis.mdx @@ -0,0 +1,87 @@ +# Command Line Interfaces (CLIs) + +You can use TRL to fine-tune your Language Model on Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) using the TRL CLIs. + +Currently supported CLIs are: + +- `trl sft` +- `trl dpo` + +## Get started + +Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task. + +Also make sure to run: +```bash +accelerate config +``` +and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command. + +We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command. + +```yaml +model_name_or_path: + HuggingFaceM4/tiny-random-LlamaForCausalLM +dataset_name: + imdb +dataset_text_field: + text +report_to: + none +learning_rate: + 0.0001 +lr_scheduler_type: + cosine +``` + +Save that config in a `.yaml` and get directly started ! Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g.: + +```bash +trl sft --config example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts +``` + +Will force-use `cosine_with_restarts` for `lr_scheduler_type`. + +## Supported Arguments + +We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`: + +[[autodoc]] ModelConfig + +You can pass any of these arguments either to the CLI or the YAML file. + +### Supervised Fine-tuning (SFT) + +Follow the basic instructions above and run `trl sft --output_dir <*args>`: + +```bash +trl sft --config config.yaml --output_dir your-output-dir +``` + +The SFT CLI is based on the `examples/scripts/sft.py` script. + +### Direct Policy Optimization (DPO) + +First, follow the basic instructions above and run `trl dpo --output_dir <*args>`. Make sure to process your DPO dataset in the TRL format as follows: + +1- Make sure to pre-tokenize the dataset using chat templates: + +```bash +python examples/datasets/tokenize_ds.py --model gpt2 --dataset yourdataset +``` + +You might need to adapt the `examples/datasets/tokenize_ds.py` to use yout chat template + +2- Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`): + +```bash +python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org +``` + +Once your dataset being pushed, run the dpo CLI as follows: + +```bash +trl dpo --config config.yaml --output_dir your-output-dir +``` + +The SFT CLI is based on the `examples/scripts/dpo.py` script. \ No newline at end of file diff --git a/example_config.yaml b/example_config.yaml new file mode 100644 index 0000000000..da04a9c213 --- /dev/null +++ b/example_config.yaml @@ -0,0 +1,20 @@ +# This is an example configuration file of TRL CLI, you can use it for +# SFT like that: `trl sft --config config.yaml --output_dir test-sft` +# The YAML file supports environment variables by adding an `env` field +# as below + +# env: +# CUDA_VISIBLE_DEVICES: 0 + +model_name_or_path: + HuggingFaceM4/tiny-random-LlamaForCausalLM +dataset_name: + imdb +dataset_text_field: + text +report_to: + none +learning_rate: + 1e-4 +lr_scheduler_type: + cosine \ No newline at end of file diff --git a/examples/scripts/dpo.py b/examples/scripts/dpo.py index 587a5efdfb..3a0d6486ca 100644 --- a/examples/scripts/dpo.py +++ b/examples/scripts/dpo.py @@ -1,3 +1,4 @@ +# flake8: noqa # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -48,76 +49,47 @@ --lora_r=16 \ --lora_alpha=16 """ -from dataclasses import dataclass, field -from typing import Dict, Optional +import logging +import os +from contextlib import nullcontext -import torch -from datasets import Dataset, load_dataset -from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments - -from trl import DPOTrainer, ModelConfig, get_kbit_device_map, get_peft_config, get_quantization_config - - -@dataclass -class ScriptArguments: - beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) - max_length: int = field(default=512, metadata={"help": "max length of each sample"}) - max_prompt_length: int = field(default=128, metadata={"help": "max length of each sample's prompt"}) - max_target_length: int = field( - default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"} - ) - sanity_check: bool = field(default=True, metadata={"help": "only train on 1000 samples"}) - ignore_bias_buffers: bool = field( - default=False, - metadata={ - "help": "debug argument for distributed training;" - "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" - "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" - }, - ) - generate_during_eval: bool = field(default=False, metadata={"help": "Generate during evaluation"}) +TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False) +from trl.commands.cli_utils import DpoScriptArguments, init_zero_verbose, TrlParser -def extract_anthropic_prompt(prompt_and_response): - """Extract the anthropic prompt from a prompt and response pair.""" - search_term = "\n\nAssistant:" - search_term_idx = prompt_and_response.rfind(search_term) - assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" - return prompt_and_response[: search_term_idx + len(search_term)] +if TRL_USE_RICH: + init_zero_verbose() + FORMAT = "%(message)s" + from rich.console import Console + from rich.logging import RichHandler -def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: Optional[str] = None) -> Dataset: - """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format. - - The dataset is converted to a dictionary with the following structure: - { - 'prompt': List[str], - 'chosen': List[str], - 'rejected': List[str], - } +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments - Prompts should be structured as follows: - \n\nHuman: \n\nAssistant: - Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:. - """ - dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir) - if sanity_check: - dataset = dataset.select(range(min(len(dataset), 1000))) +from trl import ( + DPOTrainer, + ModelConfig, + RichProgressCallback, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) - def split_prompt_and_responses(sample) -> Dict[str, str]: - prompt = extract_anthropic_prompt(sample["chosen"]) - return { - "prompt": prompt, - "chosen": sample["chosen"][len(prompt) :], - "rejected": sample["rejected"][len(prompt) :], - } - return dataset.map(split_prompt_and_responses) +if TRL_USE_RICH: + logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO) if __name__ == "__main__": - parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig)) - args, training_args, model_config = parser.parse_args_into_dataclasses() + parser = TrlParser((DpoScriptArguments, TrainingArguments, ModelConfig)) + args, training_args, model_config = parser.parse_args_and_config() + + # Force use our print callback + if TRL_USE_RICH: + training_args.disable_tqdm = True + console = Console() ################ # Model & Tokenizer @@ -152,28 +124,43 @@ def split_prompt_and_responses(sample) -> Dict[str, str]: name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool ] + ################ + # Optional rich context managers + ############### + init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the DPOTrainer...") + save_context = ( + nullcontext() + if not TRL_USE_RICH + else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}") + ) + ################ # Dataset ################ - train_dataset = get_hh("train", sanity_check=args.sanity_check) - eval_dataset = get_hh("test", sanity_check=args.sanity_check) + train_dataset = load_dataset(args.dataset_name, split="train") + eval_dataset = load_dataset(args.dataset_name, split="test") ################ # Training ################ - trainer = DPOTrainer( - model, - model_ref, - args=training_args, - beta=args.beta, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - max_length=args.max_length, - max_target_length=args.max_target_length, - max_prompt_length=args.max_prompt_length, - generate_during_eval=args.generate_during_eval, - peft_config=get_peft_config(model_config), - ) + with init_context: + trainer = DPOTrainer( + model, + model_ref, + args=training_args, + beta=args.beta, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + max_length=args.max_length, + max_target_length=args.max_target_length, + max_prompt_length=args.max_prompt_length, + generate_during_eval=args.generate_during_eval, + peft_config=get_peft_config(model_config), + callbacks=[RichProgressCallback] if TRL_USE_RICH else None, + ) + trainer.train() - trainer.save_model(training_args.output_dir) + + with save_context: + trainer.save_model(training_args.output_dir) diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index 61f5eedb03..1533906141 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -1,3 +1,4 @@ +# flake8: noqa # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -43,30 +44,50 @@ --lora_r=64 \ --lora_alpha=16 """ -from dataclasses import dataclass, field +import logging +import os +from contextlib import nullcontext + +TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False) + +from trl.commands.cli_utils import init_zero_verbose, SftScriptArguments, TrlParser + +if TRL_USE_RICH: + init_zero_verbose() + FORMAT = "%(message)s" + + from rich.console import Console + from rich.logging import RichHandler import torch from datasets import load_dataset -from tqdm import tqdm -from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments -from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config +from tqdm.rich import tqdm +from transformers import AutoTokenizer, TrainingArguments +from trl import ( + ModelConfig, + RichProgressCallback, + SFTTrainer, + get_peft_config, + get_quantization_config, + get_kbit_device_map, +) tqdm.pandas() - -@dataclass -class ScriptArguments: - dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"}) - dataset_text_field: str = field(default="text", metadata={"help": "the text field of the dataset"}) - max_seq_length: int = field(default=512, metadata={"help": "The maximum sequence length for SFT Trainer"}) +if TRL_USE_RICH: + logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO) if __name__ == "__main__": - parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig)) - args, training_args, model_config = parser.parse_args_into_dataclasses() - training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False) + parser = TrlParser((SftScriptArguments, TrainingArguments, ModelConfig)) + args, training_args, model_config = parser.parse_args_and_config() + + # Force use our print callback + if TRL_USE_RICH: + training_args.disable_tqdm = True + console = Console() ################ # Model & Tokenizer @@ -96,20 +117,35 @@ class ScriptArguments: train_dataset = raw_datasets["train"] eval_dataset = raw_datasets["test"] + ################ + # Optional rich context managers + ############### + init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...") + save_context = ( + nullcontext() + if not TRL_USE_RICH + else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}") + ) + ################ # Training ################ - trainer = SFTTrainer( - model=model_config.model_name_or_path, - model_init_kwargs=model_kwargs, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - dataset_text_field="text", - max_seq_length=args.max_seq_length, - tokenizer=tokenizer, - packing=True, - peft_config=get_peft_config(model_config), - ) + with init_context: + trainer = SFTTrainer( + model=model_config.model_name_or_path, + model_init_kwargs=model_kwargs, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + dataset_text_field=args.dataset_text_field, + max_seq_length=args.max_seq_length, + tokenizer=tokenizer, + packing=args.packing, + peft_config=get_peft_config(model_config), + callbacks=[RichProgressCallback] if TRL_USE_RICH else None, + ) + trainer.train() - trainer.save_model(training_args.output_dir) + + with save_context: + trainer.save_model(training_args.output_dir) diff --git a/setup.cfg b/setup.cfg index 0c9e0fc144..29b21fba75 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,2 @@ [metadata] -license_file = LICENSE +license_file = LICENSE \ No newline at end of file diff --git a/setup.py b/setup.py index 91f01ff214..ce85815527 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ 8. Change the version in __init__.py and setup.py to X.X.X+1.dev0 (e.g. VERSION=1.18.3 -> 1.18.4.dev0). Then push the change with a message 'set dev version' """ +import os from setuptools import find_packages, setup @@ -79,34 +80,44 @@ for reqs in EXTRAS.values(): EXTRAS["dev"].extend(reqs) -setup( - name="trl", - license="Apache 2.0", - classifiers=[ - "Development Status :: 2 - Pre-Alpha", - "Intended Audience :: Developers", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: Apache Software License", - "Natural Language :: English", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - ], - url="https://github.com/huggingface/trl", - packages=find_packages(), - include_package_data=True, - install_requires=REQUIRED_PKGS, - extras_require=EXTRAS, - python_requires=">=3.7", - long_description=open("README.md", encoding="utf-8").read(), - long_description_content_type="text/markdown", - zip_safe=False, - version=__version__, - description="Train transformer language models with reinforcement learning.", - keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf", - author="Leandro von Werra", - author_email="leandro.vonwerra@gmail.com", -) +try: + file_path = os.path.dirname(os.path.abspath(__file__)) + os.symlink(os.path.join(file_path, "examples/scripts"), os.path.join(file_path, "trl/commands/scripts")) + + setup( + name="trl", + license="Apache 2.0", + classifiers=[ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + ], + url="https://github.com/huggingface/trl", + entry_points={ + "console_scripts": ["trl=trl.commands.cli:main"], + }, + package_data={"trl": ["commands/scripts/*"]}, + packages=find_packages(), + include_package_data=True, + install_requires=REQUIRED_PKGS, + extras_require=EXTRAS, + python_requires=">=3.7", + long_description=open("README.md", encoding="utf-8").read(), + long_description_content_type="text/markdown", + zip_safe=False, + version=__version__, + description="Train transformer language models with reinforcement learning.", + keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf", + author="Leandro von Werra", + author_email="leandro.vonwerra@gmail.com", + ) +finally: + os.unlink(os.path.join(file_path, "trl/commands/scripts")) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000000..d26ed5ffba --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,34 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import subprocess +import sys +import unittest + + +@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") +def test_sft_cli(): + subprocess.run( + "trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name imdb --learning_rate 1e-4 --lr_scheduler_type cosine", + shell=True, + check=True, + ) + + +@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows") +def test_dpo_cli(): + subprocess.run( + "trl dpo --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/Anthropic-hh-rlhf-processed --learning_rate 1e-4 --lr_scheduler_type cosine", + shell=True, + check=True, + ) diff --git a/trl/__init__.py b/trl/__init__.py index 943323390b..6ab0e4b39a 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -2,45 +2,126 @@ __version__ = "0.7.12.dev0" -from .core import set_seed -from .environment import TextEnvironment, TextHistory -from .extras import BestOfNSampler -from .import_utils import ( - is_bitsandbytes_available, - is_diffusers_available, - is_npu_available, - is_peft_available, - is_wandb_available, - is_xpu_available, -) -from .models import ( - AutoModelForCausalLMWithValueHead, - AutoModelForSeq2SeqLMWithValueHead, - PreTrainedModelWrapper, - create_reference_model, - setup_chat_format, -) -from .trainer import ( - DataCollatorForCompletionOnlyLM, - DPOTrainer, - IterativeSFTTrainer, - KTOConfig, - KTOTrainer, - ModelConfig, - PPOConfig, - PPOTrainer, - RewardConfig, - RewardTrainer, - SFTTrainer, -) -from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config - - -if is_diffusers_available(): +from typing import TYPE_CHECKING +from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable + +_import_structure = { + "core": [ + "set_seed", + ], + "environment": [ + "TextEnvironment", + "TextHistory", + ], + "extras": [ + "BestOfNSampler", + ], + "import_utils": [ + "is_bitsandbytes_available", + "is_diffusers_available", + "is_npu_available", + "is_peft_available", + "is_wandb_available", + "is_xpu_available", + ], + "models": [ + "AutoModelForCausalLMWithValueHead", + "AutoModelForSeq2SeqLMWithValueHead", + "PreTrainedModelWrapper", + "create_reference_model", + "setup_chat_format", + "SUPPORTED_ARCHITECTURES", + ], + "trainer": [ + "DataCollatorForCompletionOnlyLM", + "DPOTrainer", + "IterativeSFTTrainer", + "KTOConfig", + "KTOTrainer", + "ModelConfig", + "PPOConfig", + "PPOTrainer", + "RewardConfig", + "RewardTrainer", + "SFTTrainer", + ], + "commands": [], + "commands.utils": ["SftArgumentParser", "init_zero_verbose", "TrlParser", "DpoArgumentParser"], + "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "RichProgressCallback"], + "multitask_prompt_tuning": [ + "MultitaskPromptEmbedding", + "MultitaskPromptTuningConfig", + "MultitaskPromptTuningInit", + ], +} + +try: + if not is_diffusers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["models"].extend( + [ + "DDPOPipelineOutput", + "DDPOSchedulerOutput", + "DDPOStableDiffusionPipeline", + "DefaultDDPOStableDiffusionPipeline", + ] + ) + _import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"]) + +if TYPE_CHECKING: + from .core import set_seed + from .environment import TextEnvironment, TextHistory + from .extras import BestOfNSampler + from .import_utils import ( + is_bitsandbytes_available, + is_diffusers_available, + is_npu_available, + is_peft_available, + is_wandb_available, + is_xpu_available, + ) from .models import ( - DDPOPipelineOutput, - DDPOSchedulerOutput, - DDPOStableDiffusionPipeline, - DefaultDDPOStableDiffusionPipeline, + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, + PreTrainedModelWrapper, + create_reference_model, + setup_chat_format, + SUPPORTED_ARCHITECTURES, ) - from .trainer import DDPOConfig, DDPOTrainer + from .trainer import ( + DataCollatorForCompletionOnlyLM, + DPOTrainer, + IterativeSFTTrainer, + KTOConfig, + KTOTrainer, + ModelConfig, + PPOConfig, + PPOTrainer, + RewardConfig, + RewardTrainer, + SFTTrainer, + ) + from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, RichProgressCallback + from .commands.utils import init_zero_verbose, SftScriptArguments, DpoScriptArguments, TrlParser + + try: + if not is_diffusers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .models import ( + DDPOPipelineOutput, + DDPOSchedulerOutput, + DDPOStableDiffusionPipeline, + DefaultDDPOStableDiffusionPipeline, + ) + from .trainer import DDPOConfig, DDPOTrainer + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/trl/commands/__init__.py b/trl/commands/__init__.py new file mode 100644 index 0000000000..2c0afefef9 --- /dev/null +++ b/trl/commands/__init__.py @@ -0,0 +1,34 @@ +# flake8: noqa + +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa + +from typing import TYPE_CHECKING +from ..import_utils import _LazyModule, OptionalDependencyNotAvailable + + +_import_structure = { + "cli_utils": ["SftArgumentParser", "init_zero_verbose", "DpoScriptArguments", "TrlParser"], + "config_parser": ["YamlConfigParser"], +} + + +if TYPE_CHECKING: + from .cli_utils import SftScriptArguments, init_zero_verbose, DpoScriptArguments, TrlParser + from .config_parser import YamlConfigParser +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/trl/commands/cli.py b/trl/commands/cli.py new file mode 100644 index 0000000000..9f4f8da1f5 --- /dev/null +++ b/trl/commands/cli.py @@ -0,0 +1,65 @@ +# This file is a copy of trl/examples/scripts/sft.py so that we could +# use it together with rich and the TRL CLI in a more customizable manner. +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import subprocess +import sys +from subprocess import CalledProcessError + +from rich.console import Console + + +SUPPORTED_COMMANDS = ["sft", "dpo"] + + +def main(): + console = Console() + # Make sure to import things locally to avoid verbose from third party libs. + with console.status("[bold purple]Welcome! Initializing the TRL CLI..."): + from trl.commands.cli_utils import init_zero_verbose + + init_zero_verbose() + + command_name = sys.argv[1] + + if command_name not in SUPPORTED_COMMANDS: + raise ValueError( + f"Please use one of the supported commands, got {command_name} - supported commands are {SUPPORTED_COMMANDS}" + ) + + trl_examples_dir = os.path.dirname(__file__) + + # Force-use rich + os.environ["TRL_USE_RICH"] = "1" + + command = f""" + accelerate launch {trl_examples_dir}/scripts/{command_name}.py {" ".join(sys.argv[2:])} + """ + + try: + subprocess.run( + command.split(), + text=True, + check=True, + encoding="utf-8", + cwd=os.getcwd(), + env=os.environ.copy(), + ) + except (CalledProcessError, ChildProcessError): + console.log(f"TRL - {command_name.upper()} failed on ! See the logs above for further details.") + + +if __name__ == "__main__": + main() diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py new file mode 100644 index 0000000000..564828c4e7 --- /dev/null +++ b/trl/commands/cli_utils.py @@ -0,0 +1,227 @@ +# This file is a copy of trl/examples/scripts/sft.py so that we could +# use it together with rich and the TRL CLI in a more customizable manner. +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import inspect +import os +from copy import deepcopy +from dataclasses import asdict, dataclass, field, fields +from typing import Any, List + +import yaml +from transformers import HfArgumentParser + + +class YamlConfigParser: + def __init__(self, config_path: str = None, dataclasses: List[Any] = None): + self.config = None + + if config_path is not None: + with open(config_path) as yaml_file: + self.config = yaml.safe_load(yaml_file) + else: + self.config = {} + + if dataclasses is None: + dataclasses = [] + + # We create a dummy training args to compare the values before / after + # __post_init__ + # Here we import `TrainingArguments` from the local level to not + # break TRL lazy imports. + from transformers import TrainingArguments + + self._dummy_training_args = TrainingArguments(output_dir="dummy-training-args") + + self.parse_and_set_env() + self.merge_dataclasses(dataclasses) + + def parse_and_set_env(self): + if "env" in self.config: + env_vars = self.config["env"] + if isinstance(env_vars, dict): + for key, value in env_vars.items(): + os.environ[key] = str(value) + else: + raise ValueError("`env` field should be a dict in the YAML file.") + + def merge_dataclasses(self, dataclasses): + from transformers import TrainingArguments + + dataclasses_copy = [deepcopy(dataclass) for dataclass in dataclasses] + + if len(self.config) > 0: + for i, dataclass in enumerate(dataclasses): + is_hf_training_args = False + + for data_class_field in fields(dataclass): + # Get the field here + field_name = data_class_field.name + field_value = getattr(dataclass, field_name) + + if not isinstance(dataclass, TrainingArguments): + default_value = data_class_field.default + else: + default_value = ( + getattr(self._dummy_training_args, field_name) + if field_name != "output_dir" + else field_name + ) + is_hf_training_args = True + + default_value_changed = field_value != default_value + + if field_value is not None or field_name in self.config: + if field_name in self.config: + # In case the field value is not different from default, overwrite it + if not default_value_changed: + value_to_replace = self.config[field_name] + + setattr(dataclasses_copy[i], field_name, value_to_replace) + # Otherwise do nothing + + # Re-init `TrainingArguments` to handle all post-processing correctly + if is_hf_training_args: + init_signature = list(inspect.signature(TrainingArguments.__init__).parameters) + dict_dataclass = asdict(dataclasses_copy[i]) + new_dict_dataclass = {k: v for k, v in dict_dataclass.items() if k in init_signature} + dataclasses_copy[i] = TrainingArguments(**new_dict_dataclass) + + return dataclasses_copy + + def to_string(self): + final_string = """""" + for key, value in self.config.items(): + if isinstance(value, (dict, list)): + if len(value) != 0: + value = str(value) + value = value.replace("'", '"') + value = f"'{value}'" + else: + continue + + final_string += f"--{key} {value} " + return final_string + + +def init_zero_verbose(): + """ + Perform zero verbose init - use this method on top of the CLI modules to make + """ + import logging + import warnings + + from rich.logging import RichHandler + + FORMAT = "%(message)s" + logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.ERROR) + + # Custom warning handler to redirect warnings to the logging system + def warning_handler(message, category, filename, lineno, file=None, line=None): + logging.warning(f"{filename}:{lineno}: {category.__name__}: {message}") + + # Add the custom warning handler - we need to do that before importing anything to make sure the loggers work well + warnings.showwarning = warning_handler + + +@dataclass +class SftScriptArguments: + dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"}) + dataset_text_field: str = field(default="text", metadata={"help": "the text field of the dataset"}) + max_seq_length: int = field(default=512, metadata={"help": "The maximum sequence length for SFT Trainer"}) + packing: bool = field(default=False, metadata={"help": "Whether to apply data packing or not during training"}) + config: str = field(default=None, metadata={"help": "Path to the optional config file"}) + gradient_checkpointing_use_reentrant: bool = field( + default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"} + ) + + +@dataclass +class DpoScriptArguments: + dataset_name: str = field(default=None, metadata={"help": "the dataset name"}) + beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"}) + max_length: int = field(default=512, metadata={"help": "max length of each sample"}) + max_prompt_length: int = field(default=128, metadata={"help": "max length of each sample's prompt"}) + max_target_length: int = field( + default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"} + ) + sanity_check: bool = field(default=True, metadata={"help": "only train on 1000 samples"}) + ignore_bias_buffers: bool = field( + default=False, + metadata={ + "help": "debug argument for distributed training;" + "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See" + "https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992" + }, + ) + generate_during_eval: bool = field(default=False, metadata={"help": "Generate during evaluation"}) + config: str = field(default=None, metadata={"help": "Path to the optional config file"}) + gradient_checkpointing_use_reentrant: bool = field( + default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"} + ) + + +class TrlParser(HfArgumentParser): + def __init__(self, parsers): + """ + The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config + parsers for users that pass a valid `config` field and merge the values that are set in the config + with the processed parsers. + + Args: + parsers (`List[argparse.ArgumentParser`]): + List of parsers. + """ + super().__init__(parsers) + + def post_process_dataclasses(self, dataclasses): + # Apply additional post-processing in case some arguments needs a special + # care + training_args = trl_args = None + training_args_index = None + + for i, dataclass_obj in enumerate(dataclasses): + if dataclass_obj.__class__.__name__ == "TrainingArguments": + training_args = dataclass_obj + training_args_index = i + elif dataclass_obj.__class__.__name__ in ("SftScriptArguments", "DpoScriptArguments"): + trl_args = dataclass_obj + else: + ... + + if trl_args is not None and training_args is not None: + training_args.gradient_checkpointing_kwargs = dict( + use_reentrant=trl_args.gradient_checkpointing_use_reentrant + ) + dataclasses[training_args_index] = training_args + + return dataclasses + + def parse_args_and_config(self): + dataclasses = self.parse_args_into_dataclasses(return_remaining_strings=True) + # Pop the last element which should be the remaining strings + dataclasses = dataclasses[:-1] + self.config_parser = None + + for parser_dataclass in dataclasses: + if hasattr(parser_dataclass, "config"): + if self.config_parser is not None: + raise ValueError("You passed the `config` field twice! Make sure to pass `config` only once.") + self.config_parser = YamlConfigParser(parser_dataclass.config) + + if self.config_parser is not None: + dataclasses = self.config_parser.merge_dataclasses(dataclasses) + + dataclasses = self.post_process_dataclasses(dataclasses) + return dataclasses diff --git a/trl/environment/__init__.py b/trl/environment/__init__.py index ae1cda4ecb..85b0c501da 100644 --- a/trl/environment/__init__.py +++ b/trl/environment/__init__.py @@ -1,3 +1,14 @@ # flake8: noqa +from typing import TYPE_CHECKING +from ..import_utils import _LazyModule -from .base_environment import TextEnvironment, TextHistory +_import_structure = { + "base_environment": ["TextEnvironment", "TextHistory"], +} + +if TYPE_CHECKING: + from .base_environment import TextEnvironment, TextHistory +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/trl/extras/__init__.py b/trl/extras/__init__.py index 6b3035db92..ef6b216986 100644 --- a/trl/extras/__init__.py +++ b/trl/extras/__init__.py @@ -13,4 +13,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .best_of_n_sampler import BestOfNSampler +from typing import TYPE_CHECKING + +from ..import_utils import _LazyModule + + +_import_structure = { + "best_of_n_sampler": ["BestOfNSampler"], +} + +if TYPE_CHECKING: + from .best_of_n_sampler import BestOfNSampler +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/trl/import_utils.py b/trl/import_utils.py index 3cb97d9002..c58c44ac3d 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +import os import sys +from importlib.util import find_spec +from itertools import chain +from types import ModuleType +from typing import Any if sys.version_info < (3, 8): @@ -22,11 +27,11 @@ def is_peft_available() -> bool: - return importlib.util.find_spec("peft") is not None + return find_spec("peft") is not None def is_unsloth_available() -> bool: - return importlib.util.find_spec("unsloth") is not None + return find_spec("unsloth") is not None def is_accelerate_greater_20_0() -> bool: @@ -66,26 +71,26 @@ def is_torch_greater_2_0() -> bool: def is_diffusers_available() -> bool: - return importlib.util.find_spec("diffusers") is not None + return find_spec("diffusers") is not None def is_bitsandbytes_available() -> bool: import torch # bnb can be imported without GPU but is not usable. - return importlib.util.find_spec("bitsandbytes") is not None and torch.cuda.is_available() + return find_spec("bitsandbytes") is not None and torch.cuda.is_available() def is_torchvision_available() -> bool: - return importlib.util.find_spec("torchvision") is not None + return find_spec("torchvision") is not None def is_rich_available() -> bool: - return importlib.util.find_spec("rich") is not None + return find_spec("rich") is not None def is_wandb_available() -> bool: - return importlib.util.find_spec("wandb") is not None + return find_spec("wandb") is not None def is_xpu_available() -> bool: @@ -94,7 +99,7 @@ def is_xpu_available() -> bool: return accelerate.utils.is_xpu_available() else: - if importlib.util.find_spec("intel_extension_for_pytorch") is None: + if find_spec("intel_extension_for_pytorch") is None: return False try: import torch @@ -106,10 +111,74 @@ def is_xpu_available() -> bool: def is_npu_available() -> bool: """Checks if `torch_npu` is installed and potentially if a NPU is in the environment""" - if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: + if find_spec("torch") is None or find_spec("torch_npu") is None: return False import torch import torch_npu # noqa: F401 return hasattr(torch, "npu") and torch.npu.is_available() + + +class _LazyModule(ModuleType): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + # Very heavily inspired by optuna.integration._IntegrationModule + # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py + def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None): + super().__init__(name) + self._modules = set(import_structure.keys()) + self._class_to_module = {} + for key, values in import_structure.items(): + for value in values: + self._class_to_module[value] = key + # Needed for autocompletion in an IDE + self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) + self.__file__ = module_file + self.__spec__ = module_spec + self.__path__ = [os.path.dirname(module_file)] + self._objects = {} if extra_objects is None else extra_objects + self._name = name + self._import_structure = import_structure + + # Needed for autocompletion in an IDE + def __dir__(self): + result = super().__dir__() + # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether + # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. + for attr in self.__all__: + if attr not in result: + result.append(attr) + return result + + def __getattr__(self, name: str) -> Any: + if name in self._objects: + return self._objects[name] + if name in self._modules: + value = self._get_module(name) + elif name in self._class_to_module.keys(): + module = self._get_module(self._class_to_module[name]) + value = getattr(module, name) + else: + raise AttributeError(f"module {self.__name__} has no attribute {name}") + + setattr(self, name, value) + return value + + def _get_module(self, module_name: str): + try: + return importlib.import_module("." + module_name, self.__name__) + except Exception as e: + raise RuntimeError( + f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its" + f" traceback):\n{e}" + ) from e + + def __reduce__(self): + return (self.__class__, (self._name, self.__file__, self._import_structure)) + + +class OptionalDependencyNotAvailable(BaseException): + """Internally used error class for signalling an optional dependency was not found.""" diff --git a/trl/models/__init__.py b/trl/models/__init__.py index ec20345533..72c4a3ff22 100644 --- a/trl/models/__init__.py +++ b/trl/models/__init__.py @@ -13,23 +13,52 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .modeling_base import PreTrainedModelWrapper, create_reference_model -from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead -from .utils import setup_chat_format +# flake8: noqa + +from typing import TYPE_CHECKING +from ..import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable + +_import_structure = { + "modeling_base": ["PreTrainedModelWrapper", "create_reference_model"], + "modeling_value_head": [ + "AutoModelForCausalLMWithValueHead", + "AutoModelForSeq2SeqLMWithValueHead", + ], + "utils": ["setup_chat_format", "SUPPORTED_ARCHITECTURES"], +} -SUPPORTED_ARCHITECTURES = ( - AutoModelForCausalLMWithValueHead, - AutoModelForSeq2SeqLMWithValueHead, -) +try: + if not is_diffusers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_sd_base"] = [ + "DDPOPipelineOutput", + "DDPOSchedulerOutput", + "DDPOStableDiffusionPipeline", + "DefaultDDPOStableDiffusionPipeline", + ] -from ..import_utils import is_diffusers_available +if TYPE_CHECKING: + from .modeling_base import PreTrainedModelWrapper, create_reference_model + from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead + from .utils import setup_chat_format, SUPPORTED_ARCHITECTURES + try: + if not is_diffusers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_sd_base import ( + DDPOPipelineOutput, + DDPOSchedulerOutput, + DDPOStableDiffusionPipeline, + DefaultDDPOStableDiffusionPipeline, + ) +else: + import sys -if is_diffusers_available(): - from .modeling_sd_base import ( - DDPOPipelineOutput, - DDPOSchedulerOutput, - DDPOStableDiffusionPipeline, - DefaultDDPOStableDiffusionPipeline, - ) + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/trl/models/utils.py b/trl/models/utils.py index 849ac19020..8e9df978d1 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -3,6 +3,14 @@ from transformers import PreTrainedModel, PreTrainedTokenizer +from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead + + +SUPPORTED_ARCHITECTURES = ( + AutoModelForCausalLMWithValueHead, + AutoModelForSeq2SeqLMWithValueHead, +) + # TODO: Add Abstract Base Class if more formats are added @dataclass diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 86e654213f..e7b517dda1 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -15,34 +15,84 @@ # limitations under the License. # There is a circular import in the PPOTrainer if we let isort sort these -# isort: off -from .utils import ( - AdaptiveKLController, - FixedKLController, - ConstantLengthDataset, - DataCollatorForCompletionOnlyLM, - RunningMoments, - disable_dropout_in_model, - peft_module_casting_to_bf16, -) - -# isort: on - -from ..import_utils import is_diffusers_available -from .base import BaseTrainer -from .ddpo_config import DDPOConfig - - -if is_diffusers_available(): - from .ddpo_trainer import DDPOTrainer - -from .dpo_trainer import DPOTrainer -from .iterative_sft_trainer import IterativeSFTTrainer -from .kto_config import KTOConfig -from .kto_trainer import KTOTrainer -from .model_config import ModelConfig -from .ppo_config import PPOConfig -from .ppo_trainer import PPOTrainer -from .reward_config import RewardConfig -from .reward_trainer import RewardTrainer, compute_accuracy -from .sft_trainer import SFTTrainer +from typing import TYPE_CHECKING +from ..import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable + +_import_structure = { + "utils": [ + "AdaptiveKLController", + "FixedKLController", + "ConstantLengthDataset", + "DataCollatorForCompletionOnlyLM", + "RunningMoments", + "disable_dropout_in_model", + "peft_module_casting_to_bf16", + "RichProgressCallback", + ], + "dpo_trainer": [ + "DPOTrainer", + ], + "iterative_sft_trainer": [ + "IterativeSFTTrainer", + ], + "kto_config": ["KTOConfig"], + "kto_trainer": ["KTOTrainer"], + "model_config": ["ModelConfig"], + "ppo_config": ["PPOConfig"], + "ppo_trainer": ["PPOTrainer"], + "reward_config": ["RewardConfig"], + "reward_trainer": ["RewardTrainer", "compute_accuracy"], + "sft_trainer": ["SFTTrainer"], + "base": ["BaseTrainer"], + "ddpo_config": ["DDPOConfig"], +} + +try: + if not is_diffusers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["ddpo_trainer"] = ["DDPOTrainer"] + + +if TYPE_CHECKING: + # isort: off + from .utils import ( + AdaptiveKLController, + FixedKLController, + ConstantLengthDataset, + DataCollatorForCompletionOnlyLM, + RunningMoments, + disable_dropout_in_model, + peft_module_casting_to_bf16, + RichProgressCallback, + ) + + # isort: on + + from .base import BaseTrainer + from .ddpo_config import DDPOConfig + + from .dpo_trainer import DPOTrainer + from .iterative_sft_trainer import IterativeSFTTrainer + from .kto_config import KTOConfig + from .kto_trainer import KTOTrainer + from .model_config import ModelConfig + from .ppo_config import PPOConfig + from .ppo_trainer import PPOTrainer + from .reward_config import RewardConfig + from .reward_trainer import RewardTrainer, compute_accuracy + from .sft_trainer import SFTTrainer + + try: + if not is_diffusers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .ddpo_trainer import DDPOTrainer +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 0dd06dc728..edbf2c6000 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -42,6 +42,7 @@ from .utils import ( ConstantLengthDataset, DataCollatorForCompletionOnlyLM, + RichProgressCallback, neftune_post_forward_hook, peft_module_casting_to_bf16, trl_sanitze_kwargs_for_tagging, @@ -344,6 +345,12 @@ def make_inputs_require_grad(module, input, output): elif self.args.max_steps == -1 and packing: self.train_dataset.infinite = False + if any(isinstance(callback, RichProgressCallback) for callback in self.callback_handler.callbacks): + for callback in self.callback_handler.callbacks: + # Remove the PrinterCallback to avoid duplicated prints in case we passed a `RichProgressCallback` + if callback.__class__.__name__ == "PrinterCallback": + self.callback_handler.pop_callback(callback) + @wraps(Trainer.train) def train(self, *args, **kwargs): # Activate neftune right before training. diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 33edbadfd3..bfbc57c279 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -20,9 +20,15 @@ import numpy as np import torch from accelerate import PartialState +from rich.console import Console, Group +from rich.live import Live +from rich.panel import Panel +from rich.progress import Progress from torch.nn.utils.rnn import pad_sequence from torch.utils.data import IterableDataset from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase +from transformers.trainer import TrainerCallback +from transformers.trainer_utils import has_length from ..import_utils import is_peft_available, is_unsloth_available, is_xpu_available from ..trainer.model_config import ModelConfig @@ -732,3 +738,70 @@ def get_peft_config(model_config: ModelConfig) -> "Optional[PeftConfig]": ) return peft_config + + +class RichProgressCallback(TrainerCallback): + """ + A [`TrainerCallback`] that displays the progress of training or evaluation using Rich. + """ + + def __init__(self): + self.training_bar = Progress() + self.prediction_bar = Progress() + + self.training_task_id = None + self.prediction_task_id = None + + self.rich_group = None + self.rich_console = None + + def on_train_begin(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.rich_console = Console() + + self.training_status = self.rich_console.status("Nothing to log yet ...") + + self.rich_group = Live(Panel(Group(self.training_bar, self.prediction_bar, self.training_status))) + self.rich_group.start() + + # self.training_bar.start() + self.training_task_id = self.training_bar.add_task("[blue]Training the model", total=state.max_steps) + self.current_step = 0 + + def on_step_end(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar.update(self.training_task_id, advance=state.global_step - self.current_step, update=True) + self.current_step = state.global_step + + def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs): + if state.is_world_process_zero and has_length(eval_dataloader): + if self.prediction_bar is None: + # self.prediction_bar.start() + self.prediction_task_id = self.prediction_bar.add_task( + "[blue]Predicting on the evaluation dataset", total=state.max_steps + ) + self.prediction_bar.update(self.prediction_task_id, advance=1, update=True) + + def on_evaluate(self, args, state, control, **kwargs): + if state.is_world_process_zero: + if self.prediction_bar is not None: + self.prediction_bar.close() + self.prediction_bar = None + + def on_predict(self, args, state, control, **kwargs): + if state.is_world_process_zero: + if self.prediction_bar is not None: + self.prediction_bar.stop() + self.prediction_bar.remove_task(self.prediction_task_id) + self.prediction_bar = None + + def on_log(self, args, state, control, logs=None, **kwargs): + if state.is_world_process_zero and self.training_bar is not None: + _ = logs.pop("total_flos", None) + self.training_status.update(f"[bold green]Status = {str(logs)}") + + def on_train_end(self, args, state, control, **kwargs): + if state.is_world_process_zero: + self.training_bar.stop() + self.rich_group.stop() + self.training_bar = None