Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge #1296

Merged
merged 9 commits into from
Jan 22, 2025
Merged

merge #1296

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/python-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ jobs:
run: poetry -C install/apple install

- name: Run Tests
run: poetry -C install/apple run python -m unittest discover tests/
run: poetry -C ./ -P install/apple run python -m unittest discover tests/
32 changes: 29 additions & 3 deletions OPTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,21 @@ A lot of settings are instead set through the [dataloader config](/documentation
- **Choices**: constant, constant_with_warmup, cosine, cosine_with_restarts, **polynomial** (recommended), linear
- **Why**: Models benefit from continual learning rate adjustments to further explore the loss landscape. A cosine schedule is used as the default; this allows the training to smoothly transition between two extremes. If using a constant learning rate, it is common to select a too-high or too-low value, causing divergence (too high) or getting stuck in a local minima (too low). A polynomial schedule is best paired with a warmup, where it will gradually approach the `learning_rate` value before then slowing down and approaching `--lr_end` by the end.

### `--optimizer`

- **What**: The optimizer to use for training.
- **Choices**: adamw_bf16, ao-adamw8bit, ao-adamw4bit, ao-adamfp8, ao-adamwfp8, adamw_schedulefree, adamw_schedulefree+aggressive, adamw_schedulefree+no_kahan, optimi-stableadamw, optimi-adamw, optimi-lion, optimi-radam, optimi-ranger, optimi-adan, optimi-adam, optimi-sgd, soap, bnb-adagrad, bnb-adagrad8bit, bnb-adam, bnb-adam8bit, bnb-adamw, bnb-adamw8bit, bnb-adamw-paged, bnb-adamw8bit-paged, bnb-lion, bnb-lion8bit, bnb-lion-paged, bnb-lion8bit-paged, bnb-ademamix, bnb-ademamix8bit, bnb-ademamix-paged, bnb-ademamix8bit-paged, prodigy

> Note: Some optimisers may not be available on non-NVIDIA hardware.

### `--optimizer_config`

- **What**: Tweak optimizer settings.
- **Why**: Because optimizers have so many different settings, it's not feasible to provide a command-line argument for each one. Instead, you can provide a comma-separated list of values to override any of the default settings.
- **Example**: You may wish to set the `d_coef` for the **prodigy** optimizer: `--optimizer_config=d_coef=0.1`

> Note: Optimizer betas are overridden using dedicated parameters, `--optimizer_beta1`, `--optimizer_beta2`.

### `--train_batch_size`

- **What**: Batch size for the training data loader.
Expand All @@ -298,6 +313,8 @@ A lot of settings are instead set through the [dataloader config](/documentation
- **What**: Number of update steps to accumulate before performing a backward/update pass, essentially splitting the work over multiple batches to save memory at the cost of a higher training runtime.
- **Why**: Useful for handling larger models or datasets.

> Note: Do not enable fused backward pass for any optimizers when using gradient accumulation steps.

---

## 🛠 Advanced Optimizations
Expand Down Expand Up @@ -491,7 +508,7 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
[--ema_update_interval EMA_UPDATE_INTERVAL]
[--ema_decay EMA_DECAY] [--non_ema_revision NON_EMA_REVISION]
[--offload_param_path OFFLOAD_PARAM_PATH] --optimizer
{adamw_bf16,ao-adamw8bit,ao-adamw4bit,ao-adamfp8,ao-adamwfp8,adamw_schedulefree,adamw_schedulefree+aggressive,adamw_schedulefree+no_kahan,optimi-stableadamw,optimi-adamw,optimi-lion,optimi-radam,optimi-ranger,optimi-adan,optimi-adam,optimi-sgd,soap,bnb-adagrad,bnb-adagrad8bit,bnb-adam,bnb-adam8bit,bnb-adamw,bnb-adamw8bit,bnb-adamw-paged,bnb-adamw8bit-paged,bnb-lion,bnb-lion8bit,bnb-lion-paged,bnb-lion8bit-paged,bnb-ademamix,bnb-ademamix8bit,bnb-ademamix-paged,bnb-ademamix8bit-paged}
{adamw_bf16,ao-adamw8bit,ao-adamw4bit,ao-adamfp8,ao-adamwfp8,adamw_schedulefree,adamw_schedulefree+aggressive,adamw_schedulefree+no_kahan,optimi-stableadamw,optimi-adamw,optimi-lion,optimi-radam,optimi-ranger,optimi-adan,optimi-adam,optimi-sgd,soap,bnb-adagrad,bnb-adagrad8bit,bnb-adam,bnb-adam8bit,bnb-adamw,bnb-adamw8bit,bnb-adamw-paged,bnb-adamw8bit-paged,bnb-lion,bnb-lion8bit,bnb-lion-paged,bnb-lion8bit-paged,bnb-ademamix,bnb-ademamix8bit,bnb-ademamix-paged,bnb-ademamix8bit-paged,prodigy}
[--optimizer_config OPTIMIZER_CONFIG]
[--optimizer_cpu_offload_method {none}]
[--optimizer_offload_gradients] [--fuse_optimizer]
Expand All @@ -500,7 +517,8 @@ usage: train.py [-h] [--snr_gamma SNR_GAMMA] [--use_soft_min_snr]
[--optimizer_release_gradients] [--adam_beta1 ADAM_BETA1]
[--adam_beta2 ADAM_BETA2]
[--adam_weight_decay ADAM_WEIGHT_DECAY]
[--adam_epsilon ADAM_EPSILON] [--max_grad_norm MAX_GRAD_NORM]
[--adam_epsilon ADAM_EPSILON] [--prodigy_steps PRODIGY_STEPS]
[--max_grad_norm MAX_GRAD_NORM]
[--grad_clip_method {value,norm}] [--push_to_hub]
[--push_checkpoints_to_hub] [--hub_model_id HUB_MODEL_ID]
[--model_card_note MODEL_CARD_NOTE]
Expand Down Expand Up @@ -1239,7 +1257,7 @@ options:
When using DeepSpeed ZeRo stage 2 or 3 with NVMe
offload, this may be specified to provide a path for
the offload.
--optimizer {adamw_bf16,ao-adamw8bit,ao-adamw4bit,ao-adamfp8,ao-adamwfp8,adamw_schedulefree,adamw_schedulefree+aggressive,adamw_schedulefree+no_kahan,optimi-stableadamw,optimi-adamw,optimi-lion,optimi-radam,optimi-ranger,optimi-adan,optimi-adam,optimi-sgd,soap,bnb-adagrad,bnb-adagrad8bit,bnb-adam,bnb-adam8bit,bnb-adamw,bnb-adamw8bit,bnb-adamw-paged,bnb-adamw8bit-paged,bnb-lion,bnb-lion8bit,bnb-lion-paged,bnb-lion8bit-paged,bnb-ademamix,bnb-ademamix8bit,bnb-ademamix-paged,bnb-ademamix8bit-paged}
--optimizer {adamw_bf16,ao-adamw8bit,ao-adamw4bit,ao-adamfp8,ao-adamwfp8,adamw_schedulefree,adamw_schedulefree+aggressive,adamw_schedulefree+no_kahan,optimi-stableadamw,optimi-adamw,optimi-lion,optimi-radam,optimi-ranger,optimi-adan,optimi-adam,optimi-sgd,soap,bnb-adagrad,bnb-adagrad8bit,bnb-adam,bnb-adam8bit,bnb-adamw,bnb-adamw8bit,bnb-adamw-paged,bnb-adamw8bit-paged,bnb-lion,bnb-lion8bit,bnb-lion-paged,bnb-lion8bit-paged,bnb-ademamix,bnb-ademamix8bit,bnb-ademamix-paged,bnb-ademamix8bit-paged,prodigy}
--optimizer_config OPTIMIZER_CONFIG
When setting a given optimizer, this allows a comma-
separated list of key-value pairs to be provided that
Expand Down Expand Up @@ -1276,6 +1294,14 @@ options:
Weight decay to use.
--adam_epsilon ADAM_EPSILON
Epsilon value for the Adam optimizer
--prodigy_steps PRODIGY_STEPS
When training with Prodigy, this defines how many
steps it should be adjusting its learning rate for. It
seems to be that Diffusion models benefit from a
capping off of the adjustments after 25 percent of the
training run (dependent on batch size, repeats, and
epochs). It this value is not supplied, it will be
calculated at 25 percent of your training steps.
--max_grad_norm MAX_GRAD_NORM
Clipping the max gradient norm can help prevent
exploding gradients, but may also harm training by
Expand Down
11 changes: 11 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1336,6 +1336,17 @@ def get_argument_parser():
default=1e-08,
help="Epsilon value for the Adam optimizer",
)
parser.add_argument(
"--prodigy_steps",
type=int,
default=None,
help=(
"When training with Prodigy, this defines how many steps it should be adjusting its learning rate for."
" It seems to be that Diffusion models benefit from a capping off of the adjustments after 25 percent"
" of the training run (dependent on batch size, repeats, and epochs)."
" It this value is not supplied, it will be calculated at 25 percent of your training steps."
),
)
parser.add_argument(
"--max_grad_norm",
default=2.0,
Expand Down
59 changes: 58 additions & 1 deletion helpers/training/optimizer_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,18 @@
if "AdEMAMix" in dir(bitsandbytes.optim):
is_ademamix_available = True

is_prodigy_available = False
try:
import prodigyplus

is_prodigy_available = True
except:
if torch.cuda.is_available():
logger.warning(
"Could not load prodigyplus library. Prodigy will not be available."
)


optimizer_choices = {
"adamw_bf16": {
"precision": "bf16",
Expand Down Expand Up @@ -456,6 +468,42 @@
}
)

if is_prodigy_available:
optimizer_choices.update(
{
"prodigy": {
"precision": "any",
"override_lr_scheduler": True,
"can_warmup": False,
"default_settings": {
"lr": 1.0,
"betas": (0.9, 0.99),
"beta3": None,
"weight_decay": 0.0,
"weight_decay_by_lr": True,
"use_bias_correction": False,
"d0": 1e-6,
"d_coef": 1,
"prodigy_steps": 0,
"use_speed": False,
"eps": 1e-8,
"split_groups": True,
"split_groups_mean": True,
"factored": True,
"factored_fp32": True,
"fused_back_pass": False,
"use_stableadamw": True,
"use_muon_pp": False,
"use_cautious": False,
"use_grams": False,
"use_adopt": False,
"stochastic_rounding": True,
},
"class": prodigyplus.prodigy_plus_schedulefree.ProdigyPlusScheduleFree,
}
}
)

args_to_optimizer_mapping = {
"use_adafactor_optimizer": "adafactor",
"use_prodigy_optimizer": "prodigy",
Expand All @@ -465,7 +513,6 @@
}

deprecated_optimizers = {
"prodigy": "Prodigy optimiser has been removed due to issues with precision levels and convergence. Please use adamw_schedulefree instead.",
"dadaptation": "D-adaptation optimiser has been removed due to issues with precision levels and convergence. Please use adamw_schedulefree instead.",
"adafactor": "Adafactor optimiser has been removed in favour of optimi-stableadamw, which offers improved memory efficiency and convergence.",
"adamw8bit": "AdamW8Bit has been removed in favour of optimi-adamw optimiser, which offers better low-precision support. Please use this or adamw_bf16 instead.",
Expand Down Expand Up @@ -512,6 +559,16 @@ def optimizer_parameters(optimizer, args):
if args.optimizer_release_gradients and "optimi-" in optimizer:
optimizer_params["gradient_release"] = True
optimizer_details["default_settings"] = optimizer_params
if args.optimizer == "prodigy":
prodigy_steps = args.prodigy_steps
if prodigy_steps and prodigy_steps > 0:
optimizer_params["prodigy_steps"] = prodigy_steps
else:
# 25% of the total number of steps
optimizer_params["prodigy_steps"] = int(args.max_train_steps * 0.25)
print(
f"Using Prodigy optimiser with {optimizer_params['prodigy_steps']} steps of learning rate adjustment."
)
return optimizer_class, optimizer_details
else:
raise ValueError(f"Optimizer {optimizer} not found.")
Expand Down
44 changes: 30 additions & 14 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1232,9 +1232,7 @@ def init_optimizer(self):
def init_lr_scheduler(self):
self.config.is_schedulefree = is_lr_scheduler_disabled(self.config.optimizer)
if self.config.is_schedulefree:
logger.info(
"Using experimental AdamW ScheduleFree optimiser from Facebook. Experimental due to newly added Kahan summation."
)
logger.info("Using experimental ScheduleFree optimiser..")
# we don't use LR schedulers with schedulefree optimisers
lr_scheduler = None
if not self.config.use_deepspeed_scheduler and not self.config.is_schedulefree:
Expand Down Expand Up @@ -2778,12 +2776,14 @@ def train(self):
if param.grad is not None:
param.grad.data = param.grad.data.to(torch.float32)

self.grad_norm = self._max_grad_value()
if (
self.accelerator.sync_gradients
and self.config.optimizer != "optimi-stableadamw"
and self.config.optimizer
not in ["optimi-stableadamw", "prodigy"]
and self.config.max_grad_norm > 0
):
# StableAdamW does not need clipping, similar to Adafactor.
# StableAdamW/Prodigy do not need clipping, similar to Adafactor.
if self.config.grad_clip_method == "norm":
self.grad_norm = self.accelerator.clip_grad_norm_(
self._get_trainable_parameters(),
Expand All @@ -2793,7 +2793,6 @@ def train(self):
# deepspeed can only do norm clipping (internally)
pass
elif self.config.grad_clip_method == "value":
self.grad_norm = self._max_grad_value()
self.accelerator.clip_grad_value_(
self._get_trainable_parameters(),
self.config.max_grad_norm,
Expand Down Expand Up @@ -2824,7 +2823,22 @@ def train(self):
wandb_logs = {}
if self.accelerator.sync_gradients:
try:
if self.config.is_schedulefree:
if "prodigy" in self.config.optimizer:
self.lr = self.optimizer.param_groups[0]["d"]
wandb_logs.update(
{
"prodigy/d": self.optimizer.param_groups[0]["d"],
"prodigy/d_prev": self.optimizer.param_groups[0][
"d_prev"
],
"prodigy/d0": self.optimizer.param_groups[0]["d0"],
"prodigy/d_coef": self.optimizer.param_groups[0][
"d_coef"
],
"prodigy/k": self.optimizer.param_groups[0]["k"],
}
)
elif self.config.is_schedulefree:
# hackjob method of retrieving LR from accelerated optims
self.lr = StateTracker.get_last_lr()
else:
Expand All @@ -2834,12 +2848,14 @@ def train(self):
logger.error(
f"Failed to get the last learning rate from the scheduler. Error: {e}"
)
wandb_logs = {
"train_loss": self.train_loss,
"optimization_loss": loss,
"learning_rate": self.lr,
"epoch": epoch,
}
wandb_logs.update(
{
"train_loss": self.train_loss,
"optimization_loss": loss,
"learning_rate": self.lr,
"epoch": epoch,
}
)
if parent_loss is not None:
wandb_logs["regularisation_loss"] = parent_loss
if self.config.model_family == "flux" and self.guidance_values_list:
Expand All @@ -2850,7 +2866,7 @@ def train(self):
if self.grad_norm is not None:
if self.config.grad_clip_method == "norm":
wandb_logs["grad_norm"] = self.grad_norm
elif self.config.grad_clip_method == "value":
else:
wandb_logs["grad_absmax"] = self.grad_norm
if self.validation is not None and hasattr(
self.validation, "evaluation_result"
Expand Down
18 changes: 9 additions & 9 deletions helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,9 +918,9 @@ def would_validate(
validation_type="intermediary",
force_evaluation: bool = False,
):
# a wrapper for should_perform_validation that can run in the training loop
# a wrapper for should_perform_intermediary_validation that can run in the training loop
self._update_state()
return self.should_perform_validation(
return self.should_perform_intermediary_validation(
step, self.validation_prompts, validation_type
) or (step == 0 and validation_type == "base_model")

Expand All @@ -932,17 +932,18 @@ def run_validations(
skip_execution: bool = False,
):
self._update_state()
should_validate = self.should_perform_validation(
would_do_intermediary_validation = self.should_perform_intermediary_validation(
step, self.validation_prompts, validation_type
) or (step == 0 and validation_type == "base_model")
logger.debug(
f"Should evaluate: {should_validate}, force evaluation: {force_evaluation}, skip execution: {skip_execution}"
f"Should evaluate: {would_do_intermediary_validation}, force evaluation: {force_evaluation}, skip execution: {skip_execution}"
)
if not should_validate and not force_evaluation:
if not would_do_intermediary_validation and not force_evaluation:
return self
if should_validate and skip_execution:
if would_do_intermediary_validation and validation_type == "final":
# If the validation would have fired off, we'll skip it.
# This is useful at the end of training so we don't validate 2x.
logger.debug("Not running validation because intermediary might have already fired off.")
return self
if StateTracker.get_webhook_handler() is not None:
StateTracker.get_webhook_handler().send(
Expand Down Expand Up @@ -970,15 +971,14 @@ def run_validations(

return self

def should_perform_validation(self, step, validation_prompts, validation_type):
def should_perform_intermediary_validation(self, step, validation_prompts, validation_type):
should_do_intermediary_validation = (
validation_prompts
and self.global_step % self.args.validation_steps == 0
and step % self.args.gradient_accumulation_steps == 0
and self.global_step > self.global_resume_step
)
is_final_validation = validation_type == "final"
return (is_final_validation or should_do_intermediary_validation) and (
return should_do_intermediary_validation and (
self.accelerator.is_main_process or self.deepseed
)

Expand Down
Loading
Loading