Skip to content

Commit

Permalink
Migrating HL compile and export to infer APIs
Browse files Browse the repository at this point in the history
Change-Id: Idef53b1886462a05cd991bdbc1e9bf31fcad30cb
  • Loading branch information
asmigosw committed Dec 18, 2024
2 parents ed90daf + d267e40 commit 15684a7
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions QEfficient/cloud/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def main(
hf_token=hf_token,
)

if '--mxfp6' in sys.argv:
if "--mxfp6" in sys.argv:
if args.mxfp6:
logger.warning("mxfp6 is going to be deprecated in a future release, use -mxfp6_matmul instead.")
if '--mxint8' in sys.argv:
if "--mxint8" in sys.argv:
if args.mxint8:
logger.warning("mxint8 is going to be deprecated in a future release, use -mxint8_kv_cache instead.")

Expand All @@ -104,11 +104,13 @@ def main(
#########
# Execute
#########
_ = qeff_model.generate(tokenizer,
prompts=prompt,
device_id=device_group,
prompts_txt_file_path=prompts_txt_file_path,
generation_len=generation_len,)
_ = qeff_model.generate(
tokenizer,
prompts=prompt,
device_id=device_group,
prompts_txt_file_path=prompts_txt_file_path,
generation_len=generation_len,
)


if __name__ == "__main__":
Expand Down Expand Up @@ -139,7 +141,7 @@ def main(
"--mxfp6_matmul",
"--mxfp6-matmul",
action="store_true",
help="Compress constant MatMul weights to MXFP6 E2M3, default is no compression"
help="Compress constant MatMul weights to MXFP6 E2M3, default is no compression",
)
parser.add_argument(
"--mxint8",
Expand Down Expand Up @@ -205,9 +207,13 @@ def main(
args, compiler_options = parser.parse_known_args()
compiler_options_dict = {}
for i in range(0, len(compiler_options)):
if (compiler_options[i].startswith('--')):
key = compiler_options[i].lstrip('-')
value = compiler_options[i+1] if i+1 < len(compiler_options) and not compiler_options[i+1].startswith('-') else True
if compiler_options[i].startswith("--"):
key = compiler_options[i].lstrip("-")
value = (
compiler_options[i + 1]
if i + 1 < len(compiler_options) and not compiler_options[i + 1].startswith("-")
else True
)
compiler_options_dict[key] = value

if args.verbose:
Expand Down

0 comments on commit 15684a7

Please sign in to comment.