diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py
index d58e1a1b2..66eff2f8c 100644
--- a/QEfficient/cloud/infer.py
+++ b/QEfficient/cloud/infer.py
@@ -192,7 +192,7 @@ def main(
     )
     parser.add_argument(
         "--full_batch_size",
-        "--full_batch_size",
+        "--full-batch-size",
         type=int,
         default=None,
         help="Set full batch size to enable continuous batching mode, default is None",
diff --git a/scripts/replicate_kv_head/replicate_kv_heads.py b/scripts/replicate_kv_head/replicate_kv_heads.py
index 2fdbaf883..328628da2 100644
--- a/scripts/replicate_kv_head/replicate_kv_heads.py
+++ b/scripts/replicate_kv_head/replicate_kv_heads.py
@@ -59,7 +59,7 @@ def main(args):
     replace_transformers_quantizers()
     model = AutoModelForCausalLM.from_pretrained(
         model_name,
-        num_hidden_layers=1,
+        # num_hidden_layers=1,  # Use for generating smaller model
         attn_implementation="eager",
     )
     # Undo the effect of replace_transformers_quantizers
@@ -104,12 +104,13 @@ def main(args):
         )
 
     # Export the modified model
-    q_model = QEFFAutoModelForCausalLM(model, model_name)
+    q_model = QEFFAutoModelForCausalLM(model, continuous_batching=(True if args.full_batch_size else False))
     export(
         model_name,
         q_model,
         tokenizer=tokenizer,
         onnx_dir_path=f"{model_base_name}-{new_kv_heads}kvheads",
+        full_batch_size=(args.full_batch_size if args.full_batch_size else None),
     )
 
 
@@ -117,10 +118,21 @@ def main(args):
     # Set up argument parser
     parser = argparse.ArgumentParser(description="Modify and export a causal language model.")
     parser.add_argument(
-        "--model_name", type=str, default="meta-llama/Meta-Llama-3-8B-Instruct", help="Name of the model to use."
+        "--model_name",
+        "--model-name",
+        type=str,
+        default="meta-llama/Meta-Llama-3-8B-Instruct",
+        help="Name of the model to use.",
     )
     parser.add_argument("--prompt", type=str, default="My name is", help="Prompt to use for the model.")
     parser.add_argument("--repeat", type=int, default=2, help="Factor to repeat key-value heads.")
+    parser.add_argument(
+        "--full_batch_size",
+        "--full-batch-size",
+        type=int,
+        default=None,
+        help="Set full batch size to enable continuous batching mode, default is None",
+    )
 
     args = parser.parse_args()
     main(args)