Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix backward error in stack llama2 DPO example if checkpointing is used #1056

Closed
wants to merge 1 commit into from

Conversation

sywangyi
Copy link
Contributor

No description provided.

@sywangyi
Copy link
Contributor Author

@pacman100 @younesbelkada please help review.
when I run dpo in https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2/scripts, default gradient_checkpointing is enabled.
coredump like
/home/ywan171/miniconda3/envs/trl/lib/python3.9/site-packages/torch/utils/checkpoint.py:31: UserWarning: None of the inputs have requires_grad=True. Gradients will be None
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
Traceback (most recent call last):
File "/home/ywan171/trl/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py", line 221, in
dpo_trainer.train()
File "/home/ywan171/transformers/src/transformers/trainer.py", line 1511, in train
return inner_training_loop(
File "/home/ywan171/transformers/src/transformers/trainer.py", line 1811, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
File "/home/ywan171/transformers/src/transformers/trainer.py", line 2669, in training_step
self.accelerator.backward(loss)
File "/home/ywan171/miniconda3/envs/trl/lib/python3.9/site-packages/accelerate/accelerator.py", line 1985, in backward
loss.backward(**kwargs)
File "/home/ywan171/miniconda3/envs/trl/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(
File "/home/ywan171/miniconda3/envs/trl/lib/python3.9/site-packages/torch/autograd/init.py", line 200, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
0%| | 0/1000 [00:12<?, ?it/s]

I did not use load_in_4bit=True in AutoPeftModelForCausalLM.from_pretrained.

@sywangyi
Copy link
Contributor Author

I check the code, is_gradient_checkpointing is set in https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L1619 in trainer.train() after the AutoPeftModelForCausalLM.from_pretrained()

@sywangyi
Copy link
Contributor Author

@yao-matrix

@BenjaminBossan
Copy link
Member

The change in this PR is problematic because it means that PEFT will stop working with non-transformers models. It would also call _prepare_model_for_gradient_checkpointing unconditionally, whether it's needed or not.

I check the code, is_gradient_checkpointing is set in https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L1619 in trainer.train() after the AutoPeftModelForCausalLM.from_pretrained()

Maybe it's something that needs to be fixed in transformers then, WDYT @younesbelkada?

@sywangyi
Copy link
Contributor Author

Hi @BenjaminBossan, thanks for the review, how about change like this? the PR is updated.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@younesbelkada
Copy link
Contributor

Hi @sywangyi
Thanks for the PR, I think that this patch is not needed, can you try to run that script by hardcoding gradient_checkpointing=True here: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py#L183
As you can see here: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L1618 GC is enabled only if it is set to True in the TrainingArguments

@sywangyi
Copy link
Contributor Author

Hi @sywangyi Thanks for the PR, I think that this patch is not needed, can you try to run that script by hardcoding gradient_checkpointing=True here: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py#L183 As you can see here: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L1618 GC is enabled only if it is set to True in the TrainingArguments

I just set "gradient_checkpointing=True", and the error is shown like what I describe in the bug. also the default value of the gradient_checkpointing in this script is True. see https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py#L41. some other guy find the issue as well. see what's decribed in huggingface/trl#906 "I found that gradient checkpointing doesn't work, so I disabled it. I guess it's caused by PEFT, for gradient checkpointing works well with deepspeed version of SFT (in my private repo)."

@younesbelkada
Copy link
Contributor

@sywangyi , I see, thanks!
What happens if you call model.enable_input_require_grads() before feeding it to DPOTrainer?

@sywangyi
Copy link
Contributor Author

sywangyi commented Oct 30, 2023

model.enable_input_require_grads()

error disappears if we call model.enable_input_require_grads() before feeding it to DPOTrainer, the error is just because lack of enable_input_require_grads(), no matter it's called in trainer.train() or someplace else

@younesbelkada
Copy link
Contributor

younesbelkada commented Oct 30, 2023

I wonder if we should change the line of prepare_model_for_kbit_training as such:

- if (loaded_in_kbit or is_gptq_quantized) and use_gradient_checkpointing:
+if use_gradient_checkpointing:

what is the reason to only enable that for quantized models cc @BenjaminBossan @pacman100 ?

@BenjaminBossan
Copy link
Member

what is the reason to only enable that for quantized models

Good question, I don't know :)

This seems to be the commit that added the logic:

7dfb472

The commit message is:

make gradient checkpointing optional when using PEFT+INT8

but AFAICT this message does not correspond to the actual code change. Should we move enabling of gradient checkpointing out of this function, since it's not directly related to quantization? Or maybe just leave it completely up to the user whether to enable it or not?

@younesbelkada
Copy link
Contributor

Agreed! Not performing enable_input_require_grads in prepare_model_for_kbit_training might be too breaking I will most likely apply fixes on TRL side to make sure to call enable_input_require_grads if the model is not quantized. What do you think?

@BenjaminBossan
Copy link
Member

Not performing enable_input_require_grads in prepare_model_for_kbit_training might be too breaking

Good point (I assume you mean backwards compatibility breaking).

What I don't quite get is why we tie enabling gradients to whether gradient checkpointing is being used. We could also allow gradients on the input embedding without checkpointing, right? In that case, it would be up to the user to enable those gradients? Maybe, at the end of the day, it's also a matter of documenting this properly.

@sywangyi
Copy link
Contributor Author

sywangyi commented Oct 31, 2023

Hi, @younesbelkada prepare_model_for_kbit_training is not called in the case(I am not using loaded_in_8bit or loaded_in_4bit). Also, I see logic in transformers like https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L1877-L1882. That's why I change the logic in PR like this because other apps (LORA finetune) may have the same issue. In what scenario will these logics in transformers be used?

@younesbelkada
Copy link
Contributor

Hi @sywangyi
I just merged a fix in TRL: huggingface/trl#927 can you let me know if this fixes your issue?
A more general fix could be to call enable_input_require_grads whenever we call gradient_checkpointing_enable but I am not sure here, I need to dig a bit. Can you let me know if the above PR fixes the issues for the DPO script?

@sywangyi
Copy link
Contributor Author

sywangyi commented Nov 1, 2023

the PR in TRL fix my issue. @younesbelkada

@younesbelkada
Copy link
Contributor

GReat! thanks very much @sywangyi !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants