Skip to content

Commit

Permalink
Merge pull request #38 from zoyahav/master
Browse files Browse the repository at this point in the history
Project import generated by Copybara.
  • Loading branch information
KesterTong authored Oct 20, 2017
2 parents 51a0c5f + 03f16fb commit a9d2911
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 57 deletions.
48 changes: 28 additions & 20 deletions tensorflow_transform/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,28 +49,36 @@ class Analyzer(object):
Args:
inputs: The inputs to the analyzer.
output_tensors_and_is_asset: List of pairs of `(Tensor, bool)` for each
output. The `Tensor`s are typically placeholders; they will be later
be replaced with analysis results. The boolean value states whether this
Tensor represents an asset filename or not.
output_dtype_shape_and_is_asset: List of tuples of `(DType, Shape, bool)`
for each output. A tf.placeholder with the given DType and Shape will be
constructed to represent the output of the analyzer, and this placeholder
will eventually be replaced by the actual value of the analyzer. The
boolean value states whether this Tensor represents an asset filename or
not.
spec: A description of the computation to be done.
name: Similar to a TF op name. Used to define a unique scope for this
analyzer, which can be used for debugging info.
Raises:
ValueError: If the inputs are not all `Tensor`s.
"""

def __init__(self, inputs, output_tensors_and_is_asset, spec):
def __init__(self, inputs, output_dtype_shape_and_is_asset, spec, name):
for tensor in inputs:
if not isinstance(tensor, tf.Tensor):
raise ValueError('Analyzers can only accept `Tensor`s as inputs')
self._inputs = inputs
for output_tensor, is_asset in output_tensors_and_is_asset:
if is_asset and output_tensor.dtype != tf.string:
raise ValueError(('Tensor {} cannot represent an asset, because it is '
'not a string.').format(output_tensor.name))
self._outputs = [output_tensor
for output_tensor, _ in output_tensors_and_is_asset]
self._output_is_asset_map = dict(output_tensors_and_is_asset)
self._outputs = []
self._output_is_asset_map = {}
with tf.name_scope(name) as scope:
self._name = scope
for dtype, shape, is_asset in output_dtype_shape_and_is_asset:
output_tensor = tf.placeholder(dtype, shape)
if is_asset and output_tensor.dtype != tf.string:
raise ValueError(('Tensor {} cannot represent an asset, because it '
'is not a string.').format(output_tensor.name))
self._outputs.append(output_tensor)
self._output_is_asset_map[output_tensor] = is_asset
self._spec = spec
tf.add_to_collection(ANALYZER_COLLECTION, self)

Expand All @@ -86,6 +94,10 @@ def outputs(self):
def spec(self):
return self._spec

@property
def name(self):
return self._name

def output_is_asset(self, output_tensor):
return self._output_is_asset_map[output_tensor]

Expand Down Expand Up @@ -131,11 +143,9 @@ def _numeric_combine(x, combiner_type, reduce_instance_dims=True):
# If reducing over batch dimensions, with unknown shape, the result will
# also have unknown shape.
shape = None
with tf.name_scope(combiner_type):
spec = NumericCombineSpec(x.dtype, combiner_type, reduce_instance_dims)
return Analyzer([x],
[(tf.placeholder(x.dtype, shape), False)],
spec).outputs[0]
spec = NumericCombineSpec(x.dtype, combiner_type, reduce_instance_dims)
return Analyzer(
[x], [(x.dtype, shape, False)], spec, combiner_type).outputs[0]


def min(x, reduce_instance_dims=True): # pylint: disable=redefined-builtin
Expand Down Expand Up @@ -381,9 +391,7 @@ def uniques(x, top_k=None, frequency_threshold=None,

spec = UniquesSpec(tf.string, top_k, frequency_threshold,
vocab_filename, store_frequency)
return Analyzer([x],
[(tf.placeholder(tf.string, []), True)],
spec).outputs[0]
return Analyzer([x], [(tf.string, [], True)], spec, 'uniques').outputs[0]


class QuantilesSpec(object):
Expand Down
13 changes: 6 additions & 7 deletions tensorflow_transform/beam/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,9 +513,8 @@ def __init__(self, analyzers, base_temp_dir):
def expand(self, analyzer_input_values):
# For each analyzer output, look up its input values (by tensor name)
# and run the analyzer on these values.
#
result = {}
for idx, analyzer in enumerate(self._analyzers):
for analyzer in self._analyzers:
temp_assets_dir = _make_unique_temp_dir(self._base_temp_dir)
tf.gfile.MkDir(temp_assets_dir)
analyzer_impl = analyzer_impls._impl_for_analyzer(
Expand All @@ -525,10 +524,10 @@ def expand(self, analyzer_input_values):
assert len(analyzer.inputs) == 1
output_pcolls = (
analyzer_input_values
| 'Extract_%d' % idx >> beam.Map(
| 'ExtractInput[%s]' % analyzer.name >> beam.Map(
lambda batch, key: batch[key],
key=analyzer.inputs[0].name)
| 'Analyze_%d' % idx >> analyzer_impl)
| 'Analyze[%s]' % analyzer.name >> analyzer_impl)
assert len(analyzer.outputs) == len(output_pcolls), (
'Analyzer outputs don\'t match the expected outputs from the '
'Analyzer definition: %d != %d' %
Expand All @@ -537,7 +536,7 @@ def expand(self, analyzer_input_values):
for collection_idx, (tensor, pcoll) in enumerate(
zip(analyzer.outputs, output_pcolls)):
is_asset = analyzer.output_is_asset(tensor)
pcoll |= ('WrapAsTensorValue_%d_%d' % (idx, collection_idx)
pcoll |= ('WrapAsTensorValue[%s][%d]' % (analyzer.name, collection_idx)
>> beam.Map(_TensorValue, is_asset))
result[tensor.name] = pcoll
return result
Expand Down Expand Up @@ -711,15 +710,15 @@ def expand(self, dataset):
graph, inputs, analyzer_inputs, unbound_saved_model_dir)
saved_model_dir = (
tensor_pcoll_mapping
| 'CreateSavedModelForAnaylzerInputs_%d' % level
| 'CreateSavedModelForAnaylzerInputs[%d]' % level
>> _ReplaceTensorsWithConstants(
unbound_saved_model_dir, base_temp_dir, input_values.pipeline))

# Run this saved model on the input dataset to obtain the inputs to the
# analyzers.
analyzer_input_values = (
input_values
| 'ComputeAnalyzerInputs_%d' % level >> beam.ParDo(
| 'ComputeAnalyzerInputs[%d]' % level >> beam.ParDo(
_RunMetaGraphDoFn(
input_schema,
analyzer_inputs_schema,
Expand Down
18 changes: 10 additions & 8 deletions tensorflow_transform/coders/example_proto_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,21 @@ class _FixedLenFeatureHandler(object):
def __init__(self, name, feature_spec):
self._name = name
self._np_dtype = feature_spec.dtype.as_numpy_dtype
self._default_value = feature_spec.default_value
self._value_fn = _make_feature_value_fn(feature_spec.dtype)
self._shape = feature_spec.shape
self._rank = len(feature_spec.shape)
if self._rank > 0 and self._default_value:
raise ValueError('FixedLenFeature %r got default value for rank > 0, '
'only scalar default values are supported'
% (self._name,))
if isinstance(self._default_value, list):
raise ValueError('FixedLenFeature %r got non-scalar default value, '
'only scalar default values are supported' %
(self._name,))
self._size = 1
for dim in feature_spec.shape:
self._size *= dim
self._default_value = feature_spec.default_value
if self._default_value:
if list(np.asarray(self._default_value).shape) != self._shape:
raise ValueError(
'FixedLenFeature %r got default value with incorrect shape' %
(self._name,))
self._default_value = np.asarray(self._default_value).reshape(-1).tolist()

@property
def name(self):
Expand All @@ -152,7 +154,7 @@ def parse_value(self, feature_map):
feature = feature_map[self._name]
values = self._value_fn(feature)
elif self._default_value is not None:
values = self._default_value
values = [self._default_value]
else:
values = []

Expand Down
30 changes: 8 additions & 22 deletions tensorflow_transform/coders/example_proto_coder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,8 @@ def test_example_proto_coder(self):

def test_example_proto_coder_default_value(self):
input_schema = dataset_schema.from_feature_spec({
'scalar_feature_3':
tf.FixedLenFeature(shape=[], dtype=tf.float32, default_value=1.0),
'1d_vector_feature':
tf.FixedLenFeature(
shape=[1], dtype=tf.float32, default_value=[2.0]),
'2d_vector_feature':
tf.FixedLenFeature(
shape=[2, 2],
dtype=tf.float32,
default_value=[[1.0, 2.0], [3.0, 4.0]])
'scalar_feature_3': tf.FixedLenFeature(shape=[], dtype=tf.float32,
default_value=1.0),
})
coder = example_proto_coder.ExampleProtoCoder(input_schema)

Expand All @@ -193,31 +185,25 @@ def test_example_proto_coder_default_value(self):
# Assert the data is decoded into the expected format.
expected_decoded = {
'scalar_feature_3': 1.0,
'1d_vector_feature': [2.0],
'2d_vector_feature': [[1.0, 2.0], [3.0, 4.0]]
}
decoded = coder.decode(data)
np.testing.assert_equal(expected_decoded, decoded)

def test_example_proto_coder_bad_default_value(self):
input_schema = dataset_schema.from_feature_spec({
'1d_vector_feature':
tf.FixedLenFeature(
shape=[2], dtype=tf.float32, default_value=[1.0]),
'scalar_feature_2': tf.FixedLenFeature(shape=[2], dtype=tf.float32,
default_value=[1.0, 2.0]),
})
with self.assertRaisesRegexp(ValueError,
'got default value with incorrect shape'):
'only scalar default values are supported'):
example_proto_coder.ExampleProtoCoder(input_schema)

input_schema = dataset_schema.from_feature_spec({
'2d_vector_feature':
tf.FixedLenFeature(
shape=[2, 3],
dtype=tf.float32,
default_value=[[1.0, 1.0], [1.0]]),
'scalar_feature_2': tf.FixedLenFeature(shape=[], dtype=tf.float32,
default_value=[1.0]),
})
with self.assertRaisesRegexp(ValueError,
'got default value with incorrect shape'):
'only scalar default values are supported'):
example_proto_coder.ExampleProtoCoder(input_schema)

def test_example_proto_coder_picklable(self):
Expand Down

0 comments on commit a9d2911

Please sign in to comment.