Skip to content

Commit

Permalink
[tensorflow] Allow unspecified NodeDef names
Browse files Browse the repository at this point in the history
`NodeDef` name values can now be `None`, which means that reification
will use the next available unique name in the default graph.

Closes pymc-devs#93.
  • Loading branch information
brandonwillard committed Dec 3, 2019
1 parent b504958 commit 2cdb61b
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 49 deletions.
76 changes: 32 additions & 44 deletions symbolic_pymc/tensorflow/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,15 @@ def _protobuf_convert(cls, k, v):
raise TypeError(f"Could not convert {k}")

def __init__(self, op, name, attr, obj=None):
"""Create a TF meta NodeDef.
XXX: Meta NodeDefs with `name == None` have a special meaning;
their names are uniquely generated. We still consider them equal
(when every other property is equal, of course).
"""
super().__init__(obj=obj)
self.op = metatize(op)
assert name is not None
self.name = name if isvar(name) else str(name)
self.name = name if isvar(name) else name

if not isvar(attr):
opdef_sig, _ = op_def_lib.get_op_info(self.op)
Expand Down Expand Up @@ -601,7 +606,7 @@ def reify(self):
#
try:
existing_op = ops.get_default_graph().get_operation_by_name(self.name)
except KeyError:
except (KeyError, TypeError):
#
# There is no such `Operation`, so we attempt to create it
#
Expand All @@ -613,7 +618,15 @@ def reify(self):
# An `Operation` with this name exists, let's make sure it's
# equivalent to this meta `Operation`
#
if self != mt(existing_op):
existing_op_mt = mt(existing_op)

# # Since we can't exactly reproduce all NodeDef.attr information
# # (e.g. dtypes), we need to remove any unnecessary NodeDef.attr
# # fields from comparisons with same-named nodes in the graph.
# if op_attrs.keys() != node_attr.keys():
# existing_op_mt.node_def.attr = node_attr

if self != existing_op_mt:
raise MetaReificationError(
f"An Operation with the name {self.name}"
" already exists in the graph and is not"
Expand Down Expand Up @@ -987,48 +1000,22 @@ def __api_call__(self, *args, **kwargs):

if not op_args_unreified:

res_var = None
# name = op_args.get("name", None)
#
# if name is not None:
# #
# # An operation with this name might already exist in the graph
# #
# We create the `Operation` in the graph
#
# from tensorflow.python.framework import ops
#
# try:
# this_op = ops.get_default_graph().get_operation_by_name(name)
# except KeyError:
# pass
# else:
# # TODO: Make sure the existing `Operation` matches our arguments
# assert this_op.type == self.op_def.obj.name
#
# this_op = mt(this_op)
# op_inputs, op_node_def = self.op_args_to_operation_inputs(op_args)
# assert op_inputs == this_op.inputs
# assert op_node_def == this_op.node_def
# res_var = this_op.default_output

if res_var is None:
#
# We create the `Operation` in the graph
#

tf_out = self._apply_func(**op_args)

# Ensure that the original meta objects will be available
# for use in the `metatize` that follows
tf_metatize_cache.update(
{
k: v
for k, v in zip(op_args.values(), apply_arguments.values())
if isinstance(k, tf.Tensor)
}
)
tf_out = self._apply_func(**op_args)

# Ensure that the original meta objects will be available
# for use in the `metatize` that follows
tf_metatize_cache.update(
{
k: v
for k, v in zip(op_args.values(), apply_arguments.values())
if isinstance(k, tf.Tensor)
}
)

res_var = metatize(tf_out)
res_var = metatize(tf_out)

if "names" in meta._lvar_defaults_enabled:
# This should also reset the NodeDef's `obj`
Expand Down Expand Up @@ -1073,7 +1060,8 @@ def op_args_to_operation_inputs(self, apply_arguments):
node_attr = var()

if "names" not in meta._lvar_defaults_enabled:
op_name = apply_arguments.get("name", op_def_tf.name) or op_def_tf.name
# default_name = ops.get_default_graph().unique_name(op_def_tf.name, mark_as_used=False)
op_name = apply_arguments.get("name", None)
else:
op_name = var()

Expand Down
17 changes: 15 additions & 2 deletions tests/tensorflow/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def test_global_options():
with tf.Graph().as_default(), disable_auto_reification():
y_mt = mt.Placeholder('float')
assert y_mt.obj is None
assert y_mt.name == 'Placeholder:0'
assert isvar(y_mt.name)
assert isinstance(y_mt.op.node_def.attr, dict)

with tf.Graph().as_default(), enable_lvar_defaults('names', 'node_attrs'):
Expand Down Expand Up @@ -706,7 +706,7 @@ def test_meta_const():
@run_in_graph_mode
def test_meta_existing_names():

with tf.Graph().as_default():
with tf.Graph().as_default() as test_graph:
one_mt = mt(1)
assert one_mt.op.name == 'Const'

Expand All @@ -723,6 +723,7 @@ def test_meta_existing_names():
# Make sure it's the first base variable we created
assert orig_one_tf is one_tf

# FYI: This implicitly creates 'Const_1'
two_mt = mt(2)
two_mt.op.node_def.name = 'Const'

Expand All @@ -736,3 +737,15 @@ def test_meta_existing_names():

with pytest.raises(MetaReificationError):
two_mt.reify()

another_one_mt = TFlowMetaOperator('Const', None)(3, var())
# The following is something that would happen as a result of
# reification (of the lvar in the meta object, not the meta object
# itself).
another_one_mt.op.node_def.attr['dtype'] = tf.int32

assert another_one_mt.op.name is None
# We need to make sure that the reified meta object actually uses a
# unique name.
assert isinstance(another_one_mt.reify(), tf.Tensor)
assert another_one_mt.reify().op.name == 'Const_2'
9 changes: 6 additions & 3 deletions tests/tensorflow/test_unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,11 @@ def test_basic_unify_reify():

test_expr = mt.add(tf.constant(1, dtype=tf.float64),
mt.mul(tf.constant(2, dtype=tf.float64),
x_l))
test_reify_res = reify(test_expr, {x_l: a})
x_l, name=var('mul_name')),
name=var('add_name'))
test_reify_res = reify(test_expr, {x_l: a,
var('add_name'): 'Add_10',
var('mul_name'): 'Mul_10'})
test_base_res = test_reify_res.reify()
assert isinstance(test_base_res, tf.Tensor)

Expand All @@ -141,7 +144,7 @@ def test_basic_unify_reify():
# Simply make sure that unification succeeds
meta_expected_res = mt(expected_res)
s_test = unify(test_expr, meta_expected_res, {})
assert len(s_test) == 3
assert len(s_test) == 5

assert reify(test_expr, s_test) == meta_expected_res

Expand Down

0 comments on commit 2cdb61b

Please sign in to comment.