Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 325088513
  • Loading branch information
dhr authored and tensorflower-gardener committed Aug 5, 2020
1 parent 29d45e8 commit 0edeca5
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 75 deletions.
16 changes: 11 additions & 5 deletions official/vision/image_classification/resnet/resnet_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,12 @@ def __init__(self, flags_obj, time_callback, epoch_steps):
.datasets_num_private_threads,
dtype=self.dtype,
drop_remainder=True)
orbit.StandardTrainer.__init__(self, train_dataset,
flags_obj.use_tf_while_loop,
flags_obj.use_tf_function)
orbit.StandardTrainer.__init__(
self,
train_dataset,
options=orbit.StandardTrainerOptions(
use_tf_while_loop=flags_obj.use_tf_while_loop,
use_tf_function=flags_obj.use_tf_function))
if not flags_obj.skip_eval:
eval_dataset = orbit.utils.make_distributed_dataset(
self.strategy,
Expand All @@ -119,8 +122,11 @@ def __init__(self, flags_obj, time_callback, epoch_steps):
batch_size=self.batch_size,
parse_record_fn=imagenet_preprocessing.parse_record,
dtype=self.dtype)
orbit.StandardEvaluator.__init__(self, eval_dataset,
flags_obj.use_tf_function)
orbit.StandardEvaluator.__init__(
self,
eval_dataset,
options=orbit.StandardEvaluatorOptions(
use_tf_function=flags_obj.use_tf_function))

def train_loop_begin(self):
"""See base class."""
Expand Down
5 changes: 4 additions & 1 deletion orbit/controller_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ def __init__(self):
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
standard_runner.StandardTrainer.__init__(
self, train_dataset, use_tpu_summary_optimization=True)
self,
train_dataset,
options=standard_runner.StandardTrainerOptions(
use_tpu_summary_optimization=True))

