Skip to content

Commit

Permalink
Added new model support Cohere/Command-R
Browse files Browse the repository at this point in the history
Signed-off-by: Ann <[email protected]>
  • Loading branch information
quic-akuruvil committed Nov 13, 2024
1 parent 244d81f commit d5c622b
Show file tree
Hide file tree
Showing 13 changed files with 721 additions and 18 deletions.
7 changes: 5 additions & 2 deletions QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import QEfficient
from QEfficient.cloud.export import get_onnx_model_path
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.utils import check_and_assign_cache_dir, get_qpc_dir_path, load_hf_tokenizer, qpc_exists
from QEfficient.utils import check_and_assign_cache_dir, get_qpc_dir_path, load_hf_tokenizer, qpc_exists, get_embeddings
from QEfficient.utils.logging_utils import logger


Expand Down Expand Up @@ -72,11 +72,12 @@ def main(
cache_dir=cache_dir,
hf_token=hf_token,
)
embeds,config = get_embeddings(model_name, hf_token,cache_dir,local_model_dir)

qpc_dir_path = get_qpc_dir_path(
model_name, num_cores, mos, batch_size, prompt_len, ctx_len, mxfp6, mxint8, device_group, full_batch_size
)

# Handle qpc generation
if qpc_exists(qpc_dir_path):
logger.info(f"Pre-compiled qpc found at {qpc_dir_path}! Executing with given prompt")
Expand Down Expand Up @@ -111,6 +112,8 @@ def main(
#########
cloud_ai_100_exec_kv(
tokenizer=tokenizer,
config=config,
embeddings=embeds,
qpc_path=qpc_dir_path,
device_id=device_group,
prompt=prompt,
Expand Down
9 changes: 7 additions & 2 deletions QEfficient/exporter/export_hf_to_cloud_ai_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,18 +199,23 @@ def export_kvstyle_transformed_model_to_onnx(
raise ValueError(f"Need seq_len to be greater than zero, got seq_len={seq_len}")

# Preprocess inputs
embeds=None
if model_name == "CohereForAI/c4ai-command-r-v01":
embeds = transformed_model.get_input_embeddings()
# inputs['inputs_embeds']=embeds(inputs.pop('input_ids'))
# Build inputs for prefill
input_handler = InputHandler(
batch_size=len(Constants.INPUT_STR),
tokenizer=tokenizer,
embeddings=embeds,
config=transformed_model.config,
prompt=Constants.INPUT_STR,
prompt_len=Constants.PROMPT_LEN,
ctx_len=seq_len,
full_batch_size=full_batch_size,
)

inputs = input_handler.prepare_pytorch_inputs()

pt_outputs = transformed_model(**inputs)
output_names = list(pt_outputs.keys())

Expand Down Expand Up @@ -260,7 +265,7 @@ def export_kvstyle_transformed_model_to_onnx(
for i, (key, value) in enumerate(pkv):
inputs[f"past_key.{i}"] = key
inputs[f"past_value.{i}"] = value

# Run onnxrt inference
input_names, ort_outputs = run_model_on_ort(
onnx_path=os.path.join(onnx_dir_path, f"{model_name}.onnx"),
Expand Down
1 change: 1 addition & 0 deletions QEfficient/exporter/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def export_onnx(
# Create dynamic axes dict for inputs that need to have dynamic input shapes
seq_len_inputs = {
"input_ids",
"inputs_embeds",
"attention_mask",
"position_ids",
"token_type_ids",
Expand Down
37 changes: 30 additions & 7 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import transformers
from transformers import AutoConfig
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast

from QEfficient.generation.cloud_infer import QAICInferenceSession
Expand Down Expand Up @@ -221,6 +223,8 @@ def print_latency_stats_kv(prompt, exec_info, automation: bool = False):

def cloud_ai_100_exec_kv(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
config: AutoConfig,
embeddings: torch.Tensor,
qpc_path: str,
prompt: Optional[str] = None,
prompts_txt_file_path: Optional[str] = None,
Expand Down Expand Up @@ -269,6 +273,8 @@ def cloud_ai_100_exec_kv(
generate_text = TextGeneration(
tokenizer=tokenizer,
prompt=prompt,
embeddings=embeddings,
config=config,
qpc_path=qpc_path,
device_id=device_id,
ctx_len=ctx_len,
Expand Down Expand Up @@ -310,6 +316,8 @@ class TextGeneration:
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
embeddings: torch.Tensor,
config: AutoConfig,
qpc_path: str,
prompt: List[str],
full_batch_size: Optional[int] = None,
Expand All @@ -321,6 +329,8 @@ def __init__(
write_io_dir: Optional[str] = None,
) -> None:
self.tokenizer = tokenizer
self.embeddings = embeddings
self.config = config
self.prompt = prompt
self.qpc_path = qpc_path
self.device_id = device_id
Expand Down Expand Up @@ -404,12 +414,20 @@ def _fetch_batch_size_prefill_seq_len(
prefill_seq_len: The prefill sequence length fetched from the session's bindings or allowed shapes.
"""
if self.session.allowed_shapes:
batch_size = max(
[x[self.session.binding_index_map["input_ids"]][1][0] for x in self.session.allowed_shapes]
)
prefill_seq_len = max(
[x[self.session.binding_index_map["input_ids"]][1][1] for x in self.session.allowed_shapes]
)
if "input_ids" in self.session.binding_index_map:
batch_size = max(
[x[self.session.binding_index_map["input_ids"]][1][0] for x in self.session.allowed_shapes]
)
prefill_seq_len = max(
[x[self.session.binding_index_map["input_ids"]][1][1] for x in self.session.allowed_shapes]
)
else:
batch_size = max(
[x[self.session.binding_index_map["inputs_embeds"]][1][0] for x in self.session.allowed_shapes]
)
prefill_seq_len = max(
[x[self.session.binding_index_map["inputs_embeds"]][1][1] for x in self.session.allowed_shapes]
)
else:
batch_size, prefill_seq_len = self.session.bindings[self.session.binding_index_map["input_ids"]].dims
return batch_size, prefill_seq_len
Expand Down Expand Up @@ -460,7 +478,8 @@ def prepare_decode_inputs(self):
decode_inputs["position_ids"] = self.decode_pos_ids
if self.batch_index is not None:
decode_inputs["batch_index"] = self.batch_index

if self.config.architectures[0] == 'CohereForCausalLM':
decode_inputs['inputs_embeds'] = self.embeddings(torch.tensor(decode_inputs['input_ids'])).detach().numpy()
return decode_inputs

def _update_decode_input(self, outputs, position_ids, generation_len, decode_batch_id=None):
Expand Down Expand Up @@ -557,6 +576,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
chunk_inputs["position_ids"] = inputs["position_ids"][
:, i * self.prefill_seq_len : (i + 1) * self.prefill_seq_len
]
if self.config.architectures[0] == 'CohereForCausalLM':
chunk_inputs['inputs_embeds'] = self.embeddings(torch.tensor(chunk_inputs.pop('input_ids'))).detach().numpy()
outputs = self.session.run(chunk_inputs)
if self.write_io_dir is not None:
write_io_files(inputs, outputs, self.write_io_dir, "prefill", "aic_batch_io", True, False)
Expand Down Expand Up @@ -656,6 +677,8 @@ def run_decode(self, decode_inputs, generation_len):
for num_token in range(1, generation_len):
if self.stream:
self.streamer.put(decode_inputs["input_ids"][0])
if self.config.architectures[0] == 'CohereForCausalLM':
decode_inputs['inputs_embeds'] = self.embeddings(torch.tensor(decode_inputs.pop("input_ids"))).detach().numpy()
outputs = self.session.run(decode_inputs)

if self.write_io_dir is not None:
Expand Down
22 changes: 22 additions & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@
GPTBigCodeForCausalLM,
GPTBigCodeModel,
)
from transformers.models.cohere.modeling_cohere import (
CohereRotaryEmbedding,
CohereAttention,
CohereModel,
CohereForCausalLM,
CohereDecoderLayer,
)
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJForCausalLM, GPTJModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
Expand Down Expand Up @@ -95,6 +102,13 @@
QEffGemma2ForCausalLM,
QEffGemma2Model,
)
from .models.cohere.modeling_cohere import (
QEffCohereAttention,
QEffCohereForCausalLM,
QEffCohereModel,
QEffCohereRotaryEmbedding,
QEffCohereDecoderLayer,
)
from .models.gpt2.modeling_gpt2 import QEffGPT2Attention, QEffGPT2Block, QEffGPT2LMHeadModel, QEffGPT2Model
from .models.gpt_bigcode.modeling_gpt_bigcode import (
QEffGPTBigCodeAttention,
Expand Down Expand Up @@ -154,6 +168,7 @@
MptForCausalLM.__name__,
FalconForCausalLM.__name__,
GPTBigCodeForCausalLM.__name__,
CohereForCausalLM.__name__,
]
)
# Create an instance of the named tuple
Expand All @@ -174,6 +189,7 @@
Qwen2ForCausalLM.__name__,
Starcoder2ForCausalLM.__name__,
GPTBigCodeForCausalLM.__name__,
CohereForCausalLM.__name__,
]
)

Expand Down Expand Up @@ -217,6 +233,12 @@
CodeGenModel: QEffCodeGenModel,
CodeGenForCausalLM: QEffCodeGenForCausalLM,
CodeGenBlock: QeffCodeGenBlock,
# Cohere
CohereForCausalLM: QEffCohereForCausalLM,
CohereAttention: QEffCohereAttention,
CohereModel: QEffCohereModel,
CohereRotaryEmbedding: QEffCohereRotaryEmbedding,
CohereDecoderLayer: QEffCohereDecoderLayer,
# Mistral model layers
MistralAttention: QEffMistralAttention,
MistralDecoderLayer: QEffMistralDecoderLayer,
Expand Down
7 changes: 7 additions & 0 deletions QEfficient/transformers/models/cohere/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

Loading

0 comments on commit d5c622b

Please sign in to comment.