Skip to content

Commit

Permalink
Merge pull request #13 from tensorflow/tft-0.1.9
Browse files Browse the repository at this point in the history
Project import generated by Copybara.

PiperOrigin-RevId: 156233589
  • Loading branch information
elmer-garduno authored May 17, 2017
2 parents 14ee57f + 76bfb40 commit 3206f45
Show file tree
Hide file tree
Showing 16 changed files with 1,359 additions and 1,490 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from setuptools import setup

# Tensorflow transform version.
__version__ = '0.1.8'
__version__ = '0.1.9'


def _make_required_install_packages():
Expand Down
242 changes: 154 additions & 88 deletions tensorflow_transform/analyzers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,164 +11,234 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF.Transform analyzers."""
"""Functions that involve a full pass over the dataset.
This module contains functions that are used in the preprocessing function, to
define a full pass operation such as computing the sum, min, max or unique
values of a tensor over the entire dataset. This is implemented by a reduction
operation in the Beam implementation.
From the user's point of view, an analyzer appears as a regular TensorFlow
function, i.e. it accepts and returns tensors. However it is represented in
the graph as a `Analyzer` which is not a TensorFlow op, but a placeholder for
the computation that takes place outside of TensorFlow.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_transform import api


def _get_output_shape(x, reduce_instance_dims):
"""Determines the shape of the output of a numerical analyzer.
ANALYZER_COLLECTION = 'tft_analyzers'


class Analyzer(object):
"""An operation-like class for full-pass analyses of data.
An Analyzer is like a tf.Operation except that it requires computation over
the full dataset. E.g. sum(my_tensor) will compute the sum of the value of
my_tensor over all instances in the dataset. The Analyzer class contains the
inputs to this computation, and placeholders which will later be converted to
constants during a call to AnalyzeDataset.
Args:
x: An input `Column' wrapping a `Tensor`.
reduce_instance_dims: If true, collapses the batch and instance dimensions
to arrive at a single scalar output. If False, only collapses the batch
dimension and outputs a vector of the same shape as the output.
inputs: The inputs to the analyzer.
output_shapes_and_dtype: List of pairs of (shape, dtype) for each output.
spec: A description of the computation to be done.
Returns:
The shape to use for the output placeholder.
Raises:
ValueError: If the inputs are not all `Tensor`s.
"""

def __init__(self, inputs, output_shapes_and_dtypes, spec):
for tensor in inputs:
if not isinstance(tensor, tf.Tensor):
raise ValueError('Analyzers can only accept `Tensor`s as inputs')
self._inputs = inputs
self._outputs = [tf.placeholder(shape, dtype)
for shape, dtype in output_shapes_and_dtypes]
self._spec = spec
tf.add_to_collection(ANALYZER_COLLECTION, self)

@property
def inputs(self):
return self._inputs

@property
def outputs(self):
return self._outputs

@property
def spec(self):
return self._spec


class NumericCombineSpec(object):
"""Operation to combine numeric values."""

MIN = 'min'
MAX = 'max'
SUM = 'sum'

def __init__(self, dtype, combiner_type, reduce_instance_dims):
self._dtype = dtype
self._combiner_type = combiner_type
self._reduce_instance_dims = reduce_instance_dims

@property
def dtype(self):
return self._dtype

@property
def combiner_type(self):
return self._combiner_type

@property
def reduce_instance_dims(self):
return self._reduce_instance_dims


def _numeric_combine(x, combiner_type, reduce_instance_dims=True):
"""Apply an analyzer with NumericCombineSpec to given input."""
if not isinstance(x, tf.Tensor):
raise TypeError('Expected a Tensor, but got %r' % x)

if reduce_instance_dims:
# Numerical analyzers produce scalar output by default
return ()
# If reducing over all dimensions, result is scalar.
shape = ()
elif x.shape.dims is not None:
# If reducing over batch dimensions, with known shape, the result will be
# the same shape as the input, but without the batch.
shape = x.shape.as_list()[1:]
else:
in_shape = x.tensor.shape
if in_shape:
# The output will be the same shape as the input, but without the batch.
return in_shape.as_list()[1:]
else:
return None
# If reducing over batch dimensions, with unknown shape, the result will
# also have unknown shape.
shape = None
with tf.name_scope(combiner_type):
spec = NumericCombineSpec(x.dtype, combiner_type, reduce_instance_dims)
return Analyzer([x], [(x.dtype, shape)], spec).outputs[0]


def min(x, reduce_instance_dims=True): # pylint: disable=redefined-builtin
"""Computes the minimum of a `Column`.
"""Computes the minimum of the values of a `Tensor` over the whole dataset.
Args:
x: An input `Column' wrapping a `Tensor`.
x: A `Tensor`.
reduce_instance_dims: By default collapses the batch and instance dimensions
to arrive at a single scalar output. If False, only collapses the batch
dimension and outputs a vector of the same shape as the output.
dimension and outputs a `Tensor` of the same shape as the input.
Returns:
A `Statistic`.
A `Tensor`.
"""
if not isinstance(x.tensor, tf.Tensor):
raise TypeError('Expected a Tensor, but got %r' % x.tensor)

arg_dict = {'reduce_instance_dims': reduce_instance_dims}

# pylint: disable=protected-access
return api._AnalyzerOutput(
tf.placeholder(x.tensor.dtype, _get_output_shape(
x, reduce_instance_dims)), api.CanonicalAnalyzers.MIN, [x], arg_dict)
return _numeric_combine(x, NumericCombineSpec.MIN, reduce_instance_dims)


def max(x, reduce_instance_dims=True): # pylint: disable=redefined-builtin
"""Computes the maximum of a `Column`.
"""Computes the maximum of the values of a `Tensor` over the whole dataset.
Args:
x: An input `Column' wrapping a `Tensor`.
x: A `Tensor`.
reduce_instance_dims: By default collapses the batch and instance dimensions
to arrive at a single scalar output. If False, only collapses the batch
dimension and outputs a vector of the same shape as the output.
Returns:
A `Statistic`.
A `Tensor`.
"""
if not isinstance(x.tensor, tf.Tensor):
raise TypeError('Expected a Tensor, but got %r' % x.tensor)

