Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new model support Cohere/Command-R #154

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import QEfficient
from QEfficient.cloud.export import get_onnx_model_path
from QEfficient.generation.text_generation_inference import cloud_ai_100_exec_kv
from QEfficient.utils import check_and_assign_cache_dir, get_qpc_dir_path, load_hf_tokenizer, qpc_exists
from QEfficient.utils import check_and_assign_cache_dir, get_embeddings, get_qpc_dir_path, load_hf_tokenizer, qpc_exists
from QEfficient.utils.logging_utils import logger


Expand Down Expand Up @@ -72,6 +72,7 @@ def main(
cache_dir=cache_dir,
hf_token=hf_token,
)
embeds, config = get_embeddings(model_name, hf_token, cache_dir, local_model_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to generalize this? I don't this is accessible for all the model categories.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes right, we can make this conditional, fetch embeddings only for Cohere.


qpc_dir_path = get_qpc_dir_path(
model_name, num_cores, mos, batch_size, prompt_len, ctx_len, mxfp6, mxint8, device_group, full_batch_size
Expand Down Expand Up @@ -111,6 +112,8 @@ def main(
#########
cloud_ai_100_exec_kv(
tokenizer=tokenizer,
config=config,
embeddings=embeds,
qpc_path=qpc_dir_path,
device_id=device_group,
prompt=prompt,
Expand Down
7 changes: 6 additions & 1 deletion QEfficient/exporter/export_hf_to_cloud_ai_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,18 +199,23 @@ def export_kvstyle_transformed_model_to_onnx(
raise ValueError(f"Need seq_len to be greater than zero, got seq_len={seq_len}")

# Preprocess inputs
embeds = None
if model_name == "CohereForAI/c4ai-command-r-v01":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to do this for all the versions of cohere and not specific to this model_name, then we need to do it as torch level? not a good practice to do at model_name level

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay sure, I have changed it to those architecture with CohereCausalLM head in architecture, so it's more generic.

embeds = transformed_model.get_input_embeddings()
# inputs['inputs_embeds']=embeds(inputs.pop('input_ids'))
# Build inputs for prefill
input_handler = InputHandler(
batch_size=len(Constants.INPUT_STR),
tokenizer=tokenizer,
embeddings=embeds,
config=transformed_model.config,
prompt=Constants.INPUT_STR,
prompt_len=Constants.PROMPT_LEN,
ctx_len=seq_len,
full_batch_size=full_batch_size,
)

inputs = input_handler.prepare_pytorch_inputs()

pt_outputs = transformed_model(**inputs)
output_names = list(pt_outputs.keys())

Expand Down
1 change: 1 addition & 0 deletions QEfficient/exporter/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def export_onnx(
# Create dynamic axes dict for inputs that need to have dynamic input shapes
seq_len_inputs = {
"input_ids",
"inputs_embeds",
"attention_mask",
"position_ids",
"token_type_ids",
Expand Down
42 changes: 34 additions & 8 deletions QEfficient/generation/text_generation_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import transformers
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import AutoConfig, PreTrainedTokenizer, PreTrainedTokenizerFast

from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import padding_check_and_fix
Expand Down Expand Up @@ -221,6 +222,8 @@ def print_latency_stats_kv(prompt, exec_info, automation: bool = False):

def cloud_ai_100_exec_kv(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
config: AutoConfig,
embeddings: torch.Tensor,
qpc_path: str,
prompt: Optional[str] = None,
prompts_txt_file_path: Optional[str] = None,
Expand Down Expand Up @@ -269,6 +272,8 @@ def cloud_ai_100_exec_kv(
generate_text = TextGeneration(
tokenizer=tokenizer,
prompt=prompt,
embeddings=embeddings,
config=config,
qpc_path=qpc_path,
device_id=device_id,
ctx_len=ctx_len,
Expand Down Expand Up @@ -310,6 +315,8 @@ class TextGeneration:
def __init__(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
embeddings: torch.Tensor,
config: AutoConfig,
qpc_path: str,
prompt: List[str],
full_batch_size: Optional[int] = None,
Expand All @@ -321,6 +328,8 @@ def __init__(
write_io_dir: Optional[str] = None,
) -> None:
self.tokenizer = tokenizer
self.embeddings = embeddings
self.config = config
self.prompt = prompt
self.qpc_path = qpc_path
self.device_id = device_id
Expand Down Expand Up @@ -404,12 +413,20 @@ def _fetch_batch_size_prefill_seq_len(
prefill_seq_len: The prefill sequence length fetched from the session's bindings or allowed shapes.
"""
if self.session.allowed_shapes:
batch_size = max(
[x[self.session.binding_index_map["input_ids"]][1][0] for x in self.session.allowed_shapes]
)
prefill_seq_len = max(
[x[self.session.binding_index_map["input_ids"]][1][1] for x in self.session.allowed_shapes]
)
if "input_ids" in self.session.binding_index_map:
batch_size = max(
[x[self.session.binding_index_map["input_ids"]][1][0] for x in self.session.allowed_shapes]
)
prefill_seq_len = max(
[x[self.session.binding_index_map["input_ids"]][1][1] for x in self.session.allowed_shapes]
)
else:
batch_size = max(
[x[self.session.binding_index_map["inputs_embeds"]][1][0] for x in self.session.allowed_shapes]
)
prefill_seq_len = max(
[x[self.session.binding_index_map["inputs_embeds"]][1][1] for x in self.session.allowed_shapes]
)
else:
batch_size, prefill_seq_len = self.session.bindings[self.session.binding_index_map["input_ids"]].dims
return batch_size, prefill_seq_len
Expand Down Expand Up @@ -460,7 +477,8 @@ def prepare_decode_inputs(self):
decode_inputs["position_ids"] = self.decode_pos_ids
if self.batch_index is not None:
decode_inputs["batch_index"] = self.batch_index

if self.config.architectures[0] == "CohereForCausalLM":
decode_inputs["inputs_embeds"] = self.embeddings(torch.tensor(decode_inputs["input_ids"])).detach().numpy()
return decode_inputs

def _update_decode_input(self, outputs, position_ids, generation_len, decode_batch_id=None):
Expand Down Expand Up @@ -557,6 +575,10 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
chunk_inputs["position_ids"] = inputs["position_ids"][
:, i * self.prefill_seq_len : (i + 1) * self.prefill_seq_len
]
if self.config.architectures[0] == "CohereForCausalLM":
chunk_inputs["inputs_embeds"] = (
self.embeddings(torch.tensor(chunk_inputs.pop("input_ids"))).detach().numpy()
)
outputs = self.session.run(chunk_inputs)
if self.write_io_dir is not None:
write_io_files(inputs, outputs, self.write_io_dir, "prefill", "aic_batch_io", True, False)
Expand Down Expand Up @@ -656,6 +678,10 @@ def run_decode(self, decode_inputs, generation_len):
for num_token in range(1, generation_len):
if self.stream:
self.streamer.put(decode_inputs["input_ids"][0])
if self.config.architectures[0] == "CohereForCausalLM":
decode_inputs["inputs_embeds"] = (
self.embeddings(torch.tensor(decode_inputs.pop("input_ids"))).detach().numpy()
)
outputs = self.session.run(decode_inputs)

if self.write_io_dir is not None:
Expand Down
22 changes: 22 additions & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
CodeGenForCausalLM,
CodeGenModel,
)
from transformers.models.cohere.modeling_cohere import (
CohereAttention,
CohereDecoderLayer,
CohereForCausalLM,
CohereModel,
CohereRotaryEmbedding,
)
from transformers.models.falcon.modeling_falcon import (
FalconAttention,
FalconForCausalLM,
Expand Down Expand Up @@ -83,6 +90,13 @@
QEffCodeGenForCausalLM,
QEffCodeGenModel,
)
from .models.cohere.modeling_cohere import (
QEffCohereAttention,
QEffCohereDecoderLayer,
QEffCohereForCausalLM,
QEffCohereModel,
QEffCohereRotaryEmbedding,
)
from .models.falcon.modeling_falcon import (
QEffFalconAttention,
QEffFalconForCausalLM,
Expand Down Expand Up @@ -154,6 +168,7 @@
MptForCausalLM.__name__,
FalconForCausalLM.__name__,
GPTBigCodeForCausalLM.__name__,
CohereForCausalLM.__name__,
]
)
# Create an instance of the named tuple
Expand All @@ -174,6 +189,7 @@
Qwen2ForCausalLM.__name__,
Starcoder2ForCausalLM.__name__,
GPTBigCodeForCausalLM.__name__,
CohereForCausalLM.__name__,
]
)

Expand Down Expand Up @@ -217,6 +233,12 @@
CodeGenModel: QEffCodeGenModel,
CodeGenForCausalLM: QEffCodeGenForCausalLM,
CodeGenBlock: QeffCodeGenBlock,
# Cohere
CohereForCausalLM: QEffCohereForCausalLM,
CohereAttention: QEffCohereAttention,
CohereModel: QEffCohereModel,
CohereRotaryEmbedding: QEffCohereRotaryEmbedding,
CohereDecoderLayer: QEffCohereDecoderLayer,
# Mistral model layers
MistralAttention: QEffMistralAttention,
MistralDecoderLayer: QEffMistralDecoderLayer,
Expand Down
7 changes: 7 additions & 0 deletions QEfficient/transformers/models/cohere/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------

Loading
Loading