diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 1f2e75414e..1e9bb0444b 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -115,6 +115,7 @@ "minicpm3", "baichuan", "deepseek_v2", + "deepseek_v3", "chatglm", "qwen2_vl", ] diff --git a/optimum/habana/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/optimum/habana/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 44986e22a1..86b5abf313 100644 --- a/optimum/habana/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/optimum/habana/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -699,11 +699,18 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, + token_idx: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Copied from DeepseekV2Attention.forward: https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py + deltas are: + - add token_idx + - optimize KV cache + """ if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" @@ -741,7 +748,10 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if token_idx is None: + kv_seq_len += past_key_value[0].shape[-2] + else: + kv_seq_len = past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) @@ -754,10 +764,19 @@ def forward( key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) + if token_idx is None: + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + past_key_value[0].index_add_( + 2, token_idx - 1, key_states - torch.index_select(past_key_value[0], 2, token_idx - 1) + ) + past_key_value[1].index_add_( + 2, token_idx - 1, value_states - torch.index_select(past_key_value[1], 2, token_idx - 1) + ) + key_states = past_key_value[0] + value_states = past_key_value[1] + past_key_value = (key_states, value_states) if use_cache else None attn_weights = ( torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale @@ -803,295 +822,12 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV3 -class DeepseekV3FlashAttention2(DeepseekV3Attention): - """ - DeepseekV3 flash attention module. This module inherits from `DeepseekV3Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # DeepseekV3FlashAttention2 attention does not support output_attentions - if "padding_mask" in kwargs: - warnings.warn( - "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" - ) - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop("padding_mask") - - output_attentions = False - - bsz, q_len, _ = hidden_states.size() - - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split( - q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 - ) - - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split( - compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 - ) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) - ) - - k_nope, value_states = torch.split( - kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 - ) - kv_seq_len = value_states.shape[-2] - - kv_seq_len = value_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - - if self.q_head_dim != self.v_head_dim: - value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - dropout_rate = self.attention_dropout if self.training else 0.0 - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (DeepseekV3RMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - # Handle the case where the model is quantized - if hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - elif torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - else: - target_dtype = ( - self.q_proj.weight.dtype - if self.q_lora_rank is None - else self.q_a_proj.weight.dtype - ) - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = self._flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=self.softmax_scale, - ) - if self.q_head_dim != self.v_head_dim: - attn_output = attn_output[:, :, :, : self.v_head_dim] - - attn_output = attn_output.reshape( - bsz, q_len, self.num_heads * self.v_head_dim - ).contiguous() - attn_output = self.o_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - def _flash_attention_forward( - self, - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=0.0, - softmax_scale=None, - ): - """ - Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token - first unpad the input, then computes the attention scores and pad the final attention scores. - - Args: - query_states (`torch.Tensor`): - Input query states to be passed to Flash Attention API - key_states (`torch.Tensor`): - Input key states to be passed to Flash Attention API - value_states (`torch.Tensor`): - Input value states to be passed to Flash Attention API - attention_mask (`torch.Tensor`): - The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the - position of padding tokens and 1 for the position of non-padding tokens. - dropout (`int`, *optional*): - Attention dropout - softmax_scale (`float`, *optional*): - The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) - """ - if not self._flash_attn_uses_top_left_mask: - causal = self.is_causal - else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV3FlashAttention2 __init__. - causal = self.is_causal and query_length != 1 - - # Contains at least one padding token in the sequence - if attention_mask is not None: - batch_size = query_states.shape[0] - ( - query_states, - key_states, - value_states, - indices_q, - cu_seq_lens, - max_seq_lens, - ) = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length - ) - - cu_seqlens_q, cu_seqlens_k = cu_seq_lens - max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens - - attn_output_unpad = flash_attn_varlen_func( - query_states, - key_states, - value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_in_batch_q, - max_seqlen_k=max_seqlen_in_batch_k, - dropout_p=dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - attn_output = pad_input( - attn_output_unpad, indices_q, batch_size, query_length - ) - else: - attn_output = flash_attn_func( - query_states, - key_states, - value_states, - dropout, - softmax_scale=softmax_scale, - causal=causal, - ) - - return attn_output - - def _upad_input( - self, query_layer, key_layer, value_layer, attention_mask, query_length - ): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) - batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape - - key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), - indices_k, - ) - value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), - indices_k, - ) - if query_length == kv_seq_len: - query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), - indices_k, - ) - cu_seqlens_q = cu_seqlens_k - max_seqlen_in_batch_q = max_seqlen_in_batch_k - indices_q = indices_k - elif query_length == 1: - max_seqlen_in_batch_q = 1 - cu_seqlens_q = torch.arange( - batch_size + 1, dtype=torch.int32, device=query_layer.device - ) # There is a memcpy here, that is very bad. - indices_q = cu_seqlens_q[:-1] - query_layer = query_layer.squeeze(1) - else: - # The -q_len: slice assumes left padding. - attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( - query_layer, attention_mask - ) - - return ( - query_layer, - key_layer, - value_layer, - indices_q, - (cu_seqlens_q, cu_seqlens_k), - (max_seqlen_in_batch_q, max_seqlen_in_batch_k), - ) - - -ATTENTION_CLASSES = { - "eager": DeepseekV3Attention, - "flash_attention_2": DeepseekV3FlashAttention2, -} - - class DeepseekV3DecoderLayer(nn.Module): def __init__(self, config: DeepseekV3Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = DeepseekV3Attention( config=config, layer_idx=layer_idx ) @@ -1119,10 +855,16 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + token_idx: Optional[torch.Tensor] = None, **kwargs, ) -> Tuple[ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] ]: + """ + Copied from DeepseekV2DecoderLayer.forward: https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py + The deltas are: + - add token_idx + """ """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -1153,6 +895,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + token_idx=token_idx, **kwargs, ) hidden_states = residual + hidden_states @@ -1337,6 +1080,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = ( output_attentions @@ -1367,11 +1111,8 @@ def forward( raise ValueError("You have to specify either input_ids or inputs_embeds") past_key_values_length = 0 - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -1402,25 +1143,28 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = None + next_decoder_cache = () if use_cache else None - for decoder_layer in self.layers: + for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) + past_key_value = past_key_values[idx] if past_key_values is not None else None + layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + token_idx=token_idx, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1431,13 +1175,7 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = ( - next_decoder_cache.to_legacy_cache() - if use_legacy_cache - else next_decoder_cache - ) + next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple( v @@ -1494,6 +1232,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1545,6 +1284,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + token_idx=token_idx, ) hidden_states = outputs[0] @@ -1584,14 +1324,20 @@ def prepare_inputs_for_generation( inputs_embeds=None, **kwargs, ): + token_idx = kwargs.get("token_idx") + past_length = 0 + max_cache_length = None if past_key_values is not None: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() + if token_idx is not None: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where @@ -1622,13 +1368,16 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if token_idx is not None: + position_ids = torch.index_select(position_ids, 1, token_idx - 1) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + model_inputs = {"input_ids": input_ids.contiguous()} model_inputs.update( { @@ -1636,22 +1385,11 @@ def prepare_inputs_for_generation( "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, + "token_idx": token_idx, } ) return model_inputs - @staticmethod - def _reorder_cache(past_key_values, beam_idx): - reordered_past = () - for layer_past in past_key_values: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx.to(past_state.device)) - for past_state in layer_past - ), - ) - return reordered_past - @add_start_docstrings( """