From 8529f134b2bc23d3d7bc28b6c6657000e1fea1e3 Mon Sep 17 00:00:00 2001 From: Jerry Lin Date: Thu, 6 Feb 2025 22:14:08 -0800 Subject: [PATCH] Migrate GS usages to new module location (#3311) Summary: X-link: https://github.com/facebookresearch/aepsych/pull/623 Pull Request resolved: https://github.com/facebook/Ax/pull/3311 Removing the following old files from `ax/modelbridge/` in favor of the new `ax/generation_strategy` directory. ``` best_model_selector.py dispatch_utils.py external_generation_node.py generation_node_input_constructors.py generation_node.py generation_strategy.py model_spec.py transition_criterion.py ``` Reviewed By: saitcakmak Differential Revision: D68645075 --- .../healthcheck/constraints_feasibility.py | 2 +- .../tests/test_constraints_feasibility.py | 6 +- .../plotly/arm_effects/insample_effects.py | 2 +- .../plotly/arm_effects/predicted_effects.py | 2 +- ax/analysis/plotly/cross_validation.py | 2 +- ax/analysis/plotly/surface/contour.py | 2 +- ax/analysis/plotly/surface/slice.py | 2 +- .../plotly/tests/test_predicted_effects.py | 2 +- ax/benchmark/benchmark_method.py | 2 +- ax/benchmark/methods/modular_botorch.py | 4 +- ax/benchmark/methods/sobol.py | 5 +- ax/benchmark/tests/test_benchmark.py | 9 +- .../tests/test_aepsych_criterion.py | 10 +- .../tests/test_best_model_selector.py | 4 +- .../tests/test_dispatch_utils.py | 2 +- .../tests/test_external_generation_node.py | 4 +- .../tests/test_generation_node.py | 15 +- ...test_generation_node_input_constructors.py | 8 +- .../tests/test_generation_strategy.py | 27 +- .../tests/test_model_spec.py | 5 +- .../tests/test_transition_criterion.py | 8 +- .../tests/test_model_fit_metrics.py | 5 +- ax/preview/api/client.py | 2 +- ax/preview/modelbridge/dispatch_utils.py | 9 +- ...tils.py => test_preview_dispatch_utils.py} | 2 +- ax/runners/tests/test_torchx.py | 3 +- ax/service/ax_client.py | 4 +- ax/service/managed_loop.py | 4 +- ax/service/scheduler.py | 2 +- ax/service/tests/scheduler_test_utils.py | 7 +- ax/service/tests/test_ax_client.py | 10 +- ax/service/tests/test_interactive_loop.py | 5 +- ax/service/tests/test_managed_loop.py | 5 +- ax/service/tests/test_report_utils.py | 4 +- ax/service/tests/test_scheduler.py | 7 +- .../tests/test_with_db_settings_base.py | 2 +- ax/service/utils/best_point.py | 2 +- ax/service/utils/best_point_mixin.py | 2 +- ax/service/utils/report_utils.py | 2 +- ax/service/utils/with_db_settings_base.py | 2 +- ax/storage/json_store/decoder.py | 12 +- ax/storage/json_store/encoders.py | 16 +- ax/storage/json_store/registry.py | 36 +- .../json_store/tests/test_json_store.py | 4 +- ax/storage/sqa_store/decoder.py | 2 +- ax/storage/sqa_store/delete.py | 2 +- ax/storage/sqa_store/encoder.py | 2 +- ax/storage/sqa_store/load.py | 2 +- ax/storage/sqa_store/save.py | 2 +- ax/storage/sqa_store/sqa_config.py | 2 +- ax/storage/sqa_store/tests/test_sqa_store.py | 2 +- ax/utils/testing/benchmark_stubs.py | 4 +- ax/utils/testing/core_stubs.py | 25 +- ax/utils/testing/modeling_stubs.py | 31 +- ax/utils/testing/tests/test_utils.py | 7 +- ax/utils/testing/utils.py | 2 +- docs/api.md | 2 +- docs/glossary.md | 2 +- sphinx/source/generation_strategy.rst | 2 +- tutorials/early_stopping/early_stopping.ipynb | 1358 ++++++++--------- .../external_generation_node.ipynb | 12 +- .../generation_strategy.ipynb | 10 +- tutorials/modular_botax/modular_botax.ipynb | 4 +- tutorials/scheduler/scheduler.ipynb | 6 +- tutorials/sebo/sebo.ipynb | 4 +- 65 files changed, 905 insertions(+), 848 deletions(-) rename ax/{modelbridge => generation_strategy}/tests/test_aepsych_criterion.py (96%) rename ax/{modelbridge => generation_strategy}/tests/test_best_model_selector.py (97%) rename ax/{modelbridge => generation_strategy}/tests/test_dispatch_utils.py (99%) rename ax/{modelbridge => generation_strategy}/tests/test_external_generation_node.py (96%) rename ax/{modelbridge => generation_strategy}/tests/test_generation_node.py (97%) rename ax/{modelbridge => generation_strategy}/tests/test_generation_node_input_constructors.py (98%) rename ax/{modelbridge => generation_strategy}/tests/test_generation_strategy.py (99%) rename ax/{modelbridge => generation_strategy}/tests/test_model_spec.py (98%) rename ax/{modelbridge => generation_strategy}/tests/test_transition_criterion.py (99%) rename ax/preview/modelbridge/tests/{test_dispatch_utils.py => test_preview_dispatch_utils.py} (99%) diff --git a/ax/analysis/healthcheck/constraints_feasibility.py b/ax/analysis/healthcheck/constraints_feasibility.py index 61b050ecd1f..552a0843aaa 100644 --- a/ax/analysis/healthcheck/constraints_feasibility.py +++ b/ax/analysis/healthcheck/constraints_feasibility.py @@ -23,8 +23,8 @@ from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.optimization_config import OptimizationConfig from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.base import Adapter -from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.transforms.derelativize import Derelativize from pyre_extensions import assert_is_instance, none_throws diff --git a/ax/analysis/healthcheck/tests/test_constraints_feasibility.py b/ax/analysis/healthcheck/tests/test_constraints_feasibility.py index a6e95609685..92067cca5ac 100644 --- a/ax/analysis/healthcheck/tests/test_constraints_feasibility.py +++ b/ax/analysis/healthcheck/tests/test_constraints_feasibility.py @@ -23,10 +23,10 @@ from ax.core.objective import Objective from ax.core.optimization_config import OptimizationConfig from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_node import GenerationNode +from ax.generation_strategy.generation_strategy import GenerationStrategy +from ax.generation_strategy.model_spec import GeneratorSpec from ax.modelbridge.factory import get_sobol -from ax.modelbridge.generation_node import GenerationNode -from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.model_spec import GeneratorSpec from ax.modelbridge.registry import Generators from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( diff --git a/ax/analysis/plotly/arm_effects/insample_effects.py b/ax/analysis/plotly/arm_effects/insample_effects.py index b556ec6a22d..666b6ba9a78 100644 --- a/ax/analysis/plotly/arm_effects/insample_effects.py +++ b/ax/analysis/plotly/arm_effects/insample_effects.py @@ -22,8 +22,8 @@ from ax.core.generator_run import GeneratorRun from ax.core.outcome_constraint import OutcomeConstraint from ax.exceptions.core import DataRequiredError, UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.base import Adapter -from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.registry import Generators from ax.modelbridge.transforms.derelativize import Derelativize from ax.utils.common.logger import get_logger diff --git a/ax/analysis/plotly/arm_effects/predicted_effects.py b/ax/analysis/plotly/arm_effects/predicted_effects.py index 75155257dc6..4e04e4e32e8 100644 --- a/ax/analysis/plotly/arm_effects/predicted_effects.py +++ b/ax/analysis/plotly/arm_effects/predicted_effects.py @@ -23,8 +23,8 @@ from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.base import Adapter -from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.transforms.derelativize import Derelativize from pyre_extensions import assert_is_instance, none_throws diff --git a/ax/analysis/plotly/cross_validation.py b/ax/analysis/plotly/cross_validation.py index 08a7cfdf5f8..f210dd1f438 100644 --- a/ax/analysis/plotly/cross_validation.py +++ b/ax/analysis/plotly/cross_validation.py @@ -14,8 +14,8 @@ from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.cross_validation import cross_validate -from ax.modelbridge.generation_strategy import GenerationStrategy from plotly import express as px, graph_objects as go from pyre_extensions import assert_is_instance, none_throws diff --git a/ax/analysis/plotly/surface/contour.py b/ax/analysis/plotly/surface/contour.py index 3ce17133302..47139c0ac56 100644 --- a/ax/analysis/plotly/surface/contour.py +++ b/ax/analysis/plotly/surface/contour.py @@ -22,8 +22,8 @@ from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.observation import ObservationFeatures from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.base import Adapter -from ax.modelbridge.generation_strategy import GenerationStrategy from plotly import graph_objects as go from pyre_extensions import none_throws diff --git a/ax/analysis/plotly/surface/slice.py b/ax/analysis/plotly/surface/slice.py index 282c19c7c26..b031b022a94 100644 --- a/ax/analysis/plotly/surface/slice.py +++ b/ax/analysis/plotly/surface/slice.py @@ -22,8 +22,8 @@ from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.observation import ObservationFeatures from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.base import Adapter -from ax.modelbridge.generation_strategy import GenerationStrategy from plotly import express as px, graph_objects as go from pyre_extensions import none_throws diff --git a/ax/analysis/plotly/tests/test_predicted_effects.py b/ax/analysis/plotly/tests/test_predicted_effects.py index 862274a3003..da402240e67 100644 --- a/ax/analysis/plotly/tests/test_predicted_effects.py +++ b/ax/analysis/plotly/tests/test_predicted_effects.py @@ -15,7 +15,7 @@ from ax.core.observation import ObservationFeatures from ax.core.trial import Trial from ax.exceptions.core import UserInputError -from ax.modelbridge.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.dispatch_utils import choose_generation_strategy from ax.modelbridge.prediction_utils import predict_at_point from ax.modelbridge.registry import Generators from ax.utils.common.testutils import TestCase diff --git a/ax/benchmark/benchmark_method.py b/ax/benchmark/benchmark_method.py index 767a49322e0..a00993b6b02 100644 --- a/ax/benchmark/benchmark_method.py +++ b/ax/benchmark/benchmark_method.py @@ -15,7 +15,7 @@ from ax.core.types import TParameterization from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.service.utils.best_point_mixin import BestPointMixin from ax.utils.common.base import Base from pyre_extensions import none_throws diff --git a/ax/benchmark/methods/modular_botorch.py b/ax/benchmark/methods/modular_botorch.py index 693927fcee5..1e21665b27a 100644 --- a/ax/benchmark/methods/modular_botorch.py +++ b/ax/benchmark/methods/modular_botorch.py @@ -8,8 +8,8 @@ from typing import Any from ax.benchmark.benchmark_method import BenchmarkMethod -from ax.modelbridge.generation_node import GenerationStep -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_node import GenerationStep +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.registry import Generators from ax.models.torch.botorch_modular.surrogate import SurrogateSpec from botorch.acquisition.acquisition import AcquisitionFunction diff --git a/ax/benchmark/methods/sobol.py b/ax/benchmark/methods/sobol.py index 666cccceecd..05a1d8fa66d 100644 --- a/ax/benchmark/methods/sobol.py +++ b/ax/benchmark/methods/sobol.py @@ -7,7 +7,10 @@ from ax.benchmark.benchmark_method import BenchmarkMethod -from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy +from ax.generation_strategy.generation_strategy import ( + GenerationStep, + GenerationStrategy, +) from ax.modelbridge.registry import Generators diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index 4f744976ea5..8f55c5bc1a8 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -47,9 +47,12 @@ from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.search_space import SearchSpace from ax.early_stopping.strategies.threshold import ThresholdEarlyStoppingStrategy -from ax.modelbridge.external_generation_node import ExternalGenerationNode -from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy -from ax.modelbridge.model_spec import GeneratorSpec +from ax.generation_strategy.external_generation_node import ExternalGenerationNode +from ax.generation_strategy.generation_strategy import ( + GenerationNode, + GenerationStrategy, +) +from ax.generation_strategy.model_spec import GeneratorSpec from ax.modelbridge.registry import Generators from ax.service.utils.scheduler_options import TrialType from ax.storage.json_store.load import load_experiment diff --git a/ax/modelbridge/tests/test_aepsych_criterion.py b/ax/generation_strategy/tests/test_aepsych_criterion.py similarity index 96% rename from ax/modelbridge/tests/test_aepsych_criterion.py rename to ax/generation_strategy/tests/test_aepsych_criterion.py index 0d7e9f584c9..6d488fa3615 100644 --- a/ax/modelbridge/tests/test_aepsych_criterion.py +++ b/ax/generation_strategy/tests/test_aepsych_criterion.py @@ -10,9 +10,15 @@ import pandas as pd from ax.core.base_trial import TrialStatus from ax.core.data import Data -from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy +from ax.generation_strategy.generation_strategy import ( + GenerationStep, + GenerationStrategy, +) +from ax.generation_strategy.transition_criterion import ( + MinimumPreferenceOccurances, + MinTrials, +) from ax.modelbridge.registry import Generators -from ax.modelbridge.transition_criterion import MinimumPreferenceOccurances, MinTrials from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_experiment diff --git a/ax/modelbridge/tests/test_best_model_selector.py b/ax/generation_strategy/tests/test_best_model_selector.py similarity index 97% rename from ax/modelbridge/tests/test_best_model_selector.py rename to ax/generation_strategy/tests/test_best_model_selector.py index 6247058a7b3..30aee768195 100644 --- a/ax/modelbridge/tests/test_best_model_selector.py +++ b/ax/generation_strategy/tests/test_best_model_selector.py @@ -9,11 +9,11 @@ from unittest.mock import Mock, patch from ax.exceptions.core import UserInputError -from ax.modelbridge.best_model_selector import ( +from ax.generation_strategy.best_model_selector import ( ReductionCriterion, SingleDiagnosticBestModelSelector, ) -from ax.modelbridge.model_spec import GeneratorSpec +from ax.generation_strategy.model_spec import GeneratorSpec from ax.modelbridge.registry import Generators from ax.utils.common.testutils import TestCase diff --git a/ax/modelbridge/tests/test_dispatch_utils.py b/ax/generation_strategy/tests/test_dispatch_utils.py similarity index 99% rename from ax/modelbridge/tests/test_dispatch_utils.py rename to ax/generation_strategy/tests/test_dispatch_utils.py index 4792ea96004..b77d12f3db9 100644 --- a/ax/modelbridge/tests/test_dispatch_utils.py +++ b/ax/generation_strategy/tests/test_dispatch_utils.py @@ -13,7 +13,7 @@ import torch from ax.core.objective import MultiObjective from ax.core.optimization_config import MultiObjectiveOptimizationConfig -from ax.modelbridge.dispatch_utils import ( +from ax.generation_strategy.dispatch_utils import ( _make_botorch_step, calculate_num_initialization_trials, choose_generation_strategy, diff --git a/ax/modelbridge/tests/test_external_generation_node.py b/ax/generation_strategy/tests/test_external_generation_node.py similarity index 96% rename from ax/modelbridge/tests/test_external_generation_node.py rename to ax/generation_strategy/tests/test_external_generation_node.py index caadf25475e..7b188d57073 100644 --- a/ax/modelbridge/tests/test_external_generation_node.py +++ b/ax/generation_strategy/tests/test_external_generation_node.py @@ -14,8 +14,8 @@ from ax.core.observation import ObservationFeatures from ax.core.types import TParameterization from ax.exceptions.core import UnsupportedError -from ax.modelbridge.external_generation_node import ExternalGenerationNode -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.external_generation_node import ExternalGenerationNode +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.random import RandomAdapter from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( diff --git a/ax/modelbridge/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py similarity index 97% rename from ax/modelbridge/tests/test_generation_node.py rename to ax/generation_strategy/tests/test_generation_node.py index ccbb702ba70..3be75b01ffb 100644 --- a/ax/modelbridge/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -14,23 +14,26 @@ from ax.core.base_trial import TrialStatus from ax.core.observation import ObservationFeatures from ax.exceptions.core import UserInputError -from ax.modelbridge.best_model_selector import ( +from ax.generation_strategy.best_model_selector import ( ReductionCriterion, SingleDiagnosticBestModelSelector, ) -from ax.modelbridge.factory import get_sobol -from ax.modelbridge.generation_node import ( +from ax.generation_strategy.generation_node import ( GenerationNode, GenerationStep, MISSING_MODEL_SELECTOR_MESSAGE, ) -from ax.modelbridge.generation_node_input_constructors import ( +from ax.generation_strategy.generation_node_input_constructors import ( InputConstructorPurpose, NodeInputConstructors, ) -from ax.modelbridge.model_spec import FactoryFunctionGeneratorSpec, GeneratorSpec +from ax.generation_strategy.model_spec import ( + FactoryFunctionGeneratorSpec, + GeneratorSpec, +) +from ax.generation_strategy.transition_criterion import MinTrials +from ax.modelbridge.factory import get_sobol from ax.modelbridge.registry import Generators -from ax.modelbridge.transition_criterion import MinTrials from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.testutils import TestCase diff --git a/ax/modelbridge/tests/test_generation_node_input_constructors.py b/ax/generation_strategy/tests/test_generation_node_input_constructors.py similarity index 98% rename from ax/modelbridge/tests/test_generation_node_input_constructors.py rename to ax/generation_strategy/tests/test_generation_node_input_constructors.py index de6164870f2..1a55c284e17 100644 --- a/ax/modelbridge/tests/test_generation_node_input_constructors.py +++ b/ax/generation_strategy/tests/test_generation_node_input_constructors.py @@ -16,13 +16,13 @@ from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures from ax.exceptions.generation_strategy import AxGenerationException -from ax.modelbridge.generation_node import GenerationNode -from ax.modelbridge.generation_node_input_constructors import ( +from ax.generation_strategy.generation_node import GenerationNode +from ax.generation_strategy.generation_node_input_constructors import ( InputConstructorPurpose, NodeInputConstructors, ) -from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.model_spec import GeneratorSpec +from ax.generation_strategy.generation_strategy import GenerationStrategy +from ax.generation_strategy.model_spec import GeneratorSpec from ax.modelbridge.registry import Generators from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/generation_strategy/tests/test_generation_strategy.py similarity index 99% rename from ax/modelbridge/tests/test_generation_strategy.py rename to ax/generation_strategy/tests/test_generation_strategy.py index a61199028ad..3ef13f543f0 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/generation_strategy/tests/test_generation_strategy.py @@ -29,16 +29,24 @@ GenerationStrategyRepeatedPoints, MaxParallelismReachedException, ) -from ax.modelbridge.best_model_selector import SingleDiagnosticBestModelSelector -from ax.modelbridge.discrete import DiscreteAdapter -from ax.modelbridge.factory import get_sobol -from ax.modelbridge.generation_node import GenerationNode -from ax.modelbridge.generation_node_input_constructors import ( +from ax.generation_strategy.best_model_selector import SingleDiagnosticBestModelSelector +from ax.generation_strategy.generation_node import GenerationNode +from ax.generation_strategy.generation_node_input_constructors import ( InputConstructorPurpose, NodeInputConstructors, ) -from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.model_spec import GeneratorSpec +from ax.generation_strategy.generation_strategy import ( + GenerationStep, + GenerationStrategy, +) +from ax.generation_strategy.model_spec import GeneratorSpec +from ax.generation_strategy.transition_criterion import ( + AutoTransitionAfterGen, + MaxGenerationParallelism, + MinTrials, +) +from ax.modelbridge.discrete import DiscreteAdapter +from ax.modelbridge.factory import get_sobol from ax.modelbridge.random import RandomAdapter from ax.modelbridge.registry import ( _extract_model_state_after_gen, @@ -48,11 +56,6 @@ MODEL_KEY_TO_MODEL_SETUP, ) from ax.modelbridge.torch import TorchAdapter -from ax.modelbridge.transition_criterion import ( - AutoTransitionAfterGen, - MaxGenerationParallelism, - MinTrials, -) from ax.models.random.sobol import SobolGenerator from ax.utils.common.constants import Keys from ax.utils.common.equality import same_elements diff --git a/ax/modelbridge/tests/test_model_spec.py b/ax/generation_strategy/tests/test_model_spec.py similarity index 98% rename from ax/modelbridge/tests/test_model_spec.py rename to ax/generation_strategy/tests/test_model_spec.py index 15911344919..8278aaced60 100644 --- a/ax/modelbridge/tests/test_model_spec.py +++ b/ax/generation_strategy/tests/test_model_spec.py @@ -12,8 +12,11 @@ from ax.core.observation import ObservationFeatures from ax.exceptions.core import UserInputError +from ax.generation_strategy.model_spec import ( + FactoryFunctionGeneratorSpec, + GeneratorSpec, +) from ax.modelbridge.factory import get_sobol -from ax.modelbridge.model_spec import FactoryFunctionGeneratorSpec, GeneratorSpec from ax.modelbridge.modelbridge_utils import extract_search_space_digest from ax.modelbridge.registry import Generators from ax.utils.common.testutils import TestCase diff --git a/ax/modelbridge/tests/test_transition_criterion.py b/ax/generation_strategy/tests/test_transition_criterion.py similarity index 99% rename from ax/modelbridge/tests/test_transition_criterion.py rename to ax/generation_strategy/tests/test_transition_criterion.py index 92b0709143e..a62a7fb9d03 100644 --- a/ax/modelbridge/tests/test_transition_criterion.py +++ b/ax/generation_strategy/tests/test_transition_criterion.py @@ -15,14 +15,13 @@ from ax.core.base_trial import TrialStatus from ax.core.data import Data from ax.exceptions.core import UserInputError -from ax.modelbridge.generation_strategy import ( +from ax.generation_strategy.generation_strategy import ( GenerationNode, GenerationStep, GenerationStrategy, ) -from ax.modelbridge.model_spec import GeneratorSpec -from ax.modelbridge.registry import Generators -from ax.modelbridge.transition_criterion import ( +from ax.generation_strategy.model_spec import GeneratorSpec +from ax.generation_strategy.transition_criterion import ( AutoTransitionAfterGen, AuxiliaryExperimentCheck, IsSingleObjective, @@ -32,6 +31,7 @@ MinimumTrialsInStatus, MinTrials, ) +from ax.modelbridge.registry import Generators from ax.utils.common.logger import get_logger from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( diff --git a/ax/modelbridge/tests/test_model_fit_metrics.py b/ax/modelbridge/tests/test_model_fit_metrics.py index c425bf925cc..659fd28e183 100644 --- a/ax/modelbridge/tests/test_model_fit_metrics.py +++ b/ax/modelbridge/tests/test_model_fit_metrics.py @@ -14,6 +14,10 @@ from ax.core.experiment import Experiment from ax.core.objective import Objective from ax.core.optimization_config import OptimizationConfig +from ax.generation_strategy.generation_strategy import ( + GenerationStep, + GenerationStrategy, +) from ax.metrics.branin import BraninMetric from ax.modelbridge.cross_validation import ( _predict_on_cross_validation_data, @@ -21,7 +25,6 @@ compute_model_fit_metrics_from_modelbridge, get_fit_and_std_quality_and_generalization_dict, ) -from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import Generators from ax.runners.synthetic import SyntheticRunner from ax.service.scheduler import get_fitted_model_bridge, Scheduler, SchedulerOptions diff --git a/ax/preview/api/client.py b/ax/preview/api/client.py index c53161377cd..f489a3664dc 100644 --- a/ax/preview/api/client.py +++ b/ax/preview/api/client.py @@ -34,7 +34,7 @@ PercentileEarlyStoppingStrategy, ) from ax.exceptions.core import ObjectNotFoundError, UnsupportedError -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.preview.api.configs import ( ExperimentConfig, GenerationStrategyConfig, diff --git a/ax/preview/modelbridge/dispatch_utils.py b/ax/preview/modelbridge/dispatch_utils.py index 47e743cd3ed..d7853290d64 100644 --- a/ax/preview/modelbridge/dispatch_utils.py +++ b/ax/preview/modelbridge/dispatch_utils.py @@ -9,10 +9,13 @@ import torch from ax.core.base_trial import TrialStatus from ax.exceptions.core import UnsupportedError -from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy -from ax.modelbridge.model_spec import GeneratorSpec +from ax.generation_strategy.generation_strategy import ( + GenerationNode, + GenerationStrategy, +) +from ax.generation_strategy.model_spec import GeneratorSpec +from ax.generation_strategy.transition_criterion import MinTrials from ax.modelbridge.registry import Generators -from ax.modelbridge.transition_criterion import MinTrials from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec from ax.preview.api.configs import GenerationMethod, GenerationStrategyConfig from botorch.models.transforms.input import Normalize, Warp diff --git a/ax/preview/modelbridge/tests/test_dispatch_utils.py b/ax/preview/modelbridge/tests/test_preview_dispatch_utils.py similarity index 99% rename from ax/preview/modelbridge/tests/test_dispatch_utils.py rename to ax/preview/modelbridge/tests/test_preview_dispatch_utils.py index fb13853d0a3..a655632580f 100644 --- a/ax/preview/modelbridge/tests/test_dispatch_utils.py +++ b/ax/preview/modelbridge/tests/test_preview_dispatch_utils.py @@ -8,8 +8,8 @@ import torch from ax.core.base_trial import TrialStatus from ax.core.trial import Trial +from ax.generation_strategy.transition_criterion import MinTrials from ax.modelbridge.registry import Generators -from ax.modelbridge.transition_criterion import MinTrials from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec from ax.preview.api.configs import GenerationMethod, GenerationStrategyConfig from ax.preview.modelbridge.dispatch_utils import choose_generation_strategy diff --git a/ax/runners/tests/test_torchx.py b/ax/runners/tests/test_torchx.py index 9c38df614b1..00e8642e67b 100644 --- a/ax/runners/tests/test_torchx.py +++ b/ax/runners/tests/test_torchx.py @@ -20,8 +20,9 @@ RangeParameter, SearchSpace, ) + +from ax.generation_strategy.dispatch_utils import choose_generation_strategy from ax.metrics.torchx import TorchXMetric -from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.runners.torchx import TorchXRunner from ax.service.scheduler import FailureRateExceededError, Scheduler, SchedulerOptions from ax.utils.common.constants import Keys diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 4821dd67cb1..76d92b5c221 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -56,10 +56,10 @@ UserInputError, ) from ax.exceptions.generation_strategy import MaxParallelismReachedException +from ax.generation_strategy.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy from ax.global_stopping.strategies.improvement import constraint_satisfaction -from ax.modelbridge.dispatch_utils import choose_generation_strategy -from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.prediction_utils import predict_by_features from ax.plot.base import AxPlotConfig from ax.plot.contour import plot_contour diff --git a/ax/service/managed_loop.py b/ax/service/managed_loop.py index ee1f9733fe4..379f91c74ab 100644 --- a/ax/service/managed_loop.py +++ b/ax/service/managed_loop.py @@ -27,9 +27,9 @@ from ax.core.utils import get_pending_observation_features from ax.exceptions.constants import CHOLESKY_ERROR_ANNOTATION from ax.exceptions.core import SearchSpaceExhausted, UserInputError +from ax.generation_strategy.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.base import Adapter -from ax.modelbridge.dispatch_utils import choose_generation_strategy -from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.registry import Generators from ax.service.utils.best_point import ( get_best_parameters_from_model_predictions_with_trial_index, diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index df1fe924441..7e81c06ca46 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -48,8 +48,8 @@ MaxParallelismReachedException, OptimizationConfigRequired, ) +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.base import Adapter -from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.modelbridge_utils import get_fixed_features_from_experiment from ax.service.utils.analysis_base import AnalysisBase from ax.service.utils.best_point_mixin import BestPointMixin diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 2c758413ac1..becba865fa3 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -44,11 +44,14 @@ UserInputError, ) from ax.exceptions.generation_strategy import AxGenerationException +from ax.generation_strategy.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.generation_strategy import ( + GenerationStep, + GenerationStrategy, +) from ax.metrics.branin import BraninMetric from ax.metrics.branin_map import BraninTimestampMapMetric from ax.modelbridge.cross_validation import compute_model_fit_metrics_from_modelbridge -from ax.modelbridge.dispatch_utils import choose_generation_strategy -from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import Generators, MBM_MTGP_trans from ax.runners.single_running_trial_mixin import SingleRunningTrialMixin from ax.runners.synthetic import SyntheticRunner diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 2d217a5c096..349367d1b33 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -48,14 +48,14 @@ UserInputError, ) from ax.exceptions.generation_strategy import MaxParallelismReachedException -from ax.metrics.branin import branin, BraninMetric -from ax.modelbridge.dispatch_utils import DEFAULT_BAYESIAN_PARALLELISM -from ax.modelbridge.generation_strategy import ( +from ax.generation_strategy.dispatch_utils import DEFAULT_BAYESIAN_PARALLELISM +from ax.generation_strategy.generation_strategy import ( GenerationNode, GenerationStep, GenerationStrategy, ) -from ax.modelbridge.model_spec import GeneratorSpec +from ax.generation_strategy.model_spec import GeneratorSpec +from ax.metrics.branin import branin, BraninMetric from ax.modelbridge.random import RandomAdapter from ax.modelbridge.registry import Cont_X_trans, Generators from ax.runners.synthetic import SyntheticRunner @@ -3048,7 +3048,7 @@ def test_SingleTaskGP_log_unordered_categorical_parameters(self) -> None: ] with mock.patch( - "ax.modelbridge.dispatch_utils.logger.info", + "ax.generation_strategy.dispatch_utils.logger.info", side_effect=(lambda log: logs.append(log)), ): ax_client.create_experiment( diff --git a/ax/service/tests/test_interactive_loop.py b/ax/service/tests/test_interactive_loop.py index 8a734ebf8f5..5ef08d05e3f 100644 --- a/ax/service/tests/test_interactive_loop.py +++ b/ax/service/tests/test_interactive_loop.py @@ -16,7 +16,10 @@ import numpy as np from ax.core.types import TEvaluationOutcome, TParameterization -from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy +from ax.generation_strategy.generation_strategy import ( + GenerationStep, + GenerationStrategy, +) from ax.modelbridge.registry import Generators from ax.service.ax_client import AxClient from ax.service.interactive_loop import ( diff --git a/ax/service/tests/test_managed_loop.py b/ax/service/tests/test_managed_loop.py index 7a0cc14d82f..73a4d06d834 100644 --- a/ax/service/tests/test_managed_loop.py +++ b/ax/service/tests/test_managed_loop.py @@ -11,8 +11,11 @@ import numpy as np import numpy.typing as npt from ax.exceptions.core import UserInputError +from ax.generation_strategy.generation_strategy import ( + GenerationStep, + GenerationStrategy, +) from ax.metrics.branin import branin -from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import Generators from ax.service.managed_loop import OptimizationLoop, optimize from ax.utils.common.testutils import TestCase diff --git a/ax/service/tests/test_report_utils.py b/ax/service/tests/test_report_utils.py index 6245fd6ff7b..b6a4d9f825b 100644 --- a/ax/service/tests/test_report_utils.py +++ b/ax/service/tests/test_report_utils.py @@ -22,8 +22,8 @@ ) from ax.core.outcome_constraint import ObjectiveThreshold from ax.core.types import ComparisonOp -from ax.modelbridge.generation_node import GenerationStep -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_node import GenerationStep +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.registry import Generators from ax.service.scheduler import Scheduler from ax.service.utils.report_utils import ( diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index 95cbd3cb916..c7d3941889c 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -9,9 +9,12 @@ from ax.core.multi_type_experiment import MultiTypeExperiment from ax.core.objective import Objective from ax.core.optimization_config import OptimizationConfig +from ax.generation_strategy.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.generation_strategy import ( + GenerationStep, + GenerationStrategy, +) from ax.metrics.branin import BraninMetric -from ax.modelbridge.dispatch_utils import choose_generation_strategy -from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import Generators from ax.service.tests.scheduler_test_utils import ( AxSchedulerTestCase, diff --git a/ax/service/tests/test_with_db_settings_base.py b/ax/service/tests/test_with_db_settings_base.py index 6a6b48f21e3..4fa00749fbf 100644 --- a/ax/service/tests/test_with_db_settings_base.py +++ b/ax/service/tests/test_with_db_settings_base.py @@ -12,7 +12,7 @@ from ax.core.base_trial import TrialStatus from ax.core.experiment import Experiment -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.service.utils.with_db_settings_base import ( try_load_generation_strategy, WithDBSettingsBase, diff --git a/ax/service/utils/best_point.py b/ax/service/utils/best_point.py index d50aa1fbfc3..bbd1a5e8d66 100644 --- a/ax/service/utils/best_point.py +++ b/ax/service/utils/best_point.py @@ -30,13 +30,13 @@ from ax.core.trial import Trial from ax.core.types import ComparisonOp, TModelPredictArm, TParameterization from ax.exceptions.core import UnsupportedError, UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.base import Adapter from ax.modelbridge.cross_validation import ( assess_model_fit, compute_diagnostics, cross_validate, ) -from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.modelbridge_utils import ( observed_pareto_frontier as observed_pareto, predicted_pareto_frontier as predicted_pareto, diff --git a/ax/service/utils/best_point_mixin.py b/ax/service/utils/best_point_mixin.py index 4090e026e8d..4d72cfe8df6 100644 --- a/ax/service/utils/best_point_mixin.py +++ b/ax/service/utils/best_point_mixin.py @@ -23,7 +23,7 @@ from ax.core.trial import Trial from ax.core.types import TModelPredictArm, TParameterization from ax.exceptions.core import UserInputError -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.modelbridge_utils import ( extract_objective_thresholds, extract_objective_weights, diff --git a/ax/service/utils/report_utils.py b/ax/service/utils/report_utils.py index a9e7cb6b47c..cb1730149c9 100644 --- a/ax/service/utils/report_utils.py +++ b/ax/service/utils/report_utils.py @@ -35,12 +35,12 @@ from ax.core.trial import BaseTrial from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy from ax.exceptions.core import DataRequiredError, UserInputError +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge import Adapter from ax.modelbridge.cross_validation import ( compute_model_fit_metrics_from_modelbridge, cross_validate, ) -from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.random import RandomAdapter from ax.modelbridge.torch import TorchAdapter from ax.plot.contour import interact_contour_plotly diff --git a/ax/service/utils/with_db_settings_base.py b/ax/service/utils/with_db_settings_base.py index fa8df73c864..31010da5cb6 100644 --- a/ax/service/utils/with_db_settings_base.py +++ b/ax/service/utils/with_db_settings_base.py @@ -25,7 +25,7 @@ ObjectNotFoundError, UnsupportedError, ) -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.utils.common.executils import retry_on_exception from ax.utils.common.logger import _round_floats_for_logging, get_logger from pyre_extensions import none_throws diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index fa5d1e69471..cb0e435ed94 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -33,19 +33,21 @@ ) from ax.core.search_space import SearchSpace from ax.exceptions.storage import JSONDecodeError, STORAGE_DOCS_SUFFIX -from ax.modelbridge.generation_node_input_constructors import InputConstructorPurpose -from ax.modelbridge.generation_strategy import ( +from ax.generation_strategy.generation_node_input_constructors import ( + InputConstructorPurpose, +) +from ax.generation_strategy.generation_strategy import ( GenerationNode, GenerationStep, GenerationStrategy, ) -from ax.modelbridge.model_spec import GeneratorSpec -from ax.modelbridge.registry import _decode_callables_from_references, ModelRegistryBase -from ax.modelbridge.transition_criterion import ( +from ax.generation_strategy.model_spec import GeneratorSpec +from ax.generation_strategy.transition_criterion import ( AuxiliaryExperimentCheck, TransitionCriterion, TrialBasedCriterion, ) +from ax.modelbridge.registry import _decode_callables_from_references, ModelRegistryBase from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_modular.utils import ModelConfig from ax.storage.json_store.decoders import ( diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 7279021cc4c..03448f72431 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -43,14 +43,20 @@ ) from ax.exceptions.core import AxStorageWarning from ax.exceptions.storage import JSONEncodeError, STORAGE_DOCS_SUFFIX +from ax.generation_strategy.best_model_selector import BestModelSelector +from ax.generation_strategy.generation_node import GenerationNode +from ax.generation_strategy.generation_strategy import ( + GenerationStep, + GenerationStrategy, +) +from ax.generation_strategy.model_spec import ( + FactoryFunctionGeneratorSpec, + GeneratorSpec, +) +from ax.generation_strategy.transition_criterion import TransitionCriterion from ax.global_stopping.strategies.improvement import ImprovementGlobalStoppingStrategy -from ax.modelbridge.best_model_selector import BestModelSelector -from ax.modelbridge.generation_node import GenerationNode -from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.model_spec import FactoryFunctionGeneratorSpec, GeneratorSpec from ax.modelbridge.registry import _encode_callables_as_references from ax.modelbridge.transforms.base import Transform -from ax.modelbridge.transition_criterion import TransitionCriterion from ax.models.torch.botorch_modular.model import BoTorchGenerator from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.models.winsorization_config import WinsorizationConfig diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index 149a3712dbb..7f05d636e11 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -67,30 +67,18 @@ AndEarlyStoppingStrategy, OrEarlyStoppingStrategy, ) -from ax.global_stopping.strategies.improvement import ImprovementGlobalStoppingStrategy -from ax.metrics.branin import BraninMetric, NegativeBraninMetric -from ax.metrics.branin_map import BraninTimestampMapMetric -from ax.metrics.chemistry import ChemistryMetric, ChemistryProblemType -from ax.metrics.factorial import FactorialMetric -from ax.metrics.hartmann6 import Hartmann6Metric -from ax.metrics.l2norm import L2NormMetric -from ax.metrics.noisy_function import NoisyFunctionMetric -from ax.metrics.sklearn import SklearnDataset, SklearnMetric, SklearnModelType -from ax.modelbridge.best_model_selector import ( +from ax.generation_strategy.best_model_selector import ( ReductionCriterion, SingleDiagnosticBestModelSelector, ) -from ax.modelbridge.factory import Generators -from ax.modelbridge.generation_node import GenerationNode, GenerationStep -from ax.modelbridge.generation_node_input_constructors import ( +from ax.generation_strategy.generation_node import GenerationNode, GenerationStep +from ax.generation_strategy.generation_node_input_constructors import ( InputConstructorPurpose, NodeInputConstructors, ) -from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.model_spec import GeneratorSpec -from ax.modelbridge.registry import ModelRegistryBase -from ax.modelbridge.transforms.base import Transform -from ax.modelbridge.transition_criterion import ( +from ax.generation_strategy.generation_strategy import GenerationStrategy +from ax.generation_strategy.model_spec import GeneratorSpec +from ax.generation_strategy.transition_criterion import ( AutoTransitionAfterGen, AuxiliaryExperimentCheck, IsSingleObjective, @@ -101,6 +89,18 @@ MinTrials, TransitionCriterion, ) +from ax.global_stopping.strategies.improvement import ImprovementGlobalStoppingStrategy +from ax.metrics.branin import BraninMetric, NegativeBraninMetric +from ax.metrics.branin_map import BraninTimestampMapMetric +from ax.metrics.chemistry import ChemistryMetric, ChemistryProblemType +from ax.metrics.factorial import FactorialMetric +from ax.metrics.hartmann6 import Hartmann6Metric +from ax.metrics.l2norm import L2NormMetric +from ax.metrics.noisy_function import NoisyFunctionMetric +from ax.metrics.sklearn import SklearnDataset, SklearnMetric, SklearnModelType +from ax.modelbridge.factory import Generators +from ax.modelbridge.registry import ModelRegistryBase +from ax.modelbridge.transforms.base import Transform from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.model import BoTorchGenerator from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 295ff949365..9eda3b31617 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -19,8 +19,8 @@ from ax.core.runner import Runner from ax.exceptions.core import AxStorageWarning from ax.exceptions.storage import JSONDecodeError, JSONEncodeError -from ax.modelbridge.generation_node import GenerationStep -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_node import GenerationStep +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.registry import Generators from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel from ax.models.torch.botorch_modular.surrogate import SurrogateSpec diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index a6e1a24d24f..bad3318389c 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -47,7 +47,7 @@ from ax.core.search_space import HierarchicalSearchSpace, RobustSearchSpace, SearchSpace from ax.core.trial import Trial from ax.exceptions.storage import JSONDecodeError, SQADecodeError -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.storage.json_store.decoder import object_from_json from ax.storage.sqa_store.db import session_scope from ax.storage.sqa_store.sqa_classes import ( diff --git a/ax/storage/sqa_store/delete.py b/ax/storage/sqa_store/delete.py index 8d719f4f825..a68d05800dc 100644 --- a/ax/storage/sqa_store/delete.py +++ b/ax/storage/sqa_store/delete.py @@ -8,7 +8,7 @@ from logging import Logger from ax.core.experiment import Experiment -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.storage.sqa_store.db import session_scope from ax.storage.sqa_store.decoder import Decoder from ax.storage.sqa_store.sqa_classes import SQAExperiment diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 76372fa5e92..aab1a486e26 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -45,7 +45,7 @@ from ax.core.search_space import RobustSearchSpace, SearchSpace from ax.core.trial import Trial from ax.exceptions.storage import SQAEncodeError -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.storage.json_store.encoder import object_to_json from ax.storage.sqa_store.sqa_classes import ( SQAAbandonedArm, diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index 58ebdcce10a..0ea39e8787a 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -16,7 +16,7 @@ from ax.core.metric import Metric from ax.core.trial import Trial from ax.exceptions.core import ExperimentNotFoundError, ObjectNotFoundError -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.storage.sqa_store.db import session_scope from ax.storage.sqa_store.decoder import Decoder from ax.storage.sqa_store.reduced_state import ( diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index a60bb666cf5..05b3deb9a4b 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -25,7 +25,7 @@ from ax.core.trial import Trial from ax.exceptions.core import UserInputError from ax.exceptions.storage import SQADecodeError -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.storage.sqa_store.db import session_scope, SQABase from ax.storage.sqa_store.decoder import Decoder from ax.storage.sqa_store.encoder import Encoder diff --git a/ax/storage/sqa_store/sqa_config.py b/ax/storage/sqa_store/sqa_config.py index 3e615151fbf..9b7eb640762 100644 --- a/ax/storage/sqa_store/sqa_config.py +++ b/ax/storage/sqa_store/sqa_config.py @@ -24,7 +24,7 @@ from ax.core.parameter_constraint import ParameterConstraint from ax.core.runner import Runner from ax.core.trial import Trial -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.storage.json_store.registry import ( CORE_CLASS_DECODER_REGISTRY, CORE_CLASS_ENCODER_REGISTRY, diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index d98bf51c1e2..84984875d13 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -33,8 +33,8 @@ from ax.core.types import ComparisonOp from ax.exceptions.core import ObjectNotFoundError from ax.exceptions.storage import JSONDecodeError, SQADecodeError, SQAEncodeError +from ax.generation_strategy.dispatch_utils import choose_generation_strategy from ax.metrics.branin import BraninMetric -from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.registry import Generators from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.runners.synthetic import SyntheticRunner diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index faf5a0053ae..fdd52a05c1f 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -38,8 +38,8 @@ from ax.core.trial import Trial from ax.core.types import TParameterization, TParamValue from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy -from ax.modelbridge.external_generation_node import ExternalGenerationNode -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.external_generation_node import ExternalGenerationNode +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.modelbridge.torch import TorchAdapter from ax.models.torch.botorch_modular.model import BoTorchGenerator from ax.models.torch.botorch_modular.surrogate import Surrogate diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index c45001754b7..4b07d949310 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -78,24 +78,27 @@ OrEarlyStoppingStrategy, ) from ax.exceptions.core import UserInputError -from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy -from ax.global_stopping.strategies.improvement import ImprovementGlobalStoppingStrategy -from ax.metrics.branin import BraninMetric -from ax.metrics.branin_map import BraninTimestampMapMetric -from ax.metrics.factorial import FactorialMetric -from ax.metrics.hartmann6 import Hartmann6Metric -from ax.modelbridge.factory import Cont_X_trans, Generators, get_factorial, get_sobol -from ax.modelbridge.generation_node_input_constructors import ( +from ax.generation_strategy.generation_node_input_constructors import ( InputConstructorPurpose, NodeInputConstructors, ) -from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy -from ax.modelbridge.model_spec import GeneratorSpec -from ax.modelbridge.transition_criterion import ( +from ax.generation_strategy.generation_strategy import ( + GenerationNode, + GenerationStrategy, +) +from ax.generation_strategy.model_spec import GeneratorSpec +from ax.generation_strategy.transition_criterion import ( MaxGenerationParallelism, MinTrials, TrialBasedCriterion, ) +from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy +from ax.global_stopping.strategies.improvement import ImprovementGlobalStoppingStrategy +from ax.metrics.branin import BraninMetric +from ax.metrics.branin_map import BraninTimestampMapMetric +from ax.metrics.factorial import FactorialMetric +from ax.metrics.hartmann6 import Hartmann6Metric +from ax.modelbridge.factory import Cont_X_trans, Generators, get_factorial, get_sobol from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel from ax.models.torch.botorch_modular.model import BoTorchGenerator diff --git a/ax/utils/testing/modeling_stubs.py b/ax/utils/testing/modeling_stubs.py index d117afa0265..5eb93e38c66 100644 --- a/ax/utils/testing/modeling_stubs.py +++ b/ax/utils/testing/modeling_stubs.py @@ -17,33 +17,36 @@ from ax.core.parameter import FixedParameter, RangeParameter from ax.core.search_space import SearchSpace from ax.exceptions.core import UserInputError -from ax.modelbridge.base import Adapter -from ax.modelbridge.best_model_selector import ( +from ax.generation_strategy.best_model_selector import ( ReductionCriterion, SingleDiagnosticBestModelSelector, ) -from ax.modelbridge.cross_validation import FISHER_EXACT_TEST_P -from ax.modelbridge.dispatch_utils import choose_generation_strategy -from ax.modelbridge.factory import get_sobol -from ax.modelbridge.generation_node import GenerationNode +from ax.generation_strategy.dispatch_utils import choose_generation_strategy +from ax.generation_strategy.generation_node import GenerationNode -from ax.modelbridge.generation_node_input_constructors import ( +from ax.generation_strategy.generation_node_input_constructors import ( InputConstructorPurpose, NodeInputConstructors, ) -from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.model_spec import GeneratorSpec -from ax.modelbridge.registry import Generators -from ax.modelbridge.transforms.base import Transform -from ax.modelbridge.transforms.int_to_float import IntToFloat -from ax.modelbridge.transforms.transform_to_new_sq import TransformToNewSQ -from ax.modelbridge.transition_criterion import ( +from ax.generation_strategy.generation_strategy import ( + GenerationStep, + GenerationStrategy, +) +from ax.generation_strategy.model_spec import GeneratorSpec +from ax.generation_strategy.transition_criterion import ( AutoTransitionAfterGen, IsSingleObjective, MaxGenerationParallelism, MinimumPreferenceOccurances, MinTrials, ) +from ax.modelbridge.base import Adapter +from ax.modelbridge.cross_validation import FISHER_EXACT_TEST_P +from ax.modelbridge.factory import get_sobol +from ax.modelbridge.registry import Generators +from ax.modelbridge.transforms.base import Transform +from ax.modelbridge.transforms.int_to_float import IntToFloat +from ax.modelbridge.transforms.transform_to_new_sq import TransformToNewSQ from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger diff --git a/ax/utils/testing/tests/test_utils.py b/ax/utils/testing/tests/test_utils.py index edf936545ac..a47530d22da 100644 --- a/ax/utils/testing/tests/test_utils.py +++ b/ax/utils/testing/tests/test_utils.py @@ -8,8 +8,11 @@ import numpy as np import torch -from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy -from ax.modelbridge.model_spec import GeneratorSpec +from ax.generation_strategy.generation_strategy import ( + GenerationNode, + GenerationStrategy, +) +from ax.generation_strategy.model_spec import GeneratorSpec from ax.modelbridge.registry import Generators from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_experiment_with_observations diff --git a/ax/utils/testing/utils.py b/ax/utils/testing/utils.py index 49f9be7711e..fb8840ac337 100644 --- a/ax/utils/testing/utils.py +++ b/ax/utils/testing/utils.py @@ -14,7 +14,7 @@ from ax.core.data import Data from ax.core.experiment import Experiment from ax.exceptions.core import UnsupportedError -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from pyre_extensions import none_throws from torch import Tensor diff --git a/docs/api.md b/docs/api.md index d5e3a79e34c..bb93f02a08c 100644 --- a/docs/api.md +++ b/docs/api.md @@ -184,7 +184,7 @@ best_parameters = best_arm.parameters ```py from ax import * -from ax.modelbridge.generation_strategy import GenerationStrategy +from ax.generation_strategy.generation_strategy import GenerationStrategy from ax.service import Scheduler # Full `Experiment` and `GenerationStrategy` instantiation diff --git a/docs/glossary.md b/docs/glossary.md index 39a451021a7..ba08d22e312 100644 --- a/docs/glossary.md +++ b/docs/glossary.md @@ -29,7 +29,7 @@ Object that keeps track of the whole optimization process. Contains a [search sp ### Generation strategy -Abstraction that allows to declaratively specify one or multiple models to use in the course of the optimization and automate transition between them (relevant [tutorial](/docs/tutorials/scheduler)). [`[GenerationStrategy]`](https://ax.readthedocs.io/en/latest/modelbridge.html#module-ax.modelbridge.generation_strategy) +Abstraction that allows to declaratively specify one or multiple models to use in the course of the optimization and automate transition between them (relevant [tutorial](/docs/tutorials/scheduler)). [`[GenerationStrategy]`](https://ax.readthedocs.io/en/latest/modelbridge.html#module-ax.generation_strategy.generation_strategy) ### Generator run diff --git a/sphinx/source/generation_strategy.rst b/sphinx/source/generation_strategy.rst index 096b10cdb85..8e61dcc8d8a 100644 --- a/sphinx/source/generation_strategy.rst +++ b/sphinx/source/generation_strategy.rst @@ -8,7 +8,7 @@ ax.generation_strategy .. currentmodule:: ax.generation_strategy -Generation Strategy, Registry, and Factory +Generation Strategy ------------------------------------------ Generation Strategy diff --git a/tutorials/early_stopping/early_stopping.ipynb b/tutorials/early_stopping/early_stopping.ipynb index 2c1541f8cda..26b27fcf942 100644 --- a/tutorials/early_stopping/early_stopping.ipynb +++ b/tutorials/early_stopping/early_stopping.ipynb @@ -1,681 +1,681 @@ { - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "id": "12fe3797", - "metadata": {}, - "source": [ - "## Trial-level early stopping in Ax\n", - "\n", - "This tutorial illustrates how to add a trial-level early stopping strategy to an Ax hyper-parameter optimization (HPO) loop. The goal of trial-level early stopping is to monitor the results of expensive evaluations and terminate those that are unlikely to produce promising results, freeing up resources to explore more configurations.\n", - "\n", - "Most of this tutorial is adapted from the [PyTorch Ax Multiobjective NAS Tutorial](https://pytorch.org/tutorials/intermediate/ax_multiobjective_nas_tutorial.html). The training job is different from the original in that we do not optimize `batch_size` or `epochs`. This was done for illustrative purposes, as each validation curve now has the same number of points. The companion training file `mnist_train_nas.py` has also been altered to log to Tensorboard during training.\n", - "\n", - "NOTE: Although the original NAS tutorial is for a multi-objective problem, this tutorial focuses on a single objective (validation accuracy) problem. Early stopping currently does not support \\\"true\\\" multi-objective stopping, although one can use [logical compositions of early stopping strategies](https://github.com/facebook/Ax/blob/main/ax/early_stopping/strategies/logical.py) to target multiple objectives separately. Early stopping for the multi-objective case is currently a work in progress." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "779ea790", - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "import plotly.io as pio\n", - "if 'google.colab' in sys.modules:\n", - " pio.renderers.default = \"colab\"\n", - " %pip install ax-platform" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cb953f30", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import tempfile\n", - "\n", - "from pathlib import Path\n", - "\n", - "import torchx\n", - "\n", - "from ax.core import Experiment, Objective, ParameterType, RangeParameter, SearchSpace\n", - "from ax.core.optimization_config import OptimizationConfig\n", - "\n", - "from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy\n", - "from ax.metrics.tensorboard import TensorboardMetric\n", - "\n", - "from ax.modelbridge.dispatch_utils import choose_generation_strategy\n", - "\n", - "from ax.runners.torchx import TorchXRunner\n", - "\n", - "from ax.service.scheduler import Scheduler, SchedulerOptions\n", - "from ax.service.utils.report_utils import exp_to_df\n", - "\n", - "from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer\n", - "\n", - "from torchx import specs\n", - "from torchx.components import utils\n", - "\n", - "from matplotlib import pyplot as plt\n", - "\n", - "\n", - "%matplotlib inline" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8a7bd328", - "metadata": {}, - "outputs": [], - "source": [ - "SMOKE_TEST = os.environ.get(\"SMOKE_TEST\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "fe2cf6fe", - "metadata": {}, - "source": [ - "## Defining the TorchX App\n", - "\n", - "Our goal is to optimize the PyTorch Lightning training job defined in\n", - "[mnist_train_nas.py](https://github.com/pytorch/tutorials/tree/master/intermediate_source/mnist_train_nas.py)_.\n", - "To do this using TorchX, we write a helper function that takes in\n", - "the values of the architcture and hyperparameters of the training\n", - "job and creates a [TorchX AppDef](https://pytorch.org/torchx/latest/basics.html)_\n", - "with the appropriate settings.\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2e21d309", - "metadata": {}, - "outputs": [], - "source": [ - "if SMOKE_TEST:\n", - " epochs = 3\n", - "else:\n", - " epochs = 10" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b423923c", - "metadata": {}, - "outputs": [], - "source": [ - "def trainer(\n", - " log_path: str,\n", - " hidden_size_1: int,\n", - " hidden_size_2: int,\n", - " learning_rate: float,\n", - " dropout: float,\n", - " trial_idx: int = -1,\n", - ") -> specs.AppDef:\n", - "\n", - " # define the log path so we can pass it to the TorchX AppDef\n", - " if trial_idx >= 0:\n", - " log_path = Path(log_path).joinpath(str(trial_idx)).absolute().as_posix()\n", - "\n", - " batch_size = 32\n", - "\n", - " return utils.python(\n", - " # command line args to the training script\n", - " \"--log_path\",\n", - " log_path,\n", - " \"--hidden_size_1\",\n", - " str(hidden_size_1),\n", - " \"--hidden_size_2\",\n", - " str(hidden_size_2),\n", - " \"--learning_rate\",\n", - " str(learning_rate),\n", - " \"--epochs\",\n", - " str(epochs),\n", - " \"--dropout\",\n", - " str(dropout),\n", - " \"--batch_size\",\n", - " str(batch_size),\n", - " # other config options\n", - " name=\"trainer\",\n", - " script=\"tutorials/early_stopping/mnist_train_nas.py\",\n", - " image=torchx.version.TORCHX_IMAGE,\n", - " )" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "65f7011d", - "metadata": {}, - "source": [ - "## Setting up the Runner\n", - "\n", - "Ax’s [Runner](https://ax.dev/api/core.html#ax.core.runner.Runner)\n", - "abstraction allows writing interfaces to various backends.\n", - "Ax already comes with Runner for TorchX, so we just need to\n", - "configure it. For the purpose of this tutorial, we run jobs locally\n", - "in a fully asynchronous fashion. In order to launch them on a cluster, you can instead specify a\n", - "different TorchX scheduler and adjust the configuration appropriately.\n", - "For example, if you have a Kubernetes cluster, you just need to change the\n", - "scheduler from ``local_cwd`` to ``kubernetes``.\n", - "\n", - "The training job launched by this runner will log partial results to Tensorboard, which will then be monitored by the early stopping strategy. We will show how this is done using an Ax \n", - "[TensorboardMetric](https://ax.dev/api/metrics.html#module-ax.metrics.tensorboard) below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "998e5835", - "metadata": {}, - "outputs": [], - "source": [ - "# Make a temporary dir to log our results into\n", - "log_dir = tempfile.mkdtemp()\n", - "\n", - "ax_runner = TorchXRunner(\n", - " tracker_base=\"/tmp/\",\n", - " component=trainer,\n", - " # NOTE: To launch this job on a cluster instead of locally you can\n", - " # specify a different scheduler and adjust args appropriately.\n", - " scheduler=\"local_cwd\",\n", - " component_const_params={\"log_path\": log_dir},\n", - " cfg={},\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "2fec7495", - "metadata": {}, - "source": [ - "## Setting up the SearchSpace\n", - "\n", - "First, we define our search space. Ax supports both range parameters\n", - "of type integer and float as well as choice parameters which can have\n", - "non-numerical types such as strings.\n", - "We will tune the hidden sizes, learning rate, and dropout parameters." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "cf6f869f", - "metadata": {}, - "outputs": [], - "source": [ - "parameters = [\n", - " # NOTE: In a real-world setting, hidden_size_1 and hidden_size_2\n", - " # should probably be powers of 2, but in our simple example this\n", - " # would mean that num_params can't take on that many values, which\n", - " # in turn makes the Pareto frontier look pretty weird.\n", - " RangeParameter(\n", - " name=\"hidden_size_1\",\n", - " lower=16,\n", - " upper=128,\n", - " parameter_type=ParameterType.INT,\n", - " log_scale=True,\n", - " ),\n", - " RangeParameter(\n", - " name=\"hidden_size_2\",\n", - " lower=16,\n", - " upper=128,\n", - " parameter_type=ParameterType.INT,\n", - " log_scale=True,\n", - " ),\n", - " RangeParameter(\n", - " name=\"learning_rate\",\n", - " lower=1e-4,\n", - " upper=1e-2,\n", - " parameter_type=ParameterType.FLOAT,\n", - " log_scale=True,\n", - " ),\n", - " RangeParameter(\n", - " name=\"dropout\",\n", - " lower=0.0,\n", - " upper=0.5,\n", - " parameter_type=ParameterType.FLOAT,\n", - " ),\n", - "]\n", - "\n", - "search_space = SearchSpace(\n", - " parameters=parameters,\n", - " # NOTE: In practice, it may make sense to add a constraint\n", - " # hidden_size_2 <= hidden_size_1\n", - " parameter_constraints=[],\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "a8005e80", - "metadata": {}, - "source": [ - "## Setting up Metrics\n", - "\n", - "Ax has the concept of a Metric that defines properties of outcomes and how observations are obtained for these outcomes. This allows e.g. encodig how data is fetched from some distributed execution backend and post-processed before being passed as input to Ax.\n", - "\n", - "We will optimize the validation accuracy, which is a `TensorboardMetric` that points to the logging directory assigned above. Note that we have set `is_available_while_running`, allowing for the metric to be queried as the trial progresses. This is critical for the early stopping strategy to monitor partial results." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0775a96e", - "metadata": {}, - "outputs": [], - "source": [ - "class MyTensorboardMetric(TensorboardMetric):\n", - "\n", - " # NOTE: We need to tell the new Tensorboard metric how to get the id /\n", - " # file handle for the tensorboard logs from a trial. In this case\n", - " # our convention is to just save a separate file per trial in\n", - " # the pre-specified log dir.\n", - " def _get_event_multiplexer_for_trial(self, trial):\n", - " mul = event_multiplexer.EventMultiplexer(max_reload_threads=20)\n", - " mul.AddRunsFromDirectory(Path(log_dir).joinpath(str(trial.index)).as_posix(), None)\n", - " mul.Reload()\n", - "\n", - " return mul\n", - "\n", - " # This indicates whether the metric is queryable while the trial is\n", - " # still running. This is required for early stopping to monitor the\n", - " # progress of the running trial.ArithmeticError\n", - " @classmethod\n", - " def is_available_while_running(cls):\n", - " return True" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a5c5a7d0", - "metadata": {}, - "outputs": [], - "source": [ - "val_acc = MyTensorboardMetric(\n", - " name=\"val_acc\",\n", - " tag=\"val_acc\",\n", - " lower_is_better=False,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "d4f3ba5d", - "metadata": {}, - "source": [ - "## Setting up the OptimizationConfig\n", - "\n", - "The `OptimizationConfig` specifies the objective for Ax to optimize." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ada66cf3", - "metadata": {}, - "outputs": [], - "source": [ - "opt_config = OptimizationConfig(\n", - " objective=Objective(\n", - " metric=val_acc,\n", - " minimize=False,\n", - " )\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "57aa9cf7", - "metadata": {}, - "source": [ - "## Defining an Early Stopping Strategy\n", - "\n", - "A `PercentileEarlyStoppingStrategy` is a simple method that stops a trial if its performance falls below a certain percentile of other trials at the same step (e.g., when `percentile_threshold` is 50, at a given point in time, if a trial ranks in the bottom 50% of trials, it is stopped). \n", - "- We make use of `normalize_progressions` which normalizes the progression column (e.g. timestamp, epochs, training data used) to be in [0, 1]. This is useful because one doesn't need to know the maximum progression values of the curve (which might be, e.g., the total number of data points in the training dataset).\n", - "- The `min_progression` parameter specifies that trials should only be considered for stopping if the latest progression value is greater than this threshold.\n", - "- The `min_curves` parameter specifies the minimum number of completed curves (i.e., fully completed training jobs) before early stopping will be considered. This should be larger than zero if `normalize_progression` is used. In general, we want a few completed curves to have a baseline for comparison.\n", - "\n", - "Note that `PercentileEarlyStoppingStrategy` does not make use of learning curve modeling or prediction. More sophisticated model-based methods will be available in future versions of Ax." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "949e8ab5", - "metadata": {}, - "outputs": [], - "source": [ - "percentile_early_stopping_strategy = PercentileEarlyStoppingStrategy(\n", - " # stop if in bottom 70% of runs at the same progression\n", - " percentile_threshold=70,\n", - " # the trial must have passed `min_progression` steps before early stopping is initiated\n", - " # note that we are using `normalize_progressions`, so this is on a scale of [0, 1]\n", - " min_progression=0.3,\n", - " # there must be `min_curves` completed trials and `min_curves` trials reporting data in\n", - " # order for early stopping to be applicable\n", - " min_curves=5,\n", - " # specify, e.g., [0, 1] if the first two trials should never be stopped\n", - " trial_indices_to_ignore=None,\n", - " normalize_progressions=True,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "2665ca93", - "metadata": {}, - "source": [ - "## Creating the Ax Experiment\n", - "\n", - "In Ax, the Experiment object is the object that stores all the information about the problem setup." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12849b31", - "metadata": {}, - "outputs": [], - "source": [ - "experiment = Experiment(\n", - " name=\"torchx_mnist\",\n", - " search_space=search_space,\n", - " optimization_config=opt_config,\n", - " runner=ax_runner,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "49a4ed0e", - "metadata": {}, - "source": [ - "## Choosing the GenerationStrategy\n", - "\n", - "A [GenerationStrategy](https://ax.dev/api/modelbridge.html#ax.modelbridge.generation_strategy.GenerationStrategy)\n", - "is the abstract representation of how we would like to perform the\n", - "optimization. While this can be customized (if you’d like to do so, see\n", - "[this tutorial](https://ax.dev/tutorials/generation_strategy.html)),\n", - "in most cases Ax can automatically determine an appropriate strategy\n", - "based on the search space, optimization config, and the total number\n", - "of trials we want to run.\n", - "\n", - "Typically, Ax chooses to evaluate a number of random configurations\n", - "before starting a model-based Bayesian Optimization strategy.\n", - "\n", - "We remark that in Ax, generation strategies and early stopping strategies are separate, a design decision motivated by ease-of-use. However, we should acknowledge that jointly considering generation and stopping using a single strategy would likely be the \"proper\" formulation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e38d0237", - "metadata": {}, - "outputs": [], - "source": [ - "if SMOKE_TEST:\n", - " total_trials = 6\n", - "else:\n", - " total_trials = 15 # total evaluation budget\n", - "\n", - "gs = choose_generation_strategy(\n", - " search_space=experiment.search_space,\n", - " optimization_config=experiment.optimization_config,\n", - " num_trials=total_trials,\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "78d86fea", - "metadata": {}, - "source": [ - "## Configuring the Scheduler\n", - "\n", - "The `Scheduler` acts as the loop control for the optimization.\n", - "It communicates with the backend to launch trials, check their status, retrieve (partial) results, and importantly for this tutorial, calls the early stopping strategy. If the early stopping strategy suggests a trial to be the stopped, the `Scheduler` communicates with the backend to terminate the trial.\n", - "\n", - "The ``Scheduler`` requires the ``Experiment`` and the ``GenerationStrategy``.\n", - "A set of options can be passed in via ``SchedulerOptions``. Here, we\n", - "configure the number of total evaluations as well as ``max_pending_trials``,\n", - "the maximum number of trials that should run concurrently. In our\n", - "local setting, this is the number of training jobs running as individual\n", - "processes, while in a remote execution setting, this would be the number\n", - "of machines you want to use in parallel.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "499fb9b5", - "metadata": {}, - "outputs": [], - "source": [ - "scheduler = Scheduler(\n", - " experiment=experiment,\n", - " generation_strategy=gs,\n", - " options=SchedulerOptions(\n", - " total_trials=total_trials,\n", - " max_pending_trials=5,\n", - " early_stopping_strategy=percentile_early_stopping_strategy,\n", - " ),\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "78257ebb", - "metadata": {}, - "outputs": [], - "source": [ - "%%time\n", - "scheduler.run_all_trials()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "8c5afbe8", - "metadata": {}, - "source": [ - "## Results\n", - "\n", - "First, we examine the data stored on the experiment. This shows that each trial is associated with an entire learning curve, represented by the column \"steps\"." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "764365f0", - "metadata": {}, - "outputs": [], - "source": [ - "experiment.lookup_data().map_df.head(n=10)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "0033ed2e", - "metadata": {}, - "source": [ - "Below is a summary of the experiment, showing that a portion of trials have been early stopped." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "00f2b35f", - "metadata": {}, - "outputs": [], - "source": [ - "exp_to_df(experiment)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "f8909cf2", - "metadata": {}, - "source": [ - "We can give a very rough estimate of the amount of computational savings due to early stopping, by looking at the total number of steps used when early stopping is used versus the number of steps used if we ran all trials to completion. Note to do a true comparison, one should run full HPO loops with and without early stopping (as early stopping will influence the model and future points selected by the generation strategy). " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5abb3ce8", - "metadata": {}, - "outputs": [], - "source": [ - "map_df = experiment.lookup_data().map_df\n", - "trial_to_max_steps = map_df.groupby(\"trial_index\")[\"step\"].max()\n", - "completed_trial_steps = trial_to_max_steps.iloc[0]\n", - "savings = 1.0 - trial_to_max_steps.sum() / (\n", - " completed_trial_steps * len(trial_to_max_steps)\n", - ")\n", - "# TODO format nicer\n", - "print(f\"A rough estimate of the computational savings is {100 * savings}%.\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "37df6964", - "metadata": {}, - "source": [ - "## Visualizations\n", - "\n", - "Finally, we show a visualization of learning curves versus actual elapsed wall time. This helps to illustrate that stopped trials make room for additional trials to be run." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c88cb8d0", - "metadata": {}, - "outputs": [], - "source": [ - "# helper function for getting trial start times\n", - "def time_started(row):\n", - " trial_index = row[\"trial_index\"]\n", - " return experiment.trials[trial_index].time_run_started\n", - "\n", - "\n", - "# helper function for getting trial completion times\n", - "def time_completed(row):\n", - " trial_index = row[\"trial_index\"]\n", - " return experiment.trials[trial_index].time_completed\n", - "\n", - "\n", - "# helper function for getting relevant data from experiment\n", - "# with early stopping into useful dfs\n", - "def early_stopping_exp_to_df(experiment):\n", - " trials_df = exp_to_df(experiment)\n", - " curve_df = experiment.lookup_data().map_df\n", - " training_row_df = (\n", - " curve_df.groupby(\"trial_index\").max().reset_index()[[\"trial_index\", \"steps\"]]\n", - " )\n", - " trials_df = trials_df.merge(training_row_df, on=\"trial_index\")\n", - " trials_df[\"time_started\"] = trials_df.apply(func=time_started, axis=1)\n", - " trials_df[\"time_completed\"] = trials_df.apply(func=time_completed, axis=1)\n", - " start_time = trials_df[\"time_started\"].min()\n", - " trials_df[\"time_started_rel\"] = (\n", - " trials_df[\"time_started\"] - start_time\n", - " ).dt.total_seconds()\n", - " trials_df[\"time_completed_rel\"] = (\n", - " trials_df[\"time_completed\"] - start_time\n", - " ).dt.total_seconds()\n", - " return trials_df, curve_df\n", - "\n", - "\n", - "def plot_curves_by_wall_time(trials_df, curve_df):\n", - " trials = set(curve_df[\"trial_index\"])\n", - " fig, ax = plt.subplots(1, 1, figsize=(10, 6))\n", - " ax.set(xlabel=\"seconds since start\", ylabel=\"validation accuracy\")\n", - " for trial_index in trials:\n", - " this_trial_df = curve_df[curve_df[\"trial_index\"] == trial_index]\n", - " start_time_rel = trials_df[\"time_started_rel\"].iloc[trial_index]\n", - " completed_time_rel = trials_df[\"time_completed_rel\"].iloc[trial_index]\n", - " total_steps = trials_df.loc[trial_index, \"steps\"]\n", - " smoothed_curve = this_trial_df[\"mean\"].rolling(window=3).mean()\n", - " x = (\n", - " start_time_rel\n", - " + (completed_time_rel - start_time_rel)\n", - " / total_steps\n", - " * this_trial_df[\"steps\"]\n", - " )\n", - " ax.plot(\n", - " x,\n", - " smoothed_curve,\n", - " label=f\"trial #{trial_index}\" if trial_index % 2 == 1 else None,\n", - " )\n", - " ax.legend()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d7f52fed", - "metadata": {}, - "outputs": [], - "source": [ - "# wrap in try/except in case of flaky I/O issues\n", - "try:\n", - " trials_df, curve_df = early_stopping_exp_to_df(experiment)\n", - " plot_curves_by_wall_time(trials_df, curve_df)\n", - "except Exception as e:\n", - " print(f\"Encountered exception while plotting results: {e}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "193e2fc7", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "12fe3797", + "metadata": {}, + "source": [ + "## Trial-level early stopping in Ax\n", + "\n", + "This tutorial illustrates how to add a trial-level early stopping strategy to an Ax hyper-parameter optimization (HPO) loop. The goal of trial-level early stopping is to monitor the results of expensive evaluations and terminate those that are unlikely to produce promising results, freeing up resources to explore more configurations.\n", + "\n", + "Most of this tutorial is adapted from the [PyTorch Ax Multiobjective NAS Tutorial](https://pytorch.org/tutorials/intermediate/ax_multiobjective_nas_tutorial.html). The training job is different from the original in that we do not optimize `batch_size` or `epochs`. This was done for illustrative purposes, as each validation curve now has the same number of points. The companion training file `mnist_train_nas.py` has also been altered to log to Tensorboard during training.\n", + "\n", + "NOTE: Although the original NAS tutorial is for a multi-objective problem, this tutorial focuses on a single objective (validation accuracy) problem. Early stopping currently does not support \\\"true\\\" multi-objective stopping, although one can use [logical compositions of early stopping strategies](https://github.com/facebook/Ax/blob/main/ax/early_stopping/strategies/logical.py) to target multiple objectives separately. Early stopping for the multi-objective case is currently a work in progress." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "779ea790", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import plotly.io as pio\n", + "if 'google.colab' in sys.modules:\n", + " pio.renderers.default = \"colab\"\n", + " %pip install ax-platform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cb953f30", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import tempfile\n", + "\n", + "from pathlib import Path\n", + "\n", + "import torchx\n", + "\n", + "from ax.core import Experiment, Objective, ParameterType, RangeParameter, SearchSpace\n", + "from ax.core.optimization_config import OptimizationConfig\n", + "\n", + "from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy\n", + "from ax.metrics.tensorboard import TensorboardMetric\n", + "\n", + "from ax.generation_strategy.dispatch_utils import choose_generation_strategy\n", + "\n", + "from ax.runners.torchx import TorchXRunner\n", + "\n", + "from ax.service.scheduler import Scheduler, SchedulerOptions\n", + "from ax.service.utils.report_utils import exp_to_df\n", + "\n", + "from tensorboard.backend.event_processing import plugin_event_multiplexer as event_multiplexer\n", + "\n", + "from torchx import specs\n", + "from torchx.components import utils\n", + "\n", + "from matplotlib import pyplot as plt\n", + "\n", + "\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a7bd328", + "metadata": {}, + "outputs": [], + "source": [ + "SMOKE_TEST = os.environ.get(\"SMOKE_TEST\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "fe2cf6fe", + "metadata": {}, + "source": [ + "## Defining the TorchX App\n", + "\n", + "Our goal is to optimize the PyTorch Lightning training job defined in\n", + "[mnist_train_nas.py](https://github.com/pytorch/tutorials/tree/master/intermediate_source/mnist_train_nas.py)_.\n", + "To do this using TorchX, we write a helper function that takes in\n", + "the values of the architcture and hyperparameters of the training\n", + "job and creates a [TorchX AppDef](https://pytorch.org/torchx/latest/basics.html)_\n", + "with the appropriate settings.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2e21d309", + "metadata": {}, + "outputs": [], + "source": [ + "if SMOKE_TEST:\n", + " epochs = 3\n", + "else:\n", + " epochs = 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b423923c", + "metadata": {}, + "outputs": [], + "source": [ + "def trainer(\n", + " log_path: str,\n", + " hidden_size_1: int,\n", + " hidden_size_2: int,\n", + " learning_rate: float,\n", + " dropout: float,\n", + " trial_idx: int = -1,\n", + ") -> specs.AppDef:\n", + "\n", + " # define the log path so we can pass it to the TorchX AppDef\n", + " if trial_idx >= 0:\n", + " log_path = Path(log_path).joinpath(str(trial_idx)).absolute().as_posix()\n", + "\n", + " batch_size = 32\n", + "\n", + " return utils.python(\n", + " # command line args to the training script\n", + " \"--log_path\",\n", + " log_path,\n", + " \"--hidden_size_1\",\n", + " str(hidden_size_1),\n", + " \"--hidden_size_2\",\n", + " str(hidden_size_2),\n", + " \"--learning_rate\",\n", + " str(learning_rate),\n", + " \"--epochs\",\n", + " str(epochs),\n", + " \"--dropout\",\n", + " str(dropout),\n", + " \"--batch_size\",\n", + " str(batch_size),\n", + " # other config options\n", + " name=\"trainer\",\n", + " script=\"tutorials/early_stopping/mnist_train_nas.py\",\n", + " image=torchx.version.TORCHX_IMAGE,\n", + " )" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "65f7011d", + "metadata": {}, + "source": [ + "## Setting up the Runner\n", + "\n", + "Ax’s [Runner](https://ax.dev/api/core.html#ax.core.runner.Runner)\n", + "abstraction allows writing interfaces to various backends.\n", + "Ax already comes with Runner for TorchX, so we just need to\n", + "configure it. For the purpose of this tutorial, we run jobs locally\n", + "in a fully asynchronous fashion. In order to launch them on a cluster, you can instead specify a\n", + "different TorchX scheduler and adjust the configuration appropriately.\n", + "For example, if you have a Kubernetes cluster, you just need to change the\n", + "scheduler from ``local_cwd`` to ``kubernetes``.\n", + "\n", + "The training job launched by this runner will log partial results to Tensorboard, which will then be monitored by the early stopping strategy. We will show how this is done using an Ax \n", + "[TensorboardMetric](https://ax.dev/api/metrics.html#module-ax.metrics.tensorboard) below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "998e5835", + "metadata": {}, + "outputs": [], + "source": [ + "# Make a temporary dir to log our results into\n", + "log_dir = tempfile.mkdtemp()\n", + "\n", + "ax_runner = TorchXRunner(\n", + " tracker_base=\"/tmp/\",\n", + " component=trainer,\n", + " # NOTE: To launch this job on a cluster instead of locally you can\n", + " # specify a different scheduler and adjust args appropriately.\n", + " scheduler=\"local_cwd\",\n", + " component_const_params={\"log_path\": log_dir},\n", + " cfg={},\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "2fec7495", + "metadata": {}, + "source": [ + "## Setting up the SearchSpace\n", + "\n", + "First, we define our search space. Ax supports both range parameters\n", + "of type integer and float as well as choice parameters which can have\n", + "non-numerical types such as strings.\n", + "We will tune the hidden sizes, learning rate, and dropout parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf6f869f", + "metadata": {}, + "outputs": [], + "source": [ + "parameters = [\n", + " # NOTE: In a real-world setting, hidden_size_1 and hidden_size_2\n", + " # should probably be powers of 2, but in our simple example this\n", + " # would mean that num_params can't take on that many values, which\n", + " # in turn makes the Pareto frontier look pretty weird.\n", + " RangeParameter(\n", + " name=\"hidden_size_1\",\n", + " lower=16,\n", + " upper=128,\n", + " parameter_type=ParameterType.INT,\n", + " log_scale=True,\n", + " ),\n", + " RangeParameter(\n", + " name=\"hidden_size_2\",\n", + " lower=16,\n", + " upper=128,\n", + " parameter_type=ParameterType.INT,\n", + " log_scale=True,\n", + " ),\n", + " RangeParameter(\n", + " name=\"learning_rate\",\n", + " lower=1e-4,\n", + " upper=1e-2,\n", + " parameter_type=ParameterType.FLOAT,\n", + " log_scale=True,\n", + " ),\n", + " RangeParameter(\n", + " name=\"dropout\",\n", + " lower=0.0,\n", + " upper=0.5,\n", + " parameter_type=ParameterType.FLOAT,\n", + " ),\n", + "]\n", + "\n", + "search_space = SearchSpace(\n", + " parameters=parameters,\n", + " # NOTE: In practice, it may make sense to add a constraint\n", + " # hidden_size_2 <= hidden_size_1\n", + " parameter_constraints=[],\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "a8005e80", + "metadata": {}, + "source": [ + "## Setting up Metrics\n", + "\n", + "Ax has the concept of a Metric that defines properties of outcomes and how observations are obtained for these outcomes. This allows e.g. encodig how data is fetched from some distributed execution backend and post-processed before being passed as input to Ax.\n", + "\n", + "We will optimize the validation accuracy, which is a `TensorboardMetric` that points to the logging directory assigned above. Note that we have set `is_available_while_running`, allowing for the metric to be queried as the trial progresses. This is critical for the early stopping strategy to monitor partial results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0775a96e", + "metadata": {}, + "outputs": [], + "source": [ + "class MyTensorboardMetric(TensorboardMetric):\n", + "\n", + " # NOTE: We need to tell the new Tensorboard metric how to get the id /\n", + " # file handle for the tensorboard logs from a trial. In this case\n", + " # our convention is to just save a separate file per trial in\n", + " # the pre-specified log dir.\n", + " def _get_event_multiplexer_for_trial(self, trial):\n", + " mul = event_multiplexer.EventMultiplexer(max_reload_threads=20)\n", + " mul.AddRunsFromDirectory(Path(log_dir).joinpath(str(trial.index)).as_posix(), None)\n", + " mul.Reload()\n", + "\n", + " return mul\n", + "\n", + " # This indicates whether the metric is queryable while the trial is\n", + " # still running. This is required for early stopping to monitor the\n", + " # progress of the running trial.ArithmeticError\n", + " @classmethod\n", + " def is_available_while_running(cls):\n", + " return True" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a5c5a7d0", + "metadata": {}, + "outputs": [], + "source": [ + "val_acc = MyTensorboardMetric(\n", + " name=\"val_acc\",\n", + " tag=\"val_acc\",\n", + " lower_is_better=False,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "d4f3ba5d", + "metadata": {}, + "source": [ + "## Setting up the OptimizationConfig\n", + "\n", + "The `OptimizationConfig` specifies the objective for Ax to optimize." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ada66cf3", + "metadata": {}, + "outputs": [], + "source": [ + "opt_config = OptimizationConfig(\n", + " objective=Objective(\n", + " metric=val_acc,\n", + " minimize=False,\n", + " )\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "57aa9cf7", + "metadata": {}, + "source": [ + "## Defining an Early Stopping Strategy\n", + "\n", + "A `PercentileEarlyStoppingStrategy` is a simple method that stops a trial if its performance falls below a certain percentile of other trials at the same step (e.g., when `percentile_threshold` is 50, at a given point in time, if a trial ranks in the bottom 50% of trials, it is stopped). \n", + "- We make use of `normalize_progressions` which normalizes the progression column (e.g. timestamp, epochs, training data used) to be in [0, 1]. This is useful because one doesn't need to know the maximum progression values of the curve (which might be, e.g., the total number of data points in the training dataset).\n", + "- The `min_progression` parameter specifies that trials should only be considered for stopping if the latest progression value is greater than this threshold.\n", + "- The `min_curves` parameter specifies the minimum number of completed curves (i.e., fully completed training jobs) before early stopping will be considered. This should be larger than zero if `normalize_progression` is used. In general, we want a few completed curves to have a baseline for comparison.\n", + "\n", + "Note that `PercentileEarlyStoppingStrategy` does not make use of learning curve modeling or prediction. More sophisticated model-based methods will be available in future versions of Ax." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "949e8ab5", + "metadata": {}, + "outputs": [], + "source": [ + "percentile_early_stopping_strategy = PercentileEarlyStoppingStrategy(\n", + " # stop if in bottom 70% of runs at the same progression\n", + " percentile_threshold=70,\n", + " # the trial must have passed `min_progression` steps before early stopping is initiated\n", + " # note that we are using `normalize_progressions`, so this is on a scale of [0, 1]\n", + " min_progression=0.3,\n", + " # there must be `min_curves` completed trials and `min_curves` trials reporting data in\n", + " # order for early stopping to be applicable\n", + " min_curves=5,\n", + " # specify, e.g., [0, 1] if the first two trials should never be stopped\n", + " trial_indices_to_ignore=None,\n", + " normalize_progressions=True,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "2665ca93", + "metadata": {}, + "source": [ + "## Creating the Ax Experiment\n", + "\n", + "In Ax, the Experiment object is the object that stores all the information about the problem setup." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12849b31", + "metadata": {}, + "outputs": [], + "source": [ + "experiment = Experiment(\n", + " name=\"torchx_mnist\",\n", + " search_space=search_space,\n", + " optimization_config=opt_config,\n", + " runner=ax_runner,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "49a4ed0e", + "metadata": {}, + "source": [ + "## Choosing the GenerationStrategy\n", + "\n", + "A [GenerationStrategy](https://ax.dev/api/modelbridge.html#ax.generation_strategy.generation_strategy.GenerationStrategy)\n", + "is the abstract representation of how we would like to perform the\n", + "optimization. While this can be customized (if you’d like to do so, see\n", + "[this tutorial](https://ax.dev/tutorials/generation_strategy.html)),\n", + "in most cases Ax can automatically determine an appropriate strategy\n", + "based on the search space, optimization config, and the total number\n", + "of trials we want to run.\n", + "\n", + "Typically, Ax chooses to evaluate a number of random configurations\n", + "before starting a model-based Bayesian Optimization strategy.\n", + "\n", + "We remark that in Ax, generation strategies and early stopping strategies are separate, a design decision motivated by ease-of-use. However, we should acknowledge that jointly considering generation and stopping using a single strategy would likely be the \"proper\" formulation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e38d0237", + "metadata": {}, + "outputs": [], + "source": [ + "if SMOKE_TEST:\n", + " total_trials = 6\n", + "else:\n", + " total_trials = 15 # total evaluation budget\n", + "\n", + "gs = choose_generation_strategy(\n", + " search_space=experiment.search_space,\n", + " optimization_config=experiment.optimization_config,\n", + " num_trials=total_trials,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "78d86fea", + "metadata": {}, + "source": [ + "## Configuring the Scheduler\n", + "\n", + "The `Scheduler` acts as the loop control for the optimization.\n", + "It communicates with the backend to launch trials, check their status, retrieve (partial) results, and importantly for this tutorial, calls the early stopping strategy. If the early stopping strategy suggests a trial to be the stopped, the `Scheduler` communicates with the backend to terminate the trial.\n", + "\n", + "The ``Scheduler`` requires the ``Experiment`` and the ``GenerationStrategy``.\n", + "A set of options can be passed in via ``SchedulerOptions``. Here, we\n", + "configure the number of total evaluations as well as ``max_pending_trials``,\n", + "the maximum number of trials that should run concurrently. In our\n", + "local setting, this is the number of training jobs running as individual\n", + "processes, while in a remote execution setting, this would be the number\n", + "of machines you want to use in parallel.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "499fb9b5", + "metadata": {}, + "outputs": [], + "source": [ + "scheduler = Scheduler(\n", + " experiment=experiment,\n", + " generation_strategy=gs,\n", + " options=SchedulerOptions(\n", + " total_trials=total_trials,\n", + " max_pending_trials=5,\n", + " early_stopping_strategy=percentile_early_stopping_strategy,\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "78257ebb", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "scheduler.run_all_trials()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "8c5afbe8", + "metadata": {}, + "source": [ + "## Results\n", + "\n", + "First, we examine the data stored on the experiment. This shows that each trial is associated with an entire learning curve, represented by the column \"steps\"." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "764365f0", + "metadata": {}, + "outputs": [], + "source": [ + "experiment.lookup_data().map_df.head(n=10)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "0033ed2e", + "metadata": {}, + "source": [ + "Below is a summary of the experiment, showing that a portion of trials have been early stopped." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00f2b35f", + "metadata": {}, + "outputs": [], + "source": [ + "exp_to_df(experiment)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "f8909cf2", + "metadata": {}, + "source": [ + "We can give a very rough estimate of the amount of computational savings due to early stopping, by looking at the total number of steps used when early stopping is used versus the number of steps used if we ran all trials to completion. Note to do a true comparison, one should run full HPO loops with and without early stopping (as early stopping will influence the model and future points selected by the generation strategy). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5abb3ce8", + "metadata": {}, + "outputs": [], + "source": [ + "map_df = experiment.lookup_data().map_df\n", + "trial_to_max_steps = map_df.groupby(\"trial_index\")[\"step\"].max()\n", + "completed_trial_steps = trial_to_max_steps.iloc[0]\n", + "savings = 1.0 - trial_to_max_steps.sum() / (\n", + " completed_trial_steps * len(trial_to_max_steps)\n", + ")\n", + "# TODO format nicer\n", + "print(f\"A rough estimate of the computational savings is {100 * savings}%.\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "37df6964", + "metadata": {}, + "source": [ + "## Visualizations\n", + "\n", + "Finally, we show a visualization of learning curves versus actual elapsed wall time. This helps to illustrate that stopped trials make room for additional trials to be run." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c88cb8d0", + "metadata": {}, + "outputs": [], + "source": [ + "# helper function for getting trial start times\n", + "def time_started(row):\n", + " trial_index = row[\"trial_index\"]\n", + " return experiment.trials[trial_index].time_run_started\n", + "\n", + "\n", + "# helper function for getting trial completion times\n", + "def time_completed(row):\n", + " trial_index = row[\"trial_index\"]\n", + " return experiment.trials[trial_index].time_completed\n", + "\n", + "\n", + "# helper function for getting relevant data from experiment\n", + "# with early stopping into useful dfs\n", + "def early_stopping_exp_to_df(experiment):\n", + " trials_df = exp_to_df(experiment)\n", + " curve_df = experiment.lookup_data().map_df\n", + " training_row_df = (\n", + " curve_df.groupby(\"trial_index\").max().reset_index()[[\"trial_index\", \"steps\"]]\n", + " )\n", + " trials_df = trials_df.merge(training_row_df, on=\"trial_index\")\n", + " trials_df[\"time_started\"] = trials_df.apply(func=time_started, axis=1)\n", + " trials_df[\"time_completed\"] = trials_df.apply(func=time_completed, axis=1)\n", + " start_time = trials_df[\"time_started\"].min()\n", + " trials_df[\"time_started_rel\"] = (\n", + " trials_df[\"time_started\"] - start_time\n", + " ).dt.total_seconds()\n", + " trials_df[\"time_completed_rel\"] = (\n", + " trials_df[\"time_completed\"] - start_time\n", + " ).dt.total_seconds()\n", + " return trials_df, curve_df\n", + "\n", + "\n", + "def plot_curves_by_wall_time(trials_df, curve_df):\n", + " trials = set(curve_df[\"trial_index\"])\n", + " fig, ax = plt.subplots(1, 1, figsize=(10, 6))\n", + " ax.set(xlabel=\"seconds since start\", ylabel=\"validation accuracy\")\n", + " for trial_index in trials:\n", + " this_trial_df = curve_df[curve_df[\"trial_index\"] == trial_index]\n", + " start_time_rel = trials_df[\"time_started_rel\"].iloc[trial_index]\n", + " completed_time_rel = trials_df[\"time_completed_rel\"].iloc[trial_index]\n", + " total_steps = trials_df.loc[trial_index, \"steps\"]\n", + " smoothed_curve = this_trial_df[\"mean\"].rolling(window=3).mean()\n", + " x = (\n", + " start_time_rel\n", + " + (completed_time_rel - start_time_rel)\n", + " / total_steps\n", + " * this_trial_df[\"steps\"]\n", + " )\n", + " ax.plot(\n", + " x,\n", + " smoothed_curve,\n", + " label=f\"trial #{trial_index}\" if trial_index % 2 == 1 else None,\n", + " )\n", + " ax.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7f52fed", + "metadata": {}, + "outputs": [], + "source": [ + "# wrap in try/except in case of flaky I/O issues\n", + "try:\n", + " trials_df, curve_df = early_stopping_exp_to_df(experiment)\n", + " plot_curves_by_wall_time(trials_df, curve_df)\n", + "except Exception as e:\n", + " print(f\"Encountered exception while plotting results: {e}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "193e2fc7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/tutorials/external_generation_node/external_generation_node.ipynb b/tutorials/external_generation_node/external_generation_node.ipynb index 0f87fdc3be7..61c9c3f7f4a 100644 --- a/tutorials/external_generation_node/external_generation_node.ipynb +++ b/tutorials/external_generation_node/external_generation_node.ipynb @@ -59,12 +59,12 @@ "from ax.core.experiment import Experiment\n", "from ax.core.parameter import RangeParameter\n", "from ax.core.types import TParameterization\n", - "from ax.modelbridge.external_generation_node import ExternalGenerationNode\n", - "from ax.modelbridge.generation_node import GenerationNode\n", - "from ax.modelbridge.generation_strategy import GenerationStrategy\n", - "from ax.modelbridge.model_spec import GeneratorSpec\n", + "from ax.generation_strategy.external_generation_node import ExternalGenerationNode\n", + "from ax.generation_strategy.generation_node import GenerationNode\n", + "from ax.generation_strategy.generation_strategy import GenerationStrategy\n", + "from ax.generation_strategy.model_spec import GeneratorSpec\n", + "from ax.generation_strategy.transition_criterion import MaxTrials\n", "from ax.modelbridge.registry import Generators\n", - "from ax.modelbridge.transition_criterion import MaxTrials\n", "from ax.plot.trace import plot_objective_value_vs_trial_index\n", "from ax.service.ax_client import AxClient, ObjectiveProperties\n", "from ax.service.utils.report_utils import exp_to_df\n", @@ -406,7 +406,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.16" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/generation_strategy/generation_strategy.ipynb b/tutorials/generation_strategy/generation_strategy.ipynb index 704a1432b7e..eb54dee59c1 100644 --- a/tutorials/generation_strategy/generation_strategy.ipynb +++ b/tutorials/generation_strategy/generation_strategy.ipynb @@ -19,8 +19,8 @@ "metadata": {}, "outputs": [], "source": [ - "from ax.modelbridge.dispatch_utils import choose_generation_strategy\n", - "from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy\n", + "from ax.generation_strategy.dispatch_utils import choose_generation_strategy\n", + "from ax.generation_strategy.generation_strategy import GenerationStep, GenerationStrategy\n", "from ax.modelbridge.modelbridge_utils import get_pending_observation_features\n", "from ax.modelbridge.registry import ModelRegistryBase, Generators\n", "\n", @@ -33,7 +33,7 @@ "source": [ "# Generation Strategy (GS) Tutorial\n", "\n", - "`GenerationStrategy` ([API reference](https://ax.dev/api/modelbridge.html#ax.modelbridge.generation_strategy.GenerationStrategy)) is a key abstraction in Ax:\n", + "`GenerationStrategy` ([API reference](https://ax.dev/api/modelbridge.html#ax.generation_strategy.generation_strategy.GenerationStrategy)) is a key abstraction in Ax:\n", "- It allows for specifying multiple optimization algorithms to chain one after another in the course of the optimization. \n", "- Many higher-level APIs in Ax use generation strategies: Service and Loop APIs, `Scheduler` etc. (tutorials for all those higher-level APIs are here: https://ax.dev/tutorials/).\n", "- Generation strategy allows for storage and resumption of modeling setups, making optimization resumable from SQL or JSON snapshots.\n", @@ -123,7 +123,7 @@ " # Required arguments:\n", " search_space=get_branin_search_space(), # Ax `SearchSpace`\n", " # Some optional arguments (shown with their defaults), see API docs for more settings:\n", - " # https://ax.dev/api/modelbridge.html#module-ax.modelbridge.dispatch_utils\n", + " # https://ax.dev/api/modelbridge.html#module-ax.generation_strategy.dispatch_utils\n", " use_batch_trials=False, # Whether this GS will be used to generate 1-arm `Trial`-s or `BatchTrials`\n", " no_bayesian_optimization=False, # Use quasi-random candidate generation without BayesOpt\n", " max_parallelism_override=None, # Integer, to which to set the `max_parallelism` setting of all steps in this GS\n", @@ -459,7 +459,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.16" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/modular_botax/modular_botax.ipynb b/tutorials/modular_botax/modular_botax.ipynb index 8b822a6e6c3..a9014bc8898 100644 --- a/tutorials/modular_botax/modular_botax.ipynb +++ b/tutorials/modular_botax/modular_botax.ipynb @@ -959,7 +959,7 @@ }, "outputs": [], "source": [ - "from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy\n", + "from ax.generation_strategy.generation_strategy import GenerationStep, GenerationStrategy\n", "from ax.modelbridge.modelbridge_utils import get_pending_observation_features\n", "\n", "gs = GenerationStrategy(\n", @@ -1463,7 +1463,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.16" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/tutorials/scheduler/scheduler.ipynb b/tutorials/scheduler/scheduler.ipynb index 650bd3ca127..c1031ff7d43 100644 --- a/tutorials/scheduler/scheduler.ipynb +++ b/tutorials/scheduler/scheduler.ipynb @@ -386,7 +386,7 @@ }, "outputs": [], "source": [ - "from ax.modelbridge.dispatch_utils import choose_generation_strategy\n", + "from ax.generation_strategy.dispatch_utils import choose_generation_strategy\n", "\n", "generation_strategy = choose_generation_strategy(\n", " search_space=experiment.search_space,\n", @@ -925,9 +925,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.12.4" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/tutorials/sebo/sebo.ipynb b/tutorials/sebo/sebo.ipynb index f3e2f100663..ec8cc8c0bbc 100644 --- a/tutorials/sebo/sebo.ipynb +++ b/tutorials/sebo/sebo.ipynb @@ -67,7 +67,7 @@ "from ax.core.objective import Objective\n", "from ax.core.optimization_config import OptimizationConfig\n", "from ax.metrics.noisy_function import NoisyFunctionMetric\n", - "from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy\n", + "from ax.generation_strategy.generation_strategy import GenerationStep, GenerationStrategy\n", "from ax.modelbridge.registry import Generators\n", "from ax.models.torch.botorch_modular.sebo import SEBOAcquisition\n", "from ax.models.torch.botorch_modular.surrogate import Surrogate\n", @@ -653,7 +653,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.16" + "version": "3.12.4" } }, "nbformat": 4,