Skip to content

Commit

Permalink
[FT] Fix FT's single query attention for bf16 hdim128 rotary
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Mar 29, 2023
1 parent 4d87e4d commit f5d0fbd
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 22 deletions.
30 changes: 14 additions & 16 deletions csrc/ft_attention/decoder_masked_multihead_attention_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1669,22 +1669,6 @@ __device__ __inline__ void write_smem_transpose(const float& vec, float* smem, i
return;
}

#ifdef ENABLE_BF16
template<>
__device__ __inline__ void
write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
return;
}

template<>
__device__ __inline__ void
write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
return;
}
#endif

template<>
__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch)
{
Expand Down Expand Up @@ -1776,6 +1760,20 @@ write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpo
smem[transpose_idx] = vec.x;
smem[smem_pitch + transpose_idx] = vec.y;
}

template<>
__device__ __inline__ void
write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
write_smem_transpose(reinterpret_cast<const uint2&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
}

template<>
__device__ __inline__ void
write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch)
{
write_smem_transpose(reinterpret_cast<const uint4&>(vec), reinterpret_cast<uint16_t*>(smem), transpose_idx, smem_pitch);
}
#endif

template<>
Expand Down
6 changes: 4 additions & 2 deletions flash_attn/modules/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,8 @@ def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seql
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
*inference_params.key_value_memory_dict[self.layer_idx],
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
self.rotary_emb_dim
self.rotary_emb_dim,
not self.rotary_emb.interleaved # neox_rotary_style
)
context = rearrange(context, 'b h d -> b 1 h d')
else:
Expand Down Expand Up @@ -607,7 +608,8 @@ def forward(self, x, seqlen=None, inference_params=None, **kwargs):
*rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1),
*inference_params.key_value_memory_dict[self.layer_idx],
inference_params.lengths_per_sample, inference_params.sequence_len_offset,
self.rotary_emb_dim
self.rotary_emb_dim,
not self.rotary_emb.interleaved # neox_rotary_style
)
context = rearrange(context, 'b h d -> b 1 h d')
if seqlen is None:
Expand Down
6 changes: 4 additions & 2 deletions flash_attn/utils/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
Arguments:
input_ids: (batch, seq_len)
max_length: int
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
logits, the next token is taken from the teacher_outputs. Useful for testing.
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
sequences: (batch, max_length)
scores: tuples of (batch, vocab_size)
Expand Down Expand Up @@ -111,7 +113,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
start = time.time()
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits)
scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= seqlen_og:
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else:
Expand All @@ -129,7 +131,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
inference_params.sequence_len_offset)
if vocab_size is not None:
logits = logits[..., :vocab_size]
scores.append(logits)
scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1:
next_token = sample(logits, top_k=top_k, temperature=temperature)
else:
Expand Down
3 changes: 1 addition & 2 deletions tests/models/test_gpt_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from flash_attn.models.gpt import remap_state_dict_hf_gpt2
from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config
from flash_attn.utils.pretrained import state_dict_from_pretrained
from flash_attn.utils.distributed import all_gather_raw
from flash_attn.utils.generation import update_graph_cache


Expand Down Expand Up @@ -61,7 +60,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
torch.manual_seed(0)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("Hello, my dog is cute and",
return_tensors="pt").input_ids.to(device=device)
return_tensors="pt").input_ids.to(device=device)
max_length = 30
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
Expand Down

0 comments on commit f5d0fbd

Please sign in to comment.