Skip to content

Commit

Permalink
Merge pull request #1300 from bghira/chore/lycoris-defaults-update
Browse files Browse the repository at this point in the history
update lycoris defaults, fix regularised training resume
  • Loading branch information
bghira authored Jan 25, 2025
2 parents 4220e58 + cf0c689 commit d67aabe
Show file tree
Hide file tree
Showing 3 changed files with 376 additions and 15 deletions.
2 changes: 1 addition & 1 deletion config/lycoris_config.json.example
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"algo": "lokr",
"multiplier": 1.0,
"linear_dim": 10000,
"full_matrix": true,
"linear_alpha": 1,
"factor": 16,
"apply_preset": {
Expand Down
33 changes: 22 additions & 11 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,8 @@ def _recalculate_training_steps(self):
"You must specify either --max_train_steps or --num_train_epochs with a value > 0"
)
self.config.num_train_epochs = math.ceil(
self.config.max_train_steps / max(self.config.num_update_steps_per_epoch, 1)
self.config.max_train_steps
/ max(self.config.num_update_steps_per_epoch, 1)
)
logger.info(
f"Calculated our maximum training steps at {self.config.max_train_steps} because we have"
Expand Down Expand Up @@ -1616,7 +1617,10 @@ def init_resume_checkpoint(self, lr_scheduler):
* self.accelerator.num_processes
)

if self.state["current_epoch"] > self.config.num_train_epochs + 1 and not self.config.ignore_final_epochs:
if (
self.state["current_epoch"] > self.config.num_train_epochs + 1
and not self.config.ignore_final_epochs
):
logger.info(
f"Reached the end ({self.state['current_epoch']} epochs) of our training run ({self.config.num_train_epochs} epochs). This run will do zero steps."
)
Expand Down Expand Up @@ -2307,7 +2311,10 @@ def train(self):
if self.config.ignore_final_epochs:
num_epochs_to_track += 1000000
for epoch in range(self.state["first_epoch"], num_epochs_to_track):
if self.state["current_epoch"] > self.config.num_train_epochs + 1 and not self.config.ignore_final_epochs:
if (
self.state["current_epoch"] > self.config.num_train_epochs + 1
and not self.config.ignore_final_epochs
):
# This might immediately end training, but that's useful for simply exporting the model.
logger.info(
f"Training run is complete ({self.config.num_train_epochs}/{self.config.num_train_epochs} epochs, {self.state['global_step']}/{self.config.max_train_steps} steps)."
Expand Down Expand Up @@ -2633,7 +2640,9 @@ def train(self):
training_logger.debug(
"Detaching LyCORIS adapter for parent prediction."
)
self.accelerator._lycoris_wrapped_network.restore()
self.accelerator._lycoris_wrapped_network.set_multiplier(
0.0
)
else:
raise ValueError(
f"Cannot train parent-student networks on {self.config.lora_type} model. Only LyCORIS is supported."
Expand All @@ -2651,7 +2660,9 @@ def train(self):
training_logger.debug(
"Attaching LyCORIS adapter for student prediction."
)
self.accelerator._lycoris_wrapped_network.apply_to()
self.accelerator._lycoris_wrapped_network.set_multiplier(
1.0
)

training_logger.debug("Predicting noise residual.")
model_pred = self.model_predict(
Expand Down Expand Up @@ -3077,18 +3088,18 @@ def train(self):
)
self.accelerator.wait_for_everyone()

if (
self.state["global_step"] >= self.config.max_train_steps
or (epoch > self.config.num_train_epochs and not self.config.ignore_final_epochs)
if self.state["global_step"] >= self.config.max_train_steps or (
epoch > self.config.num_train_epochs
and not self.config.ignore_final_epochs
):
logger.info(
f"Training has completed."
f"\n -> global_step = {self.state['global_step']}, max_train_steps = {self.config.max_train_steps}, epoch = {epoch}, num_train_epochs = {self.config.num_train_epochs}",
)
break
if (
self.state["global_step"] >= self.config.max_train_steps
or (epoch > self.config.num_train_epochs and not self.config.ignore_final_epochs)
if self.state["global_step"] >= self.config.max_train_steps or (
epoch > self.config.num_train_epochs
and not self.config.ignore_final_epochs
):
logger.info(
f"Exiting training loop. Beginning model unwind at epoch {epoch}, step {self.state['global_step']}"
Expand Down
Loading

0 comments on commit d67aabe

Please sign in to comment.