Skip to content

Commit

Permalink
Merge pull request #1244 from bghira/main
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
bghira authored Dec 24, 2024
2 parents 6900937 + 933789c commit f56900d
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 58 deletions.
3 changes: 2 additions & 1 deletion helpers/caching/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,8 @@ def encode_images(self, images, filepaths, load_from_cache=True):
) * self.vae.config.scaling_factor
else:
latents_uncached = (
latents_uncached.latent * self.vae.config.scaling_factor
getattr(latents_uncached, "latent", latents_uncached)
* self.vae.config.scaling_factor
)
logger.debug(f"Latents shape: {latents_uncached.shape}")

Expand Down
61 changes: 57 additions & 4 deletions helpers/training/deepspeed.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,45 @@
import accelerate, logging, os
import accelerate, logging, os, contextlib, transformers
from accelerate.state import AcceleratorState
from transformers.integrations import HfDeepSpeedConfig

logger = logging.getLogger(__name__)
logger = logging.getLogger("DeepSpeed")
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL", "INFO"))

from transformers.integrations.deepspeed import (
is_deepspeed_zero3_enabled,
set_hf_deepspeed_config,
unset_hf_deepspeed_config,
)


@contextlib.contextmanager
def temporarily_disable_deepspeed_zero3():
# https://github.com/huggingface/transformers/issues/28106
deepspeed_plugin = (
AcceleratorState().deepspeed_plugin
if accelerate.state.is_initialized()
else None
)
if deepspeed_plugin is None:
print("DeepSpeed was not enabled.")
return []

if deepspeed_plugin and is_deepspeed_zero3_enabled():
print("DeepSpeed being disabled.")
_hf_deepspeed_config_weak_ref = (
transformers.integrations.deepspeed._hf_deepspeed_config_weak_ref
)
unset_hf_deepspeed_config()
yield
print("DeepSpeed being enabled.")
set_hf_deepspeed_config(HfDeepSpeedConfig(deepspeed_plugin.deepspeed_config))
transformers.integrations.deepspeed._hf_deepspeed_config_weak_ref = (
_hf_deepspeed_config_weak_ref
)
else:
print(f"Doing nothing, deepspeed zero3 was not enabled?")
yield


