Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for python3 #9

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 84 additions & 80 deletions dnc/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,20 @@ def initials(self):
sets the initial values of the controller transformation weights matrices
this method can be overwritten to use a different initialization scheme
"""
# defining internal weights of the controller
self.interface_weights = tf.Variable(
tf.random_normal([self.nn_output_size, self.interface_vector_size], stddev=0.1),
name='interface_weights'
)
self.nn_output_weights = tf.Variable(
tf.random_normal([self.nn_output_size, self.output_size], stddev=0.1),
name='nn_output_weights'
)
self.mem_output_weights = tf.Variable(
tf.random_normal([self.word_size * self.read_heads, self.output_size], stddev=0.1),
name='mem_output_weights'
)
with tf.name_scope("controller_initials"):
# defining internal weights of the controller
self.interface_weights = tf.Variable(
tf.random_normal([self.nn_output_size, self.interface_vector_size], stddev=0.1),
name='interface_weights'
)
self.nn_output_weights = tf.Variable(
tf.random_normal([self.nn_output_size, self.output_size], stddev=0.1),
name='nn_output_weights'
)
self.mem_output_weights = tf.Variable(
tf.random_normal([self.word_size * self.read_heads, self.output_size], stddev=0.1),
name='mem_output_weights'
)

def network_vars(self):
"""
Expand Down Expand Up @@ -100,21 +101,21 @@ def get_nn_output_size(self):

Raises: ValueError
"""
with tf.name_scope("controller_get_nn_output_size"):
input_vector = np.zeros([self.batch_size, self.nn_input_size], dtype=np.float32)
output_vector = None

input_vector = np.zeros([self.batch_size, self.nn_input_size], dtype=np.float32)
output_vector = None
if self.has_recurrent_nn:
output_vector,_ = self.network_op(input_vector, self.get_state())
else:
output_vector = self.network_op(input_vector)

if self.has_recurrent_nn:
output_vector,_ = self.network_op(input_vector, self.get_state())
else:
output_vector = self.network_op(input_vector)
shape = output_vector.get_shape().as_list()

shape = output_vector.get_shape().as_list()

if len(shape) > 2:
raise ValueError("Expected the neural network to output a 1D vector, but got %dD" % (len(shape) - 1))
else:
return shape[1]
if len(shape) > 2:
raise ValueError("Expected the neural network to output a 1D vector, but got %dD" % (len(shape) - 1))
else:
return shape[1]


def parse_interface_vector(self, interface_vector):
Expand All @@ -131,44 +132,45 @@ def parse_interface_vector(self, interface_vector):
a dictionary with the components of the interface_vector parsed
"""

parsed = {}

r_keys_end = self.word_size * self.read_heads
r_strengths_end = r_keys_end + self.read_heads
w_key_end = r_strengths_end + self.word_size
erase_end = w_key_end + 1 + self.word_size
write_end = erase_end + self.word_size
free_end = write_end + self.read_heads

r_keys_shape = (-1, self.word_size, self.read_heads)
r_strengths_shape = (-1, self.read_heads)
w_key_shape = (-1, self.word_size, 1)
write_shape = erase_shape = (-1, self.word_size)
free_shape = (-1, self.read_heads)
modes_shape = (-1, 3, self.read_heads)

# parsing the vector into its individual components
parsed['read_keys'] = tf.reshape(interface_vector[:, :r_keys_end], r_keys_shape)
parsed['read_strengths'] = tf.reshape(interface_vector[:, r_keys_end:r_strengths_end], r_strengths_shape)
parsed['write_key'] = tf.reshape(interface_vector[:, r_strengths_end:w_key_end], w_key_shape)
parsed['write_strength'] = tf.reshape(interface_vector[:, w_key_end], (-1, 1))
parsed['erase_vector'] = tf.reshape(interface_vector[:, w_key_end + 1:erase_end], erase_shape)
parsed['write_vector'] = tf.reshape(interface_vector[:, erase_end:write_end], write_shape)
parsed['free_gates'] = tf.reshape(interface_vector[:, write_end:free_end], free_shape)
parsed['allocation_gate'] = tf.expand_dims(interface_vector[:, free_end], 1)
parsed['write_gate'] = tf.expand_dims(interface_vector[:, free_end + 1], 1)
parsed['read_modes'] = tf.reshape(interface_vector[:, free_end + 2:], modes_shape)

# transforming the components to ensure they're in the right ranges
parsed['read_strengths'] = 1 + tf.nn.softplus(parsed['read_strengths'])
parsed['write_strength'] = 1 + tf.nn.softplus(parsed['write_strength'])
parsed['erase_vector'] = tf.nn.sigmoid(parsed['erase_vector'])
parsed['free_gates'] = tf.nn.sigmoid(parsed['free_gates'])
parsed['allocation_gate'] = tf.nn.sigmoid(parsed['allocation_gate'])
parsed['write_gate'] = tf.nn.sigmoid(parsed['write_gate'])
parsed['read_modes'] = tf.nn.softmax(parsed['read_modes'], 1)

