diff --git a/keras/src/saving/serialization_lib.py b/keras/src/saving/serialization_lib.py index 535478b62bb..48c70808b40 100644 --- a/keras/src/saving/serialization_lib.py +++ b/keras/src/saving/serialization_lib.py @@ -783,7 +783,8 @@ def _retrieve_class_or_fn( # Otherwise, attempt to retrieve the class object given the `module` # and `class_name`. Import the module, find the class. - if module == "keras.src" or module.startswith("keras.src."): + package = module.split(".", maxsplit=1)[0] + if package in {"keras", "keras_hub", "keras_cv", "keras_nlp"}: try: mod = importlib.import_module(module) obj = vars(mod).get(name, None)