diff --git a/docs/source/sft_trainer.mdx b/docs/source/sft_trainer.mdx index fdcc1b91eb..a6ca3e439c 100644 --- a/docs/source/sft_trainer.mdx +++ b/docs/source/sft_trainer.mdx @@ -279,33 +279,6 @@ trainer = SFTTrainer( trainer.train() ``` -Note that in case of training adapters, we manually add a saving callback to automatically save the adapters only: -```python -class PeftSavingCallback(TrainerCallback): - def on_save(self, args, state, control, **kwargs): - checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") - kwargs["model"].save_pretrained(checkpoint_path) - - if "pytorch_model.bin" in os.listdir(checkpoint_path): - os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) -``` -If you want to add more callbacks, make sure to add this one as well to properly save the adapters only during training. -```python -... - -callbacks = [YourCustomCallback(), PeftSavingCallback()] - -trainer = SFTTrainer( - "EleutherAI/gpt-neo-125m", - train_dataset=dataset, - dataset_text_field="text", - peft_config=peft_config, - callbacks=callbacks -) - -trainer.train() -``` - You can also continue training your `PeftModel`. For that, first load a `PeftModel` outside `SFTTrainer` and pass it directly to the trainer without the `peft_config` argument being passed. ### Training adapters with base 8 bit models diff --git a/setup.py b/setup.py index 1b2e1cb14c..523496c6cb 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ REQUIRED_PKGS = [ "torch>=1.4.0", - "transformers>=4.18.0", + "transformers>=4.31.0", "numpy>=1.18.2", "accelerate", "datasets", diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index e81705fbc2..7c239e417b 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -23,6 +23,7 @@ DataCollatorForCompletionOnlyLM, RunningMoments, disable_dropout_in_model, + peft_module_casting_to_bf16, ) # isort: on diff --git a/trl/trainer/reward_trainer.py b/trl/trainer/reward_trainer.py index ed81ef73e6..44a5b79223 100644 --- a/trl/trainer/reward_trainer.py +++ b/trl/trainer/reward_trainer.py @@ -26,7 +26,7 @@ from ..import_utils import is_peft_available from .training_configs import RewardConfig -from .utils import PeftSavingCallback, RewardDataCollatorWithPadding, compute_accuracy +from .utils import RewardDataCollatorWithPadding, compute_accuracy if is_peft_available(): @@ -147,12 +147,6 @@ def __init__( model = get_peft_model(model, peft_config) - if is_peft_available() and isinstance(model, PeftModel): - if callbacks is None: - callbacks = [PeftSavingCallback()] - else: - callbacks += [PeftSavingCallback()] - if compute_metrics is None: compute_metrics = compute_accuracy diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 6c4a6c536a..677b0efd41 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -40,8 +40,8 @@ from .utils import ( ConstantLengthDataset, DataCollatorForCompletionOnlyLM, - PeftSavingCallback, neftune_post_forward_hook, + peft_module_casting_to_bf16, ) @@ -201,9 +201,8 @@ def make_inputs_require_grad(module, input, output): model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) model = get_peft_model(model, peft_config) - - if callbacks is None: - callbacks = [PeftSavingCallback] + if args.bf16 and getattr(model, "is_loaded_in_4bit", False): + peft_module_casting_to_bf16(model) if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 554346cfd1..f0d90da138 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import random import warnings from collections import deque @@ -22,7 +21,7 @@ import torch from torch.nn.utils.rnn import pad_sequence from torch.utils.data import IterableDataset -from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerBase, TrainerCallback +from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizerBase class AdaptiveKLController: @@ -204,6 +203,7 @@ class RewardDataCollatorWithPadding: return_tensors (`str`, `optional`, defaults to `"pt"`): The tensor type to use. """ + tokenizer: PreTrainedTokenizerBase padding: Union[bool, str] = True max_length: Optional[int] = None @@ -281,6 +281,7 @@ class DPODataCollatorWithPadding: is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`): Whether or not you model has an encoder_decoder architecture. """ + pad_token_id: int = 0 label_pad_token_id: int = -100 is_encoder_decoder: Optional[bool] = False @@ -358,6 +359,8 @@ class ConstantLengthDataset(IterableDataset): Id of the end of sequence token if the passed tokenizer does not have an EOS token. shuffle ('bool', *optional*, defaults to True) Shuffle the examples before they are returned + append_concat_token ('bool', *optional*, defaults to True) + If true, appends `eos_token_id` at the end of each sample being packed. """ def __init__( @@ -372,6 +375,7 @@ def __init__( chars_per_token=3.6, eos_token_id=0, shuffle=True, + append_concat_token=True, ): self.tokenizer = tokenizer @@ -388,6 +392,7 @@ def __init__( self.current_size = 0 self.max_buffer_size = seq_length * chars_per_token * num_of_sequences self.shuffle = shuffle + self.append_concat_token = append_concat_token if formatting_func is None: self.formatting_func = lambda x: x[dataset_text_field] else: @@ -424,7 +429,9 @@ def __iter__(self): tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] all_token_ids = [] for tokenized_input in tokenized_inputs: - all_token_ids.extend(tokenized_input + [self.concat_token_id]) + if self.append_concat_token: + tokenized_input = tokenized_input + [self.concat_token_id] + all_token_ids.extend(tokenized_input) examples = [] for i in range(0, len(all_token_ids), self.seq_length): input_ids = all_token_ids[i : i + self.seq_length] @@ -440,16 +447,6 @@ def __iter__(self): } -class PeftSavingCallback(TrainerCallback): - def on_save(self, args, state, control, **kwargs): - if args.should_save: - checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}") - kwargs["model"].save_pretrained(checkpoint_path) - - if "pytorch_model.bin" in os.listdir(checkpoint_path): - os.remove(os.path.join(checkpoint_path, "pytorch_model.bin")) - - class RunningMoments: def __init__(self, accelerator): """ @@ -620,3 +617,17 @@ def neftune_post_forward_hook(module, input, output): mag_norm = module.neftune_noise_alpha / torch.sqrt(dims) output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm) return output + + +def peft_module_casting_to_bf16(model): + from peft.tuners.tuners_utils import BaseTunerLayer + + for name, module in model.named_modules(): + if isinstance(module, BaseTunerLayer): + module = module.to(torch.bfloat16) + if "norm" in name: + module = module.to(torch.float32) + if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): + if hasattr(module, "weight"): + if module.weight.dtype == torch.float32: + module = module.to(torch.bfloat16)