Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FA2 broken for Cohere2 if Optional Mask is not passed in forward #35547

Open
2 of 4 tasks
Qubitium opened this issue Jan 7, 2025 · 5 comments
Open
2 of 4 tasks

FA2 broken for Cohere2 if Optional Mask is not passed in forward #35547

Qubitium opened this issue Jan 7, 2025 · 5 comments
Labels

Comments

@Qubitium
Copy link
Contributor

Qubitium commented Jan 7, 2025

System Info

transformers==4.48.0.dev0 (from git+https://github.com/huggingface/transformers.git@5615a393691c81e00251e420c73e4d04c6fe22e5)

Who can help?

@ArthurZucker @Cyrilvallez @SunMarc

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Check our CI test failures:

Gemma

https://github.com/ModelCloud/GPTQModel/actions/runs/12651906072/job/35253942521#step:12:1164

Cohere2

https://github.com/ModelCloud/GPTQModel/actions/runs/12651906072/job/35253938235#step:12:922

We enabled FA2 by default on GPTQModel for inference of gptq quantized models and our CI tests are failing for multiple models. This looks like a regression in the fa2 attention code where seq_len is never set if mask is None. FA2 forward requires seq_len:

def flash_attention_forward(
config: Cohere2Config,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor],
target_dtype: torch.dtype = torch.float16,
**_kwargs,
) -> Tuple[torch.Tensor, None]:
if mask is not None:
seq_len = mask.shape[1]
query = query[:, :, :seq_len]
value = value[:, :, :seq_len]
# TODO: These transpose are quite inefficient but Flash Attention requires the layout
# [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor rotary embedding
query_states = query.transpose(1, 2)
key_states = key.transpose(1, 2)
value_states = value.transpose(1, 2)
dropout_rate = config.attention_dropout if config.training else 0.0
input_dtype = query_states.dtype
if input_dtype == torch.float32:
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
mask,
seq_len,
dropout=dropout_rate,
is_causal=config.is_causal,

@SunMarc I don't think this is related to quantization and @ArthurZucker The FA2 code above is broken if mask is not passed or None as seq_len will never be set. The mask param is explicitly declared as Optional.

Expected behavior

Work and not crash.

@Qubitium Qubitium added the bug label Jan 7, 2025
@SunMarc
Copy link
Member

SunMarc commented Jan 7, 2025

cohere2 flash attention 2 code is the original one from the author as you can see here. cohere2 model is one of the few models that code its own flash_attention_forward function. Maybe @alexrs-cohere can help you fix this issue.

Also, we are refactoring the attention #35235, please let us know if you face any issues with other models !

@Cyrilvallez
Copy link
Member

Cyrilvallez commented Jan 7, 2025

Not entirely sure why you chose this particular commit as a version, but this does not seem to be an issue on main

@Qubitium Qubitium changed the title FA2 broken if Optional Mask is not passed in forward FA2 broken for Cohere2 if Optional Mask is not passed in forward Jan 7, 2025
@Qubitium
Copy link
Contributor Author

Qubitium commented Jan 7, 2025

Not entirely sure why you chose this particular commit as a version, but this does not seem to be an issue on main

@Cyrilvallez My mistake. Our CI was force checking out a commit post 4.47.1 but not the latest main since Cohere2 code was merged right after 4.47.1 release.

So it looks like Cohere2 is the only model that still has the broken implementation code for Fa2.

@alexrs-cohere Please check.

@Cyrilvallez
Copy link
Member

Ha indeed the issue persists for Cohere2! Thanks, I'll open a PR!

@alexrs-cohere
Copy link
Contributor

Thanks for reporting this @Qubitium!

@Cyrilvallez let me know when the PR is ready and if you need any support from me!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants