Skip to content

Commit

Permalink
Formatting done
Browse files Browse the repository at this point in the history
Signed-off-by: Ann <[email protected]>
  • Loading branch information
quic-akuruvil committed Nov 13, 2024
1 parent d5c622b commit 5eaf4bb
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 98 deletions.
6 changes: 3 additions & 3 deletions QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import QEfficient
from QEfficient.cloud.export import get_onnx_model_path
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.utils import check_and_assign_cache_dir, get_qpc_dir_path, load_hf_tokenizer, qpc_exists, get_embeddings
from QEfficient.utils import check_and_assign_cache_dir, get_embeddings, get_qpc_dir_path, load_hf_tokenizer, qpc_exists
from QEfficient.utils.logging_utils import logger


Expand Down Expand Up @@ -72,12 +72,12 @@ def main(
cache_dir=cache_dir,
hf_token=hf_token,
)
embeds,config = get_embeddings(model_name, hf_token,cache_dir,local_model_dir)
embeds, config = get_embeddings(model_name, hf_token, cache_dir, local_model_dir)

qpc_dir_path = get_qpc_dir_path(
model_name, num_cores, mos, batch_size, prompt_len, ctx_len, mxfp6, mxint8, device_group, full_batch_size
)

# Handle qpc generation
if qpc_exists(qpc_dir_path):
logger.info(f"Pre-compiled qpc found at {qpc_dir_path}! Executing with given prompt")
Expand Down
6 changes: 3 additions & 3 deletions QEfficient/exporter/export_hf_to_cloud_ai_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def export_kvstyle_transformed_model_to_onnx(
raise ValueError(f"Need seq_len to be greater than zero, got seq_len={seq_len}")

# Preprocess inputs
embeds=None
embeds = None
if model_name == "CohereForAI/c4ai-command-r-v01":
embeds = transformed_model.get_input_embeddings()
# inputs['inputs_embeds']=embeds(inputs.pop('input_ids'))
Expand All @@ -215,7 +215,7 @@ def export_kvstyle_transformed_model_to_onnx(
full_batch_size=full_batch_size,
)
inputs = input_handler.prepare_pytorch_inputs()

pt_outputs = transformed_model(**inputs)
output_names = list(pt_outputs.keys())

Expand Down Expand Up @@ -265,7 +265,7 @@ def export_kvstyle_transformed_model_to_onnx(
for i, (key, value) in enumerate(pkv):
inputs[f"past_key.{i}"] = key
inputs[f"past_value.{i}"] = value

# Run onnxrt inference
input_names, ort_outputs = run_model_on_ort(
onnx_path=os.path.join(onnx_dir_path, f"{model_name}.onnx"),
Expand Down
19 changes: 11 additions & 8 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
import numpy as np
import torch
import transformers
from transformers import AutoConfig
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast

from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import padding_check_and_fix
Expand Down Expand Up @@ -478,8 +477,8 @@ def prepare_decode_inputs(self):
decode_inputs["position_ids"] = self.decode_pos_ids
if self.batch_index is not None:
decode_inputs["batch_index"] = self.batch_index
if self.config.architectures[0] == 'CohereForCausalLM':
decode_inputs['inputs_embeds'] = self.embeddings(torch.tensor(decode_inputs['input_ids'])).detach().numpy()
if self.config.architectures[0] == "CohereForCausalLM":
decode_inputs["inputs_embeds"] = self.embeddings(torch.tensor(decode_inputs["input_ids"])).detach().numpy()
return decode_inputs

def _update_decode_input(self, outputs, position_ids, generation_len, decode_batch_id=None):
Expand Down Expand Up @@ -576,8 +575,10 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
chunk_inputs["position_ids"] = inputs["position_ids"][
:, i * self.prefill_seq_len : (i + 1) * self.prefill_seq_len
]
if self.config.architectures[0] == 'CohereForCausalLM':
chunk_inputs['inputs_embeds'] = self.embeddings(torch.tensor(chunk_inputs.pop('input_ids'))).detach().numpy()
if self.config.architectures[0] == "CohereForCausalLM":
chunk_inputs["inputs_embeds"] = (
self.embeddings(torch.tensor(chunk_inputs.pop("input_ids"))).detach().numpy()
)
outputs = self.session.run(chunk_inputs)
if self.write_io_dir is not None:
write_io_files(inputs, outputs, self.write_io_dir, "prefill", "aic_batch_io", True, False)
Expand Down Expand Up @@ -677,8 +678,10 @@ def run_decode(self, decode_inputs, generation_len):
for num_token in range(1, generation_len):
if self.stream:
self.streamer.put(decode_inputs["input_ids"][0])
if self.config.architectures[0] == 'CohereForCausalLM':
decode_inputs['inputs_embeds'] = self.embeddings(torch.tensor(decode_inputs.pop("input_ids"))).detach().numpy()
if self.config.architectures[0] == "CohereForCausalLM":
decode_inputs["inputs_embeds"] = (
self.embeddings(torch.tensor(decode_inputs.pop("input_ids"))).detach().numpy()
)
outputs = self.session.run(decode_inputs)

if self.write_io_dir is not None:
Expand Down
21 changes: 14 additions & 7 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
CodeGenForCausalLM,
CodeGenModel,
)
from transformers.models.cohere.modeling_cohere import (
CohereAttention,
CohereDecoderLayer,
CohereForCausalLM,
CohereModel,
CohereRotaryEmbedding,
)
from transformers.models.falcon.modeling_falcon import (
FalconAttention,
FalconForCausalLM,
Expand All @@ -41,13 +48,6 @@
GPTBigCodeForCausalLM,
GPTBigCodeModel,
)
from transformers.models.cohere.modeling_cohere import (
CohereRotaryEmbedding,
CohereAttention,
CohereModel,
CohereForCausalLM,
CohereDecoderLayer,
)
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJForCausalLM, GPTJModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
Expand Down Expand Up @@ -90,6 +90,13 @@
QEffCodeGenForCausalLM,
QEffCodeGenModel,
)
from .models.cohere.modeling_cohere import (
QEffCohereAttention,
QEffCohereDecoderLayer,
QEffCohereForCausalLM,
QEffCohereModel,
QEffCohereRotaryEmbedding,
)
from .models.falcon.modeling_falcon import (
QEffFalconAttention,
QEffFalconForCausalLM,
Expand Down
78 changes: 29 additions & 49 deletions QEfficient/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,28 @@
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
from transformers.models.cohere.modeling_cohere import (
repeat_kv,
rotate_half,
logger
)
from transformers.models.cohere.modeling_cohere import(
CohereRotaryEmbedding,
CohereAttention,
CohereConfig,
CohereDecoderLayer,
CohereForCausalLM,
CohereModel,
CohereLayerNorm,
CohereDecoderLayer,
)
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
CohereRotaryEmbedding,
logger,
repeat_kv,
rotate_half,
)

from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask



class QEffCohereRotaryEmbedding(CohereRotaryEmbedding):
def __init__(
self,
Expand Down Expand Up @@ -82,10 +78,11 @@ def __init__(
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

self._set_cos_sin_cache(
seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
Expand All @@ -95,7 +92,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
emb = torch.repeat_interleave(freqs, 2, dim=1) # This line differs from Llama's implementation
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

@torch.no_grad()
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
Expand All @@ -108,8 +105,6 @@ def forward(self, x, seq_len=None):
)




def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Expand All @@ -130,18 +125,14 @@ def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""

cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed.to(q.dtype), k_embed.to(k.dtype)






class QEffCohereAttention(CohereAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

Expand All @@ -150,11 +141,9 @@ def __init__(self, config: CohereConfig, layer_idx: Optional[int] = None):
self.config = config
self.__qeff_init__()


def __qeff_init__(self):
self.config.rope_scaling = None
self.rotary_emb = QEffCohereRotaryEmbedding(config=self.config)


# Ignore copy
def forward(
Expand Down Expand Up @@ -187,9 +176,11 @@ def forward(
kv_seq_len = key_states.shape[-2]
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
kv_seq_len = past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin,position_ids=position_ids)
query_states, key_states = qeff_apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids=position_ids
)

if past_key_value is not None:
# sin and cos are specific to RoPE models; position_ids needed for the static cache
Expand Down Expand Up @@ -225,8 +216,9 @@ def forward(

return attn_output, attn_weights, past_key_value


class QEffCohereDecoderLayer(CohereDecoderLayer):
def forward(
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -292,7 +284,6 @@ def forward(


class QEffCohereModel(CohereModel):

def forward(
self,
input_ids: torch.LongTensor = None,
Expand Down Expand Up @@ -326,7 +317,7 @@ def forward(
use_cache = False

if input_ids is not None:
batch_size, seq_length = input_ids.shape
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
if inputs_embeds is None:
Expand All @@ -349,7 +340,7 @@ def forward(
position_ids = cache_position.unsqueeze(0)

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, position_ids,past_key_values, output_attentions
attention_mask, inputs_embeds, cache_position, position_ids, past_key_values, output_attentions
)

# embed positions
Expand Down Expand Up @@ -380,7 +371,7 @@ def forward(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
batch_index= batch_index,
batch_index=batch_index,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
Expand Down Expand Up @@ -455,11 +446,7 @@ def _update_causal_mask(
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens
)
target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens

# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
if attention_mask is not None and attention_mask.dim() == 4:
Expand Down Expand Up @@ -498,27 +485,24 @@ def _update_causal_mask(
return causal_mask



class QEffCohereForCausalLM(CohereForCausalLM):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config: CohereConfig ):
def __init__(self, config: CohereConfig):
super().__init__()
self.__qeff_init__()


def __qeff_init__(self):
lm_head_weights = self.lm_head.weight.data.split(64000)
self.lm_heads = torch.nn.ModuleList()
for i in range(4):
lm_head_i = torch.nn.Linear(8192, 64000, bias=False) # hiddensize-8192
lm_head_i.weight.data = lm_head_weights[i]
self.lm_heads.append(lm_head_i)
return


return

# Ignore copy

def forward(
self,
input_ids: torch.LongTensor = None,
Expand All @@ -534,7 +518,6 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down Expand Up @@ -586,6 +569,3 @@ def forward(
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)



Loading

0 comments on commit 5eaf4bb

Please sign in to comment.