Skip to content

Commit

Permalink
fix gradient checkpointing when using PEFT (#1118)
Browse files Browse the repository at this point in the history
  • Loading branch information
pacman100 authored Dec 20, 2023
1 parent f2acd82 commit 830cadf
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,27 +172,28 @@ def __init__(
)

if not isinstance(model, PeftModel):
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)
gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {}
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
_support_gc_kwargs = hasattr(
args, "gradient_checkpointing_kwargs"
) and "gradient_checkpointing_kwargs" in list(
inspect.signature(prepare_model_for_kbit_training).parameters
)

preprare_model_kwargs = {
"use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False)
}

if _support_gc_kwargs:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = getattr(
args, "gradient_checkpointing_kwargs", None
)
preprare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs

model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)

if args is not None:
args = dataclasses.replace(args, gradient_checkpointing=False)
elif getattr(args, "gradient_checkpointing", False):
elif getattr(args, "gradient_checkpointing", False) and (
"use_reentrant" not in gradient_checkpointing_kwargs
or gradient_checkpointing_kwargs["use_reentrant"]
):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
Expand Down

0 comments on commit 830cadf

Please sign in to comment.