diff --git a/official/vision/image_classification/resnet/resnet_runnable.py b/official/vision/image_classification/resnet/resnet_runnable.py index e3b49200fba..5b898593e11 100644 --- a/official/vision/image_classification/resnet/resnet_runnable.py +++ b/official/vision/image_classification/resnet/resnet_runnable.py @@ -107,9 +107,12 @@ def __init__(self, flags_obj, time_callback, epoch_steps): .datasets_num_private_threads, dtype=self.dtype, drop_remainder=True) - orbit.StandardTrainer.__init__(self, train_dataset, - flags_obj.use_tf_while_loop, - flags_obj.use_tf_function) + orbit.StandardTrainer.__init__( + self, + train_dataset, + options=orbit.StandardTrainerOptions( + use_tf_while_loop=flags_obj.use_tf_while_loop, + use_tf_function=flags_obj.use_tf_function)) if not flags_obj.skip_eval: eval_dataset = orbit.utils.make_distributed_dataset( self.strategy, @@ -119,8 +122,11 @@ def __init__(self, flags_obj, time_callback, epoch_steps): batch_size=self.batch_size, parse_record_fn=imagenet_preprocessing.parse_record, dtype=self.dtype) - orbit.StandardEvaluator.__init__(self, eval_dataset, - flags_obj.use_tf_function) + orbit.StandardEvaluator.__init__( + self, + eval_dataset, + options=orbit.StandardEvaluatorOptions( + use_tf_function=flags_obj.use_tf_function)) def train_loop_begin(self): """See base class.""" diff --git a/orbit/controller_test.py b/orbit/controller_test.py index 6751e902557..8472de4fba5 100644 --- a/orbit/controller_test.py +++ b/orbit/controller_test.py @@ -221,7 +221,10 @@ def __init__(self): self.strategy.experimental_distribute_datasets_from_function(dataset_fn) ) standard_runner.StandardTrainer.__init__( - self, train_dataset, use_tpu_summary_optimization=True) + self, + train_dataset, + options=standard_runner.StandardTrainerOptions( + use_tpu_summary_optimization=True)) def build_train_dataset(self): return self.strategy.experimental_distribute_datasets_from_function( diff --git a/orbit/standard_runner.py b/orbit/standard_runner.py index 1d37f2cc90a..8d3099b388c 100644 --- a/orbit/standard_runner.py +++ b/orbit/standard_runner.py @@ -23,20 +23,22 @@ @dataclasses.dataclass(frozen=True) -class TrainerOverrides: - """Advanced overrides for Orbit trainers. +class StandardTrainerOptions: + """Advanced options for `orbit.StandardTrainer`. Attributes: - use_tf_while_loop: A boolean indicates whether to wrap the train step with - a `tf.while_loop`. - use_tf_function: A boolean indicates whether a `tf.function` will be used. - If False, training will run on pure eager mode. - use_tpu_summary_optimization: A boolean indicates whether to enable the - performance optimization for summaries in TPUs. In TPUs, writing - summaries with outside compilation inside train step is slow. If True, - it creates two `tf.function` with two XLA programs: one with summaries - and one without, and run the program with summaries (slow one) only if - necessary. + use_tf_while_loop: A boolean indicating whether to run the training loop + using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`. + use_tf_function: A boolean indicating whether to apply `tf.function` to the + training loop. This will only affect the body of the loop (involving + `train_step`); `train_loop_begin` and `train_loop_end` will always be run + in eager mode. + use_tpu_summary_optimization: A boolean indicating whether to enable a + performance optimization for summaries in TPUs. Writing summaries + conditionally with outside compilation on TPUs can be extremely slow. If + `True`, this optimization creates two `tf.function`s with two XLA programs + (one with summary calls, and one without). The program with summaries runs + only for one step when summaries should be recorded. """ use_tf_while_loop: bool = True use_tf_function: bool = True @@ -46,39 +48,29 @@ class TrainerOverrides: class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta): """Implements the standard functionality of AbstractTrainer APIs.""" - def __init__(self, - train_dataset, - use_tf_while_loop=True, - use_tf_function=True, - use_tpu_summary_optimization=False): + def __init__(self, train_dataset, options: StandardTrainerOptions = None): """Construct a `StandardTrainer` object. Args: train_dataset: A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset. - use_tf_while_loop: A boolean indicates whether to wrap the train step with - a `tf.while_loop`. - use_tf_function: A boolean indicates whether a `tf.function` will be used. - If False, training will run on pure eager mode. - use_tpu_summary_optimization: A boolean indicates whether to enable the - performance optimization for summaries in TPUs. In TPUs, writing - summaries with outside compilation inside train step is slow. If True, - it creates two `tf.function` with two XLA programs: one with summaries - and one without, and run the program with summaries (slow one) only if - necessary. + options: An `orbit.StandardTrainerOptions` instance. """ - if use_tf_while_loop and not use_tf_function: + options = options or StandardTrainerOptions() + if options.use_tf_while_loop and not options.use_tf_function: raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` " "is not supported") - if use_tpu_summary_optimization and not use_tf_while_loop: + if options.use_tpu_summary_optimization and not options.use_tf_while_loop: raise ValueError("`use_tpu_summary_optimization=True` and " "`use_tf_while_loop=False` is not supported") - self._use_tf_while_loop = use_tf_while_loop - self._use_tf_function = use_tf_function + + self._use_tf_while_loop = options.use_tf_while_loop + self._use_tf_function = options.use_tf_function + self._use_tpu_summary_optimization = options.use_tpu_summary_optimization + self._train_dataset = train_dataset self._train_iter = None self._train_loop_fn = None - self._use_tpu_summary_optimization = use_tpu_summary_optimization def train(self, num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]: @@ -168,12 +160,14 @@ def train_dataset(self, train_dataset): @dataclasses.dataclass(frozen=True) -class EvaluatorOverrides: - """Advanced overrides for Orbit evaluators. +class StandardEvaluatorOptions: + """Advanced options for the `orbit.StandardEvaluator`. Attributes: - use_tf_function: A boolean indicates whether a `tf.function` will be used. - If False, training will run on pure eager mode. + use_tf_function: A boolean indicating whether to apply `tf.function` to the + training loop. This will only affect the body of the loop (involving + `train_step`); `train_loop_begin` and `train_loop_end` will always be run + in eager mode. """ use_tf_function: bool = True @@ -181,16 +175,16 @@ class EvaluatorOverrides: class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta): """Implements the standard functionality of AbstractEvaluator APIs.""" - def __init__(self, eval_dataset, use_tf_function=True): + def __init__(self, eval_dataset, options: StandardEvaluatorOptions = None): """Construct a `StandardEvaluator` object. Args: eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or DistributedDataset. - use_tf_function: A boolean indicates whether a `tf.function` will be used. - If False, evaluation will run on pure eager mode. + options: An `orbit.StandardEvaluatorOptions` instance. """ - self._eval_use_tf_function = use_tf_function + options = options or StandardEvaluatorOptions() + self._eval_use_tf_function = options.use_tf_function self._eval_dataset = eval_dataset self._eval_loop_fn = None diff --git a/orbit/standard_runner_test.py b/orbit/standard_runner_test.py index fb98a715df4..7adf65bc76b 100644 --- a/orbit/standard_runner_test.py +++ b/orbit/standard_runner_test.py @@ -15,6 +15,7 @@ """Tests for orbit.standard_runner.""" from orbit import standard_runner +from orbit import utils import tensorflow as tf @@ -32,46 +33,49 @@ def dummy_data(_): return dataset -class TestRunner(standard_runner.StandardTrainer, - standard_runner.StandardEvaluator): - """Implements the training and evaluation APIs for tests.""" +class TestTrainer(standard_runner.StandardTrainer): + """A StandardTrainer subclass for tests.""" - def __init__(self): + def __init__(self, options=None): self.strategy = tf.distribute.get_strategy() - self.global_step = tf.Variable( - 0, - trainable=False, - dtype=tf.int64, - name='global_step', - aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) - standard_runner.StandardTrainer.__init__(self, train_dataset=None) - standard_runner.StandardEvaluator.__init__(self, eval_dataset=None) + self.global_step = utils.create_global_step() + distribute = self.strategy.experimental_distribute_datasets_from_function + dataset = distribute(dataset_fn) + super().__init__(train_dataset=dataset, options=options) def train_loop_begin(self): - self.train_dataset = ( - self.strategy.experimental_distribute_datasets_from_function(dataset_fn) - ) + self.global_step.assign(0) def train_step(self, iterator): - def _replicated_step(_): + def replica_step(_): self.global_step.assign_add(1) - self.strategy.run(_replicated_step, args=(next(iterator),)) + self.strategy.run(replica_step, args=(next(iterator),)) def train_loop_end(self): return self.global_step.numpy() + +class TestEvaluator(standard_runner.StandardEvaluator): + """A StandardEvaluator subclass for tests.""" + + def __init__(self, options=None): + self.strategy = tf.distribute.get_strategy() + self.global_step = utils.create_global_step() + distribute = self.strategy.experimental_distribute_datasets_from_function + dataset = distribute(dataset_fn) + super().__init__(eval_dataset=dataset, options=options) + def eval_begin(self): - self.eval_dataset = self.strategy.experimental_distribute_datasets_from_function( - dataset_fn) + self.global_step.assign(0) def eval_step(self, iterator): - def _replicated_step(_): + def replica_step(_): self.global_step.assign_add(1) - self.strategy.run(_replicated_step, args=(next(iterator),)) + self.strategy.run(replica_step, args=(next(iterator),)) def eval_end(self): return self.global_step.numpy() @@ -79,15 +83,19 @@ def eval_end(self): class StandardRunnerTest(tf.test.TestCase): - def test_train(self): - test_runner = TestRunner() - self.assertEqual( - test_runner.train(tf.convert_to_tensor(10, dtype=tf.int32)), 10) + def test_default_trainer(self): + trainer = TestTrainer() + self.assertEqual(trainer.train(tf.constant(10)), 10) + + def test_trainer_with_tpu_summary_optimization(self): + options = standard_runner.StandardTrainerOptions( + use_tpu_summary_optimization=True) + trainer = TestTrainer(options) + self.assertEqual(trainer.train(tf.constant(10)), 10) - def test_eval(self): - test_runner = TestRunner() - self.assertEqual( - test_runner.evaluate(tf.convert_to_tensor(10, dtype=tf.int32)), 10) + def test_default_evaluator(self): + evaluator = TestEvaluator() + self.assertEqual(evaluator.evaluate(tf.constant(10)), 10) if __name__ == '__main__':