diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index bfddb777cf..8db4bef069 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -176,14 +176,19 @@ def __init__( inspect.signature(prepare_model_for_kbit_training).parameters ) - preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + preprare_model_kwargs = { + "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False) + } if _support_gc_kwargs: - preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + preprare_model_kwargs["gradient_checkpointing_kwargs"] = getattr( + args, "gradient_checkpointing_kwargs", None + ) model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) - args = dataclasses.replace(args, gradient_checkpointing=False) + if args is not None: + args = dataclasses.replace(args, gradient_checkpointing=False) elif getattr(args, "gradient_checkpointing", False): # For backward compatibility with older versions of transformers if hasattr(model, "enable_input_require_grads"):