From da084721abcc7afbc53d19962ce5ba740eafcda6 Mon Sep 17 00:00:00 2001 From: Jerry Lin Date: Wed, 12 Feb 2025 15:17:41 -0800 Subject: [PATCH 1/2] Remove no-longer-used `GenerationNode.fit` search_space and optimization_config arguments (#3360) Summary: As titled. Now the search_space and optimization_config will not be passed and by default will be extracted from the experiment. Reviewed By: lena-kashtelyan Differential Revision: D67803329 --- ax/generation_strategy/generation_node.py | 21 ++++++++----------- .../tests/test_generation_node.py | 2 -- .../tests/test_generation_strategy.py | 4 +--- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 2e1007604ef..9c3c2751f77 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -18,8 +18,6 @@ from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures -from ax.core.optimization_config import OptimizationConfig -from ax.core.search_space import SearchSpace from ax.core.trial_status import TrialStatus from ax.exceptions.core import UserInputError from ax.exceptions.generation_strategy import GenerationStrategyRepeatedPoints @@ -209,9 +207,7 @@ def model_to_gen_from_name(self) -> str | None: return None @property - def generation_strategy( - self, - ) -> GenerationStrategy: + def generation_strategy(self) -> GenerationStrategy: """Returns a backpointer to the GenerationStrategy, useful for obtaining the experiment associated with this GenerationStrategy""" # TODO: @mgarrard remove this property once we make experiment a required @@ -289,24 +285,26 @@ def fit( self, experiment: Experiment, data: Data, - search_space: SearchSpace | None = None, - optimization_config: OptimizationConfig | None = None, **kwargs: Any, ) -> None: """Fits the specified models to the given experiment + data using the model kwargs set on each corresponding model spec and the kwargs passed to this method. + NOTE: During fitting of the ``GeneratorSpec``, state of this ``GeneratorSpec`` + after its last candidate generation is extracted from the last + ``GeneratorRun`` it produced (if any was captured in + ``GeneratorRun.model_state_after_gen``) and passed into ``GeneratorSpec.fit`` + as keyword arguments. + Args: experiment: The experiment to fit the model to. data: The experiment data used to fit the model. - search_space: An optional overwrite for the experiment search space. - optimization_config: An optional overwrite for the experiment - optimization config. kwargs: Additional keyword arguments to pass to the model's ``fit`` method. NOTE: Local kwargs take precedence over the ones stored in ``GeneratorSpec.model_kwargs``. """ + data = data if data is not None else experiment.lookup_data() if not data.df.empty: trial_indices_in_data = sorted(data.df["trial_index"].unique()) else: @@ -317,11 +315,10 @@ def fit( f"Fitting model {model_spec.model_key} with data for " f"trials: {trial_indices_in_data}" ) + # search space and optimization config will come from the experiment model_spec.fit( # Stores the fitted model as `model_spec._fitted_model` experiment=experiment, data=data, - search_space=search_space, - optimization_config=optimization_config, **{ **self._get_model_state_from_last_generator_run( model_spec=model_spec diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index 1681df0d2c2..edd4203045b 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -153,8 +153,6 @@ def test_fit(self) -> None: mock_model_spec_fit.assert_called_with( experiment=self.branin_experiment, data=self.branin_data, - search_space=None, - optimization_config=None, ) def test_gen(self) -> None: diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index 863bfb98bef..43ced787190 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -408,7 +408,7 @@ def test_validation(self) -> None: ] ) - exp = Experiment( + Experiment( name="test", search_space=SearchSpace(parameters=[get_choice_parameter()]) ) factorial_thompson_generation_strategy = GenerationStrategy( @@ -421,8 +421,6 @@ def test_validation(self) -> None: self.assertFalse( factorial_thompson_generation_strategy.uses_non_registered_models ) - with self.assertRaises(ValueError): - factorial_thompson_generation_strategy._gen_with_multiple_nodes(exp) self.assertEqual(GenerationStep(model=sum, num_trials=1).model_name, "sum") with self.assertRaisesRegex(UserInputError, "Maximum parallelism should be"): GenerationStrategy( From e433d5a9d6aef20bae24adca6bfff8344991ee0a Mon Sep 17 00:00:00 2001 From: Jerry Lin Date: Wed, 12 Feb 2025 15:17:41 -0800 Subject: [PATCH 2/2] Enable gen fallback model spec for GenerationNode (#3209) Summary: Allowing for gen fallback (default to sobol) upon running into specified error in GenerationNode.gen() Reviewed By: saitcakmak Differential Revision: D67232696 --- ax/generation_strategy/generation_node.py | 130 ++++++++++++------ ax/generation_strategy/generation_strategy.py | 6 - .../tests/test_generation_strategy.py | 117 +++++++++------- ax/service/tests/test_ax_client.py | 6 +- 4 files changed, 160 insertions(+), 99 deletions(-) diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 9c3c2751f77..61d7319ab1f 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -23,7 +23,6 @@ from ax.exceptions.generation_strategy import GenerationStrategyRepeatedPoints from ax.generation_strategy.best_model_selector import BestModelSelector - if TYPE_CHECKING: from ax.generation_strategy.generation_node_input_constructors import ( InputConstructorPurpose, @@ -43,7 +42,11 @@ TrialBasedCriterion, ) from ax.modelbridge.base import Adapter -from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase +from ax.modelbridge.registry import ( + _extract_model_state_after_gen, + Generators, + ModelRegistryBase, +) from ax.utils.common.base import SortableBase from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger @@ -60,12 +63,17 @@ "the `BestModelSelector` will be used to select the `GeneratorSpec` to " "use for candidate generation." ) -MAX_GEN_DRAWS = 5 -MAX_GEN_DRAWS_EXCEEDED_MESSAGE = ( - f"GenerationStrategy exceeded `MAX_GEN_DRAWS` of {MAX_GEN_DRAWS} while trying to " - "generate a unique parameterization. This indicates that the search space has " - "likely been fully explored, or that the sweep has converged." +MAX_GEN_ATTEMPTS = 5 +MAX_GEN_ATTEMPTS_EXCEEDED_MESSAGE = ( + f"GenerationStrategy exceeded `MAX_GEN_ATTEMPTS` of {MAX_GEN_ATTEMPTS} while " + "trying to generate a unique parameterization. This indicates that the search " + "space has likely been fully explored, or that the sweep has converged." ) +DEFAULT_FALLBACK = { + GenerationStrategyRepeatedPoints: GeneratorSpec( + model_enum=Generators.SOBOL, model_key_override="Fallback_Sobol" + ) +} class GenerationNode(SerializationMixin, SortableBase): @@ -105,6 +113,8 @@ class GenerationNode(SerializationMixin, SortableBase): store the most recent previous ``GenerationNode`` name. should_skip: Whether to skip this node during generation time. Defaults to False, and can only currently be set to True via ``NodeInputConstructors`` + fallback_specs: Optional dict mapping expected exception types to `ModelSpec` + fallbacks used when gen fails. Note for developers: by "model" here we really mean an Ax Adapter object, which contains an Ax Model under the hood. We call it "model" here to simplify and focus @@ -150,6 +160,7 @@ def __init__( previous_node_name: str | None = None, trial_type: str | None = None, should_skip: bool = False, + fallback_specs: dict[type[Exception], GeneratorSpec] | None = None, ) -> None: self._node_name = node_name # Check that the model specs have unique model keys. @@ -179,6 +190,10 @@ def __init__( self._previous_node_name = previous_node_name self._trial_type = trial_type self._should_skip = should_skip + # pyre-fixme[8]: Incompatible attribute type + self.fallback_specs: dict[type[Exception], GeneratorSpec] = ( + fallback_specs if fallback_specs is not None else DEFAULT_FALLBACK + ) @property def node_name(self) -> str: @@ -374,7 +389,7 @@ def gen( self, n: int | None = None, pending_observations: dict[str, list[ObservationFeatures]] | None = None, - max_gen_draws_for_deduplication: int = MAX_GEN_DRAWS, + max_gen_attempts_for_deduplication: int = MAX_GEN_ATTEMPTS, arms_by_signature_for_deduplication: dict[str, Arm] | None = None, **model_gen_kwargs: Any, ) -> GeneratorRun: @@ -389,54 +404,87 @@ def gen( Args: n: Optional integer representing how many arms should be in the generator run produced by this method. When this is ``None``, ``n`` will be - determined by the ``GeneratorSpec`` that we are generating from. + determined by the ``ModelSpec`` that we are generating from. pending_observations: A map from metric name to pending observations for that metric, used by some models to avoid resuggesting points that are currently being evaluated. - max_gen_draws_for_deduplication: Maximum number of attempts for generating - new candidates without duplicates. If non-duplicate candidates are not - generated with these attempts, a ``GenerationStrategyRepeatedPoints`` - exception will be raised. + max_gen_attempts_for_deduplication: Maximum number of attempts for + generating new candidates without duplicates. If non-duplicate + candidates are not generated with these attempts, a + ``GenerationStrategyRepeatedPoints`` exception will be raised. arms_by_signature_for_deduplication: A dictionary mapping arm signatures to the arms, to be used for deduplicating newly generated arms. model_gen_kwargs: Keyword arguments, passed through to - ``GeneratorSpec.gen``; these override any pre-specified in - ``GeneratorSpec.model_gen_kwargs``. + ``ModelSpec.gen``; these override any pre-specified in + ``ModelSpec.model_gen_kwargs``. Returns: A ``GeneratorRun`` containing the newly generated candidates. """ - should_generate_run = True generator_run = None n_gen_draws = 0 - # Keep generating until each of `generator_run.arms` is not a duplicate - # of a previous arm, if `should_deduplicate is True` - while should_generate_run: + try: + # Keep generating until each of `generator_run.arms` is not a duplicate + # of a previous arm, if `should_deduplicate is True` + while n_gen_draws < max_gen_attempts_for_deduplication: + n_gen_draws += 1 + generator_run = self._gen( + n=n, + pending_observations=pending_observations, + **model_gen_kwargs, + ) + if not ( + self.should_deduplicate + and arms_by_signature_for_deduplication + and any( + arm.signature in arms_by_signature_for_deduplication + for arm in generator_run.arms + ) + ): # Not deduplicating or generated a non-duplicate arm. + break + + logger.info( + "The generator run produced duplicate arms. Re-running the " + "generation step in an attempt to deduplicate. Candidates " + f"produced in the last generator run: {generator_run.arms}." + ) + + if n_gen_draws >= max_gen_attempts_for_deduplication: + raise GenerationStrategyRepeatedPoints( + MAX_GEN_ATTEMPTS_EXCEEDED_MESSAGE + ) + except Exception as e: + error_type = type(e) + if error_type not in self.fallback_specs: + raise e + + # identify fallback model to use + fallback_model = self.fallback_specs[error_type] + logger.warning( + f"gen failed with error {e}, " + "switching to fallback model with model_enum " + f"{fallback_model.model_enum}" + ) + + # fit fallback model using information from `self.experiment` + # as ground truth + fallback_model.fit( + experiment=self.experiment, + data=self.experiment.lookup_data(), + search_space=self.experiment.search_space, + optimization_config=self.experiment.optimization_config, + **self._get_model_state_from_last_generator_run( + model_spec=fallback_model + ), + ) + # Switch _model_spec_to_gen_from to a fallback spec + self._model_spec_to_gen_from = fallback_model generator_run = self._gen( n=n, pending_observations=pending_observations, **model_gen_kwargs, ) - should_generate_run = ( - self.should_deduplicate - and arms_by_signature_for_deduplication - and any( - arm.signature in arms_by_signature_for_deduplication - for arm in generator_run.arms - ) - ) - n_gen_draws += 1 - if should_generate_run: - if n_gen_draws > max_gen_draws_for_deduplication: - raise GenerationStrategyRepeatedPoints( - MAX_GEN_DRAWS_EXCEEDED_MESSAGE - ) - else: - logger.info( - "The generator run produced duplicate arms. Re-running the " - "generation step in an attempt to deduplicate. Candidates " - f"produced in the last generator run: {generator_run.arms}." - ) + assert generator_run is not None, ( "The GeneratorRun is None which is an unexpected state of this" " GenerationStrategy. This occurred on GenerationNode: {self.node_name}." @@ -929,14 +977,14 @@ def gen( self, n: int | None = None, pending_observations: dict[str, list[ObservationFeatures]] | None = None, - max_gen_draws_for_deduplication: int = MAX_GEN_DRAWS, + max_gen_attempts_for_deduplication: int = MAX_GEN_ATTEMPTS, arms_by_signature_for_deduplication: dict[str, Arm] | None = None, **model_gen_kwargs: Any, ) -> GeneratorRun: gr = super().gen( n=n, pending_observations=pending_observations, - max_gen_draws_for_deduplication=max_gen_draws_for_deduplication, + max_gen_attempts_for_deduplication=max_gen_attempts_for_deduplication, arms_by_signature_for_deduplication=arms_by_signature_for_deduplication, **model_gen_kwargs, ) diff --git a/ax/generation_strategy/generation_strategy.py b/ax/generation_strategy/generation_strategy.py index 90ea01ee8d8..641f340cdc5 100644 --- a/ax/generation_strategy/generation_strategy.py +++ b/ax/generation_strategy/generation_strategy.py @@ -43,12 +43,6 @@ MAX_CONDITIONS_GENERATED = 10000 -MAX_GEN_DRAWS = 5 -MAX_GEN_DRAWS_EXCEEDED_MESSAGE = ( - f"GenerationStrategy exceeded `MAX_GEN_DRAWS` of {MAX_GEN_DRAWS} while trying to " - "generate a unique parameterization. This indicates that the search space has " - "likely been fully explored, or that the sweep has converged." -) T = TypeVar("T") diff --git a/ax/generation_strategy/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py index 43ced787190..2acd7ba0158 100644 --- a/ax/generation_strategy/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -117,7 +117,7 @@ def test_with_model_selection(self, mock_model_state: Mock) -> None: # Model state is not extracted since there is no past GR. mock_model_state.assert_not_called() exp.new_trial(gs.gen(experiment=exp)) - # Model state is extracted since there is a past GR. + # Model state is extracted for the model since there is a past GR. mock_model_state.assert_called_once() mock_model_state.reset_mock() # Gen with MBM/BO_MIXED. @@ -770,51 +770,69 @@ def test_max_parallelism_reached(self) -> None: with self.assertRaises(MaxParallelismReachedException): sobol_generation_strategy.gen(experiment=exp) - def test_deduplication(self) -> None: - tiny_parameters = [ - FixedParameter( - name="x1", - parameter_type=ParameterType.FLOAT, - value=1.0, - ), - ChoiceParameter( - name="x2", - parameter_type=ParameterType.FLOAT, - values=[float(x) for x in range(2)], - ), - ] - tiny_search_space = SearchSpace( - parameters=cast(list[Parameter], tiny_parameters) - ) - exp = get_branin_experiment(search_space=tiny_search_space) - sobol = GenerationStrategy( - name="Sobol", - steps=[ - GenerationStep( - model=Generators.SOBOL, - num_trials=-1, - # Disable model-level deduplication. - model_kwargs={"deduplicate": False}, - should_deduplicate=True, + def test_deduplication_and_fallback(self) -> None: + # None uses default fallback, which catches + # GenerationStrategyRepeatedPoints and re-generate with sobol + # {} will not have a fallback model and will raise the exception + for fallback_specs in [{}, None]: + tiny_parameters = [ + FixedParameter( + name="x1", + parameter_type=ParameterType.FLOAT, + value=1.0, ), - ], - ) - for _ in range(2): - g = sobol.gen(exp) - exp.new_trial(generator_run=g).run() - - self.assertEqual(len(exp.arms_by_signature), 2) - - with self.assertRaisesRegex( - GenerationStrategyRepeatedPoints, "exceeded `MAX_GEN_DRAWS`" - ), mock.patch( - "ax.generation_strategy.generation_node.logger.info" - ) as mock_logger: - g = sobol.gen(exp) - self.assertEqual(mock_logger.call_count, 5) - self.assertIn( - "The generator run produced duplicate arms.", mock_logger.call_args[0][0] - ) + ChoiceParameter( + name="x2", + parameter_type=ParameterType.FLOAT, + values=[float(x) for x in range(2)], + ), + ] + tiny_search_space = SearchSpace( + parameters=cast(list[Parameter], tiny_parameters) + ) + exp = get_branin_experiment(search_space=tiny_search_space) + sobol = GenerationStrategy( + name="Sobol", + nodes=[ + GenerationNode( + node_name="sobol", + model_specs=[ + GeneratorSpec( + model_enum=Generators.SOBOL, + model_kwargs={"deduplicate": False}, + ) + ], + # Disable model-level deduplication. + should_deduplicate=True, + fallback_specs=fallback_specs, + ), + ], + ) + for _ in range(2): + g = sobol.gen(exp) + exp.new_trial(generator_run=g).run() + + self.assertEqual(len(exp.arms_by_signature), 2) + + if fallback_specs is not None: + with self.assertRaisesRegex( + GenerationStrategyRepeatedPoints, "exceeded `MAX_GEN_ATTEMPTS`" + ), mock.patch( + "ax.generation_strategy.generation_node.logger.info" + ) as mock_logger: + g = sobol.gen(exp) + else: + # generation with a fallback model + with self.assertLogs(GenerationNode.__module__, logging.WARNING) as cm: + g = sobol.gen(exp) + self.assertTrue( + any("gen failed with error" in msg for msg in cm.output) + ) + self.assertEqual(mock_logger.call_count, 5) + self.assertIn( + "The generator run produced duplicate arms.", + mock_logger.call_args[0][0], + ) def test_current_generator_run_limit(self) -> None: NUM_INIT_TRIALS = 5 @@ -904,7 +922,8 @@ def test_hierarchical_search_space(self) -> None: RandomAdapter, "gen" ): self.sobol_GS.gen(experiment=experiment) - mock_model_fit.assert_called_once() + # We should only fit once for each model + self.assertEqual(mock_model_fit.call_count, 1) observations = mock_model_fit.call_args[1].get("observations") all_parameter_names = assert_is_instance( experiment.search_space, HierarchicalSearchSpace @@ -952,9 +971,9 @@ def test_gen_multiple(self) -> None: # first four become trials. grs = sobol_MBM_gs._gen_multiple(experiment=exp, num_generator_runs=3) self.assertEqual(len(grs), 3) - # We should only fit once; refitting for each `gen` would be - # wasteful as there is no new data. - model_spec_fit_mock.assert_called_once() + # We should only fit once for each model + # refitting for each `gen` would be wasteful as there is no new data. + self.assertEqual(model_spec_fit_mock.call_count, 1) self.assertEqual(model_spec_gen_mock.call_count, 3) pending_in_each_gen = enumerate( args_and_kwargs.kwargs.get("pending_observations") diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 349367d1b33..40064e066c4 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -263,7 +263,7 @@ def get_client_with_simple_discrete_moo_problem( metrics = [-m for m in metrics] y0, y1, y2 = metrics raw_data = {"y0": (y0, 0.0), "y1": (y1, 0.0), "y2": (y2, 0.0)} - # pyre-fixme [6]: In call `AxClient.complete_trial`, for 2nd parameter + # pyre-fixme[6]: In call `AxClient.complete_trial`, for 2nd parameter # `raw_data` # expected `Union[Dict[str, Union[Tuple[Union[float, floating, integer], # Union[None, float, floating, integer]], float, floating, integer]], @@ -1778,7 +1778,7 @@ def test_trial_completion_with_metadata_with_iso_times(self) -> None: RandomAdapter, "_fit", autospec=True, side_effect=RandomAdapter._fit ) as mock_fit: ax_client.get_next_trial() - mock_fit.assert_called_once() + self.assertEqual(mock_fit.call_count, 1) features = mock_fit.call_args_list[0][1]["observations"][0].features # we're asserting it's actually created real Timestamp objects # for the observation features @@ -1800,7 +1800,7 @@ def test_trial_completion_with_metadata_millisecond_times(self) -> None: RandomAdapter, "_fit", autospec=True, side_effect=RandomAdapter._fit ) as mock_fit: ax_client.get_next_trial() - mock_fit.assert_called_once() + self.assertEqual(mock_fit.call_count, 1) features = mock_fit.call_args_list[0][1]["observations"][0].features # we're asserting it's actually created real Timestamp objects # for the observation features