Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add concept of recoverable errors respected by Scheduler #3262

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions ax/core/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ class Metric(SortableBase, SerializationMixin):
"""

data_constructor: type[Data] = Data
# The set of exception types stored in a ``MetchFetchE.exception`` that are
# recoverable ``Scheduler._fetch_and_process_trials_data_results()``.
# Exception may be a subclass of any of these types. If you want your metric
# to never fail the trial, set this to ``{Exception}`` in your metric subclass.
recoverable_exceptions: set[type[Exception]] = set()

def __init__(
self,
Expand Down Expand Up @@ -138,6 +143,17 @@ def period_of_new_data_after_trial_completion(cls) -> timedelta:
"""
return timedelta(0)

@classmethod
def is_reconverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool:
"""Checks whether the given MetricFetchE is recoverable for this metric class
in ``Scheduler._fetch_and_process_trials_data_results``.
"""
if metric_fetch_e.exception is None:
return False
return any(
isinstance(metric_fetch_e.exception, e) for e in cls.recoverable_exceptions
)

# NOTE: This is rarely overridden –– oonly if you want to fetch data in groups
# consisting of multiple different metric classes, for data to be fetched together.
# This makes sense only if `fetch_trial data_multi` or `fetch_experiment_data_multi`
Expand Down
5 changes: 4 additions & 1 deletion ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2041,11 +2041,14 @@ def _fetch_and_process_trials_data_results(
)

# If the fetch failure was for a metric in the optimization config (an
# objective or constraint) the trial as failed
# objective or constraint) mark the trial as failed
optimization_config = self.experiment.optimization_config
if (
optimization_config is not None
and metric_name in optimization_config.metrics.keys()
and not self.experiment.metrics[
metric_name
].is_reconverable_fetch_e(metric_fetch_e=metric_fetch_e)
):
status = self._mark_err_trial_status(
trial=self.experiment.trials[trial_index],
Expand Down
147 changes: 123 additions & 24 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@
get_pending_observation_features_based_on_trial_status,
)
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
from ax.exceptions.core import OptimizationComplete, UnsupportedError, UserInputError
from ax.exceptions.core import (
AxError,
OptimizationComplete,
UnsupportedError,
UserInputError,
)
from ax.exceptions.generation_strategy import AxGenerationException
from ax.metrics.branin import BraninMetric
from ax.metrics.branin_map import BraninTimestampMapMetric
Expand Down Expand Up @@ -1981,43 +1986,137 @@ def test_fetch_and_process_trials_data_results_failed_objective(self) -> None:
experiment=self.branin_experiment,
generation_strategy=self.two_sobol_steps_GS,
)
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=gs,
options=SchedulerOptions(
**self.scheduler_options_kwargs,
),
db_settings=self.db_settings_if_always_needed,
)
with patch(
f"{BraninMetric.__module__}.BraninMetric.f", side_effect=Exception("yikes!")
), patch(
f"{BraninMetric.__module__}.BraninMetric.is_available_while_running",
return_value=False,
), self.assertLogs(logger="ax.service.scheduler") as lg:
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=gs,
options=SchedulerOptions(
**self.scheduler_options_kwargs,
),
db_settings=self.db_settings_if_always_needed,
# This trial will fail
with self.assertRaises(FailureRateExceededError):
scheduler.run_n_trials(max_trials=1)
self.assertTrue(
any(
re.search(r"Failed to fetch (branin|m1) for trial 0", warning)
is not None
for warning in lg.output
)
)
self.assertTrue(
any(
re.search(
r"Because (branin|m1) is an objective, marking trial 0 as "
"TrialStatus.FAILED",
warning,
)
is not None
for warning in lg.output
)
)
self.assertEqual(scheduler.experiment.trials[0].status, TrialStatus.FAILED)

def test_fetch_and_process_trials_data_results_failed_objective_but_recoverable(
self,
) -> None:
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.two_sobol_steps_GS,
)
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=gs,
options=SchedulerOptions(
enforce_immutable_search_space_and_opt_config=False,
**self.scheduler_options_kwargs,
),
db_settings=self.db_settings_if_always_needed,
)
BraninMetric.recoverable_exceptions = {AxError, TypeError}
# we're throwing a recoverable exception because UserInputError
# is a subclass of AxError
with patch(
f"{BraninMetric.__module__}.BraninMetric.f",
side_effect=UserInputError("yikes!"),
), patch(
f"{BraninMetric.__module__}.BraninMetric.is_available_while_running",
return_value=False,
), self.assertLogs(logger="ax.service.scheduler") as lg:
scheduler.run_n_trials(max_trials=1)
self.assertTrue(
any(
re.search(r"Failed to fetch (branin|m1) for trial 0", warning)
is not None
for warning in lg.output
),
lg.output,
)
self.assertTrue(
any(
re.search(
"MetricFetchE INFO: Continuing optimization even though "
"MetricFetchE encountered",
warning,
)
is not None
for warning in lg.output
)
)
self.assertEqual(scheduler.experiment.trials[0].status, TrialStatus.COMPLETED)

def test_fetch_and_process_trials_data_results_failed_objective_not_recoverable(
self,
) -> None:
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.two_sobol_steps_GS,
)
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=gs,
options=SchedulerOptions(
**self.scheduler_options_kwargs,
),
db_settings=self.db_settings_if_always_needed,
)
# we're throwing a unrecoverable exception because Exception is not subclass
# of either error type in recoverable_exceptions
BraninMetric.recoverable_exceptions = {AxError, TypeError}
with patch(
f"{BraninMetric.__module__}.BraninMetric.f", side_effect=Exception("yikes!")
), patch(
f"{BraninMetric.__module__}.BraninMetric.is_available_while_running",
return_value=False,
), self.assertLogs(logger="ax.service.scheduler") as lg:
# This trial will fail
with self.assertRaises(FailureRateExceededError):
scheduler.run_n_trials(max_trials=1)
self.assertTrue(
any(
re.search(r"Failed to fetch (branin|m1) for trial 0", warning)
is not None
for warning in lg.output
)
self.assertTrue(
any(
re.search(r"Failed to fetch (branin|m1) for trial 0", warning)
is not None
for warning in lg.output
)
self.assertTrue(
any(
re.search(
r"Because (branin|m1) is an objective, marking trial 0 as "
"TrialStatus.FAILED",
warning,
)
is not None
for warning in lg.output
)
self.assertTrue(
any(
re.search(
r"Because (branin|m1) is an objective, marking trial 0 as "
"TrialStatus.FAILED",
warning,
)
is not None
for warning in lg.output
)
self.assertEqual(scheduler.experiment.trials[0].status, TrialStatus.FAILED)
)
self.assertEqual(scheduler.experiment.trials[0].status, TrialStatus.FAILED)

def test_should_consider_optimization_complete(self) -> None:
# Tests non-GSS parts of the completion criterion.
Expand Down
Loading