def build_train_dataset(self):
return self.strategy.experimental_distribute_datasets_from_function(
Expand Down
74 changes: 34 additions & 40 deletions orbit/standard_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,22 @@


@dataclasses.dataclass(frozen=True)
class TrainerOverrides:
"""Advanced overrides for Orbit trainers.
class StandardTrainerOptions:
"""Advanced options for `orbit.StandardTrainer`.
Attributes:
use_tf_while_loop: A boolean indicates whether to wrap the train step with
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
use_tf_while_loop: A boolean indicating whether to run the training loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
use_tf_function: A boolean indicating whether to apply `tf.function` to the
training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode.
use_tpu_summary_optimization: A boolean indicating whether to enable a
performance optimization for summaries in TPUs. Writing summaries
conditionally with outside compilation on TPUs can be extremely slow. If
`True`, this optimization creates two `tf.function`s with two XLA programs
(one with summary calls, and one without). The program with summaries runs
only for one step when summaries should be recorded.
"""
use_tf_while_loop: bool = True
use_tf_function: bool = True
Expand All @@ -46,39 +48,29 @@ class TrainerOverrides:
class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractTrainer APIs."""

def __init__(self,
train_dataset,
use_tf_while_loop=True,
use_tf_function=True,
use_tpu_summary_optimization=False):
def __init__(self, train_dataset, options: StandardTrainerOptions = None):
"""Construct a `StandardTrainer` object.
Args:
train_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
use_tf_while_loop: A boolean indicates whether to wrap the train step with
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
options: An `orbit.StandardTrainerOptions` instance.
"""
if use_tf_while_loop and not use_tf_function:
options = options or StandardTrainerOptions()
if options.use_tf_while_loop and not options.use_tf_function:
raise ValueError("`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported")
if use_tpu_summary_optimization and not use_tf_while_loop:
if options.use_tpu_summary_optimization and not options.use_tf_while_loop:
raise ValueError("`use_tpu_summary_optimization=True` and "
"`use_tf_while_loop=False` is not supported")
self._use_tf_while_loop = use_tf_while_loop
self._use_tf_function = use_tf_function

self._use_tf_while_loop = options.use_tf_while_loop
self._use_tf_function = options.use_tf_function
self._use_tpu_summary_optimization = options.use_tpu_summary_optimization

self._train_dataset = train_dataset
self._train_iter = None
self._train_loop_fn = None
self._use_tpu_summary_optimization = use_tpu_summary_optimization

def train(self,
num_steps: Optional[tf.Tensor]) -> Optional[Dict[Text, tf.Tensor]]:
Expand Down Expand Up @@ -168,29 +160,31 @@ def train_dataset(self, train_dataset):


@dataclasses.dataclass(frozen=True)
class EvaluatorOverrides:
"""Advanced overrides for Orbit evaluators.
class StandardEvaluatorOptions:
"""Advanced options for the `orbit.StandardEvaluator`.
Attributes:
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tf_function: A boolean indicating whether to apply `tf.function` to the
training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode.
"""
use_tf_function: bool = True


class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
"""Implements the standard functionality of AbstractEvaluator APIs."""

def __init__(self, eval_dataset, use_tf_function=True):
def __init__(self, eval_dataset, options: StandardEvaluatorOptions = None):
"""Construct a `StandardEvaluator` object.
Args:
eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, evaluation will run on pure eager mode.
options: An `orbit.StandardEvaluatorOptions` instance.
"""
self._eval_use_tf_function = use_tf_function
options = options or StandardEvaluatorOptions()
self._eval_use_tf_function = options.use_tf_function
self._eval_dataset = eval_dataset
self._eval_loop_fn = None

Expand Down
66 changes: 37 additions & 29 deletions orbit/standard_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tests for orbit.standard_runner."""

from orbit import standard_runner
from orbit import utils

import tensorflow as tf

Expand All @@ -32,62 +33,69 @@ def dummy_data(_):
return dataset


class TestRunner(standard_runner.StandardTrainer,
standard_runner.StandardEvaluator):
"""Implements the training and evaluation APIs for tests."""
class TestTrainer(standard_runner.StandardTrainer):
"""A StandardTrainer subclass for tests."""

def __init__(self):
def __init__(self, options=None):
self.strategy = tf.distribute.get_strategy()
self.global_step = tf.Variable(
0,
trainable=False,
dtype=tf.int64,
name='global_step',
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
standard_runner.StandardTrainer.__init__(self, train_dataset=None)
standard_runner.StandardEvaluator.__init__(self, eval_dataset=None)
self.global_step = utils.create_global_step()
distribute = self.strategy.experimental_distribute_datasets_from_function
dataset = distribute(dataset_fn)
super().__init__(train_dataset=dataset, options=options)

def train_loop_begin(self):
self.train_dataset = (
self.strategy.experimental_distribute_datasets_from_function(dataset_fn)
)
self.global_step.assign(0)

def train_step(self, iterator):

def _replicated_step(_):
def replica_step(_):
self.global_step.assign_add(1)

self.strategy.run(_replicated_step, args=(next(iterator),))
self.strategy.run(replica_step, args=(next(iterator),))

def train_loop_end(self):
return self.global_step.numpy()


class TestEvaluator(standard_runner.StandardEvaluator):
"""A StandardEvaluator subclass for tests."""

def __init__(self, options=None):
self.strategy = tf.distribute.get_strategy()
self.global_step = utils.create_global_step()
distribute = self.strategy.experimental_distribute_datasets_from_function
dataset = distribute(dataset_fn)
super().__init__(eval_dataset=dataset, options=options)

def eval_begin(self):
self.eval_dataset = self.strategy.experimental_distribute_datasets_from_function(
dataset_fn)
self.global_step.assign(0)

def eval_step(self, iterator):

def _replicated_step(_):
def replica_step(_):
self.global_step.assign_add(1)

self.strategy.run(_replicated_step, args=(next(iterator),))
self.strategy.run(replica_step, args=(next(iterator),))

def eval_end(self):
return self.global_step.numpy()


class StandardRunnerTest(tf.test.TestCase):

def test_train(self):
test_runner = TestRunner()
self.assertEqual(
test_runner.train(tf.convert_to_tensor(10, dtype=tf.int32)), 10)
def test_default_trainer(self):
trainer = TestTrainer()
self.assertEqual(trainer.train(tf.constant(10)), 10)

def test_trainer_with_tpu_summary_optimization(self):
options = standard_runner.StandardTrainerOptions(
use_tpu_summary_optimization=True)
trainer = TestTrainer(options)
self.assertEqual(trainer.train(tf.constant(10)), 10)

def test_eval(self):
test_runner = TestRunner()
self.assertEqual(
test_runner.evaluate(tf.convert_to_tensor(10, dtype=tf.int32)), 10)
def test_default_evaluator(self):
evaluator = TestEvaluator()
self.assertEqual(evaluator.evaluate(tf.constant(10)), 10)


if __name__ == '__main__':
Expand Down

0 comments on commit 0edeca5

Please sign in to comment.