Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

input requirements of flashselfattn #1

Open
taoye9 opened this issue Nov 14, 2022 · 1 comment
Open

input requirements of flashselfattn #1

taoye9 opened this issue Nov 14, 2022 · 1 comment

Comments

@taoye9
Copy link

taoye9 commented Nov 14, 2022

Hi, i‘m trying to use flashattn on cases where q,k ,v are of different values.

could someone confirm whether flashSelfAttn forward inputs (q, k, v)are a same tensor/values are same?

as the world self attn usually implies q,k,v are a same tensor.

code is here:

class FlashSelfAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
device=None, dtype=None):
super().__init__()
assert flash_attn_unpadded_qkvpacked_func is not None, 'FlashAttention is not installed'
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, qkv):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
"""
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
batch_size, seqlen = qkv.shape[0], qkv.shape[1]
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=self.causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
return output

@tridao
Copy link
Owner

tridao commented Nov 14, 2022

Yes. There's also a cross attention version where q is separate by k, v are stacked into 1 tensor. Or you can call flash_attn_unpadded_func which takes in separate q, k, v tensors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants