From fa7843f815784f7e615943cd5669307e45c818e4 Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Tue, 17 Dec 2024 13:32:45 -0800 Subject: [PATCH 1/6] Introduce new Transform that adds metadata as parameters in an ObservationFeature (#3023) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3023 **Context:** The values corresponding to map keys are propagated as part of the ObservationFeatures' `metadata` dict field. We require a way to place it in the `parameters` dict field so that it can be used later on. This generalized transform is able to take user-specified entries from an `ObservationFeatures`'s `metadata` field and place it within its `parameters` field, and update the search space accordingly to reflect this. This implements a new transform, `MetadataToFloat`, that extracts specified fields from each `ObservationFeature` instance's metadata and incorporates them as parameters. Furthermore, it updates the search space to include the specified field as a `RangeParameter` with bounds determined by observations provided during initialization. This process involves analyzing the metadata of each observation feature and identifying relevant fields that need to be included in the search space. The bounds for these fields are then determined based on the observations provided during initialization. Differential Revision: D65430943 --- .../transforms/metadata_to_float.py | 143 ++++++++++++++ .../tests/test_metadata_to_float_transform.py | 180 ++++++++++++++++++ sphinx/source/modelbridge.rst | 9 + 3 files changed, 332 insertions(+) create mode 100644 ax/modelbridge/transforms/metadata_to_float.py create mode 100644 ax/modelbridge/transforms/tests/test_metadata_to_float_transform.py diff --git a/ax/modelbridge/transforms/metadata_to_float.py b/ax/modelbridge/transforms/metadata_to_float.py new file mode 100644 index 00000000000..d74af4604fe --- /dev/null +++ b/ax/modelbridge/transforms/metadata_to_float.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from __future__ import annotations + +from logging import Logger +from typing import Any, Iterable, Optional, SupportsFloat, TYPE_CHECKING + +from ax.core import ParameterType + +from ax.core.observation import Observation, ObservationFeatures +from ax.core.parameter import RangeParameter +from ax.core.search_space import SearchSpace +from ax.exceptions.core import DataRequiredError +from ax.modelbridge.transforms.base import Transform +from ax.models.types import TConfig +from ax.utils.common.logger import get_logger +from pyre_extensions import assert_is_instance, none_throws + +if TYPE_CHECKING: + # import as module to make sphinx-autodoc-typehints happy + from ax import modelbridge as modelbridge_module # noqa F401 + + +logger: Logger = get_logger(__name__) + + +class MetadataToFloat(Transform): + """ + This transform converts metadata from observation features into range (float) + parameters for a search space. + + It allows the user to specify the `config` with `parameters` as the key, where + each entry maps a metadata key to a dictionary of keyword arguments for the + corresponding RangeParameter constructor. + + Transform is done in-place. + """ + + DEFAULT_LOG_SCALE: bool = False + DEFAULT_LOGIT_SCALE: bool = False + DEFAULT_IS_FIDELITY: bool = False + ENFORCE_BOUNDS: bool = False + + def __init__( + self, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + config: TConfig | None = None, + ) -> None: + if observations is None or not observations: + raise DataRequiredError( + "`MetadataToRange` transform requires non-empty data." + ) + config = config or {} + self.parameters: dict[str, dict[str, Any]] = assert_is_instance( + config.get("parameters", {}), dict + ) + + self._parameter_list: list[RangeParameter] = [] + for name in self.parameters: + lb = ub = None # de facto bounds + for obs in observations: + obsf_metadata = none_throws(obs.features.metadata) + + val = float(assert_is_instance(obsf_metadata[name], SupportsFloat)) + + lb = min(val, lb) if lb is not None else val + ub = max(val, ub) if ub is not None else val + + lower: float = self.parameters[name].get("lower", lb) + upper: float = self.parameters[name].get("upper", ub) + + log_scale = self.parameters[name].get("log_scale", self.DEFAULT_LOG_SCALE) + logit_scale = self.parameters[name].get( + "logit_scale", self.DEFAULT_LOGIT_SCALE + ) + digits = self.parameters[name].get("digits") + is_fidelity = self.parameters[name].get( + "is_fidelity", self.DEFAULT_IS_FIDELITY + ) + + target_value = self.parameters[name].get("target_value") + + parameter = RangeParameter( + name=name, + parameter_type=ParameterType.FLOAT, + lower=lower, + upper=upper, + log_scale=log_scale, + logit_scale=logit_scale, + digits=digits, + is_fidelity=is_fidelity, + target_value=target_value, + ) + self._parameter_list.append(parameter) + + def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: + for parameter in self._parameter_list: + search_space.add_parameter(parameter.clone()) + return search_space + + def transform_observation_features( + self, observation_features: list[ObservationFeatures] + ) -> list[ObservationFeatures]: + for obsf in observation_features: + self._transform_observation_feature(obsf) + return observation_features + + def untransform_observation_features( + self, observation_features: list[ObservationFeatures] + ) -> list[ObservationFeatures]: + for obsf in observation_features: + obsf.metadata = obsf.metadata or {} + _transfer( + src=obsf.parameters, + dst=obsf.metadata, + keys=self.parameters.keys(), + ) + return observation_features + + def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: + _transfer( + src=none_throws(obsf.metadata), + dst=obsf.parameters, + keys=self.parameters.keys(), + ) + + +def _transfer( + src: dict[str, Any], + dst: dict[str, Any], + keys: Iterable[str], +) -> None: + """Transfer items in-place from one dictionary to another.""" + for key in keys: + dst[key] = src.pop(key) diff --git a/ax/modelbridge/transforms/tests/test_metadata_to_float_transform.py b/ax/modelbridge/transforms/tests/test_metadata_to_float_transform.py new file mode 100644 index 00000000000..7c49f1df099 --- /dev/null +++ b/ax/modelbridge/transforms/tests/test_metadata_to_float_transform.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from copy import deepcopy +from typing import Iterator + +import numpy as np +from ax.core.observation import Observation, ObservationData, ObservationFeatures +from ax.core.parameter import ParameterType, RangeParameter +from ax.core.search_space import SearchSpace +from ax.exceptions.core import DataRequiredError +from ax.modelbridge.transforms.metadata_to_float import MetadataToFloat +from ax.utils.common.testutils import TestCase +from pyre_extensions import assert_is_instance + + +WIDTHS = [2.0, 4.0, 8.0] +HEIGHTS = [4.0, 2.0, 8.0] +STEPS_ENDS = [1, 5, 3] + + +def _enumerate() -> Iterator[tuple[int, float, float, float]]: + yield from ( + (trial_index, width, height, float(i + 1)) + for trial_index, (width, height, steps_end) in enumerate( + zip(WIDTHS, HEIGHTS, STEPS_ENDS) + ) + for i in range(steps_end) + ) + + +class MetadataToFloatTransformTest(TestCase): + def setUp(self) -> None: + super().setUp() + + self.search_space = SearchSpace( + parameters=[ + RangeParameter( + name="width", + parameter_type=ParameterType.FLOAT, + lower=1, + upper=20, + ), + RangeParameter( + name="height", + parameter_type=ParameterType.FLOAT, + lower=1, + upper=20, + ), + ] + ) + + self.observations = [] + for trial_index, width, height, steps in _enumerate(): + obs_feat = ObservationFeatures( + trial_index=trial_index, + parameters={"width": width, "height": height}, + metadata={ + "foo": 42, + "bar": 3.0 * steps, + }, + ) + obs_data = ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ) + self.observations.append(Observation(features=obs_feat, data=obs_data)) + + self.t = MetadataToFloat( + observations=self.observations, + config={ + "parameters": {"bar": {"log_scale": True}}, + }, + ) + + def test_Init(self) -> None: + self.assertEqual(len(self.t._parameter_list), 1) + + p = self.t._parameter_list[0] + + # check that the parameter options are specified in a sensible manner + # by default if the user does not specify them explicitly + self.assertEqual(p.name, "bar") + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 3.0) + self.assertEqual(p.upper, 15.0) + self.assertTrue(p.log_scale) + self.assertFalse(p.logit_scale) + self.assertIsNone(p.digits) + self.assertFalse(p.is_fidelity) + self.assertIsNone(p.target_value) + + with self.assertRaisesRegex(DataRequiredError, "requires non-empty data"): + MetadataToFloat(search_space=None, observations=None) + with self.assertRaisesRegex(DataRequiredError, "requires non-empty data"): + MetadataToFloat(search_space=None, observations=[]) + + with self.subTest("infer parameter type"): + observations = [] + for trial_index, width, height, steps in _enumerate(): + obs_feat = ObservationFeatures( + trial_index=trial_index, + parameters={"width": width, "height": height}, + metadata={ + "foo": 42, + "bar": int(steps), + }, + ) + obs_data = ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ) + observations.append(Observation(features=obs_feat, data=obs_data)) + + t = MetadataToFloat( + observations=observations, + config={ + "parameters": {"bar": {}}, + }, + ) + self.assertEqual(len(t._parameter_list), 1) + + p = t._parameter_list[0] + + self.assertEqual(p.name, "bar") + self.assertEqual(p.parameter_type, ParameterType.INT) + self.assertEqual(p.lower, 1) + self.assertEqual(p.upper, 5) + self.assertFalse(p.log_scale) + self.assertFalse(p.logit_scale) + self.assertIsNone(p.digits) + self.assertFalse(p.is_fidelity) + self.assertIsNone(p.target_value) + + def test_TransformSearchSpace(self) -> None: + ss2 = deepcopy(self.search_space) + ss2 = self.t.transform_search_space(ss2) + + self.assertSetEqual( + set(ss2.parameters.keys()), + {"height", "width", "bar"}, + ) + + p = assert_is_instance(ss2.parameters["bar"], RangeParameter) + + self.assertEqual(p.name, "bar") + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 3.0) + self.assertEqual(p.upper, 15.0) + self.assertTrue(p.log_scale) + self.assertFalse(p.logit_scale) + self.assertIsNone(p.digits) + self.assertFalse(p.is_fidelity) + self.assertIsNone(p.target_value) + + def test_TransformObservationFeatures(self) -> None: + observation_features = [obs.features for obs in self.observations] + obs_ft2 = deepcopy(observation_features) + obs_ft2 = self.t.transform_observation_features(obs_ft2) + + self.assertEqual( + obs_ft2, + [ + ObservationFeatures( + trial_index=trial_index, + parameters={ + "width": width, + "height": height, + "bar": 3.0 * steps, + }, + metadata={"foo": 42}, + ) + for trial_index, width, height, steps in _enumerate() + ], + ) + obs_ft2 = self.t.untransform_observation_features(obs_ft2) + self.assertEqual(obs_ft2, observation_features) diff --git a/sphinx/source/modelbridge.rst b/sphinx/source/modelbridge.rst index 98a0c124cc7..b7edec63bce 100644 --- a/sphinx/source/modelbridge.rst +++ b/sphinx/source/modelbridge.rst @@ -310,6 +310,15 @@ Transforms :undoc-members: :show-inheritance: + +`ax.modelbridge.transforms.metadata\_to\_float` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: ax.modelbridge.transforms.metadata_to_float + :members: + :undoc-members: + :show-inheritance: + `ax.modelbridge.transforms.rounding` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From a9c2583a46deae041b97f3d00bb54b5c8ca33c05 Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Tue, 17 Dec 2024 13:35:31 -0800 Subject: [PATCH 2/6] Implements MapKeyToFloat, a subclass of the MetadataToFloat Transform that provides sensible defaults for MapData (#3155) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3155 This adds a specialization of the `MetadataToFloat` Transform, `MapKeyToFloat`, that provides sensible default settings to allow for intercepting map metric data appearing in the ObservationFeatures' metadata. Additionally, for the purposes of specifying `fixed_features` down the line, when `_transform_observation_feature` is given an empty `ObservationFeatures` (more specifically, an `ObservationFeatures` with an empty `parameters` dict), it will populate it with the *upper bound* associated with each metadata key. Differential Revision: D66945078 --- ax/modelbridge/transforms/map_key_to_float.py | 54 ++++++ .../tests/test_map_key_to_float_transform.py | 174 ++++++++++++++++++ sphinx/source/modelbridge.rst | 9 + 3 files changed, 237 insertions(+) create mode 100644 ax/modelbridge/transforms/map_key_to_float.py create mode 100644 ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py diff --git a/ax/modelbridge/transforms/map_key_to_float.py b/ax/modelbridge/transforms/map_key_to_float.py new file mode 100644 index 00000000000..1ec645aff51 --- /dev/null +++ b/ax/modelbridge/transforms/map_key_to_float.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, Optional, TYPE_CHECKING + +from ax.core.map_metric import MapMetric +from ax.core.observation import Observation, ObservationFeatures +from ax.core.search_space import SearchSpace +from ax.modelbridge.transforms.metadata_to_range import MetadataToFloat +from ax.models.types import TConfig +from pyre_extensions import assert_is_instance + +if TYPE_CHECKING: + # import as module to make sphinx-autodoc-typehints happy + from ax import modelbridge as modelbridge_module # noqa F401 + + +class MapKeyToFloat(MetadataToFloat): + DEFAULT_LOG_SCALE: bool = True + DEFAULT_MAP_KEY: str = MapMetric.map_key_info.key + + def __init__( + self, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + config: TConfig | None = None, + ) -> None: + config = config or {} + self.parameters: dict[str, dict[str, Any]] = assert_is_instance( + config.setdefault("parameters", {}), dict + ) + # TODO[tiao]: raise warning if `DEFAULT_MAP_KEY` is already in keys(?) + self.parameters.setdefault(self.DEFAULT_MAP_KEY, {}) + super().__init__( + search_space=search_space, + observations=observations, + modelbridge=modelbridge, + config=config, + ) + + def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: + if not obsf.parameters: + for p in self._parameter_list: + # TODO[tiao]: can we use be p.target_value? + # (not its original intended use but could be advantageous) + obsf.parameters[p.name] = p.upper + return + super()._transform_observation_feature(obsf) diff --git a/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py b/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py new file mode 100644 index 00000000000..c1ddb8c1d34 --- /dev/null +++ b/ax/modelbridge/transforms/tests/test_map_key_to_float_transform.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from copy import deepcopy +from typing import Iterator + +import numpy as np +from ax.core.observation import Observation, ObservationData, ObservationFeatures +from ax.core.parameter import ParameterType, RangeParameter +from ax.core.search_space import SearchSpace +from ax.modelbridge.transforms.map_key_to_float import MapKeyToFloat +from ax.utils.common.testutils import TestCase +from pyre_extensions import assert_is_instance + + +WIDTHS = [2.0, 4.0, 8.0] +HEIGHTS = [4.0, 2.0, 8.0] +STEPS_ENDS = [1, 5, 3] + + +def _enumerate() -> Iterator[tuple[int, float, float, float]]: + yield from ( + (trial_index, width, height, float(i + 1)) + for trial_index, (width, height, steps_end) in enumerate( + zip(WIDTHS, HEIGHTS, STEPS_ENDS) + ) + for i in range(steps_end) + ) + + +class MapKeyToFloatTransformTest(TestCase): + def setUp(self) -> None: + super().setUp() + + self.search_space = SearchSpace( + parameters=[ + RangeParameter( + name="width", + parameter_type=ParameterType.FLOAT, + lower=1, + upper=20, + ), + RangeParameter( + name="height", + parameter_type=ParameterType.FLOAT, + lower=1, + upper=20, + ), + ] + ) + + self.observations = [] + for trial_index, width, height, steps in _enumerate(): + obs_feat = ObservationFeatures( + trial_index=trial_index, + parameters={"width": width, "height": height}, + metadata={ + "foo": 42, + MapKeyToFloat.DEFAULT_MAP_KEY: steps, + }, + ) + obs_data = ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ) + self.observations.append(Observation(features=obs_feat, data=obs_data)) + + # does not require explicitly specifying `config` + self.t = MapKeyToFloat( + observations=self.observations, + ) + + def test_Init(self) -> None: + self.assertEqual(len(self.t._parameter_list), 1) + + p = self.t._parameter_list[0] + + self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY) + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 1.0) + self.assertEqual(p.upper, 5.0) + self.assertTrue(p.log_scale) + + with self.subTest("infer parameter type"): + observations = [] + for trial_index, width, height, steps in _enumerate(): + obs_feat = ObservationFeatures( + trial_index=trial_index, + parameters={"width": width, "height": height}, + metadata={ + "foo": 42, + "bar": int(steps), + }, + ) + obs_data = ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ) + observations.append(Observation(features=obs_feat, data=obs_data)) + + # test that one is able to override default config + with self.subTest(msg="override default config"): + t = MapKeyToFloat( + observations=self.observations, + config={ + "parameters": {MapKeyToFloat.DEFAULT_MAP_KEY: {"log_scale": False}} + }, + ) + self.assertDictEqual(t.parameters, {"steps": {"log_scale": False}}) + + self.assertEqual(len(t._parameter_list), 1) + + p = t._parameter_list[0] + + self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY) + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 1.0) + self.assertEqual(p.upper, 5.0) + self.assertFalse(p.log_scale) + + def test_TransformSearchSpace(self) -> None: + ss2 = deepcopy(self.search_space) + ss2 = self.t.transform_search_space(ss2) + + self.assertSetEqual( + set(ss2.parameters), + {"height", "width", MapKeyToFloat.DEFAULT_MAP_KEY}, + ) + + p = assert_is_instance( + ss2.parameters[MapKeyToFloat.DEFAULT_MAP_KEY], RangeParameter + ) + + self.assertEqual(p.name, MapKeyToFloat.DEFAULT_MAP_KEY) + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 1.0) + self.assertEqual(p.upper, 5.0) + self.assertTrue(p.log_scale) + + def test_TransformObservationFeatures(self) -> None: + observation_features = [obs.features for obs in self.observations] + obs_ft2 = deepcopy(observation_features) + obs_ft2 = self.t.transform_observation_features(obs_ft2) + + self.assertEqual( + obs_ft2, + [ + ObservationFeatures( + trial_index=trial_index, + parameters={ + "width": width, + "height": height, + MapKeyToFloat.DEFAULT_MAP_KEY: steps, + }, + metadata={"foo": 42}, + ) + for trial_index, width, height, steps in _enumerate() + ], + ) + obs_ft2 = self.t.untransform_observation_features(obs_ft2) + self.assertEqual(obs_ft2, observation_features) + + def test_TransformObservationFeaturesWithEmptyParameters(self) -> None: + obsf = ObservationFeatures(parameters={}) + self.t.transform_observation_features([obsf]) + + p = self.t._parameter_list[0] + self.assertEqual( + obsf, + ObservationFeatures(parameters={MapKeyToFloat.DEFAULT_MAP_KEY: p.upper}), + ) diff --git a/sphinx/source/modelbridge.rst b/sphinx/source/modelbridge.rst index b7edec63bce..c35831f380a 100644 --- a/sphinx/source/modelbridge.rst +++ b/sphinx/source/modelbridge.rst @@ -319,6 +319,15 @@ Transforms :undoc-members: :show-inheritance: + +`ax.modelbridge.transforms.map\_key\_to\_float` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: ax.modelbridge.transforms.map_key_to_float + :members: + :undoc-members: + :show-inheritance: + `ax.modelbridge.transforms.rounding` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From 2f0dc179b47e09ac115910f55005ed6863e28c80 Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Tue, 17 Dec 2024 13:35:31 -0800 Subject: [PATCH 3/6] Simplified and optimized logic for calculating per-metric subsampling rate for MapData (#3106) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3106 This refines the logic for calculating per-metric subsampling rates in `MapData.subsample` and incorporates a (probably premature) performance optimization, achieved by utilizing binary search on a sorted list instead of linear search. Differential Revision: D66366076 Reviewed By: Balandat --- ax/core/map_data.py | 69 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/ax/core/map_data.py b/ax/core/map_data.py index 5b2685b54d1..b55e57344c8 100644 --- a/ax/core/map_data.py +++ b/ax/core/map_data.py @@ -7,12 +7,14 @@ from __future__ import annotations +from bisect import bisect_right from collections.abc import Iterable, Sequence from copy import deepcopy from logging import Logger from typing import Any, Generic, TypeVar import numpy as np +import numpy.typing as npt import pandas as pd from ax.core.data import Data from ax.core.types import TMapTrialEvaluation @@ -411,6 +413,48 @@ def subsample( ) +def _ceil_divide( + a: int | np.int_ | npt.NDArray[np.int_], b: int | np.int_ | npt.NDArray[np.int_] +) -> np.int_ | npt.NDArray[np.int_]: + return -np.floor_divide(-a, b) + + +def _subsample_rate( + map_df: pd.DataFrame, + keep_every: int | None = None, + limit_rows_per_group: int | None = None, + limit_rows_per_metric: int | None = None, +) -> int: + if keep_every is not None: + return keep_every + + grouped_map_df = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS) + group_sizes = grouped_map_df.size() + max_rows = group_sizes.max() + + if limit_rows_per_group is not None: + return _ceil_divide(max_rows, limit_rows_per_group).item() + + if limit_rows_per_metric is not None: + # search for the `keep_every` such that when you apply it to each group, + # the total number of rows is smaller than `limit_rows_per_metric`. + ks = np.arange(max_rows, 0, -1) + # total sizes in ascending order + total_sizes = np.sum( + _ceil_divide(group_sizes.values, ks[..., np.newaxis]), axis=1 + ) + # binary search + i = bisect_right(total_sizes, limit_rows_per_metric) + # if no such `k` is found, then `derived_keep_every` stays as 1. + if i > 0: + return ks[i - 1].item() + + raise ValueError( + "at least one of `keep_every`, `limit_rows_per_group`, " + "or `limit_rows_per_metric` must be specified." + ) + + def _subsample_one_metric( map_df: pd.DataFrame, map_key: str | None = None, @@ -420,30 +464,21 @@ def _subsample_one_metric( include_first_last: bool = True, ) -> pd.DataFrame: """Helper function to subsample a dataframe that holds a single metric.""" - derived_keep_every = 1 - if keep_every is not None: - derived_keep_every = keep_every - elif limit_rows_per_group is not None: - max_rows = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().max() - derived_keep_every = np.ceil(max_rows / limit_rows_per_group) - elif limit_rows_per_metric is not None: - group_sizes = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().to_numpy() - # search for the `keep_every` such that when you apply it to each group, - # the total number of rows is smaller than `limit_rows_per_metric`. - for k in range(1, group_sizes.max() + 1): - if (np.ceil(group_sizes / k)).sum() <= limit_rows_per_metric: - derived_keep_every = k - break - # if no such `k` is found, then `derived_keep_every` stays as 1. + + grouped_map_df = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS) + + derived_keep_every = _subsample_rate( + map_df, keep_every, limit_rows_per_group, limit_rows_per_metric + ) if derived_keep_every <= 1: filtered_map_df = map_df else: filtered_dfs = [] - for _, df_g in map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS): + for _, df_g in grouped_map_df: df_g = df_g.sort_values(map_key) if include_first_last: - rows_per_group = int(np.ceil(len(df_g) / derived_keep_every)) + rows_per_group = _ceil_divide(len(df_g), derived_keep_every) linspace_idcs = np.linspace(0, len(df_g) - 1, rows_per_group) idcs = np.round(linspace_idcs).astype(int) filtered_df = df_g.iloc[idcs] From fbe48774d99a6d4ed54e5f152629a4d4862a7dfb Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Tue, 17 Dec 2024 13:35:31 -0800 Subject: [PATCH 4/6] Include progression information as metadata when transforming (Map)Data to Observations (#3001) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3001 This updates `observations_from_data` to include progression info as observation feature metadata by default. More specifically: - Updates `observations_from_data` to subsume behavior of `observations_from_map_data` as special case. - Updates calls to `observations_from_map_data` to instead call `observations_from_data` - Removes `observations_from_map_data` which is used exclusively by `MapTorchModelBridge` Differential Revision: D65255312 Reviewed By: saitcakmak --- ax/core/observation.py | 169 +++++++++--------------------- ax/core/tests/test_observation.py | 3 +- ax/modelbridge/base.py | 1 + ax/modelbridge/map_torch.py | 11 +- 4 files changed, 55 insertions(+), 129 deletions(-) diff --git a/ax/core/observation.py b/ax/core/observation.py index 50bd1550572..a6284405837 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -426,7 +426,7 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]: feature_cols = OBS_COLS.intersection(data.df.columns) # note we use this check, rather than isinstance, since # only some Modelbridges (e.g. MapTorchModelBridge) - # use observations_from_map_data, which is required + # use observations_from_data, which is required # to properly handle MapData features (e.g. fidelity). if is_map_data: data = checked_cast(MapData, data) @@ -448,174 +448,103 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]: def observations_from_data( experiment: experiment.Experiment, - data: Data, - statuses_to_include: set[TrialStatus] | None = None, - statuses_to_include_map_metric: set[TrialStatus] | None = None, -) -> list[Observation]: - """Convert Data to observations. - - Converts a Data object to a list of Observation objects. Pulls arm parameters from - from experiment. Overrides fidelity parameters in the arm with those found in the - Data object. - - Uses a diagonal covariance matrix across metric_names. - - Args: - experiment: Experiment with arm parameters. - data: Data of observations. - statuses_to_include: data from non-MapMetrics will only be included for trials - with statuses in this set. Defaults to all statuses except abandoned. - statuses_to_include_map_metric: data from MapMetrics will only be included for - trials with statuses in this set. Defaults to completed status only. - - Returns: - List of Observation objects. - """ - if statuses_to_include is None: - statuses_to_include = NON_ABANDONED_STATUSES - if statuses_to_include_map_metric is None: - statuses_to_include_map_metric = {TrialStatus.COMPLETED} - feature_cols = get_feature_cols(data) - observations = [] - arm_name_only = len(feature_cols) == 1 # there will always be an arm name - # One DataFrame where all rows have all features. - isnull = data.df[feature_cols].isnull() - isnull_any = isnull.any(axis=1) - incomplete_df_cols = isnull[isnull_any].any() - - # Get the incomplete_df columns that are complete, and usable as groupby keys. - complete_feature_cols = list( - OBS_COLS.intersection(incomplete_df_cols.index[~incomplete_df_cols]) - ) - - if set(feature_cols) == set(complete_feature_cols): - complete_df = data.df - incomplete_df = None - else: - # The groupby and filter is expensive, so do it only if we have to. - grouped = data.df.groupby(by=complete_feature_cols) - complete_df = grouped.filter(lambda r: ~r[feature_cols].isnull().any().any()) - incomplete_df = grouped.filter(lambda r: r[feature_cols].isnull().any().any()) - - # Get Observations from complete_df - observations.extend( - _observations_from_dataframe( - experiment=experiment, - df=complete_df, - cols=feature_cols, - arm_name_only=arm_name_only, - statuses_to_include=statuses_to_include, - statuses_to_include_map_metric=statuses_to_include_map_metric, - map_keys=[], - ) - ) - if incomplete_df is not None: - # Get Observations from incomplete_df - observations.extend( - _observations_from_dataframe( - experiment=experiment, - df=incomplete_df, - cols=complete_feature_cols, - arm_name_only=arm_name_only, - statuses_to_include=statuses_to_include, - statuses_to_include_map_metric=statuses_to_include_map_metric, - map_keys=[], - ) - ) - return observations - - -def observations_from_map_data( - experiment: experiment.Experiment, - map_data: MapData, + data: Data | MapData, statuses_to_include: set[TrialStatus] | None = None, statuses_to_include_map_metric: set[TrialStatus] | None = None, map_keys_as_parameters: bool = False, limit_rows_per_metric: int | None = None, limit_rows_per_group: int | None = None, ) -> list[Observation]: - """Convert MapData to observations. + """Convert Data (or MapData) to observations. - Converts a MapData object to a list of Observation objects. Pulls arm parameters - from experiment. Overrides fidelity parameters in the arm with those found in the - Data object. + Converts a Data (or MapData) object to a list of Observation objects. + Pulls arm parameters from from experiment. Overrides fidelity parameters + in the arm with those found in the Data object. Uses a diagonal covariance matrix across metric_names. Args: experiment: Experiment with arm parameters. - map_data: MapData of observations. + data: Data (or MapData) of observations. statuses_to_include: data from non-MapMetrics will only be included for trials with statuses in this set. Defaults to all statuses except abandoned. statuses_to_include_map_metric: data from MapMetrics will only be included for trials with statuses in this set. Defaults to all statuses except abandoned. map_keys_as_parameters: Whether map_keys should be returned as part of the parameters of the Observation objects. - limit_rows_per_metric: If specified, uses MapData.subsample() with + limit_rows_per_metric: If specified, and if data is an instance of MapData, + uses MapData.subsample() with `limit_rows_per_metric` equal to the specified value on the first map_key (map_data.map_keys[0]) to subsample the MapData. This is useful in, e.g., cases where learning curves are frequently updated, leading to an intractable number of Observation objects created. - limit_rows_per_group: If specified, uses MapData.subsample() with + limit_rows_per_group: If specified, and if data is an instance of MapData, + uses MapData.subsample() with `limit_rows_per_group` equal to the specified value on the first map_key (map_data.map_keys[0]) to subsample the MapData. Returns: List of Observation objects. """ + is_map_data = isinstance(data, MapData) + if statuses_to_include is None: statuses_to_include = NON_ABANDONED_STATUSES if statuses_to_include_map_metric is None: statuses_to_include_map_metric = NON_ABANDONED_STATUSES - if limit_rows_per_metric is not None or limit_rows_per_group is not None: - map_data = map_data.subsample( - map_key=map_data.map_keys[0], - limit_rows_per_metric=limit_rows_per_metric, - limit_rows_per_group=limit_rows_per_group, - include_first_last=True, - ) - feature_cols = get_feature_cols(map_data, is_map_data=True) - observations = [] + + map_keys = [] + obs_cols = OBS_COLS + if is_map_data: + data = checked_cast(MapData, data) + + if limit_rows_per_metric is not None or limit_rows_per_group is not None: + data = data.subsample( + map_key=data.map_keys[0], + limit_rows_per_metric=limit_rows_per_metric, + limit_rows_per_group=limit_rows_per_group, + include_first_last=True, + ) + + map_keys.extend(data.map_keys) + obs_cols = obs_cols.union(data.map_keys) + df = data.map_df + else: + df = data.df + + feature_cols = get_feature_cols(data, is_map_data=is_map_data) + arm_name_only = len(feature_cols) == 1 # there will always be an arm name # One DataFrame where all rows have all features. - isnull = map_data.map_df[feature_cols].isnull() + isnull = df[feature_cols].isnull() isnull_any = isnull.any(axis=1) incomplete_df_cols = isnull[isnull_any].any() # Get the incomplete_df columns that are complete, and usable as groupby keys. - obs_cols_and_map = OBS_COLS.union(map_data.map_keys) complete_feature_cols = list( - obs_cols_and_map.intersection(incomplete_df_cols.index[~incomplete_df_cols]) + obs_cols.intersection(incomplete_df_cols.index[~incomplete_df_cols]) ) if set(feature_cols) == set(complete_feature_cols): - complete_df = map_data.map_df + complete_df = df incomplete_df = None else: # The groupby and filter is expensive, so do it only if we have to. - grouped = map_data.map_df.groupby( - by=( - complete_feature_cols - if len(complete_feature_cols) > 1 - else complete_feature_cols[0] - ) - ) + grouped = df.groupby(by=complete_feature_cols) complete_df = grouped.filter(lambda r: ~r[feature_cols].isnull().any().any()) incomplete_df = grouped.filter(lambda r: r[feature_cols].isnull().any().any()) # Get Observations from complete_df - observations.extend( - _observations_from_dataframe( - experiment=experiment, - df=complete_df, - cols=feature_cols, - arm_name_only=arm_name_only, - map_keys=map_data.map_keys, - statuses_to_include=statuses_to_include, - statuses_to_include_map_metric=statuses_to_include_map_metric, - map_keys_as_parameters=map_keys_as_parameters, - ) + observations = _observations_from_dataframe( + experiment=experiment, + df=complete_df, + cols=feature_cols, + arm_name_only=arm_name_only, + map_keys=map_keys, + statuses_to_include=statuses_to_include, + statuses_to_include_map_metric=statuses_to_include_map_metric, + map_keys_as_parameters=map_keys_as_parameters, ) if incomplete_df is not None: # Get Observations from incomplete_df @@ -625,7 +554,7 @@ def observations_from_map_data( df=incomplete_df, cols=complete_feature_cols, arm_name_only=arm_name_only, - map_keys=map_data.map_keys, + map_keys=map_keys, statuses_to_include=statuses_to_include, statuses_to_include_map_metric=statuses_to_include_map_metric, map_keys_as_parameters=map_keys_as_parameters, diff --git a/ax/core/tests/test_observation.py b/ax/core/tests/test_observation.py index 2c304353502..849cc69ab2f 100644 --- a/ax/core/tests/test_observation.py +++ b/ax/core/tests/test_observation.py @@ -24,7 +24,6 @@ ObservationData, ObservationFeatures, observations_from_data, - observations_from_map_data, recombine_observations, separate_observations, ) @@ -475,7 +474,7 @@ def test_ObservationsFromMapData(self) -> None: MapKeyInfo(key="timestamp", default_value=0.0), ], ) - observations = observations_from_map_data(experiment, data) + observations = observations_from_data(experiment, data) self.assertEqual(len(observations), 3) diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index 54346c39136..798bc76f43b 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -297,6 +297,7 @@ def _prepare_observations( data=data, statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, + map_keys_as_parameters=False, ) def _transform_data( diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index 9fa0f119147..8acbd31746d 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -20,7 +20,7 @@ Observation, ObservationData, ObservationFeatures, - observations_from_map_data, + observations_from_data, separate_observations, ) from ax.core.optimization_config import OptimizationConfig @@ -252,19 +252,16 @@ def _array_to_observation_features( def _prepare_observations( self, experiment: Experiment | None, data: Data | None ) -> list[Observation]: - """The difference b/t this method and ModelBridge._prepare_observations(...) - is that this one uses `observations_from_map_data`. - """ if experiment is None or data is None: return [] - return observations_from_map_data( + return observations_from_data( experiment=experiment, - map_data=data, # pyre-ignore[6]: Checked in __init__. - map_keys_as_parameters=True, + data=data, limit_rows_per_metric=self._map_data_limit_rows_per_metric, limit_rows_per_group=self._map_data_limit_rows_per_group, statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, + map_keys_as_parameters=True, ) def _compute_in_design( From f36a9213d76172cdf653046ee01e4aadb568e486 Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Tue, 17 Dec 2024 13:35:31 -0800 Subject: [PATCH 5/6] Adds method to retain the N most recently observed values from MapData, [Ax][WIP/Not Ready] Putting it all together (subclass TorchModelBridge) (#3112) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3112 * Provide a new method `latest` for `MapData` to retrieve the *n* most recently observed values for each (arm, metric) group, determined by the `map_key` values, where higher implies more recent. * Update `observations_from_data` to optionally utilize `latest` and retain only the most recently observed *n* values (the new option, if specified alongside the existing subsampling options, will now take precedence). * Modify the "upcast" `MapData.df` property to leverage `latest`, which is a special case with *n=1*. * Revise the docstring to reflect changes in the pertinent methods, as well as update related methods like `subsample` to ensure uniform and consistent writing. Differential Revision: D66434621 --- ax/core/map_data.py | 56 +++++++++++++++++++++---- ax/core/observation.py | 44 ++++++++++++-------- ax/core/tests/test_map_data.py | 75 +++++++++++++++++++++++++++++++++- ax/modelbridge/base.py | 1 + 4 files changed, 150 insertions(+), 26 deletions(-) diff --git a/ax/core/map_data.py b/ax/core/map_data.py index b55e57344c8..3ec9e8f1044 100644 --- a/ax/core/map_data.py +++ b/ax/core/map_data.py @@ -277,15 +277,15 @@ def from_multiple_data( def df(self) -> pd.DataFrame: """Returns a Data shaped DataFrame""" - # If map_keys is empty just return the df if self._memo_df is not None: return self._memo_df + # If map_keys is empty just return the df if len(self.map_keys) == 0: return self.map_df - self._memo_df = self.map_df.sort_values(self.map_keys).drop_duplicates( - MapData.DEDUPLICATE_BY_COLUMNS, keep="last" + self._memo_df = _tail( + map_df=self.map_df, map_keys=self.map_keys, n=1, sort=True ) return self._memo_df @@ -339,6 +339,32 @@ def clone(self) -> MapData: description=self.description, ) + def latest( + self, + map_keys: list[str] | None = None, + rows_per_group: int = 1, + ) -> MapData: + """Return a new MapData with the most recently observed `rows_per_group` + rows for each (arm, metric) group, determined by the `map_key` values, + where higher implies more recent. + + This function considers only the relative ordering of the `map_key` values, + making it most suitable when these values are equally spaced. + + If `rows_per_group` is greater than the number of rows in a given + (arm, metric) group, then all rows are returned. + """ + if map_keys is None: + map_keys = self.map_keys + + return MapData( + df=_tail( + map_df=self.map_df, map_keys=map_keys, n=rows_per_group, sort=True + ), + map_key_infos=self.map_key_infos, + description=self.description, + ) + def subsample( self, map_key: str | None = None, @@ -347,11 +373,13 @@ def subsample( limit_rows_per_metric: int | None = None, include_first_last: bool = True, ) -> MapData: - """Subsample the `map_key` column in an equally-spaced manner (if there is - a `self.map_keys` is length one, then `map_key` can be set to None). The - values of the `map_key` column are not taken into account, so this function - is most reasonable when those values are equally-spaced. There are three - ways that this can be done: + """Return a new MapData that subsamples the `map_key` column in an + equally-spaced manner. If `self.map_keys` has a length of one, `map_key` + can be set to None. This function considers only the relative ordering + of the `map_key` values, making it most suitable when these values are + equally spaced. + + There are three ways that this can be done: 1. If `keep_every = k` is set, then every kth row of the DataFrame in the `map_key` column is kept after grouping by `DEDUPLICATE_BY_COLUMNS`. In other words, every kth step of each (arm, metric) will be kept. @@ -455,6 +483,18 @@ def _subsample_rate( ) +def _tail( + map_df: pd.DataFrame, + map_keys: list[str], + n: int = 1, + sort: bool = True, +) -> pd.DataFrame: + df = map_df.sort_values(map_keys).groupby(MapData.DEDUPLICATE_BY_COLUMNS).tail(n) + if sort: + df.sort_values(MapData.DEDUPLICATE_BY_COLUMNS, inplace=True) + return df + + def _subsample_one_metric( map_df: pd.DataFrame, map_key: str | None = None, diff --git a/ax/core/observation.py b/ax/core/observation.py index a6284405837..94ee013d7a5 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -452,6 +452,7 @@ def observations_from_data( statuses_to_include: set[TrialStatus] | None = None, statuses_to_include_map_metric: set[TrialStatus] | None = None, map_keys_as_parameters: bool = False, + latest_rows_per_group: int | None = None, limit_rows_per_metric: int | None = None, limit_rows_per_group: int | None = None, ) -> list[Observation]: @@ -472,17 +473,21 @@ def observations_from_data( trials with statuses in this set. Defaults to all statuses except abandoned. map_keys_as_parameters: Whether map_keys should be returned as part of the parameters of the Observation objects. - limit_rows_per_metric: If specified, and if data is an instance of MapData, - uses MapData.subsample() with - `limit_rows_per_metric` equal to the specified value on the first - map_key (map_data.map_keys[0]) to subsample the MapData. This is - useful in, e.g., cases where learning curves are frequently - updated, leading to an intractable number of Observation objects - created. - limit_rows_per_group: If specified, and if data is an instance of MapData, - uses MapData.subsample() with - `limit_rows_per_group` equal to the specified value on the first - map_key (map_data.map_keys[0]) to subsample the MapData. + latest_rows_per_group: If specified and data is an instance of MapData, + uses MapData.latest() with `rows_per_group=latest_rows_per_group` to + retrieve the most recent rows for each group. Useful in cases where + learning curves are frequently updated, preventing an excessive + number of Observation objects. Overrides `limit_rows_per_metric` + and `limit_rows_per_group`. + limit_rows_per_metric: If specified and data is an instance of MapData, + uses MapData.subsample() with `limit_rows_per_metric` on the first + map_key (map_data.map_keys[0]) to subsample the MapData. Useful for + managing the number of Observation objects when learning curves are + frequently updated. Ignored if `latest_rows_per_group` is specified. + limit_rows_per_group: If specified and data is an instance of MapData, + uses MapData.subsample() with `limit_rows_per_group` on the first + map_key (map_data.map_keys[0]) to subsample the MapData. Ignored if + `latest_rows_per_group` is specified. Returns: List of Observation objects. @@ -499,13 +504,18 @@ def observations_from_data( if is_map_data: data = checked_cast(MapData, data) - if limit_rows_per_metric is not None or limit_rows_per_group is not None: - data = data.subsample( - map_key=data.map_keys[0], - limit_rows_per_metric=limit_rows_per_metric, - limit_rows_per_group=limit_rows_per_group, - include_first_last=True, + if latest_rows_per_group is not None: + data = data.latest( + map_keys=data.map_keys, rows_per_group=latest_rows_per_group ) + else: + if limit_rows_per_metric is not None or limit_rows_per_group is not None: + data = data.subsample( + map_key=data.map_keys[0], + limit_rows_per_metric=limit_rows_per_metric, + limit_rows_per_group=limit_rows_per_group, + include_first_last=True, + ) map_keys.extend(data.map_keys) obs_cols = obs_cols.union(data.map_keys) diff --git a/ax/core/tests/test_map_data.py b/ax/core/tests/test_map_data.py index ce0576e295c..0b4f1f5fd22 100644 --- a/ax/core/tests/test_map_data.py +++ b/ax/core/tests/test_map_data.py @@ -6,6 +6,7 @@ # pyre-strict +import numpy as np import pandas as pd from ax.core.data import Data from ax.core.map_data import MapData, MapKeyInfo @@ -236,7 +237,17 @@ def test_upcast(self) -> None: self.assertIsNotNone(fresh._memo_df) # Assert df is cached after first call - def test_subsample(self) -> None: + self.assertTrue( + fresh.df.equals( + fresh.map_df.sort_values(fresh.map_keys).drop_duplicates( + MapData.DEDUPLICATE_BY_COLUMNS, keep="last" + ) + ) + ) + + def test_latest(self) -> None: + seed = 8888 + arm_names = ["0_0", "1_0", "2_0", "3_0"] max_epochs = [25, 50, 75, 100] metric_names = ["a", "b"] @@ -259,6 +270,68 @@ def test_subsample(self) -> None: ) large_map_data = MapData(df=large_map_df, map_key_infos=self.map_key_infos) + shuffled_large_map_df = large_map_data.map_df.groupby( + MapData.DEDUPLICATE_BY_COLUMNS + ).sample(frac=1, random_state=seed) + shuffled_large_map_data = MapData( + df=shuffled_large_map_df, map_key_infos=self.map_key_infos + ) + + for rows_per_group in [1, 40]: + large_map_data_latest = large_map_data.latest(rows_per_group=rows_per_group) + + if rows_per_group == 1: + self.assertTrue( + large_map_data_latest.map_df.groupby("metric_name") + .epoch.transform(lambda col: set(col) == set(max_epochs)) + .all() + ) + + # when rows_per_group is larger than the number of rows + # actually observed in a group + actual_rows_per_group = large_map_data_latest.map_df.groupby( + MapData.DEDUPLICATE_BY_COLUMNS + ).size() + expected_rows_per_group = np.minimum( + large_map_data_latest.map_df.groupby( + MapData.DEDUPLICATE_BY_COLUMNS + ).epoch.max(), + rows_per_group, + ) + self.assertTrue(actual_rows_per_group.equals(expected_rows_per_group)) + + # behavior should be consistent even if map_keys are not in ascending order + shuffled_large_map_data_latest = shuffled_large_map_data.latest( + rows_per_group=rows_per_group + ) + self.assertTrue( + shuffled_large_map_data_latest.map_df.equals( + large_map_data_latest.map_df + ) + ) + + def test_subsample(self) -> None: + arm_names = ["0_0", "1_0", "2_0", "3_0"] + max_epochs = [25, 50, 75, 100] + metric_names = ["a", "b"] + large_map_df = pd.DataFrame( + [ + { + "arm_name": arm_name, + "epoch": epoch + 1, + "mean": epoch * 0.1, + "sem": 0.1, + "trial_index": trial_index, + "metric_name": metric_name, + } + for metric_name in metric_names + for trial_index, (arm_name, max_epoch) in enumerate( + zip(arm_names, max_epochs) + ) + for epoch in range(max_epoch) + ] + ) + large_map_data = MapData(df=large_map_df, map_key_infos=self.map_key_infos) large_map_df_sparse_metric = pd.DataFrame( [ { diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index 798bc76f43b..c2374c34f8d 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -295,6 +295,7 @@ def _prepare_observations( return observations_from_data( experiment=experiment, data=data, + latest_rows_per_group=1, statuses_to_include=self.statuses_to_fit, statuses_to_include_map_metric=self.statuses_to_fit_map_metric, map_keys_as_parameters=False, From d8b8dd0f4d475462bbced135c0c10a14be7324af Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Tue, 17 Dec 2024 14:25:27 -0800 Subject: [PATCH 6/6] Update `get_sobol_mbm_generation_strategy` to allow specification of transforms Summary: Allows the user to customize the list of transforms to use for the Sobol-MBM benchmarking method. Differential Revision: D67357004 --- ax/benchmark/methods/modular_botorch.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ax/benchmark/methods/modular_botorch.py b/ax/benchmark/methods/modular_botorch.py index d1f27ef17ad..65d060c0679 100644 --- a/ax/benchmark/methods/modular_botorch.py +++ b/ax/benchmark/methods/modular_botorch.py @@ -5,12 +5,13 @@ # pyre-strict -from typing import Any +from typing import Any, Sequence from ax.benchmark.benchmark_method import BenchmarkMethod from ax.modelbridge.generation_node import GenerationStep from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.registry import Models +from ax.modelbridge.transforms.base import Transform from ax.models.torch.botorch_modular.surrogate import SurrogateSpec from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.analytic import LogExpectedImprovement @@ -39,6 +40,7 @@ def get_sobol_mbm_generation_strategy( model_cls: type[Model], acquisition_cls: type[AcquisitionFunction], + transforms: Sequence[type[Transform]] | None, name: str | None = None, num_sobol_trials: int = 5, model_gen_kwargs: dict[str, Any] | None = None, @@ -75,6 +77,8 @@ def get_sobol_mbm_generation_strategy( "botorch_acqf_class": acquisition_cls, "surrogate_spec": SurrogateSpec(botorch_model_class=model_cls), } + if transforms is not None: + model_kwargs["transforms"] = transforms model_name = model_names_abbrevations.get(model_cls.__name__, model_cls.__name__) acqf_name = acqf_name_abbreviations.get( @@ -109,6 +113,7 @@ def get_sobol_botorch_modular_acquisition( model_cls: type[Model], acquisition_cls: type[AcquisitionFunction], distribute_replications: bool, + transforms: Sequence[type[Transform]] | None, name: str | None = None, num_sobol_trials: int = 5, model_gen_kwargs: dict[str, Any] | None = None, @@ -162,6 +167,7 @@ def get_sobol_botorch_modular_acquisition( generation_strategy = get_sobol_mbm_generation_strategy( model_cls=model_cls, acquisition_cls=acquisition_cls, + transforms=transforms, name=name, num_sobol_trials=num_sobol_trials, model_gen_kwargs=model_gen_kwargs,