Skip to content

Commit

Permalink
reverted onnx file dynamic axes names for QNN
Browse files Browse the repository at this point in the history
Signed-off-by: Onkar Chougule <[email protected]>
  • Loading branch information
ochougul committed Jan 10, 2025
1 parent 7cf607d commit 485bc73
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,12 +200,12 @@ def export(self, export_dir: Optional[str] = None) -> str:
}
if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d
pkv_dynamic_axes = {
0: "kv_cache_batch_size",
0: "full_batch_size" if self.continuous_batching else "batch_size",
1: "ctx_len",
}
else: # pkv is 4d
pkv_dynamic_axes = {
0: "kv_cache_batch_size",
0: "full_batch_size" if self.continuous_batching else "batch_size",
2: "ctx_len",
}
output_names = ["logits"]
Expand Down Expand Up @@ -307,10 +307,14 @@ def compile(
"batch_size": 1 if self.continuous_batching else batch_size,
"seq_len": prefill_seq_len,
"ctx_len": ctx_len,
"kv_cache_batch_size": kv_cache_batch_size,
# TODO: should be renamed to kv_cache_batch_size in specialzation too
}
prefill_specialization.update({"num_logits_to_keep": 1}) if self.is_tlm else ...
prefill_specialization.update({"full_batch_size": full_batch_size}) if full_batch_size else ...
if self.continuous_batching:
prefill_specialization.update({"full_batch_size": kv_cache_batch_size})
else:
prefill_specialization.update({"batch_size": kv_cache_batch_size})
prefill_specialization.update({"full_batch_exec_size": full_batch_size}) if full_batch_size else ...
specializations = [
prefill_specialization,
]
Expand All @@ -321,8 +325,11 @@ def compile(
"batch_size": full_batch_size if self.continuous_batching else batch_size,
"seq_len": num_speculative_tokens + 1 if self.is_tlm else 1,
"ctx_len": ctx_len,
"kv_cache_batch_size": kv_cache_batch_size,
}
if self.continuous_batching:
decode_specialization.update({"full_batch_size": kv_cache_batch_size})
else:
decode_specialization.update({"batch_size": kv_cache_batch_size})
decode_specialization.update({"num_logits_to_keep": num_speculative_tokens + 1}) if self.is_tlm else ...
specializations.append(decode_specialization)

Expand Down

0 comments on commit 485bc73

Please sign in to comment.