Skip to content

Commit

Permalink
Enable gen fallback model spec for GenerationNode
Browse files Browse the repository at this point in the history
Summary: Allowing for gen fallback (default to sobol) upon running into specified error in GenerationNode.gen()

Differential Revision: D67232696
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Dec 22, 2024
1 parent 2d375ab commit d0a1b49
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 87 deletions.
122 changes: 86 additions & 36 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@
from ax.modelbridge.best_model_selector import BestModelSelector

from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec
from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase
from ax.modelbridge.registry import (
_extract_model_state_after_gen,
ModelRegistryBase,
Models,
)
from ax.modelbridge.transition_criterion import (
AutoTransitionAfterGen,
MaxGenerationParallelism,
Expand All @@ -54,11 +58,11 @@
"the `BestModelSelector` will be used to select the `ModelSpec` to "
"use for candidate generation."
)
MAX_GEN_DRAWS = 5
MAX_GEN_DRAWS_EXCEEDED_MESSAGE = (
f"GenerationStrategy exceeded `MAX_GEN_DRAWS` of {MAX_GEN_DRAWS} while trying to "
"generate a unique parameterization. This indicates that the search space has "
"likely been fully explored, or that the sweep has converged."
MAX_GEN_ATTEMPTS = 5
MAX_GEN_ATTEMPTS_EXCEEDED_MESSAGE = (
f"GenerationStrategy exceeded `MAX_GEN_ATTEMPTS` of {MAX_GEN_ATTEMPTS} while "
"trying to generate a unique parameterization. This indicates that the search "
"space has likely been fully explored, or that the sweep has converged."
)


Expand Down Expand Up @@ -145,6 +149,7 @@ def __init__(
previous_node_name: str | None = None,
trial_type: str | None = None,
should_skip: bool = False,
fallback_specs: None | dict[type[Exception], ModelSpec] = None,
) -> None:
self._node_name = node_name
# Check that the model specs have unique model keys.
Expand Down Expand Up @@ -174,6 +179,12 @@ def __init__(
self._previous_node_name = previous_node_name
self._trial_type = trial_type
self._should_skip = should_skip
# pyre-fixme[8]: Incompatible attribute type
self.fallback_specs: dict[type[Exception], ModelSpec] = (
fallback_specs
if fallback_specs is not None
else {GenerationStrategyRepeatedPoints: ModelSpec(model_enum=Models.SOBOL)}
)

@property
def node_name(self) -> str:
Expand Down Expand Up @@ -321,6 +332,25 @@ def fit(
},
)

# fit fallback models
for fallback_model_spec in self.fallback_specs.values():
logger.debug(
f"Fitting fallback model {fallback_model_spec.model_key} with data for "
f"trials: {trial_indices_in_data}"
)
fallback_model_spec.fit(
experiment=experiment,
data=data,
search_space=search_space,
optimization_config=optimization_config,
**{
**self._get_model_state_from_last_generator_run(
model_spec=fallback_model_spec
),
**kwargs,
},
)

def _get_model_state_from_last_generator_run(
self, model_spec: ModelSpec
) -> dict[str, Any]:
Expand Down Expand Up @@ -368,7 +398,7 @@ def gen(
self,
n: int | None = None,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
max_gen_draws_for_deduplication: int = MAX_GEN_DRAWS,
max_gen_attempts_for_deduplication: int = MAX_GEN_ATTEMPTS,
arms_by_signature_for_deduplication: dict[str, Arm] | None = None,
**model_gen_kwargs: Any,
) -> GeneratorRun:
Expand All @@ -387,10 +417,10 @@ def gen(
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
max_gen_draws_for_deduplication: Maximum number of attempts for generating
new candidates without duplicates. If non-duplicate candidates are not
generated with these attempts, a ``GenerationStrategyRepeatedPoints``
exception will be raised.
max_gen_attempts_for_deduplication: Maximum number of attempts for
generating new candidates without duplicates. If non-duplicate
candidates are not generated with these attempts, a
``GenerationStrategyRepeatedPoints`` exception will be raised.
arms_by_signature_for_deduplication: A dictionary mapping arm signatures to
the arms, to be used for deduplicating newly generated arms.
model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``;
Expand All @@ -402,34 +432,54 @@ def gen(
should_generate_run = True
generator_run = None
n_gen_draws = 0
# Keep generating until each of `generator_run.arms` is not a duplicate
# of a previous arm, if `should_deduplicate is True`
while should_generate_run:
try:
# Keep generating until each of `generator_run.arms` is not a duplicate
# of a previous arm, if `should_deduplicate is True`
while should_generate_run:
generator_run = self._gen(
n=n,
pending_observations=pending_observations,
**model_gen_kwargs,
)
should_generate_run = (
self.should_deduplicate
and arms_by_signature_for_deduplication
and any(
arm.signature in arms_by_signature_for_deduplication
for arm in generator_run.arms
)
)
n_gen_draws += 1
if should_generate_run:
if n_gen_draws > max_gen_attempts_for_deduplication:
raise GenerationStrategyRepeatedPoints(
MAX_GEN_ATTEMPTS_EXCEEDED_MESSAGE
)
else:
logger.info(
"The generator run produced duplicate arms. Re-running the "
"generation step in an attempt to deduplicate. Candidates "
f"produced in the last generator run: {generator_run.arms}."
)
except Exception as e:
error_type = type(e)
if error_type not in self.fallback_specs:
raise e
# Switch _model_spec_to_gen_from to a fallback spec
self._model_spec_to_gen_from = self.fallback_specs[error_type]
logger.warning(
f"gen failed with error {e}, "
"switching to fallback model with model_enum "
f"{self._model_spec_to_gen_from.model_enum}"
)
generator_run = self._gen(
n=n,
pending_observations=pending_observations,
**model_gen_kwargs,
)
should_generate_run = (
self.should_deduplicate
and arms_by_signature_for_deduplication
and any(
arm.signature in arms_by_signature_for_deduplication
for arm in generator_run.arms
)
)
n_gen_draws += 1
if should_generate_run:
if n_gen_draws > max_gen_draws_for_deduplication:
raise GenerationStrategyRepeatedPoints(
MAX_GEN_DRAWS_EXCEEDED_MESSAGE
)
else:
logger.info(
"The generator run produced duplicate arms. Re-running the "
"generation step in an attempt to deduplicate. Candidates "
f"produced in the last generator run: {generator_run.arms}."
)
# If we fell back, we need to re-pick the model to generate from
self._model_spec_to_gen_from = None

assert generator_run is not None, (
"The GeneratorRun is None which is an unexpected state of this"
" GenerationStrategy. This occurred on GenerationNode: {self.node_name}."
Expand Down Expand Up @@ -895,14 +945,14 @@ def gen(
self,
n: int | None = None,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
max_gen_draws_for_deduplication: int = MAX_GEN_DRAWS,
max_gen_attempts_for_deduplication: int = MAX_GEN_ATTEMPTS,
arms_by_signature_for_deduplication: dict[str, Arm] | None = None,
**model_gen_kwargs: Any,
) -> GeneratorRun:
gr = super().gen(
n=n,
pending_observations=pending_observations,
max_gen_draws_for_deduplication=max_gen_draws_for_deduplication,
max_gen_attempts_for_deduplication=max_gen_attempts_for_deduplication,
arms_by_signature_for_deduplication=arms_by_signature_for_deduplication,
**model_gen_kwargs,
)
Expand Down
120 changes: 72 additions & 48 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ def test_with_model_selection(self, mock_model_state: Mock) -> None:
# Model state is not extracted since there is no past GR.
mock_model_state.assert_not_called()
exp.new_trial(gs.gen(experiment=exp))
# Model state is extracted since there is a past GR.
mock_model_state.assert_called_once()
# Model state is extracted for each model (one main, one default fallback)
# since there is a past GR.
self.assertEqual(mock_model_state.call_count, 2)
mock_model_state.reset_mock()
# Gen with MBM/BO_MIXED.
mbm_gr_1 = gs.gen(experiment=exp)
Expand Down Expand Up @@ -774,49 +775,69 @@ def test_max_parallelism_reached(self) -> None:
with self.assertRaises(MaxParallelismReachedException):
sobol_generation_strategy.gen(experiment=exp)

def test_deduplication(self) -> None:
tiny_parameters = [
FixedParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
value=1.0,
),
ChoiceParameter(
name="x2",
parameter_type=ParameterType.FLOAT,
values=[float(x) for x in range(2)],
),
]
tiny_search_space = SearchSpace(
parameters=cast(list[Parameter], tiny_parameters)
)
exp = get_branin_experiment(search_space=tiny_search_space)
sobol = GenerationStrategy(
name="Sobol",
steps=[
GenerationStep(
model=Models.SOBOL,
num_trials=-1,
# Disable model-level deduplication.
model_kwargs={"deduplicate": False},
should_deduplicate=True,
def test_deduplication_and_fallback(self) -> None:
# None uses default fallback, which catches
# GenerationStrategyRepeatedPoints and re-generate with sobol
# {} will not have a fallback model and will raise the exception
for fallback_specs in [{}, None]:
tiny_parameters = [
FixedParameter(
name="x1",
parameter_type=ParameterType.FLOAT,
value=1.0,
),
],
)
for _ in range(2):
g = sobol.gen(exp)
exp.new_trial(generator_run=g).run()

self.assertEqual(len(exp.arms_by_signature), 2)

with self.assertRaisesRegex(
GenerationStrategyRepeatedPoints, "exceeded `MAX_GEN_DRAWS`"
), mock.patch("ax.modelbridge.generation_node.logger.info") as mock_logger:
g = sobol.gen(exp)
self.assertEqual(mock_logger.call_count, 5)
self.assertIn(
"The generator run produced duplicate arms.", mock_logger.call_args[0][0]
)
ChoiceParameter(
name="x2",
parameter_type=ParameterType.FLOAT,
values=[float(x) for x in range(2)],
),
]
tiny_search_space = SearchSpace(
parameters=cast(list[Parameter], tiny_parameters)
)
exp = get_branin_experiment(search_space=tiny_search_space)
sobol = GenerationStrategy(
name="Sobol",
nodes=[
GenerationNode(
node_name="sobol",
model_specs=[
ModelSpec(
model_enum=Models.SOBOL,
model_kwargs={"deduplicate": False},
)
],
# Disable model-level deduplication.
should_deduplicate=True,
fallback_specs=fallback_specs,
),
],
)
for _ in range(2):
g = sobol.gen(exp)
exp.new_trial(generator_run=g).run()

self.assertEqual(len(exp.arms_by_signature), 2)

if fallback_specs is not None:
with self.assertRaisesRegex(
GenerationStrategyRepeatedPoints, "exceeded `MAX_GEN_ATTEMPTS`"
), mock.patch(
"ax.modelbridge.generation_node.logger.info"
) as mock_logger:
g = sobol.gen(exp)
else:
# generation with a fallback model
with self.assertLogs(GenerationNode.__module__, logging.WARNING) as cm:
g = sobol.gen(exp)
self.assertTrue(
any("gen failed with error" in msg for msg in cm.output)
)
self.assertEqual(mock_logger.call_count, 5)
self.assertIn(
"The generator run produced duplicate arms.",
mock_logger.call_args[0][0],
)

def test_current_generator_run_limit(self) -> None:
NUM_INIT_TRIALS = 5
Expand Down Expand Up @@ -906,7 +927,9 @@ def test_hierarchical_search_space(self) -> None:
RandomModelBridge, "_fit"
) as mock_model_fit, patch.object(RandomModelBridge, "gen"):
self.sobol_GS.gen(experiment=experiment)
mock_model_fit.assert_called_once()
# We should only fit once for each model
# (one for the main model, another for the default fallback model);
self.assertEqual(mock_model_fit.call_count, 2)
observations = mock_model_fit.call_args[1].get("observations")
all_parameter_names = checked_cast(
HierarchicalSearchSpace, experiment.search_space
Expand Down Expand Up @@ -954,9 +977,10 @@ def test_gen_multiple(self) -> None:
# first four become trials.
grs = sobol_MBM_gs._gen_multiple(experiment=exp, num_generator_runs=3)
self.assertEqual(len(grs), 3)
# We should only fit once; refitting for each `gen` would be
# wasteful as there is no new data.
model_spec_fit_mock.assert_called_once()
# We should only fit once for each model
# (one for the main model, another for the default fallback model);
# refitting for each `gen` would be wasteful as there is no new data.
self.assertEqual(model_spec_fit_mock.call_count, 2)
self.assertEqual(model_spec_gen_mock.call_count, 3)
pending_in_each_gen = enumerate(
args_and_kwargs.kwargs.get("pending_observations")
Expand Down
8 changes: 5 additions & 3 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def get_client_with_simple_discrete_moo_problem(
metrics = [-m for m in metrics]
y0, y1, y2 = metrics
raw_data = {"y0": (y0, 0.0), "y1": (y1, 0.0), "y2": (y2, 0.0)}
# pyre-fixme [6]: In call `AxClient.complete_trial`, for 2nd parameter
# pyre-fixme[6]: In call `AxClient.complete_trial`, for 2nd parameter
# `raw_data`
# expected `Union[Dict[str, Union[Tuple[Union[float, floating, integer],
# Union[None, float, floating, integer]], float, floating, integer]],
Expand Down Expand Up @@ -1768,7 +1768,8 @@ def test_trial_completion_with_metadata_with_iso_times(self) -> None:
RandomModelBridge, "_fit", autospec=True, side_effect=RandomModelBridge._fit
) as mock_fit:
ax_client.get_next_trial()
mock_fit.assert_called_once()
# one for the main model, one for the default fallback model
self.assertEqual(mock_fit.call_count, 2)
features = mock_fit.call_args_list[0][1]["observations"][0].features
# we're asserting it's actually created real Timestamp objects
# for the observation features
Expand All @@ -1790,7 +1791,8 @@ def test_trial_completion_with_metadata_millisecond_times(self) -> None:
RandomModelBridge, "_fit", autospec=True, side_effect=RandomModelBridge._fit
) as mock_fit:
ax_client.get_next_trial()
mock_fit.assert_called_once()
# one for the main model, one for the default fallback model
self.assertEqual(mock_fit.call_count, 2)
features = mock_fit.call_args_list[0][1]["observations"][0].features
# we're asserting it's actually created real Timestamp objects
# for the observation features
Expand Down

0 comments on commit d0a1b49

Please sign in to comment.