Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable gen fallback model spec for GenerationNode #3209

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 98 additions & 55 deletions ax/modelbridge/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import OptimizationConfig
from ax.core.search_space import SearchSpace
from ax.exceptions.core import UserInputError
from ax.exceptions.generation_strategy import GenerationStrategyRepeatedPoints
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.best_model_selector import BestModelSelector

from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec
from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase
from ax.modelbridge.registry import (
_extract_model_state_after_gen,
ModelRegistryBase,
Models,
)
from ax.modelbridge.transition_criterion import (
AutoTransitionAfterGen,
MaxGenerationParallelism,
Expand All @@ -54,12 +56,17 @@
"the `BestModelSelector` will be used to select the `ModelSpec` to "
"use for candidate generation."
)
MAX_GEN_DRAWS = 5
MAX_GEN_DRAWS_EXCEEDED_MESSAGE = (
f"GenerationStrategy exceeded `MAX_GEN_DRAWS` of {MAX_GEN_DRAWS} while trying to "
"generate a unique parameterization. This indicates that the search space has "
"likely been fully explored, or that the sweep has converged."
MAX_GEN_ATTEMPTS = 5
MAX_GEN_ATTEMPTS_EXCEEDED_MESSAGE = (
f"GenerationStrategy exceeded `MAX_GEN_ATTEMPTS` of {MAX_GEN_ATTEMPTS} while "
"trying to generate a unique parameterization. This indicates that the search "
"space has likely been fully explored, or that the sweep has converged."
)
DEFAULT_FALLBACK = {
GenerationStrategyRepeatedPoints: ModelSpec(
model_enum=Models.SOBOL, model_key_override="Fallback_Sobol"
)
}


class GenerationNode(SerializationMixin, SortableBase):
Expand Down Expand Up @@ -98,6 +105,8 @@ class GenerationNode(SerializationMixin, SortableBase):
store the most recent previous ``GenerationNode`` name.
should_skip: Whether to skip this node during generation time. Defaults to
False, and can only currently be set to True via ``NodeInputConstructors``
fallback_specs: Optional dict mapping expected exception types to `ModelSpec`
fallbacks used when gen fails.

Note for developers: by "model" here we really mean an Ax ModelBridge object, which
contains an Ax Model under the hood. We call it "model" here to simplify and focus
Expand Down Expand Up @@ -145,6 +154,7 @@ def __init__(
previous_node_name: str | None = None,
trial_type: str | None = None,
should_skip: bool = False,
fallback_specs: dict[type[Exception], ModelSpec] | None = None,
) -> None:
self._node_name = node_name
# Check that the model specs have unique model keys.
Expand Down Expand Up @@ -174,6 +184,10 @@ def __init__(
self._previous_node_name = previous_node_name
self._trial_type = trial_type
self._should_skip = should_skip
# pyre-fixme[8]: Incompatible attribute type
self.fallback_specs: dict[type[Exception], ModelSpec] = (
fallback_specs if fallback_specs is not None else DEFAULT_FALLBACK
)

@property
def node_name(self) -> str:
Expand Down Expand Up @@ -279,25 +293,26 @@ def _fitted_model(self) -> ModelBridge | None:
def fit(
self,
experiment: Experiment,
data: Data,
search_space: SearchSpace | None = None,
optimization_config: OptimizationConfig | None = None,
**kwargs: Any,
data: Data | None = None,
) -> None:
"""Fits the specified models to the given experiment + data using
the model kwargs set on each corresponding model spec and the kwargs
passed to this method.

NOTE: During fitting of the ``ModelSpec``, state of this ``ModelSpec``
after its last candidate generation is extracted from the last
``GeneratorRun`` it produced (if any was captured in
``GeneratorRun.model_state_after_gen``) and passed into ``ModelSpec.fit``
as keyword arguments.

Args:
experiment: The experiment to fit the model to.
data: The experiment data used to fit the model.
search_space: An optional overwrite for the experiment search space.
optimization_config: An optional overwrite for the experiment
optimization config.
kwargs: Additional keyword arguments to pass to the model's
``fit`` method. NOTE: Local kwargs take precedence over the ones
stored in ``ModelSpec.model_kwargs``.

"""
data = data if data is not None else experiment.lookup_data()
search_space = experiment.search_space
optimization_config = experiment.optimization_config
if not data.df.empty:
trial_indices_in_data = sorted(data.df["trial_index"].unique())
else:
Expand All @@ -313,12 +328,7 @@ def fit(
data=data,
search_space=search_space,
optimization_config=optimization_config,
**{
**self._get_model_state_from_last_generator_run(
model_spec=model_spec
),
**kwargs,
},
**self._get_model_state_from_last_generator_run(model_spec=model_spec),
)

def _get_model_state_from_last_generator_run(
Expand Down Expand Up @@ -368,7 +378,7 @@ def gen(
self,
n: int | None = None,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
max_gen_draws_for_deduplication: int = MAX_GEN_DRAWS,
max_gen_attempts_for_deduplication: int = MAX_GEN_ATTEMPTS,
arms_by_signature_for_deduplication: dict[str, Arm] | None = None,
**model_gen_kwargs: Any,
) -> GeneratorRun:
Expand All @@ -387,10 +397,10 @@ def gen(
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
max_gen_draws_for_deduplication: Maximum number of attempts for generating
new candidates without duplicates. If non-duplicate candidates are not
generated with these attempts, a ``GenerationStrategyRepeatedPoints``
exception will be raised.
max_gen_attempts_for_deduplication: Maximum number of attempts for
generating new candidates without duplicates. If non-duplicate
candidates are not generated with these attempts, a
``GenerationStrategyRepeatedPoints`` exception will be raised.
arms_by_signature_for_deduplication: A dictionary mapping arm signatures to
the arms, to be used for deduplicating newly generated arms.
model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``;
Expand All @@ -399,37 +409,70 @@ def gen(
Returns:
A ``GeneratorRun`` containing the newly generated candidates.
"""
should_generate_run = True
generator_run = None
n_gen_draws = 0
# Keep generating until each of `generator_run.arms` is not a duplicate
# of a previous arm, if `should_deduplicate is True`
while should_generate_run:
try:
# Keep generating until each of `generator_run.arms` is not a duplicate
# of a previous arm, if `should_deduplicate is True`
while n_gen_draws < max_gen_attempts_for_deduplication:
n_gen_draws += 1
generator_run = self._gen(
n=n,
pending_observations=pending_observations,
**model_gen_kwargs,
)
if not (
self.should_deduplicate
and arms_by_signature_for_deduplication
and any(
arm.signature in arms_by_signature_for_deduplication
for arm in generator_run.arms
)
): # Not deduplicating or generated a non-duplicate arm.
break

logger.info(
"The generator run produced duplicate arms. Re-running the "
"generation step in an attempt to deduplicate. Candidates "
f"produced in the last generator run: {generator_run.arms}."
)

if n_gen_draws >= max_gen_attempts_for_deduplication:
raise GenerationStrategyRepeatedPoints(
MAX_GEN_ATTEMPTS_EXCEEDED_MESSAGE
)
except Exception as e:
error_type = type(e)
if error_type not in self.fallback_specs:
raise e

# identify fallback model to use
fallback_model = self.fallback_specs[error_type]
logger.warning(
f"gen failed with error {e}, "
"switching to fallback model with model_enum "
f"{fallback_model.model_enum}"
)

# fit fallback model using information from `self.experiment`
# as ground truth
fallback_model.fit(
experiment=self.experiment,
data=self.experiment.lookup_data(),
search_space=self.experiment.search_space,
optimization_config=self.experiment.optimization_config,
**self._get_model_state_from_last_generator_run(
model_spec=fallback_model
),
)
# Switch _model_spec_to_gen_from to a fallback spec
self._model_spec_to_gen_from = fallback_model
generator_run = self._gen(
n=n,
pending_observations=pending_observations,
**model_gen_kwargs,
)
should_generate_run = (
self.should_deduplicate
and arms_by_signature_for_deduplication
and any(
arm.signature in arms_by_signature_for_deduplication
for arm in generator_run.arms
)
)
n_gen_draws += 1
if should_generate_run:
if n_gen_draws > max_gen_draws_for_deduplication:
raise GenerationStrategyRepeatedPoints(
MAX_GEN_DRAWS_EXCEEDED_MESSAGE
)
else:
logger.info(
"The generator run produced duplicate arms. Re-running the "
"generation step in an attempt to deduplicate. Candidates "
f"produced in the last generator run: {generator_run.arms}."
)

assert generator_run is not None, (
"The GeneratorRun is None which is an unexpected state of this"
" GenerationStrategy. This occurred on GenerationNode: {self.node_name}."
Expand Down Expand Up @@ -895,14 +938,14 @@ def gen(
self,
n: int | None = None,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
max_gen_draws_for_deduplication: int = MAX_GEN_DRAWS,
max_gen_attempts_for_deduplication: int = MAX_GEN_ATTEMPTS,
arms_by_signature_for_deduplication: dict[str, Arm] | None = None,
**model_gen_kwargs: Any,
) -> GeneratorRun:
gr = super().gen(
n=n,
pending_observations=pending_observations,
max_gen_draws_for_deduplication=max_gen_draws_for_deduplication,
max_gen_attempts_for_deduplication=max_gen_attempts_for_deduplication,
arms_by_signature_for_deduplication=arms_by_signature_for_deduplication,
**model_gen_kwargs,
)
Expand Down
26 changes: 4 additions & 22 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@


MAX_CONDITIONS_GENERATED = 10000
MAX_GEN_DRAWS = 5
MAX_GEN_DRAWS_EXCEEDED_MESSAGE = (
f"GenerationStrategy exceeded `MAX_GEN_DRAWS` of {MAX_GEN_DRAWS} while trying to "
"generate a unique parameterization. This indicates that the search space has "
"likely been fully explored, or that the sweep has converged."
)
T = TypeVar("T")


Expand Down Expand Up @@ -1134,25 +1128,13 @@ def _fit_current_model(
Args:
data: Optional ``Data`` to fit or update with; if not specified, generation
strategy will obtain the data via ``experiment.lookup_data``.
status_quo_features: An ``ObservationFeature`` of the status quo arm,
needed by some models during fit to accomadate relative constraints.
Includes the status quo parameterization and target trial index.
status_quo_features: UNSUPPORTED. This will not used in GenNode.fit and is
not needed here. This will be removed in the future.
"""
data = self.experiment.lookup_data() if data is None else data

# Only pass status_quo_features if not None to avoid errors
# with ``ExternalGenerationNode``.
if status_quo_features is not None:
self._curr.fit(
experiment=self.experiment,
data=data,
status_quo_features=status_quo_features,
)
else:
self._curr.fit(
experiment=self.experiment,
data=data,
)
logger.warning("`status_quo_features` is passed in but will not be used!")
self._curr.fit(experiment=self.experiment, data=data)
self._model = self._curr._fitted_model

def _maybe_transition_to_next_node(
Expand Down
Loading
Loading