diff --git a/litgpt/finetune/adapter_v2.py b/litgpt/finetune/adapter_v2.py index f354decfd0..270674deff 100644 --- a/litgpt/finetune/adapter_v2.py +++ b/litgpt/finetune/adapter_v2.py @@ -24,6 +24,7 @@ from litgpt.utils import ( CLI, CycleIterator, + load_checkpoint_update, check_valid_checkpoint_dir, choose_logger, chunked_cross_entropy, @@ -43,6 +44,7 @@ def setup( precision: Optional[str] = None, quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None, devices: Union[int, str] = 1, + resume: Optional[bool] = False, data: Optional[DataModule] = None, train: TrainArgs = TrainArgs( save_interval=1000, @@ -110,7 +112,7 @@ def setup( strategy = "auto" fabric = L.Fabric(devices=devices, strategy=strategy, precision=precision, loggers=logger, plugins=plugins) - fabric.launch(main, devices, seed, config, data, checkpoint_dir, out_dir, train, eval) + fabric.launch(main, devices, seed, config, data, resume, checkpoint_dir, out_dir, train, eval) def main( @@ -119,6 +121,7 @@ def main( seed: int, config: Config, data: DataModule, + resume: bool, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, @@ -149,7 +152,6 @@ def main( trainable_params = [p for p in model.parameters() if p.requires_grad] if isinstance(fabric.strategy.precision, BitsandbytesPrecision): import bitsandbytes as bnb - optimizer_cls = bnb.optim.PagedAdamW else: optimizer_cls = torch.optim.AdamW @@ -158,10 +160,23 @@ def main( ) optimizer = fabric.setup_optimizers(optimizer) scheduler = get_lr_scheduler(optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps) - + if resume == True: + # Finding last trace of adapter training + try: + resume = max(out_dir.rglob("step-*/*.pth.adapter_v2"), key=(lambda p: int(p.parent.name.split("-")[1]))) + fabric.print(f"Resuming training from {resume}") + load_checkpoint_update(fabric, resume, model, checkpoint_path, strict=False) + resume = True + except ValueError: + fabric.print("No previous adapter found. Finetune from start.") + resume = False + load_checkpoint(fabric, model, checkpoint_path, strict=False) + else: # strict=False because missing keys due to Adapter weights not contained in state dict - load_checkpoint(fabric, model, checkpoint_path, strict=False) - + load_checkpoint(fabric, model, checkpoint_path, strict=False) + + mark_only_adapter_v2_as_trainable(model) + train_time = time.perf_counter() fit( fabric, @@ -171,6 +186,7 @@ def main( train_dataloader, val_dataloader, devices, + resume, checkpoint_dir, out_dir, train, @@ -206,6 +222,7 @@ def fit( train_dataloader: DataLoader, val_dataloader: DataLoader, devices: int, + resume: bool, checkpoint_dir: Path, out_dir: Path, train: TrainArgs, @@ -234,6 +251,14 @@ def fit( total_t0 = time.perf_counter() val_loss = "n/a" + if resume != False: + try: + iter_match = max(out_dir.rglob("step-*/*.pth.adapter_v2"), key=lambda p: int(p.parent.name.split("-")[1])) + step_count = int(iter_match.parent.name.split("-")[1]) if iter_match else 0 + except ValueError: + step_count = 0 + + fabric.print(f"Starting at step count {step_count}") while step_count < max_steps and train_iterator.epoch < train.epochs: iter_num += 1 iter_t0 = time.perf_counter() diff --git a/litgpt/utils.py b/litgpt/utils.py index 8a64b94110..8196647cef 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -324,6 +324,16 @@ def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, s state_dict = state_dict.get("model", state_dict) model.load_state_dict(state_dict, strict=strict) +def load_checkpoint_update(fabric: L.Fabric, adapter_path: Path, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None: + if isinstance(fabric.strategy, FSDPStrategy): + fabric.load_raw(checkpoint_path, model, strict=strict) + else: + state_dict = lazy_load(checkpoint_path) + state_dict = state_dict.get("model", state_dict) + adapter_cp = lazy_load(adapter_path) + state_dict.update(adapter_cp) + model.load_state_dict(state_dict, strict=strict) + def flops_per_param(max_seq_length: int, n_layer: int, n_embd: int, n_params: int) -> int: flops_per_token = 2 * n_params # each parameter is used for a MAC (2 FLOPS) per network operation