From 132160b8ad9a61c63fe52e36f2d8e178ada323a2 Mon Sep 17 00:00:00 2001 From: Alexander Novikov Date: Tue, 28 Mar 2017 19:37:56 +0300 Subject: [PATCH] support validate_shape and validate_tt_ranks of get_variable --- t3f/shapes.py | 27 ++++++++++++++++++++++++++- t3f/variables.py | 26 +++++++++++++++++++++----- 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/t3f/shapes.py b/t3f/shapes.py index 3e0ff172..7a6e2942 100644 --- a/t3f/shapes.py +++ b/t3f/shapes.py @@ -269,4 +269,29 @@ def expand_batch_dim(tt): for core_idx in range(tt.ndims()): tt_cores.append(tf.expand_dims(tt.tt_cores[core_idx], 0)) return TensorTrainBatch(tt_cores, tt.get_raw_shape(), tt.get_tt_ranks(), - batch_size=1) \ No newline at end of file + batch_size=1) + + +def is_compatible_with(tt_a, tt_b): + """Compare compatibility of dtype, shape, and TT-ranks of 2 TT-objects.""" + + a_shape = tt_a.get_raw_shape() + b_shape = tt_b.get_raw_shape() + if len(a_shape) == len(b_shape): + return False + for i in range(len(a_shape)): + if not a_shape[i].is_compatible_with(b_shape[i]): + return False + + if not tt_a.get_tt_ranks().is_compatible_with(tt_b.get_tt_ranks()): + return False + + if tt_a.dtype != tt_b.dtype: + return False + + return True + + +def is_fully_defined(raw_shape): + """Check that a raw shape (a list of TensorShapes) is fully defined.""" + return all([s.is_fully_defined() for s in raw_shape]) diff --git a/t3f/variables.py b/t3f/variables.py index 33b5512c..39b186dc 100644 --- a/t3f/variables.py +++ b/t3f/variables.py @@ -2,6 +2,7 @@ from tensor_train import TensorTrain from tensor_train_batch import TensorTrainBatch +import shapes def get_variable(name, @@ -11,7 +12,8 @@ def get_variable(name, trainable=True, collections=None, caching_device=None, - validate_shape=True): + validate_shape=True, + validate_tt_ranks=True): """Returns TensorTrain object with tf.Variables as the TT-cores. Args: @@ -37,6 +39,9 @@ def get_variable(name, validate_shape: If False, allows the variable to be initialized with a value of unknown shape. If True, the default, the shape of initial_value must be known. + validate_tt_ranks: If False, allows the variable to be initialized with a + value of unknown TT-ranks. If True, the default, the TT-ranks of + initial_value must be known. Returns: The created or existing `TensorTrain` object with tf.Variables TT-cores. @@ -46,14 +51,25 @@ def get_variable(name, violating reuse during variable creation, or when initializer dtype and dtype don't match. Reuse is set inside variable_scope. """ - # TODO: support validate shape: check that the tensor dimensions are correct, - # but ignore the ranks. - # TODO: add validate ranks flag. - reuse = tf.get_variable_scope().reuse if not reuse and initializer is None: raise ValueError('Scope reuse is False and initializer is not provided.') + if initializer is not None: + if validate_shape: + raw_shape = initializer.get_raw_shape() + if not shapes.is_fully_defined(raw_shape): + raise ValueError('The shape of the initializer (%s) is not fully ' + 'defined. If you want to create a variable anyway, ' + 'use validate_shape=False.' % raw_shape) + + if validate_tt_ranks: + tt_ranks = initializer.get_tt_ranks() + if not tt_ranks.is_fully_defined(): + raise ValueError('The TT-ranks of the initializer (%s) are not fully ' + 'defined. If you want to create a variable anyway, ' + 'use validate_tt_ranks=False.' % tt_ranks) + variable_cores = [] if reuse: # Find an existing variable in the collection.