Skip to content

Commit

Permalink
Edit for RK3588
Browse files Browse the repository at this point in the history
  • Loading branch information
happyme531 committed Feb 17, 2024
1 parent 7caa798 commit f881416
Show file tree
Hide file tree
Showing 11 changed files with 1,198 additions and 10 deletions.
88 changes: 82 additions & 6 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def _parse_args():
if system() == "Darwin":
target = tvm.target.Target("apple/m1-gpu")
else:
has_gpu = tvm.cuda().exist
target = tvm.target.Target("cuda" if has_gpu else "llvm")
# has_gpu = tvm.cuda().exist
# target = tvm.target.Target("cuda" if has_gpu else "llvm")
target = tvm.target.Target("opencl -device=mali -max_shared_memory_per_block=32768 -max_threads_per_block=1024 -max_num_threads=1024 -thread_warp_size=16")
print(f"Automatically configuring target: {target}")
parsed.target = tvm.target.Target(target, host="llvm")
elif parsed.target == "webgpu":
Expand Down Expand Up @@ -108,6 +109,7 @@ def legalize_and_lift_params(
entry_funcs = (
model_names + scheduler_func_names + ["image_to_rgba", "concat_embeddings"]
)
print(f"Entry functions: {entry_funcs}")

mod = relax.pipeline.get_pipeline()(mod)
mod = relax.transform.DeadCodeElimination(entry_funcs)(mod)
Expand All @@ -128,10 +130,83 @@ def legalize_and_lift_params(
def build(mod: tvm.IRModule, args: Dict) -> None:
from tvm import meta_schedule as ms

db = ms.database.create(work_dir=args.db_path)
with args.target, db, tvm.transform.PassContext(opt_level=3):
mod_deploy = relax.transform.MetaScheduleApplyDatabase(enable_warning=True)(mod)

# # tuning part
# # delete the VAE part of the model when tuning u-net. It will interfere with the tuning. Also it can run on NPU? https://clehaxze.tw/gemlog/2023/07-15-inexhaustive-list-of-models-that-works-on-rk3588.gmi
# entry_funcs = ['clip', 'unet', 'dpm_solver_multistep_scheduler_convert_model_output', 'dpm_solver_multistep_scheduler_step', 'pndm_scheduler_step_0', 'pndm_scheduler_step_1', 'pndm_scheduler_step_2', 'pndm_scheduler_step_3', 'pndm_scheduler_step_4', 'image_to_rgba', 'concat_embeddings']
# new_mod = tvm.IRModule()
# for gv, func in mod.functions.items():
# try:
# if func.attrs["global_symbol"] == "main" and func.attrs["num_input"] == 1: # vae
# continue
# except:
# pass
# new_mod[gv] = func
# mod = new_mod
# mod = relax.transform.DeadCodeElimination(entry_funcs)(mod)
# debug_dump_script(mod, "mod_tune.py", args)

# # Important!! run `echo 99999999999 > /sys/class/misc/mali0/device/progress_timeout` before tuning to avoid timeout issue 1
# # run tuning
# ms.relax_integration.tune_relax(
# mod=mod,
# target=args.target,
# params={},
# builder=ms.builder.LocalBuilder(
# max_workers=7,
# timeout_sec=450,
# ),
# op_names={"softmax2"},
# runner=ms.runner.LocalRunner(timeout_sec=120, # need to be that long!
# maximum_process_uses=1, # to avoid buggy behaivour of mali opencl that subsequent runs fail after the first failure # this code change is not committed yet
# evaluator_config=ms.runner.config.EvaluatorConfig(
# number=1, # avoid timeout 2
# repeat=1,
# min_repeat_ms=0, # https://github.com/apache/tvm/issues/16276
# )),
# work_dir="log_db_my",
# max_trials_global=100000,
# max_trials_per_task=8000,
# seed=42,
# num_trials_per_iter=32,
# )
# mydb = ms.database.create(work_dir="log_db_my")
# mydb1 = ms.database.create(work_dir="log_db_my_pruned2_novae")
# mydb.dump_pruned(
# mydb1,
# )
# db = ms.database.create(work_dir=args.db_path)
# with args.target, mydb, tvm.transform.PassContext(opt_level=3):
# mod_deploy = relax.transform.MetaScheduleApplyDatabase(enable_warning=True)(mod)
mod_deploy = mod
print("Applying database 1 =======================")
db3 = ms.database.create(work_dir="log_db_my_unet_softmax2") # For some reason, the softmax2 op run very slow on Mali GPU, so I need to tune it separately
with args.target, db3, tvm.transform.PassContext(opt_level=3):
mod_deploy = relax.transform.MetaScheduleApplyDatabase(enable_warning=True)(mod_deploy)
print("Applying database 2 =======================")
db0 = ms.database.create(work_dir="log_db_my_clip_unet") # The clip and unet part of the model
with args.target, db0, tvm.transform.PassContext(opt_level=3):
mod_deploy = relax.transform.MetaScheduleApplyDatabase(enable_warning=True)(mod_deploy)
print("Applying database 3 =======================")
db2 = ms.database.create(work_dir="log_db_my_vae") # The vae part of the model (Not tuned very well yet)
with args.target, db2, tvm.transform.PassContext(opt_level=3):
mod_deploy = relax.transform.MetaScheduleApplyDatabase(enable_warning=True)(mod_deploy)
print("Generating missing schedules ==============")
with tvm.target.Target("cuda"):
mod_deploy = tvm.tir.transform.DefaultGPUSchedule()(mod_deploy) # for some missing schedules

# i don't know why but the u-net, vae, clip symbol names changed to main and subgraph_0
# get the original symbol names back
# Delete this part if it is not necessary
for gv, func in mod_deploy.functions.items():
try:
if func.attrs["global_symbol"] == "main" and func.attrs["num_input"] == 3: # u-net
mod_deploy[gv] = func.with_attr("global_symbol", "unet")
if func.attrs["global_symbol"] == "main" and func.attrs["num_input"] == 1: # vae
mod_deploy[gv] = func.with_attr("global_symbol", "vae")
if func.attrs["global_symbol"] == "subgraph_0":
mod_deploy[gv] = func.with_attr("global_symbol", "clip")
except:
pass
debug_dump_script(mod_deploy, "mod_build_stage.py", args)

ex = relax.build(mod_deploy, args.target)
Expand All @@ -145,6 +220,7 @@ def build(mod: tvm.IRModule, args: Dict) -> None:

debug_dump_shader(ex, f"stable_diffusion_{target_kind}", args)
ex.export_library(os.path.join(args.artifact_path, output_filename))
print(ex.stats())


if __name__ == "__main__":
Expand Down
188 changes: 188 additions & 0 deletions convert_model_from_pth_safetensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conversion script for the LDM checkpoints. """

import argparse
import importlib

import torch

from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
)
# !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
parser.add_argument(
"--original_config_file",
default=None,
type=str,
help="The YAML config file corresponding to the original architecture.",
)
parser.add_argument(
"--config_files",
default=None,
type=str,
help="The YAML config file corresponding to the architecture.",
)
parser.add_argument(
"--num_in_channels",
default=None,
type=int,
help="The number of input channels. If `None` number of input channels will be automatically inferred.",
)
parser.add_argument(
"--scheduler_type",
default="pndm",
type=str,
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
)
parser.add_argument(
"--pipeline_type",
default=None,
type=str,
help=(
"The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'"
". If `None` pipeline will be automatically inferred."
),
)
parser.add_argument(
"--image_size",
default=None,
type=int,
help=(
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
" Base. Use 768 for Stable Diffusion v2."
),
)
parser.add_argument(
"--prediction_type",
default=None,
type=str,
help=(
"The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
" Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2."
),
)
parser.add_argument(
"--extract_ema",
action="store_true",
help=(
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
),
)
parser.add_argument(
"--upcast_attention",
action="store_true",
help=(
"Whether the attention computation should always be upcasted. This is necessary when running stable"
" diffusion 2.1."
),
)
parser.add_argument(
"--from_safetensors",
action="store_true",
help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
)
parser.add_argument(
"--to_safetensors",
action="store_true",
help="Whether to store pipeline in safetensors format or not.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
parser.add_argument(
"--stable_unclip",
type=str,
default=None,
required=False,
help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.",
)
parser.add_argument(
"--stable_unclip_prior",
type=str,
default=None,
required=False,
help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.",
)
parser.add_argument(
"--clip_stats_path",
type=str,
help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.",
required=False,
)
parser.add_argument(
"--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint."
)
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
parser.add_argument(
"--vae_path",
type=str,
default=None,
required=False,
help="Set to a path, hub id to an already converted vae to not convert it again.",
)
parser.add_argument(
"--pipeline_class_name",
type=str,
default=None,
required=False,
help="Specify the pipeline class name",
)

args = parser.parse_args()

if args.pipeline_class_name is not None:
library = importlib.import_module("diffusers")
class_obj = getattr(library, args.pipeline_class_name)
pipeline_class = class_obj
else:
pipeline_class = None

pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=args.checkpoint_path,
original_config_file=args.original_config_file,
config_files=args.config_files,
image_size=args.image_size,
prediction_type=args.prediction_type,
model_type=args.pipeline_type,
extract_ema=args.extract_ema,
scheduler_type=args.scheduler_type,
num_in_channels=args.num_in_channels,
upcast_attention=args.upcast_attention,
from_safetensors=args.from_safetensors,
device=args.device,
stable_unclip=args.stable_unclip,
stable_unclip_prior=args.stable_unclip_prior,
clip_stats_path=args.clip_stats_path,
controlnet=args.controlnet,
vae_path=args.vae_path,
pipeline_class=pipeline_class,
)

if args.half:
pipe.to(dtype=torch.float16)

if args.controlnet:
# only save the controlnet model
pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
else:
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
6 changes: 3 additions & 3 deletions deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def _parse_args():
args.add_argument("--debug-dump", action="store_true", default=False)
args.add_argument("--artifact-path", type=str, default="dist")
args.add_argument(
"--prompt", type=str, default="A photo of an astronaut riding a horse on mars."
"--prompt", type=str, default="masterpiece, best quality, A photo of an astronaut riding a horse on mars."
)
args.add_argument("--negative-prompt", type=str, default="")
args.add_argument("--negative-prompt", type=str, default="cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, blurry")
args.add_argument(
"--scheduler",
type=str,
Expand Down Expand Up @@ -114,7 +114,7 @@ def __call__(self, prompt: str, negative_prompt: str = ""):
noise_pred = self.unet_latents_to_noise_pred(latents, t, text_embeddings)
self.debug_dump(f"unet_output_{i}", noise_pred)
latents = self.scheduler.step(self.vm, noise_pred, latents, i)

latents.numpy() # Force a copy to avoid memory leak
self.debug_dump("vae_input", latents)
image = self.vae_to_image(latents)
self.debug_dump("vae_output", image)
Expand Down
Loading

0 comments on commit f881416

Please sign in to comment.