arg_dict = {'reduce_instance_dims': reduce_instance_dims}
# pylint: disable=protected-access
return api._AnalyzerOutput(
tf.placeholder(x.tensor.dtype, _get_output_shape(
x, reduce_instance_dims)), api.CanonicalAnalyzers.MAX, [x], arg_dict)
return _numeric_combine(x, NumericCombineSpec.MAX, reduce_instance_dims)


def sum(x, reduce_instance_dims=True): # pylint: disable=redefined-builtin
"""Computes the sum of a `Column`.
"""Computes the sum of the values of a `Tensor` over the whole dataset.
Args:
x: An input `Column' wrapping a `Tensor`.
x: A `Tensor`.
reduce_instance_dims: By default collapses the batch and instance dimensions
to arrive at a single scalar output. If False, only collapses the batch
dimension and outputs a vector of the same shape as the output.
Returns:
A `Statistic`.
A `Tensor`.
"""
if not isinstance(x.tensor, tf.Tensor):
raise TypeError('Expected a Tensor, but got %r' % x.tensor)

arg_dict = {'reduce_instance_dims': reduce_instance_dims}
# pylint: disable=protected-access
return api._AnalyzerOutput(
tf.placeholder(x.tensor.dtype, _get_output_shape(
x, reduce_instance_dims)), api.CanonicalAnalyzers.SUM, [x], arg_dict)
return _numeric_combine(x, NumericCombineSpec.SUM, reduce_instance_dims)


def size(x, reduce_instance_dims=True):
"""Computes the total size of instances in a `Column`.
"""Computes the total size of instances in a `Tensor` over the whole dataset.
Args:
x: An input `Column' wrapping a `Tensor`.
x: A `Tensor`.
reduce_instance_dims: By default collapses the batch and instance dimensions
to arrive at a single scalar output. If False, only collapses the batch
dimension and outputs a vector of the same shape as the output.
Returns:
A `Statistic`.
A `Tensor`.
"""
if not isinstance(x.tensor, tf.Tensor):
raise TypeError('Expected a Tensor, but got %r' % x.tensor)

# Note: Calling `sum` defined in this module, not the builtin.
return sum(api.map(tf.ones_like, x), reduce_instance_dims)
with tf.name_scope('size'):
# Note: Calling `sum` defined in this module, not the builtin.
return sum(tf.ones_like(x), reduce_instance_dims)


def mean(x, reduce_instance_dims=True):
"""Computes the mean of the values in a `Column`.
"""Computes the mean of the values of a `Tensor` over the whole dataset.
Args:
x: An input `Column' wrapping a `Tensor`.
x: A `Tensor`.
reduce_instance_dims: By default collapses the batch and instance dimensions
to arrive at a single scalar output. If False, only collapses the batch
dimension and outputs a vector of the same shape as the output.
Returns:
A `Column` with an underlying `Tensor` of shape [1], containing the mean.
A `Tensor` containing the mean.
"""
if not isinstance(x.tensor, tf.Tensor):
raise TypeError('Expected a Tensor, but got %r' % x.tensor)
with tf.name_scope('mean'):
# Note: Calling `sum` defined in this module, not the builtin.
return tf.divide(
sum(x, reduce_instance_dims), size(x, reduce_instance_dims))


