Skip to content

Commit

Permalink
constraints feasibilty via GPs (facebook#3152)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#3152

Warn users if their constraints aren't satisfied above the given threshold for any of the arms. The constraints feasibility is computed using the GP model fit and the user provided constraint bounds.

Reviewed By: danielcohenlive, Balandat

Differential Revision: D66398437

fbshipit-source-id: 4dc59b6fbf296b1a659fcb951e0730a1a8184320
  • Loading branch information
Jelena Markovic-Voronov authored and facebook-github-bot committed Dec 6, 2024
1 parent d709c5d commit f9a9fd6
Show file tree
Hide file tree
Showing 4 changed files with 426 additions and 0 deletions.
5 changes: 5 additions & 0 deletions ax/analysis/healthcheck/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
from ax.analysis.healthcheck.can_generate_candidates import (
CanGenerateCandidatesAnalysis,
)

from ax.analysis.healthcheck.constraints_feasibility import (
ConstraintsFeasibilityAnalysis,
)
from ax.analysis.healthcheck.healthcheck_analysis import (
HealthcheckAnalysis,
HealthcheckAnalysisCard,
Expand All @@ -16,6 +20,7 @@
from ax.analysis.healthcheck.should_generate_candidates import ShouldGenerateCandidates

__all__ = [
"ConstraintsFeasibilityAnalysis",
"CanGenerateCandidatesAnalysis",
"HealthcheckAnalysis",
"HealthcheckAnalysisCard",
Expand Down
200 changes: 200 additions & 0 deletions ax/analysis/healthcheck/constraints_feasibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import json
from typing import Tuple

import pandas as pd

from ax.analysis.analysis import AnalysisCardLevel

from ax.analysis.healthcheck.healthcheck_analysis import (
HealthcheckAnalysis,
HealthcheckAnalysisCard,
HealthcheckStatus,
)
from ax.analysis.plotly.arm_effects.utils import get_predictions_by_arm
from ax.analysis.plotly.utils import is_predictive
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.optimization_config import OptimizationConfig
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.utils.common.typeutils import checked_cast
from pyre_extensions import none_throws


class ConstraintsFeasibilityAnalysis(HealthcheckAnalysis):
"""
Analysis for checking the feasibility of the constraints for the experiment.
A constraint is considered feasible if the probability of constraints violation
is below the threshold for at least one arm.
"""

def compute(
self,
experiment: Experiment | None = None,
generation_strategy: GenerationStrategyInterface | None = None,
prob_threshold: float = 0.90,
) -> HealthcheckAnalysisCard:
r"""
Compute the feasibility of the constraints for the experiment.
Args:
experiment: Ax experiment.
generation_strategy: Ax generation strategy.
prob_threhshold: Threshold for the probability of constraint violation.
Constraints are considered feasible if the probability of constraint
violation is below the threshold for at least one arm.
Returns:
A HealthcheckAnalysisCard object with the information on infeasible metrics,
i.e., metrics for which the constraints are infeasible for all test groups
(arms).
"""
status = HealthcheckStatus.PASS
subtitle = "All constraints are feasible."
title_status = "Success"
level = AnalysisCardLevel.LOW
df = pd.DataFrame({"status": [status]})

if experiment is None:
raise UserInputError(
"ConstraintsFeasibilityAnalysis requires an Experiment."
)

if experiment.optimization_config is None:
raise UserInputError(
"ConstraintsFeasibilityAnalysis requires an Experiment with an "
"optimization config."
)

if (
experiment.optimization_config.outcome_constraints is None
or len(experiment.optimization_config.outcome_constraints) == 0
):
subtitle = "No constraints are specified."
return HealthcheckAnalysisCard(
name="ConstraintsFeasibility",
title=f"Ax Constraints Feasibility {title_status}",
blob=json.dumps({"status": status}),
subtitle=subtitle,
df=df,
level=level,
)

if generation_strategy is None:
raise UserInputError(
"ConstraintsFeasibilityAnalysis requires a GenerationStrategy."
)
generation_strategy = checked_cast(
GenerationStrategy,
generation_strategy,
exception=UserInputError(
"ConstraintsFeasibilityAnalysis requires a GenerationStrategy."
),
)

if generation_strategy.model is None:
generation_strategy._fit_current_model(data=experiment.lookup_data())

model = none_throws(generation_strategy.model)
if not is_predictive(model=model):
raise UserInputError(
"ConstraintsFeasibility requires a GenerationStrategy which is "
"in a state where the current model supports prediction. "
"The current model is {model._model_key} and does not support "
"prediction."
)
optimization_config = checked_cast(
OptimizationConfig, experiment.optimization_config
)
constraints_feasible, df = constraints_feasibility(
optimization_config=optimization_config,
model=model,
prob_threshold=prob_threshold,
)
df["status"] = status

if not constraints_feasible:
status = HealthcheckStatus.WARNING
subtitle = (
"Constraints are infeasible for all test groups (arms) with respect "
f"to the probability threshold {prob_threshold}. "
"We suggest relaxing the constraint bounds for the constraints."
)
title_status = "Warning"
df.loc[
df["overall_probability_constraints_violated"] > prob_threshold,
"status",
] = status

return HealthcheckAnalysisCard(
name="ConstraintsFeasibility",
title=f"Ax Constraints Feasibility {title_status}",
blob=json.dumps({"status": status}),
subtitle=subtitle,
df=df,
level=level,
)


def constraints_feasibility(
optimization_config: OptimizationConfig,
model: ModelBridge,
prob_threshold: float = 0.99,
) -> Tuple[bool, pd.DataFrame]:
r"""
Check the feasibility of the constraints for the experiment.
Args:
optimization_config: Ax optimization config.
model: Ax model to use for predictions.
prob_threshold: Threshold for the probability of constraint violation.
Returns:
A tuple of a boolean indicating whether the constraints are feasible and a
dataframe with information on the probabilities of constraints violation for
each arm.
"""
if (optimization_config.outcome_constraints is None) or (
len(optimization_config.outcome_constraints) == 0
):
raise UserInputError("No constraints are specified.")

derel_optimization_config = optimization_config
outcome_constraints = optimization_config.outcome_constraints

if any(constraint.relative for constraint in outcome_constraints):
derel_optimization_config = Derelativize().transform_optimization_config(
optimization_config=optimization_config,
modelbridge=model,
)

constraint_metric_name = [
constraint.metric.name
for constraint in derel_optimization_config.outcome_constraints
][0]

arm_dict = get_predictions_by_arm(
model=model,
metric_name=constraint_metric_name,
outcome_constraints=derel_optimization_config.outcome_constraints,
)

df = pd.DataFrame(arm_dict)
constraints_feasible = True
if all(
arm_info["overall_probability_constraints_violated"] > prob_threshold
for arm_info in arm_dict
if arm_info["arm_name"] != model.status_quo_name
):
constraints_feasible = False

return constraints_feasible, df
Loading

0 comments on commit f9a9fd6

Please sign in to comment.