From 2cdb61bd342513edb9c2ff2c03ce1f80852c5f57 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Tue, 3 Dec 2019 00:50:55 -0600 Subject: [PATCH] [tensorflow] Allow unspecified NodeDef names `NodeDef` name values can now be `None`, which means that reification will use the next available unique name in the default graph. Closes #93. --- symbolic_pymc/tensorflow/meta.py | 76 ++++++++++++++------------------ tests/tensorflow/test_meta.py | 17 ++++++- tests/tensorflow/test_unify.py | 9 ++-- 3 files changed, 53 insertions(+), 49 deletions(-) diff --git a/symbolic_pymc/tensorflow/meta.py b/symbolic_pymc/tensorflow/meta.py index 5daefe8..f582dd6 100644 --- a/symbolic_pymc/tensorflow/meta.py +++ b/symbolic_pymc/tensorflow/meta.py @@ -366,10 +366,15 @@ def _protobuf_convert(cls, k, v): raise TypeError(f"Could not convert {k}") def __init__(self, op, name, attr, obj=None): + """Create a TF meta NodeDef. + + XXX: Meta NodeDefs with `name == None` have a special meaning; + their names are uniquely generated. We still consider them equal + (when every other property is equal, of course). + """ super().__init__(obj=obj) self.op = metatize(op) - assert name is not None - self.name = name if isvar(name) else str(name) + self.name = name if isvar(name) else name if not isvar(attr): opdef_sig, _ = op_def_lib.get_op_info(self.op) @@ -601,7 +606,7 @@ def reify(self): # try: existing_op = ops.get_default_graph().get_operation_by_name(self.name) - except KeyError: + except (KeyError, TypeError): # # There is no such `Operation`, so we attempt to create it # @@ -613,7 +618,15 @@ def reify(self): # An `Operation` with this name exists, let's make sure it's # equivalent to this meta `Operation` # - if self != mt(existing_op): + existing_op_mt = mt(existing_op) + + # # Since we can't exactly reproduce all NodeDef.attr information + # # (e.g. dtypes), we need to remove any unnecessary NodeDef.attr + # # fields from comparisons with same-named nodes in the graph. + # if op_attrs.keys() != node_attr.keys(): + # existing_op_mt.node_def.attr = node_attr + + if self != existing_op_mt: raise MetaReificationError( f"An Operation with the name {self.name}" " already exists in the graph and is not" @@ -987,48 +1000,22 @@ def __api_call__(self, *args, **kwargs): if not op_args_unreified: - res_var = None - # name = op_args.get("name", None) # - # if name is not None: - # # - # # An operation with this name might already exist in the graph - # # + # We create the `Operation` in the graph # - # from tensorflow.python.framework import ops - # - # try: - # this_op = ops.get_default_graph().get_operation_by_name(name) - # except KeyError: - # pass - # else: - # # TODO: Make sure the existing `Operation` matches our arguments - # assert this_op.type == self.op_def.obj.name - # - # this_op = mt(this_op) - # op_inputs, op_node_def = self.op_args_to_operation_inputs(op_args) - # assert op_inputs == this_op.inputs - # assert op_node_def == this_op.node_def - # res_var = this_op.default_output - - if res_var is None: - # - # We create the `Operation` in the graph - # - - tf_out = self._apply_func(**op_args) - - # Ensure that the original meta objects will be available - # for use in the `metatize` that follows - tf_metatize_cache.update( - { - k: v - for k, v in zip(op_args.values(), apply_arguments.values()) - if isinstance(k, tf.Tensor) - } - ) + tf_out = self._apply_func(**op_args) + + # Ensure that the original meta objects will be available + # for use in the `metatize` that follows + tf_metatize_cache.update( + { + k: v + for k, v in zip(op_args.values(), apply_arguments.values()) + if isinstance(k, tf.Tensor) + } + ) - res_var = metatize(tf_out) + res_var = metatize(tf_out) if "names" in meta._lvar_defaults_enabled: # This should also reset the NodeDef's `obj` @@ -1073,7 +1060,8 @@ def op_args_to_operation_inputs(self, apply_arguments): node_attr = var() if "names" not in meta._lvar_defaults_enabled: - op_name = apply_arguments.get("name", op_def_tf.name) or op_def_tf.name + # default_name = ops.get_default_graph().unique_name(op_def_tf.name, mark_as_used=False) + op_name = apply_arguments.get("name", None) else: op_name = var() diff --git a/tests/tensorflow/test_meta.py b/tests/tensorflow/test_meta.py index 41a0d89..2993dde 100644 --- a/tests/tensorflow/test_meta.py +++ b/tests/tensorflow/test_meta.py @@ -636,7 +636,7 @@ def test_global_options(): with tf.Graph().as_default(), disable_auto_reification(): y_mt = mt.Placeholder('float') assert y_mt.obj is None - assert y_mt.name == 'Placeholder:0' + assert isvar(y_mt.name) assert isinstance(y_mt.op.node_def.attr, dict) with tf.Graph().as_default(), enable_lvar_defaults('names', 'node_attrs'): @@ -706,7 +706,7 @@ def test_meta_const(): @run_in_graph_mode def test_meta_existing_names(): - with tf.Graph().as_default(): + with tf.Graph().as_default() as test_graph: one_mt = mt(1) assert one_mt.op.name == 'Const' @@ -723,6 +723,7 @@ def test_meta_existing_names(): # Make sure it's the first base variable we created assert orig_one_tf is one_tf + # FYI: This implicitly creates 'Const_1' two_mt = mt(2) two_mt.op.node_def.name = 'Const' @@ -736,3 +737,15 @@ def test_meta_existing_names(): with pytest.raises(MetaReificationError): two_mt.reify() + + another_one_mt = TFlowMetaOperator('Const', None)(3, var()) + # The following is something that would happen as a result of + # reification (of the lvar in the meta object, not the meta object + # itself). + another_one_mt.op.node_def.attr['dtype'] = tf.int32 + + assert another_one_mt.op.name is None + # We need to make sure that the reified meta object actually uses a + # unique name. + assert isinstance(another_one_mt.reify(), tf.Tensor) + assert another_one_mt.reify().op.name == 'Const_2' diff --git a/tests/tensorflow/test_unify.py b/tests/tensorflow/test_unify.py index f29ebd6..19a669b 100644 --- a/tests/tensorflow/test_unify.py +++ b/tests/tensorflow/test_unify.py @@ -127,8 +127,11 @@ def test_basic_unify_reify(): test_expr = mt.add(tf.constant(1, dtype=tf.float64), mt.mul(tf.constant(2, dtype=tf.float64), - x_l)) - test_reify_res = reify(test_expr, {x_l: a}) + x_l, name=var('mul_name')), + name=var('add_name')) + test_reify_res = reify(test_expr, {x_l: a, + var('add_name'): 'Add_10', + var('mul_name'): 'Mul_10'}) test_base_res = test_reify_res.reify() assert isinstance(test_base_res, tf.Tensor) @@ -141,7 +144,7 @@ def test_basic_unify_reify(): # Simply make sure that unification succeeds meta_expected_res = mt(expected_res) s_test = unify(test_expr, meta_expected_res, {}) - assert len(s_test) == 3 + assert len(s_test) == 5 assert reify(test_expr, s_test) == meta_expected_res