-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flash attention error #973
Comments
Tagging @kashif and @younesbelkada. |
i have to check the |
Hey @SpursLipu |
my dpo_qwen.py please check
|
Can you pass quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
model = AutoModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
device_map={"": Accelerator().local_process_index},
trust_remote_code=script_args.trust_remote_code,
quantization_config=quantization_config) But I am really not sure this will solve your bug, I just suspect that there might be some weird interaction between the compute dtype and FA-2 on their repository |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
I want to use dpo ft qwen-chat-14b, but I meet the error. The input(q, k, y) type of flash-attention in qwen has to be set as float16 or bfloat16, but in dpo_trainer the type is float32. If I turn off the flash-attention this error will not occur. But training become very slow. How to solve this problem?
Traceback (most recent call last):
File "/mnt/afs/smartbrain/FastChat/fastchat/rlhf/dpo_qwen.py", line 215, in
dpo_trainer.train()
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1555, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 1837, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/transformers/trainer.py", line 2682, in training_step
loss = self.compute_loss(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 594, in compute_loss
loss, metrics = self.get_batch_metrics(model, inputs, train_eval="train")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 545, in get_batch_metrics
) = self.concatenated_forward(model, batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/trl/trainer/dpo_trainer.py", line 511, in concatenated_forward
all_logits = model(
^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 632, in forward
return model_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/utils/operations.py", line 620, in call
return convert_to_fp32(self.model_forward(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 14, in decorate_autocast
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/peft/peft_model.py", line 918, in forward
return self.base_model(
^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 94, in forward
return self.model.forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 1108, in forward
transformer_outputs = self.transformer(
^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 938, in forward
outputs = block(
^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 639, in forward
attn_outputs = self.attn(
^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 546, in forward
context_layer = self.core_attention_flash(q, k, v, attention_mask=attention_mask)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/lib/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/.cache/huggingface/modules/transformers_modules/modeling_qwen.py", line 174, in forward
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
AssertionError
0%| | 0/127611 [00:01<?, ?it/s]
The text was updated successfully, but these errors were encountered: