From 485bc73820df2509869f3bc562c453fbd45d7880 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Sat, 11 Jan 2025 04:40:41 +0530 Subject: [PATCH] reverted onnx file dynamic axes names for QNN Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/modeling_auto.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index ef061e1c..ff657d29 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -200,12 +200,12 @@ def export(self, export_dir: Optional[str] = None) -> str: } if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d pkv_dynamic_axes = { - 0: "kv_cache_batch_size", + 0: "full_batch_size" if self.continuous_batching else "batch_size", 1: "ctx_len", } else: # pkv is 4d pkv_dynamic_axes = { - 0: "kv_cache_batch_size", + 0: "full_batch_size" if self.continuous_batching else "batch_size", 2: "ctx_len", } output_names = ["logits"] @@ -307,10 +307,14 @@ def compile( "batch_size": 1 if self.continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, - "kv_cache_batch_size": kv_cache_batch_size, + # TODO: should be renamed to kv_cache_batch_size in specialzation too } prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ... - prefill_specialization.update({"full_batch_size": full_batch_size}) if full_batch_size else ... + if self.continuous_batching: + prefill_specialization.update({"full_batch_size": kv_cache_batch_size}) + else: + prefill_specialization.update({"batch_size": kv_cache_batch_size}) + prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ... specializations = [ prefill_specialization, ] @@ -321,8 +325,11 @@ def compile( "batch_size": full_batch_size if self.continuous_batching else batch_size, "seq_len": num_speculative_tokens + 1 if self.is_tlm else 1, "ctx_len": ctx_len, - "kv_cache_batch_size": kv_cache_batch_size, } + if self.continuous_batching: + decode_specialization.update({"full_batch_size": kv_cache_batch_size}) + else: + decode_specialization.update({"batch_size": kv_cache_batch_size}) decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ... specializations.append(decode_specialization)