From 8fde3c22aefb77f4c47217d9767c9c233e71f32d Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Mon, 19 Nov 2018 11:56:39 +0100 Subject: [PATCH] Fix variable creation order mismatch when using a DenseBridge (#267) --- CHANGELOG.md | 2 + opennmt/decoders/decoder.py | 8 +-- opennmt/tests/decoder_test.py | 132 +++++++++++++++++++--------------- 3 files changed, 79 insertions(+), 63 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 597894a7c..633e0df0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ OpenNMT-tf follows [semantic versioning 2.0.0](https://semver.org/). The API cov ### Fixes and improvements +* Fix error when building an inference graph including a `DenseBridge` + ## [1.13.0](https://github.com/OpenNMT/OpenNMT-tf/releases/tag/v1.13.0) (2018-11-14) ### New features diff --git a/opennmt/decoders/decoder.py b/opennmt/decoders/decoder.py index d9fe84ecf..ca3c5cbbc 100644 --- a/opennmt/decoders/decoder.py +++ b/opennmt/decoders/decoder.py @@ -305,10 +305,6 @@ def dynamic_decode_and_search(self, if memory is None: raise ValueError("dtype argument is required when no memory is set") dtype = tf.contrib.framework.nest.flatten(memory)[0].dtype - if output_layer is None: - if vocab_size is None: - raise ValueError("vocab_size must be known when the output_layer is not set") - output_layer = build_output_layer(self.output_size, vocab_size, dtype=dtype) if beam_width > 1: if initial_state is not None: @@ -327,6 +323,10 @@ def dynamic_decode_and_search(self, memory=memory, memory_sequence_length=memory_sequence_length, dtype=dtype) + if output_layer is None: + if vocab_size is None: + raise ValueError("vocab_size must be known when the output_layer is not set") + output_layer = build_output_layer(self.output_size, vocab_size, dtype=dtype) state = {"decoder": initial_state} if self.support_alignment_history and not isinstance(memory, (tuple, list)): diff --git a/opennmt/tests/decoder_test.py b/opennmt/tests/decoder_test.py index 83a19bb7c..ac915ad2e 100644 --- a/opennmt/tests/decoder_test.py +++ b/opennmt/tests/decoder_test.py @@ -1,4 +1,5 @@ import math +import os import tensorflow as tf import numpy as np @@ -6,6 +7,7 @@ from opennmt import decoders from opennmt.decoders import decoder from opennmt.utils import beam_search +from opennmt.layers import bridge class DecoderTest(tf.test.TestCase): @@ -43,7 +45,7 @@ def testSamplingProbability(self): self.assertAlmostEqual( 1.0 - (1.0 / (1.0 + math.exp(5.0 / 1.0))), sess.run(inv_sig_sample_prob)) - def _testDecoderTraining(self, decoder, dtype=tf.float32): + def _testDecoderTraining(self, decoder, initial_state_fn=None, dtype=tf.float32): batch_size = 4 vocab_size = 10 time_dim = 5 @@ -58,10 +60,15 @@ def _testDecoderTraining(self, decoder, dtype=tf.float32): memory = tf.placeholder_with_default( np.random.randn(batch_size, memory_time, depth).astype(dtype.as_numpy_dtype()), shape=(None, None, depth)) + if initial_state_fn is not None: + initial_state = initial_state_fn(tf.shape(memory)[0], dtype) + else: + initial_state = None outputs, _, _, attention = decoder.decode( inputs, sequence_length, vocab_size=vocab_size, + initial_state=initial_state, memory=memory, memory_sequence_length=memory_sequence_length, return_alignment_history=True) @@ -72,44 +79,23 @@ def _testDecoderTraining(self, decoder, dtype=tf.float32): else: self.assertIsNone(attention) - with self.test_session() as sess: + saver = tf.train.Saver(var_list=tf.global_variables()) + with self.test_session(graph=tf.get_default_graph()) as sess: sess.run(tf.global_variables_initializer()) - with self.test_session() as sess: output_time_dim_val = sess.run(output_time_dim) self.assertEqual(time_dim, output_time_dim_val) if decoder.support_alignment_history: attention_val = sess.run(attention) self.assertAllEqual([batch_size, time_dim, memory_time], attention_val.shape) - - def testRNNDecoderTraining(self): - decoder = decoders.RNNDecoder(2, 20) - self._testDecoderTraining(decoder) - - def testAttentionalRNNDecoderTraining(self): - decoder = decoders.AttentionalRNNDecoder(2, 20) - self._testDecoderTraining(decoder) - - def testMultiAttentionalRNNDecoderTraining(self): - decoder = decoders.MultiAttentionalRNNDecoder(2, 20, attention_layers=[0]) - self._testDecoderTraining(decoder) - - def testRNMTPlusDecoderTraining(self): - decoder = decoders.RNMTPlusDecoder(2, 20, 4) - self._testDecoderTraining(decoder) - - def testSelfAttentionDecoderTraining(self): - decoder = decoders.SelfAttentionDecoder(2, num_units=6, num_heads=2, ffn_inner_dim=12) - self._testDecoderTraining(decoder) - - def testSelfAttentionDecoderFP16Training(self): - decoder = decoders.SelfAttentionDecoder(2, num_units=6, num_heads=2, ffn_inner_dim=12) - self._testDecoderTraining(decoder, dtype=tf.float16) - - def _testDecoderGeneric(self, - decoder, - with_beam_search=False, - with_alignment_history=False, - dtype=tf.float32): + return saver.save(sess, os.path.join(self.get_temp_dir(), "model.ckpt")) + + def _testDecoderInference(self, + decoder, + initial_state_fn=None, + with_beam_search=False, + with_alignment_history=False, + dtype=tf.float32, + checkpoint_path=None): batch_size = 4 beam_width = 5 num_hyps = beam_width if with_beam_search else 1 @@ -126,6 +112,10 @@ def _testDecoderGeneric(self, embedding = tf.placeholder_with_default( np.random.randn(vocab_size, depth).astype(dtype.as_numpy_dtype()), shape=(vocab_size, depth)) + if initial_state_fn is not None: + initial_state = initial_state_fn(tf.shape(memory)[0], dtype) + else: + initial_state = None if with_beam_search: decode_fn = decoder.dynamic_decode_and_search @@ -143,6 +133,7 @@ def _testDecoderGeneric(self, start_tokens, end_token, vocab_size=vocab_size, + initial_state=initial_state, maximum_iterations=10, memory=memory, memory_sequence_length=memory_sequence_length, @@ -155,55 +146,71 @@ def _testDecoderGeneric(self, self.assertEqual(log_probs.dtype, tf.float32) decode_time = tf.shape(ids)[-1] + saver = tf.train.Saver(var_list=tf.global_variables()) - with self.test_session() as sess: - sess.run(tf.global_variables_initializer()) + with self.test_session(graph=tf.get_default_graph()) as sess: + if checkpoint_path is not None: + saver.restore(sess, checkpoint_path) + else: + sess.run(tf.global_variables_initializer()) - if not with_alignment_history: - self.assertEqual(4, len(outputs)) - else: - self.assertEqual(5, len(outputs)) - alignment_history = outputs[4] - if decoder.support_alignment_history: - self.assertIsInstance(alignment_history, tf.Tensor) - with self.test_session() as sess: + if not with_alignment_history: + self.assertEqual(4, len(outputs)) + else: + self.assertEqual(5, len(outputs)) + alignment_history = outputs[4] + if decoder.support_alignment_history: + self.assertIsInstance(alignment_history, tf.Tensor) alignment_history, decode_time = sess.run([alignment_history, decode_time]) self.assertAllEqual( [batch_size, num_hyps, decode_time, memory_time], alignment_history.shape) - else: - self.assertIsNone(alignment_history) + else: + self.assertIsNone(alignment_history) - with self.test_session() as sess: ids, lengths, log_probs = sess.run([ids, lengths, log_probs]) self.assertAllEqual([batch_size, num_hyps], ids.shape[0:2]) self.assertAllEqual([batch_size, num_hyps], lengths.shape) self.assertAllEqual([batch_size, num_hyps], log_probs.shape) - def _testDecoder(self, decoder, dtype=tf.float32): - with tf.variable_scope(tf.get_variable_scope()): - self._testDecoderGeneric( + def _testDecoder(self, decoder, initial_state_fn=None, dtype=tf.float32): + with tf.Graph().as_default() as g: + checkpoint_path = self._testDecoderTraining( decoder, + initial_state_fn=initial_state_fn, + dtype=dtype) + + with tf.Graph().as_default() as g: + self._testDecoderInference( + decoder, + initial_state_fn=initial_state_fn, with_beam_search=False, with_alignment_history=False, - dtype=dtype) - with tf.variable_scope(tf.get_variable_scope(), reuse=True): - self._testDecoderGeneric( + dtype=dtype, + checkpoint_path=checkpoint_path) + with tf.Graph().as_default() as g: + self._testDecoderInference( decoder, + initial_state_fn=initial_state_fn, with_beam_search=False, with_alignment_history=True, - dtype=dtype) - with tf.variable_scope(tf.get_variable_scope(), reuse=True): - self._testDecoderGeneric( + dtype=dtype, + checkpoint_path=checkpoint_path) + with tf.Graph().as_default() as g: + self._testDecoderInference( decoder, + initial_state_fn=initial_state_fn, with_beam_search=True, with_alignment_history=False, - dtype=dtype) - with tf.variable_scope(tf.get_variable_scope(), reuse=True): - self._testDecoderGeneric( + dtype=dtype, + checkpoint_path=checkpoint_path) + with tf.Graph().as_default() as g: + self._testDecoderInference( decoder, + initial_state_fn=initial_state_fn, with_beam_search=True, with_alignment_history=True, - dtype=dtype) + dtype=dtype, + checkpoint_path=checkpoint_path) def testRNNDecoder(self): decoder = decoders.RNNDecoder(2, 20) @@ -213,6 +220,13 @@ def testAttentionalRNNDecoder(self): decoder = decoders.AttentionalRNNDecoder(2, 20) self._testDecoder(decoder) + def testAttentionalRNNDecoderWithDenseBridge(self): + decoder = decoders.AttentionalRNNDecoder(2, 36, bridge=bridge.DenseBridge()) + encoder_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.LSTMCell(5), + tf.nn.rnn_cell.LSTMCell(5)]) + initial_state_fn = lambda batch_size, dtype: encoder_cell.zero_state(batch_size, dtype) + self._testDecoder(decoder, initial_state_fn=initial_state_fn) + def testMultiAttentionalRNNDecoder(self): decoder = decoders.MultiAttentionalRNNDecoder(2, 20, attention_layers=[0]) self._testDecoder(decoder)