From 0371e061a58e730dff70e320d1a4050cfd19aee7 Mon Sep 17 00:00:00 2001 From: Oleg S <97077423+RobotSail@users.noreply.github.com> Date: Wed, 16 Oct 2024 16:16:04 +0000 Subject: [PATCH] This PR implements the capability of the training library to save LoRA models when training with FSDP as the distributed backend. This is accomplished by creating a copy of the LoRA model on the CPU, loading in the state dict after gathering it from the distributed model, and saving after merging the adapters back into the original model. Afterwards, the CPU copy is discarded and training continues. Signed-off-by: Oleg S <97077423+RobotSail@users.noreply.github.com> --- src/instructlab/training/main_ds.py | 65 +++----- src/instructlab/training/setup_accelerator.py | 29 ++-- src/instructlab/training/utils.py | 157 +++++++++++++++++- 3 files changed, 193 insertions(+), 58 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index d1ae0e01..e581ba11 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -2,6 +2,7 @@ # Standard from copy import deepcopy +from datetime import timedelta from pathlib import Path import argparse import math @@ -42,8 +43,8 @@ add_noisy_embeddings, apply_gradient_checkpointing, convert_loss_to_reduce_sum, + create_lora_config, ensure_loadable_granite_checkpoint, - get_projection_layer_names, load_latest_full_state, prepare_peft_model, prepare_universal_checkpoint_from_latest, @@ -114,13 +115,16 @@ def setup_model(args, tokenizer, train_loader, grad_accum): args.model_name_or_path, args.output_dir ) as path: base_model_args["pretrained_model_name_or_path"] = path + base_model_args["use_padding_free_transformer"] = True model = GPTDolomiteForCausalLM.from_pretrained( **base_model_args, - use_padding_free_transformer=True, ) else: model = AutoModelForCausalLM.from_pretrained(**base_model_args) + # store the base model args so we can recall them later if saving a LoRA model + args.base_model_args = base_model_args + if len(tokenizer) > model.config.vocab_size: print( f"WARNING: tokenizer has {len(tokenizer)} tokens but model has {model.config.vocab_size} vocab size" @@ -175,46 +179,14 @@ def setup_model(args, tokenizer, train_loader, grad_accum): # - with the exception of granite, which handles it # in the later stanza if args.lora_r > 0: - # if lora - # Third Party - from peft import LoraConfig - - # ensure we select only the modules that exist in the model - proj_layers = get_projection_layer_names(model) - if not args.lora_target_modules: - print( - f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules" - ) - if not proj_layers: - raise RuntimeError("could not find any projection layers in the model") - args.__dict__["lora_target_modules"] = proj_layers - else: - # when the user specifies the module, we should verify that they align with what's in the model - lora_target_modules_set = set(args.lora_target_modules) - diff = lora_target_modules_set - set(proj_layers) - layers_to_target = lora_target_modules_set - diff - if len(diff) == len(args.lora_target_modules): - raise ValueError( - f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically." - ) - if diff: - print( - f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m" - ) - args.__dict__["lora_target_modules"] = list(layers_to_target) - - peft_config = LoraConfig( - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - r=args.lora_r, - bias="none", - task_type="CAUSAL_LM", - target_modules=args.lora_target_modules, - ) + lora_config = create_lora_config(model, args) model = prepare_peft_model( - model, peft_config, gradient_checkpointing=not args.is_granite + model, + lora_config, + args.distributed_training_framework, + gradient_checkpointing=not args.is_granite, ) - + args.lora_config = lora_config elif not args.is_granite: model.gradient_checkpointing_enable() @@ -529,7 +501,11 @@ def main(args): #### distributed init ##### torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) args.local_rank = int(os.environ["LOCAL_RANK"]) - torch.distributed.init_process_group("nccl") + nccl_timeout: timedelta | None = None + if args.debug: + # surely we won't need any more than this... right? + nccl_timeout = timedelta(days=1) + torch.distributed.init_process_group("nccl", timeout=nccl_timeout) args.global_rank = torch.distributed.get_rank() tensor = torch.ByteTensor([False]).cuda() torch.distributed.all_reduce(tensor) @@ -932,6 +908,13 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py" ), ) + # hidden argument for our own sake + parser.add_argument( + "--debug", + help="Enables settings for debugging. For example, the NCCL timeout increases so more time can be spent in breakpoints.", + action="store_true", + default=False, + ) parser.add_argument("--disable_flash_attn", action="store_true") args = parser.parse_args() set_random_seed(args.seed) diff --git a/src/instructlab/training/setup_accelerator.py b/src/instructlab/training/setup_accelerator.py index 33972b59..3c3f9cec 100644 --- a/src/instructlab/training/setup_accelerator.py +++ b/src/instructlab/training/setup_accelerator.py @@ -3,12 +3,10 @@ # Third Party from accelerate import Accelerator -from torch.distributed.fsdp import ( # FullyShardedDataParallel as FSDP, - BackwardPrefetch, - MixedPrecision, - ShardingStrategy, -) +from peft.utils.other import fsdp_auto_wrap_policy +from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision, ShardingStrategy from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy +from transformers import PreTrainedModel import torch # First Party @@ -51,34 +49,43 @@ def get_ds_plugin(world_size, samples_per_gpu, grad_accum, opts: DeepSpeedOption return ds_plugin -def get_fsdp_config(args, model): +def get_fsdp_config(args, model: PreTrainedModel): # Third Party from accelerate.utils import FullyShardedDataParallelPlugin from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload block_name = model._no_split_modules[0] - fsdp_plugin = FullyShardedDataParallelPlugin( - auto_wrap_policy=partial( + wrap_policy = None + if args.lora_r > 0: + wrap_policy = fsdp_auto_wrap_policy(model) + else: + wrap_policy = partial( transformer_auto_wrap_policy, transformer_layer_cls={ get_module_class_from_name(model, block_name), }, - ), + ) + + fsdp_plugin = FullyShardedDataParallelPlugin( + auto_wrap_policy=wrap_policy, limit_all_gathers=True, mixed_precision_policy=MixedPrecision( param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16, buffer_dtype=torch.bfloat16, ), - backward_prefetch=BackwardPrefetch.BACKWARD_PRE, + backward_prefetch=BackwardPrefetch.BACKWARD_POST, sharding_strategy=ShardingStrategy[args.fsdp_sharding_strategy], cpu_offload=CPUOffload(args.cpu_offload_params_fsdp), ) + if args.lora_r > 0: + fsdp_plugin.use_orig_params = False + return fsdp_plugin -def setup_accelerator(args, model, grad_accum): +def setup_accelerator(args, model: PreTrainedModel, grad_accum): if args.distributed_training_framework == "deepspeed": # Third Party from deepspeed import DeepSpeedEngine diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 6d79d897..1a6422c2 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -1,13 +1,14 @@ # SPDX-License-Identifier: Apache-2.0 # Standard +from argparse import Namespace from collections import OrderedDict from contextlib import contextmanager from copy import deepcopy from functools import partial from pathlib import Path from tempfile import TemporaryDirectory -from typing import Any, List, Optional +from typing import Any, List, Optional, Tuple import importlib import inspect import logging @@ -21,7 +22,7 @@ # Third Party # pylint: disable=no-name-in-module -from accelerate import Accelerator +from accelerate import Accelerator, DistributedType from instructlab.dolomite.hf_models import ( GPTDolomiteConfig, export_to_huggingface, @@ -29,17 +30,24 @@ ) from rich.logging import RichHandler from torch import distributed as dist +from torch import nn from torch.distributed import get_rank, is_initialized from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( CheckpointImpl, apply_activation_checkpointing, checkpoint_wrapper, ) -from transformers import PreTrainedModel +from torch.distributed.fsdp import FullStateDictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import StateDictType +from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer import numpy as np import torch import torch.nn.functional as F +# First Party +from instructlab.training.config import DistributedBackend + def retrieve_chat_template(chat_tmpl_path): try: @@ -304,9 +312,129 @@ def patch_target_module( setattr(source, obj_name_to_patch, replace_with) +def wraps(module: nn.Module, wrapped_classes: Tuple[Any]) -> bool: + """Checks if a module or its children are an instance of one of the provided classes. + + Args: + module (nn.Module): A PyTorch module. + wrapped_classes(Tuple): A tuple of potential classes the module could be. + + Returns: + bool: True if the module or any of its children are instances of one of `wrapped_classes`, False otherwise. + """ + if isinstance(module, wrapped_classes): + return True + + for m in module.children(): + if wraps(m, wrapped_classes): + return True + + return False + + +def create_lora_config(model: PreTrainedModel, args: Namespace) -> "peft.LoraConfig": + # if lora + # Third Party + from peft import LoraConfig + + # ensure we select only the modules that exist in the model + proj_layers = get_projection_layer_names(model) + if not args.lora_target_modules: + print( + f"WARNING: lora_target_modules was not specified, defaulting to all of the model's projection modules" + ) + if not proj_layers: + raise RuntimeError("could not find any projection layers in the model") + args.__dict__["lora_target_modules"] = proj_layers + else: + # when the user specifies the module, we should verify that they align with what's in the model + lora_target_modules_set = set(args.lora_target_modules) + diff = lora_target_modules_set - set(proj_layers) + layers_to_target = lora_target_modules_set - diff + if len(diff) == len(args.lora_target_modules): + raise ValueError( + f"None of the modules you requested exist in the model.\nRequested modules: {args.lora_target_modules}; Available modules: {proj_layers}.\nThis is usually a misconfiuration error. Consider omitting your `lora_target_modules` list to have these discovered automatically." + ) + if diff: + print( + f"\033[33mWARNING: the following modules were targeted for LoRA but are not present in the model: {list(diff)}. Applying LoRA only to {list(layers_to_target)} modules.\033[0m" + ) + args.__dict__["lora_target_modules"] = list(layers_to_target) + + return LoraConfig( + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + r=args.lora_r, + bias="none", + task_type="CAUSAL_LM", + target_modules=args.lora_target_modules, + ) + + +def save_fsdp_lora_model( + args: Namespace, + model: FSDP, + tokenizer: PreTrainedTokenizer, + accelerator: Accelerator, + output_dir: Path, +): + """Given a LoRA model wrapped by FSDP and Accelerate, save a full copy of the original + model with the trained LoRA adapters merged into the copy. + + This function creates a full copy of the model being trained and stores it in CPU memory. + If encountering OOM errors on CPU, this is likely a culprit. + + Args: + args (Namespace): Args received by the ArgumentParser. + model (FSDP): FSDP model as prepared by `accelerate.Accelerator` + accelerator (Accelerator): The given accelerator object. + """ + # Third Party + from peft import LoraConfig, LoraModel + + if accelerator.distributed_type != DistributedType.FSDP: + raise RuntimeError( + "`save_fsdp_lora_model` was called when FSDP was not being used." + ) + if not wraps(model, FSDP): + raise RuntimeError( + "`save_fsdp_lora_model` was called but provided model is not an FSDP model." + ) + if not wraps(model, LoraModel): + raise RuntimeError( + "`save_fsdp_lora_model` was called but provided model is not a LoRA model." + ) + + # okay now that validation is out of the way, we are free to implement saving + lora_conf: LoraConfig = args.lora_config + sd_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, sd_config): + state = model.state_dict() + + if accelerator.is_main_process: + # remove device_map from args list so we can load the model on CPU + old_device_map = args.base_model_args.pop("device_map", None) + model_copy = AutoModelForCausalLM.from_pretrained( + **args.base_model_args, device_map="cpu" + ) + model_copy = LoraModel(model_copy, lora_conf, "default") + model_copy.load_state_dict(state) + model_copy.merge_and_unload(progressbar=True) + model_copy.save_pretrained(output_dir, safe_serialization=True) + model.config.to_json_file(f"{output_dir}/config.json") + tokenizer.save_pretrained(output_dir) + del model_copy + if old_device_map: + # return the previous device_map so it can be used later on if needed + args.base_model_args["device_map"] = old_device_map + + dist.barrier() + + def prepare_peft_model( - model, + model: PreTrainedModel, peft_config, + distributed_backend: str, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": True}, mixed_precision="bf16", @@ -314,6 +442,7 @@ def prepare_peft_model( # will guard this # Third Party from peft import ( + LoraModel, PeftConfig, PeftModel, get_peft_model, @@ -355,7 +484,11 @@ def make_inputs_require_grad(module, input, output): make_inputs_require_grad ) - model = get_peft_model(model, peft_config) + if distributed_backend == DistributedBackend.FSDP.value: + # FSDP doesn't like `get_peft_model` as it leads to dtype mismatches + model = LoraModel(model, peft_config, "default") + else: + model = get_peft_model(model, peft_config) if mixed_precision == "bf16" and getattr(model, "is_loaded_in_4bit", False): peft_module_casting_to_bf16(model) @@ -630,7 +763,7 @@ def _copy_no_lora_dict(state_dict): def save_dict_accelerate( - accelerator, + accelerator: Accelerator, state_to_save, save_directory, max_shard_size="5GB", @@ -681,6 +814,17 @@ def save_hf_format_accelerate( CONFIG_NAME = "config.json" output_config_file = output_dir / CONFIG_NAME + if is_lora and accelerator.distributed_type == DistributedType.FSDP: + save_fsdp_lora_model( + args=args, + model=model, + tokenizer=tokenizer, + accelerator=accelerator, + output_dir=output_dir, + ) + dist.barrier() + return + get_state_dict_unpatched = accelerator.get_state_dict def _get_state_dict_patched(model, unwrap=False): @@ -711,6 +855,7 @@ def _get_state_dict_patched(model, unwrap=False): safe_serialization=True, ) model.module.unmerge_adapter() + dist.barrier() if not is_lora: accelerator.save_model(