Skip to content

Commit

Permalink
Fix variable creation order mismatch when using a DenseBridge (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln committed Nov 19, 2018
1 parent f9f7cdb commit 8fde3c2
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 63 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ OpenNMT-tf follows [semantic versioning 2.0.0](https://semver.org/). The API cov

### Fixes and improvements

* Fix error when building an inference graph including a `DenseBridge`

## [1.13.0](https://github.com/OpenNMT/OpenNMT-tf/releases/tag/v1.13.0) (2018-11-14)

### New features
Expand Down
8 changes: 4 additions & 4 deletions opennmt/decoders/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,6 @@ def dynamic_decode_and_search(self,
if memory is None:
raise ValueError("dtype argument is required when no memory is set")
dtype = tf.contrib.framework.nest.flatten(memory)[0].dtype
if output_layer is None:
if vocab_size is None:
raise ValueError("vocab_size must be known when the output_layer is not set")
output_layer = build_output_layer(self.output_size, vocab_size, dtype=dtype)

if beam_width > 1:
if initial_state is not None:
Expand All @@ -327,6 +323,10 @@ def dynamic_decode_and_search(self,
memory=memory,
memory_sequence_length=memory_sequence_length,
dtype=dtype)
if output_layer is None:
if vocab_size is None:
raise ValueError("vocab_size must be known when the output_layer is not set")
output_layer = build_output_layer(self.output_size, vocab_size, dtype=dtype)

state = {"decoder": initial_state}
if self.support_alignment_history and not isinstance(memory, (tuple, list)):
Expand Down
132 changes: 73 additions & 59 deletions opennmt/tests/decoder_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import math
import os

import tensorflow as tf
import numpy as np

from opennmt import decoders
from opennmt.decoders import decoder
from opennmt.utils import beam_search
from opennmt.layers import bridge


class DecoderTest(tf.test.TestCase):
Expand Down Expand Up @@ -43,7 +45,7 @@ def testSamplingProbability(self):
self.assertAlmostEqual(
1.0 - (1.0 / (1.0 + math.exp(5.0 / 1.0))), sess.run(inv_sig_sample_prob))

def _testDecoderTraining(self, decoder, dtype=tf.float32):
def _testDecoderTraining(self, decoder, initial_state_fn=None, dtype=tf.float32):
batch_size = 4
vocab_size = 10
time_dim = 5
Expand All @@ -58,10 +60,15 @@ def _testDecoderTraining(self, decoder, dtype=tf.float32):
memory = tf.placeholder_with_default(
np.random.randn(batch_size, memory_time, depth).astype(dtype.as_numpy_dtype()),
shape=(None, None, depth))
if initial_state_fn is not None:
initial_state = initial_state_fn(tf.shape(memory)[0], dtype)
else:
initial_state = None
outputs, _, _, attention = decoder.decode(
inputs,
sequence_length,
vocab_size=vocab_size,
initial_state=initial_state,
memory=memory,
memory_sequence_length=memory_sequence_length,
return_alignment_history=True)
Expand All @@ -72,44 +79,23 @@ def _testDecoderTraining(self, decoder, dtype=tf.float32):
else:
self.assertIsNone(attention)

with self.test_session() as sess:
saver = tf.train.Saver(var_list=tf.global_variables())
with self.test_session(graph=tf.get_default_graph()) as sess:
sess.run(tf.global_variables_initializer())
with self.test_session() as sess:
output_time_dim_val = sess.run(output_time_dim)
self.assertEqual(time_dim, output_time_dim_val)
if decoder.support_alignment_history:
attention_val = sess.run(attention)
self.assertAllEqual([batch_size, time_dim, memory_time], attention_val.shape)

def testRNNDecoderTraining(self):
decoder = decoders.RNNDecoder(2, 20)
self._testDecoderTraining(decoder)

def testAttentionalRNNDecoderTraining(self):
decoder = decoders.AttentionalRNNDecoder(2, 20)
self._testDecoderTraining(decoder)

def testMultiAttentionalRNNDecoderTraining(self):
decoder = decoders.MultiAttentionalRNNDecoder(2, 20, attention_layers=[0])
self._testDecoderTraining(decoder)

def testRNMTPlusDecoderTraining(self):
decoder = decoders.RNMTPlusDecoder(2, 20, 4)
self._testDecoderTraining(decoder)

def testSelfAttentionDecoderTraining(self):
decoder = decoders.SelfAttentionDecoder(2, num_units=6, num_heads=2, ffn_inner_dim=12)
self._testDecoderTraining(decoder)

def testSelfAttentionDecoderFP16Training(self):
decoder = decoders.SelfAttentionDecoder(2, num_units=6, num_heads=2, ffn_inner_dim=12)
self._testDecoderTraining(decoder, dtype=tf.float16)

def _testDecoderGeneric(self,
decoder,
with_beam_search=False,
with_alignment_history=False,
dtype=tf.float32):
return saver.save(sess, os.path.join(self.get_temp_dir(), "model.ckpt"))

def _testDecoderInference(self,
decoder,
initial_state_fn=None,
with_beam_search=False,
with_alignment_history=False,
dtype=tf.float32,
checkpoint_path=None):
batch_size = 4
beam_width = 5
num_hyps = beam_width if with_beam_search else 1
Expand All @@ -126,6 +112,10 @@ def _testDecoderGeneric(self,
embedding = tf.placeholder_with_default(
np.random.randn(vocab_size, depth).astype(dtype.as_numpy_dtype()),
shape=(vocab_size, depth))
if initial_state_fn is not None:
initial_state = initial_state_fn(tf.shape(memory)[0], dtype)
else:
initial_state = None

if with_beam_search:
decode_fn = decoder.dynamic_decode_and_search
Expand All @@ -143,6 +133,7 @@ def _testDecoderGeneric(self,
start_tokens,
end_token,
vocab_size=vocab_size,
initial_state=initial_state,
maximum_iterations=10,
memory=memory,
memory_sequence_length=memory_sequence_length,
Expand All @@ -155,55 +146,71 @@ def _testDecoderGeneric(self,
self.assertEqual(log_probs.dtype, tf.float32)

decode_time = tf.shape(ids)[-1]
saver = tf.train.Saver(var_list=tf.global_variables())

with self.test_session() as sess:
sess.run(tf.global_variables_initializer())
with self.test_session(graph=tf.get_default_graph()) as sess:
if checkpoint_path is not None:
saver.restore(sess, checkpoint_path)
else:
sess.run(tf.global_variables_initializer())

if not with_alignment_history:
self.assertEqual(4, len(outputs))
else:
self.assertEqual(5, len(outputs))
alignment_history = outputs[4]
if decoder.support_alignment_history:
self.assertIsInstance(alignment_history, tf.Tensor)
with self.test_session() as sess:
if not with_alignment_history:
self.assertEqual(4, len(outputs))
else:
self.assertEqual(5, len(outputs))
alignment_history = outputs[4]
if decoder.support_alignment_history:
self.assertIsInstance(alignment_history, tf.Tensor)
alignment_history, decode_time = sess.run([alignment_history, decode_time])
self.assertAllEqual(
[batch_size, num_hyps, decode_time, memory_time], alignment_history.shape)
else:
self.assertIsNone(alignment_history)
else:
self.assertIsNone(alignment_history)

with self.test_session() as sess:
ids, lengths, log_probs = sess.run([ids, lengths, log_probs])
self.assertAllEqual([batch_size, num_hyps], ids.shape[0:2])
self.assertAllEqual([batch_size, num_hyps], lengths.shape)
self.assertAllEqual([batch_size, num_hyps], log_probs.shape)

def _testDecoder(self, decoder, dtype=tf.float32):
with tf.variable_scope(tf.get_variable_scope()):
self._testDecoderGeneric(
def _testDecoder(self, decoder, initial_state_fn=None, dtype=tf.float32):
with tf.Graph().as_default() as g:
checkpoint_path = self._testDecoderTraining(
decoder,
initial_state_fn=initial_state_fn,
dtype=dtype)

with tf.Graph().as_default() as g:
self._testDecoderInference(
decoder,
initial_state_fn=initial_state_fn,
with_beam_search=False,
with_alignment_history=False,
dtype=dtype)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._testDecoderGeneric(
dtype=dtype,
checkpoint_path=checkpoint_path)
with tf.Graph().as_default() as g:
self._testDecoderInference(
decoder,
initial_state_fn=initial_state_fn,
with_beam_search=False,
with_alignment_history=True,
dtype=dtype)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._testDecoderGeneric(
dtype=dtype,
checkpoint_path=checkpoint_path)
with tf.Graph().as_default() as g:
self._testDecoderInference(
decoder,
initial_state_fn=initial_state_fn,
with_beam_search=True,
with_alignment_history=False,
dtype=dtype)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._testDecoderGeneric(
dtype=dtype,
checkpoint_path=checkpoint_path)
with tf.Graph().as_default() as g:
self._testDecoderInference(
decoder,
initial_state_fn=initial_state_fn,
with_beam_search=True,
with_alignment_history=True,
dtype=dtype)
dtype=dtype,
checkpoint_path=checkpoint_path)

def testRNNDecoder(self):
decoder = decoders.RNNDecoder(2, 20)
Expand All @@ -213,6 +220,13 @@ def testAttentionalRNNDecoder(self):
decoder = decoders.AttentionalRNNDecoder(2, 20)
self._testDecoder(decoder)

def testAttentionalRNNDecoderWithDenseBridge(self):
decoder = decoders.AttentionalRNNDecoder(2, 36, bridge=bridge.DenseBridge())
encoder_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.LSTMCell(5),
tf.nn.rnn_cell.LSTMCell(5)])
initial_state_fn = lambda batch_size, dtype: encoder_cell.zero_state(batch_size, dtype)
self._testDecoder(decoder, initial_state_fn=initial_state_fn)

def testMultiAttentionalRNNDecoder(self):
decoder = decoders.MultiAttentionalRNNDecoder(2, 20, attention_layers=[0])
self._testDecoder(decoder)
Expand Down

0 comments on commit 8fde3c2

Please sign in to comment.