diff --git a/onediff_diffusers_extensions/examples/sd3/README.md b/onediff_diffusers_extensions/examples/sd3/README.md
index 70f227234..dd91de9cf 100644
--- a/onediff_diffusers_extensions/examples/sd3/README.md
+++ b/onediff_diffusers_extensions/examples/sd3/README.md
@@ -49,37 +49,54 @@ python3 onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py \
## Performance comparation
Testing on H800-NVL-80GB, with image size of 1024*1024, iterating 28 steps:
-| Metric | |
-| ------------------------------------------------ | ----------------------------------- |
-| Data update date(yyyy-mm-dd) | 2024-06-29 |
-| PyTorch iteration speed | 15.56 it/s |
-| OneDiff iteration speed | 24.12 it/s (+55.0%) |
-| PyTorch E2E time | 1.96 s |
-| OneDiff E2E time | 1.31 s (-33.2%) |
-| PyTorch Max Mem Used | 18.784 GiB |
-| OneDiff Max Mem Used | 18.324 GiB |
-| PyTorch Warmup with Run time | 2.86 s |
-| OneDiff Warmup with Compilation time1 | 889.25 s |
-| OneDiff Warmup with Cache time | 44.38 s |
+| Metric | |
+| ------------------------------------------------ | ------------------- |
+| Data update date(yyyy-mm-dd) | 2024-06-29 |
+| PyTorch iteration speed | 15.56 it/s |
+| OneDiff iteration speed | 24.12 it/s (+55.0%) |
+| PyTorch E2E time | 1.96 s |
+| OneDiff E2E time | 1.31 s (-33.2%) |
+| PyTorch Max Mem Used | 18.784 GiB |
+| OneDiff Max Mem Used | 18.324 GiB |
+| PyTorch Warmup with Run time | 2.86 s |
+| OneDiff Warmup with Compilation time1 | 889.25 s |
+| OneDiff Warmup with Cache time | 44.38 s |
1 OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Platinum 8468. Note this is just for reference, and it varies a lot on different CPU.
-Testing on 4090:
-| Metric | |
-| ------------------------------------------------ | ----------------------------------- |
-| Data update date(yyyy-mm-dd) | 2024-06-29 |
-| PyTorch iteration speed | 6.67 it/s |
-| OneDiff iteration speed | 11.51 it/s (+72.6%) |
-| PyTorch E2E time | 4.90 s |
-| OneDiff E2E time | 2.67 s (-45.5%) |
-| PyTorch Max Mem Used | 18.799 GiB |
-| OneDiff Max Mem Used | 17.902 GiB |
-| PyTorch Warmup with Run time | 4.99 s |
-| OneDiff Warmup with Compilation time2 | 302.79 s |
-| OneDiff Warmup with Cache time | 51.96 s |
-
- 2 AMD EPYC 7543 32-Core Processor
+Testing on RTX 4090:
+| Metric | |
+| ------------------------------------------------ | ------------------- |
+| Data update date(yyyy-mm-dd) | 2024-06-29 |
+| PyTorch iteration speed | 6.67 it/s |
+| OneDiff iteration speed | 11.51 it/s (+72.6%) |
+| PyTorch E2E time | 4.90 s |
+| OneDiff E2E time | 2.67 s (-45.5%) |
+| PyTorch Max Mem Used | 18.799 GiB |
+| OneDiff Max Mem Used | 17.902 GiB |
+| PyTorch Warmup with Run time | 4.99 s |
+| OneDiff Warmup with Compilation time2 | 302.79 s |
+| OneDiff Warmup with Cache time | 51.96 s |
+
+ 2 OneDiff Warmup with Compilation time is tested on AMD EPYC 7543 32-Core Processor
+
+Testing on A100(NVIDIA A100-PCIE-40GB):
+| Metric | |
+| ------------------------------------------------ | ------------------ |
+| Data update date(yyyy-mm-dd) | 2024-07-04 |
+| PyTorch iteration speed | 6.42 it/s |
+| OneDiff iteration speed | 8.98 it/s (+39.8%) |
+| PyTorch E2E time | 4.69 s |
+| OneDiff E2E time | 3.33 s (-29%) |
+| PyTorch Max Mem Used | 18.765 GiB |
+| OneDiff Max Mem Used | 17.89 GiB |
+| PyTorch Warmup with Run time | 5.73 s |
+| OneDiff Warmup with Compilation time3 | 601.98 s |
+| OneDiff Warmup with Cache time | 54 s |
+| | |
+
+ 3 OneDiff Warmup with Compilation time is tested on Intel(R) Xeon(R) Gold 6348 CPU @ 2.60GHz.
## Dynamic shape for SD3.
diff --git a/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py b/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py
index 4809c9f07..ee8a00c71 100644
--- a/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py
+++ b/onediff_diffusers_extensions/examples/sd3/text_to_image_sd3.py
@@ -182,18 +182,22 @@ def main():
"negative_prompt": args.negative_prompt,
}
- sd3.warmup(gen_args)
-
- for prompt in prompt_list:
- gen_args["prompt"] = prompt
- print(f"Processing prompt of length {len(prompt)} characters.")
- image, inference_time = sd3.generate(gen_args)
- assert inference_time < 20, "Prompt inference took too long"
- print(
- f"Generated image saved to {args.saved_image} in {inference_time:.2f} seconds."
- )
- cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3)
- print(f"Max used CUDA memory : {cuda_mem_after_used:.3f}GiB")
+ with torch.profiler.profile() as prof:
+ with torch.profiler.record_function("warmup compile"):
+ sd3.warmup(gen_args)
+
+ with torch.profiler.record_function("sd3 compiled"):
+ for prompt in prompt_list:
+ gen_args["prompt"] = prompt
+ print(f"Processing prompt of length {len(prompt)} characters.")
+ image, inference_time = sd3.generate(gen_args)
+ assert inference_time < 20, "Prompt inference took too long"
+ print(
+ f"Generated image saved to {args.saved_image} in {inference_time:.2f} seconds."
+ )
+ cuda_mem_after_used = torch.cuda.max_memory_allocated() / (1024**3)
+ print(f"Max used CUDA memory : {cuda_mem_after_used:.3f} GiB")
+ prof.export_chrome_trace("sd3_compile_cache.json")
if args.run_multiple_resolutions:
gen_args["prompt"] = args.prompt
diff --git a/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py b/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py
index 192820c69..665e87dd3 100644
--- a/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py
+++ b/onediff_diffusers_extensions/onediffx/compilers/diffusion_pipeline_compiler.py
@@ -38,6 +38,9 @@ def _recursive_setattr(obj, attr, value):
"vae.decoder",
"vae.encoder",
]
+_PARTS = [
+ "transformer", # for Transformer-based DiffusionPipeline such as DiTPipeline and PixArtAlphaPipeline
+]
def _filter_parts(ignores=()):