return parsed
with tf.name_scope("controller_parse_interface_vector"):
parsed = {}

r_keys_end = self.word_size * self.read_heads
r_strengths_end = r_keys_end + self.read_heads
w_key_end = r_strengths_end + self.word_size
erase_end = w_key_end + 1 + self.word_size
write_end = erase_end + self.word_size
free_end = write_end + self.read_heads

r_keys_shape = (-1, self.word_size, self.read_heads)
r_strengths_shape = (-1, self.read_heads)
w_key_shape = (-1, self.word_size, 1)
write_shape = erase_shape = (-1, self.word_size)
free_shape = (-1, self.read_heads)
modes_shape = (-1, 3, self.read_heads)

# parsing the vector into its individual components
parsed['read_keys'] = tf.reshape(interface_vector[:, :r_keys_end], r_keys_shape)
parsed['read_strengths'] = tf.reshape(interface_vector[:, r_keys_end:r_strengths_end], r_strengths_shape)
parsed['write_key'] = tf.reshape(interface_vector[:, r_strengths_end:w_key_end], w_key_shape)
parsed['write_strength'] = tf.reshape(interface_vector[:, w_key_end], (-1, 1))
parsed['erase_vector'] = tf.reshape(interface_vector[:, w_key_end + 1:erase_end], erase_shape)
parsed['write_vector'] = tf.reshape(interface_vector[:, erase_end:write_end], write_shape)
parsed['free_gates'] = tf.reshape(interface_vector[:, write_end:free_end], free_shape)
parsed['allocation_gate'] = tf.expand_dims(interface_vector[:, free_end], 1)
parsed['write_gate'] = tf.expand_dims(interface_vector[:, free_end + 1], 1)
parsed['read_modes'] = tf.reshape(interface_vector[:, free_end + 2:], modes_shape)

# transforming the components to ensure they're in the right ranges
parsed['read_strengths'] = 1 + tf.nn.softplus(parsed['read_strengths'])
parsed['write_strength'] = 1 + tf.nn.softplus(parsed['write_strength'])
parsed['erase_vector'] = tf.nn.sigmoid(parsed['erase_vector'])
parsed['free_gates'] = tf.nn.sigmoid(parsed['free_gates'])
parsed['allocation_gate'] = tf.nn.sigmoid(parsed['allocation_gate'])
parsed['write_gate'] = tf.nn.sigmoid(parsed['write_gate'])
parsed['read_modes'] = tf.nn.softmax(parsed['read_modes'], 1)

return parsed

def process_input(self, X, last_read_vectors, state=None):
"""
Expand All @@ -189,23 +191,24 @@ def process_input(self, X, last_read_vectors, state=None):
parsed_interface_vector: dict
"""

flat_read_vectors = tf.reshape(last_read_vectors, (-1, self.word_size * self.read_heads))
complete_input = tf.concat(1, [X, flat_read_vectors])
nn_output, nn_state = None, None
with tf.name_scope("controller_process_input"):
flat_read_vectors = tf.reshape(last_read_vectors, (-1, self.word_size * self.read_heads))
complete_input = tf.concat(1, [X, flat_read_vectors])
nn_output, nn_state = None, None

if self.has_recurrent_nn:
nn_output, nn_state = self.network_op(complete_input, state)
else:
nn_output = self.network_op(complete_input)
if self.has_recurrent_nn:
nn_output, nn_state = self.network_op(complete_input, state)
else:
nn_output = self.network_op(complete_input)

pre_output = tf.matmul(nn_output, self.nn_output_weights)
interface = tf.matmul(nn_output, self.interface_weights)
parsed_interface = self.parse_interface_vector(interface)
pre_output = tf.matmul(nn_output, self.nn_output_weights)
interface = tf.matmul(nn_output, self.interface_weights)
parsed_interface = self.parse_interface_vector(interface)

if self.has_recurrent_nn:
return pre_output, parsed_interface, nn_state
else:
return pre_output, parsed_interface
if self.has_recurrent_nn:
return pre_output, parsed_interface, nn_state
else:
return pre_output, parsed_interface


def final_output(self, pre_output, new_read_vectors):
Expand All @@ -222,8 +225,9 @@ def final_output(self, pre_output, new_read_vectors):
Returns: Tensor (batch_size, output_size)
"""

flat_read_vectors = tf.reshape(new_read_vectors, (-1, self.word_size * self.read_heads))
with tf.name_scope("controller_final_output"):
flat_read_vectors = tf.reshape(new_read_vectors, (-1, self.word_size * self.read_heads))

final_output = pre_output + tf.matmul(flat_read_vectors, self.mem_output_weights)
final_output = pre_output + tf.matmul(flat_read_vectors, self.mem_output_weights)

return final_output
return final_output
Loading