diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index d58e1a1b2..ad446d1a2 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -13,7 +13,7 @@ import QEfficient from QEfficient.cloud.export import get_onnx_model_path from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv -from QEfficient.utils import check_and_assign_cache_dir, get_qpc_dir_path, load_hf_tokenizer, qpc_exists +from QEfficient.utils import check_and_assign_cache_dir, get_embeddings, get_qpc_dir_path, load_hf_tokenizer, qpc_exists from QEfficient.utils.logging_utils import logger @@ -72,6 +72,7 @@ def main( cache_dir=cache_dir, hf_token=hf_token, ) + embeds, config = get_embeddings(model_name, hf_token, cache_dir, local_model_dir) qpc_dir_path = get_qpc_dir_path( model_name, num_cores, mos, batch_size, prompt_len, ctx_len, mxfp6, mxint8, device_group, full_batch_size @@ -111,6 +112,8 @@ def main( ######### cloud_ai_100_exec_kv( tokenizer=tokenizer, + config=config, + embeddings=embeds, qpc_path=qpc_dir_path, device_id=device_group, prompt=prompt, diff --git a/QEfficient/exporter/export_hf_to_cloud_ai_100.py b/QEfficient/exporter/export_hf_to_cloud_ai_100.py index 55f2ac3be..0995f3bcb 100644 --- a/QEfficient/exporter/export_hf_to_cloud_ai_100.py +++ b/QEfficient/exporter/export_hf_to_cloud_ai_100.py @@ -199,18 +199,23 @@ def export_kvstyle_transformed_model_to_onnx( raise ValueError(f"Need seq_len to be greater than zero, got seq_len={seq_len}") # Preprocess inputs + embeds = None + if model_name == "CohereForAI/c4ai-command-r-v01": + embeds = transformed_model.get_input_embeddings() + # inputs['inputs_embeds']=embeds(inputs.pop('input_ids')) # Build inputs for prefill input_handler = InputHandler( batch_size=len(Constants.INPUT_STR), tokenizer=tokenizer, + embeddings=embeds, config=transformed_model.config, prompt=Constants.INPUT_STR, prompt_len=Constants.PROMPT_LEN, ctx_len=seq_len, full_batch_size=full_batch_size, ) - inputs = input_handler.prepare_pytorch_inputs() + pt_outputs = transformed_model(**inputs) output_names = list(pt_outputs.keys()) diff --git a/QEfficient/exporter/export_utils.py b/QEfficient/exporter/export_utils.py index d7da3ae04..9af404ca0 100644 --- a/QEfficient/exporter/export_utils.py +++ b/QEfficient/exporter/export_utils.py @@ -61,6 +61,7 @@ def export_onnx( # Create dynamic axes dict for inputs that need to have dynamic input shapes seq_len_inputs = { "input_ids", + "inputs_embeds", "attention_mask", "position_ids", "token_type_ids", diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 0ddb0acc9..0a3614128 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -13,8 +13,9 @@ from typing import Dict, List, Optional, Tuple, Union import numpy as np +import torch import transformers -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import padding_check_and_fix @@ -221,6 +222,8 @@ def print_latency_stats_kv(prompt, exec_info, automation: bool = False): def cloud_ai_100_exec_kv( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + config: AutoConfig, + embeddings: torch.Tensor, qpc_path: str, prompt: Optional[str] = None, prompts_txt_file_path: Optional[str] = None, @@ -269,6 +272,8 @@ def cloud_ai_100_exec_kv( generate_text = TextGeneration( tokenizer=tokenizer, prompt=prompt, + embeddings=embeddings, + config=config, qpc_path=qpc_path, device_id=device_id, ctx_len=ctx_len, @@ -310,6 +315,8 @@ class TextGeneration: def __init__( self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + embeddings: torch.Tensor, + config: AutoConfig, qpc_path: str, prompt: List[str], full_batch_size: Optional[int] = None, @@ -321,6 +328,8 @@ def __init__( write_io_dir: Optional[str] = None, ) -> None: self.tokenizer = tokenizer + self.embeddings = embeddings + self.config = config self.prompt = prompt self.qpc_path = qpc_path self.device_id = device_id @@ -404,12 +413,20 @@ def _fetch_batch_size_prefill_seq_len( prefill_seq_len: The prefill sequence length fetched from the session's bindings or allowed shapes. """ if self.session.allowed_shapes: - batch_size = max( - [x[self.session.binding_index_map["input_ids"]][1][0] for x in self.session.allowed_shapes] - ) - prefill_seq_len = max( - [x[self.session.binding_index_map["input_ids"]][1][1] for x in self.session.allowed_shapes] - ) + if "input_ids" in self.session.binding_index_map: + batch_size = max( + [x[self.session.binding_index_map["input_ids"]][1][0] for x in self.session.allowed_shapes] + ) + prefill_seq_len = max( + [x[self.session.binding_index_map["input_ids"]][1][1] for x in self.session.allowed_shapes] + ) + else: + batch_size = max( + [x[self.session.binding_index_map["inputs_embeds"]][1][0] for x in self.session.allowed_shapes] + ) + prefill_seq_len = max( + [x[self.session.binding_index_map["inputs_embeds"]][1][1] for x in self.session.allowed_shapes] + ) else: batch_size, prefill_seq_len = self.session.bindings[self.session.binding_index_map["input_ids"]].dims return batch_size, prefill_seq_len @@ -460,7 +477,8 @@ def prepare_decode_inputs(self): decode_inputs["position_ids"] = self.decode_pos_ids if self.batch_index is not None: decode_inputs["batch_index"] = self.batch_index - + if self.config.architectures[0] == "CohereForCausalLM": + decode_inputs["inputs_embeds"] = self.embeddings(torch.tensor(decode_inputs["input_ids"])).detach().numpy() return decode_inputs def _update_decode_input(self, outputs, position_ids, generation_len, decode_batch_id=None): @@ -557,6 +575,10 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i chunk_inputs["position_ids"] = inputs["position_ids"][ :, i * self.prefill_seq_len : (i + 1) * self.prefill_seq_len ] + if self.config.architectures[0] == "CohereForCausalLM": + chunk_inputs["inputs_embeds"] = ( + self.embeddings(torch.tensor(chunk_inputs.pop("input_ids"))).detach().numpy() + ) outputs = self.session.run(chunk_inputs) if self.write_io_dir is not None: write_io_files(inputs, outputs, self.write_io_dir, "prefill", "aic_batch_io", True, False) @@ -656,6 +678,10 @@ def run_decode(self, decode_inputs, generation_len): for num_token in range(1, generation_len): if self.stream: self.streamer.put(decode_inputs["input_ids"][0]) + if self.config.architectures[0] == "CohereForCausalLM": + decode_inputs["inputs_embeds"] = ( + self.embeddings(torch.tensor(decode_inputs.pop("input_ids"))).detach().numpy() + ) outputs = self.session.run(decode_inputs) if self.write_io_dir is not None: diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 91c886c5f..0354b27ec 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -15,6 +15,13 @@ CodeGenForCausalLM, CodeGenModel, ) +from transformers.models.cohere.modeling_cohere import ( + CohereAttention, + CohereDecoderLayer, + CohereForCausalLM, + CohereModel, + CohereRotaryEmbedding, +) from transformers.models.falcon.modeling_falcon import ( FalconAttention, FalconForCausalLM, @@ -83,6 +90,13 @@ QEffCodeGenForCausalLM, QEffCodeGenModel, ) +from .models.cohere.modeling_cohere import ( + QEffCohereAttention, + QEffCohereDecoderLayer, + QEffCohereForCausalLM, + QEffCohereModel, + QEffCohereRotaryEmbedding, +) from .models.falcon.modeling_falcon import ( QEffFalconAttention, QEffFalconForCausalLM, @@ -154,6 +168,7 @@ MptForCausalLM.__name__, FalconForCausalLM.__name__, GPTBigCodeForCausalLM.__name__, + CohereForCausalLM.__name__, ] ) # Create an instance of the named tuple @@ -174,6 +189,7 @@ Qwen2ForCausalLM.__name__, Starcoder2ForCausalLM.__name__, GPTBigCodeForCausalLM.__name__, + CohereForCausalLM.__name__, ] ) @@ -217,6 +233,12 @@ CodeGenModel: QEffCodeGenModel, CodeGenForCausalLM: QEffCodeGenForCausalLM, CodeGenBlock: QeffCodeGenBlock, + # Cohere + CohereForCausalLM: QEffCohereForCausalLM, + CohereAttention: QEffCohereAttention, + CohereModel: QEffCohereModel, + CohereRotaryEmbedding: QEffCohereRotaryEmbedding, + CohereDecoderLayer: QEffCohereDecoderLayer, # Mistral model layers MistralAttention: QEffMistralAttention, MistralDecoderLayer: QEffMistralDecoderLayer, diff --git a/QEfficient/transformers/models/cohere/__init__.py b/QEfficient/transformers/models/cohere/__init__.py new file mode 100644 index 000000000..da26921c5 --- /dev/null +++ b/QEfficient/transformers/models/cohere/__init__.py @@ -0,0 +1,7 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + diff --git a/QEfficient/transformers/models/cohere/modeling_cohere.py b/QEfficient/transformers/models/cohere/modeling_cohere.py new file mode 100644 index 000000000..e2b2ced39 --- /dev/null +++ b/QEfficient/transformers/models/cohere/modeling_cohere.py @@ -0,0 +1,571 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# + + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from transformers.models.cohere.modeling_cohere import ( + CohereAttention, + CohereConfig, + CohereDecoderLayer, + CohereForCausalLM, + CohereModel, + CohereRotaryEmbedding, + logger, + repeat_kv, + rotate_half, +) + +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask + + +class QEffCohereRotaryEmbedding(CohereRotaryEmbedding): + def __init__( + self, + dim=None, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + rope_type="default", + config: Optional[CohereConfig] = None, + ): + super(CohereRotaryEmbedding, self).__init__() + # TODO (joao): remove the `if` below, only used for BC + self.rope_kwargs = {} + if config is None: + logger.warning_once( + "`CohereRotaryEmbedding` can now be fully parameterized by passing the model config through the " + "`config` argument. All other arguments will be removed in v4.46" + ) + self.rope_kwargs = { + "rope_type": rope_type, + "factor": scaling_factor, + "dim": dim, + "base": base, + "max_position_embeddings": max_position_embeddings, + } + self.rope_type = rope_type + self.max_seq_len_cached = max_position_embeddings + self.original_max_seq_len = max_position_embeddings + else: + # BC: "rope_type" was originally "type" + if config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.repeat_interleave(freqs, 2, dim=1) # This line differs from Llama's implementation + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + @torch.no_grad() + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +class QEffCohereAttention(CohereAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None): + super().__init__(config, layer_idx) + self.config = config + self.__qeff_init__() + + def __qeff_init__(self): + self.config.rope_scaling = None + self.rotary_emb = QEffCohereRotaryEmbedding(config=self.config) + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + if self.use_qk_norm: + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + kv_seq_len = key_states.shape[-2] + past_key_value = getattr(self, "past_key_value", past_key_value) + if past_key_value is not None: + kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = qeff_apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids=position_ids + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "position_ids": position_ids, "batch_index": batch_index} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class QEffCohereDecoderLayer(CohereDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states_attention, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + batch_index=batch_index, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + # Fully Connected + hidden_states_mlp = self.mlp(hidden_states) + + # Add everything together + hidden_states = residual + hidden_states_attention + hidden_states_mlp + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffCohereModel(CohereModel): + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + past_seen_tokens = 0 + return_legacy_cache = False + if ( + use_cache and not isinstance(past_key_values, Cache) and not self.training + ): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, position_ids, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + batch_index=batch_index, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + else: + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=target_length) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class QEffCohereForCausalLM(CohereForCausalLM): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: CohereConfig): + super().__init__() + self.__qeff_init__() + + def __qeff_init__(self): + lm_head_weights = self.lm_head.weight.data.split(64000) + self.lm_heads = torch.nn.ModuleList() + for i in range(4): + lm_head_i = torch.nn.Linear(8192, 64000, bias=False) # hiddensize-8192 + lm_head_i.weight.data = lm_head_weights[i] + self.lm_heads.append(lm_head_i) + return + + # Ignore copy + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + batch_index=batch_index, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_idx] + logits = torch.cat([head_i(hidden_states) for head_i in self.lm_heads], 2) + logits = logits * self.logit_scale + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/QEfficient/transformers/pytorch_transforms.py b/QEfficient/transformers/pytorch_transforms.py index 8ce30f61f..3a4c7c996 100644 --- a/QEfficient/transformers/pytorch_transforms.py +++ b/QEfficient/transformers/pytorch_transforms.py @@ -15,6 +15,13 @@ CodeGenForCausalLM, CodeGenModel, ) +from transformers.models.cohere.modeling_cohere import ( + CohereAttention, + CohereDecoderLayer, + CohereForCausalLM, + CohereModel, + CohereRotaryEmbedding, +) from transformers.models.falcon.modeling_falcon import ( FalconAttention, FalconDecoderLayer, @@ -97,6 +104,13 @@ QEffCodeGenForCausalLM, QEffCodeGenModel, ) +from QEfficient.transformers.models.cohere.modeling_cohere import ( + QEffCohereAttention, + QEffCohereDecoderLayer, + QEffCohereForCausalLM, + QEffCohereModel, + QEffCohereRotaryEmbedding, +) from QEfficient.transformers.models.falcon.modeling_falcon import ( QEffFalconAttention, QEffFalconDecoderLayer, @@ -233,6 +247,12 @@ class KVCacheTransform(ModuleMappingTransform): Gemma2DecoderLayer: QEffGemma2DecoderLayer, Gemma2Model: QEffGemma2Model, Gemma2ForCausalLM: QEffGemma2ForCausalLM, + # Cohere + CohereForCausalLM: QEffCohereForCausalLM, + CohereAttention: QEffCohereAttention, + CohereModel: QEffCohereModel, + CohereRotaryEmbedding: QEffCohereRotaryEmbedding, + CohereDecoderLayer: QEffCohereDecoderLayer, # Mistral MistralAttention: QEffMistralAttention, MistralDecoderLayer: QEffMistralDecoderLayer, diff --git a/QEfficient/utils/__init__.py b/QEfficient/utils/__init__.py index b9efbf720..3513343b3 100755 --- a/QEfficient/utils/__init__.py +++ b/QEfficient/utils/__init__.py @@ -7,6 +7,7 @@ from QEfficient.utils._utils import ( # noqa: F401 check_and_assign_cache_dir, + get_embeddings, get_num_layers_from_config, get_onnx_dir_name, get_padding_shape_from_config, diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 3f1999c45..e4f40e743 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -184,6 +184,22 @@ def load_hf_tokenizer( return tokenizer +def get_embeddings( + model_name: str, + hf_token: Optional[str] = None, + cache_dir: Optional[str] = None, + local_model_dir: Optional[str] = None, +): + from QEfficient.base.common import QEFFCommonLoader + + model_kv = QEFFCommonLoader.from_pretrained( + pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name), + token=hf_token, + cache_dir=cache_dir, + ) + return model_kv.model.get_input_embeddings(), model_kv.model.config + + def get_qpc_dir_path( model_card_name, num_cores, mos, batch_size, prompt_len, ctx_len, mxfp6, mxint8, device_group, full_batch_size ): diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index c45cfec41..95c34fd10 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -12,7 +12,7 @@ class InputHandler: - def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size): + def __init__(self, batch_size, tokenizer, embeddings, config, prompt, prompt_len, ctx_len, full_batch_size): """ Initialization @@ -28,10 +28,13 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f # check and fix tokenizer viability padding_check_and_fix(tokenizer) self.tokenizer = tokenizer + self.embeddings = embeddings + self.config = config self.prompt = prompt self.prompt_len = prompt_len self.ctx_len = ctx_len self.full_batch_size = full_batch_size + # breakpoint() self.n_layer = get_num_layers_from_config(config) self.padding_shape = get_padding_shape_from_config( config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len @@ -51,6 +54,7 @@ def prepare_pytorch_inputs(self): padding=True, ) input_ids = inputs["input_ids"] + batch_size, input_len = input_ids.shape inputs.pop("attention_mask") inputs.pop("token_type_ids", None) @@ -75,7 +79,8 @@ def prepare_pytorch_inputs(self): inputs["input_ids"] = input_ids inputs["position_ids"] = torch.arange(input_len).view(1, input_len) inputs["batch_index"] = torch.arange(1).view(-1, 1) - + if self.config.architectures[0] == "CohereForCausalLM": + inputs["inputs_embeds"] = self.embeddings(inputs.pop("input_ids")) past_key_values = [] for i in range(self.n_layer): past_key = torch.zeros((self.padding_shape), dtype=torch.float32) @@ -114,7 +119,8 @@ def update_pytorch_inputs(self, inputs, pt_outputs): else: updated_inputs["input_ids"] = pt_outputs["logits"].argmax(-1).reshape(-1, 1) updated_inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1 - + if self.config.architectures[0] == "CohereForCausalLM": + updated_inputs["inputs_embeds"] = self.embeddings(updated_inputs.pop("input_ids")) updated_inputs["past_key_values"] = tuple( [(key.detach(), value.detach()) for key, value in pt_outputs["past_key_values"]] ) @@ -147,7 +153,9 @@ def prepare_ort_inputs(self): [position_ids, np.full((batch_size, self.prompt_len - input_len), -1)], axis=1, ).astype(np.int64) - + if self.config.architectures[0] == "CohereForCausalLM": + # breakpoint() + inputs["inputs_embeds"] = self.embeddings(torch.tensor(inputs.pop("input_ids"))).numpy() for i in range(self.n_layer): inputs["past_key." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) inputs["past_value." + str(i)] = np.zeros((self.padding_shape), dtype=np.float32) @@ -169,6 +177,8 @@ def update_ort_inputs(self, inputs, ort_outputs): updated_inputs = {} updated_inputs["input_ids"] = ort_outputs["logits"].argmax(-1) updated_inputs["position_ids"] = np.max(inputs["position_ids"], axis=1, keepdims=True) + 1 + if self.config.architectures[0] == "CohereForCausalLM": + updated_inputs["inputs_embeds"] = self.embeddings(torch.tensor(updated_inputs.pop("input_ids"))).numpy() for i in range(self.n_layer): updated_inputs["past_key." + str(i)] = ort_outputs["past_key_values"][i * 2] updated_inputs["past_value." + str(i)] = ort_outputs["past_key_values"][i * 2 + 1] diff --git a/QEfficient/utils/run_utils.py b/QEfficient/utils/run_utils.py index 6f3e6035b..33740c0cf 100644 --- a/QEfficient/utils/run_utils.py +++ b/QEfficient/utils/run_utils.py @@ -28,7 +28,7 @@ class ApiRunner: 4. ``ONNX`` model on Cloud AI 100 """ - def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size=None): + def __init__(self, batch_size, tokenizer, embeddings, config, prompt, prompt_len, ctx_len, full_batch_size=None): """ Initialization @@ -43,6 +43,7 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f self.input_handler = InputHandler( batch_size=batch_size, tokenizer=tokenizer, + embeddings=embeddings, config=config, prompt=prompt, prompt_len=prompt_len, @@ -227,6 +228,8 @@ def run_kv_model_on_cloud_ai_100(self, qpc_path, device_group=None): """ execinfo = TextGeneration( tokenizer=self.input_handler.tokenizer, + embeddings=self.input_handler.embeddings, + config=self.input_handler.config, prompt=self.input_handler.prompt, qpc_path=qpc_path, device_id=device_group, diff --git a/tests/transformers/models/test_causal_lm_models.py b/tests/transformers/models/test_causal_lm_models.py index e8204a38b..25bdf875c 100644 --- a/tests/transformers/models/test_causal_lm_models.py +++ b/tests/transformers/models/test_causal_lm_models.py @@ -38,6 +38,7 @@ "TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", # AWQ model "TheBloke/Llama-2-7B-GPTQ", # GPTQ model "ibm-granite/granite-20b-code-base", + "CohereForAI/c4ai-command-r-v01", ] @@ -90,13 +91,14 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): model_config["n_layer"] = n_layer model_hf, _ = load_causal_lm_model(model_config) - + embeds = model_hf.get_input_embeddings() if model_name == "CohereForAI/c4ai-command-r-v01" else None tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) config = model_hf.config batch_size = len(Constants.INPUT_STR) api_runner = ApiRunner( batch_size, tokenizer, + embeds, config, Constants.INPUT_STR, Constants.PROMPT_LEN, @@ -142,6 +144,7 @@ def test_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): api_runner = ApiRunner( batch_size, tokenizer, + embeds, config, fbs_prompts, Constants.PROMPT_LEN,