Skip to content

Commit

Permalink
[SFTTrainer] Fix Trainer when args is None (#1064)
Browse files Browse the repository at this point in the history
* fix sfttrainer when args is None

* oops
  • Loading branch information
younesbelkada authored Dec 6, 2023
1 parent ee44946 commit 9fb00cf
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,19 @@ def __init__(
inspect.signature(prepare_model_for_kbit_training).parameters
)

preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
preprare_model_kwargs = {
"use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False)
}

if _support_gc_kwargs:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
preprare_model_kwargs["gradient_checkpointing_kwargs"] = getattr(
args, "gradient_checkpointing_kwargs", None
)

model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)

args = dataclasses.replace(args, gradient_checkpointing=False)
if args is not None:
args = dataclasses.replace(args, gradient_checkpointing=False)
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
Expand Down

0 comments on commit 9fb00cf

Please sign in to comment.