Skip to content

Commit

Permalink
fixed to update prev_c, prev_h in the controller
Browse files Browse the repository at this point in the history
  • Loading branch information
hyhieu committed Apr 20, 2018
1 parent 2734eb2 commit 2d0cd4b
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/ptb/ptb_enas_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def __init__(self,
lstm_keep_prob=1.0,
tanh_constant=None,
temperature=None,
num_funcs=2,
lr_init=1e-3,
lr_dec_start=0,
lr_dec_every=100,
Expand All @@ -46,6 +47,7 @@ def __init__(self,
self.lstm_keep_prob = lstm_keep_prob
self.tanh_constant = tanh_constant
self.temperature = temperature
self.num_funcs = num_funcs
self.lr_init = lr_init
self.lr_dec_start = lr_dec_start
self.lr_dec_every = lr_dec_every
Expand Down Expand Up @@ -74,7 +76,7 @@ def _create_params(self):
w = tf.get_variable("w", [2 * self.lstm_size, 4 * self.lstm_size])
self.w_lstm.append(w)

num_funcs = 4
num_funcs = self.num_funcs
with tf.variable_scope("embedding"):
self.g_emb = tf.get_variable("g_emb", [1, self.lstm_size])
self.w_emb = tf.get_variable("w", [num_funcs, self.lstm_size])
Expand Down Expand Up @@ -106,6 +108,7 @@ def _build_sampler(self):
# used = tf.zeros([self.rhn_depth, 2], dtype=tf.int32)
for layer_id in xrange(self.rhn_depth):
next_c, next_h = stack_lstm(inputs, prev_c, prev_h, self.w_lstm)
prev_c, prev_h = next_c, next_h
all_h.append(next_h[-1])
all_h_w.append(tf.matmul(next_h[-1], self.attn_w_1))

Expand All @@ -121,7 +124,7 @@ def _build_sampler(self):
if self.tanh_constant is not None:
logits = self.tanh_constant * tf.tanh(logits)
diff = tf.to_float(layer_id - tf.range(0, layer_id)) ** 2
logits -= tf.reshape(diff, [1, layer_id]) / 12.0
logits -= tf.reshape(diff, [1, layer_id]) / 6.0

skip_index = tf.multinomial(logits, 1)
skip_index = tf.to_int32(skip_index)
Expand All @@ -142,6 +145,7 @@ def _build_sampler(self):
inputs = self.g_emb

next_c, next_h = stack_lstm(inputs, prev_c, prev_h, self.w_lstm)
prev_c, prev_h = next_c, next_h
logits = tf.matmul(next_h[-1], self.w_soft)
if self.temperature is not None:
logits /= self.temperature
Expand Down Expand Up @@ -173,7 +177,6 @@ def build_trainer(self, child_model):
# actor
self.valid_loss = tf.to_float(child_model.rl_loss)
self.valid_loss = tf.stop_gradient(self.valid_loss)
self.valid_loss = tf.minimum(self.valid_loss, 10.0)
self.valid_ppl = tf.exp(self.valid_loss)
self.reward = 80.0 / self.valid_ppl

Expand Down

0 comments on commit 2d0cd4b

Please sign in to comment.