From 21c11b6b5a34775133e8b7e9f8ee928e3c4b0a70 Mon Sep 17 00:00:00 2001 From: Onkar Chougule <168134249+ochougul@users.noreply.github.com> Date: Thu, 28 Nov 2024 11:03:59 +0530 Subject: [PATCH] added full_batch_size to replicate_kv script (#185) added full_batch_size to replicate_kv script and removed num_hidden_layers bug Signed-off-by: Onkar Chougule --- QEfficient/cloud/infer.py | 2 +- .../replicate_kv_head/replicate_kv_heads.py | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index d58e1a1b2..66eff2f8c 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -192,7 +192,7 @@ def main( ) parser.add_argument( "--full_batch_size", - "--full_batch_size", + "--full-batch-size", type=int, default=None, help="Set full batch size to enable continuous batching mode, default is None", diff --git a/scripts/replicate_kv_head/replicate_kv_heads.py b/scripts/replicate_kv_head/replicate_kv_heads.py index 2fdbaf883..328628da2 100644 --- a/scripts/replicate_kv_head/replicate_kv_heads.py +++ b/scripts/replicate_kv_head/replicate_kv_heads.py @@ -59,7 +59,7 @@ def main(args): replace_transformers_quantizers() model = AutoModelForCausalLM.from_pretrained( model_name, - num_hidden_layers=1, + # num_hidden_layers=1, # Use for generating smaller model attn_implementation="eager", ) # Undo the effect of replace_transformers_quantizers @@ -104,12 +104,13 @@ def main(args): ) # Export the modified model - q_model = QEFFAutoModelForCausalLM(model, model_name) + q_model = QEFFAutoModelForCausalLM(model, continuous_batching=(True if args.full_batch_size else False)) export( model_name, q_model, tokenizer=tokenizer, onnx_dir_path=f"{model_base_name}-{new_kv_heads}kvheads", + full_batch_size=(args.full_batch_size if args.full_batch_size else None), ) @@ -117,10 +118,21 @@ def main(args): # Set up argument parser parser = argparse.ArgumentParser(description="Modify and export a causal language model.") parser.add_argument( - "--model_name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct", help="Name of the model to use." + "--model_name", + "--model-name", + type=str, + default="meta-llama/Meta-Llama-3-8B-Instruct", + help="Name of the model to use.", ) parser.add_argument("--prompt", type=str, default="My name is", help="Prompt to use for the model.") parser.add_argument("--repeat", type=int, default=2, help="Factor to repeat key-value heads.") + parser.add_argument( + "--full_batch_size", + "--full-batch-size", + type=int, + default=None, + help="Set full batch size to enable continuous batching mode, default is None", + ) args = parser.parse_args() main(args)