From 32f526516be87a3b549949d770bf562910f92d5f Mon Sep 17 00:00:00 2001 From: Chi Zhang Date: Thu, 12 Sep 2024 18:40:05 -0700 Subject: [PATCH 1/3] Add scripts to export mmdit and vae into onnx format --- .../sd3_inference/sd3_mmdit_onnx.py | 130 ++++++++++++++++++ .../sd3_inference/sd3_vae_onnx.py | 130 ++++++++++++++++++ 2 files changed, 260 insertions(+) create mode 100644 models/turbine_models/custom_models/sd3_inference/sd3_mmdit_onnx.py create mode 100644 models/turbine_models/custom_models/sd3_inference/sd3_vae_onnx.py diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_onnx.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_onnx.py new file mode 100644 index 000000000..fc3a53c06 --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_onnx.py @@ -0,0 +1,130 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import copy +import os +import sys +import math + +import numpy as np +from shark_turbine.aot import * + +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import SD3Transformer2DModel + + +class MMDiTModel(torch.nn.Module): + def __init__( + self, + hf_model_name="stabilityai/stable-diffusion-3-medium-diffusers", + dtype=torch.float16, + ): + super().__init__() + self.mmdit = SD3Transformer2DModel.from_pretrained( + hf_model_name, + subfolder="transformer", + torch_dtype=dtype, + low_cpu_mem_usage=False, + ) + + def forward( + self, + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + ): + # timestep.expand(hidden_states.shape[0]) + noise_pred = self.mmdit( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + return_dict=False, + )[0] + return noise_pred + +@torch.no_grad() +def export_mmdit_model( + hf_model_name="stabilityai/stable-diffusion-3-medium-diffusers", + batch_size=1, + height=512, + width=512, + precision="fp16", + max_length=77 +): + dtype = torch.float16 if precision == "fp16" else torch.float32 + mmdit_model = MMDiTModel( + dtype=dtype, + ) + file_prefix = "C:/Users/chiz/work/sd3/mmdit/exported/" + safe_name = file_prefix + utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_mmdit", + ) + ".onnx" + print(safe_name) + + do_classifier_free_guidance = True + init_batch_dim = 2 if do_classifier_free_guidance else 1 + batch_size = batch_size * init_batch_dim + hidden_states_shape = ( + batch_size, + 16, + height // 8, + width // 8, + ) + encoder_hidden_states_shape = (batch_size, 154, 4096) + pooled_projections_shape = (batch_size, 2048) + hidden_states = torch.empty(hidden_states_shape, dtype=dtype) + encoder_hidden_states = torch.empty(encoder_hidden_states_shape, dtype=dtype) + pooled_projections = torch.empty(pooled_projections_shape, dtype=dtype) + timestep = torch.empty(batch_size, dtype=dtype) + # mmdit_model(hidden_states, encoder_hidden_states, pooled_projections, timestep) + + torch.onnx.export( + mmdit_model, # model being run + ( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep + ), # model input (or a tuple for multiple inputs) + safe_name, # where to save the model (can be a file or file-like object) + export_params=True, # store the trained parameter weights inside the model file + opset_version=17, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=[ + "hidden_states", + "encoder_hidden_states", + "pooled_projections", + "timestep" + ], # the model's input names + output_names=[ + "sample_out", + ], # the model's output names + ) + return safe_name + + + +if __name__ == "__main__": + import logging + + logging.basicConfig(level=logging.DEBUG) + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + onnx_model_name = export_mmdit_model( + args.hf_model_name, + 1, # args.batch_size, + 512, # args.height, + 512, # args.width, + "fp16", # args.precision, + 77, # args.max_length, + ) + + print("Saved to", onnx_model_name) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_vae_onnx.py b/models/turbine_models/custom_models/sd3_inference/sd3_vae_onnx.py new file mode 100644 index 000000000..5d97c623c --- /dev/null +++ b/models/turbine_models/custom_models/sd3_inference/sd3_vae_onnx.py @@ -0,0 +1,130 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import copy +import os +import sys + +from iree import runtime as ireert +from iree.compiler.ir import Context +import numpy as np +from shark_turbine.aot import * +from shark_turbine.dynamo.passes import ( + DEFAULT_DECOMPOSITIONS, +) +from turbine_models.custom_models.sd_inference import utils +import torch +import torch._dynamo as dynamo +from diffusers import AutoencoderKL + + +class VaeModel(torch.nn.Module): + def __init__( + self, + hf_model_name, + ): + super().__init__() + self.vae = AutoencoderKL.from_pretrained( + hf_model_name, + subfolder="vae", + ) + + def forward(self, inp): + inp = (inp / self.vae.config.scaling_factor) + self.vae.config.shift_factor + image = self.vae.decode(inp, return_dict=False)[0] + image = image.float() + image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + return image + + # def decode(self, inp): + # inp = (inp / self.vae.config.scaling_factor) + self.vae.config.shift_factor + # image = self.vae.decode(inp, return_dict=False)[0] + # image = image.float() + # image = torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)[0] + # return image + + # def encode(self, inp): + # image_np = inp / 255.0 + # image_np = np.moveaxis(image_np, 2, 0) + # batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0) + # image_torch = torch.from_numpy(batch_images) + # image_torch = 2.0 * image_torch - 1.0 + # image_torch = image_torch + # latent = self.vae.encode(image_torch) + # return latent + + +def export_vae_model( + vae_model, + hf_model_name="stabilityai/stable-diffusion-3-medium-diffusers", + batch_size=1, + height=512, + width=512, + precision="fp32" +): + dtype = torch.float16 if precision == "fp16" else torch.float32 + file_prefix = "C:/Users/chiz/work/sd3/vae_decoder/exported/" + safe_name = file_prefix + utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{height}x{width}_{precision}_vae", + ) + ".onnx" + print(safe_name) + + if dtype == torch.float16: + vae_model = vae_model.half() + + + # input_image_shape = (height, width, 3) + input_latents_shape = (batch_size, 16, height // 8, width // 8) + input_latents = torch.empty(input_latents_shape, dtype=dtype) + # encode_args = [ + # torch.empty( + # input_image_shape, + # dtype=torch.float32, + # ) + # ] + # decode_args = [ + # torch.empty( + # input_latents_shape, + # dtype=dtype, + # ) + # ] + + torch.onnx.export( + vae_model, # model being run + ( + input_latents + ), # model input (or a tuple for multiple inputs) + safe_name, # where to save the model (can be a file or file-like object) + export_params=True, # store the trained parameter weights inside the model file + opset_version=17, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=[ + "input_latents", + ], # the model's input names + output_names=[ + "sample_out", + ], # the model's output names + ) + return safe_name + + + +if __name__ == "__main__": + from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args + + vae_model = VaeModel( + args.hf_model_name, + ) + onnx_model_name = export_vae_model( + vae_model, + args.hf_model_name, + 1, # args.batch_size, + 512, # height=args.height, + 512, # width=args.width, + "fp32" # precision=args.precision + ) + print("Saved to", onnx_model_name) From 0b70f6a66e201117310826299d2d31c2b093169b Mon Sep 17 00:00:00 2001 From: Chi Zhang Date: Thu, 12 Sep 2024 18:40:41 -0700 Subject: [PATCH 2/3] Enable running mmdit in onnx for the sd3 pipeline --- .../custom_models/pipeline_base.py | 97 ++++++++++++++++++- .../custom_models/sd3_inference/sd3_mmdit.py | 1 + .../custom_models/sd_inference/sd_cmd_opts.py | 16 +++ .../custom_models/sd_inference/sd_pipeline.py | 40 +++++++- run.py | 54 +++++++++++ 5 files changed, 203 insertions(+), 5 deletions(-) create mode 100644 run.py diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index c5f550e5d..ac8655b71 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -16,6 +16,8 @@ ) from turbine_models.utils.sdxl_benchmark import run_benchmark from turbine_models.model_runner import vmfbRunner +import onnxruntime +import pdb from PIL import Image import gc @@ -74,6 +76,59 @@ def merge_export_arg(model_map, arg, arg_name): # item = ast.literal_eval(item) # return out +class OnnxPipelineComponent: + def __init__( + self, + printer, + dest_type="numpy", + dest_dtype="fp16", + ): + self.ort_session = None + self.onnx_file_path = None + self.ep = None + self.dest_type = dest_type + self.dest_dtype = dest_dtype + self.printer = printer + self.supported_dtypes = ["fp32"] + self.default_dtype = "fp32" + self.used_dtype = dest_dtype if dest_dtype in self.supported_dtypes else self.default_dtype + def load( + self, + onnx_file_path: str, + ep="CPUExecutionProvider" + ): + self.onnx_file_path = onnx_file_path + self.ep = ep + + self.ort_session = onnxruntime.InferenceSession(onnx_file_path, providers=[ep]) + self.printer.print( + f"Loading {onnx_file_path} into onnxruntime with {ep}." + ) + def unload(self): + self.ort_session = None + gc.collect() + + # input type only support numpy + def _convert_inputs(self, inputs): + for iname in inputs.keys(): + inp = inputs[iname] + if isinstance(inp, ireert.DeviceArray): + inputs[iname] = inp.to_host() + inputs[iname] = inputs[iname].astype(np_dtypes[self.used_dtype]) + return inputs + def _convert_output(self, output): + return output.astype(np_dtypes[self.dest_dtype]) + + def __call__(self, inputs: dict): + converted_inputs = self._convert_inputs(inputs) + # pdb.set_trace() + out = self.ort_session.run( + None, + converted_inputs, + )[0] + return self._convert_output(out) + + class PipelineComponent: """ @@ -268,6 +323,16 @@ def __call__(self, function_name, inputs: list): # def _run_and_validate(self, iree_fn, torch_fn, inputs: list) +class Bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' class Printer: def __init__(self, verbose, start_time, print_time): @@ -284,24 +349,31 @@ def __init__(self, verbose, start_time, print_time): def reset(self): if self.print_time: + print(Bcolors.BOLD + Bcolors.WARNING) if self.verbose: self.print("Will now reset clock for printer to 0.0 [s].") self.last_print = time.time() self.start_time = time.time() if self.verbose: self.print("Clock for printer reset to t = 0.0 [s].") + print(Bcolors.ENDC, end='') def print(self, message): if self.verbose: # Print something like "[t=0.123 dt=0.004] 'message'" + print(Bcolors.BOLD + Bcolors.OKCYAN) if self.print_time: time_now = time.time() print( - f"[t={time_now - self.start_time:.3f} dt={time_now - self.last_print:.3f}] {message}" + f"[ts={time_now - self.start_time:.3f}s] {message}" ) + # print( + # f"[t={time_now - self.start_time:.3f} dt={time_now - self.last_print:.3f}] {message}" + # ) self.last_print = time_now else: print(f"{message}") + print(Bcolors.ENDC, end='') class TurbinePipelineBase: @@ -359,6 +431,8 @@ def __init__( ireec_flags: str | dict[str] = None, precision: str | dict[str] = "fp16", attn_spec: str | dict[str] = None, + onnx_model_path: str | dict[str] = None, + run_onnx_mmdit: bool = False, decomp_attn: bool | dict[bool] = False, external_weights: str | dict[str] = None, pipeline_dir: str = "./shark_vmfbs", @@ -372,6 +446,7 @@ def __init__( self.map = model_map self.verbose = verbose self.printer = Printer(self.verbose, time.time(), True) + self.run_onnx_mmdit=run_onnx_mmdit if isinstance(device, dict): assert isinstance( target, dict @@ -396,6 +471,7 @@ def __init__( map_arguments = { "ireec_flags": ireec_flags, "precision": precision, + "onnx_model_path": onnx_model_path, "attn_spec": attn_spec, "decomp_attn": decomp_attn, "external_weights": external_weights, @@ -412,6 +488,7 @@ def __init__( self.map = merge_arg_into_map( self.map, torch_dtypes[self.map[submodel]["precision"]], "torch_dtype" ) + # pdb.set_trace() for arg in common_export_args.keys(): for submodel in self.map.keys(): self.map[submodel].get("export_args", {})[arg] = self.map[submodel].get( @@ -761,6 +838,8 @@ def load_map(self): self.load_submodel(submodel) def load_submodel(self, submodel): + + if not self.map[submodel].get("vmfb"): raise ValueError(f"VMFB not found for {submodel}.") if not self.map[submodel].get("weights") and self.map[submodel].get( @@ -783,6 +862,22 @@ def load_submodel(self, submodel): ) setattr(self, submodel, self.map[submodel]["runner"]) + # add an onnx runners + if self.run_onnx_mmdit and submodel == "mmdit": + dest_type = "numpy" + dest_dtype = self.map[submodel]["precision"] + onnx_runner = OnnxPipelineComponent( + printer=self.printer, + dest_type=dest_type, + dest_dtype=dest_dtype + ) + ep = "CPUExecutionProvider" + onnx_runner.load( + onnx_file_path=self.map[submodel]["onnx_model_path"], + ep=ep + ) + setattr(self, submodel+"_onnx", onnx_runner) + def unload_submodel(self, submodel): self.map[submodel]["runner"].unload() self.map[submodel]["vmfb"] = None diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py index 40e0f18c4..0c49e61af 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py @@ -159,6 +159,7 @@ def export_mmdit_model( attn_spec=None, input_mlir=None, weights_only=False, + onnx_model_path=None, ): dtype = torch.float16 if precision == "fp16" else torch.float32 mmdit_model = MMDiTModel( diff --git a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py index 87e01467a..d09c5d977 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py +++ b/models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py @@ -445,4 +445,20 @@ def is_valid_file(arg): ) +############################################################################## +# ONNX Options +############################################################################## +p.add_argument( + "--mmdit_onnx_model_path", + type=str, + default="C:/Users/chiz/work/sd3/mmdit/fp32/mmdit_optimized.onnx", + help="Path to mmdit onnx model", +) + +p.add_argument( + "--run_onnx_mmdit", + action="store_true", + help="Run MMDiT in onnx", +) + args, unknown = p.parse_known_args() diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index e37c095c4..c983e468b 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -41,6 +41,7 @@ import numpy as np import time from datetime import datetime as dt +import pdb # These are arguments common among submodel exports. # They are expected to be populated in two steps: @@ -227,6 +228,8 @@ def __init__( target: str | dict[str], ireec_flags: str | dict[str] = None, attn_spec: str | dict[str] = None, + onnx_model_path: str | dict[str] = None, + run_onnx_mmdit: bool = False, decomp_attn: bool | dict[bool] = False, pipeline_dir: str = "./shark_vmfbs", external_weights_dir: str = "./shark_weights", @@ -287,6 +290,8 @@ def __init__( ireec_flags, precision, attn_spec, + onnx_model_path, + run_onnx_mmdit, decomp_attn, external_weights, pipeline_dir, @@ -419,6 +424,7 @@ def load_scheduler( scheduler_id: str = None, steps: int = 30, ): + # pdb.set_trace() if not self.cpu_scheduling: if self.is_sd3: export_fn = sd3_schedulers.export_scheduler_model @@ -460,6 +466,7 @@ def load_scheduler( self.pipeline_dir, utils.create_safe_name(self.base_model_name, scheduler_uid) + ".vmfb", ) + # pdb.set_trace() if not os.path.exists(scheduler_path): self.export_submodel("scheduler") else: @@ -720,10 +727,26 @@ def _produce_latents_sd3( pooled_prompt_embeds, t, ] - noise_pred = self.mmdit( - "run_forward", - mmdit_inputs, - ) + # pdb.set_trace() + if hasattr(self, 'mmdit_onnx'): + # pdb.set_trace() + latent_model_input = latent_model_input.to_host() + batch = latent_model_input.shape[0] + batched_t = np.repeat(t.to_host(), batch) + noise_pred = self.mmdit_onnx( + { + "hidden_states": latent_model_input, + "encoder_hidden_states": prompt_embeds, + "pooled_projections" : pooled_prompt_embeds, + "timestep": batched_t, + + } + ) + else: + noise_pred = self.mmdit( + "run_forward", + mmdit_inputs, + ) latents = self.scheduler( "run_step", [noise_pred, t, latents, guidance_scale, steps_list_gpu[i]] ) @@ -754,6 +777,7 @@ def generate_images( prompt = "" self.cpu_scheduling = cpu_scheduling + # pdb.set_trace() if steps and needs_new_scheduler: self.num_inference_steps = steps self.load_scheduler(scheduler_id, steps) @@ -884,6 +908,9 @@ def numpy_to_pil_image(images): "mmdit": args.mmdit_spec if args.mmdit_spec else args.attn_spec, "vae": args.vae_spec if args.vae_spec else args.attn_spec, } + onnx_model_paths = { + "mmdit": args.mmdit_onnx_model_path + } if not args.pipeline_dir: args.pipeline_dir = utils.create_safe_name(args.hf_model_name, "") benchmark = {} @@ -913,6 +940,7 @@ def numpy_to_pil_image(images): ), "vae": args.vae_decomp_attn if args.vae_decomp_attn else args.decomp_attn, } + # pdb.set_trace() sd_pipe = SharkSDPipeline( args.hf_model_name, args.height, @@ -924,6 +952,8 @@ def numpy_to_pil_image(images): targets, ireec_flags, specs, + onnx_model_paths, + args.run_onnx_mmdit, args.decomp_attn, args.pipeline_dir, args.external_weights_dir, @@ -937,8 +967,10 @@ def numpy_to_pil_image(images): args.verbose, save_outputs=save_outputs, ) + # pdb.set_trace() sd_pipe.prepare_all() sd_pipe.load_map() + # pdb.set_trace() sd_pipe.generate_images( args.prompt, args.negative_prompt, diff --git a/run.py b/run.py new file mode 100644 index 000000000..49e6ab2ea --- /dev/null +++ b/run.py @@ -0,0 +1,54 @@ +import os + +class bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + +def print_cmd (cmd, pipeline, flags): + print(bcolors.BOLD + bcolors.OKGREEN) + print (cmd, pipeline) + for f in flags: + print("\t", f) + print(bcolors.ENDC) + +cmd = "python" +pipeline = "models/turbine_models/custom_models/sd_inference/sd_pipeline.py" +prompt = "Photo of a ultra realistic sailing ship, dramatic light, pale sunrise, cinematic lighting, battered, low angle, trending on artstation, 4k, hyper realistic, focused, extreme details, unreal engine 5, cinematic, masterpiece, art by studio ghibli, intricate artwork by john william turner" +height = 512 +width=512 +mmdit_onnx_model_path = "C:/Users/chiz/work/sd3/mmdit/fp32/mmdit_optimized.onnx" +flags = [ + "--hf_model_name=stabilityai/stable-diffusion-3-medium-diffusers", + f"--height={height}", + f"--width={width}", + "--clip_device=local-task", + "--clip_precision=fp16", + "--clip_target=znver4", + "--clip_decomp_attn", + "--mmdit_precision=fp16", + "--mmdit_device=rocm-legacy://0", + "--mmdit_target=gfx1150", + '''--mmdit_flags="masked_attention" ''', + "--run_onnx_mmdit", + f'''--mmdit_onnx_model_path="{mmdit_onnx_model_path}" ''', + "--vae_device=rocm-legacy://0", + "--vae_precision=fp16", + "--vae_target=gfx1150", + '''--vae_flags="masked_attention" ''', + "--external_weights=safetensors", + "--num_inference_steps=28", + "--verbose", + f'''--prompt="{prompt}" ''' + ] + +print_cmd(cmd, pipeline, flags) + +final_cmd = ' '.join([cmd, pipeline]+flags) +os.system(final_cmd) \ No newline at end of file From 2146f54150ad85f2c8ff64b17f9eab92a9f9121b Mon Sep 17 00:00:00 2001 From: Chi Zhang Date: Mon, 16 Sep 2024 14:22:42 -0700 Subject: [PATCH 3/3] formatting and cleaning --- .../custom_models/pipeline_base.py | 67 ++++++++---------- .../sd3_inference/sd3_mmdit_onnx.py | 68 ++++++++++--------- .../sd3_inference/sd3_vae_onnx.py | 54 +++++++-------- .../custom_models/sd_inference/sd_pipeline.py | 23 ++----- .../custom_models/sd_inference/utils.py | 2 +- .../custom_models/sd_inference/vae.py | 1 + models/turbine_models/tests/sdxl_test.py | 25 ++++--- 7 files changed, 114 insertions(+), 126 deletions(-) diff --git a/models/turbine_models/custom_models/pipeline_base.py b/models/turbine_models/custom_models/pipeline_base.py index ac8655b71..b8ef57613 100644 --- a/models/turbine_models/custom_models/pipeline_base.py +++ b/models/turbine_models/custom_models/pipeline_base.py @@ -76,6 +76,7 @@ def merge_export_arg(model_map, arg, arg_name): # item = ast.literal_eval(item) # return out + class OnnxPipelineComponent: def __init__( self, @@ -91,19 +92,17 @@ def __init__( self.printer = printer self.supported_dtypes = ["fp32"] self.default_dtype = "fp32" - self.used_dtype = dest_dtype if dest_dtype in self.supported_dtypes else self.default_dtype - def load( - self, - onnx_file_path: str, - ep="CPUExecutionProvider" - ): + self.used_dtype = ( + dest_dtype if dest_dtype in self.supported_dtypes else self.default_dtype + ) + + def load(self, onnx_file_path: str, ep="CPUExecutionProvider"): self.onnx_file_path = onnx_file_path self.ep = ep - + self.ort_session = onnxruntime.InferenceSession(onnx_file_path, providers=[ep]) - self.printer.print( - f"Loading {onnx_file_path} into onnxruntime with {ep}." - ) + self.printer.print(f"Loading {onnx_file_path} into onnxruntime with {ep}.") + def unload(self): self.ort_session = None gc.collect() @@ -116,18 +115,17 @@ def _convert_inputs(self, inputs): inputs[iname] = inp.to_host() inputs[iname] = inputs[iname].astype(np_dtypes[self.used_dtype]) return inputs + def _convert_output(self, output): return output.astype(np_dtypes[self.dest_dtype]) - + def __call__(self, inputs: dict): converted_inputs = self._convert_inputs(inputs) - # pdb.set_trace() out = self.ort_session.run( None, converted_inputs, )[0] return self._convert_output(out) - class PipelineComponent: @@ -323,16 +321,18 @@ def __call__(self, function_name, inputs: list): # def _run_and_validate(self, iree_fn, torch_fn, inputs: list) + class Bcolors: - HEADER = '\033[95m' - OKBLUE = '\033[94m' - OKCYAN = '\033[96m' - OKGREEN = '\033[92m' - WARNING = '\033[93m' - FAIL = '\033[91m' - ENDC = '\033[0m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' + HEADER = "\033[95m" + OKBLUE = "\033[94m" + OKCYAN = "\033[96m" + OKGREEN = "\033[92m" + WARNING = "\033[93m" + FAIL = "\033[91m" + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + class Printer: def __init__(self, verbose, start_time, print_time): @@ -356,7 +356,7 @@ def reset(self): self.start_time = time.time() if self.verbose: self.print("Clock for printer reset to t = 0.0 [s].") - print(Bcolors.ENDC, end='') + print(Bcolors.ENDC, end="") def print(self, message): if self.verbose: @@ -364,16 +364,14 @@ def print(self, message): print(Bcolors.BOLD + Bcolors.OKCYAN) if self.print_time: time_now = time.time() - print( - f"[ts={time_now - self.start_time:.3f}s] {message}" - ) + print(f"[ts={time_now - self.start_time:.3f}s] {message}") # print( # f"[t={time_now - self.start_time:.3f} dt={time_now - self.last_print:.3f}] {message}" # ) self.last_print = time_now else: print(f"{message}") - print(Bcolors.ENDC, end='') + print(Bcolors.ENDC, end="") class TurbinePipelineBase: @@ -446,7 +444,7 @@ def __init__( self.map = model_map self.verbose = verbose self.printer = Printer(self.verbose, time.time(), True) - self.run_onnx_mmdit=run_onnx_mmdit + self.run_onnx_mmdit = run_onnx_mmdit if isinstance(device, dict): assert isinstance( target, dict @@ -488,7 +486,6 @@ def __init__( self.map = merge_arg_into_map( self.map, torch_dtypes[self.map[submodel]["precision"]], "torch_dtype" ) - # pdb.set_trace() for arg in common_export_args.keys(): for submodel in self.map.keys(): self.map[submodel].get("export_args", {})[arg] = self.map[submodel].get( @@ -838,7 +835,6 @@ def load_map(self): self.load_submodel(submodel) def load_submodel(self, submodel): - if not self.map[submodel].get("vmfb"): raise ValueError(f"VMFB not found for {submodel}.") @@ -862,21 +858,18 @@ def load_submodel(self, submodel): ) setattr(self, submodel, self.map[submodel]["runner"]) - # add an onnx runners + # add an onnx runners if self.run_onnx_mmdit and submodel == "mmdit": dest_type = "numpy" dest_dtype = self.map[submodel]["precision"] onnx_runner = OnnxPipelineComponent( - printer=self.printer, - dest_type=dest_type, - dest_dtype=dest_dtype + printer=self.printer, dest_type=dest_type, dest_dtype=dest_dtype ) ep = "CPUExecutionProvider" onnx_runner.load( - onnx_file_path=self.map[submodel]["onnx_model_path"], - ep=ep + onnx_file_path=self.map[submodel]["onnx_model_path"], ep=ep ) - setattr(self, submodel+"_onnx", onnx_runner) + setattr(self, submodel + "_onnx", onnx_runner) def unload_submodel(self, submodel): self.map[submodel]["runner"].unload() diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_onnx.py b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_onnx.py index fc3a53c06..33521fdc3 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_onnx.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_mmdit_onnx.py @@ -49,6 +49,7 @@ def forward( )[0] return noise_pred + @torch.no_grad() def export_mmdit_model( hf_model_name="stabilityai/stable-diffusion-3-medium-diffusers", @@ -56,17 +57,21 @@ def export_mmdit_model( height=512, width=512, precision="fp16", - max_length=77 + max_length=77, ): dtype = torch.float16 if precision == "fp16" else torch.float32 mmdit_model = MMDiTModel( dtype=dtype, ) file_prefix = "C:/Users/chiz/work/sd3/mmdit/exported/" - safe_name = file_prefix + utils.create_safe_name( - hf_model_name, - f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_mmdit", - ) + ".onnx" + safe_name = ( + file_prefix + + utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{max_length}_{height}x{width}_{precision}_mmdit", + ) + + ".onnx" + ) print(safe_name) do_classifier_free_guidance = True @@ -87,31 +92,30 @@ def export_mmdit_model( # mmdit_model(hidden_states, encoder_hidden_states, pooled_projections, timestep) torch.onnx.export( - mmdit_model, # model being run - ( - hidden_states, - encoder_hidden_states, - pooled_projections, - timestep - ), # model input (or a tuple for multiple inputs) - safe_name, # where to save the model (can be a file or file-like object) - export_params=True, # store the trained parameter weights inside the model file - opset_version=17, # the ONNX version to export the model to - do_constant_folding=True, # whether to execute constant folding for optimization - input_names=[ - "hidden_states", - "encoder_hidden_states", - "pooled_projections", - "timestep" - ], # the model's input names - output_names=[ - "sample_out", - ], # the model's output names - ) + mmdit_model, # model being run + ( + hidden_states, + encoder_hidden_states, + pooled_projections, + timestep, + ), # model input (or a tuple for multiple inputs) + safe_name, # where to save the model (can be a file or file-like object) + export_params=True, # store the trained parameter weights inside the model file + opset_version=17, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=[ + "hidden_states", + "encoder_hidden_states", + "pooled_projections", + "timestep", + ], # the model's input names + output_names=[ + "sample_out", + ], # the model's output names + ) return safe_name - if __name__ == "__main__": import logging @@ -120,11 +124,11 @@ def export_mmdit_model( onnx_model_name = export_mmdit_model( args.hf_model_name, - 1, # args.batch_size, - 512, # args.height, - 512, # args.width, - "fp16", # args.precision, - 77, # args.max_length, + 1, # args.batch_size, + 512, # args.height, + 512, # args.width, + "fp16", # args.precision, + 77, # args.max_length, ) print("Saved to", onnx_model_name) diff --git a/models/turbine_models/custom_models/sd3_inference/sd3_vae_onnx.py b/models/turbine_models/custom_models/sd3_inference/sd3_vae_onnx.py index 5d97c623c..d8e5c6592 100644 --- a/models/turbine_models/custom_models/sd3_inference/sd3_vae_onnx.py +++ b/models/turbine_models/custom_models/sd3_inference/sd3_vae_onnx.py @@ -63,20 +63,23 @@ def export_vae_model( batch_size=1, height=512, width=512, - precision="fp32" + precision="fp32", ): dtype = torch.float16 if precision == "fp16" else torch.float32 file_prefix = "C:/Users/chiz/work/sd3/vae_decoder/exported/" - safe_name = file_prefix + utils.create_safe_name( - hf_model_name, - f"_bs{batch_size}_{height}x{width}_{precision}_vae", - ) + ".onnx" + safe_name = ( + file_prefix + + utils.create_safe_name( + hf_model_name, + f"_bs{batch_size}_{height}x{width}_{precision}_vae", + ) + + ".onnx" + ) print(safe_name) if dtype == torch.float16: vae_model = vae_model.half() - # input_image_shape = (height, width, 3) input_latents_shape = (batch_size, 16, height // 8, width // 8) input_latents = torch.empty(input_latents_shape, dtype=dtype) @@ -94,37 +97,34 @@ def export_vae_model( # ] torch.onnx.export( - vae_model, # model being run - ( - input_latents - ), # model input (or a tuple for multiple inputs) - safe_name, # where to save the model (can be a file or file-like object) - export_params=True, # store the trained parameter weights inside the model file - opset_version=17, # the ONNX version to export the model to - do_constant_folding=True, # whether to execute constant folding for optimization - input_names=[ - "input_latents", - ], # the model's input names - output_names=[ - "sample_out", - ], # the model's output names - ) + vae_model, # model being run + (input_latents), # model input (or a tuple for multiple inputs) + safe_name, # where to save the model (can be a file or file-like object) + export_params=True, # store the trained parameter weights inside the model file + opset_version=17, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=[ + "input_latents", + ], # the model's input names + output_names=[ + "sample_out", + ], # the model's output names + ) return safe_name - if __name__ == "__main__": from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args vae_model = VaeModel( - args.hf_model_name, + args.hf_model_name, ) onnx_model_name = export_vae_model( vae_model, args.hf_model_name, - 1, # args.batch_size, - 512, # height=args.height, - 512, # width=args.width, - "fp32" # precision=args.precision + 1, # args.batch_size, + 512, # height=args.height, + 512, # width=args.width, + "fp32", # precision=args.precision ) print("Saved to", onnx_model_name) diff --git a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py index c983e468b..7b0fe7d4c 100644 --- a/models/turbine_models/custom_models/sd_inference/sd_pipeline.py +++ b/models/turbine_models/custom_models/sd_inference/sd_pipeline.py @@ -396,7 +396,9 @@ def setup_punet(self): self.map["unet"]["mlir"] = None self.map["unet"]["vmfb"] = None self.map["unet"]["weights"] = None - self.map["unet"]["keywords"] = [i for i in self.map["unet"]["keywords"] if i != "!punet"] + self.map["unet"]["keywords"] = [ + i for i in self.map["unet"]["keywords"] if i != "!punet" + ] self.map["unet"]["keywords"] += "punet" if self.use_i8_punet: if self.add_tk_kernels: @@ -424,7 +426,6 @@ def load_scheduler( scheduler_id: str = None, steps: int = 30, ): - # pdb.set_trace() if not self.cpu_scheduling: if self.is_sd3: export_fn = sd3_schedulers.export_scheduler_model @@ -466,7 +467,6 @@ def load_scheduler( self.pipeline_dir, utils.create_safe_name(self.base_model_name, scheduler_uid) + ".vmfb", ) - # pdb.set_trace() if not os.path.exists(scheduler_path): self.export_submodel("scheduler") else: @@ -727,19 +727,16 @@ def _produce_latents_sd3( pooled_prompt_embeds, t, ] - # pdb.set_trace() - if hasattr(self, 'mmdit_onnx'): - # pdb.set_trace() + if hasattr(self, "mmdit_onnx"): latent_model_input = latent_model_input.to_host() batch = latent_model_input.shape[0] batched_t = np.repeat(t.to_host(), batch) noise_pred = self.mmdit_onnx( { - "hidden_states": latent_model_input, + "hidden_states": latent_model_input, "encoder_hidden_states": prompt_embeds, - "pooled_projections" : pooled_prompt_embeds, + "pooled_projections": pooled_prompt_embeds, "timestep": batched_t, - } ) else: @@ -777,7 +774,6 @@ def generate_images( prompt = "" self.cpu_scheduling = cpu_scheduling - # pdb.set_trace() if steps and needs_new_scheduler: self.num_inference_steps = steps self.load_scheduler(scheduler_id, steps) @@ -908,9 +904,7 @@ def numpy_to_pil_image(images): "mmdit": args.mmdit_spec if args.mmdit_spec else args.attn_spec, "vae": args.vae_spec if args.vae_spec else args.attn_spec, } - onnx_model_paths = { - "mmdit": args.mmdit_onnx_model_path - } + onnx_model_paths = {"mmdit": args.mmdit_onnx_model_path} if not args.pipeline_dir: args.pipeline_dir = utils.create_safe_name(args.hf_model_name, "") benchmark = {} @@ -940,7 +934,6 @@ def numpy_to_pil_image(images): ), "vae": args.vae_decomp_attn if args.vae_decomp_attn else args.decomp_attn, } - # pdb.set_trace() sd_pipe = SharkSDPipeline( args.hf_model_name, args.height, @@ -967,10 +960,8 @@ def numpy_to_pil_image(images): args.verbose, save_outputs=save_outputs, ) - # pdb.set_trace() sd_pipe.prepare_all() sd_pipe.load_map() - # pdb.set_trace() sd_pipe.generate_images( args.prompt, args.negative_prompt, diff --git a/models/turbine_models/custom_models/sd_inference/utils.py b/models/turbine_models/custom_models/sd_inference/utils.py index 6667f1fdd..fbd09cd59 100644 --- a/models/turbine_models/custom_models/sd_inference/utils.py +++ b/models/turbine_models/custom_models/sd_inference/utils.py @@ -353,7 +353,7 @@ def compile_to_vmfb( # the TD spec is implemented in C++. if attn_spec in ["default", "mfma", "punet"]: -# if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: + # if any(x in safe_name for x in ["clip", "prompt_encoder"]) == False: use_punet = True if attn_spec in ["punet", "i8"] else False attn_spec = get_mfma_spec_path( target_triple, diff --git a/models/turbine_models/custom_models/sd_inference/vae.py b/models/turbine_models/custom_models/sd_inference/vae.py index add39b353..34cb85661 100644 --- a/models/turbine_models/custom_models/sd_inference/vae.py +++ b/models/turbine_models/custom_models/sd_inference/vae.py @@ -265,6 +265,7 @@ class CompiledVae(CompiledModule): if __name__ == "__main__": from turbine_models.custom_models.sd_inference.sd_cmd_opts import args + mod_str = export_vae_model( args.hf_model_name, args.batch_size, diff --git a/models/turbine_models/tests/sdxl_test.py b/models/turbine_models/tests/sdxl_test.py index 7cec4a661..cf0a05f5e 100644 --- a/models/turbine_models/tests/sdxl_test.py +++ b/models/turbine_models/tests/sdxl_test.py @@ -62,11 +62,15 @@ def command_line_args(request): arguments["compile_to"] = request.config.getoption("--compile_to") arguments["external_weights"] = request.config.getoption("--external_weights") arguments["decomp_attn"] = request.config.getoption("--decomp_attn") - arguments["attn_spec"] = request.config.getoption("--attn_spec") if request.config.getoption("attn_spec") else { - "text_encoder": request.config.getoption("clip_spec"), - "unet": request.config.getoption("unet_spec"), - "vae": request.config.getoption("vae_spec"), - } + arguments["attn_spec"] = ( + request.config.getoption("--attn_spec") + if request.config.getoption("attn_spec") + else { + "text_encoder": request.config.getoption("clip_spec"), + "unet": request.config.getoption("unet_spec"), + "vae": request.config.getoption("vae_spec"), + } + ) arguments["device"] = request.config.getoption("--device") arguments["rt_device"] = request.config.getoption("--rt_device") arguments["iree_target_triple"] = request.config.getoption("--iree_target_triple") @@ -117,9 +121,7 @@ def setUp(self): def test01_PromptEncoder(self): if arguments["device"] in ["vulkan", "cuda"]: - self.skipTest( - "Compilation error on vulkan; To be tested on cuda." - ) + self.skipTest("Compilation error on vulkan; To be tested on cuda.") arguments["vmfb_path"] = self.pipe.map["text_encoder"]["vmfb"] arguments["external_weight_path"] = self.pipe.map["text_encoder"]["weights"] tokenizer_1 = CLIPTokenizer.from_pretrained( @@ -350,9 +352,7 @@ def test05_t2i_generate_images(self): def test06_t2i_generate_images_punet(self): if arguments["device"] in ["vulkan", "cuda"]: - self.skipTest( - "Have issues with submodels on vulkan, cuda" - ) + self.skipTest("Have issues with submodels on vulkan, cuda") if getattr(self.pipe, "unet"): self.pipe.unload_submodel("unet") self.pipe.use_punet = True @@ -372,13 +372,12 @@ def test06_t2i_generate_images_punet(self): True, # return_img ) assert output is not None - + def tearDown(self): del self.pipe gc.collect() - if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) unittest.main()