Skip to content

Commit

Permalink
Move counters hook inside model_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Oct 4, 2018
1 parent 5bbc7d6 commit 784efeb
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 14 deletions.
18 changes: 12 additions & 6 deletions opennmt/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

import tensorflow as tf

from opennmt.utils import data
from opennmt.utils import data, hooks
from opennmt.utils.optim import optimize
from opennmt.utils.hooks import add_counter
from opennmt.utils.misc import item_or_tuple
from opennmt.utils.parallel import GraphDispatcher

Expand Down Expand Up @@ -96,7 +95,11 @@ def _extract_loss(loss):
def _model_fn(features, labels, params, mode, config):
"""model_fn implementation."""
if mode == tf.estimator.ModeKeys.TRAIN:
self._register_word_counters(features, labels)
counters = self._register_word_counters(features, labels)
counters_hook = hooks.CountersHook(
every_n_steps=config.save_summary_steps,
output_dir=config.model_dir,
counters=counters)

features_shards = dispatcher.shard(features)
labels_shards = dispatcher.shard(labels)
Expand All @@ -110,7 +113,8 @@ def _model_fn(features, labels, params, mode, config):
return tf.estimator.EstimatorSpec(
mode,
loss=loss,
train_op=train_op)
train_op=train_op,
training_hooks=[counters_hook])
elif mode == tf.estimator.ModeKeys.EVAL:
with tf.variable_scope(self.name):
logits, predictions = self._build(features, labels, params, mode, config=config)
Expand Down Expand Up @@ -207,11 +211,13 @@ def _register_word_counters(self, features, labels):
features_length = self._get_features_length(features)
labels_length = self._get_labels_length(labels)

counters = []
with tf.variable_scope("words_per_sec"):
if features_length is not None:
add_counter("features", tf.reduce_sum(features_length))
counters.append(hooks.add_counter("features", tf.reduce_sum(features_length)))
if labels_length is not None:
add_counter("labels", tf.reduce_sum(labels_length))
counters.append(hooks.add_counter("labels", tf.reduce_sum(labels_length)))
return counters

def _initialize(self, metadata):
"""Runs model specific initialization (e.g. vocabularies loading).
Expand Down
5 changes: 1 addition & 4 deletions opennmt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,7 @@ def _make_eval_prediction_hooks_fn(self):

def _build_train_spec(self):
train_hooks = [
hooks.LogParametersCountHook(),
hooks.CountersHook(
every_n_steps=self._estimator.config.save_summary_steps,
output_dir=self._estimator.model_dir)]
hooks.LogParametersCountHook()]

train_spec = tf.estimator.TrainSpec(
input_fn=self._model.input_fn(
Expand Down
3 changes: 1 addition & 2 deletions opennmt/tests/hooks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ class HooksTest(tf.test.TestCase):

def testAddCounter(self):
a = tf.placeholder(tf.int64, shape=[])
hooks.add_counter("sum_a", a)
sum_a = tf.get_collection(hooks._DEFAULT_COUNTERS_COLLECTION)
sum_a = hooks.add_counter("sum_a", a)
self.assertIsNotNone(sum_a)
with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
Expand Down
11 changes: 9 additions & 2 deletions opennmt/utils/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def add_counter(name, tensor):
name: The name of this counter.
tensor: The integer ``tf.Tensor`` to count.
Returns:
An op that increments the counter.
See Also:
:meth:`opennmt.utils.misc.WordCounterHook` that fetches these counters
to log their value in TensorBoard.
Expand All @@ -43,6 +46,7 @@ def add_counter(name, tensor):
count,
name=name)
tf.add_to_collection(_DEFAULT_COUNTERS_COLLECTION, total_count)
return total_count


class CountersHook(tf.train.SessionRunHook):
Expand All @@ -55,7 +59,8 @@ def __init__(self,
every_n_steps=100,
every_n_secs=None,
output_dir=None,
summary_writer=None):
summary_writer=None,
counters=None):
if (every_n_steps is None) == (every_n_secs is None):
raise ValueError("exactly one of every_n_steps and every_n_secs should be provided.")
self._timer = tf.train.SecondOrStepTimer(
Expand All @@ -64,9 +69,11 @@ def __init__(self,

self._summary_writer = summary_writer
self._output_dir = output_dir
self._counters = counters

def begin(self):
self._counters = tf.get_collection(_DEFAULT_COUNTERS_COLLECTION)
if self._counters is None:
self._counters = tf.get_collection(_DEFAULT_COUNTERS_COLLECTION)
if not self._counters:
return

Expand Down

0 comments on commit 784efeb

Please sign in to comment.