You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am using Torch as the backend with Keras 3. I followed the guide at https://keras.io/guides/custom_train_step_in_torch/, but encountered the error "No module named 'tensorflow'" when trying to save the model after training.
Versions:
keras==3.8.0
torch==2.4.1
torchvision==0.19.1
Error can be reproduced by running the script provided below.
Is tensorflow still required just to save the model?
import os
os.environ["KERAS_BACKEND"] = "torch"
import torch
import keras
from keras import layers
import numpy as np
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
if len(data) == 3:
x, y, sample_weight = data
else:
sample_weight = None
x, y = data
# Call torch.nn.Module.zero_grad() to clear the leftover gradients
# for the weights from the previous train step.
self.zero_grad()
# Compute loss
y_pred = self(x, training=True) # Forward pass
loss = self.compute_loss(
y=y,
y_pred=y_pred,
sample_weight=sample_weight,
)
# Call torch.Tensor.backward() on the loss to compute gradients
# for the weights.
loss.backward()
trainable_weights = [v for v in self.trainable_weights]
gradients = [v.value.grad for v in trainable_weights]
# Update weights
with torch.no_grad():
self.optimizer.apply(gradients, trainable_weights)
# Update metrics (includes the metric that tracks the loss)
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(y, y_pred, sample_weight=sample_weight)
# Return a dict mapping metric names to current value
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}
def test_step(self, data):
# Unpack the data
x, y = data
# Compute predictions
y_pred = self(x, training=False)
# Updates the metrics tracking the loss
loss = self.compute_loss(y=y, y_pred=y_pred)
# Update the metrics.
for metric in self.metrics:
if metric.name == "loss":
metric.update_state(loss)
else:
metric.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
return {m.name: m.result() for m in self.metrics}
# Construct an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
# We don't pass a loss or metrics here.
model.compile(loss="categorical_crossentropy",
optimizer="adam",
metrics=["accuracy"])
# Just use `fit` as usual -- you can use callbacks, etc.
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
model.export("./")
Error:
Traceback (most recent call last):
File "/home/perfuser/shailesh/torch_3feb/lib/python3.11/site-packages/keras/src/utils/module_utils.py", line 27, in initialize
self.module = importlib.import_module(self.name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/lib/python3.11/importlib/__init__.py", line 126, in import_module
return _bootstrap._gcd_import(name[level:], package, level)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<frozen importlib._bootstrap>", line 1204, in _gcd_import
File "<frozen importlib._bootstrap>", line 1176, in _find_and_load
File "<frozen importlib._bootstrap>", line 1140, in _find_and_load_unlocked
ModuleNotFoundError: No module named 'tensorflow'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/perfuser/shailesh/openfl_pytoch_latest_3_feb/torch_demo.py", line 162, in <module>
model.export("./best.pbuf")
File "/home/perfuser/shailesh/torch_3feb/lib/python3.11/site-packages/keras/src/models/model.py", line 539, in export
export_saved_model(
File "/home/perfuser/shailesh/torch_3feb/lib/python3.11/site-packages/keras/src/export/saved_model.py", line 624, in export_saved_model
export_archive = ExportArchive()
^^^^^^^^^^^^^^^
File "/home/perfuser/shailesh/torch_3feb/lib/python3.11/site-packages/keras/src/export/saved_model.py", line 118, in __init__
self.tensorflow_version = tf.__version__
^^^^^^^^^^^^^^
File "/home/perfuser/shailesh/torch_3feb/lib/python3.11/site-packages/keras/src/utils/module_utils.py", line 35, in __getattr__
self.initialize()
File "/home/perfuser/shailesh/torch_3feb/lib/python3.11/site-packages/keras/src/utils/module_utils.py", line 29, in initialize
raise ImportError(self.import_error_msg)
ImportError: This requires the tensorflow module. You can install it via `pip install tensorflow`
if I use model.save instead of export, it result in error listed below.
model.save("./best.pbuf")
Traceback (most recent call last):
File "/home/perfuser/shailesh/openfl_pytoch_latest_3_feb/torch_demo.py", line 162, in <module>
model.save("./best.pbuf")
File "/home/perfuser/shailesh/torch_3feb/lib/python3.11/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/perfuser/shailesh/torch_3feb/lib/python3.11/site-packages/keras/src/saving/saving_api.py", line 114, in save_model
raise ValueError(
ValueError: Invalid filepath extension for saving. Please add either a `.keras` extension for the native Keras format (recommended) or a `.h5` extension. Use `model.export(filepath)` if you want to export a SavedModel for use with TFLite/TFServing/etc. Received: filepath=./best.pbuf.
model.export expects an argument format, which can be either tf_saved_model or onnx. TensorFlow is needed for the default option, which is tf_saved_model.
model.save only supports .keras or .h5 file formats.
Hi,
I am using Torch as the backend with Keras 3. I followed the guide at https://keras.io/guides/custom_train_step_in_torch/, but encountered the error "No module named 'tensorflow'" when trying to save the model after training.
Versions:
Error can be reproduced by running the script provided below.
Is tensorflow still required just to save the model?
Error:
if I use model.save instead of export, it result in error listed below.
pip freeze
The text was updated successfully, but these errors were encountered: