From f5d0fbd46805155a4406d36e8fb9f9d0324030c4 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 28 Mar 2023 21:27:00 -0700 Subject: [PATCH] [FT] Fix FT's single query attention for bf16 hdim128 rotary --- ...decoder_masked_multihead_attention_utils.h | 30 +++++++++---------- flash_attn/modules/mha.py | 6 ++-- flash_attn/utils/generation.py | 6 ++-- tests/models/test_gpt_generation.py | 3 +- 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/csrc/ft_attention/decoder_masked_multihead_attention_utils.h b/csrc/ft_attention/decoder_masked_multihead_attention_utils.h index 48739f79b..21b801307 100644 --- a/csrc/ft_attention/decoder_masked_multihead_attention_utils.h +++ b/csrc/ft_attention/decoder_masked_multihead_attention_utils.h @@ -1669,22 +1669,6 @@ __device__ __inline__ void write_smem_transpose(const float& vec, float* smem, i return; } -#ifdef ENABLE_BF16 -template<> -__device__ __inline__ void -write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - return; -} - -template<> -__device__ __inline__ void -write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) -{ - return; -} -#endif - template<> __device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) { @@ -1776,6 +1760,20 @@ write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpo smem[transpose_idx] = vec.x; smem[smem_pitch + transpose_idx] = vec.y; } + +template<> +__device__ __inline__ void +write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); +} + +template<> +__device__ __inline__ void +write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); +} #endif template<> diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 4434b2fa2..96f8c8266 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -494,7 +494,8 @@ def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seql *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), *inference_params.key_value_memory_dict[self.layer_idx], inference_params.lengths_per_sample, inference_params.sequence_len_offset, - self.rotary_emb_dim + self.rotary_emb_dim, + not self.rotary_emb.interleaved # neox_rotary_style ) context = rearrange(context, 'b h d -> b 1 h d') else: @@ -607,7 +608,8 @@ def forward(self, x, seqlen=None, inference_params=None, **kwargs): *rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1), *inference_params.key_value_memory_dict[self.layer_idx], inference_params.lengths_per_sample, inference_params.sequence_len_offset, - self.rotary_emb_dim + self.rotary_emb_dim, + not self.rotary_emb.interleaved # neox_rotary_style ) context = rearrange(context, 'b h d -> b 1 h d') if seqlen is None: diff --git a/flash_attn/utils/generation.py b/flash_attn/utils/generation.py index 233d07986..89796449f 100644 --- a/flash_attn/utils/generation.py +++ b/flash_attn/utils/generation.py @@ -82,6 +82,8 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, Arguments: input_ids: (batch, seq_len) max_length: int + teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the + logits, the next token is taken from the teacher_outputs. Useful for testing. Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: sequences: (batch, max_length) scores: tuples of (batch, vocab_size) @@ -111,7 +113,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, start = time.time() if vocab_size is not None: logits = logits[..., :vocab_size] - scores.append(logits) + scores.append(logits if not cg else logits.clone()) if teacher_outputs is None or teacher_output_len <= seqlen_og: next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) else: @@ -129,7 +131,7 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, inference_params.sequence_len_offset) if vocab_size is not None: logits = logits[..., :vocab_size] - scores.append(logits) + scores.append(logits if not cg else logits.clone()) if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset + 1: next_token = sample(logits, top_k=top_k, temperature=temperature) else: diff --git a/tests/models/test_gpt_generation.py b/tests/models/test_gpt_generation.py index dbbe05daa..652aca0cb 100644 --- a/tests/models/test_gpt_generation.py +++ b/tests/models/test_gpt_generation.py @@ -15,7 +15,6 @@ from flash_attn.models.gpt import remap_state_dict_hf_gpt2 from flash_attn.models.opt import remap_state_dict_hf_opt, opt_config_to_gpt2_config from flash_attn.utils.pretrained import state_dict_from_pretrained -from flash_attn.utils.distributed import all_gather_raw from flash_attn.utils.generation import update_graph_cache @@ -61,7 +60,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel): torch.manual_seed(0) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") input_ids = tokenizer("Hello, my dog is cute and", - return_tensors="pt").input_ids.to(device=device) + return_tensors="pt").input_ids.to(device=device) max_length = 30 # input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda') # max_length = input_ids.shape[1] + 40