From eb0c7d7aaee197f2282f2fcda6550cf1841a69ce Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Wed, 22 Jan 2025 14:50:53 -0800 Subject: [PATCH] Add concept of recoverable errors respected by Scheduler (#3262) Summary: The motivation is that some metrics are flaky and we don't want to fail the trial just because we encountered one exception fetching. Especially trials with `period_of_new_data_after_trial_completion()` > 0. This alternative to implementing this on the metric is that the set of recoverable errors should be a scheduler option, and it's more a matter of scheduler use case than metric. Reviewed By: Cesar-Cardoso Differential Revision: D68273328 --- ax/core/metric.py | 16 +++ ax/service/scheduler.py | 2 + ax/service/tests/scheduler_test_utils.py | 147 +++++++++++++++++++---- 3 files changed, 141 insertions(+), 24 deletions(-) diff --git a/ax/core/metric.py b/ax/core/metric.py index 6a325bcf9e0..4cefc3529c0 100644 --- a/ax/core/metric.py +++ b/ax/core/metric.py @@ -82,6 +82,11 @@ class Metric(SortableBase, SerializationMixin): """ data_constructor: type[Data] = Data + # The set of exception types stored in a ``MetchFetchE.exception`` that are + # recoverable ``Scheduler._fetch_and_process_trials_data_results()``. + # Exception may be a subclass of any of these types. If you want your metric + # to never fail the trial, set this to ``{Exception}`` in your metric subclass. + recoverable_exceptions: set[type[Exception]] = set() def __init__( self, @@ -138,6 +143,17 @@ def period_of_new_data_after_trial_completion(cls) -> timedelta: """ return timedelta(0) + @classmethod + def is_reconverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool: + """Checks whether the given MetricFetchE is recoverable for this metric class + in ``Scheduler._fetch_and_process_trials_data_results``. + """ + if metric_fetch_e.exception is None: + return False + return any( + isinstance(metric_fetch_e.exception, e) for e in cls.recoverable_exceptions + ) + # NOTE: This is rarely overridden –– oonly if you want to fetch data in groups # consisting of multiple different metric classes, for data to be fetched together. # This makes sense only if `fetch_trial data_multi` or `fetch_experiment_data_multi` diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 19fdf47bc1a..ce3de72e930 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -2046,6 +2046,8 @@ def _fetch_and_process_trials_data_results( if ( optimization_config is not None and metric_name in optimization_config.metrics.keys() + ) and not self.experiment.metrics[metric_name].is_reconverable_fetch_e( + metric_fetch_e=metric_fetch_e ): status = self._mark_err_trial_status( trial=self.experiment.trials[trial_index], diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 92e5a0a6cec..388e59e008e 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -37,7 +37,12 @@ get_pending_observation_features_based_on_trial_status, ) from ax.early_stopping.strategies import BaseEarlyStoppingStrategy -from ax.exceptions.core import OptimizationComplete, UnsupportedError, UserInputError +from ax.exceptions.core import ( + AxError, + OptimizationComplete, + UnsupportedError, + UserInputError, +) from ax.exceptions.generation_strategy import AxGenerationException from ax.metrics.branin import BraninMetric from ax.metrics.branin_map import BraninTimestampMapMetric @@ -1981,43 +1986,137 @@ def test_fetch_and_process_trials_data_results_failed_objective(self) -> None: experiment=self.branin_experiment, generation_strategy=self.two_sobol_steps_GS, ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, + options=SchedulerOptions( + **self.scheduler_options_kwargs, + ), + db_settings=self.db_settings_if_always_needed, + ) with patch( f"{BraninMetric.__module__}.BraninMetric.f", side_effect=Exception("yikes!") ), patch( f"{BraninMetric.__module__}.BraninMetric.is_available_while_running", return_value=False, ), self.assertLogs(logger="ax.service.scheduler") as lg: - scheduler = Scheduler( - experiment=self.branin_experiment, - generation_strategy=gs, - options=SchedulerOptions( - **self.scheduler_options_kwargs, - ), - db_settings=self.db_settings_if_always_needed, + # This trial will fail + with self.assertRaises(FailureRateExceededError): + scheduler.run_n_trials(max_trials=1) + self.assertTrue( + any( + re.search(r"Failed to fetch (branin|m1) for trial 0", warning) + is not None + for warning in lg.output + ) + ) + self.assertTrue( + any( + re.search( + r"Because (branin|m1) is an objective, marking trial 0 as " + "TrialStatus.FAILED", + warning, + ) + is not None + for warning in lg.output + ) + ) + self.assertEqual(scheduler.experiment.trials[0].status, TrialStatus.FAILED) + + def test_fetch_and_process_trials_data_results_failed_objective_but_recoverable( + self, + ) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, + options=SchedulerOptions( + enforce_immutable_search_space_and_opt_config=False, + **self.scheduler_options_kwargs, + ), + db_settings=self.db_settings_if_always_needed, + ) + BraninMetric.recoverable_exceptions = {AxError, TypeError} + # we're throwing a recoverable exception because UserInputError + # is a subclass of AxError + with patch( + f"{BraninMetric.__module__}.BraninMetric.f", + side_effect=UserInputError("yikes!"), + ), patch( + f"{BraninMetric.__module__}.BraninMetric.is_available_while_running", + return_value=False, + ), self.assertLogs(logger="ax.service.scheduler") as lg: + scheduler.run_n_trials(max_trials=1) + self.assertTrue( + any( + re.search(r"Failed to fetch (branin|m1) for trial 0", warning) + is not None + for warning in lg.output + ), + lg.output, + ) + self.assertTrue( + any( + re.search( + "MetricFetchE INFO: Continuing optimization even though " + "MetricFetchE encountered", + warning, + ) + is not None + for warning in lg.output ) + ) + self.assertEqual(scheduler.experiment.trials[0].status, TrialStatus.COMPLETED) + def test_fetch_and_process_trials_data_results_failed_objective_not_recoverable( + self, + ) -> None: + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, + options=SchedulerOptions( + **self.scheduler_options_kwargs, + ), + db_settings=self.db_settings_if_always_needed, + ) + # we're throwing a unrecoverable exception because Exception is not subclass + # of either error type in recoverable_exceptions + BraninMetric.recoverable_exceptions = {AxError, TypeError} + with patch( + f"{BraninMetric.__module__}.BraninMetric.f", side_effect=Exception("yikes!") + ), patch( + f"{BraninMetric.__module__}.BraninMetric.is_available_while_running", + return_value=False, + ), self.assertLogs(logger="ax.service.scheduler") as lg: # This trial will fail with self.assertRaises(FailureRateExceededError): scheduler.run_n_trials(max_trials=1) - self.assertTrue( - any( - re.search(r"Failed to fetch (branin|m1) for trial 0", warning) - is not None - for warning in lg.output - ) + self.assertTrue( + any( + re.search(r"Failed to fetch (branin|m1) for trial 0", warning) + is not None + for warning in lg.output ) - self.assertTrue( - any( - re.search( - r"Because (branin|m1) is an objective, marking trial 0 as " - "TrialStatus.FAILED", - warning, - ) - is not None - for warning in lg.output + ) + self.assertTrue( + any( + re.search( + r"Because (branin|m1) is an objective, marking trial 0 as " + "TrialStatus.FAILED", + warning, ) + is not None + for warning in lg.output ) - self.assertEqual(scheduler.experiment.trials[0].status, TrialStatus.FAILED) + ) + self.assertEqual(scheduler.experiment.trials[0].status, TrialStatus.FAILED) def test_should_consider_optimization_complete(self) -> None: # Tests non-GSS parts of the completion criterion.