From cc467763c84589fec2e08f21475815f0b4c25bb6 Mon Sep 17 00:00:00 2001 From: Surya Date: Sat, 8 Feb 2025 03:46:08 +0530 Subject: [PATCH] fix time_distributed layer with mask and partial_batch_size (#20765) * fix time_distributed layer with mask and partial_batch_size * Fix test fails for non TF backends * Fix formatting issue * test case and inline import of TF * Disable testcase for Numpy backend * Fix lint error --- keras/src/layers/rnn/time_distributed.py | 25 ++++++++++++++++--- keras/src/layers/rnn/time_distributed_test.py | 22 ++++++++++++++++ 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/keras/src/layers/rnn/time_distributed.py b/keras/src/layers/rnn/time_distributed.py index e61274d96c08..d6c7f57fc7d6 100644 --- a/keras/src/layers/rnn/time_distributed.py +++ b/keras/src/layers/rnn/time_distributed.py @@ -77,10 +77,29 @@ def call(self, inputs, training=None, mask=None): batch_size = input_shape[0] timesteps = input_shape[1] - if mask_shape is not None and mask_shape[:2] != (batch_size, timesteps): + # For TF backend with graph mode and `partial_batch_size`, skip + # evaluation of `batch_size` as it can be a `strided_slice` and + # not a constant. + if backend.backend() == "tensorflow": + from keras.src.utils.module_utils import tensorflow as tf + + if ( + not tf.executing_eagerly + and mask_shape is not None + and mask_shape[1:2] != (timesteps,) + ): + raise ValueError( + "`TimeDistributed` Layer should be passed a `mask` of " + f"shape ({batch_size}, {timesteps}, ...), " + f"received: mask.shape={mask_shape}" + ) + elif mask_shape is not None and mask_shape[:2] != ( + batch_size, + timesteps, + ): raise ValueError( - "`TimeDistributed` Layer should be passed a `mask` of shape " - f"({batch_size}, {timesteps}, ...), " + "`TimeDistributed` Layer should be passed a `mask` of " + f"shape ({batch_size}, {timesteps}, ...), " f"received: mask.shape={mask_shape}" ) diff --git a/keras/src/layers/rnn/time_distributed_test.py b/keras/src/layers/rnn/time_distributed_test.py index f2ad37e9d110..87cc31fe6197 100644 --- a/keras/src/layers/rnn/time_distributed_test.py +++ b/keras/src/layers/rnn/time_distributed_test.py @@ -6,6 +6,7 @@ from keras.src import layers from keras.src import ops from keras.src import testing +from keras.src.models import Sequential class TimeDistributedTest(testing.TestCase): @@ -77,3 +78,24 @@ def call(self, inputs, training=False, mask=None): np.array([[[0], [0.22]], [[0.38], [0]], [[0.7], [0.86]]]), output, ) + + @pytest.mark.requires_trainable_backend + def test_with_mask_zero(self): + model = Sequential( + [ + layers.Input(shape=(20,)), + layers.Embedding(input_dim=10, output_dim=5, mask_zero=True), + layers.TimeDistributed( + layers.Dense(units=5, activation="softmax") + ), + ] + ) + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=["accuracy"], + ) + X_train = np.random.uniform(1, 10, size=(22, 20)) + Y_train = np.random.randint(1, 2, size=(22, 20)) + + model.fit(X_train, Y_train, epochs=1, batch_size=16)