forked from microsoft/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add stable diffusion unit test (microsoft#2496)
Add a unit test for Stable Diffusion without using stable-diffusion model that needs HF token. Midjourney model does not need the HF token and has the same structure. We verified this by printing both models. The following is the structure: ``` StableDiffusionPipeline { "_class_name": "StableDiffusionPipeline", "_diffusers_version": "0.7.2", "feature_extractor": [ "transformers", "CLIPFeatureExtractor" ], "safety_checker": [ "stable_diffusion", "StableDiffusionSafetyChecker" ], "scheduler": [ "diffusers", "PNDMScheduler" ], "text_encoder": [ "transformers", "CLIPTextModel" ], "tokenizer": [ "transformers", "CLIPTokenizer" ], "unet": [ "diffusers", "UNet2DConditionModel" ], "vae": [ "diffusers", "AutoencoderKL" ] } ``` @cmikeh2 --------- Co-authored-by: Olatunji Ruwase <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Logan Adams <[email protected]> Co-authored-by: Lev Kurilenko <[email protected]> Co-authored-by: Lev Kurilenko <[email protected]>
- Loading branch information
1 parent
eed6f98
commit dfb5126
Showing
4 changed files
with
121 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
name: nv-sd | ||
|
||
on: | ||
schedule: | ||
- cron: "0 0 * * 0" | ||
workflow_dispatch: | ||
pull_request: | ||
paths: | ||
- "deepspeed/ops/transformer/inference/diffusers_**" | ||
- "tests/unit/inference/test_stable_diffusion.py" | ||
- "deepspeed/model_implementations/diffusers/unet.py" | ||
- "deepspeed/model_implementations/diffusers/vae.py" | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
permissions: | ||
contents: read | ||
issues: write | ||
|
||
jobs: | ||
sd-tests: | ||
runs-on: [self-hosted, nvidia, a6000] | ||
container: | ||
image: nvcr.io/nvidia/pytorch:23.03-py3 | ||
ports: | ||
- 80 | ||
options: --gpus all --shm-size "8G" | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
|
||
- name: Check container state | ||
run: | | ||
ldd --version | ||
nvcc --version | ||
nvidia-smi | ||
python -c "import torch; print('torch:', torch.__version__, torch)" | ||
python -c "import torch; print('CUDA available:', torch.cuda.is_available())" | ||
- name: Install transformers | ||
run: | | ||
git clone https://github.com/huggingface/transformers | ||
cd transformers | ||
git rev-parse --short HEAD | ||
python -m pip install . | ||
- name: Install deepspeed | ||
run: | | ||
pip install image-similarity-measures | ||
python -m pip install opencv-python==4.6.* --force-reinstall | ||
python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja | ||
python -m pip install .[dev,1bit,autotuning,sd] | ||
ds_report | ||
- name: Python environment | ||
run: | | ||
python -m pip list | ||
- name: Unit tests | ||
run: | | ||
unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch | ||
cd tests | ||
python -m pytest --color=yes --durations=0 --verbose -rF -m 'stable_diffusion' -k "TestStableDiffusion" unit/ --torch_ver="2.0" --cuda_ver="12" | ||
- name: Open GitHub issue if weekly CI fails | ||
if: ${{ failure() && (github.event_name == 'schedule') }} | ||
uses: JasonEtco/create-an-issue@v2 | ||
env: | ||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} | ||
with: | ||
filename: .github/ISSUE_TEMPLATE/ci_failure_report.md | ||
update_existing: true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
diffusers | ||
triton | ||
triton>=2.1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
# DeepSpeed Team | ||
|
||
import os | ||
import torch | ||
import pytest | ||
import deepspeed | ||
import numpy | ||
from unit.common import DistributedTest | ||
from deepspeed.accelerator import get_accelerator | ||
|
||
|
||
# Setup for these models is different from other pipelines, so we add a separate test | ||
@pytest.mark.stable_diffusion | ||
class TestStableDiffusion(DistributedTest): | ||
world_size = 1 | ||
|
||
def test(self): | ||
from diffusers import DiffusionPipeline | ||
from image_similarity_measures.quality_metrics import rmse | ||
generator = torch.Generator(device=get_accelerator().current_device()) | ||
seed = 0xABEDABE7 | ||
generator.manual_seed(seed) | ||
prompt = "a dog on a rocket" | ||
model = "prompthero/midjourney-v4-diffusion" | ||
local_rank = int(os.getenv("LOCAL_RANK", "0")) | ||
device = torch.device(f"cuda:{local_rank}") | ||
|
||
pipe = DiffusionPipeline.from_pretrained(model, torch_dtype=torch.half) | ||
pipe = pipe.to(device) | ||
baseline_image = pipe(prompt, guidance_scale=7.5, generator=generator).images[0] | ||
|
||
pipe = deepspeed.init_inference( | ||
pipe, | ||
mp_size=1, | ||
dtype=torch.half, | ||
replace_with_kernel_inject=True, | ||
enable_cuda_graph=True, | ||
) | ||
generator.manual_seed(seed) | ||
deepspeed_image = pipe(prompt, guidance_scale=7.5, generator=generator).images[0] | ||
|
||
rmse_value = rmse(org_img=numpy.asarray(baseline_image), pred_img=numpy.asarray(deepspeed_image)) | ||
|
||
# RMSE threshold value is arbitrary, may need to adjust as needed | ||
assert rmse_value <= 0.01 |