From 1903f43773e5a4f23057a686a3c393162d759c39 Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Thu, 19 Dec 2024 17:46:43 +0800 Subject: [PATCH 01/14] Optimized attention and MoE of deepseek_v2 on Gaudi --- .../habana/transformers/generation/utils.py | 2 + .../deepseek_v2/modeling_deepseek_v2.py | 1104 ++++++++++++++--- 2 files changed, 908 insertions(+), 198 deletions(-) mode change 100644 => 100755 optimum/habana/transformers/generation/utils.py diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py old mode 100644 new mode 100755 index 68b445c1b2..6a465de5b3 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1094,6 +1094,7 @@ def generate( "gemma2", "baichuan", "chatglm", + "deepseek_v2", ] ), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2, starcoder2, baichuan and chatglm at the moment" if not generation_config.bucket_internal: @@ -1301,6 +1302,7 @@ def generate( "gemma2", "qwen2_moe", "baichuan", + "deepseek_v2", ]: if ( hasattr(self.config, "max_position_embeddings") diff --git a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py index ee271b7254..339ae93f2c 100644 --- a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -19,38 +19,94 @@ # limitations under the License. """PyTorch DeepSeekV2 model. Adapted from https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/resolve/main/modeling_deepseek.py""" +import contextlib import math +import os import warnings from typing import List, Optional, Tuple, Union +import habana_frameworks.torch.core as htcore import torch import torch.distributed as dist + +# import torch.distributed as dist import torch.nn.functional as F import torch.utils.checkpoint from torch import nn -from torch.nn import CrossEntropyLoss +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers import PretrainedConfig # , PreTrainedModel from transformers.activations import ACT2FN -from transformers.cache_utils import Cache +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.integrations.deepspeed import is_deepspeed_available +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, +) from transformers.modeling_outputs import ( BaseModelOutputWithPast, - CausalLMOutputWithPast, + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + SequenceClassifierOutputWithPast, ) from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import ( ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13, ) from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, logging, - replace_return_docstrings, ) -from ....distributed.tensorparallel import _all_reduce +# from ....distributed.tensorparallel import _all_reduce +from transformers.utils.import_utils import is_torch_fx_available + from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask from .configuration_deepseek_v2 import DeepseekV2Config +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeepseekV2Config" + +try: + from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE + + print("Using HPU fused kernel for apply_rotary_pos_emb") +except ImportError: + print("Not using HPU fused kernel for apply_rotary_pos_emb") + FusedRoPE = None + +try: + from habana_frameworks.torch.hpex.normalization import FusedRMSNorm + + print("Using HPU fused kernel for RMSNorm") +except ImportError: + print("Not using HPU fused kernel for RMSNorm") + FusedRMSNorm = None + +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA +except ImportError: + print("Not using HPU fused scaled dot-product attention kernel.") + FusedSDPA = None + +try: + from habana_frameworks.torch.hpu import sdp_kernel + + SDPContext = True +except ImportError: + SDPContext = False + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DeepseekV2Config" @@ -68,6 +124,89 @@ def _get_unpad_data(attention_mask): ) +# Copied from transformers.models.mixtral.modeling_mixtral.load_balancing_loss_func +def load_balancing_loss_func( + gate_logits: Union[torch.Tensor, Tuple[torch.Tensor], None], + num_experts: Optional[int] = None, + top_k=2, + attention_mask: Optional[torch.Tensor] = None, +) -> Union[torch.Tensor, int]: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits: + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts: + Number of experts + top_k: + The number of experts to route per-token, can be also interpreted as the `top-k` routing + parameter. + attention_mask (`torch.Tensor`, *optional*): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + class DeepseekV2RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -78,11 +217,23 @@ def __init__(self, hidden_size, eps=1e-6): self.variance_epsilon = eps def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + if hidden_states.device.type == "hpu" and FusedRMSNorm: + # mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype + if hidden_states.dtype != self.weight.dtype: + orig_dtype = hidden_states.dtype + hidden_states = FusedRMSNorm.apply( + hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon + ) + return hidden_states.to(orig_dtype) + else: + hidden_states = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon) + return hidden_states + else: + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm) @@ -118,7 +269,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: + if seq_len is not None and seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( @@ -273,6 +424,15 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb_sin, persistent=False) +def apply_customized_rope(q, k, cos, sin, position_ids): + if q.device.type == "hpu" and FusedRoPE: + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids + ), FusedRoPE.apply(k, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), position_ids) + else: + return apply_rotary_pos_emb(q, k, cos, sin, position_ids) + + # Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -282,11 +442,10 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +def apply_rotary_pos_emb(q: torch.Tensor, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. - k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`): @@ -302,18 +461,19 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) b, h, s, d = q.shape q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) - b, h, s, d = k.shape - k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) - - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed + if q.device.type == "hpu" and FusedRoPE: + return FusedRoPE.apply( + q, cos.unsqueeze(0).unsqueeze(0).clone(), sin.unsqueeze(0).unsqueeze(0).clone(), position_ids + ) + else: + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed class DeepseekV2MLP(nn.Module): @@ -362,18 +522,21 @@ def forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape ### compute gating score hidden_states = hidden_states.view(-1, h) - logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32), None) + + logits = F.linear(hidden_states.type(torch.bfloat16), self.weight.type(torch.bfloat16), None).to( + dtype=torch.float32 + ) if self.scoring_func == "softmax": - scores = F.softmax(logits, dim=-1, dtype=torch.float32) + scores = logits.softmax(dim=-1, dtype=torch.float32) else: raise NotImplementedError(f"insupportable scoring function for MoE gating: {self.scoring_func}") ### select top-k experts if self.topk_method == "greedy": - topk_weight, topk_idx = torch.topk(scores, self.top_k, dim=-1) + topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=True) elif self.topk_method == "group_limited_greedy": group_scores = scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values # [n, n_group] - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[1] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = ( @@ -382,7 +545,7 @@ def forward(self, hidden_states): .reshape(bsz * seq_len, -1) ) # [n, e] tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=False) + topk_weight, topk_idx = torch.topk(tmp_scores, k=self.top_k, dim=-1, sorted=True) ### norm gate to sum 1 if self.top_k > 1 and self.norm_topk_prob: @@ -446,7 +609,7 @@ def __init__(self, config): super().__init__() self.config = config self.num_experts_per_tok = config.num_experts_per_tok - + self.experts_per_rank = config.n_routed_experts if hasattr(config, "ep_size") and config.ep_size > 1: assert config.ep_size == dist.get_world_size() self.ep_size = config.ep_size @@ -482,48 +645,68 @@ def forward(self, hidden_states): orig_shape = hidden_states.shape topk_idx, topk_weight, aux_loss = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - flat_topk_idx = topk_idx.view(-1) + # we cast back to the input dtype + topk_weight = topk_weight.to(hidden_states.dtype) + batch = orig_shape[0] + sequence_length = orig_shape[1] + hidden_dim = orig_shape[2] if self.training: - hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0) - y = torch.empty_like(hidden_states) + padded_weights = torch.zeros( + (batch * sequence_length, self.config.n_routed_experts), + dtype=topk_weight.dtype, + device=topk_weight.device, + ) + padded_weights.scatter_(-1, topk_idx, topk_weight) + padded_weights = padded_weights.reshape(-1, sequence_length, self.config.n_routed_experts) + padded_weights = padded_weights.permute(2, 0, 1).unsqueeze(-1) + + final_hidden_states = torch.zeros( + (batch, sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) for i, expert in enumerate(self.experts): - y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) - y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) - y = y.to(hidden_states.dtype).view(*orig_shape) - y = AddAuxiliaryLoss.apply(y, aux_loss) + current_hidden_state = expert(hidden_states) + current_padded_weight = padded_weights[i] + final_hidden_states = ( + final_hidden_states + + current_hidden_state.reshape(-1, sequence_length, hidden_dim) * current_padded_weight + ) + final_hidden_states = final_hidden_states.type(hidden_states.dtype) + final_hidden_states = final_hidden_states.view(*orig_shape) + final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, aux_loss) else: - y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) - if self.config.n_shared_experts is not None: - y = y + self.shared_experts(identity) - return y - - @torch.no_grad() - def moe_infer(self, x, topk_ids, topk_weight): - """ - Rewrite DeepseekV2MoE.moe_infer: https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py for static expert support - """ - out = torch.zeros_like(x) + # pre-processing for custom op inputs + experts_range = range(self.config.n_routed_experts) + gate_proj_list = [self.experts[i].gate_proj.weight.squeeze() for i in experts_range] + down_proj_list = [self.experts[i].down_proj.weight.squeeze() for i in experts_range] + up_proj_list = [self.experts[i].up_proj.weight.squeeze() for i in experts_range] + + act_fn_s = self.config.hidden_act + final_hidden_states = torch.ops.hpu.mixture_of_experts( + hidden_states=hidden_states, + expert_routing_table=topk_idx, + router_weights=topk_weight, + w1=gate_proj_list, + w2=up_proj_list, + w3=down_proj_list, + permuted_weights=True, + activation=act_fn_s, + experts_min=0, + experts_max=(self.config.n_routed_experts - 1), + ) + final_hidden_states = final_hidden_states.reshape(-1, sequence_length, hidden_dim) - seq_len, hidden_dim = x.shape - num_experts = len(self.experts) + if is_deepspeed_available(): + from deepspeed import comm as dist - padded_weights = torch.zeros((seq_len, num_experts), dtype=topk_weight.dtype, device=x.device) - padded_weights.scatter_(-1, topk_ids, topk_weight) - padded_weights = padded_weights.reshape(seq_len, num_experts) - padded_weights = padded_weights.permute(1, 0).unsqueeze(-1) + if dist.is_initialized(): + dist.all_reduce(final_hidden_states, op=dist.ReduceOp.SUM) - # Loop over all available experts in the model and perform the computation on each expert - for i in range(self.experts_per_rank): - expert_idx = i + self.ep_rank * self.experts_per_rank - expert = self.experts[expert_idx] - padded_weight = padded_weights[expert_idx] - x_static = expert(x) * padded_weight - out += x_static + final_hidden_states = final_hidden_states.type(hidden_states.dtype) - if self.ep_size > 1: - out = _all_reduce(out) + if self.config.n_shared_experts is not None: + final_hidden_states = final_hidden_states + self.shared_experts(identity) - return out + return final_hidden_states # Copied from transformers.models.llama.modeling_llama.repeat_kv @@ -539,6 +722,48 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) + else: + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[1] > 1 and cur.shape[1] <= prev.shape[1]: + # Initialize + prev[:, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[1] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev + else: + return torch.cat((prev, cur), dim=dim) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + + # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 class DeepseekV2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -594,6 +819,9 @@ def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): bias=config.attention_bias, ) self._init_rope() + self.k_cache = KVCache() + self.v_cache = KVCache() + self.inp_seq_len = -1 self.softmax_scale = self.q_head_dim ** (-0.5) if self.config.rope_scaling is not None: @@ -649,107 +877,270 @@ def _init_rope(self): else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + compressed_kv_cache_shape = (batch_size, max_seq_len, self.kv_lora_rank) + k_pe_cache_shape = (batch_size, max_seq_len, self.qk_rope_head_dim) + device = self.kv_a_proj_with_mqa.weight.device + dtype = self.config.torch_dtype + + self.k_cache.allocate(inp_seq_len, dtype, device, compressed_kv_cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, k_pe_cache_shape) + + def update_sincos_cache(self, seq_len): + # Call rotary emb forward() to update cos/sin cache when infering more than self.max_position_embeddings + # This helps in avoiding creation of these caches during actual model forward pass and + # reduce memory consumption and improve performance. + if seq_len > self.max_position_embeddings: + self.max_position_embeddings = seq_len + _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + + def reorder(self, tensor, beam_idx, dim_a, dim_b): + updated = tensor.index_select(0, beam_idx) + tensor.copy_(updated) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + if self.k_cache.cache is None: + return (None, None) + + head_dim = self.k_cache.cache.size(-1) + seq_length = self.k_cache.cache.size(-2) + self.reorder(self.k_cache.cache, beam_idx, seq_length, head_dim) + self.reorder(self.v_cache.cache, beam_idx, seq_length, head_dim) + return (self.k_cache.cache.shape, self.v_cache.cache.shape) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim).transpose(1, 2).contiguous() + def split_kv_b_proj(self): + kv_b_proj_weight = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) + self.q_absorb = ( + kv_b_proj_weight[:, : self.qk_nope_head_dim, :].unsqueeze(0).transpose(0, 1) + ) # k, head (128) - dim 0,auto_tp split based on dim 0 /head + self.out_absorb = kv_b_proj_weight[:, self.qk_nope_head_dim :, :].unsqueeze(0) # v head (128) - dim 0 + # del self.kv_b_proj + + def compress_kv( + self, + hidden_states_kv: torch.Tensor, + kv_position_ids: torch.LongTensor, + past_key_value: Optional[Cache] = None, + ) -> torch.Tensor: + # return the RoPE'ed & compressed kv + bsz, kv_seq_len, _ = hidden_states_kv.size() + compressed_kv = self.kv_a_proj_with_mqa(hidden_states_kv) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + compressed_kv = self.kv_a_layernorm(compressed_kv) + k_pe = k_pe.view(bsz, kv_seq_len, 1, self.qk_rope_head_dim).transpose(1, 2) + cos, sin = self.rotary_emb.cos_cached, self.rotary_emb.sin_cached + k_pe = apply_rotary_pos_emb(k_pe, cos, sin, kv_position_ids).view(bsz, kv_seq_len, self.qk_rope_head_dim) + return compressed_kv, k_pe + def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: int = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ - Copied from DeepseekV2Attention.forward: https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py - deltas are: - - add token_idx - - optimize KV cache + Attention masks and past cache are removed. + Input: + - hidden_states: [bsz, q_len, hidden_size] + - compressed_kv: [bsz, kv_len, kv_lora_rank] + - position_ids: [bsz, q_len] """ if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" ) - bsz, q_len, _ = hidden_states.size() + if self.training: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) - if self.q_lora_rank is None: - q = self.q_proj(hidden_states) - else: - q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) - q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + q_pe, k_pe = apply_customized_rope(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) - compressed_kv = self.kv_a_proj_with_mqa(hidden_states) - compressed_kv, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) - kv = ( - self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) - .transpose(1, 2) - ) + if FusedSDPA: + with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): + attn_output = FusedSDPA.apply( + query_states, key_states, value_states, attention_mask, 0.0, False, None + ) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - k_nope, value_states = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - kv_seq_len = value_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - if token_idx is None: - kv_seq_len += past_key_value[0].shape[-2] + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + else: + hidden_states_q = hidden_states + hidden_states_kv = hidden_states + self.split_kv_b_proj() + q_position_ids = position_ids + kv_position_ids = position_ids + bsz, q_len, _ = hidden_states_q.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states_q) else: - kv_seq_len = past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - query_states[:, :, :, : self.qk_nope_head_dim] = q_nope - query_states[:, :, :, self.qk_nope_head_dim :] = q_pe - - key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) - key_states[:, :, :, : self.qk_nope_head_dim] = k_nope - key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - if past_key_value is not None: - if token_idx is None: - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states_q))) + + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + + q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + kv_seq_len = q_pe.shape[-2] + + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + if token_idx is None: + if hasattr(past_key_value, "get_usable_length"): + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value[0].shape[-2] + else: + if reuse_cache: + kv_seq_len = past_key_value[0][-2] + else: + kv_seq_len = past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(q_pe, seq_len=kv_seq_len) + q_pe = apply_rotary_pos_emb(q_pe, cos, sin, q_position_ids) + q_nope = torch.matmul(q_nope.transpose(0, 1), self.q_absorb).transpose(0, 1) + compressed_kv, k_pe = self.compress_kv(hidden_states_kv, kv_position_ids) + + # update & get all compressed_kv, k_pe + if use_cache: + if reuse_cache: + if past_key_value is not None and isinstance(past_key_value[0], torch.Tensor): + # prefix tuning case. attach past_key_value to generate first token. + compressed_kv = torch.cat((past_key_value[0], compressed_kv), -2) + k_pe = torch.cat((past_key_value[1], k_pe), -2) + + compressed_kv = self.k_cache(compressed_kv, 1, token_idx) + + k_pe = self.v_cache(k_pe, 1, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) + + else: + if past_key_value is None: + dtype_1 = hidden_states.dtype + device_1 = hidden_states.device + past_key = torch.zeros(compressed_kv.shape, dtype=dtype_1, device=device_1) + past_value = torch.zeros(k_pe.shape, dtype=dtype_1, device=device_1) + past_key_value = (past_key, past_value) + compressed_kv = self.k_cache.update( + past_key_value[0], compressed_kv, 1, token_idx, self.inp_seq_len + ) + k_pe = self.v_cache.update(past_key_value[1], k_pe, 1, token_idx, self.inp_seq_len) + + if token_idx is None: + past_key_value = (compressed_kv, k_pe) + + if cache_idx is not None and q_len == 1: + compressed_kv = compressed_kv[:, :cache_idx, :] + + k_pe = k_pe[:, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + + kv_seq_len = compressed_kv.shape[-2] else: - past_key_value[0].index_add_( - 2, token_idx - 1, key_states - torch.index_select(past_key_value[0], 2, token_idx - 1) - ) - past_key_value[1].index_add_( - 2, token_idx - 1, value_states - torch.index_select(past_key_value[1], 2, token_idx - 1) - ) - key_states = past_key_value[0] - value_states = past_key_value[1] - past_key_value = (key_states, value_states) if use_cache else None + past_key_value = None - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + kv_seq_len = compressed_kv.size(1) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - assert attention_mask is not None - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + k_pe = k_pe.view(bsz, 1, kv_seq_len, self.qk_rope_head_dim) + + attn_weights = ( + torch.matmul(q_pe, k_pe.mT) + torch.matmul(q_nope, compressed_kv.unsqueeze(-3).mT) + ) * self.softmax_scale + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" ) - attn_weights = attn_weights + attention_mask + assert attention_mask is not None + if attention_mask is not None: + attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_nope.dtype) + + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.einsum("bhql,blc->bhqc", attn_weights, compressed_kv) + + attn_output = torch.matmul(attn_output.permute(2, 1, 0, 3), self.out_absorb.mT).permute( + 2, 1, 0, 3 + ) # torch.einsum('bhqc,hdc->bhqd', attn_output, out_absorb) + ####end of inference if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): raise ValueError( @@ -788,6 +1179,15 @@ def __init__(self, config: DeepseekV2Config, layer_idx: int): self.input_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.self_attn.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.self_attn.update_sincos_cache(seq_len) + def forward( self, hidden_states: torch.Tensor, @@ -797,13 +1197,10 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: int = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Copied from DeepseekV2DecoderLayer.forward: https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py - The deltas are: - - add token_idx - """ """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -835,6 +1232,8 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, **kwargs, ) hidden_states = residual + hidden_states @@ -842,7 +1241,10 @@ def forward( # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) + if isinstance(self.mlp, DeepseekV2MoE): + hidden_states = self.mlp(hidden_states) + else: + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -881,7 +1283,7 @@ class DeepseekV2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["DeepseekV2DecoderLayer"] _skip_keys_device_placement = "past_key_values" - _supports_flash_attn_2 = True + _supports_flash_attn_2 = False _supports_cache_class = True def _init_weights(self, module): @@ -974,6 +1376,7 @@ def __init__(self, config: DeepseekV2Config): self.layers = nn.ModuleList( [DeepseekV2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) + self._attn_implementation = "eager" self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -981,6 +1384,17 @@ def __init__(self, config: DeepseekV2Config): # Initialize weights and apply final processing self.post_init() + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + for layer in self.layers: + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return tuple(layer.reorder_kv_cache(beam_idx) for layer in self.layers) + + def update_sincos_cache(self, seq_len): + for layer in self.layers: + layer.update_sincos_cache(seq_len) + def get_input_embeddings(self): return self.embed_tokens @@ -998,8 +1412,19 @@ def forward( use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, token_idx: Optional[torch.Tensor] = None, + attn_softmax_bf16: Optional[bool] = False, + reuse_cache: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + cache_idx: int = None, + lazy_mode: Optional[bool] = True, + num_virtual_tokens: int = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1026,32 +1451,53 @@ def forward( ) use_cache = False - past_key_values_length = 0 - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0) - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # 4d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - if attention_mask is not None: - attention_mask = _gaudi_prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) + ignore_cache_position = True # Ignoring cache position for HPU + use_new_cache = False # Ignoring new Cache path for HPU + + past_seen_tokens = 0 + + if past_key_values is not None and use_cache: # kept for BC (cache positions) + if reuse_cache: + if isinstance(past_key_values[0][0], torch.Tensor): + past_seen_tokens = past_key_values[0][0].shape[2] + else: + past_seen_tokens = past_key_values[0][0][2] + else: + if use_new_cache: + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + else: + if past_key_values[0] is not None: ##added for (None, None) + past_seen_tokens = past_key_values[0][0].shape[2] + + if ignore_cache_position is False: + if cache_position is None: + if isinstance(past_key_values, StaticCache): + raise ValueError("cache_position is a required argument when using StaticCache.") + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None and cache_position: + position_ids = cache_position.unsqueeze(0) + + else: + if position_ids is None: + position_ids = torch.arange( + past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=inputs_embeds.device + ) + position_ids = position_ids.unsqueeze(0) + cache_position = None + + causal_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, + input_ids.shape if input_ids is not None else (batch_size, seq_length), + inputs_embeds, + past_seen_tokens, + ) # embed positions hidden_states = inputs_embeds @@ -1059,34 +1505,69 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = () if not use_new_cache else None + + if lazy_mode: + htcore.mark_step() - for idx, decoder_layer in enumerate(self.layers): + for layer_idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, position_ids, past_key_values, output_attentions, + output_router_logits, use_cache, + cache_position, + None, + attn_softmax_bf16, + False, + use_flash_attention, + flash_attention_recompute, + flash_attention_causal_mask, + flash_attention_fast_softmax, + None, ) else: + if ( + lazy_mode + and not self.training + and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) + ): + htcore.mark_step() + layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], output_attentions=output_attentions, + output_router_logits=output_router_logits, use_cache=use_cache, + cache_position=cache_position, token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, + reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + cache_idx=cache_idx, + num_virtual_tokens=num_virtual_tokens, ) + if ( + lazy_mode + and not self.training + and (torch.distributed.is_initialized() is False or torch.distributed.get_world_size() == 1) + ): + htcore.mark_step() hidden_states = layer_outputs[0] @@ -1096,20 +1577,32 @@ def forward( if output_attentions: all_self_attns += (layer_outputs[1],) + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, + router_logits=all_router_logits, ) @@ -1143,8 +1636,64 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + self.kv_cache_len = max_seq_len + + def reorder_kv_cache(self, beam_idx: torch.LongTensor): + return self.model.reorder_kv_cache(beam_idx) + + def update_sincos_cache(self, seq_len): + self.model.update_sincos_cache(seq_len) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: bool = None, + **kwargs, + ): + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=False, + proxies=None, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder="", + _from_auto=False, + _from_pipeline=None, + **kwargs, + ) + + return super(DeepseekV2ForCausalLM, cls).from_pretrained( + pretrained_model_name_or_path, + *model_args, + config=config, + cache_dir=cache_dir, + ignore_mismatched_sizes=ignore_mismatched_sizes, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + use_safetensors=use_safetensors, + **kwargs, + ) + def forward( self, input_ids: torch.LongTensor = None, @@ -1155,10 +1704,16 @@ def forward( labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, + output_hidden_states: Optional[bool] = False, # None, + output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + reuse_cache: Optional[bool] = None, + flash_attention_recompute: Optional[bool] = False, + cache_idx: int = None, + lazy_mode: Optional[bool] = True, + num_virtual_tokens: int = None, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1194,8 +1749,14 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, return_dict=return_dict, token_idx=token_idx, + reuse_cache=reuse_cache, + flash_attention_recompute=flash_attention_recompute, + cache_idx=cache_idx, + lazy_mode=lazy_mode, + num_virtual_tokens=num_virtual_tokens, ) hidden_states = outputs[0] @@ -1215,16 +1776,31 @@ def forward( shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + if not return_dict: output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output return (loss,) + output if loss is not None else output - return CausalLMOutputWithPast( + return MoeCausalLMOutputWithPast( loss=loss, + aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + router_logits=outputs.router_logits, ) def prepare_inputs_for_generation( @@ -1237,7 +1813,8 @@ def prepare_inputs_for_generation( ): token_idx = kwargs.get("token_idx") past_length = 0 - max_cache_length = None + reuse_cache = kwargs.get("reuse_cache") + # Omit tokens covered by past_key_values if past_key_values is not None: if token_idx is not None: input_ids = torch.index_select(input_ids, 1, token_idx - 1) @@ -1250,26 +1827,29 @@ def prepare_inputs_for_generation( cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation @@ -1294,6 +1874,134 @@ def prepare_inputs_for_generation( "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "token_idx": token_idx, + "reuse_cache": reuse_cache, + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "cache_idx": kwargs.get("cache_idx"), + "lazy_mode": kwargs.get("lazy_mode"), + # "use_dynamic_moe": kwargs.get("use_dynamic_moe"), + "num_virtual_tokens": kwargs.get("num_virtual_tokens"), } ) return model_inputs + + +@add_start_docstrings( + """ + The DeepseekV2 Model transformer with a sequence classification head on top (linear layer). + + [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) From 33c0930f2abb31710dcbc5d95d671837e99c0caf Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Mon, 23 Dec 2024 16:59:45 +0800 Subject: [PATCH 02/14] Optimized cache --- .../deepseek_v2/modeling_deepseek_v2.py | 412 ++++++++++++++++-- 1 file changed, 384 insertions(+), 28 deletions(-) diff --git a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py index 339ae93f2c..2de8e57c75 100644 --- a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -708,19 +708,45 @@ def forward(self, hidden_states): return final_hidden_states +class Matmul(torch.nn.Module): + def __init__(self): + super().__init__() -# Copied from transformers.models.llama.modeling_llama.repeat_kv -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + def forward(self, x, y): + return torch.matmul(x, y) + +def gaudi_deepseekv2_repeat_kv( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, +): """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + Copied from repeat_kv: https://github.com/huggingface/transformers/blob/v4.37.0/src/transformers/models/mixtral/modeling_mixtral.py + The only differences are: + - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them. + - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion. + The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) + The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + batch, num_key_value_heads, kv_len, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask + + new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) + key_states = key_states.reshape(new_kv_shape) + value_states = value_states.reshape(new_kv_shape) + + batch, q_heads, q_len, head_dim = query_states.shape + new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) + query_states = query_states.reshape(new_q_shape) + if attention_mask is not None: + # Add groups dim and set to 1 + attention_mask = attention_mask.unsqueeze(1) + + return query_states, key_states, value_states, attention_mask class KVCache(torch.nn.Module): def __init__(self): @@ -763,6 +789,42 @@ def get_shape(self): def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) +class ModuleFusedSDPA(torch.nn.Module): + def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8): + super().__init__() + self._hpu_kernel_fsdpa = fusedSDPA + self.scale = scale + self.attention_dropout = attention_dropout + self.enable_recompute = enable_recompute + self.flash_attention_fp8 = flash_attention_fp8 + + def forward( + self, + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side="left", + ): + return self._hpu_kernel_fsdpa.apply( + query, + key, + value, + attn_mask, + dropout_p, + is_casual, + scale, + softmax_mode, + recompute_mode, + valid_sequence_lengths, + padding_side, + ) # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 class DeepseekV2Attention(nn.Module): @@ -778,7 +840,7 @@ def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) - + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -791,7 +853,24 @@ def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + self.norm_factor = 1.0 / math.sqrt(self.head_dim) + + self.fused_scaled_dot_product_attention = ( + ModuleFusedSDPA( + FusedSDPA, + scale=self.norm_factor, + attention_dropout=self.attention_dropout, + enable_recompute=False, + flash_attention_fp8=getattr(config, "flash_attention_fp8", False), + ) + if FusedSDPA + else None + ) + + #self.head_dim = self.hidden_size // self.num_heads + + self.is_causal = True if self.q_lora_rank is None: @@ -819,6 +898,10 @@ def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): bias=config.attention_bias, ) self._init_rope() + + self.num_key_value_groups = self.num_heads // config.num_key_value_heads + self.matmul_qk = Matmul() + self.matmul_av = Matmul() self.k_cache = KVCache() self.v_cache = KVCache() self.inp_seq_len = -1 @@ -892,7 +975,7 @@ def update_sincos_cache(self, seq_len): # reduce memory consumption and improve performance. if seq_len > self.max_position_embeddings: self.max_position_embeddings = seq_len - _, _ = self.rotary_emb(self.k_proj.weight, seq_len=seq_len) + _, _ = self.rotary_emb(self.k_b_proj.weight, seq_len=seq_len) def reorder(self, tensor, beam_idx, dim_a, dim_b): updated = tensor.index_select(0, beam_idx) @@ -946,13 +1029,22 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, + ###### + cache_position: Optional[torch.LongTensor] = None, + attn_softmax_bf16: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, + num_virtual_tokens: int = None, + ###### **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Attention masks and past cache are removed. Input: - - hidden_states: [bsz, q_len, hidden_size] - - compressed_kv: [bsz, kv_len, kv_lora_rank] + - hidden_states: [bsz, q_len, hidden_size] - position_ids: [bsz, q_len] """ @@ -991,10 +1083,26 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + ########## + if token_idx is None: + if hasattr(past_key_value, "get_usable_length"): + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + else: + kv_seq_len += past_key_value[0].shape[-2] + else: + if reuse_cache and not isinstance(past_key_value[0], torch.Tensor): + kv_seq_len = past_key_value[0][-2] + else: + if num_virtual_tokens is not None and num_virtual_tokens == past_key_value[0].shape[-2]: + kv_seq_len = past_key_value[0].shape[-2] + kv_seq_len + else: + kv_seq_len = past_key_value[0].shape[-2] + + ############# + #kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_customized_rope(q_pe, k_pe, cos, sin, position_ids) - + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe @@ -1002,12 +1110,133 @@ def forward( key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx, cache_kwargs + #################optimization + if use_cache: + # reuse k, v, self_attention + if reuse_cache: + if past_key_value is not None and isinstance(past_key_value[0], torch.Tensor): + # prefix tuning case. attach past_key_value to generate first token. + key_states = torch.cat((past_key_value[0], key_states), -2) + value_states = torch.cat((past_key_value[1], value_states), -2) + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.kv_b_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros( + key_states.shape, dtype=self.kv_b_proj.weight.dtype, device=key_states.device + ) + # Return list instead of tuple + past_key_value = [past_key, past_value] + if ( + token_idx is not None + and num_virtual_tokens is not None + and num_virtual_tokens == past_key_value[0].shape[-2] + ): + # prefix tuning case. attach past_key_value to generate first token. + key_states = torch.cat((past_key_value[0], key_states), -2) + value_states = torch.cat((past_key_value[1], value_states), -2) + past_key_value = (key_states, value_states) + else: + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + + if token_idx is None: + past_key_value = (key_states, value_states) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] + else: + past_key_value = None + ############## + # if past_key_value is not None: + # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update( + # key_states, value_states, self.layer_idx, cache_kwargs + # ) + ##################optimization + if use_flash_attention and FusedSDPA is not None: + if q_len == 1: + # next token + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + "None", + False, + None, + "None", + ) + else: + # first token + softmax_mode = "fast" if flash_attention_fast_softmax else "None" + if flash_attention_causal_mask: + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + None, + 0.0, + True, + None, + softmax_mode, + flash_attention_recompute, + valid_sequence_lengths, + "left", + ) + else: + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attention_mask, + 0.0, + False, + None, + softmax_mode, + flash_attention_recompute, + None, + "None", + ) + + else: + query_states, key_states, value_states, attention_mask = gaudi_deepseekv2_repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups ) + + #query_states = query_states * self.norm_factor + #attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)).float() + attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.softmax_scale + htcore.mark_step() + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask.float() + + if attn_softmax_bf16: + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=query_states.dtype) + else: + # upcast attention to fp32 + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( + query_states.dtype + ) + attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = self.matmul_av(attn_weights, value_states) + #attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) + ''' + ############### if FusedSDPA: with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): attn_output = FusedSDPA.apply( @@ -1033,6 +1262,7 @@ def forward( attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) + ''' else: hidden_states_q = hidden_states hidden_states_kv = hidden_states @@ -1199,6 +1429,14 @@ def forward( token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, cache_idx: int = None, + cache_position: Optional[torch.LongTensor] = None, + attn_softmax_bf16: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_recompute: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: Optional[torch.Tensor] = None, + num_virtual_tokens: int = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -1234,6 +1472,14 @@ def forward( token_idx=token_idx, reuse_cache=reuse_cache, cache_idx=cache_idx, + cache_position=cache_position, + attn_softmax_bf16=attn_softmax_bf16, + use_flash_attention=use_flash_attention, + flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, + valid_sequence_lengths=valid_sequence_lengths, + num_virtual_tokens=num_virtual_tokens, **kwargs, ) hidden_states = residual + hidden_states @@ -1424,7 +1670,9 @@ def forward( flash_attention_fast_softmax: Optional[bool] = False, cache_idx: int = None, lazy_mode: Optional[bool] = True, + valid_sequence_lengths: Optional[torch.Tensor] = None, num_virtual_tokens: int = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1522,18 +1770,20 @@ def forward( causal_mask, position_ids, past_key_values, - output_attentions, - output_router_logits, + output_attentions, use_cache, - cache_position, - None, - attn_softmax_bf16, - False, + token_idx, + reuse_cache, + cache_idx, + cache_position, + attn_softmax_bf16, use_flash_attention, flash_attention_recompute, flash_attention_causal_mask, flash_attention_fast_softmax, - None, + valid_sequence_lengths, + num_virtual_tokens, + kwargs, ) else: if ( @@ -1711,6 +1961,15 @@ def forward( reuse_cache: Optional[bool] = None, flash_attention_recompute: Optional[bool] = False, cache_idx: int = None, + ######## + cache_position: Optional[torch.LongTensor] = None, + trim_logits: Optional[bool] = False, + attn_softmax_bf16: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, + flash_attention_causal_mask: Optional[bool] = False, + flash_attention_fast_softmax: Optional[bool] = False, + valid_sequence_lengths: torch.Tensor = None, + ######### lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: @@ -1751,15 +2010,29 @@ def forward( output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, + cache_position=cache_position, token_idx=token_idx, + attn_softmax_bf16=attn_softmax_bf16, reuse_cache=reuse_cache, + use_flash_attention=use_flash_attention, flash_attention_recompute=flash_attention_recompute, + flash_attention_causal_mask=flash_attention_causal_mask, + flash_attention_fast_softmax=flash_attention_fast_softmax, cache_idx=cache_idx, lazy_mode=lazy_mode, + valid_sequence_lengths=valid_sequence_lengths, num_virtual_tokens=num_virtual_tokens, ) hidden_states = outputs[0] + + _, seq_len, _ = hidden_states.shape + if seq_len > 1 and trim_logits and not self.training: + if token_idx is not None: + hidden_states = hidden_states.index_select(1, token_idx - 1) + else: + hidden_states = hidden_states[:, -1, :] + logits = self.lm_head(hidden_states) logits = logits.float() @@ -1802,13 +2075,18 @@ def forward( attentions=outputs.attentions, router_logits=outputs.router_logits, ) - + ''' def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + token_idx=None, **kwargs, ): token_idx = kwargs.get("token_idx") @@ -1878,7 +2156,85 @@ def prepare_inputs_for_generation( "flash_attention_recompute": kwargs.get("flash_attention_recompute"), "cache_idx": kwargs.get("cache_idx"), "lazy_mode": kwargs.get("lazy_mode"), - # "use_dynamic_moe": kwargs.get("use_dynamic_moe"), + + "num_virtual_tokens": kwargs.get("num_virtual_tokens"), + } + ) + return model_inputs + ''' + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + token_idx=None, + **kwargs, + ): + reuse_cache = kwargs.get("reuse_cache") + bucket_internal = kwargs.get("bucket_internal") + + if past_key_values is not None: + if token_idx is not None: + idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1 + input_ids = torch.index_select(input_ids, 1, idx) + else: + if inputs_embeds is not None: # Exception 1 + input_ids = input_ids[:, -cache_position.shape[0] :] + elif ( + input_ids.shape[1] != cache_position.shape[0] + ): # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + elif (reuse_cache or bucket_internal) and token_idx is not None: + # KV cache is pre allocated with reuse cache or will be padded with bucket internal + # hence for the 1st token we can slice the inputs till token idx for the fwd pass. + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + if token_idx is not None: + position_ids = torch.index_select(position_ids, 1, token_idx - 1) + else: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # keep cache_position implementation as None for HPU + cache_position = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + if num_logits_to_keep is not None: + model_inputs["num_logits_to_keep"] = num_logits_to_keep + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "token_idx": token_idx, + "trim_logits": kwargs.get("trim_logits"), + "attn_softmax_bf16": kwargs.get("attn_softmax_bf16"), + "reuse_cache": reuse_cache, + "use_flash_attention": kwargs.get("use_flash_attention"), + "flash_attention_recompute": kwargs.get("flash_attention_recompute"), + "flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"), + "flash_attention_fast_softmax": kwargs.get("flash_attention_fast_softmax"), + "valid_sequence_lengths": kwargs.get("valid_sequence_lengths"), + "cache_idx": kwargs.get("cache_idx"), + "lazy_mode": kwargs.get("lazy_mode"), "num_virtual_tokens": kwargs.get("num_virtual_tokens"), } ) From 5b0bc2bd9d87e42197fe36bdffc177ff809db96f Mon Sep 17 00:00:00 2001 From: ranzhejiang Date: Mon, 23 Dec 2024 10:33:05 +0000 Subject: [PATCH 03/14] enable training with sdpa --- .../deepseek_v2/modeling_deepseek_v2.py | 127 ++++-------------- 1 file changed, 25 insertions(+), 102 deletions(-) diff --git a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py index 2de8e57c75..bec7d01546 100644 --- a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -853,24 +853,7 @@ def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim - self.norm_factor = 1.0 / math.sqrt(self.head_dim) - - self.fused_scaled_dot_product_attention = ( - ModuleFusedSDPA( - FusedSDPA, - scale=self.norm_factor, - attention_dropout=self.attention_dropout, - enable_recompute=False, - flash_attention_fp8=getattr(config, "flash_attention_fp8", False), - ) - if FusedSDPA - else None - ) - - - #self.head_dim = self.hidden_size // self.num_heads - - + self.is_causal = True if self.q_lora_rank is None: @@ -914,6 +897,19 @@ def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) self.softmax_scale = self.softmax_scale * mscale * mscale + self.norm_factor = self.softmax_scale + self.fused_scaled_dot_product_attention = ( + ModuleFusedSDPA( + FusedSDPA, + scale=self.norm_factor, + attention_dropout=self.attention_dropout, + enable_recompute=False, + flash_attention_fp8=getattr(config, "flash_attention_fp8", False), + ) + if FusedSDPA + else None + ) + def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = DeepseekV2RotaryEmbedding( @@ -1090,13 +1086,11 @@ def forward( else: kv_seq_len += past_key_value[0].shape[-2] else: - if reuse_cache and not isinstance(past_key_value[0], torch.Tensor): - kv_seq_len = past_key_value[0][-2] + ######## zhejiang fix + if num_virtual_tokens is not None and num_virtual_tokens == past_key_value[0].shape[-2]: + kv_seq_len = past_key_value[0].shape[-2] + kv_seq_len else: - if num_virtual_tokens is not None and num_virtual_tokens == past_key_value[0].shape[-2]: - kv_seq_len = past_key_value[0].shape[-2] + kv_seq_len - else: - kv_seq_len = past_key_value[0].shape[-2] + kv_seq_len = past_key_value[0].shape[-2] ############# #kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) @@ -1110,55 +1104,13 @@ def forward( key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - #################optimization - if use_cache: - # reuse k, v, self_attention - if reuse_cache: - if past_key_value is not None and isinstance(past_key_value[0], torch.Tensor): - # prefix tuning case. attach past_key_value to generate first token. - key_states = torch.cat((past_key_value[0], key_states), -2) - value_states = torch.cat((past_key_value[1], value_states), -2) - key_states = self.k_cache(key_states, 2, token_idx) - value_states = self.v_cache(value_states, 2, token_idx) - past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) - else: - if past_key_value is None: - past_key = torch.zeros(key_states.shape, dtype=self.kv_b_proj.weight.dtype, device=key_states.device) - past_value = torch.zeros( - key_states.shape, dtype=self.kv_b_proj.weight.dtype, device=key_states.device - ) - # Return list instead of tuple - past_key_value = [past_key, past_value] - if ( - token_idx is not None - and num_virtual_tokens is not None - and num_virtual_tokens == past_key_value[0].shape[-2] - ): - # prefix tuning case. attach past_key_value to generate first token. - key_states = torch.cat((past_key_value[0], key_states), -2) - value_states = torch.cat((past_key_value[1], value_states), -2) - past_key_value = (key_states, value_states) - else: - key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) - value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) - - if token_idx is None: - past_key_value = (key_states, value_states) - - if cache_idx is not None and q_len == 1: - key_states = key_states[:, :, :cache_idx, :] - value_states = value_states[:, :, :cache_idx, :] - if attention_mask is not None: - attention_mask = attention_mask[:, :, :, :cache_idx] - kv_seq_len = key_states.shape[-2] - else: - past_key_value = None - ############## - # if past_key_value is not None: - # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - # key_states, value_states = past_key_value.update( - # key_states, value_states, self.layer_idx, cache_kwargs - # ) + + ##### zhejiang fix + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) ##################optimization if use_flash_attention and FusedSDPA is not None: if q_len == 1: @@ -1213,7 +1165,6 @@ def forward( query_states, key_states, value_states, attention_mask, self.num_key_value_groups ) - #query_states = query_states * self.norm_factor #attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)).float() attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.softmax_scale @@ -1235,34 +1186,6 @@ def forward( attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = self.matmul_av(attn_weights, value_states) #attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) - ''' - ############### - if FusedSDPA: - with sdp_kernel(enable_recompute=False) if SDPContext else contextlib.nullcontext(): - attn_output = FusedSDPA.apply( - query_states, key_states, value_states, attention_mask, 0.0, False, None - ) - else: - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale - - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - assert attention_mask is not None - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = torch.matmul(attn_weights, value_states) - ''' else: hidden_states_q = hidden_states hidden_states_kv = hidden_states From 1b53cf8bc344231a592ca376dff593fb1d09374a Mon Sep 17 00:00:00 2001 From: ranzhejiang Date: Mon, 23 Dec 2024 10:37:47 +0000 Subject: [PATCH 04/14] add gaudi fused kernel config for deepseek_v2 --- optimum/habana/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimum/habana/transformers/trainer.py b/optimum/habana/transformers/trainer.py index ec7d31e3a6..4bb8f5d5ca 100644 --- a/optimum/habana/transformers/trainer.py +++ b/optimum/habana/transformers/trainer.py @@ -158,7 +158,7 @@ def _get_input_update_settings(model, lazy_mode: Optional[bool] = None) -> Tuple inputs_update: Dict = {} should_update_inputs = (getattr(model, "generation_config", None) is not None) and ( - model.config.model_type in ("llama", "qwen2", "starcoder2", "gemma", "baichuan", "chatglm") + model.config.model_type in ("llama", "qwen2", "starcoder2", "gemma", "baichuan", "chatglm", "deepseek_v2") ) if should_update_inputs: if model.generation_config.attn_softmax_bf16: From aaae32098a64ae263398acdec7d046173229f8ad Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Mon, 6 Jan 2025 07:29:55 +0000 Subject: [PATCH 05/14] Added expert slice to support large scale DeepSeek-v2 models such as deepseek-v2-chat-0628 --- .../deepseek_v2/modeling_deepseek_v2.py | 180 +++++------------- 1 file changed, 43 insertions(+), 137 deletions(-) diff --git a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py index bec7d01546..cce1e141cf 100644 --- a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -29,12 +29,11 @@ import torch import torch.distributed as dist -# import torch.distributed as dist import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers import PretrainedConfig # , PreTrainedModel +from transformers import PretrainedConfig from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.integrations.deepspeed import is_deepspeed_available @@ -58,13 +57,11 @@ logging, ) -# from ....distributed.tensorparallel import _all_reduce from transformers.utils.import_utils import is_torch_fx_available from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask from .configuration_deepseek_v2 import DeepseekV2Config - # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. if is_torch_fx_available(): @@ -73,7 +70,6 @@ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DeepseekV2Config" @@ -639,6 +635,9 @@ def __init__(self, config): if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size) + SLICE_MAX_EXPERT = 80 + self.expert_slice = math.ceil(config.n_routed_experts/SLICE_MAX_EXPERT) + self.expert_chunk = self.config.n_routed_experts // self.expert_slice def forward(self, hidden_states): identity = hidden_states @@ -673,27 +672,30 @@ def forward(self, hidden_states): final_hidden_states = final_hidden_states.type(hidden_states.dtype) final_hidden_states = final_hidden_states.view(*orig_shape) final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, aux_loss) - else: - # pre-processing for custom op inputs - experts_range = range(self.config.n_routed_experts) - gate_proj_list = [self.experts[i].gate_proj.weight.squeeze() for i in experts_range] - down_proj_list = [self.experts[i].down_proj.weight.squeeze() for i in experts_range] - up_proj_list = [self.experts[i].up_proj.weight.squeeze() for i in experts_range] - - act_fn_s = self.config.hidden_act - final_hidden_states = torch.ops.hpu.mixture_of_experts( - hidden_states=hidden_states, - expert_routing_table=topk_idx, - router_weights=topk_weight, - w1=gate_proj_list, - w2=up_proj_list, - w3=down_proj_list, - permuted_weights=True, - activation=act_fn_s, - experts_min=0, - experts_max=(self.config.n_routed_experts - 1), + else: + final_hidden_states = torch.zeros( + (batch * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) - final_hidden_states = final_hidden_states.reshape(-1, sequence_length, hidden_dim) + for idx in range(self.expert_slice): + experts_range = range(self.expert_chunk) + gate_proj_list = [self.experts[idx * self.expert_chunk + i].gate_proj.weight.squeeze() for i in experts_range] + down_proj_list = [self.experts[idx * self.expert_chunk + i].down_proj.weight.squeeze() for i in experts_range] + up_proj_list = [self.experts[idx * self.expert_chunk + i].up_proj.weight.squeeze() for i in experts_range] + + hidden_states_slice = torch.ops.hpu.mixture_of_experts( + hidden_states=hidden_states, + expert_routing_table=topk_idx, + router_weights=topk_weight, + w1=gate_proj_list, + w2=up_proj_list, + w3=down_proj_list, + permuted_weights=True, + activation="silu", + experts_min=(self.expert_chunk * idx), + experts_max=(self.expert_chunk * (idx + 1) - 1), + ) + final_hidden_states = final_hidden_states + hidden_states_slice + htcore.mark_step() if is_deepspeed_available(): from deepspeed import comm as dist @@ -702,6 +704,7 @@ def forward(self, hidden_states): dist.all_reduce(final_hidden_states, op=dist.ReduceOp.SUM) final_hidden_states = final_hidden_states.type(hidden_states.dtype) + final_hidden_states = final_hidden_states.reshape(-1, sequence_length, hidden_dim) if self.config.n_shared_experts is not None: final_hidden_states = final_hidden_states + self.shared_experts(identity) @@ -994,9 +997,8 @@ def split_kv_b_proj(self): kv_b_proj_weight = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) self.q_absorb = ( kv_b_proj_weight[:, : self.qk_nope_head_dim, :].unsqueeze(0).transpose(0, 1) - ) # k, head (128) - dim 0,auto_tp split based on dim 0 /head - self.out_absorb = kv_b_proj_weight[:, self.qk_nope_head_dim :, :].unsqueeze(0) # v head (128) - dim 0 - # del self.kv_b_proj + ) + self.out_absorb = kv_b_proj_weight[:, self.qk_nope_head_dim :, :].unsqueeze(0) def compress_kv( self, @@ -1024,8 +1026,7 @@ def forward( use_cache: bool = False, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, - cache_idx: int = None, - ###### + cache_idx: int = None, cache_position: Optional[torch.LongTensor] = None, attn_softmax_bf16: Optional[bool] = False, use_flash_attention: Optional[bool] = False, @@ -1033,8 +1034,7 @@ def forward( flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, valid_sequence_lengths: Optional[torch.Tensor] = None, - num_virtual_tokens: int = None, - ###### + num_virtual_tokens: int = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ @@ -1079,21 +1079,19 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - ########## + if token_idx is None: if hasattr(past_key_value, "get_usable_length"): kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else: kv_seq_len += past_key_value[0].shape[-2] else: - ######## zhejiang fix + if num_virtual_tokens is not None and num_virtual_tokens == past_key_value[0].shape[-2]: kv_seq_len = past_key_value[0].shape[-2] + kv_seq_len else: kv_seq_len = past_key_value[0].shape[-2] - ############# - #kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_customized_rope(q_pe, k_pe, cos, sin, position_ids) @@ -1105,13 +1103,13 @@ def forward( key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - ##### zhejiang fix + if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) - ##################optimization + #optimization if use_flash_attention and FusedSDPA is not None: if q_len == 1: # next token @@ -1163,10 +1161,8 @@ def forward( else: query_states, key_states, value_states, attention_mask = gaudi_deepseekv2_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups - ) + ) - #query_states = query_states * self.norm_factor - #attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)).float() attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.softmax_scale htcore.mark_step() @@ -1184,8 +1180,7 @@ def forward( query_states.dtype ) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = self.matmul_av(attn_weights, value_states) - #attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) + attn_output = self.matmul_av(attn_weights, value_states) else: hidden_states_q = hidden_states hidden_states_kv = hidden_states @@ -1292,8 +1287,7 @@ def forward( attn_output = torch.matmul(attn_output.permute(2, 1, 0, 3), self.out_absorb.mT).permute( 2, 1, 0, 3 - ) # torch.einsum('bhqc,hdc->bhqd', attn_output, out_absorb) - ####end of inference + ) if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): raise ValueError( @@ -1877,22 +1871,20 @@ def forward( labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = False, # None, + output_hidden_states: Optional[bool] = False, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = None, flash_attention_recompute: Optional[bool] = False, - cache_idx: int = None, - ######## + cache_idx: int = None, cache_position: Optional[torch.LongTensor] = None, trim_logits: Optional[bool] = False, attn_softmax_bf16: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, - valid_sequence_lengths: torch.Tensor = None, - ######### + valid_sequence_lengths: torch.Tensor = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: @@ -1998,93 +1990,7 @@ def forward( attentions=outputs.attentions, router_logits=outputs.router_logits, ) - ''' - def prepare_inputs_for_generation( - self, - input_ids, - past_key_values=None, - attention_mask=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - use_cache=True, - num_logits_to_keep=None, - token_idx=None, - **kwargs, - ): - token_idx = kwargs.get("token_idx") - past_length = 0 - reuse_cache = kwargs.get("reuse_cache") - # Omit tokens covered by past_key_values - if past_key_values is not None: - if token_idx is not None: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) - else: - if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() - else: - cache_length = past_length = past_key_values[0][0].shape[2] - max_cache_length = None - - # Keep only the unprocessed tokens: - # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as - # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] - # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard - # input_ids based on the past_length. - elif past_length < input_ids.shape[1]: - input_ids = input_ids[:, past_length:] - # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. - - # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. - if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + input_ids.shape[1] > max_cache_length - ): - attention_mask = attention_mask[:, -max_cache_length:] - elif reuse_cache and token_idx is not None: - # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass - input_ids = input_ids[:, :token_idx] - attention_mask = attention_mask[:, :token_idx] - position_ids = kwargs.get("position_ids", None) - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - if past_key_values: - if token_idx is not None: - position_ids = torch.index_select(position_ids, 1, token_idx - 1) - else: - position_ids = position_ids[:, -input_ids.shape[1] :] - - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and past_key_values is None: - model_inputs = {"inputs_embeds": inputs_embeds} - else: - model_inputs = {"input_ids": input_ids.contiguous()} - - model_inputs.update( - { - "position_ids": position_ids, - "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), - "attention_mask": attention_mask, - "token_idx": token_idx, - "reuse_cache": reuse_cache, - "flash_attention_recompute": kwargs.get("flash_attention_recompute"), - "cache_idx": kwargs.get("cache_idx"), - "lazy_mode": kwargs.get("lazy_mode"), - - "num_virtual_tokens": kwargs.get("num_virtual_tokens"), - } - ) - return model_inputs - ''' + def prepare_inputs_for_generation( self, input_ids, From c982121bb3250e9d19ec9591bc36f1f3b01c18ab Mon Sep 17 00:00:00 2001 From: "Ran, Zhejiang" Date: Wed, 8 Jan 2025 06:13:37 +0000 Subject: [PATCH 06/14] add fused kernel config support for run_clm.py --- examples/language-modeling/run_clm.py | 42 +++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index feac065364..1606b02ab4 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -156,6 +156,40 @@ class ModelArguments: ) }, ) + attn_softmax_bf16: bool = field( + default=False, + metadata={ + "help": ( + "Whether to run attention softmax layer in bf16 precision for fine-tuning. The current support is limited to Llama only." + ) + }, + ) + use_flash_attention: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use Habana flash attention for fine-tuning. The current support is limited to Llama only." + ) + }, + ) + flash_attention_recompute: bool = field( + default=False, + metadata={ + "help": ( + "Whether to enable recompute in Habana flash attention for fine-tuning." + " It is applicable only when use_flash_attention is True." + ) + }, + ) + flash_attention_causal_mask: bool = field( + default=False, + metadata={ + "help": ( + "Whether to enable causal mask in Habana flash attention for fine-tuning." + " It is applicable only when use_flash_attention is True." + ) + }, + ) low_cpu_mem_usage: bool = field( default=False, metadata={ @@ -482,6 +516,14 @@ def main(): if len(tokenizer) > embedding_size: model.resize_token_embeddings(len(tokenizer)) + # We need to add these fused kernels config + if model_args.attn_softmax_bf16: + model.generation_config.attn_softmax_bf16 = True + if model_args.use_flash_attention: + model.generation_config.use_flash_attention = True + model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute + model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask + # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: From 8b8e5c5b385e4be4b34d59e4aa2e753c4349e1e1 Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Wed, 8 Jan 2025 15:09:59 +0000 Subject: [PATCH 07/14] Fixed style errors --- .../deepseek_v2/modeling_deepseek_v2.py | 110 +++++++++--------- 1 file changed, 53 insertions(+), 57 deletions(-) diff --git a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py index cce1e141cf..5f92115b32 100644 --- a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -19,7 +19,6 @@ # limitations under the License. """PyTorch DeepSeekV2 model. Adapted from https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/resolve/main/modeling_deepseek.py""" -import contextlib import math import os import warnings @@ -28,12 +27,11 @@ import habana_frameworks.torch.core as htcore import torch import torch.distributed as dist - import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers import PretrainedConfig +from transformers import PretrainedConfig from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.integrations.deepspeed import is_deepspeed_available @@ -56,12 +54,12 @@ add_start_docstrings_to_model_forward, logging, ) - from transformers.utils.import_utils import is_torch_fx_available from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask from .configuration_deepseek_v2 import DeepseekV2Config + # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. if is_torch_fx_available(): @@ -96,13 +94,6 @@ print("Not using HPU fused scaled dot-product attention kernel.") FusedSDPA = None -try: - from habana_frameworks.torch.hpu import sdp_kernel - - SDPContext = True -except ImportError: - SDPContext = False - logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DeepseekV2Config" @@ -636,7 +627,7 @@ def __init__(self, config): intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size) SLICE_MAX_EXPERT = 80 - self.expert_slice = math.ceil(config.n_routed_experts/SLICE_MAX_EXPERT) + self.expert_slice = math.ceil(config.n_routed_experts / SLICE_MAX_EXPERT) self.expert_chunk = self.config.n_routed_experts // self.expert_slice def forward(self, hidden_states): @@ -672,15 +663,21 @@ def forward(self, hidden_states): final_hidden_states = final_hidden_states.type(hidden_states.dtype) final_hidden_states = final_hidden_states.view(*orig_shape) final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, aux_loss) - else: + else: final_hidden_states = torch.zeros( (batch * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) for idx in range(self.expert_slice): experts_range = range(self.expert_chunk) - gate_proj_list = [self.experts[idx * self.expert_chunk + i].gate_proj.weight.squeeze() for i in experts_range] - down_proj_list = [self.experts[idx * self.expert_chunk + i].down_proj.weight.squeeze() for i in experts_range] - up_proj_list = [self.experts[idx * self.expert_chunk + i].up_proj.weight.squeeze() for i in experts_range] + gate_proj_list = [ + self.experts[idx * self.expert_chunk + i].gate_proj.weight.squeeze() for i in experts_range + ] + down_proj_list = [ + self.experts[idx * self.expert_chunk + i].down_proj.weight.squeeze() for i in experts_range + ] + up_proj_list = [ + self.experts[idx * self.expert_chunk + i].up_proj.weight.squeeze() for i in experts_range + ] hidden_states_slice = torch.ops.hpu.mixture_of_experts( hidden_states=hidden_states, @@ -711,13 +708,15 @@ def forward(self, hidden_states): return final_hidden_states + class Matmul(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x, y): return torch.matmul(x, y) - + + def gaudi_deepseekv2_repeat_kv( query_states: torch.Tensor, key_states: torch.Tensor, @@ -751,6 +750,7 @@ def gaudi_deepseekv2_repeat_kv( return query_states, key_states, value_states, attention_mask + class KVCache(torch.nn.Module): def __init__(self): super(KVCache, self).__init__() @@ -792,6 +792,7 @@ def get_shape(self): def forward(self, cur, dim, idx): return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + class ModuleFusedSDPA(torch.nn.Module): def __init__(self, fusedSDPA, scale, attention_dropout, enable_recompute, flash_attention_fp8): super().__init__() @@ -829,6 +830,7 @@ def forward( padding_side, ) + # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 class DeepseekV2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -843,7 +845,7 @@ def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) - + self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -856,7 +858,7 @@ def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): self.v_head_dim = config.v_head_dim self.qk_nope_head_dim = config.qk_nope_head_dim self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim - + self.is_causal = True if self.q_lora_rank is None: @@ -884,7 +886,7 @@ def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): bias=config.attention_bias, ) self._init_rope() - + self.num_key_value_groups = self.num_heads // config.num_key_value_heads self.matmul_qk = Matmul() self.matmul_av = Matmul() @@ -995,10 +997,8 @@ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def split_kv_b_proj(self): kv_b_proj_weight = self.kv_b_proj.weight.view(self.num_heads, -1, self.kv_lora_rank) - self.q_absorb = ( - kv_b_proj_weight[:, : self.qk_nope_head_dim, :].unsqueeze(0).transpose(0, 1) - ) - self.out_absorb = kv_b_proj_weight[:, self.qk_nope_head_dim :, :].unsqueeze(0) + self.q_absorb = kv_b_proj_weight[:, : self.qk_nope_head_dim, :].unsqueeze(0).transpose(0, 1) + self.out_absorb = kv_b_proj_weight[:, self.qk_nope_head_dim :, :].unsqueeze(0) def compress_kv( self, @@ -1026,21 +1026,21 @@ def forward( use_cache: bool = False, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = False, - cache_idx: int = None, + cache_idx: int = None, cache_position: Optional[torch.LongTensor] = None, attn_softmax_bf16: Optional[bool] = False, use_flash_attention: Optional[bool] = False, flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, - valid_sequence_lengths: Optional[torch.Tensor] = None, - num_virtual_tokens: int = None, + valid_sequence_lengths: Optional[torch.Tensor] = None, + num_virtual_tokens: int = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Attention masks and past cache are removed. Input: - - hidden_states: [bsz, q_len, hidden_size] + - hidden_states: [bsz, q_len, hidden_size] - position_ids: [bsz, q_len] """ @@ -1079,22 +1079,21 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - + if token_idx is None: if hasattr(past_key_value, "get_usable_length"): kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) else: kv_seq_len += past_key_value[0].shape[-2] else: - if num_virtual_tokens is not None and num_virtual_tokens == past_key_value[0].shape[-2]: kv_seq_len = past_key_value[0].shape[-2] + kv_seq_len else: kv_seq_len = past_key_value[0].shape[-2] - + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_customized_rope(q_pe, k_pe, cos, sin, position_ids) - + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe @@ -1102,14 +1101,13 @@ def forward( key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) key_states[:, :, :, : self.qk_nope_head_dim] = k_nope key_states[:, :, :, self.qk_nope_head_dim :] = k_pe - - + if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) - #optimization + # optimization if use_flash_attention and FusedSDPA is not None: if q_len == 1: # next token @@ -1161,8 +1159,8 @@ def forward( else: query_states, key_states, value_states, attention_mask = gaudi_deepseekv2_repeat_kv( query_states, key_states, value_states, attention_mask, self.num_key_value_groups - ) - + ) + attn_weights = self.matmul_qk(query_states, key_states.transpose(-2, -1)) * self.softmax_scale htcore.mark_step() @@ -1179,8 +1177,10 @@ def forward( attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( query_states.dtype ) - attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - attn_output = self.matmul_av(attn_weights, value_states) + attn_weights = torch.nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = self.matmul_av(attn_weights, value_states) else: hidden_states_q = hidden_states hidden_states_kv = hidden_states @@ -1285,9 +1285,7 @@ def forward( attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.einsum("bhql,blc->bhqc", attn_weights, compressed_kv) - attn_output = torch.matmul(attn_output.permute(2, 1, 0, 3), self.out_absorb.mT).permute( - 2, 1, 0, 3 - ) + attn_output = torch.matmul(attn_output.permute(2, 1, 0, 3), self.out_absorb.mT).permute(2, 1, 0, 3) if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): raise ValueError( @@ -1352,7 +1350,7 @@ def forward( flash_attention_recompute: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, - valid_sequence_lengths: Optional[torch.Tensor] = None, + valid_sequence_lengths: Optional[torch.Tensor] = None, num_virtual_tokens: int = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -1589,7 +1587,6 @@ def forward( lazy_mode: Optional[bool] = True, valid_sequence_lengths: Optional[torch.Tensor] = None, num_virtual_tokens: int = None, - ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1687,20 +1684,19 @@ def forward( causal_mask, position_ids, past_key_values, - output_attentions, + output_attentions, use_cache, token_idx, reuse_cache, - cache_idx, - cache_position, - attn_softmax_bf16, + cache_idx, + cache_position, + attn_softmax_bf16, use_flash_attention, flash_attention_recompute, flash_attention_causal_mask, flash_attention_fast_softmax, valid_sequence_lengths, - num_virtual_tokens, - kwargs, + num_virtual_tokens, ) else: if ( @@ -1871,20 +1867,20 @@ def forward( labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, reuse_cache: Optional[bool] = None, flash_attention_recompute: Optional[bool] = False, - cache_idx: int = None, - cache_position: Optional[torch.LongTensor] = None, - trim_logits: Optional[bool] = False, + cache_idx: int = None, + cache_position: Optional[torch.LongTensor] = None, + trim_logits: Optional[bool] = False, attn_softmax_bf16: Optional[bool] = False, - use_flash_attention: Optional[bool] = False, + use_flash_attention: Optional[bool] = False, flash_attention_causal_mask: Optional[bool] = False, flash_attention_fast_softmax: Optional[bool] = False, - valid_sequence_lengths: torch.Tensor = None, + valid_sequence_lengths: torch.Tensor = None, lazy_mode: Optional[bool] = True, num_virtual_tokens: int = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: @@ -1990,7 +1986,7 @@ def forward( attentions=outputs.attentions, router_logits=outputs.router_logits, ) - + def prepare_inputs_for_generation( self, input_ids, From cd79a76ab0fc59b146080cd43e0128efca4eda8d Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Wed, 8 Jan 2025 15:11:40 +0000 Subject: [PATCH 08/14] Updated DeepSeek-V2 in README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e44ca5430c..9764ec99ff 100644 --- a/README.md +++ b/README.md @@ -256,7 +256,7 @@ The following model architectures, tasks and device distributions have been vali | Mllama |
  • LoRA
  • | :heavy_check_mark: |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | | MiniCPM3 | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Baichuan2 |
  • DeepSpeed
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | -| DeepSeek-V2 | | :heavy_check_mark: |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | +| DeepSeek-V2 | :heavy_check_mark: | :heavy_check_mark: |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | ChatGLM |
  • DeepSpeed
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | From 1e04748eeea3dd1daa8f0c617a8e6b146928ad31 Mon Sep 17 00:00:00 2001 From: "Ran, Zhejiang" Date: Sun, 12 Jan 2025 15:31:36 +0000 Subject: [PATCH 09/14] remove some outdate statements --- examples/language-modeling/run_clm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index 1606b02ab4..eaf8b26f30 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -160,7 +160,7 @@ class ModelArguments: default=False, metadata={ "help": ( - "Whether to run attention softmax layer in bf16 precision for fine-tuning. The current support is limited to Llama only." + "Whether to run attention softmax layer in bf16 precision for fine-tuning." ) }, ) @@ -168,7 +168,7 @@ class ModelArguments: default=False, metadata={ "help": ( - "Whether to use Habana flash attention for fine-tuning. The current support is limited to Llama only." + "Whether to use Habana flash attention for fine-tuning." ) }, ) From 71d54d883f55754a732b62ec268b6531f596eeb7 Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Fri, 17 Jan 2025 09:45:41 +0000 Subject: [PATCH 10/14] Added the training command of DeepSeek-V2-Lite. --- examples/language-modeling/README.md | 30 ++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/examples/language-modeling/README.md b/examples/language-modeling/README.md index 9ef27f9e73..693dd49241 100644 --- a/examples/language-modeling/README.md +++ b/examples/language-modeling/README.md @@ -184,6 +184,36 @@ python ../gaudi_spawn.py \ --logging_steps 20 ``` +### Multi-card Training with Deepspeed (DeepSeek-V2-Lite) +```bash +python ../gaudi_spawn.py --world_size 8 --use_deepspeed run_clm.py + --config_name deepseek-ai/DeepSeek-V2-Lite + --tokenizer_name deepseek-ai/DeepSeek-V2-Lite + --dataset_name tatsu-lab/alpaca + --block_size 4096 + --do_train + --num_train_epochs 1 + --max_steps 10 + --per_device_train_batch_size 1 + --gradient_accumulation_steps 1 + --use_flash_attention True + --attn_softmax_bf16 False + --gradient_checkpointing + --learning_rate 2.4e-4 + --gaudi_config_name Habana/gpt2 + --bf16 + --save_strategy no + --no_save_last_ckpt + --output_dir /root/deepseek-v2-lite + --overwrite_output_dir + --logging_strategy steps + --logging_dir /root/deepseek-v2-lite/log + --logging_steps 1 + --evaluation_strategy no + --use_habana + --use_lazy_mode + --deepspeed llama2_ds_zero3_config.json +``` ## Multi-Node Training with Deepspeed (GPT-NeoX) From 57a86b67741fe66d5f45b22328518e2d7c36470b Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Tue, 28 Jan 2025 09:31:17 +0000 Subject: [PATCH 11/14] Added support for deepseek-v2 training --- docs/source/index.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/index.mdx b/docs/source/index.mdx index 51d6dadf0f..18194d4c53 100644 --- a/docs/source/index.mdx +++ b/docs/source/index.mdx @@ -107,7 +107,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be | Mllama |
  • LoRA
  • |✅ |
  • [image to text](https://github.com/huggingface/optimum-habana/tree/main/examples/image-to-text)
  • | | MiniCPM3 | |
  • Single card
  • |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | Baichuan2 |
  • DeepSpeed
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | -| DeepSeek-V2 | | ✅ |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | +| DeepSeek-V2 | ✅ | ✅ |
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | | ChatGLM |
  • DeepSpeed
  • |
  • Single card
  • |
  • [language modeling](https://github.com/huggingface/optimum-habana/tree/main/examples/language-modeling)
  • [text generation](https://github.com/huggingface/optimum-habana/tree/main/examples/text-generation)
  • | - Diffusers From 8a4c1a810f72a63abf9f34eaa709e84e88d7d054 Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Tue, 28 Jan 2025 10:42:40 +0000 Subject: [PATCH 12/14] Refactor the code. --- .../models/deepseek_v2/modeling_deepseek_v2.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py index 5f92115b32..bf8c58f47c 100644 --- a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -62,16 +62,22 @@ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. -if is_torch_fx_available(): - if not is_torch_greater_or_equal_than_1_13: - import torch.fx +# if is_torch_fx_available(): +# if not is_torch_greater_or_equal_than_1_13: +# import torch.fx - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) +# _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + +import torch.fx +_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DeepseekV2Config" +#default expert number per slice for dynamic MoE +SLICE_MAX_EXPERT = 80 + try: from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE @@ -626,7 +632,7 @@ def __init__(self, config): if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size) - SLICE_MAX_EXPERT = 80 + self.expert_slice = math.ceil(config.n_routed_experts / SLICE_MAX_EXPERT) self.expert_chunk = self.config.n_routed_experts // self.expert_slice From 02a5adc67f6cfe437b5d997f7ccd4dab757cc432 Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Tue, 28 Jan 2025 11:13:08 +0000 Subject: [PATCH 13/14] Removed comments --- .../transformers/models/deepseek_v2/modeling_deepseek_v2.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py index bf8c58f47c..ed515515b4 100644 --- a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -62,12 +62,6 @@ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. -# if is_torch_fx_available(): -# if not is_torch_greater_or_equal_than_1_13: -# import torch.fx - -# _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) - import torch.fx _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) From 62d63136a0918778a787726aa5f3d8e542f612a4 Mon Sep 17 00:00:00 2001 From: gyou2021 Date: Tue, 28 Jan 2025 11:30:37 +0000 Subject: [PATCH 14/14] Fixed style. --- examples/language-modeling/run_clm.py | 20 +++------- .../habana/transformers/generation/utils.py | 37 +++++++++---------- .../deepseek_v2/modeling_deepseek_v2.py | 19 +++++----- 3 files changed, 33 insertions(+), 43 deletions(-) diff --git a/examples/language-modeling/run_clm.py b/examples/language-modeling/run_clm.py index b64f4d10f0..93d85ba54b 100644 --- a/examples/language-modeling/run_clm.py +++ b/examples/language-modeling/run_clm.py @@ -158,19 +158,11 @@ class ModelArguments: ) attn_softmax_bf16: bool = field( default=False, - metadata={ - "help": ( - "Whether to run attention softmax layer in bf16 precision for fine-tuning." - ) - }, + metadata={"help": ("Whether to run attention softmax layer in bf16 precision for fine-tuning.")}, ) use_flash_attention: bool = field( default=False, - metadata={ - "help": ( - "Whether to use Habana flash attention for fine-tuning." - ) - }, + metadata={"help": ("Whether to use Habana flash attention for fine-tuning.")}, ) flash_attention_recompute: bool = field( default=False, @@ -518,11 +510,11 @@ def main(): # We need to add these fused kernels config if model_args.attn_softmax_bf16: - model.generation_config.attn_softmax_bf16 = True + model.generation_config.attn_softmax_bf16 = True if model_args.use_flash_attention: - model.generation_config.use_flash_attention = True - model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute - model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask + model.generation_config.use_flash_attention = True + model.generation_config.flash_attention_recompute = model_args.flash_attention_recompute + model.generation_config.flash_attention_causal_mask = model_args.flash_attention_causal_mask # Preprocessing the datasets. # First we tokenize all the texts. diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index ac7094f152..e14a3e3091 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -1078,25 +1078,24 @@ def generate( assert generation_config.bucket_size >= 0, "please set bucket_size to use bucket_internal" assert generation_config.use_cache, "please set use_cache flag to use bucket_internal" if generation_config.reuse_cache: - assert ( - self.config.model_type - in [ - "llama", - "mistral", - "falcon", - "mixtral", - "phi", - "qwen2", - "gptj", - "starcoder2", - "qwen2_moe", - "gemma", - "gemma2", - "baichuan", - "chatglm", - "deepseek_v2", - ] - ), "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2, starcoder2, baichuan, chatglm and deepseek_v2 at the moment" + assert self.config.model_type in [ + "llama", + "mistral", + "falcon", + "mixtral", + "phi", + "qwen2", + "gptj", + "starcoder2", + "qwen2_moe", + "gemma", + "gemma2", + "baichuan", + "chatglm", + "deepseek_v2", + ], ( + "reuse_cache only supported by llama, mistral, falcon, mixtral, phi, qwen2, qwen2_moe, gemma, gemma2, starcoder2, baichuan, chatglm and deepseek_v2 at the moment" + ) if not generation_config.bucket_internal: assert generation_config.bucket_size <= 0, ( "please set bucket_internal along with reuse_cache and bucket_size" diff --git a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py index ed515515b4..b364e75517 100644 --- a/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/optimum/habana/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -27,6 +27,10 @@ import habana_frameworks.torch.core as htcore import torch import torch.distributed as dist + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +import torch.fx import torch.nn.functional as F import torch.utils.checkpoint from torch import nn @@ -47,29 +51,24 @@ from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import ( ALL_LAYERNORM_LAYERS, - is_torch_greater_or_equal_than_1_13, ) from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, logging, ) -from transformers.utils.import_utils import is_torch_fx_available from ...modeling_attn_mask_utils import _gaudi_prepare_4d_causal_attention_mask from .configuration_deepseek_v2 import DeepseekV2Config -# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. -# It means that the function will not be traced through and simply appear as a node in the graph. -import torch.fx _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DeepseekV2Config" -#default expert number per slice for dynamic MoE +# default expert number per slice for dynamic MoE SLICE_MAX_EXPERT = 80 try: @@ -626,7 +625,7 @@ def __init__(self, config): if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP(config=config, intermediate_size=intermediate_size) - + self.expert_slice = math.ceil(config.n_routed_experts / SLICE_MAX_EXPERT) self.expert_chunk = self.config.n_routed_experts // self.expert_slice @@ -762,9 +761,9 @@ def allocate(self, inp_seq_len, dtype, device, shape): self.inp_seq_len = inp_seq_len self.cache = torch.zeros(shape, dtype=dtype, device=device) else: - assert ( - self.inp_seq_len == inp_seq_len - ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + assert self.inp_seq_len == inp_seq_len, ( + f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + ) self.cache.fill_(0) def update(self, prev, cur, dim, idx, inp_seq_len):