From 2a023c851a83a327c8a1498fca1662b2cdc0b2e1 Mon Sep 17 00:00:00 2001
From: Jou-An Chen <quic_jouachen@quicinc.com>
Date: Wed, 8 Jan 2025 17:29:22 -0800
Subject: [PATCH] Fix finite lorax generation in cb mode

Signed-off-by: Jou-An Chen <quic_jouachen@quicinc.com>
---
 .../generation/text_generation_inference.py   |  4 ++-
 QEfficient/peft/lora/auto.py                  |  6 ++--
 tests/peft/lora/test_lora_model.py            | 31 +++++++++++++++++--
 3 files changed, 34 insertions(+), 7 deletions(-)

diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py
index 4ddd57ada..54b6f057e 100755
--- a/QEfficient/generation/text_generation_inference.py
+++ b/QEfficient/generation/text_generation_inference.py
@@ -341,7 +341,9 @@ def cloud_ai_100_exec_kv(
             perf_metrics=PerfMetrics(prefill_time, decode_perf, total_perf, total_time),
         )
     else:
-        exec_info = generate_text.generate(prompt=prompt, generation_len=generation_len)
+        exec_info = generate_text.generate(
+            prompt=prompt, generation_len=generation_len, prompt_to_lora_id_mapping=prompt_to_lora_id_mapping
+        )
 
     print_latency_stats_kv(prompt, exec_info=exec_info, automation=automation)
     return exec_info
diff --git a/QEfficient/peft/lora/auto.py b/QEfficient/peft/lora/auto.py
index 2ccfac12a..c13979968 100644
--- a/QEfficient/peft/lora/auto.py
+++ b/QEfficient/peft/lora/auto.py
@@ -342,9 +342,9 @@ def generate(
         self,
         tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer],
         prompts: List[str],
-        device_id: List[int] = None,
         prompt_to_adapter_mapping: List[str] = None,
-        runtime: str = "AI_100",
+        device_id: Optional[List[int]] = None,
+        runtime: Optional[str] = "AI_100",
         **kwargs,
     ):
         """
@@ -355,9 +355,9 @@ def generate(
         ``Mandatory`` Args:
             :tokenizer (PreTrainedTokenizerFast or PreTrainedTokenizer): The tokenizer used in the inference
             :prompts (List[str]): List of prompts to run the execution.
-            :device_id (List[int]): Ids of devices for running the qpc pass as [0] in case of normal model / [0, 1, 2, 3] in case of tensor slicing model
             :prompt_to_adapter_mapping (List[str]): The sequence of the adapter names will be matched with sequence of prompts and corresponding adapters will be used for the prompts."base" for base model (no adapter).
         ``optional`` Args:
+            :device_id (List[int]): Device IDs to be used for execution. If ``len(device_id) > 1``, it enables multiple card setup. If ``None``, auto-device-picker will be used. ``Defaults to None``.
             :runtime (str, optional): Only ``AI_100`` runtime is supported as of now; ``ONNXRT`` and ``PyTorch`` coming soon. Defaults to "AI_100".
 
         """
diff --git a/tests/peft/lora/test_lora_model.py b/tests/peft/lora/test_lora_model.py
index a91555b3a..4726fb8c5 100644
--- a/tests/peft/lora/test_lora_model.py
+++ b/tests/peft/lora/test_lora_model.py
@@ -195,10 +195,12 @@ def test_auto_lora_model_for_causal_lm_load_unload_adapter(base_model_name, adap
     assert qeff_model.unload_adapter("adapter_0")  # valid unload
 
 
-# test the export, export caching, compile, generate workflow
+# test the export, export caching, compile and generate workflow in noncb mode
 @pytest.mark.on_qaic
 @pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[:1])
-def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name, adapter_id_0, adapter_id_1, tmp_path):
+def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate(
+    base_model_name, adapter_id_0, adapter_id_1, tmp_path
+):
     qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(base_model_name, num_hidden_layers=1)
 
     qeff_model.load_adapter(adapter_id_0, "adapter_0")
@@ -229,6 +231,29 @@ def test_auto_lora_model_for_causal_lm_export_compile_generate(base_model_name,
     qeff_model.generate(
         tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name),
         prompts=prompts,
-        device_id=[0],
+        prompt_to_adapter_mapping=["adapter_0", "adapter_1", "adapter_0", "base"],
+    )
+
+
+# test the compile and generate workflow in cb mode
+@pytest.mark.on_qaic
+@pytest.mark.parametrize("base_model_name,adapter_id_0,adapter_id_1", model_samples[:1])
+def test_auto_lora_model_for_causal_lm_cb_compile_generate(base_model_name, adapter_id_0, adapter_id_1, tmp_path):
+    qeff_model = QEffAutoLoraModelForCausalLM.from_pretrained(
+        base_model_name, continuous_batching=True, num_hidden_layers=1
+    )
+
+    qeff_model.load_adapter(adapter_id_0, "adapter_0")
+    qeff_model.load_adapter(adapter_id_1, "adapter_1")
+
+    # test compile
+    qeff_model.compile(prefill_seq_len=32, ctx_len=64, full_batch_size=2)
+    assert Path(qeff_model.qpc_path).is_dir()
+
+    # test generate
+    prompts = ["hello!", "hi", "hello, my name is", "hey"]
+    qeff_model.generate(
+        tokenizer=load_hf_tokenizer(pretrained_model_name_or_path=base_model_name),
+        prompts=prompts,
         prompt_to_adapter_mapping=["adapter_0", "adapter_1", "adapter_0", "base"],
     )