From 35203a77df59b094828f4fd5b800a3795018b83a Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Wed, 20 Nov 2024 19:54:01 -0800 Subject: [PATCH 1/5] SlicePlot (for use in InteractionPlot) Differential Revision: D65234219 --- ax/analysis/plotly/__init__.py | 2 + ax/analysis/plotly/surface/__init__.py | 8 + ax/analysis/plotly/surface/slice.py | 197 ++++++++++++++++++ .../plotly/surface/tests/test_slice.py | 73 +++++++ ax/analysis/plotly/surface/utils.py | 62 ++++++ 5 files changed, 342 insertions(+) create mode 100644 ax/analysis/plotly/surface/__init__.py create mode 100644 ax/analysis/plotly/surface/slice.py create mode 100644 ax/analysis/plotly/surface/tests/test_slice.py create mode 100644 ax/analysis/plotly/surface/utils.py diff --git a/ax/analysis/plotly/__init__.py b/ax/analysis/plotly/__init__.py index 078ad5b594a..a4f7b934ea6 100644 --- a/ax/analysis/plotly/__init__.py +++ b/ax/analysis/plotly/__init__.py @@ -9,6 +9,7 @@ from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.analysis.plotly.scatter import ScatterPlot +from ax.analysis.plotly.surface.slice import SlicePlot __all__ = [ "CrossValidationPlot", @@ -16,4 +17,5 @@ "PlotlyAnalysisCard", "ParallelCoordinatesPlot", "ScatterPlot", + "SlicePlot", ] diff --git a/ax/analysis/plotly/surface/__init__.py b/ax/analysis/plotly/surface/__init__.py new file mode 100644 index 00000000000..5427f2ca6ec --- /dev/null +++ b/ax/analysis/plotly/surface/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from ax.analysis.plotly.surface.slice import SlicePlot + +__all__ = ["SlicePlot"] diff --git a/ax/analysis/plotly/surface/slice.py b/ax/analysis/plotly/surface/slice.py new file mode 100644 index 00000000000..129f4275824 --- /dev/null +++ b/ax/analysis/plotly/surface/slice.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import math +from typing import Optional + +import pandas as pd +from ax.analysis.analysis import AnalysisCardLevel + +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard +from ax.analysis.plotly.surface.utils import ( + get_parameter_values, + is_axis_log_scale, + select_fixed_value, +) +from ax.analysis.plotly.utils import select_metric +from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.core.observation import ObservationFeatures +from ax.exceptions.core import UserInputError +from ax.modelbridge.base import ModelBridge +from ax.modelbridge.generation_strategy import GenerationStrategy +from plotly import express as px, graph_objects as go +from pyre_extensions import none_throws + + +class SlicePlot(PlotlyAnalysis): + """ + Plot a 1D "slice" of the surrogate model's predicted outcomes for a given + parameter, where all other parameters are held fixed at their status-quo value or + mean if no status quo is available. + + The DataFrame computed will contain the following columns: + - PARAMETER_NAME: The value of the parameter specified + - METRIC_NAME_mean: The predected mean of the metric specified + - METRIC_NAME_sem: The predected sem of the metric specified + """ + + def __init__( + self, + parameter_name: str, + metric_name: str | None = None, + ) -> None: + """ + Args: + parameter_name: The name of the parameter to plot on the x axis. + metric_name: The name of the metric to plot on the y axis. If not + specified the objective will be used. + """ + self.parameter_name = parameter_name + self.metric_name = metric_name + + def compute( + self, + experiment: Optional[Experiment] = None, + generation_strategy: Optional[GenerationStrategyInterface] = None, + ) -> PlotlyAnalysisCard: + if experiment is None: + raise UserInputError("SlicePlot requires an Experiment") + + if not isinstance(generation_strategy, GenerationStrategy): + raise UserInputError("SlicePlot requires a GenerationStrategy") + + if generation_strategy.model is None: + generation_strategy._fit_current_model(None) + + metric_name = self.metric_name or select_metric(experiment=experiment) + + df = _prepare_data( + experiment=experiment, + model=none_throws(generation_strategy.model), + parameter_name=self.parameter_name, + metric_name=metric_name, + ) + + fig = _prepare_plot( + df=df, + parameter_name=self.parameter_name, + metric_name=metric_name, + log_x=is_axis_log_scale( + parameter=experiment.search_space.parameters[self.parameter_name] + ), + ) + + return self._create_plotly_analysis_card( + title=f"{self.parameter_name} vs. {metric_name}", + subtitle=( + "1D slice of the surrogate model's predicted outcomes for " + f"{metric_name}" + ), + level=AnalysisCardLevel.LOW, + df=df, + fig=fig, + ) + + +def _prepare_data( + experiment: Experiment, + model: ModelBridge, + parameter_name: str, + metric_name: str, +) -> pd.DataFrame: + # Choose which parameter values to predict points for. + xs = get_parameter_values( + parameter=experiment.search_space.parameters[parameter_name] + ) + + # Construct observation features for each parameter value previously chosen by + # fixing all other parameters to their status-quo value or mean. + features = [ + ObservationFeatures( + parameters={ + parameter_name: x, + **{ + parameter.name: select_fixed_value(parameter=parameter) + for parameter in experiment.search_space.parameters.values() + if parameter.name != parameter_name + }, + } + ) + for x in xs + ] + + predictions = model.predict(observation_features=features) + + return pd.DataFrame.from_records( + [ + { + parameter_name: xs[i], + f"{metric_name}_mean": predictions[0][metric_name][i], + f"{metric_name}_sem": predictions[1][metric_name][metric_name][i], + } + for i in range(len(xs)) + ] + ).sort_values(by=parameter_name) + + +def _prepare_plot( + df: pd.DataFrame, + parameter_name: str, + metric_name: str, + log_x: bool = False, +) -> go.Figure: + x = df[parameter_name].tolist() + y = df[f"{metric_name}_mean"].tolist() + y_upper = (df[f"{metric_name}_mean"] + 1.96 * df[f"{metric_name}_sem"]).tolist() + y_lower = (df[f"{metric_name}_mean"] - 1.96 * df[f"{metric_name}_sem"]).tolist() + + plotly_blue = px.colors.qualitative.Plotly[0] + plotly_blue_translucent = "rgba(99, 110, 250, 0.2)" + + # Draw a line at the mean and a shaded region between the upper and lower bounds + line = go.Scatter( + x=x, + y=y, + line={"color": plotly_blue}, + mode="lines", + name=metric_name, + showlegend=False, + ) + error_band = go.Scatter( + # Concatenate x values in reverse order to create a closed polygon + x=x + x[::-1], + # Concatenate upper and lower bounds in reverse order + y=y_upper + y_lower[::-1], + fill="toself", + fillcolor=plotly_blue_translucent, + line={"color": "rgba(255,255,255,0)"}, # Make "line" transparent + hoverinfo="skip", + showlegend=False, + ) + + fig = go.Figure( + [line, error_band], + layout=go.Layout( + xaxis_title=parameter_name, + yaxis_title=metric_name, + ), + ) + + # Set the x-axis scale to log if relevant + if log_x: + fig.update_xaxes( + type="log", + range=[ + math.log10(df[parameter_name].min()), + math.log10(df[parameter_name].max()), + ], + ) + else: + fig.update_xaxes(range=[df[parameter_name].min(), df[parameter_name].max()]) + + return fig diff --git a/ax/analysis/plotly/surface/tests/test_slice.py b/ax/analysis/plotly/surface/tests/test_slice.py new file mode 100644 index 00000000000..557a7665c37 --- /dev/null +++ b/ax/analysis/plotly/surface/tests/test_slice.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.plotly.surface.slice import SlicePlot +from ax.exceptions.core import UserInputError +from ax.service.ax_client import AxClient, ObjectiveProperties +from ax.utils.common.testutils import TestCase +from ax.utils.testing.mock import mock_botorch_optimize + + +class TestSlicePlot(TestCase): + @mock_botorch_optimize + def setUp(self) -> None: + super().setUp() + self.client = AxClient() + self.client.create_experiment( + is_test=True, + name="foo", + parameters=[ + { + "name": "x", + "type": "range", + "bounds": [-1.0, 1.0], + } + ], + objectives={"bar": ObjectiveProperties(minimize=True)}, + ) + + for _ in range(10): + parameterization, trial_index = self.client.get_next_trial() + self.client.complete_trial( + trial_index=trial_index, raw_data={"bar": parameterization["x"] ** 2} + ) + + def test_compute(self) -> None: + analysis = SlicePlot(parameter_name="x", metric_name="bar") + + # Test that it fails if no Experiment is provided + with self.assertRaisesRegex(UserInputError, "requires an Experiment"): + analysis.compute() + # Test that it fails if no GenerationStrategy is provided + with self.assertRaisesRegex(UserInputError, "requires a GenerationStrategy"): + analysis.compute(experiment=self.client.experiment) + + card = analysis.compute( + experiment=self.client.experiment, + generation_strategy=self.client.generation_strategy, + ) + self.assertEqual( + card.name, + "SlicePlot", + ) + self.assertEqual(card.title, "x vs. bar") + self.assertEqual( + card.subtitle, + "1D slice of the surrogate model's predicted outcomes for bar", + ) + self.assertEqual(card.level, AnalysisCardLevel.LOW) + self.assertEqual( + {*card.df.columns}, + { + "x", + "bar_mean", + "bar_sem", + }, + ) + self.assertIsNotNone(card.blob) + self.assertEqual(card.blob_annotation, "plotly") diff --git a/ax/analysis/plotly/surface/utils.py b/ax/analysis/plotly/surface/utils.py new file mode 100644 index 00000000000..4b8acd632ee --- /dev/null +++ b/ax/analysis/plotly/surface/utils.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict +import math + +import numpy as np +from ax.core.parameter import ( + ChoiceParameter, + FixedParameter, + Parameter, + RangeParameter, + TParamValue, +) + + +def get_parameter_values(parameter: Parameter, density: int = 100) -> list[TParamValue]: + """ + Get a list of parameter values to predict over for a given parameter. + """ + + # For RangeParameter use linspace for the range of the parameter + if isinstance(parameter, RangeParameter): + if parameter.log_scale: + return np.logspace( + math.log10(parameter.lower), math.log10(parameter.upper), density + ).tolist() + + return np.linspace(parameter.lower, parameter.upper, density).tolist() + + # For ChoiceParameter use the values of the parameter directly + if isinstance(parameter, ChoiceParameter) and parameter.is_ordered: + return parameter.values + + raise ValueError( + f"Parameter {parameter.name} must be a RangeParameter or " + "ChoiceParameter with is_ordered=True to be used in surface plot." + ) + + +def select_fixed_value(parameter: Parameter) -> TParamValue: + """ + Select a fixed value for a parameter. Use mean for RangeParameter, "middle" value + for ChoiceParameter, and value for FixedParameter. + """ + if isinstance(parameter, RangeParameter): + return (parameter.lower * 1.0 + parameter.upper) / 2 + elif isinstance(parameter, ChoiceParameter): + return parameter.values[len(parameter.values) // 2] + elif isinstance(parameter, FixedParameter): + return parameter.value + else: + raise ValueError(f"Got unexpected parameter type {parameter}.") + + +def is_axis_log_scale(parameter: Parameter) -> bool: + """ + Check if the parameter is log scale. + """ + return isinstance(parameter, RangeParameter) and parameter.log_scale From bdae3640f24631d44106ac2a81e40dc347644196 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Wed, 20 Nov 2024 20:21:43 -0800 Subject: [PATCH 2/5] ContourPlot (for use in InteractionPlot) Differential Revision: D65233907 --- ax/analysis/plotly/__init__.py | 2 + ax/analysis/plotly/surface/__init__.py | 3 +- ax/analysis/plotly/surface/contour.py | 211 ++++++++++++++++++ .../plotly/surface/tests/test_contour.py | 83 +++++++ 4 files changed, 298 insertions(+), 1 deletion(-) create mode 100644 ax/analysis/plotly/surface/contour.py create mode 100644 ax/analysis/plotly/surface/tests/test_contour.py diff --git a/ax/analysis/plotly/__init__.py b/ax/analysis/plotly/__init__.py index a4f7b934ea6..a194854ca8f 100644 --- a/ax/analysis/plotly/__init__.py +++ b/ax/analysis/plotly/__init__.py @@ -9,9 +9,11 @@ from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.analysis.plotly.scatter import ScatterPlot +from ax.analysis.plotly.surface.contour import ContourPlot from ax.analysis.plotly.surface.slice import SlicePlot __all__ = [ + "ContourPlot", "CrossValidationPlot", "PlotlyAnalysis", "PlotlyAnalysisCard", diff --git a/ax/analysis/plotly/surface/__init__.py b/ax/analysis/plotly/surface/__init__.py index 5427f2ca6ec..f22e65e2769 100644 --- a/ax/analysis/plotly/surface/__init__.py +++ b/ax/analysis/plotly/surface/__init__.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from ax.analysis.plotly.surface.contour import ContourPlot from ax.analysis.plotly.surface.slice import SlicePlot -__all__ = ["SlicePlot"] +__all__ = ["ContourPlot", "SlicePlot"] diff --git a/ax/analysis/plotly/surface/contour.py b/ax/analysis/plotly/surface/contour.py new file mode 100644 index 00000000000..c3386058bfd --- /dev/null +++ b/ax/analysis/plotly/surface/contour.py @@ -0,0 +1,211 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import math +from typing import Optional + +import pandas as pd +from ax.analysis.analysis import AnalysisCardLevel + +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard +from ax.analysis.plotly.surface.utils import ( + get_parameter_values, + is_axis_log_scale, + select_fixed_value, +) +from ax.analysis.plotly.utils import select_metric +from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.core.observation import ObservationFeatures +from ax.exceptions.core import UserInputError +from ax.modelbridge.base import ModelBridge +from ax.modelbridge.generation_strategy import GenerationStrategy +from plotly import graph_objects as go +from pyre_extensions import none_throws + + +class ContourPlot(PlotlyAnalysis): + """ + Plot a 2D surface of the surrogate model's predicted outcomes for a given pair of + parameters, where all other parameters are held fixed at their status-quo value or + mean if no status quo is available. + + The DataFrame computed will contain the following columns: + - PARAMETER_NAME: The value of the x parameter specified + - PARAMETER_NAME: The value of the y parameter specified + - METRIC_NAME: The predected mean of the metric specified + """ + + def __init__( + self, + x_parameter_name: str, + y_parameter_name: str, + metric_name: str | None = None, + ) -> None: + """ + Args: + y_parameter_name: The name of the parameter to plot on the x-axis. + y_parameter_name: The name of the parameter to plot on the y-axis. + metric_name: The name of the metric to plot + """ + self.x_parameter_name = x_parameter_name + self.y_parameter_name = y_parameter_name + self.metric_name = metric_name + + def compute( + self, + experiment: Optional[Experiment] = None, + generation_strategy: Optional[GenerationStrategyInterface] = None, + ) -> PlotlyAnalysisCard: + if experiment is None: + raise UserInputError("ContourPlot requires an Experiment") + + if not isinstance(generation_strategy, GenerationStrategy): + raise UserInputError("ContourPlot requires a GenerationStrategy") + + if generation_strategy.model is None: + generation_strategy._fit_current_model(None) + + metric_name = self.metric_name or select_metric(experiment=experiment) + + df = _prepare_data( + experiment=experiment, + model=none_throws(generation_strategy.model), + x_parameter_name=self.x_parameter_name, + y_parameter_name=self.y_parameter_name, + metric_name=metric_name, + ) + + fig = _prepare_plot( + df=df, + experiment=experiment, + x_parameter_name=self.x_parameter_name, + y_parameter_name=self.y_parameter_name, + metric_name=metric_name, + log_x=is_axis_log_scale( + parameter=experiment.search_space.parameters[self.x_parameter_name] + ), + log_y=is_axis_log_scale( + parameter=experiment.search_space.parameters[self.y_parameter_name] + ), + ) + + return self._create_plotly_analysis_card( + title=( + f"{self.x_parameter_name}, {self.y_parameter_name} vs. {metric_name}" + ), + subtitle=( + "2D contour of the surrogate model's predicted outcomes for " + f"{metric_name}" + ), + level=AnalysisCardLevel.LOW, + df=df, + fig=fig, + ) + + +def _prepare_data( + experiment: Experiment, + model: ModelBridge, + x_parameter_name: str, + y_parameter_name: str, + metric_name: str, +) -> pd.DataFrame: + # Choose which parameter values to predict points for. + xs = get_parameter_values( + parameter=experiment.search_space.parameters[x_parameter_name], density=10 + ) + ys = get_parameter_values( + parameter=experiment.search_space.parameters[y_parameter_name], density=10 + ) + + # Construct observation features for each parameter value previously chosen by + # fixing all other parameters to their status-quo value or mean. + features = [ + ObservationFeatures( + parameters={ + x_parameter_name: x, + y_parameter_name: y, + **{ + parameter.name: select_fixed_value(parameter=parameter) + for parameter in experiment.search_space.parameters.values() + if not ( + parameter.name == x_parameter_name + or parameter.name == y_parameter_name + ) + }, + } + ) + for x in xs + for y in ys + ] + + predictions = model.predict(observation_features=features) + + return pd.DataFrame.from_records( + [ + { + x_parameter_name: features[i].parameters[x_parameter_name], + y_parameter_name: features[i].parameters[y_parameter_name], + f"{metric_name}_mean": predictions[0][metric_name][i], + } + for i in range(len(features)) + ] + ) + + +def _prepare_plot( + df: pd.DataFrame, + experiment: Experiment, + x_parameter_name: str, + y_parameter_name: str, + metric_name: str, + log_x: bool, + log_y: bool, +) -> go.Figure: + z_grid = df.pivot( + index=y_parameter_name, columns=x_parameter_name, values=f"{metric_name}_mean" + ) + + fig = go.Figure( + data=go.Contour( + z=z_grid.values, + x=z_grid.columns.values, + y=z_grid.index.values, + contours_coloring="heatmap", + showscale=False, + ), + layout=go.Layout( + xaxis_title=x_parameter_name, + yaxis_title=y_parameter_name, + ), + ) + + # Set the x-axis scale to log if relevant + if log_x: + fig.update_xaxes( + type="log", + range=[ + math.log10(df[x_parameter_name].min()), + math.log10(df[x_parameter_name].max()), + ], + ) + else: + fig.update_xaxes(range=[df[x_parameter_name].min(), df[x_parameter_name].max()]) + + if log_y: + fig.update_yaxes( + type="log", + range=[ + math.log10(df[y_parameter_name].min()), + math.log10(df[y_parameter_name].max()), + ], + ) + else: + fig.update_yaxes(range=[df[y_parameter_name].min(), df[y_parameter_name].max()]) + + return fig diff --git a/ax/analysis/plotly/surface/tests/test_contour.py b/ax/analysis/plotly/surface/tests/test_contour.py new file mode 100644 index 00000000000..6deec31ae4a --- /dev/null +++ b/ax/analysis/plotly/surface/tests/test_contour.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.plotly.surface.contour import ContourPlot +from ax.exceptions.core import UserInputError +from ax.service.ax_client import AxClient, ObjectiveProperties +from ax.utils.common.testutils import TestCase +from ax.utils.testing.mock import mock_botorch_optimize + + +class TestContourPlot(TestCase): + @mock_botorch_optimize + def setUp(self) -> None: + super().setUp() + self.client = AxClient() + self.client.create_experiment( + is_test=True, + name="foo", + parameters=[ + { + "name": "x", + "type": "range", + "bounds": [-1.0, 1.0], + }, + { + "name": "y", + "type": "range", + "bounds": [-1.0, 1.0], + }, + ], + objectives={"bar": ObjectiveProperties(minimize=True)}, + ) + + for _ in range(10): + parameterization, trial_index = self.client.get_next_trial() + self.client.complete_trial( + trial_index=trial_index, + raw_data={ + "bar": parameterization["x"] ** 2 + parameterization["y"] ** 2 + }, + ) + + def test_compute(self) -> None: + analysis = ContourPlot( + x_parameter_name="x", y_parameter_name="y", metric_name="bar" + ) + + # Test that it fails if no Experiment is provided + with self.assertRaisesRegex(UserInputError, "requires an Experiment"): + analysis.compute() + # Test that it fails if no GenerationStrategy is provided + with self.assertRaisesRegex(UserInputError, "requires a GenerationStrategy"): + analysis.compute(experiment=self.client.experiment) + + card = analysis.compute( + experiment=self.client.experiment, + generation_strategy=self.client.generation_strategy, + ) + self.assertEqual( + card.name, + "ContourPlot", + ) + self.assertEqual(card.title, "x, y vs. bar") + self.assertEqual( + card.subtitle, + "2D contour of the surrogate model's predicted outcomes for bar", + ) + self.assertEqual(card.level, AnalysisCardLevel.LOW) + self.assertEqual( + {*card.df.columns}, + { + "x", + "y", + "bar_mean", + }, + ) + self.assertIsNotNone(card.blob) + self.assertEqual(card.blob_annotation, "plotly") From 18ad6a47dee988d822a2ba58d7a5fdd479975dcd Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Wed, 20 Nov 2024 21:23:56 -0800 Subject: [PATCH 3/5] Move InteractionPlot to OSS Summary: As titled. Also renames InteractionAnalysis --> InteractionPlot to conform to convetions in this ax.analysis.plotly module. This is the second in a series of diffs that will get this Analysis ready for Ax 1.0, including: * Removal of dependency on ax.plot * Easier options in `InteractionPlot.__init__` * Misc tidying, pyre fixmes, etc Differential Revision: D65089145 --- ax/analysis/plotly/interaction.py | 614 +++++++++++++++++++ ax/analysis/plotly/tests/test_interaction.py | 197 ++++++ 2 files changed, 811 insertions(+) create mode 100644 ax/analysis/plotly/interaction.py create mode 100644 ax/analysis/plotly/tests/test_interaction.py diff --git a/ax/analysis/plotly/interaction.py b/ax/analysis/plotly/interaction.py new file mode 100644 index 00000000000..ccc5fd98eef --- /dev/null +++ b/ax/analysis/plotly/interaction.py @@ -0,0 +1,614 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +import math +from typing import Any + +import numpy as np +import numpy.typing as npt + +import pandas as pd + +import torch +from ax.analysis.analysis import AnalysisCardLevel + +from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard +from ax.core.data import Data +from ax.core.experiment import Experiment +from ax.core.generation_strategy_interface import GenerationStrategyInterface +from ax.core.observation import ObservationFeatures +from ax.exceptions.core import UserInputError +from ax.modelbridge.registry import Models +from ax.modelbridge.torch import TorchModelBridge +from ax.models.torch.botorch_modular.surrogate import Surrogate +from ax.plot.contour import _get_contour_predictions +from ax.plot.feature_importances import plot_feature_importance_by_feature_plotly +from ax.plot.helper import TNullableGeneratorRunsDict +from ax.plot.slice import _get_slice_predictions +from ax.utils.sensitivity.sobol_measures import ax_parameter_sens +from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel + +from gpytorch.constraints import Positive +from gpytorch.kernels import RBFKernel +from gpytorch.priors import LogNormalPrior +from plotly import graph_objects as go, io as pio +from plotly.subplots import make_subplots +from pyre_extensions import none_throws + + +TOP_K_TOO_LARGE_ERROR = ( + "Interaction Analysis only supports visualizing the slice/contour for" + " up to 6 component defined by the `top_k` argument, but received" + " {} as input." +) +MAX_NUM_PLOT_COMPONENTS: int = 6 +PLOT_SIZE: int = 380 + + +def get_model_kwargs( + use_interaction: bool, + num_parameters: int, + torch_device: torch.device | None = None, +) -> dict[str, Any]: + """Method to get the specific OAK kernel used to identify parameter interactions + in an Ax experiment. The kernel is an Orthogonal Additive Kernel (OAK), which + decomposes the objective into an additive sum of main parameter effects and + pairwise interaction effects. The kernel comes with a sparsity-inducing prior, + which attempts explain the data with as few components as possible. The + smoothness of the components is regularized by a lengthscale prior to guard + against excessicely short lengthscales being fit. + + Args: + use_interaction: Whether to use interaction effects. + num_parameters: Number of parameters in the experiment. + torch_device: The type of torch device to use for the model. + """ + # A fairly restrictive prior on the lengthscale avoids spurious + # fits, where a single component is fit to explain all variability. + # NOTE (hvarfner) Imposing a calibrated sparsity-inducing prior is + # probably a good add, but second-order components tend to break down + # even for weak priors. + return { + "covar_module_class": OrthogonalAdditiveKernel, + "covar_module_options": { + "base_kernel": RBFKernel( + ard_num_dims=num_parameters, + lengthscale_prior=LogNormalPrior(2, 1), + ), + "dim": num_parameters, + "dtype": torch.float64, + "device": torch_device, + "second_order": use_interaction, + "coeff_constraint": Positive(transform=torch.exp, inv_transform=torch.log), + }, + "allow_batched_models": False, + } + + +def sort_and_filter_top_k_components( + indices: dict[str, dict[str, npt.NDArray]], + k: int, + most_important: bool = True, +) -> dict[str, dict[str, npt.NDArray]]: + """Sorts and filter the top k components according to Sobol indices, per metric. + + Args: + indices: A dictionary of {metric: {component: sobol_index}} Sobol indices. + k: The number of components to keep. + most_important: Whether to keep the most or least important components. + + Returns: + A dictionary of the top k components. + """ + metrics = list(indices.keys()) + sorted_indices = { + metric: dict( + sorted( + metric_indices.items(), + key=lambda x: x[1], + reverse=most_important, + ) + ) + for metric, metric_indices in indices.items() + } + + # filter to top k components + sorted_indices = { + metric: { + key: value + for _, (key, value) in zip(range(k), sorted_indices[metric].items()) + } + for metric in metrics + } + return sorted_indices + + +class InteractionPlot(PlotlyAnalysis): + """ + Analysis class which tries to explain the data of an experiment as one- or two- + dimensional additive components with a level of sparsity in the components. The + relative importance of each component is quantified by its Sobol index. Each + component may be visualized through slice or contour plots depending on if it is + a first order or second order component, respectively. + """ + + def __init__( + self, + metric_name: str, + top_k: int = 6, + data: Data | None = None, + most_important: bool = True, + fit_interactions: bool = True, + display_components: bool = False, + decompose_components: bool = False, + plots_share_range: bool = True, + num_mc_samples: int = 10_000, + model_fit_seed: int = 0, + torch_device: torch.device | None = None, + ) -> None: + """Constructor for InteractionAnalysis. + + Args: + metric_name: The metric to analyze. + top_k: The 'k' most imortant interactions according to Sobol indices. + Supports up to 6 components visualized at once. + data: The data to analyze. Defaults to None, in which case the data is taken + from the experiment. + most_important: Whether to plot the most or least important interactions. + fit_interactions: Whether to fit interaction effects in addition to main + effects. + display_components: Display individual components instead of the summarized + plot of sobol index values. + decompose_components: Whether to visualize surfaces as the total effect of + x1 & x2 (False) or only the interaction term (True). Setting + decompose_components = True thus plots f(x1, x2) - f(x1) - f(x2). + plots_share_range: Whether to have all plots share the same output range in + the final visualization. + num_mc_samples: The number of Monte Carlo samples to use for the Sobol + index calculations. + model_fit_seed: The seed with which to fit the model. Defaults to 0. Used + to ensure that the model fit is identical across the generation of + various plots. + torch_device: The torch device to use for the model. + """ + + super().__init__() + if top_k > 6 and display_components: + raise UserInputError(TOP_K_TOO_LARGE_ERROR.format(str(top_k))) + self.metric_name: str = metric_name + self.top_k: int = top_k + self.data: Data | None = data + self.most_important: bool = most_important + self.fit_interactions: bool = fit_interactions + self.display_components: bool = display_components + self.decompose_components: bool = decompose_components + self.num_mc_samples: int = num_mc_samples + self.model_fit_seed: int = model_fit_seed + self.torch_device: torch.device | None = torch_device + self.plots_share_range: bool = plots_share_range + + def get_model( + self, experiment: Experiment, metric_names: list[str] | None = None + ) -> TorchModelBridge: + """ + Retrieves the modelbridge used for the analysis. The model uses an OAK + (Orthogonal Additive Kernel) with a sparsity-inducing prior, + which decomposes the objective into an additive sum of components. + """ + covar_module_kwargs = get_model_kwargs( + use_interaction=self.fit_interactions, + num_parameters=len(experiment.search_space.tunable_parameters), + torch_device=self.torch_device, + ) + data = experiment.lookup_data() if self.data is None else self.data + if metric_names: + data = data.filter(metric_names=metric_names) + with torch.random.fork_rng(): + # fixing the seed to ensure that the model is fit identically across + # different analyses of the same experiment + torch.torch.manual_seed(self.model_fit_seed) + model_bridge = Models.BOTORCH_MODULAR( + search_space=experiment.search_space, + experiment=experiment, + data=data, + surrogate=Surrogate(**covar_module_kwargs), + ) + return model_bridge # pyre-ignore[7] Return type is always a TorchModelBridge + + # pyre-ignore[14] Must pass in an Experiment (not Experiment | None) + def compute( + self, + experiment: Experiment | None = None, + generation_strategy: GenerationStrategyInterface | None = None, + ) -> PlotlyAnalysisCard: + model_bridge = self.get_model( + experiment=none_throws(experiment), metric_names=[self.metric_name] + ) + """ + Compute Sobol index sensitivity for one metric of an experiment. Sensitivity + is comptuted by component, where a compoent may be either one variable + (main effect) or two variables (interaction effect). The sensitivity is + computed using a model fit with an OAK kernel, which decomposes the objective + to be a sum of components, and where marginal effects can be computed + accurately. + """ + experiment = none_throws(experiment) + model_bridge = self.get_model(experiment, [self.metric_name]) + with torch.random.fork_rng(): + # fixing the seed to ensure that the model is fit identically across + # different analyses of the same experiment + torch.torch.manual_seed(self.model_fit_seed) + sens = ax_parameter_sens( + model_bridge=model_bridge, + metrics=[self.metric_name], + order="second" if self.fit_interactions else "first", + signed=not self.fit_interactions, + num_mc_samples=self.num_mc_samples, + ) + sens = sort_and_filter_top_k_components( + indices=sens, k=self.top_k, most_important=self.most_important + ) + if not self.display_components: + return PlotlyAnalysisCard( + name="Interaction Analysis", + title=f"Feature Importance Analysis for {self.metric_name}", + subtitle=( + "Displays the most important features " + f"for {self.metric_name} by order of importance." + ), + level=AnalysisCardLevel.MID, + df=pd.DataFrame(sens), + blob=pio.to_json( + plot_feature_importance_by_feature_plotly( + sensitivity_values=sens, # pyre-ignore[6] + ) + ), + ) + else: + metric_sens = list(sens[self.metric_name].keys()) + return PlotlyAnalysisCard( + name="OAK Interaction Analysis", + title=( + "Additive Component Feature Importance Analysis " + f"for {self.metric_name}" + ), + subtitle=( + "Displays the most important features' effects " + f"on {self.metric_name} by order of importance." + ), + level=AnalysisCardLevel.MID, + df=pd.DataFrame(sens), + blob=pio.to_json( + plot_component_surfaces_plotly( + features=metric_sens, + model=model_bridge, + metric=self.metric_name, + plots_share_range=self.plots_share_range, + ) + ), + ) + + +def update_plot_range(max_range: list[float], new_range: list[float]) -> list[float]: + """Updates the range to include the value. + Args: + max_range: Current max_range among all considered ranges. + new_range: New range to consider to be included. + + Returns: + The updated max_range. + """ + if max_range[0] > new_range[0]: + max_range[0] = new_range[0] + if max_range[1] < new_range[1]: + max_range[1] = new_range[1] + return max_range + + +def plot_component_surfaces_plotly( + features: list[str], + model: TorchModelBridge, + metric: str, + plots_share_range: bool = True, + generator_runs_dict: TNullableGeneratorRunsDict = None, + density: int = 50, + slice_values: dict[str, Any] | None = None, + fixed_features: ObservationFeatures | None = None, + trial_index: int | None = None, + renormalize: bool = True, +) -> go.Figure: + """Plots the interaction surfaces for the given features. + + Args: + features: The features to plot. Can be either 1D or 2D, where 2D features are + specified as "x1 & x2". + model: The modelbridge used for prediction. + metric: The name of the metric to plot. + plots_share_range: Whether to have all plots should share the same output range. + generator_runs_dict: The generator runs dict to use. + density: The density of the grid, i.e. the number of points evaluated in each + dimension. + slice_values: The slice values to use for the parameters that are not plotted. + fixed_features: The fixed features to use. + trial_index: The trial index to include in the plot. + renormalize: Whether to renormalize the surface to have zero mean. + + Returns: + A plotly figure of all the interaction surfaces. + """ + traces = [] + titles = [] + param_names = [] + + # tracks the maximal value range so that all plots of the same type share the same + # signal range in the final visualization. We cannot just check the largest + # component's sobol index, as it may not have the largest co-domain. + surface_range = [float("inf"), -float("inf")] + slice_range = [float("inf"), -float("inf")] + first_surface = True + for feature in features: + if " & " in feature: + component_x, component_y = feature.split(" & ") + trace, minval, maxval = generate_interaction_component( + model=model, + component_x=component_x, + component_y=component_y, + metric=metric, + generator_runs_dict=generator_runs_dict, + density=density, + slice_values=slice_values, + fixed_features=fixed_features, + trial_index=trial_index, + first_surface=first_surface, + ) + first_surface = False + traces.append(trace) + param_names.append((component_x, component_y)) + titles.append(f"Total effect, {component_x} & {component_y}") + surface_range = update_plot_range(surface_range, [minval, maxval]) + else: + trace, minval, maxval = generate_main_effect_component( + model=model, + component=feature, + metric=metric, + generator_runs_dict=generator_runs_dict, + density=density, + slice_values=slice_values, + fixed_features=fixed_features, + trial_index=trial_index, + ) + traces.append(trace) + param_names.append(feature) + titles.append(f"Main Effect, {feature}") + slice_range = update_plot_range(slice_range, [minval, maxval]) + + # 1x3 plots if 3 total, 2x2 plots if 4 total, 3x2 plots if 6 total + num_rows = 1 if len(traces) <= (MAX_NUM_PLOT_COMPONENTS / 2) else 2 + num_cols = math.ceil(len(traces) / num_rows) + + fig = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=titles) + for plot_idx, trace in enumerate(traces): + row = plot_idx // num_cols + 1 + col = plot_idx % num_cols + 1 + fig.add_trace(trace, row=row, col=col) + fig = set_axis_names( + figure=fig, trace=trace, row=row, col=col, param_names=param_names[plot_idx] + ) + + fig = scale_traces( + figure=fig, + traces=traces, + surface_range=surface_range, + slice_range=slice_range, + plots_share_range=plots_share_range, + ) + fig.update_layout({"width": PLOT_SIZE * num_cols, "height": PLOT_SIZE * num_rows}) + return fig + + +def generate_main_effect_component( + model: TorchModelBridge, + component: str, + metric: str, + generator_runs_dict: TNullableGeneratorRunsDict = None, + density: int = 50, + slice_values: dict[str, Any] | None = None, + fixed_features: ObservationFeatures | None = None, + trial_index: int | None = None, +) -> tuple[go.Scatter, float, float]: + """Plots a slice "main effect" of the model for a given component. The values are + relative to the mean of all predictions, so that the magnitude of the component is + communicated. + + Args: + model: The modelbridge used for prediction. + component: The name of the component to plot. + metric: The name of the metric to plot. + generator_runs_dict: The generator runs dict to use. + density: The density of the grid, i.e. the number of points evaluated in each + dimension. + slice_values: The slice values to use for the parameters that are not plotted. + fixed_features: The fixed features to use. + trial_index: The trial index to include in the plot. + + Returns: + A contour plot of the component interaction, and the range of the plot. + """ + _, _, slice_mean, _, grid, _, _, _, _, slice_stdev, _ = _get_slice_predictions( + model=model, + param_name=component, + metric_name=metric, + generator_runs_dict=generator_runs_dict, + density=density, + slice_values=slice_values, + fixed_features=fixed_features, + trial_index=trial_index, + ) + # renormalize the slice to have zero mean (done for each component) + slice_mean = np.array(slice_mean) - np.array(slice_mean).mean() + + trace = go.Scatter( + x=grid, + y=slice_mean, + name=component, + line_shape="spline", + showlegend=False, + error_y={ + "type": "data", + "array": slice_stdev, + "visible": True, + "thickness": 0.8, + }, + ) + + return trace, np.min(slice_mean).astype(float), np.max(slice_mean).astype(float) + + +def generate_interaction_component( + model: TorchModelBridge, + component_x: str, + component_y: str, + metric: str, + generator_runs_dict: TNullableGeneratorRunsDict = None, + density: int = 50, + slice_values: dict[str, Any] | None = None, + fixed_features: ObservationFeatures | None = None, + trial_index: int | None = None, + renormalize: bool = True, + first_surface: bool = True, +) -> tuple[go.Contour, float, float]: + """Plots a slice "main effect" of the model for a given component. The values are + relative to the mean of all predictions, so that the magnitude of the component is + communicated. + + Args: + model: The modelbridge used for prediction. + component_x: The name of the component to plot along the x-axis. + component_y: The name of the component to plot along the y-axis. + metric: The name of the metric to plot. + subtract_main_effects: Whether to subtract the main effects from the 2D surface. + If main effects are not subtracted, the 2D surface is the output of + plot_contour and models the effect of each parameter in isolation and their + interaction. If main effects are subtracted, the 2D surface visualizes only + the interaction effect of the two parameters. + generator_runs_dict: The generator runs dict to use. + density: The density of the grid, i.e. the number of points evaluated in each + dimension. + slice_values: The slice values to use for the parameters that are not plotted. + fixed_features: The fixed features to use. + trial_index: The trial index to include in the plot. + renormalize: Whether to renormalize the surface to have zero mean. + first_surface: Whether this is the first surface to be plotted. If so, we plot + its colorbar. + + Returns: + A contour plot of the component interaction, and the range of the plot. + """ + comp_name: str = f"{component_x} & {component_y}" + fixed_kwargs: dict[str, Any] = { + "model": model, + "generator_runs_dict": generator_runs_dict, + "density": density, + "slice_values": slice_values, + "fixed_features": fixed_features, + } + _, contour_mean, _, grid_x, grid_y, _ = _get_contour_predictions( + x_param_name=component_x, + y_param_name=component_y, + metric=metric, + **fixed_kwargs, + ) + contour_mean = np.reshape(contour_mean, (density, density)) + contour_mean = contour_mean - contour_mean.mean() + return ( + go.Contour( + z=contour_mean, + x=grid_x, + y=grid_y, + name=comp_name, + ncontours=50, + showscale=first_surface, + ), + np.min(contour_mean).astype(float), + np.max(contour_mean).astype(float), + ) + + +def scale_traces( + figure: go.Figure, + traces: list[go.Scatter | go.Contour], + surface_range: list[float], + slice_range: list[float], + plots_share_range: bool = True, +) -> go.Figure: + """Scales the traces to have the same range. + + Args: + figure: The main plotly figure to update the traces on. + traces: The traces to scale. + surface_range: The range of the surface traces. + slice_range: The range of the slice traces. + plots_share_range: Whether to have all plots (and not just plots + of the same type) share the same output range. + + Returns: + A figure with the traces of the same type are scaled to have the same range. + """ + if plots_share_range: + total_range = update_plot_range(surface_range, slice_range) + slice_range = total_range + surface_range = total_range + + # plotly axis names in layout are of the form "xaxis{idx}" and "yaxis{idx}" except + # for the first one, which is "xaxis" and "yaxis". We need to keep track of the + # indices of the traces and then use the correct axis names when updating ranges. + axis_names = ["yaxis"] + [f"yaxis{idx}" for idx in range(2, len(traces) + 1)] + slice_axes = [ + axis_name + for trace, axis_name in zip(traces, axis_names) + if isinstance(trace, go.Scatter) + ] + + # scale the surface traces to have the same range + for trace_idx in range(len(figure["data"])): + trace = figure["data"][trace_idx] + if isinstance(trace, go.Contour): + trace["zmin"] = surface_range[0] + trace["zmax"] = surface_range[1] + + # and scale the slice traces to have the same range + figure.update_layout({ax: {"range": slice_range} for ax in slice_axes}) + return figure + + +def set_axis_names( + figure: go.Figure, + trace: go.Contour | go.Scatter, + row: int, + col: int, + param_names: str | tuple[str, str], +) -> go.Figure: + """Sets the axis names for the given row and column. + + Args: + figure: The figure to update the axes on. + trace: The trace of the plot whose axis labels to update. + row: The row index of the trace in `figure`. + col: The column index of the trace in `figure`. + param_names: The parameter names to use for the axis names. + + Returns: + A figure where the trace at (row, col) has its axis names set. + """ + if isinstance(trace, go.Contour): + X_name, Y_name = param_names + figure.update_xaxes(title_text=X_name, row=row, col=col) + figure.update_yaxes(title_text=Y_name, row=row, col=col) + else: + figure.update_xaxes(title_text=param_names, row=row, col=col) + return figure diff --git a/ax/analysis/plotly/tests/test_interaction.py b/ax/analysis/plotly/tests/test_interaction.py new file mode 100644 index 00000000000..d60311a7223 --- /dev/null +++ b/ax/analysis/plotly/tests/test_interaction.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import pandas as pd +import torch +from ax.analysis.analysis import AnalysisCard +from ax.analysis.plotly.interaction import ( + generate_interaction_component, + generate_main_effect_component, + get_model_kwargs, + InteractionPlot, + TOP_K_TOO_LARGE_ERROR, +) +from ax.exceptions.core import DataRequiredError, UserInputError + +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_branin_experiment +from ax.utils.testing.mock import mock_botorch_optimize +from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel +from gpytorch.kernels import RBFKernel +from plotly import graph_objects as go + + +class InteractionTest(TestCase): + def test_interaction_get_model_kwargs(self) -> None: + kwargs = get_model_kwargs( + num_parameters=3, + use_interaction=False, + torch_device=torch.device("cpu"), + ) + self.assertEqual(kwargs["covar_module_class"], OrthogonalAdditiveKernel) + covar_module_options = kwargs["covar_module_options"] + self.assertIsInstance(covar_module_options["base_kernel"], RBFKernel) + self.assertEqual(covar_module_options["dim"], 3) + + # Checks that we can retrieve the modelbridge that has interaction terms + kwargs = get_model_kwargs( + num_parameters=5, + use_interaction=True, + torch_device=torch.device("cpu"), + ) + self.assertEqual(kwargs["covar_module_class"], OrthogonalAdditiveKernel) + self.assertIsInstance(kwargs["covar_module_options"]["base_kernel"], RBFKernel) + + @mock_botorch_optimize + def test_interaction_analysis_without_components(self) -> None: + exp = get_branin_experiment(with_completed_trial=True) + analysis = InteractionPlot( + metric_name="branin", + fit_interactions=False, + num_mc_samples=11, + ) + card = analysis.compute(experiment=exp) + self.assertIsInstance(card, AnalysisCard) + self.assertIsInstance(card.blob, str) + self.assertIsInstance(card.df, pd.DataFrame) + self.assertEqual( + card.name, + "Interaction Analysis", + ) + self.assertEqual( + card.title, + "Feature Importance Analysis for branin", + ) + self.assertEqual( + card.subtitle, + "Displays the most important features for branin by order of importance.", + ) + + # with interaction terms + analysis = InteractionPlot( + metric_name="branin", + fit_interactions=True, + num_mc_samples=11, + ) + card = analysis.compute(experiment=exp) + self.assertIsInstance(card, AnalysisCard) + self.assertIsInstance(card.blob, str) + self.assertIsInstance(card.df, pd.DataFrame) + self.assertEqual(len(card.df), 3) + self.assertEqual( + card.subtitle, + "Displays the most important features for branin by order of importance.", + ) + + with self.assertRaisesRegex(UserInputError, TOP_K_TOO_LARGE_ERROR.format("7")): + InteractionPlot(metric_name="branin", top_k=7, display_components=True) + + analysis = InteractionPlot(metric_name="branout", fit_interactions=False) + with self.assertRaisesRegex( + DataRequiredError, "StandardizeY` transform requires non-empty data." + ): + analysis.compute(experiment=exp) + + @mock_botorch_optimize + def test_interaction_with_components(self) -> None: + exp = get_branin_experiment(with_completed_trial=True) + analysis = InteractionPlot( + metric_name="branin", + fit_interactions=True, + display_components=True, + num_mc_samples=11, + ) + card = analysis.compute(experiment=exp) + self.assertIsInstance(card, AnalysisCard) + self.assertIsInstance(card.blob, str) + self.assertIsInstance(card.df, pd.DataFrame) + self.assertEqual(len(card.df), 3) + + analysis = InteractionPlot( + metric_name="branin", + fit_interactions=True, + display_components=True, + top_k=2, + num_mc_samples=11, + ) + card = analysis.compute(experiment=exp) + self.assertIsInstance(card, AnalysisCard) + self.assertEqual(len(card.df), 2) + analysis = InteractionPlot( + metric_name="branin", + fit_interactions=True, + display_components=True, + model_fit_seed=999, + num_mc_samples=11, + ) + card = analysis.compute(experiment=exp) + self.assertIsInstance(card, AnalysisCard) + self.assertEqual(len(card.df), 3) + + @mock_botorch_optimize + def test_generate_main_effect_component(self) -> None: + exp = get_branin_experiment(with_completed_trial=True) + analysis = InteractionPlot( + metric_name="branin", + fit_interactions=True, + display_components=True, + num_mc_samples=11, + ) + density = 13 + model = analysis.get_model(experiment=exp) + comp, _, _ = generate_main_effect_component( + model=model, + component="x1", + metric="branin", + density=density, + ) + self.assertIsInstance(comp, go.Scatter) + self.assertEqual(comp["x"].shape, (density,)) + self.assertEqual(comp["y"].shape, (density,)) + self.assertEqual(comp["name"], "x1") + + with self.assertRaisesRegex(KeyError, "braninandout"): + generate_main_effect_component( + model=model, + component="x1", + metric="braninandout", + density=density, + ) + + @mock_botorch_optimize + def test_generate_interaction_component(self) -> None: + exp = get_branin_experiment(with_completed_trial=True) + analysis = InteractionPlot( + metric_name="branin", + fit_interactions=True, + display_components=True, + num_mc_samples=11, + ) + density = 3 + model = analysis.get_model(experiment=exp) + comp, _, _ = generate_interaction_component( + model=model, + component_x="x1", + component_y="x2", + metric="branin", + density=density, + ) + self.assertIsInstance(comp, go.Contour) + self.assertEqual(comp["x"].shape, (density,)) + self.assertEqual(comp["y"].shape, (density,)) + self.assertEqual(comp["z"].shape, (density, density)) + self.assertEqual(comp["name"], "x1 & x2") + + with self.assertRaisesRegex(KeyError, "braninandout"): + generate_interaction_component( + model=model, + component_x="x1", + component_y="x2", + metric="braninandout", + density=density, + ) From 619a12c618419f95c36499a4d05b1388b1620d08 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Wed, 20 Nov 2024 21:23:57 -0800 Subject: [PATCH 4/5] Remove various args from InteractionPlot.__init__ Summary: Simplifying InteractionPlot by removing some arguments from its initializer. If this simplifications seem appropriate then I will continue to simplify the Plot. Removals: * top_k: Prefer to show sobol indices for all components on bar chart, slice/sufrace for top 6 always * data: Let's always use the data on the experiment * most_important: Always sort most important to least important, never least to most * display_components: Always display components * decompose_components: Never decompose components * plots_share_range: Always share range * num_mc_samples: Always use 10k samples * [RFC] model_fit_seed: Do not bother with seed setting -- we dont do this for any other plots so its probably not worth the complexity here The following diffs will restructure the code here to take advantage of the simplified options Differential Revision: D65148289 --- ax/analysis/plotly/interaction.py | 103 ++++++++++++------------------ 1 file changed, 41 insertions(+), 62 deletions(-) diff --git a/ax/analysis/plotly/interaction.py b/ax/analysis/plotly/interaction.py index ccc5fd98eef..1d2f5e7eb8c 100644 --- a/ax/analysis/plotly/interaction.py +++ b/ax/analysis/plotly/interaction.py @@ -18,11 +18,9 @@ from ax.analysis.analysis import AnalysisCardLevel from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard -from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.observation import ObservationFeatures -from ax.exceptions.core import UserInputError from ax.modelbridge.registry import Models from ax.modelbridge.torch import TorchModelBridge from ax.models.torch.botorch_modular.surrogate import Surrogate @@ -140,57 +138,31 @@ class InteractionPlot(PlotlyAnalysis): def __init__( self, metric_name: str, - top_k: int = 6, - data: Data | None = None, - most_important: bool = True, fit_interactions: bool = True, - display_components: bool = False, - decompose_components: bool = False, - plots_share_range: bool = True, - num_mc_samples: int = 10_000, - model_fit_seed: int = 0, + most_important: bool = True, + seed: int = 0, torch_device: torch.device | None = None, ) -> None: """Constructor for InteractionAnalysis. Args: metric_name: The metric to analyze. - top_k: The 'k' most imortant interactions according to Sobol indices. - Supports up to 6 components visualized at once. - data: The data to analyze. Defaults to None, in which case the data is taken - from the experiment. - most_important: Whether to plot the most or least important interactions. fit_interactions: Whether to fit interaction effects in addition to main effects. - display_components: Display individual components instead of the summarized - plot of sobol index values. - decompose_components: Whether to visualize surfaces as the total effect of - x1 & x2 (False) or only the interaction term (True). Setting - decompose_components = True thus plots f(x1, x2) - f(x1) - f(x2). - plots_share_range: Whether to have all plots share the same output range in - the final visualization. - num_mc_samples: The number of Monte Carlo samples to use for the Sobol - index calculations. - model_fit_seed: The seed with which to fit the model. Defaults to 0. Used + most_important: Whether to sort by most or least important features in the + bar subplot. Also controls whether the six most or least important + features are plotted in the surface subplots. + seed: The seed with which to fit the model. Defaults to 0. Used to ensure that the model fit is identical across the generation of various plots. torch_device: The torch device to use for the model. """ - super().__init__() - if top_k > 6 and display_components: - raise UserInputError(TOP_K_TOO_LARGE_ERROR.format(str(top_k))) - self.metric_name: str = metric_name - self.top_k: int = top_k - self.data: Data | None = data - self.most_important: bool = most_important - self.fit_interactions: bool = fit_interactions - self.display_components: bool = display_components - self.decompose_components: bool = decompose_components - self.num_mc_samples: int = num_mc_samples - self.model_fit_seed: int = model_fit_seed - self.torch_device: torch.device | None = torch_device - self.plots_share_range: bool = plots_share_range + self.metric_name = metric_name + self.fit_interactions = fit_interactions + self.most_important = most_important + self.seed = seed + self.torch_device = torch_device def get_model( self, experiment: Experiment, metric_names: list[str] | None = None @@ -205,19 +177,16 @@ def get_model( num_parameters=len(experiment.search_space.tunable_parameters), torch_device=self.torch_device, ) - data = experiment.lookup_data() if self.data is None else self.data + data = experiment.lookup_data() if metric_names: data = data.filter(metric_names=metric_names) - with torch.random.fork_rng(): - # fixing the seed to ensure that the model is fit identically across - # different analyses of the same experiment - torch.torch.manual_seed(self.model_fit_seed) - model_bridge = Models.BOTORCH_MODULAR( - search_space=experiment.search_space, - experiment=experiment, - data=data, - surrogate=Surrogate(**covar_module_kwargs), - ) + + model_bridge = Models.BOTORCH_MODULAR( + search_space=experiment.search_space, + experiment=experiment, + data=data, + surrogate=Surrogate(**covar_module_kwargs), + ) return model_bridge # pyre-ignore[7] Return type is always a TorchModelBridge # pyre-ignore[14] Must pass in an Experiment (not Experiment | None) @@ -239,19 +208,29 @@ def compute( """ experiment = none_throws(experiment) model_bridge = self.get_model(experiment, [self.metric_name]) - with torch.random.fork_rng(): - # fixing the seed to ensure that the model is fit identically across - # different analyses of the same experiment - torch.torch.manual_seed(self.model_fit_seed) - sens = ax_parameter_sens( - model_bridge=model_bridge, - metrics=[self.metric_name], - order="second" if self.fit_interactions else "first", - signed=not self.fit_interactions, - num_mc_samples=self.num_mc_samples, - ) + sens = ax_parameter_sens( + model_bridge=model_bridge, + metrics=[self.metric_name], + order="second" if self.fit_interactions else "first", + signed=not self.fit_interactions, + ) sens = sort_and_filter_top_k_components( - indices=sens, k=self.top_k, most_important=self.most_important + indices=sens, + k=6, + ) + # reformat the keys from tuple to a proper "x1 & x2" string + interaction_name = "Interaction" if self.fit_interactions else "Main Effect" + return PlotlyAnalysisCard( + name="Interaction Analysis", + title="Feature Importance Analysis", + subtitle=f"{interaction_name} Analysis for {self.metric_name}", + level=AnalysisCardLevel.MID, + df=pd.DataFrame(sens), + blob=pio.to_json( + plot_feature_importance_by_feature_plotly( + sensitivity_values=sens, # pyre-ignore[6] + ) + ), ) if not self.display_components: return PlotlyAnalysisCard( From 4e95ad2bd57480a29f94d0a4b351e307a2e0ccb8 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Wed, 20 Nov 2024 21:27:10 -0800 Subject: [PATCH 5/5] InteractionPlot refactor (#3097) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3097 When this is landed we will be able to use this plot in Ax 1.0 and Ax UI. Refactor the interaction plot to be in line with our structure for ax.analysis. This includes a massive reduction in overall code (about half) and a full decoupling from ax.plot. Adds robustness features around generating subplots -- a failed surface subplot will no longer fail the full analysis. This new version of the plot is slightly more opinionated in that we always plot both the feature importance bar chart AND the top 6 features, always plots top 15 components in the bar chart, never decomposes components, and always has plots share scale. These settings are most useful and help drastically simplify the code, so I think we should keep them for now and only consider adding them back if there is demand. Differential Revision: D65234856 --- ax/analysis/plotly/__init__.py | 2 + ax/analysis/plotly/interaction.py | 753 ++++++------------- ax/analysis/plotly/surface/contour.py | 2 - ax/analysis/plotly/tests/test_interaction.py | 229 ++---- 4 files changed, 308 insertions(+), 678 deletions(-) diff --git a/ax/analysis/plotly/__init__.py b/ax/analysis/plotly/__init__.py index a194854ca8f..51fe32e6274 100644 --- a/ax/analysis/plotly/__init__.py +++ b/ax/analysis/plotly/__init__.py @@ -6,6 +6,7 @@ # pyre-strict from ax.analysis.plotly.cross_validation import CrossValidationPlot +from ax.analysis.plotly.interaction import InteractionPlot from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard from ax.analysis.plotly.scatter import ScatterPlot @@ -15,6 +16,7 @@ __all__ = [ "ContourPlot", "CrossValidationPlot", + "InteractionPlot", "PlotlyAnalysis", "PlotlyAnalysisCard", "ParallelCoordinatesPlot", diff --git a/ax/analysis/plotly/interaction.py b/ax/analysis/plotly/interaction.py index 1d2f5e7eb8c..96930499258 100644 --- a/ax/analysis/plotly/interaction.py +++ b/ax/analysis/plotly/interaction.py @@ -6,124 +6,42 @@ # pyre-strict -import math -from typing import Any - -import numpy as np -import numpy.typing as npt +from logging import Logger import pandas as pd - import torch from ax.analysis.analysis import AnalysisCardLevel from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard + +from ax.analysis.plotly.surface.contour import ( + _prepare_data as _prepare_contour_data, + _prepare_plot as _prepare_contour_plot, +) +from ax.analysis.plotly.surface.slice import ( + _prepare_data as _prepare_slice_data, + _prepare_plot as _prepare_slice_plot, +) +from ax.analysis.plotly.surface.utils import is_axis_log_scale +from ax.analysis.plotly.utils import select_metric from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface -from ax.core.observation import ObservationFeatures +from ax.exceptions.core import UserInputError from ax.modelbridge.registry import Models from ax.modelbridge.torch import TorchModelBridge from ax.models.torch.botorch_modular.surrogate import Surrogate -from ax.plot.contour import _get_contour_predictions -from ax.plot.feature_importances import plot_feature_importance_by_feature_plotly -from ax.plot.helper import TNullableGeneratorRunsDict -from ax.plot.slice import _get_slice_predictions +from ax.utils.common.logger import get_logger from ax.utils.sensitivity.sobol_measures import ax_parameter_sens from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel from gpytorch.constraints import Positive from gpytorch.kernels import RBFKernel from gpytorch.priors import LogNormalPrior -from plotly import graph_objects as go, io as pio +from plotly import express as px, graph_objects as go from plotly.subplots import make_subplots -from pyre_extensions import none_throws - - -TOP_K_TOO_LARGE_ERROR = ( - "Interaction Analysis only supports visualizing the slice/contour for" - " up to 6 component defined by the `top_k` argument, but received" - " {} as input." -) -MAX_NUM_PLOT_COMPONENTS: int = 6 -PLOT_SIZE: int = 380 - - -def get_model_kwargs( - use_interaction: bool, - num_parameters: int, - torch_device: torch.device | None = None, -) -> dict[str, Any]: - """Method to get the specific OAK kernel used to identify parameter interactions - in an Ax experiment. The kernel is an Orthogonal Additive Kernel (OAK), which - decomposes the objective into an additive sum of main parameter effects and - pairwise interaction effects. The kernel comes with a sparsity-inducing prior, - which attempts explain the data with as few components as possible. The - smoothness of the components is regularized by a lengthscale prior to guard - against excessicely short lengthscales being fit. - - Args: - use_interaction: Whether to use interaction effects. - num_parameters: Number of parameters in the experiment. - torch_device: The type of torch device to use for the model. - """ - # A fairly restrictive prior on the lengthscale avoids spurious - # fits, where a single component is fit to explain all variability. - # NOTE (hvarfner) Imposing a calibrated sparsity-inducing prior is - # probably a good add, but second-order components tend to break down - # even for weak priors. - return { - "covar_module_class": OrthogonalAdditiveKernel, - "covar_module_options": { - "base_kernel": RBFKernel( - ard_num_dims=num_parameters, - lengthscale_prior=LogNormalPrior(2, 1), - ), - "dim": num_parameters, - "dtype": torch.float64, - "device": torch_device, - "second_order": use_interaction, - "coeff_constraint": Positive(transform=torch.exp, inv_transform=torch.log), - }, - "allow_batched_models": False, - } - - -def sort_and_filter_top_k_components( - indices: dict[str, dict[str, npt.NDArray]], - k: int, - most_important: bool = True, -) -> dict[str, dict[str, npt.NDArray]]: - """Sorts and filter the top k components according to Sobol indices, per metric. - - Args: - indices: A dictionary of {metric: {component: sobol_index}} Sobol indices. - k: The number of components to keep. - most_important: Whether to keep the most or least important components. - - Returns: - A dictionary of the top k components. - """ - metrics = list(indices.keys()) - sorted_indices = { - metric: dict( - sorted( - metric_indices.items(), - key=lambda x: x[1], - reverse=most_important, - ) - ) - for metric, metric_indices in indices.items() - } +from pyre_extensions import assert_is_instance - # filter to top k components - sorted_indices = { - metric: { - key: value - for _, (key, value) in zip(range(k), sorted_indices[metric].items()) - } - for metric in metrics - } - return sorted_indices +logger: Logger = get_logger(__name__) class InteractionPlot(PlotlyAnalysis): @@ -133,18 +51,22 @@ class InteractionPlot(PlotlyAnalysis): relative importance of each component is quantified by its Sobol index. Each component may be visualized through slice or contour plots depending on if it is a first order or second order component, respectively. + + The DataFrame computed will contain just the sensitivity analyisis with one row per + parameter and the following columns: + - feature: The name of the first or second order component + - sensitivity: The sensitivity of the component """ def __init__( self, - metric_name: str, + metric_name: str | None = None, fit_interactions: bool = True, most_important: bool = True, seed: int = 0, torch_device: torch.device | None = None, ) -> None: - """Constructor for InteractionAnalysis. - + """ Args: metric_name: The metric to analyze. fit_interactions: Whether to fit interaction effects in addition to main @@ -164,430 +86,255 @@ def __init__( self.seed = seed self.torch_device = torch_device - def get_model( - self, experiment: Experiment, metric_names: list[str] | None = None - ) -> TorchModelBridge: - """ - Retrieves the modelbridge used for the analysis. The model uses an OAK - (Orthogonal Additive Kernel) with a sparsity-inducing prior, - which decomposes the objective into an additive sum of components. - """ - covar_module_kwargs = get_model_kwargs( - use_interaction=self.fit_interactions, - num_parameters=len(experiment.search_space.tunable_parameters), - torch_device=self.torch_device, - ) - data = experiment.lookup_data() - if metric_names: - data = data.filter(metric_names=metric_names) - - model_bridge = Models.BOTORCH_MODULAR( - search_space=experiment.search_space, - experiment=experiment, - data=data, - surrogate=Surrogate(**covar_module_kwargs), - ) - return model_bridge # pyre-ignore[7] Return type is always a TorchModelBridge - - # pyre-ignore[14] Must pass in an Experiment (not Experiment | None) def compute( self, experiment: Experiment | None = None, generation_strategy: GenerationStrategyInterface | None = None, ) -> PlotlyAnalysisCard: - model_bridge = self.get_model( - experiment=none_throws(experiment), metric_names=[self.metric_name] + if experiment is None: + raise UserInputError("InteractionPlot requires an Experiment") + + metric_name = self.metric_name or select_metric(experiment=experiment) + + # Fix the seed to ensure that the model is fit identically across different + # analyses of the same experiment. + with torch.random.fork_rng(): + torch.torch.manual_seed(self.seed) + + # Fit the OAK model. + oak_model = self._get_oak_model( + experiment=experiment, metric_name=metric_name + ) + + # Calculate first- or second-order Sobol indices. + sens = ax_parameter_sens( + model_bridge=oak_model, + metrics=[metric_name], + order="second" if self.fit_interactions else "first", + signed=not self.fit_interactions, + )[metric_name] + + sensitivity_df = pd.DataFrame( + [*sens.items()], columns=["feature", "sensitivity"] + ).sort_values(by="sensitivity", key=abs, ascending=self.most_important) + + # Calculate feature importance bar plot. Only plot the top 15 features. + # Plot the absolute value of the sensitivity but color by the sign. + plotting_df = sensitivity_df.head(15).copy() + plotting_df["direction"] = plotting_df["sensitivity"].apply( + lambda x: "Increases Metric" if x >= 0 else "Decreases Metric" ) - """ - Compute Sobol index sensitivity for one metric of an experiment. Sensitivity - is comptuted by component, where a compoent may be either one variable - (main effect) or two variables (interaction effect). The sensitivity is - computed using a model fit with an OAK kernel, which decomposes the objective - to be a sum of components, and where marginal effects can be computed - accurately. - """ - experiment = none_throws(experiment) - model_bridge = self.get_model(experiment, [self.metric_name]) - sens = ax_parameter_sens( - model_bridge=model_bridge, - metrics=[self.metric_name], - order="second" if self.fit_interactions else "first", - signed=not self.fit_interactions, + plotting_df["sensitivity"] = plotting_df["sensitivity"].abs() + + sensitivity_fig = px.bar( + plotting_df.sort_values( + by="sensitivity", key=abs, ascending=self.most_important + ), + x="sensitivity", + y="feature", + color="direction", + # Increase gets blue, decrease gets orange. + color_discrete_sequence=["orange", "blue"], + orientation="h", ) - sens = sort_and_filter_top_k_components( - indices=sens, - k=6, + + # Calculate surface plots for six most or least important features + # Note: We use tail and reverse here because the bar plot is sorted from top to + # bottom. + top_features = [*reversed(sensitivity_df.tail(6)["feature"].to_list())] + surface_figs = [] + for feature_name in top_features: + try: + surface_figs.append( + _prepare_surface_plot( + experiment=experiment, + model=oak_model, + feature_name=feature_name, + metric_name=metric_name, + ) + ) + # Not all features will be able to be plotted, skip them gracefully. + except Exception as e: + logger.error(f"Failed to generate surface plot for {feature_name}: {e}") + + # Create a plotly figure to hold the bar plot in the top row and surface plots + # in a 3x2 grid below. + fig = make_subplots( + rows=4, + cols=3, + specs=[ + [{"colspan": 3}, None, None], + [{}, {}, {}], + [{}, {}, {}], + [{}, {}, {}], + ], ) - # reformat the keys from tuple to a proper "x1 & x2" string - interaction_name = "Interaction" if self.fit_interactions else "Main Effect" - return PlotlyAnalysisCard( - name="Interaction Analysis", - title="Feature Importance Analysis", - subtitle=f"{interaction_name} Analysis for {self.metric_name}", - level=AnalysisCardLevel.MID, - df=pd.DataFrame(sens), - blob=pio.to_json( - plot_feature_importance_by_feature_plotly( - sensitivity_values=sens, # pyre-ignore[6] + + for trace in sensitivity_fig.data: + fig.add_trace(trace, row=1, col=1) + + for i in range(len(surface_figs)): + feature_name = top_features[i] + surface_fig = surface_figs[i] + + row = (i // 3) + 2 + col = (i % 3) + 1 + for trace in surface_fig.data: + fig.add_trace(trace, row=row, col=col) + + # Label and fix axes + if "&" in feature_name: # If the feature is a second-order component + x, y = feature_name.split(" & ") + + # Reapply log scale if necessary + fig.update_xaxes( + title_text=x, + type=( + "log" + if is_axis_log_scale( + parameter=experiment.search_space.parameters[x] + ) + else "linear" + ), + row=row, + col=col, + ) + fig.update_yaxes( + title_text=y, + type=( + "log" + if is_axis_log_scale( + parameter=experiment.search_space.parameters[y] + ) + else "linear" + ), + row=row, + col=col, ) + else: # If the feature is a first-order component + fig.update_xaxes( + title_text=feature_name, + type=( + "log" + if is_axis_log_scale( + parameter=experiment.search_space.parameters[feature_name] + ) + else "linear" + ), + row=row, + col=col, + ) + + fig.update_layout( + height=1500, + width=1500, + ) + + subtitle_substring = ( + "one- or two-dimensional" if self.fit_interactions else "one-dimensional" + ) + + return self._create_plotly_analysis_card( + title=f"Interaction Analysis for {metric_name}", + subtitle=( + f"Understand an Experiment's data as {subtitle_substring} additive " + "components with sparsity. Important components are visualized through " + "slice or contour plots" ), + level=AnalysisCardLevel.MID, + df=sensitivity_df, + fig=fig, ) - if not self.display_components: - return PlotlyAnalysisCard( - name="Interaction Analysis", - title=f"Feature Importance Analysis for {self.metric_name}", - subtitle=( - "Displays the most important features " - f"for {self.metric_name} by order of importance." - ), - level=AnalysisCardLevel.MID, - df=pd.DataFrame(sens), - blob=pio.to_json( - plot_feature_importance_by_feature_plotly( - sensitivity_values=sens, # pyre-ignore[6] - ) - ), - ) - else: - metric_sens = list(sens[self.metric_name].keys()) - return PlotlyAnalysisCard( - name="OAK Interaction Analysis", - title=( - "Additive Component Feature Importance Analysis " - f"for {self.metric_name}" - ), - subtitle=( - "Displays the most important features' effects " - f"on {self.metric_name} by order of importance." - ), - level=AnalysisCardLevel.MID, - df=pd.DataFrame(sens), - blob=pio.to_json( - plot_component_surfaces_plotly( - features=metric_sens, - model=model_bridge, - metric=self.metric_name, - plots_share_range=self.plots_share_range, - ) - ), - ) + def _get_oak_model( + self, experiment: Experiment, metric_name: str + ) -> TorchModelBridge: + """ + Retrieves the modelbridge used for the analysis. The model uses an OAK + (Orthogonal Additive Kernel) with a sparsity-inducing prior, + which decomposes the objective into an additive sum of components. -def update_plot_range(max_range: list[float], new_range: list[float]) -> list[float]: - """Updates the range to include the value. - Args: - max_range: Current max_range among all considered ranges. - new_range: New range to consider to be included. + The kernel comes with a sparsity-inducing prior, which attempts explain the + data with as few components as possible. The smoothness of the components is + regularized by a lengthscale prior to guard against excessicely short + lengthscales being fit. + """ + data = experiment.lookup_data().filter(metric_names=[metric_name]) + model_bridge = Models.BOTORCH_MODULAR( + search_space=experiment.search_space, + experiment=experiment, + data=data, + surrogate=Surrogate( + covar_module_class=OrthogonalAdditiveKernel, + covar_module_options={ + # A fairly restrictive prior on the lengthscale avoids spurious + # fits, where a single component is fit to explain all + # variability. + # NOTE (hvarfner) Imposing a calibrated sparsity-inducing prior + # is probably a good add, but second-order components tend to + # break down even for weak priors. + "base_kernel": RBFKernel( + ard_num_dims=len(experiment.search_space.tunable_parameters), + lengthscale_prior=LogNormalPrior(2, 1), + ), + "dim": len(experiment.search_space.tunable_parameters), + "dtype": torch.float64, + "device": self.torch_device, + "second_order": self.fit_interactions, + "coeff_constraint": Positive( + transform=torch.exp, inv_transform=torch.log + ), + }, + allow_batched_models=False, + ), + ) - Returns: - The updated max_range. - """ - if max_range[0] > new_range[0]: - max_range[0] = new_range[0] - if max_range[1] < new_range[1]: - max_range[1] = new_range[1] - return max_range + return assert_is_instance(model_bridge, TorchModelBridge) -def plot_component_surfaces_plotly( - features: list[str], +def _prepare_surface_plot( + experiment: Experiment, model: TorchModelBridge, - metric: str, - plots_share_range: bool = True, - generator_runs_dict: TNullableGeneratorRunsDict = None, - density: int = 50, - slice_values: dict[str, Any] | None = None, - fixed_features: ObservationFeatures | None = None, - trial_index: int | None = None, - renormalize: bool = True, + feature_name: str, + metric_name: str, ) -> go.Figure: - """Plots the interaction surfaces for the given features. - - Args: - features: The features to plot. Can be either 1D or 2D, where 2D features are - specified as "x1 & x2". - model: The modelbridge used for prediction. - metric: The name of the metric to plot. - plots_share_range: Whether to have all plots should share the same output range. - generator_runs_dict: The generator runs dict to use. - density: The density of the grid, i.e. the number of points evaluated in each - dimension. - slice_values: The slice values to use for the parameters that are not plotted. - fixed_features: The fixed features to use. - trial_index: The trial index to include in the plot. - renormalize: Whether to renormalize the surface to have zero mean. - - Returns: - A plotly figure of all the interaction surfaces. - """ - traces = [] - titles = [] - param_names = [] - - # tracks the maximal value range so that all plots of the same type share the same - # signal range in the final visualization. We cannot just check the largest - # component's sobol index, as it may not have the largest co-domain. - surface_range = [float("inf"), -float("inf")] - slice_range = [float("inf"), -float("inf")] - first_surface = True - for feature in features: - if " & " in feature: - component_x, component_y = feature.split(" & ") - trace, minval, maxval = generate_interaction_component( - model=model, - component_x=component_x, - component_y=component_y, - metric=metric, - generator_runs_dict=generator_runs_dict, - density=density, - slice_values=slice_values, - fixed_features=fixed_features, - trial_index=trial_index, - first_surface=first_surface, - ) - first_surface = False - traces.append(trace) - param_names.append((component_x, component_y)) - titles.append(f"Total effect, {component_x} & {component_y}") - surface_range = update_plot_range(surface_range, [minval, maxval]) - else: - trace, minval, maxval = generate_main_effect_component( - model=model, - component=feature, - metric=metric, - generator_runs_dict=generator_runs_dict, - density=density, - slice_values=slice_values, - fixed_features=fixed_features, - trial_index=trial_index, - ) - traces.append(trace) - param_names.append(feature) - titles.append(f"Main Effect, {feature}") - slice_range = update_plot_range(slice_range, [minval, maxval]) - - # 1x3 plots if 3 total, 2x2 plots if 4 total, 3x2 plots if 6 total - num_rows = 1 if len(traces) <= (MAX_NUM_PLOT_COMPONENTS / 2) else 2 - num_cols = math.ceil(len(traces) / num_rows) - - fig = make_subplots(rows=num_rows, cols=num_cols, subplot_titles=titles) - for plot_idx, trace in enumerate(traces): - row = plot_idx // num_cols + 1 - col = plot_idx % num_cols + 1 - fig.add_trace(trace, row=row, col=col) - fig = set_axis_names( - figure=fig, trace=trace, row=row, col=col, param_names=param_names[plot_idx] + if "&" in feature_name: + # Plot a contour plot for the second-order component. + x_parameter_name, y_parameter_name = feature_name.split(" & ") + df = _prepare_contour_data( + experiment=experiment, + model=model, + x_parameter_name=x_parameter_name, + y_parameter_name=y_parameter_name, + metric_name=metric_name, ) - fig = scale_traces( - figure=fig, - traces=traces, - surface_range=surface_range, - slice_range=slice_range, - plots_share_range=plots_share_range, - ) - fig.update_layout({"width": PLOT_SIZE * num_cols, "height": PLOT_SIZE * num_rows}) - return fig - + return _prepare_contour_plot( + df=df, + x_parameter_name=x_parameter_name, + y_parameter_name=y_parameter_name, + metric_name=metric_name, + log_x=is_axis_log_scale( + parameter=experiment.search_space.parameters[x_parameter_name] + ), + log_y=is_axis_log_scale( + parameter=experiment.search_space.parameters[y_parameter_name] + ), + ) -def generate_main_effect_component( - model: TorchModelBridge, - component: str, - metric: str, - generator_runs_dict: TNullableGeneratorRunsDict = None, - density: int = 50, - slice_values: dict[str, Any] | None = None, - fixed_features: ObservationFeatures | None = None, - trial_index: int | None = None, -) -> tuple[go.Scatter, float, float]: - """Plots a slice "main effect" of the model for a given component. The values are - relative to the mean of all predictions, so that the magnitude of the component is - communicated. - - Args: - model: The modelbridge used for prediction. - component: The name of the component to plot. - metric: The name of the metric to plot. - generator_runs_dict: The generator runs dict to use. - density: The density of the grid, i.e. the number of points evaluated in each - dimension. - slice_values: The slice values to use for the parameters that are not plotted. - fixed_features: The fixed features to use. - trial_index: The trial index to include in the plot. - - Returns: - A contour plot of the component interaction, and the range of the plot. - """ - _, _, slice_mean, _, grid, _, _, _, _, slice_stdev, _ = _get_slice_predictions( + # If the feature is a first-order component, plot a slice plot. + df = _prepare_slice_data( + experiment=experiment, model=model, - param_name=component, - metric_name=metric, - generator_runs_dict=generator_runs_dict, - density=density, - slice_values=slice_values, - fixed_features=fixed_features, - trial_index=trial_index, + parameter_name=feature_name, + metric_name=metric_name, ) - # renormalize the slice to have zero mean (done for each component) - slice_mean = np.array(slice_mean) - np.array(slice_mean).mean() - - trace = go.Scatter( - x=grid, - y=slice_mean, - name=component, - line_shape="spline", - showlegend=False, - error_y={ - "type": "data", - "array": slice_stdev, - "visible": True, - "thickness": 0.8, - }, - ) - - return trace, np.min(slice_mean).astype(float), np.max(slice_mean).astype(float) - -def generate_interaction_component( - model: TorchModelBridge, - component_x: str, - component_y: str, - metric: str, - generator_runs_dict: TNullableGeneratorRunsDict = None, - density: int = 50, - slice_values: dict[str, Any] | None = None, - fixed_features: ObservationFeatures | None = None, - trial_index: int | None = None, - renormalize: bool = True, - first_surface: bool = True, -) -> tuple[go.Contour, float, float]: - """Plots a slice "main effect" of the model for a given component. The values are - relative to the mean of all predictions, so that the magnitude of the component is - communicated. - - Args: - model: The modelbridge used for prediction. - component_x: The name of the component to plot along the x-axis. - component_y: The name of the component to plot along the y-axis. - metric: The name of the metric to plot. - subtract_main_effects: Whether to subtract the main effects from the 2D surface. - If main effects are not subtracted, the 2D surface is the output of - plot_contour and models the effect of each parameter in isolation and their - interaction. If main effects are subtracted, the 2D surface visualizes only - the interaction effect of the two parameters. - generator_runs_dict: The generator runs dict to use. - density: The density of the grid, i.e. the number of points evaluated in each - dimension. - slice_values: The slice values to use for the parameters that are not plotted. - fixed_features: The fixed features to use. - trial_index: The trial index to include in the plot. - renormalize: Whether to renormalize the surface to have zero mean. - first_surface: Whether this is the first surface to be plotted. If so, we plot - its colorbar. - - Returns: - A contour plot of the component interaction, and the range of the plot. - """ - comp_name: str = f"{component_x} & {component_y}" - fixed_kwargs: dict[str, Any] = { - "model": model, - "generator_runs_dict": generator_runs_dict, - "density": density, - "slice_values": slice_values, - "fixed_features": fixed_features, - } - _, contour_mean, _, grid_x, grid_y, _ = _get_contour_predictions( - x_param_name=component_x, - y_param_name=component_y, - metric=metric, - **fixed_kwargs, - ) - contour_mean = np.reshape(contour_mean, (density, density)) - contour_mean = contour_mean - contour_mean.mean() - return ( - go.Contour( - z=contour_mean, - x=grid_x, - y=grid_y, - name=comp_name, - ncontours=50, - showscale=first_surface, + return _prepare_slice_plot( + df=df, + parameter_name=feature_name, + metric_name=metric_name, + log_x=is_axis_log_scale( + parameter=experiment.search_space.parameters[feature_name] ), - np.min(contour_mean).astype(float), - np.max(contour_mean).astype(float), ) - - -def scale_traces( - figure: go.Figure, - traces: list[go.Scatter | go.Contour], - surface_range: list[float], - slice_range: list[float], - plots_share_range: bool = True, -) -> go.Figure: - """Scales the traces to have the same range. - - Args: - figure: The main plotly figure to update the traces on. - traces: The traces to scale. - surface_range: The range of the surface traces. - slice_range: The range of the slice traces. - plots_share_range: Whether to have all plots (and not just plots - of the same type) share the same output range. - - Returns: - A figure with the traces of the same type are scaled to have the same range. - """ - if plots_share_range: - total_range = update_plot_range(surface_range, slice_range) - slice_range = total_range - surface_range = total_range - - # plotly axis names in layout are of the form "xaxis{idx}" and "yaxis{idx}" except - # for the first one, which is "xaxis" and "yaxis". We need to keep track of the - # indices of the traces and then use the correct axis names when updating ranges. - axis_names = ["yaxis"] + [f"yaxis{idx}" for idx in range(2, len(traces) + 1)] - slice_axes = [ - axis_name - for trace, axis_name in zip(traces, axis_names) - if isinstance(trace, go.Scatter) - ] - - # scale the surface traces to have the same range - for trace_idx in range(len(figure["data"])): - trace = figure["data"][trace_idx] - if isinstance(trace, go.Contour): - trace["zmin"] = surface_range[0] - trace["zmax"] = surface_range[1] - - # and scale the slice traces to have the same range - figure.update_layout({ax: {"range": slice_range} for ax in slice_axes}) - return figure - - -def set_axis_names( - figure: go.Figure, - trace: go.Contour | go.Scatter, - row: int, - col: int, - param_names: str | tuple[str, str], -) -> go.Figure: - """Sets the axis names for the given row and column. - - Args: - figure: The figure to update the axes on. - trace: The trace of the plot whose axis labels to update. - row: The row index of the trace in `figure`. - col: The column index of the trace in `figure`. - param_names: The parameter names to use for the axis names. - - Returns: - A figure where the trace at (row, col) has its axis names set. - """ - if isinstance(trace, go.Contour): - X_name, Y_name = param_names - figure.update_xaxes(title_text=X_name, row=row, col=col) - figure.update_yaxes(title_text=Y_name, row=row, col=col) - else: - figure.update_xaxes(title_text=param_names, row=row, col=col) - return figure diff --git a/ax/analysis/plotly/surface/contour.py b/ax/analysis/plotly/surface/contour.py index c3386058bfd..ff96222382e 100644 --- a/ax/analysis/plotly/surface/contour.py +++ b/ax/analysis/plotly/surface/contour.py @@ -82,7 +82,6 @@ def compute( fig = _prepare_plot( df=df, - experiment=experiment, x_parameter_name=self.x_parameter_name, y_parameter_name=self.y_parameter_name, metric_name=metric_name, @@ -160,7 +159,6 @@ def _prepare_data( def _prepare_plot( df: pd.DataFrame, - experiment: Experiment, x_parameter_name: str, y_parameter_name: str, metric_name: str, diff --git a/ax/analysis/plotly/tests/test_interaction.py b/ax/analysis/plotly/tests/test_interaction.py index d60311a7223..196be8e10f3 100644 --- a/ax/analysis/plotly/tests/test_interaction.py +++ b/ax/analysis/plotly/tests/test_interaction.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the @@ -6,192 +5,76 @@ # pyre-strict -import pandas as pd -import torch -from ax.analysis.analysis import AnalysisCard -from ax.analysis.plotly.interaction import ( - generate_interaction_component, - generate_main_effect_component, - get_model_kwargs, - InteractionPlot, - TOP_K_TOO_LARGE_ERROR, -) -from ax.exceptions.core import DataRequiredError, UserInputError - +from ax.analysis.analysis import AnalysisCardLevel +from ax.analysis.plotly.interaction import InteractionPlot +from ax.exceptions.core import UserInputError +from ax.service.ax_client import AxClient, ObjectiveProperties from ax.utils.common.testutils import TestCase -from ax.utils.testing.core_stubs import get_branin_experiment from ax.utils.testing.mock import mock_botorch_optimize -from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel -from gpytorch.kernels import RBFKernel -from plotly import graph_objects as go -class InteractionTest(TestCase): - def test_interaction_get_model_kwargs(self) -> None: - kwargs = get_model_kwargs( - num_parameters=3, - use_interaction=False, - torch_device=torch.device("cpu"), +class TestInteractionPlot(TestCase): + @mock_botorch_optimize + def setUp(self) -> None: + super().setUp() + self.client = AxClient() + self.client.create_experiment( + is_test=True, + name="foo", + parameters=[ + { + "name": "x", + "type": "range", + "bounds": [-1.0, 1.0], + }, + { + "name": "y", + "type": "range", + "bounds": [-1.0, 1.0], + }, + ], + objectives={"bar": ObjectiveProperties(minimize=True)}, ) - self.assertEqual(kwargs["covar_module_class"], OrthogonalAdditiveKernel) - covar_module_options = kwargs["covar_module_options"] - self.assertIsInstance(covar_module_options["base_kernel"], RBFKernel) - self.assertEqual(covar_module_options["dim"], 3) - # Checks that we can retrieve the modelbridge that has interaction terms - kwargs = get_model_kwargs( - num_parameters=5, - use_interaction=True, - torch_device=torch.device("cpu"), - ) - self.assertEqual(kwargs["covar_module_class"], OrthogonalAdditiveKernel) - self.assertIsInstance(kwargs["covar_module_options"]["base_kernel"], RBFKernel) + for _ in range(10): + parameterization, trial_index = self.client.get_next_trial() + self.client.complete_trial( + trial_index=trial_index, + raw_data={ + "bar": parameterization["x"] ** 2 + parameterization["y"] ** 2 + }, + ) - @mock_botorch_optimize - def test_interaction_analysis_without_components(self) -> None: - exp = get_branin_experiment(with_completed_trial=True) - analysis = InteractionPlot( - metric_name="branin", - fit_interactions=False, - num_mc_samples=11, + def test_compute(self) -> None: + analysis = InteractionPlot(metric_name="bar") + + # Test that it fails if no Experiment is provided + with self.assertRaisesRegex(UserInputError, "requires an Experiment"): + analysis.compute() + + card = analysis.compute( + experiment=self.client.experiment, + generation_strategy=self.client.generation_strategy, ) - card = analysis.compute(experiment=exp) - self.assertIsInstance(card, AnalysisCard) - self.assertIsInstance(card.blob, str) - self.assertIsInstance(card.df, pd.DataFrame) self.assertEqual( card.name, - "Interaction Analysis", - ) - self.assertEqual( - card.title, - "Feature Importance Analysis for branin", + "InteractionPlot", ) + self.assertEqual(card.title, "Interaction Analysis for bar") self.assertEqual( card.subtitle, - "Displays the most important features for branin by order of importance.", - ) - - # with interaction terms - analysis = InteractionPlot( - metric_name="branin", - fit_interactions=True, - num_mc_samples=11, + "Understand an Experiment's data as one- or two-dimensional additive " + "components with sparsity. Important components are visualized through " + "slice or contour plots", ) - card = analysis.compute(experiment=exp) - self.assertIsInstance(card, AnalysisCard) - self.assertIsInstance(card.blob, str) - self.assertIsInstance(card.df, pd.DataFrame) - self.assertEqual(len(card.df), 3) + self.assertEqual(card.level, AnalysisCardLevel.MID) self.assertEqual( - card.subtitle, - "Displays the most important features for branin by order of importance.", + {*card.df.columns}, + {"feature", "sensitivity"}, ) + self.assertIsNotNone(card.blob) + self.assertEqual(card.blob_annotation, "plotly") - with self.assertRaisesRegex(UserInputError, TOP_K_TOO_LARGE_ERROR.format("7")): - InteractionPlot(metric_name="branin", top_k=7, display_components=True) - - analysis = InteractionPlot(metric_name="branout", fit_interactions=False) - with self.assertRaisesRegex( - DataRequiredError, "StandardizeY` transform requires non-empty data." - ): - analysis.compute(experiment=exp) - - @mock_botorch_optimize - def test_interaction_with_components(self) -> None: - exp = get_branin_experiment(with_completed_trial=True) - analysis = InteractionPlot( - metric_name="branin", - fit_interactions=True, - display_components=True, - num_mc_samples=11, - ) - card = analysis.compute(experiment=exp) - self.assertIsInstance(card, AnalysisCard) - self.assertIsInstance(card.blob, str) - self.assertIsInstance(card.df, pd.DataFrame) - self.assertEqual(len(card.df), 3) - - analysis = InteractionPlot( - metric_name="branin", - fit_interactions=True, - display_components=True, - top_k=2, - num_mc_samples=11, - ) - card = analysis.compute(experiment=exp) - self.assertIsInstance(card, AnalysisCard) - self.assertEqual(len(card.df), 2) - analysis = InteractionPlot( - metric_name="branin", - fit_interactions=True, - display_components=True, - model_fit_seed=999, - num_mc_samples=11, - ) - card = analysis.compute(experiment=exp) - self.assertIsInstance(card, AnalysisCard) - self.assertEqual(len(card.df), 3) - - @mock_botorch_optimize - def test_generate_main_effect_component(self) -> None: - exp = get_branin_experiment(with_completed_trial=True) - analysis = InteractionPlot( - metric_name="branin", - fit_interactions=True, - display_components=True, - num_mc_samples=11, - ) - density = 13 - model = analysis.get_model(experiment=exp) - comp, _, _ = generate_main_effect_component( - model=model, - component="x1", - metric="branin", - density=density, - ) - self.assertIsInstance(comp, go.Scatter) - self.assertEqual(comp["x"].shape, (density,)) - self.assertEqual(comp["y"].shape, (density,)) - self.assertEqual(comp["name"], "x1") - - with self.assertRaisesRegex(KeyError, "braninandout"): - generate_main_effect_component( - model=model, - component="x1", - metric="braninandout", - density=density, - ) - - @mock_botorch_optimize - def test_generate_interaction_component(self) -> None: - exp = get_branin_experiment(with_completed_trial=True) - analysis = InteractionPlot( - metric_name="branin", - fit_interactions=True, - display_components=True, - num_mc_samples=11, - ) - density = 3 - model = analysis.get_model(experiment=exp) - comp, _, _ = generate_interaction_component( - model=model, - component_x="x1", - component_y="x2", - metric="branin", - density=density, - ) - self.assertIsInstance(comp, go.Contour) - self.assertEqual(comp["x"].shape, (density,)) - self.assertEqual(comp["y"].shape, (density,)) - self.assertEqual(comp["z"].shape, (density, density)) - self.assertEqual(comp["name"], "x1 & x2") - - with self.assertRaisesRegex(KeyError, "braninandout"): - generate_interaction_component( - model=model, - component_x="x1", - component_y="x2", - metric="braninandout", - density=density, - ) + fig = card.get_figure() + # Ensure all subplots are present + self.assertEqual(len(fig.data), 6)