Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support validate_shape and validate_tt_ranks of get_variable #40

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion t3f/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,29 @@ def expand_batch_dim(tt):
for core_idx in range(tt.ndims()):
tt_cores.append(tf.expand_dims(tt.tt_cores[core_idx], 0))
return TensorTrainBatch(tt_cores, tt.get_raw_shape(), tt.get_tt_ranks(),
batch_size=1)
batch_size=1)


def is_compatible_with(tt_a, tt_b):
"""Compare compatibility of dtype, shape, and TT-ranks of 2 TT-objects."""

a_shape = tt_a.get_raw_shape()
b_shape = tt_b.get_raw_shape()
if len(a_shape) == len(b_shape):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be !=

return False
for i in range(len(a_shape)):
if not a_shape[i].is_compatible_with(b_shape[i]):
return False

if not tt_a.get_tt_ranks().is_compatible_with(tt_b.get_tt_ranks()):
return False

if tt_a.dtype != tt_b.dtype:
return False

return True


def is_fully_defined(raw_shape):
"""Check that a raw shape (a list of TensorShapes) is fully defined."""
return all([s.is_fully_defined() for s in raw_shape])
26 changes: 21 additions & 5 deletions t3f/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from tensor_train import TensorTrain
from tensor_train_batch import TensorTrainBatch
import shapes


def get_variable(name,
Expand All @@ -11,7 +12,8 @@ def get_variable(name,
trainable=True,
collections=None,
caching_device=None,
validate_shape=True):
validate_shape=True,
validate_tt_ranks=True):
"""Returns TensorTrain object with tf.Variables as the TT-cores.

Args:
Expand All @@ -37,6 +39,9 @@ def get_variable(name,
validate_shape: If False, allows the variable to be initialized with a value
of unknown shape. If True, the default, the shape of initial_value must be
known.
validate_tt_ranks: If False, allows the variable to be initialized with a
value of unknown TT-ranks. If True, the default, the TT-ranks of
initial_value must be known.

Returns:
The created or existing `TensorTrain` object with tf.Variables TT-cores.
Expand All @@ -46,14 +51,25 @@ def get_variable(name,
violating reuse during variable creation, or when initializer dtype and
dtype don't match. Reuse is set inside variable_scope.
"""
# TODO: support validate shape: check that the tensor dimensions are correct,
# but ignore the ranks.
# TODO: add validate ranks flag.

reuse = tf.get_variable_scope().reuse
if not reuse and initializer is None:
raise ValueError('Scope reuse is False and initializer is not provided.')

if initializer is not None:
if validate_shape:
raw_shape = initializer.get_raw_shape()
if not shapes.is_fully_defined(raw_shape):
raise ValueError('The shape of the initializer (%s) is not fully '
'defined. If you want to create a variable anyway, '
'use validate_shape=False.' % raw_shape)

if validate_tt_ranks:
tt_ranks = initializer.get_tt_ranks()
if not tt_ranks.is_fully_defined():
raise ValueError('The TT-ranks of the initializer (%s) are not fully '
'defined. If you want to create a variable anyway, '
'use validate_tt_ranks=False.' % tt_ranks)

variable_cores = []
if reuse:
# Find an existing variable in the collection.
Expand Down