Skip to content

Commit

Permalink
Add parameter for the total_train_steps (separate from train_steps) f…
Browse files Browse the repository at this point in the history
…or training jobs that are split across multiple launches.

PiperOrigin-RevId: 347615010
  • Loading branch information
katelee168 authored and Mesh TensorFlow Team committed Dec 15, 2020
1 parent 57f038d commit d914606
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions mesh_tensorflow/transformer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2261,6 +2261,7 @@ def run(tpu_job_name,
eval_summary_dir=None,
batch_size=("tokens_per_replica", 2048),
train_steps=auto_train_steps,
total_run_steps=None,
sequence_length=None,
mesh_shape=gin.REQUIRED,
mesh_devices=None,
Expand Down Expand Up @@ -2315,7 +2316,10 @@ def run(tpu_job_name,
compute_batch_size(). Note that this is the global batch size and not the
per-shard batch size.
train_steps: An integer or a function with the same signature as
auto_train_steps(). Total number of training steps.
auto_train_steps(). Total number of training steps in this run.
total_run_steps: An integer, used when training is split over multiple
runs. This value is gin-configurable and used to set the total_run_steps
for the learning_rate_schedule.
sequence_length: an integer or a dict from feature-key to integer
the (packed) sequence length, e.g. {"inputs": 512, "targets": 128}.
May also be set to `None` in eval mode to automatically compute the
Expand Down Expand Up @@ -2353,20 +2357,24 @@ def run(tpu_job_name,
if not isinstance(train_steps, int):
train_steps = train_steps(batch_size, sequence_length)

if total_run_steps is None:
total_run_steps = train_steps
if isinstance(learning_rate_schedule, list):
learning_rate_schedule = functools.partial(
learning_rate_schedules.product_learning_rate,
factors=learning_rate_schedule)
total_train_steps=total_run_steps, factors=learning_rate_schedule)

if callable(learning_rate_schedule):
learning_rate_schedule = functools.partial(
learning_rate_schedule, total_train_steps=train_steps)
learning_rate_schedule, total_train_steps=total_run_steps)

tf.logging.info("model_type=%s" % model_type,)
tf.logging.info("mode=%s" % mode,)
tf.logging.info("sequence_length=%s" % sequence_length,)
tf.logging.info("batch_size=%s" % batch_size,)
tf.logging.info("train_steps=%s" % train_steps,)
if total_run_steps is not None:
tf.logging.info("total_run_steps=%s" % total_run_steps,)
tf.logging.info("mesh_shape=%s" % mesh_shape,)
tf.logging.info("layout_rules=%s" % layout_rules,)

Expand Down

0 comments on commit d914606

Please sign in to comment.