Skip to content

Commit

Permalink
Merge pull request #1285 from bghira/feature/ignore-end-epochs
Browse files Browse the repository at this point in the history
add ignore_final_epochs to workaround epoch tracking oddness when changing dataloader length
  • Loading branch information
bghira authored Jan 20, 2025
2 parents 2248fd0 + c12b7d7 commit 79e2ced
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
9 changes: 9 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,15 @@ def get_argument_parser():
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--ignore_final_epochs",
action="store_true",
default=False,
help=(
"When provided, the max epoch counter will not determine the end of the training run."
" Instead, it will end when it hits --max_train_steps."
)
)
parser.add_argument(
"--checkpointing_steps",
type=int,
Expand Down
13 changes: 8 additions & 5 deletions helpers/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,7 +1618,7 @@ def init_resume_checkpoint(self, lr_scheduler):
* self.accelerator.num_processes
)

if self.state["current_epoch"] > self.config.num_train_epochs + 1:
if self.state["current_epoch"] > self.config.num_train_epochs + 1 and not self.config.ignore_final_epochs:
logger.info(
f"Reached the end ({self.state['current_epoch']} epochs) of our training run ({self.config.num_train_epochs} epochs). This run will do zero steps."
)
Expand Down Expand Up @@ -2305,8 +2305,11 @@ def train(self):
current_epoch_step = None
self.bf, fetch_thread = None, None
iterator_fn = random_dataloader_iterator
for epoch in range(self.state["first_epoch"], self.config.num_train_epochs + 1):
if self.state["current_epoch"] > self.config.num_train_epochs + 1:
num_epochs_to_track = self.config.num_train_epochs + 1
if self.config.ignore_final_epochs:
num_epochs_to_track += 1000000
for epoch in range(self.state["first_epoch"], num_epochs_to_track):
if self.state["current_epoch"] > self.config.num_train_epochs + 1 and not self.config.ignore_final_epochs:
# This might immediately end training, but that's useful for simply exporting the model.
logger.info(
f"Training run is complete ({self.config.num_train_epochs}/{self.config.num_train_epochs} epochs, {self.state['global_step']}/{self.config.max_train_steps} steps)."
Expand Down Expand Up @@ -3060,7 +3063,7 @@ def train(self):

if (
self.state["global_step"] >= self.config.max_train_steps
or epoch > self.config.num_train_epochs
or (epoch > self.config.num_train_epochs and not self.config.ignore_final_epochs)
):
logger.info(
f"Training has completed."
Expand All @@ -3069,7 +3072,7 @@ def train(self):
break
if (
self.state["global_step"] >= self.config.max_train_steps
or epoch > self.config.num_train_epochs
or (epoch > self.config.num_train_epochs and not self.config.ignore_final_epochs)
):
logger.info(
f"Exiting training loop. Beginning model unwind at epoch {epoch}, step {self.state['global_step']}"
Expand Down

0 comments on commit 79e2ced

Please sign in to comment.