Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inplace modification error during back propagation when training a stateful LSTM using PyTorch as backend #20875

Open
Arian96669 opened this issue Feb 7, 2025 · 0 comments

Comments

@Arian96669
Copy link

Arian96669 commented Feb 7, 2025

When creating a stateful LSMT model in Keras 3 with PyTorch as the backend PyTorch performs in-place operations during back propagation in the LSTM layer.

Error description: RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [32, 50]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

Further investigation with detect_anomaly:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [32, 50]] is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later

When the same test code is run with both JAX or TensorFlow as backend the task runs without any errors. The problem is confined to Keras with PyTorch backend only.

Test script:

import numpy as np
import os
os.environ["KERAS_BACKEND"] = "torch"
import keras

# Sample dataset generator for demonstration
def generate_time_series_data(batch_size, time_steps, num_features):
    while True:
        x = np.random.rand(batch_size, time_steps, num_features)
        y = np.sum(x, axis=2)  # Just an example: target is the sum of features along the time step
        yield x, y

# Parameters
batch_size = 32  # Number of sequences per batch
time_steps = 10  # Length of each sequence
num_features = 3  # Number of features per time step
epochs = 10  # Number of epochs

# Build the LSTM model
model = keras.Sequential()
model.add(keras.Input(shape=(time_steps, num_features), batch_size=batch_size))
lstm_layer = keras.layers.LSTM(50,
               stateful=True,
               return_sequences=False) # return_sequences can be True if another LSTM is added
model.add(lstm_layer)
model.add(keras.layers.Dense(1, activation='linear'))  # For scalar output

# Compile the model with optimizer and loss function
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001), loss='mse', metrics=['mae'])

# Print model summary
model.summary()

# Generate dummy training data
train_generator = generate_time_series_data(batch_size, time_steps, num_features)
steps_per_epoch = 100  # Number of batches per epoch

# Train the model with stateful data
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    model.fit(train_generator, steps_per_epoch=steps_per_epoch, epochs=1, verbose=1, shuffle=False)
    # Reset states after each epoch
    lstm_layer.reset_states()

PraveenH has created a temporary fix in Colab repo. Please have a look at: (https://colab.research.google.com/drive/1_8HKONyYbWMRLuEJfIS3PSReMEA9xfu6#scrollTo=bPjFIAarl0wL)

We're not yet 100% certain it is the best solution (so will be guided by what team-keras comes up with), but it sure is excellent work in my view, thank you Praveen for the deep dive into this!

Can I ask for this to be logged as a bug?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants