Skip to content

Commit

Permalink
add normalization to warping function (#3259)
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/botorch#2692

Pull Request resolved: #3259

Warping only works on the unit cube. This ensures that inputs are first normalized. This doesn't use `ChainedInputTranform` to avoid nested `ChainedInputTransform`

This is to support linear+warping models in MBM when we aren't using `UnitX`. We’d want to 1) ensure data is in the unit cube before applying `Warp`, 2) then center the warped data at 0 (using Normalize). One way to do this would to apply `Normalize(center=0.5`), `Warp`, `Normalize(center=0.0)`, but we can’t currently specify different options for two different transforms of the same class. So this insteads takes an approach suggested by saitcakmak to include normalization in the `Warp` transform, since we always want inputs to be in the unit cube before warping.

Reviewed By: saitcakmak, Balandat

Differential Revision: D68356342

fbshipit-source-id: dbd6682d2aacf779c6c813e1a7b779c04f290af0
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jan 22, 2025
1 parent d8406d3 commit 85b1736
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 22 deletions.
6 changes: 6 additions & 0 deletions ax/models/torch/botorch_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,12 +780,18 @@ def get_warping_transform(
# apply warping to all non-task features, including fidelity features
if task_feature is not None:
del indices[task_feature]
# Legacy Ax models operate in the unit cube
bounds = torch.zeros(2, d, dtype=torch.double)
bounds[1] = 1
# Note: this currently uses the same warping functions for all tasks
tf = Warp(
d=d,
indices=indices,
# prior with a median of 1
concentration1_prior=LogNormalPrior(0.0, 0.75**0.5),
concentration0_prior=LogNormalPrior(0.0, 0.75**0.5),
batch_shape=batch_shape,
# Legacy Ax models operate in the unit cube
bounds=bounds,
)
return tf
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,41 @@
)


def _set_default_bounds(
search_space_digest: SearchSpaceDigest,
input_transform_options: dict[str, Any],
d: int,
torch_device: torch.device | None = None,
torch_dtype: torch.dtype | None = None,
) -> None:
"""Set default bounds in input_transform_options, in-place.
Args:
search_space_digest: Search space digest.
input_transform_options: Input transform kwargs.
d: The dimension of the input space.
torch_device: The device on which the input transform will be used.
torch_dtype: The dtype on which the input transform will be used.
"""
bounds = torch.as_tensor(
search_space_digest.bounds,
dtype=torch_dtype,
device=torch_device,
).T

if (
("bounds" not in input_transform_options)
and (bounds.shape[-1] < d)
and (len(search_space_digest.task_features) == 0)
):
raise NotImplementedError(
"Normalization bounds should be specified explicitly if there"
" are task features outside the search space."
)

input_transform_options.setdefault("bounds", bounds)


@input_transform_argparse.register(InputTransform)
def _input_transform_argparse_base(
input_transform_class: type[InputTransform],
Expand Down Expand Up @@ -63,6 +98,8 @@ def _input_transform_argparse_warp(
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
input_transform_options: dict[str, Any] | None = None,
torch_device: torch.device | None = None,
torch_dtype: torch.dtype | None = None,
) -> dict[str, Any]:
"""Extract the base input transform kwargs form the given arguments.
Expand All @@ -71,6 +108,8 @@ def _input_transform_argparse_warp(
dataset: Dataset containing feature matrix and the response.
search_space_digest: Search space digest.
input_transform_options: Input transform kwargs.
torch_device: The device on which the input transform will be used.
torch_dtype: The dtype on which the input transform will be used.
Returns:
A dictionary with input transform kwargs.
Expand All @@ -83,7 +122,15 @@ def _input_transform_argparse_warp(
for task_feature in sorted(task_features, reverse=True):
del indices[task_feature]

input_transform_options.setdefault("d", d)
input_transform_options.setdefault("indices", indices)
_set_default_bounds(
search_space_digest=search_space_digest,
input_transform_options=input_transform_options,
d=d,
torch_device=torch_device,
torch_dtype=torch_dtype,
)
return input_transform_options


Expand All @@ -107,6 +154,8 @@ def _input_transform_argparse_normalize(
dataset: Dataset containing feature matrix and the response.
search_space_digest: Search space digest.
input_transform_options: Input transform kwargs.
torch_device: The device on which the input transform will be used.
torch_dtype: The dtype on which the input transform will be used.
Returns:
A dictionary with input transform kwargs.
Expand All @@ -130,23 +179,13 @@ def _input_transform_argparse_normalize(
if ("indices" in input_transform_options) or (len(indices) < d):
input_transform_options.setdefault("indices", indices)

bounds = torch.as_tensor(
search_space_digest.bounds,
dtype=torch_dtype,
device=torch_device,
).T

if (
("bounds" not in input_transform_options)
and (bounds.shape[-1] < d)
and (len(search_space_digest.task_features) == 0)
):
raise NotImplementedError(
"Normalize transform bounds should be specified explicitly if there"
" are task features outside the search space."
)

input_transform_options.setdefault("bounds", bounds)
_set_default_bounds(
search_space_digest=search_space_digest,
input_transform_options=input_transform_options,
d=d,
torch_device=torch_device,
torch_dtype=torch_dtype,
)

return input_transform_options

Expand Down
36 changes: 31 additions & 5 deletions ax/models/torch/tests/test_input_transform_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,13 @@ def test_argparse_warp(self) -> None:
search_space_digest=self.search_space_digest,
)

self.assertEqual(
input_transform_kwargs,
{"indices": [1, 2]},
self.assertEqual(input_transform_kwargs["indices"], [1, 2])
self.assertEqual(input_transform_kwargs["d"], 4)
self.assertTrue(
torch.equal(
input_transform_kwargs["bounds"],
torch.tensor([[0.0, 0.0, 0.0], [1.0, 2.0, 4.0]]),
)
)

input_transform_kwargs = input_transform_argparse(
Expand All @@ -194,8 +198,30 @@ def test_argparse_warp(self) -> None:
search_space_digest=self.search_space_digest,
input_transform_options={"indices": [0, 1]},
)

self.assertEqual(input_transform_kwargs, {"indices": [0, 1]})
self.assertEqual(
input_transform_kwargs["indices"],
[0, 1],
)
self.assertEqual(input_transform_kwargs["d"], 4)
self.assertTrue(
torch.equal(
input_transform_kwargs["bounds"],
torch.tensor([[0.0, 0.0, 0.0], [1.0, 2.0, 4.0]]),
)
)
input_transform_kwargs = input_transform_argparse(
Warp,
dataset=self.dataset,
search_space_digest=self.search_space_digest,
input_transform_options={"indices": [0, 1]},
torch_dtype=torch.float64,
)
self.assertTrue(
torch.equal(
input_transform_kwargs["bounds"],
torch.tensor([[0.0, 0.0, 0.0], [1.0, 2.0, 4.0]], dtype=torch.float64),
)
)

def test_argparse_input_perturbation(self) -> None:
self.search_space_digest.robust_digest = RobustSearchSpaceDigest(
Expand Down

0 comments on commit 85b1736

Please sign in to comment.