def deepspeed_zero_init_disabled_context_manager():
"""
Expand All @@ -15,9 +51,16 @@ def deepspeed_zero_init_disabled_context_manager():
else None
)
if deepspeed_plugin is None:
logger.debug("DeepSpeed context manager disabled, no DeepSpeed detected.")
return []

return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
logger.debug(
f"DeepSpeed context manager enabled, DeepSpeed detected: {deepspeed_plugin}"
)
return [
deepspeed_plugin.zero3_init_context_manager(enable=False),
temporarily_disable_deepspeed_zero3(),
]


def prepare_model_for_deepspeed(accelerator, args):
Expand All @@ -38,9 +81,19 @@ def prepare_model_for_deepspeed(accelerator, args):
if offload_param["nvme_path"] == "none":
if args.offload_param_path is None:
raise ValueError(
f"DeepSpeed is using {offload_param['device']} but nvme_path is not specified."
f"DeepSpeed is using {offload_param['device']} but nvme_path is not specified. The configuration has '{offload_param['nvme_path']}' for 'nvme_path'."
)
else:
offload_buffer = 100000000.0
if args.model_family in ["flux"]:
# flux is big
offload_buffer = 131600000.0
logger.info(
f"Attempting to allocate {offload_buffer} size byte buffer."
)
accelerator.state.deepspeed_plugin.deepspeed_config[
"zero_optimization"
]["offload_param"]["buffer_size"] = offload_buffer
accelerator.state.deepspeed_plugin.deepspeed_config[
"zero_optimization"
]["offload_param"]["nvme_path"] = args.offload_param_path
Expand Down
6 changes: 5 additions & 1 deletion helpers/training/text_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def load_tes(
"EleutherAI/pile-t5-base",
torch_dtype=weight_dtype,
).encoder
text_encoder_1.eval()

if tokenizer_2 is not None:
if args.model_family.lower() == "flux":
Expand All @@ -287,4 +286,9 @@ def load_tes(
variant=args.variant,
)

for te in [text_encoder_1, text_encoder_2, text_encoder_3]:
if te is None:
continue
te.eval()

return text_encoder_variant, text_encoder_1, text_encoder_2, text_encoder_3
91 changes: 54 additions & 37 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,23 +470,24 @@ def init_vae(self, move_to_accelerator: bool = True):
else:
from diffusers import AutoencoderKL as AutoencoderClass

try:
self.vae = AutoencoderClass.from_pretrained(**self.config.vae_kwargs)
except:
logger.warning(
"Couldn't load VAE with default path. Trying without a subfolder.."
)
self.config.vae_kwargs["subfolder"] = None
self.vae = AutoencoderClass.from_pretrained(**self.config.vae_kwargs)
if (
self.vae is not None
and self.config.vae_enable_tiling
and hasattr(self.vae, "enable_tiling")
):
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
try:
self.vae = AutoencoderClass.from_pretrained(**self.config.vae_kwargs)
except:
logger.warning(
"Enabling VAE tiling for greatly reduced memory consumption due to --vae_enable_tiling which may result in VAE tiling artifacts in encoded latents."
"Couldn't load VAE with default path. Trying without a subfolder.."
)
self.vae.enable_tiling()
self.config.vae_kwargs["subfolder"] = None
self.vae = AutoencoderClass.from_pretrained(**self.config.vae_kwargs)
if (
self.vae is not None
and self.config.vae_enable_tiling
and hasattr(self.vae, "enable_tiling")
):
logger.warning(
"Enabling VAE tiling for greatly reduced memory consumption due to --vae_enable_tiling which may result in VAE tiling artifacts in encoded latents."
)
self.vae.enable_tiling()
if not move_to_accelerator:
logger.debug("Not moving VAE to accelerator.")
return
Expand Down Expand Up @@ -530,28 +531,28 @@ def init_text_encoder(self, move_to_accelerator: bool = True):
None,
None,
)
if self.tokenizer_1 is not None:
self.text_encoder_cls_1 = import_model_class_from_model_name_or_path(
self.config.text_encoder_path,
self.config.revision,
self.config,
subfolder=self.config.text_encoder_subfolder,
)
if self.tokenizer_2 is not None:
self.text_encoder_cls_2 = import_model_class_from_model_name_or_path(
self.config.pretrained_model_name_or_path,
self.config.revision,
self.config,
subfolder="text_encoder_2",
)
if self.tokenizer_3 is not None and self.config.model_family == "sd3":
self.text_encoder_cls_3 = import_model_class_from_model_name_or_path(
self.config.pretrained_model_name_or_path,
self.config.revision,
self.config,
subfolder="text_encoder_3",
)
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
if self.tokenizer_1 is not None:
self.text_encoder_cls_1 = import_model_class_from_model_name_or_path(
self.config.text_encoder_path,
self.config.revision,
self.config,
subfolder=self.config.text_encoder_subfolder,
)
if self.tokenizer_2 is not None:
self.text_encoder_cls_2 = import_model_class_from_model_name_or_path(
self.config.pretrained_model_name_or_path,
self.config.revision,
self.config,
subfolder="text_encoder_2",
)
if self.tokenizer_3 is not None and self.config.model_family == "sd3":
self.text_encoder_cls_3 = import_model_class_from_model_name_or_path(
self.config.pretrained_model_name_or_path,
self.config.revision,
self.config,
subfolder="text_encoder_3",
)
tokenizers = [self.tokenizer_1, self.tokenizer_2, self.tokenizer_3]
text_encoder_classes = [
self.text_encoder_cls_1,
Expand Down Expand Up @@ -669,7 +670,13 @@ def init_data_backend(self):

raise e

self.init_validation_prompts()
try:
self.init_validation_prompts()
except Exception as e:
logger.error("Could not generate validation prompts.")
logger.error(e)
raise e

# We calculate the number of steps per epoch by dividing the number of images by the effective batch divisor.
# Gradient accumulation steps mean that we only update the model weights every /n/ steps.
collected_data_backend_str = list(StateTracker.get_data_backends().keys())
Expand All @@ -695,6 +702,16 @@ def init_data_backend(self):
self.accelerator.wait_for_everyone()

def init_validation_prompts(self):
if (
hasattr(self.accelerator, "state")
and hasattr(self.accelerator.state, "deepspeed_plugin")
and getattr(self.accelerator.state.deepspeed_plugin, "deepspeed_config", {})
.get("zero_optimization", {})
.get("stage")
== 3
):
logger.error("Cannot run validations with DeepSpeed ZeRO stage 3.")
return
if self.accelerator.is_main_process:
if self.config.model_family == "flux":
(
Expand Down
24 changes: 16 additions & 8 deletions helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
from helpers.image_manipulation.brightness import calculate_luminance
from PIL import Image, ImageDraw, ImageFont
from diffusers import SanaPipeline
from helpers.training.deepspeed import (
deepspeed_zero_init_disabled_context_manager,
prepare_model_for_deepspeed,
)
from transformers.utils import ContextManagers

logger = logging.getLogger(__name__)
logger.setLevel(os.environ.get("SIMPLETUNER_LOG_LEVEL") or "INFO")
Expand Down Expand Up @@ -523,14 +528,17 @@ def init_vae(self):
self.vae = precached_vae
if self.vae is None:
logger.info(f"Initialising {AutoencoderClass}")
self.vae = AutoencoderClass.from_pretrained(
vae_path,
subfolder=(
"vae" if args.pretrained_vae_model_name_or_path is None else None
),
revision=args.revision,
force_upcast=False,
).to(self.inference_device)
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
self.vae = AutoencoderClass.from_pretrained(
vae_path,
subfolder=(
"vae"
if args.pretrained_vae_model_name_or_path is None
else None
),
revision=args.revision,
force_upcast=False,
).to(self.inference_device)
StateTracker.set_vae(self.vae)

return self.vae
Expand Down
10 changes: 5 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ wandb = "^0.19.1"
requests = "^2.32.3"
pillow = "^11.0.0"
opencv-python = "^4.10.0.84"
deepspeed = "^0.16.1"
deepspeed = "^0.16.2"
accelerate = "^1.2.1"
safetensors = "^0.4.5"
compel = "^2.0.1"
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
trainer.init_preprocessing_models()
trainer.init_precision(preprocessing_models_only=True)
trainer.init_data_backend()
trainer.init_validation_prompts()
# trainer.init_validation_prompts()
trainer.init_unload_text_encoder()
trainer.init_unload_vae()

Expand Down

0 comments on commit f56900d

Please sign in to comment.