Skip to content

Commit

Permalink
diffuser model load using model and path params (#264)
Browse files Browse the repository at this point in the history
Co-authored-by: grajguru <[email protected]>
  • Loading branch information
gauravrajguru and grajguru authored Nov 8, 2023
1 parent 1286f45 commit 7d37f27
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
10 changes: 8 additions & 2 deletions mii/legacy/models/providers/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
import os
import torch

from .utils import attempt_load
from mii.config import ModelConfig

def diffusers_provider(model_config):

def diffusers_provider(model_config: ModelConfig):
from diffusers import DiffusionPipeline

local_rank = int(os.getenv("LOCAL_RANK", "0"))
Expand All @@ -16,7 +19,10 @@ def diffusers_provider(model_config):
kwargs["torch_dtype"] = torch.float16
kwargs["revision"] = "fp16"

pipeline = DiffusionPipeline.from_pretrained(model_config.model_name, **kwargs)
pipeline = attempt_load(DiffusionPipeline.from_pretrained,
model_config.model,
model_config.model_path,
kwargs=kwargs)
pipeline = pipeline.to(f"cuda:{local_rank}")
pipeline.set_progress_bar_config(disable=True)
return pipeline
24 changes: 24 additions & 0 deletions mii/legacy/models/providers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from mii.utils import is_aml, mii_cache_path


def attempt_load(load_fn, model_name, model_path, cache_path=None, kwargs={}):
try:
value = load_fn(model_name, **kwargs)
except Exception as ex:
if is_aml():
print(
f"Attempted load but failed - {str(ex)}, retrying using model_path={model_path}"
)
value = load_fn(model_path, **kwargs)
else:
cache_path = cache_path or mii_cache_path()
print(
f"Attempted load but failed - {str(ex)}, retrying using cache_dir={cache_path}"
)
value = load_fn(model_name, cache_dir=cache_path, **kwargs)
return value

0 comments on commit 7d37f27

Please sign in to comment.