diff --git a/examples/text-generation/README.md b/examples/text-generation/README.md index 2844e8fb24..734ac3d2e3 100644 --- a/examples/text-generation/README.md +++ b/examples/text-generation/README.md @@ -251,7 +251,7 @@ You will also need to add `--torch_compile` in your command. ### Running with FP8 -Llama2-70b, Llama2-7b, Llama3-70b, Llama3-8b, Mixtral-8x7B, Falcon-7B, Falcon-40B, and Falcon-180B in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch. +Llama2-70b, Llama2-7b, Llama3-70b, Llama3-8b, Mixtral-8x7B, Falcon-7B, Falcon-40B, Falcon-180B and phi-2 in FP8 are enabled using the Quantization Toolkit (HQT), which provides model measurement and quantization capabilities in PyTorch. More information on enabling fp8 in SynapseAI is available here: https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_FP8.html @@ -363,6 +363,36 @@ QUANT_CONFIG=./quantization_config/maxabs_quant.json python ../gaudi_spawn.py \ --trim_logits \ --fp8 ``` + +Here is an example to measure the tensor quantization statistics on phi-2 with 1 card: + +```bash +QUANT_CONFIG=./quantization_config/maxabs_measure.json python run_lm_eval.py \ +-o acc_phi-2_bs1_measure.txt \ +--model_name_or_path microsoft/phi-2 \ +--use_hpu_graphs \ +--use_kv_cache \ +--max_new_tokens 100 \ +--batch_size 1 \ +--trim_logits \ +--reuse_cache \ +--bf16 +``` + +Here is an example to quantize the model based on previous measurements for phi-2 with 1 card: +```bash +QUANT_CONFIG=./quantization_config/maxabs_quant_phi.json python run_generation.py \ +--model_name_or_path microsoft/phi-2 \ +--use_hpu_graphs \ +--use_kv_cache \ +--max_new_tokens 100 \ +--batch_size 1 \ +--bf16 \ +--trim_logits \ +--reuse_cache \ +--fp8 +``` + `--fp8` is required to enable quantization in fp8. diff --git a/examples/text-generation/quantization_config/maxabs_quant_phi.json b/examples/text-generation/quantization_config/maxabs_quant_phi.json new file mode 100644 index 0000000000..8f13c2aa38 --- /dev/null +++ b/examples/text-generation/quantization_config/maxabs_quant_phi.json @@ -0,0 +1,14 @@ +{ + "method": "HOOKS", + "mode": "QUANTIZE", + "observer": "maxabs", + "scale_method": "maxabs_hw", + "allowlist": {"types": [], "names": []}, + "blocklist": {"types": [], "names": [ + "matmul_qk", + "matmul_av", + "lm_head" + ]}, + "dump_stats_path": "./hqt_output/measure", + "dump_stats_xlsx_path": "./hqt_output/measure/fp8stats.xlsx" +} diff --git a/examples/text-generation/run_lm_eval.py b/examples/text-generation/run_lm_eval.py index 3ea74a6a69..8682a28d35 100644 --- a/examples/text-generation/run_lm_eval.py +++ b/examples/text-generation/run_lm_eval.py @@ -75,7 +75,7 @@ def __init__(self, tokenizer, model, args, options): self.options = options self._device = args.device self.model_inputs = {"use_cache": self.options.use_cache} - if self.model.config.model_type in ["llama", "mistral", "falcon"]: + if self.model.config.model_type in ["llama", "mistral", "falcon", "phi"]: self.model_inputs.update( { "reuse_cache": self.options.reuse_cache, diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index 50f23e32ba..460d82effc 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -609,7 +609,8 @@ def generate( "mistral", "falcon", "mixtral", - ], "reuse_cache only supported by llama, mistral, falcon and mixtral at the moment" + "phi", + ], "reuse_cache only supported by llama, mistral, falcon, mixtral and phi at the moment" if not generation_config.bucket_internal: assert ( generation_config.bucket_size <= 0 diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py index 5e1cbc290d..35ec6c0c1a 100644 --- a/optimum/habana/transformers/modeling_utils.py +++ b/optimum/habana/transformers/modeling_utils.py @@ -62,7 +62,10 @@ GaudiOPTForCausalLM, GaudiOPTLearnedPositionalEmbedding, GaudiPersimmonForCausalLM, + GaudiPhiAttention, + GaudiPhiDecoderLayer, GaudiPhiForCausalLM, + GaudiPhiModel, GaudiQwen2DecoderLayer, GaudiQwen2ForCausalLM, GaudiStableLmForCausalLM, @@ -132,9 +135,6 @@ gaudi_persimmon_attention_forward, gaudi_persimmon_decoder_layer_forward, gaudi_persimmon_model_forward, - gaudi_phi_attention_forward, - gaudi_phi_decoder_layer_forward, - gaudi_phi_model_forward, gaudi_qwen2_attention_forward, gaudi_qwen2_model_forward, gaudi_rot_matmul, @@ -366,9 +366,9 @@ def adapt_transformers_to_gaudi(): # Optimization for phi on Gaudi transformers.models.phi.modeling_phi.PhiForCausalLM = GaudiPhiForCausalLM - transformers.models.phi.modeling_phi.PhiAttention.forward = gaudi_phi_attention_forward - transformers.models.phi.modeling_phi.PhiDecoderLayer.forward = gaudi_phi_decoder_layer_forward - transformers.models.phi.modeling_phi.PhiModel.forward = gaudi_phi_model_forward + transformers.models.phi.modeling_phi.PhiAttention = GaudiPhiAttention + transformers.models.phi.modeling_phi.PhiDecoderLayer = GaudiPhiDecoderLayer + transformers.models.phi.modeling_phi.PhiModel = GaudiPhiModel # Optimization for gemma on Gaudi transformers.models.gemma.modeling_gemma.GemmaForCausalLM = GaudiGemmaForCausalLM diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py index 351b482d5a..87dc38b1e5 100644 --- a/optimum/habana/transformers/models/__init__.py +++ b/optimum/habana/transformers/models/__init__.py @@ -134,10 +134,10 @@ gaudi_persimmon_model_forward, ) from .phi import ( + GaudiPhiAttention, + GaudiPhiDecoderLayer, GaudiPhiForCausalLM, - gaudi_phi_attention_forward, - gaudi_phi_decoder_layer_forward, - gaudi_phi_model_forward, + GaudiPhiModel, ) from .qwen2 import ( GaudiQwen2DecoderLayer, diff --git a/optimum/habana/transformers/models/phi/__init__.py b/optimum/habana/transformers/models/phi/__init__.py index 1a98f45f51..f7429a79df 100644 --- a/optimum/habana/transformers/models/phi/__init__.py +++ b/optimum/habana/transformers/models/phi/__init__.py @@ -1,6 +1,6 @@ from .modeling_phi import ( + GaudiPhiAttention, + GaudiPhiDecoderLayer, GaudiPhiForCausalLM, - gaudi_phi_attention_forward, - gaudi_phi_decoder_layer_forward, - gaudi_phi_model_forward, + GaudiPhiModel, ) diff --git a/optimum/habana/transformers/models/phi/modeling_phi.py b/optimum/habana/transformers/models/phi/modeling_phi.py index ff20dfdab6..872d1e7f4b 100644 --- a/optimum/habana/transformers/models/phi/modeling_phi.py +++ b/optimum/habana/transformers/models/phi/modeling_phi.py @@ -27,7 +27,14 @@ from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.phi.modeling_phi import PhiForCausalLM, apply_rotary_pos_emb, repeat_kv +from transformers.models.phi.configuration_phi import PhiConfig +from transformers.models.phi.modeling_phi import ( + PhiAttention, + PhiDecoderLayer, + PhiForCausalLM, + PhiModel, + apply_rotary_pos_emb, +) from transformers.utils import logging from ...modeling_attn_mask_utils import ( @@ -38,299 +45,466 @@ logger = logging.get_logger(__name__) -def gaudi_phi_attention_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: +def gaudi_phi_repeat_kv( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + n_rep: int, +): """ - Copied from PhiAttention.forward: https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py + Copied from repeat_kv: https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/phi/modeling_phi.py The only differences are: - - add new args token_idx + - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them. + - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion. + The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) + The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) """ - bsz, q_len, _ = hidden_states.size() + batch, num_key_value_heads, kv_len, head_dim = key_states.shape + if n_rep == 1 or num_key_value_heads == 1: + return query_states, key_states, value_states, attention_mask - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) + key_states = key_states.reshape(new_kv_shape) + value_states = value_states.reshape(new_kv_shape) - if self.qk_layernorm: - query_states = self.q_layernorm(query_states) - key_states = self.k_layernorm(key_states) + batch, _, q_len, head_dim = query_states.shape + new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) + query_states = query_states.reshape(new_q_shape) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + if attention_mask is not None: + # Add groups dim and set to 1 + attention_mask = attention_mask.unsqueeze(1) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - kv_seq_len = past_key_value.key_cache[self.layer_idx].shape[-2] + return query_states, key_states, value_states, attention_mask + + +class Matmul(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + +class KVCache(torch.nn.Module): + def __init__(self): + super(KVCache, self).__init__() + self.cache = None + self.inp_seq_len = -1 + + def allocate(self, inp_seq_len, dtype, device, shape): + if self.cache is None or self.cache.shape != shape: + self.inp_seq_len = inp_seq_len + self.cache = torch.zeros(shape, dtype=dtype, device=device) else: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - - # Partial rotary embedding - query_rot, query_pass = ( - query_states[..., : self.rotary_emb.dim], - query_states[..., self.rotary_emb.dim :], - ) - key_rot, key_pass = ( - key_states[..., : self.rotary_emb.dim], - key_states[..., self.rotary_emb.dim :], - ) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) - - # [batch_size, seq_length, num_heads, head_dim] - query_states = torch.cat((query_rot, query_pass), dim=-1) - key_states = torch.cat((key_rot, key_pass), dim=-1) - - if past_key_value is not None: - if token_idx is not None: - if 0 <= self.layer_idx < len(past_key_value.key_cache): - past_key_value.key_cache[self.layer_idx].index_copy_(2, token_idx - 1, key_states) - past_key_value.value_cache[self.layer_idx].index_copy_(2, token_idx - 1, value_states) - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] - else: - past_key_value.key_cache.append(key_states) - past_key_value.value_cache.append(value_states) + assert ( + self.inp_seq_len == inp_seq_len + ), f"inp_seq_len must be the same. self.inp_seq_len:{self.inp_seq_len} inp_seq_len:{inp_seq_len}" + self.cache.fill_(0) + + def update(self, prev, cur, dim, idx, inp_seq_len): + orig_cur = cur + if prev.shape == cur.shape: + prev.copy_(cur) + return orig_cur + if cur.shape[2] > 1 and cur.shape[2] <= prev.shape[2]: + # Initialize + prev[:, :, :inp_seq_len, :].copy_(cur) + return orig_cur + assert cur.shape[2] == 1, f"Cannot update kv-cache. Unsupported shapes. prev:{prev.shape} cur:{cur.shape}" + if idx is not None: + prev.index_copy_(dim, idx - 1, cur) + return prev else: - cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + return torch.cat((prev, cur), dim=dim) + + def get_shape(self): + if self.cache is None: + return None + return self.cache.shape + + def forward(self, cur, dim, idx): + return self.update(self.cache, cur, dim, idx, self.inp_seq_len) + + +class GaudiPhiAttention(PhiAttention): + def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.matmul_qk = Matmul() + self.matmul_av = Matmul() + self.k_cache = KVCache() + self.v_cache = KVCache() + self.inp_seq_len = -1 + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) + device = self.k_proj.weight.device + dtype = self.config.torch_dtype + self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) + self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: Optional[int] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Copied from PhiAttention.forward: https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py + The only differences are: + - add new args token_idx + - optimize KV cache + - add new args reuse_cache + - add new args cache_idx + """ + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_shape = ( + (past_key_value[0][-2] if reuse_cache else past_key_value[0].shape[-2]) + if isinstance(past_key_value, tuple) + else past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + ) + if token_idx is not None: + kv_seq_len = kv_shape + else: + kv_seq_len += kv_shape - # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow - attn_weights = torch.matmul( - query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) - ) / math.sqrt(self.head_dim) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_emb.dim], + query_states[..., self.rotary_emb.dim :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_emb.dim], + key_states[..., self.rotary_emb.dim :], ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if use_cache: + # reuse k, v, self_attention + if reuse_cache: + key_states = self.k_cache(key_states, 2, token_idx) + value_states = self.v_cache(value_states, 2, token_idx) + past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) + else: + if past_key_value is None: + past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) + past_value = torch.zeros( + key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device + ) + past_key_value = (past_key, past_value) + key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) + value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) + if token_idx is None: + past_key_value = (key_states, value_states) + + if cache_idx is not None and q_len == 1: + key_states = key_states[:, :, :cache_idx, :] + value_states = value_states[:, :, :cache_idx, :] + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, :cache_idx] + kv_seq_len = key_states.shape[-2] + else: + past_key_value = None + + query_states, key_states, value_states, attention_mask = gaudi_phi_repeat_kv( + query_states, key_states, value_states, attention_mask, self.num_key_value_groups + ) + + # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow + attn_weights = self.matmul_qk( + query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) + ) / math.sqrt(self.head_dim) + + if attn_weights.size() not in [ + (bsz, self.num_heads, q_len, kv_seq_len), + (bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len), + ]: raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)} or" + f" {(bsz, self.num_key_value_heads, self.num_key_value_groups, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" ) - attn_weights = attn_weights + attention_mask - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + if attention_mask is not None: + if attention_mask.size() not in [(bsz, 1, q_len, kv_seq_len), (bsz, 1, 1, q_len, kv_seq_len)]: + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)} or {(bsz, 1, 1, q_len, kv_seq_len)}," + f" but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask - attn_output = torch.matmul(attn_weights, value_states) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + attn_output = self.matmul_av(attn_weights, value_states) + attn_output = attn_output.reshape(bsz, -1, q_len, self.head_dim) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) - attn_output = self.dense(attn_output) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - if not output_attentions: - attn_weights = None + attn_output = self.dense(attn_output) - return attn_output, attn_weights, past_key_value + if not output_attentions: + attn_weights = None + return attn_output, attn_weights, past_key_value -def gaudi_phi_decoder_layer_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - token_idx: Optional[torch.Tensor] = None, - **kwargs, -) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Copied from PhiDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py - The only differences are: - - add new args token_idx - """ - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - attn_outputs, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - token_idx=token_idx, - ) - attn_outputs = self.resid_dropout(attn_outputs) - - feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) - hidden_states = attn_outputs + feed_forward_hidden_states + residual - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -def gaudi_phi_model_forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - token_idx: Optional[torch.Tensor] = None, -) -> Union[Tuple, BaseModelOutputWithPast]: - """ - Copied from PhiModel.forward: https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py - The only differences are: - - add new args token_idx - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False +class GaudiPhiDecoderLayer(PhiDecoderLayer): + def __init__(self, config: PhiConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.self_attn = GaudiPhiAttention(config, layer_idx) + + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: Optional[int] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Copied from PhiDecoderLayer.forward: https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py + The only differences are: + - add new args token_idx + - add new args reuse_cache + - add new args cache_idx + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) - past_key_values_length = 0 - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) - - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + # Self Attention + attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, ) - position_ids = position_ids.unsqueeze(0) + attn_outputs = self.resid_dropout(attn_outputs) - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) + hidden_states = attn_outputs + feed_forward_hidden_states + residual + outputs = (hidden_states,) - inputs_embeds = self.embed_dropout(inputs_embeds) + if output_attentions: + outputs += (self_attn_weights,) - # 4d mask is passed through the layers - attention_mask = _gaudi_prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) + if use_cache: + outputs += (present_key_value,) - hidden_states = inputs_embeds + return outputs - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = None - for decoder_layer in self.layers: - if output_hidden_states: - all_hidden_states += (hidden_states,) +class GaudiPhiModel(PhiModel): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + for layer in self.layers: + layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, - hidden_states, - attention_mask, - position_ids, - past_key_values, - output_attentions, - ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + cache_idx: Optional[int] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + """ + Copied from PhiModel.forward: https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py + The only differences are: + - add new args token_idx + - add new args reuse_cache + - add new args cache_idx + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - token_idx=token_idx, + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) + use_cache = False - hidden_states = layer_outputs[0] + use_legacy_cache = True + use_new_cache = False + past_seen_tokens = 0 + if past_key_values is not None and use_cache: + if reuse_cache: + past_seen_tokens = past_key_values[0][0][2] + else: + if use_new_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + else: + past_seen_tokens = past_key_values[0][0].shape[2] - if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_seen_tokens, seq_length + past_seen_tokens, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) - if output_attentions: - all_self_attns += (layer_outputs[1],) + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = self.embed_dropout(inputs_embeds) + + # 4d mask is passed through the layers + attention_mask = _gaudi_prepare_4d_causal_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_seen_tokens + ) - hidden_states = self.final_layernorm(hidden_states) + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if not use_new_cache else None + + for layer_idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + None if past_key_values is None else past_key_values[layer_idx], + output_attentions, + use_cache, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=None if past_key_values is None else past_key_values[layer_idx], + output_attentions=output_attentions, + use_cache=use_cache, + token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, + ) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) + hidden_states = layer_outputs[0] - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.final_layernorm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) class GaudiPhiForCausalLM(PhiForCausalLM): + def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): + self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) + def forward( self, input_ids: torch.LongTensor = None, @@ -344,11 +518,16 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, token_idx: Optional[torch.Tensor] = None, + reuse_cache: Optional[bool] = False, + trim_logits: Optional[bool] = False, + cache_idx: Optional[int] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: """ Inherits from PhiForCausalLM: https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py The only differences are: - add new args token_idx + - add new args reuse_cache + - add new args cache_idx """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -369,9 +548,17 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, token_idx=token_idx, + reuse_cache=reuse_cache, + cache_idx=cache_idx, ) hidden_states = outputs[0] + _, seq_len, _ = hidden_states.shape + if seq_len > 1 and trim_logits and not self.training: + if token_idx is not None: + hidden_states = hidden_states.index_select(1, token_idx - 1) + else: + hidden_states = hidden_states[:, -1, :] logits = self.lm_head(hidden_states) logits = logits.float() @@ -401,7 +588,7 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_idx=None, **kwargs ): """ Inherits from PhiForCausalLM: https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py @@ -411,11 +598,13 @@ def prepare_inputs_for_generation( - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx """ - token_idx = kwargs.get("token_idx", None) - + past_length = 0 + reuse_cache = kwargs.get("reuse_cache") # Omit tokens covered by past_key_values if past_key_values is not None: - if token_idx is None: + if token_idx is not None: + input_ids = torch.index_select(input_ids, 1, token_idx - 1) + else: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() past_length = past_key_values.seen_tokens @@ -443,8 +632,10 @@ def prepare_inputs_for_generation( and cache_length + input_ids.shape[1] > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] - else: - input_ids = torch.index_select(input_ids, 1, token_idx - 1) + elif reuse_cache and token_idx is not None: + # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass + input_ids = input_ids[:, :token_idx] + attention_mask = attention_mask[:, :token_idx] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -470,6 +661,9 @@ def prepare_inputs_for_generation( "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "token_idx": token_idx, + "reuse_cache": kwargs.get("reuse_cache"), + "trim_logits": kwargs.get("trim_logits"), + "cache_idx": kwargs.get("cache_idx"), } ) return model_inputs diff --git a/tests/test_text_generation_example.py b/tests/test_text_generation_example.py index 5956c3f8bf..88dec45ab1 100644 --- a/tests/test_text_generation_example.py +++ b/tests/test_text_generation_example.py @@ -41,6 +41,7 @@ ("mistralai/Mixtral-8x7B-v0.1", 39.26845661768185), ("meta-llama/Llama-2-7b-hf", 0.0), ("meta-llama/Llama-2-70b-hf", 0.0), + ("microsoft/phi-2", 254.08932787178165), ], "deepspeed": [ ("bigscience/bloomz", 36.77314954096159),