From 784efeb3a04a444dbfe8c1801f573c358108b8fa Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Thu, 4 Oct 2018 17:30:21 +0200 Subject: [PATCH] Move counters hook inside model_fn --- opennmt/models/model.py | 18 ++++++++++++------ opennmt/runner.py | 5 +---- opennmt/tests/hooks_test.py | 3 +-- opennmt/utils/hooks.py | 11 +++++++++-- 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/opennmt/models/model.py b/opennmt/models/model.py index da0bf64ff..33f25a730 100644 --- a/opennmt/models/model.py +++ b/opennmt/models/model.py @@ -7,9 +7,8 @@ import tensorflow as tf -from opennmt.utils import data +from opennmt.utils import data, hooks from opennmt.utils.optim import optimize -from opennmt.utils.hooks import add_counter from opennmt.utils.misc import item_or_tuple from opennmt.utils.parallel import GraphDispatcher @@ -96,7 +95,11 @@ def _extract_loss(loss): def _model_fn(features, labels, params, mode, config): """model_fn implementation.""" if mode == tf.estimator.ModeKeys.TRAIN: - self._register_word_counters(features, labels) + counters = self._register_word_counters(features, labels) + counters_hook = hooks.CountersHook( + every_n_steps=config.save_summary_steps, + output_dir=config.model_dir, + counters=counters) features_shards = dispatcher.shard(features) labels_shards = dispatcher.shard(labels) @@ -110,7 +113,8 @@ def _model_fn(features, labels, params, mode, config): return tf.estimator.EstimatorSpec( mode, loss=loss, - train_op=train_op) + train_op=train_op, + training_hooks=[counters_hook]) elif mode == tf.estimator.ModeKeys.EVAL: with tf.variable_scope(self.name): logits, predictions = self._build(features, labels, params, mode, config=config) @@ -207,11 +211,13 @@ def _register_word_counters(self, features, labels): features_length = self._get_features_length(features) labels_length = self._get_labels_length(labels) + counters = [] with tf.variable_scope("words_per_sec"): if features_length is not None: - add_counter("features", tf.reduce_sum(features_length)) + counters.append(hooks.add_counter("features", tf.reduce_sum(features_length))) if labels_length is not None: - add_counter("labels", tf.reduce_sum(labels_length)) + counters.append(hooks.add_counter("labels", tf.reduce_sum(labels_length))) + return counters def _initialize(self, metadata): """Runs model specific initialization (e.g. vocabularies loading). diff --git a/opennmt/runner.py b/opennmt/runner.py index 4d7ecdced..d2f499b9d 100644 --- a/opennmt/runner.py +++ b/opennmt/runner.py @@ -121,10 +121,7 @@ def _make_eval_prediction_hooks_fn(self): def _build_train_spec(self): train_hooks = [ - hooks.LogParametersCountHook(), - hooks.CountersHook( - every_n_steps=self._estimator.config.save_summary_steps, - output_dir=self._estimator.model_dir)] + hooks.LogParametersCountHook()] train_spec = tf.estimator.TrainSpec( input_fn=self._model.input_fn( diff --git a/opennmt/tests/hooks_test.py b/opennmt/tests/hooks_test.py index 6b1f4f54d..e54bc0007 100644 --- a/opennmt/tests/hooks_test.py +++ b/opennmt/tests/hooks_test.py @@ -7,8 +7,7 @@ class HooksTest(tf.test.TestCase): def testAddCounter(self): a = tf.placeholder(tf.int64, shape=[]) - hooks.add_counter("sum_a", a) - sum_a = tf.get_collection(hooks._DEFAULT_COUNTERS_COLLECTION) + sum_a = hooks.add_counter("sum_a", a) self.assertIsNotNone(sum_a) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) diff --git a/opennmt/utils/hooks.py b/opennmt/utils/hooks.py index 9c5b35222..e858890ea 100644 --- a/opennmt/utils/hooks.py +++ b/opennmt/utils/hooks.py @@ -28,6 +28,9 @@ def add_counter(name, tensor): name: The name of this counter. tensor: The integer ``tf.Tensor`` to count. + Returns: + An op that increments the counter. + See Also: :meth:`opennmt.utils.misc.WordCounterHook` that fetches these counters to log their value in TensorBoard. @@ -43,6 +46,7 @@ def add_counter(name, tensor): count, name=name) tf.add_to_collection(_DEFAULT_COUNTERS_COLLECTION, total_count) + return total_count class CountersHook(tf.train.SessionRunHook): @@ -55,7 +59,8 @@ def __init__(self, every_n_steps=100, every_n_secs=None, output_dir=None, - summary_writer=None): + summary_writer=None, + counters=None): if (every_n_steps is None) == (every_n_secs is None): raise ValueError("exactly one of every_n_steps and every_n_secs should be provided.") self._timer = tf.train.SecondOrStepTimer( @@ -64,9 +69,11 @@ def __init__(self, self._summary_writer = summary_writer self._output_dir = output_dir + self._counters = counters def begin(self): - self._counters = tf.get_collection(_DEFAULT_COUNTERS_COLLECTION) + if self._counters is None: + self._counters = tf.get_collection(_DEFAULT_COUNTERS_COLLECTION) if not self._counters: return