diff --git a/benchmarks/text_to_image.py b/benchmarks/text_to_image.py index 85ec6bb43..32013fcf6 100644 --- a/benchmarks/text_to_image.py +++ b/benchmarks/text_to_image.py @@ -46,6 +46,7 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, default=MODEL) + parser.add_argument("--dtype", type=str, default="half") parser.add_argument("--variant", type=str, default=VARIANT) parser.add_argument("--custom-pipeline", type=str, default=CUSTOM_PIPELINE) parser.add_argument("--scheduler", type=str, default=SCHEDULER) @@ -92,6 +93,8 @@ def parse_args(): default=QUANTIZE_CONFIG, ) parser.add_argument("--quant-submodules-config-path", type=str, default=None) + parser.add_argument("--revision", type=str, default=None) + parser.add_argument("--local-files-only", action="store_true") return parser.parse_args() @@ -108,6 +111,8 @@ def load_pipe( scheduler=None, lora=None, controlnet=None, + revision=None, + local_files_only=False, ): extra_kwargs = {} if custom_pipeline is not None: @@ -115,6 +120,8 @@ def load_pipe( if variant is not None: extra_kwargs["variant"] = variant if dtype is not None: + dtype = getattr(torch, dtype) + assert isinstance(dtype, torch.dtype) extra_kwargs["torch_dtype"] = dtype if controlnet is not None: from diffusers import ControlNetModel @@ -124,6 +131,11 @@ def load_pipe( torch_dtype=dtype, ) extra_kwargs["controlnet"] = controlnet + if revision is not None: + extra_kwargs["revision"] = revision + if local_files_only: + extra_kwargs["local_files_only"] = True + if os.path.exists(os.path.join(model_name, "calibrate_info.txt")): from onediff.quantization import QuantPipeline @@ -231,11 +243,14 @@ def main(): pipe = load_pipe( pipeline_cls, args.model, + dtype=args.dtype, variant=args.variant, custom_pipeline=args.custom_pipeline, scheduler=args.scheduler, lora=args.lora, controlnet=args.controlnet, + revision=args.revision, + local_files_only=args.local_files_only, ) core_net = None @@ -349,6 +364,13 @@ def get_kwarg_inputs(): kwarg_inputs["cache_block_id"] = args.cache_block_id return kwarg_inputs + kwarg_inputs = get_kwarg_inputs() + + # patch for flux pipeline, rename negative_prompt to prompt2 + if pipe.__class__.__name__ == "FluxPipeline": + kwarg_inputs["prompt_2"] = kwarg_inputs["negative_prompt"] + kwarg_inputs.pop("negative_prompt") + # NOTE: Warm it up. # The initial calls will trigger compilation and might be very slow. # After that, it should be very fast. @@ -357,7 +379,7 @@ def get_kwarg_inputs(): print("=======================================") print("Begin warmup") for _ in range(args.warmups): - pipe(**get_kwarg_inputs()) + pipe(**kwarg_inputs) end = time.time() print("End warmup") print(f"Warmup time: {end - begin:.3f}s") @@ -365,7 +387,7 @@ def get_kwarg_inputs(): # Let"s see it! # Note: Progress bar might work incorrectly due to the async nature of CUDA. - kwarg_inputs = get_kwarg_inputs() + iter_profiler = IterationProfiler() if "callback_on_step_end" in inspect.signature(pipe).parameters: kwarg_inputs["callback_on_step_end"] = iter_profiler.callback_on_step_end @@ -387,6 +409,9 @@ def get_kwarg_inputs(): else: cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3) print(f"Max used CUDA memory : {cuda_mem_after_used:.3f}GiB") + if args.compiler != "oneflow": + cuda_mem_max_reserved = torch.cuda.max_memory_reserved() / (1024**3) + print(f"Peak CUDA memory : {cuda_mem_max_reserved:.3f}GiB") print("=======================================") if args.print_output: diff --git a/onediff_diffusers_extensions/examples/flux/README.md b/onediff_diffusers_extensions/examples/flux/README.md new file mode 100644 index 000000000..033f12763 --- /dev/null +++ b/onediff_diffusers_extensions/examples/flux/README.md @@ -0,0 +1,101 @@ +# Run FLUX with nexfort backend (Beta Release) + +1. [Environment Setup](#environment-setup) + - [Set Up OneDiff](#set-up-onediff) + - [Set Up NexFort Backend](#set-up-nexfort-backend) + - [Set Up Diffusers Library](#set-up-diffusers) + - [Set Up FLUX](#set-up-flux) +2. [Execution Instructions](#run) + - [Run Without Compilation (Baseline)](#run-without-compilation-baseline) + - [Run With Compilation](#run-with-compilation) +3. [Performance Comparison](#performance-comparison) +4. [Dynamic Shape for FLUX](#dynamic-shape-for-flux) + +## Environment setup +### Set up onediff +https://github.com/siliconflow/onediff?tab=readme-ov-file#installation + +### Set up nexfort backend +https://github.com/siliconflow/onediff/tree/main/src/onediff/infer_compiler/backends/nexfort + +### Set up diffusers + +``` +pip3 install --upgrade diffusers[torch] +``` +### Set up FLUX +Model version for diffusers: https://huggingface.co/black-forest-labs/FLUX.1-schnell + +HF pipeline: https://github.com/huggingface/diffusers/blob/main/docs/source/en/api/pipelines/flux.md + +## Run + +### Run without compilation (Baseline) +```shell +python3 benchmarks/text_to_image.py \ + --model black-forest-labs/FLUX.1-schnell \ + --height 1024 --width 1024 \ + --scheduler none \ + --steps 4 \ + --output-image ./flux-schnell.png \ + --prompt "beautiful scenery nature glass bottle landscape, , purple galaxy bottle," \ + --compiler none \ + --dtype bfloat16 \ + --seed 1 \ + --print-output +``` + +### Run with compilation + +```shell +python3 benchmarks/text_to_image.py \ + --model black-forest-labs/FLUX.1-schnell \ + --height 1024 --width 1024 \ + --scheduler none \ + --steps 4 \ + --output-image ./flux-schnell-compile.png \ + --prompt "beautiful scenery nature glass bottle landscape, , purple galaxy bottle," \ + --compiler nexfort \ + --compiler-config '{"mode": "benchmark:cudagraphs:max-autotune:low-precision:cache-all", "memory_format": "channels_last", "options": {"cuda.fuse_timestep_embedding": false, "inductor.force_triton_sdpa": true}}' \ + --dtype bfloat16 \ + --seed 1 \ + --print-output +``` + +## Performance comparison + +Testing on NVIDIA A800-SXM4-80GB, with image size of 1024*1024, iterating 4 steps: +| Metric | A800-SXM4-80GB 1024*1024 | +| ------------------------------------ | ------------------------ | +| Data update date (yyyy-mm-dd) | 2024-08-07 | +| PyTorch iteration speed | 2.18 it/s | +| OneDiff iteration speed | 2.80 it/s (+28.4%) | +| PyTorch E2E time | 2.06 s | +| OneDiff E2E time | 1.53 s (-25.7%) | +| PyTorch Max Mem Used | 35.79 GiB | +| OneDiff Max Mem Used | 40.44 GiB | +| PyTorch Warmup with Run time | 2.81 s | +| OneDiff Warmup with Compilation time | 253.01 s | +| OneDiff Warmup with Cache time | 73.63 s | + +1 OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Platinum 8358P CPU @ 2.60GHz. Note this is just for reference, and it varies a lot on different CPU. + + +## Dynamic shape for FLUX + +Run: + +```shell +python3 benchmarks/text_to_image.py \ + --model black-forest-labs/FLUX.1-schnell \ + --height 1024 --width 1024 \ + --scheduler none \ + --steps 4 \ + --output-image ./flux-schnell-compile.png \ + --prompt "beautiful scenery nature glass bottle landscape, , purple galaxy bottle," \ + --compiler nexfort \ + --compiler-config '{"mode": "benchmark:cudagraphs:max-autotune:low-precision:cache-all", "memory_format": "channels_last", "options": {"cuda.fuse_timestep_embedding": false, "inductor.force_triton_sdpa": true}, "dynamic", true}' \ + --run_multiple_resolutions 1 \ + --dtype bfloat16 \ + --seed 1 \ +``` diff --git a/onediff_diffusers_extensions/examples/text_to_image_flux.py b/onediff_diffusers_extensions/examples/text_to_image_flux.py new file mode 100644 index 000000000..245e46c38 --- /dev/null +++ b/onediff_diffusers_extensions/examples/text_to_image_flux.py @@ -0,0 +1,101 @@ +import argparse +import time + +import cv2 +import numpy as np +import torch + +from diffusers import FluxPipeline +from PIL import Image + +parser = argparse.ArgumentParser() +parser.add_argument("--base", type=str, default="black-forest-labs/FLUX.1-schnell") +parser.add_argument( + "--prompt", + type=str, + default="chinese painting style women", +) +parser.add_argument("--height", type=int, default=512) +parser.add_argument("--width", type=int, default=512) +parser.add_argument("--n_steps", type=int, default=4) +parser.add_argument("--saved_image", type=str, required=False, default="flux-out.png") +parser.add_argument("--seed", type=int, default=1) +parser.add_argument("--warmup", type=int, default=1) +parser.add_argument("--run", type=int, default=3) +parser.add_argument( + "--compile", type=(lambda x: str(x).lower() in ["true", "1", "yes"]), default=True +) +parser.add_argument("--run-multiple-resolutions", action="store_true") +args = parser.parse_args() + + +# load stable diffusion +pipe = FluxPipeline.from_pretrained(args.base, torch_dtype=torch.bfloat16) +# pipe = FluxPipeline.from_pretrained(args.base, torch_dtype=torch.bfloat16, local_files_only=True, revision="93424e3a1530639fefdf08d2a7a954312e5cb254") +pipe.to("cuda") + +if args.compile: + from onediffx import compile_pipe + + pipe = compile_pipe( + pipe, + backend="nexfort", + options={ + "options": { + "cuda.fuse_timestep_embedding": False, + "inductor.force_triton_sdpa": True, + } + }, + ) + + +# generate image +generator = torch.manual_seed(args.seed) + +print("Warmup") +for i in range(args.warmup): + image = pipe( + args.prompt, + height=args.height, + width=args.width, + output_type="pil", + num_inference_steps=args.n_steps, # use a larger number if you are using [dev] + generator=torch.Generator("cpu").manual_seed(args.seed), + ).images[0] + + +print("Run") +for i in range(args.run): + begin = time.time() + image = pipe( + args.prompt, + height=args.height, + width=args.width, + output_type="pil", + num_inference_steps=args.n_steps, # use a larger number if you are using [dev] + generator=torch.Generator("cpu").manual_seed(args.seed), + ).images[0] + end = time.time() + print(f"Inference time: {end - begin:.3f}s") + + image.save(f"{i=}th_{args.saved_image}.png") + + +if args.run_multiple_resolutions: + print("Test run with multiple resolutions...") + sizes = [1024, 512, 768, 256] + for h in sizes: + for w in sizes: + print(f"Running at resolution: {h}x{w}") + start_time = time.time() + image = pipe( + args.prompt, + height=h, + width=w, + output_type="pil", + num_inference_steps=args.n_steps, # use a larger number if you are using [dev] + generator=torch.Generator("cpu").manual_seed(args.seed), + ).images[0] + end_time = time.time() + print(f"Inference time: {end_time - start_time:.2f} seconds") + image.save(f"{i=}th_{args.saved_image}_{h}x{w}.png") diff --git a/src/onediff/utils/import_utils.py b/src/onediff/utils/import_utils.py index 111387966..33d72046c 100644 --- a/src/onediff/utils/import_utils.py +++ b/src/onediff/utils/import_utils.py @@ -23,7 +23,7 @@ def check_module_availability(module_name): return True -_oneflow_available = check_module_availability("oneflow") +_oneflow_available = None _onediff_quant_available = check_module_availability("onediff_quant") _nexfort_available = check_module_availability("nexfort") @@ -33,6 +33,9 @@ def check_module_availability(module_name): def is_oneflow_available(): + global _oneflow_available + if _oneflow_available is None: + _oneflow_available = check_module_availability("oneflow") return _oneflow_available