Skip to content

Commit

Permalink
fix DPO + GC issues (huggingface#927)
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada authored and Andrew Lapp committed May 10, 2024
1 parent 4219c86 commit 703d8c7
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,30 @@ def __init__(
elif is_peft_available() and peft_config is not None:
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:

def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
model = get_peft_model(model, peft_config)
# For models that use gradient_checkpoiting, we need to attach a hook that enables input
# to explicitly have `requires_grad=True`, otherwise training will either silently
# fail or completely fail.
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:

def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)

model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

if generate_during_eval and not is_wandb_available():
raise ValueError(
Expand Down

0 comments on commit 703d8c7

Please sign in to comment.