class UniquesSpec(object):
"""Operation to compute unique values."""

def __init__(self, dtype, top_k, frequency_threshold):
self._dtype = dtype
self._top_k = top_k
self._frequency_threshold = frequency_threshold

# Note: Calling `sum` defined in this module, not the builtin.
return api.map_statistics(tf.divide,
sum(x, reduce_instance_dims),
size(x, reduce_instance_dims))
@property
def dtype(self):
return self._dtype

@property
def top_k(self):
return self._top_k

@property
def frequency_threshold(self):
return self._frequency_threshold


def uniques(x, top_k=None, frequency_threshold=None):
"""Returns the unique values of the input tensor.
"""Computes the unique values of a `Tensor` over the whole dataset.
Computes The unique values taken by the input column `x`, which can be backed
by a `Tensor` or `SparseTensor` of any size. The unique values will be
aggregated over all dimensions of `x` and all instances.
Computes The unique values taken by `x`, which can be a `Tensor` or
`SparseTensor` of any size. The unique values will be aggregated over all
dimensions of `x` and all instances.
The unique values are sorted by decreasing frequency and then decreasing
value.
Args:
x: An input `Column` wrapping a `Tensor` or `SparseTensor`.
x: An input `Tensor` or `SparseTensor`.
top_k: Limit the generated vocabulary to the first `top_k` elements. If set
to None, the full vocabulary is generated.
frequency_threshold: Limit the generated vocabulary only to elements whose
frequency is >= to the supplied threshold. If set to None, the full
vocabulary is generated.
vocabulary is generated
Returns:
The unique values of `x`.
Expand All @@ -184,15 +254,11 @@ def uniques(x, top_k=None, frequency_threshold=None):
if frequency_threshold is not None:
frequency_threshold = int(frequency_threshold)
if frequency_threshold < 0:
raise ValueError('frequency_threshold must be non-negative, but got: %r' %
frequency_threshold)

if isinstance(x.tensor, tf.SparseTensor):
values = x.tensor.values
else:
values = x.tensor
arg_dict = {'top_k': top_k, 'frequency_threshold': frequency_threshold}
# Create output placeholder whose shape is a 1-d tensor of unkown size.
# pylint: disable=protected-access
return api._AnalyzerOutput(tf.placeholder(values.dtype, (None,)),
api.CanonicalAnalyzers.UNIQUES, [x], arg_dict)
raise ValueError(
'frequency_threshold must be non-negative, but got: %r' %
frequency_threshold)
if isinstance(x, tf.SparseTensor):
x = x.values
with tf.name_scope('uniques'):
spec = UniquesSpec(x.dtype, top_k, frequency_threshold)
return Analyzer([x], [(x.dtype, [None])], spec).outputs[0]
Loading

0 comments on commit 3206f45

Please sign in to comment.