diff --git a/keras_tuner/engine/base_tuner.py b/keras_tuner/engine/base_tuner.py index b17bf58f3..4d0c43959 100644 --- a/keras_tuner/engine/base_tuner.py +++ b/keras_tuner/engine/base_tuner.py @@ -279,6 +279,8 @@ def _try_run_and_update_trial(self, trial, *fit_args, **fit_kwargs): if isinstance(e, errors.FailedTrialError): trial.status = trial_module.TrialStatus.FAILED + elif isinstance(e, errors.SkipModelError): + trial.status = trial_module.TrialStatus.SKIPPED else: trial.status = trial_module.TrialStatus.INVALID diff --git a/keras_tuner/engine/hyperparameters/hyperparameters.py b/keras_tuner/engine/hyperparameters/hyperparameters.py index 9ed4ed832..a1b9a8baa 100644 --- a/keras_tuner/engine/hyperparameters/hyperparameters.py +++ b/keras_tuner/engine/hyperparameters/hyperparameters.py @@ -20,6 +20,7 @@ import six from keras_tuner import protos +from keras_tuner import errors from keras_tuner.api_export import keras_tuner_export from keras_tuner.engine import conditions as conditions_mod from keras_tuner.engine.hyperparameters import hp_types @@ -250,6 +251,12 @@ def __contains__(self, name): except (KeyError, ValueError): return False + def skip_model(self, message): + if len(self._hps) == 0: + # Registration stage + return + raise errors.SkipModelError(message) + def Choice( self, name, diff --git a/keras_tuner/engine/trial.py b/keras_tuner/engine/trial.py index 5dde623af..283a08378 100644 --- a/keras_tuner/engine/trial.py +++ b/keras_tuner/engine/trial.py @@ -43,6 +43,7 @@ class TrialStatus: COMPLETED = "COMPLETED" # The Trial is failed. No more retries needed. FAILED = "FAILED" + SKIPPED = "SKIPPED" @staticmethod def to_proto(status): @@ -61,6 +62,8 @@ def to_proto(status): return ts.COMPLETED elif status == TrialStatus.FAILED: return ts.FAILED + elif status == TrialStatus.SKIPPED: + return ts.SKIPPED else: raise ValueError(f"Unknown status {status}") @@ -81,6 +84,8 @@ def from_proto(proto): return TrialStatus.COMPLETED elif proto == ts.FAILED: return TrialStatus.FAILED + elif proto == ts.SKIPPED: + return TrialStatus.SKIPPED else: raise ValueError(f"Unknown status {proto}") diff --git a/keras_tuner/errors.py b/keras_tuner/errors.py index ee5c3a41c..153b9affc 100644 --- a/keras_tuner/errors.py +++ b/keras_tuner/errors.py @@ -40,6 +40,11 @@ def build(self, hp): pass +@keras_tuner_export(["keras_tuner.errors.SkipModelError"]) +class SkipModelError(Exception): + pass + + @keras_tuner_export(["keras_tuner.errors.FatalError"]) class FatalError(Exception): """A fatal error during search to terminate the program.