Skip to content

Commit

Permalink
add gpu support for model selection (facebook#3154)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#3154

see title

Reviewed By: Balandat

Differential Revision: D66887106

fbshipit-source-id: 4519dfff2cd17df204549f98e1494ea4c198430d
  • Loading branch information
sdaulton authored and facebook-github-bot committed Dec 6, 2024
1 parent f9a9fd6 commit 5f3e8e2
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
6 changes: 3 additions & 3 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,13 +1151,13 @@ def cross_validate(
posterior = assert_is_instance(posterior, GPyTorchPosterior)
pred_mean = posterior.mean
pred_var = posterior.variance
pred_Y[i] = pred_mean.view(-1).numpy()
pred_Yvar[i] = pred_var.view(-1).numpy()
pred_Y[i] = pred_mean.view(-1).cpu().numpy()
pred_Yvar[i] = pred_var.view(-1).cpu().numpy()
train_mask[i] = 1
# evaluate model fit metric
diag_fn = DIAGNOSTIC_FNS[none_throws(self.surrogate_spec.eval_criterion)]
return diag_fn(
y_obs=Y.view(-1).numpy(),
y_obs=Y.view(-1).cpu().numpy(),
y_pred=pred_Y,
se_pred=pred_Yvar,
)
Expand Down
38 changes: 26 additions & 12 deletions ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,9 @@ def test__make_botorch_input_transform(self) -> None:


class SurrogateTest(TestCase):
def setUp(self) -> None:
def setUp(self, cuda: bool = False) -> None:
super().setUp()
self.device = torch.device("cpu")
self.device = torch.device("cuda" if cuda else "cpu")
self.dtype = torch.float
self.tkwargs = {"device": self.device, "dtype": self.dtype}
(
Expand All @@ -290,7 +290,7 @@ def setUp(self) -> None:
_,
self.feature_names,
_,
) = get_torch_test_data(dtype=self.dtype)
) = get_torch_test_data(dtype=self.dtype, cuda=cuda)
self.metric_names = ["metric"]
self.training_data = [
SupervisedDataset(
Expand All @@ -309,13 +309,14 @@ def setUp(self) -> None:
)
self.fixed_features = {1: 2.0}
self.refit = True
self.objective_weights = torch.tensor(
[-1.0, 1.0], dtype=self.dtype, device=self.device
self.objective_weights = torch.tensor([-1.0, 1.0], **self.tkwargs)
self.outcome_constraints = (
torch.tensor([[1.0]], **self.tkwargs),
torch.tensor([[0.5]], **self.tkwargs),
)
self.outcome_constraints = (torch.tensor([[1.0]]), torch.tensor([[0.5]]))
self.linear_constraints = (
torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]),
torch.tensor([[0.5], [1.0]]),
torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]], **self.tkwargs),
torch.tensor([[0.5], [1.0]], **self.tkwargs),
)
self.options = {}
self.torch_opt_config = TorchOptConfig(
Expand Down Expand Up @@ -782,7 +783,9 @@ def test_construct_model_with_metric_to_model_configs(self) -> None:

@mock_botorch_optimize
@patch("ax.models.torch.botorch_modular.surrogate.DIAGNOSTIC_FNS")
def test_fit_multiple_model_configs(self, mock_diag_dict: Mock) -> None:
def test_fit_multiple_model_configs(
self, mock_diag_dict: Mock, cuda: bool = False
) -> None:
mse_side_effect = [0.2, 0.1]
ll_side_effect = [0.3, 0.05]
mock_mse = Mock() # this should select linear kernel
Expand Down Expand Up @@ -891,15 +894,21 @@ def test_fit_multiple_model_configs(self, mock_diag_dict: Mock) -> None:
expected_X = torch.cat(
[
torch.cat(
[target_dataset.X, torch.zeros(2, 1)],
[
target_dataset.X,
torch.zeros(2, 1, **self.tkwargs),
],
dim=-1,
),
torch.cat(
[self.ds2.X, torch.ones(2, 1, **self.tkwargs)],
dim=-1,
),
torch.cat([self.ds2.X, torch.ones(2, 1)], dim=-1),
],
dim=0,
)
# check that only target data is used for evaluation
mask = torch.ones(4, dtype=torch.bool)
mask = torch.ones(4, dtype=torch.bool, device=self.device)
loo_idx = 0
for i in range(6):
# If i in (0,3) then all data is used.
Expand All @@ -926,6 +935,11 @@ def test_fit_multiple_model_configs(self, mock_diag_dict: Mock) -> None:
LinearKernel if eval_criterion == "MSE" else RBFKernel,
)

def test_fit_multiple_model_configs_cuda(self) -> None:
if torch.cuda.is_available():
self.setUp(cuda=True)
self.test_fit_multiple_model_configs(cuda=True)

def test_cross_validate_error_for_heterogeneous_datasets(self) -> None:
# self.ds2.outcome_names[0] = "metric"
new_feature_names = copy(self.ds2.feature_names)
Expand Down
7 changes: 5 additions & 2 deletions ax/utils/testing/torch_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ def get_torch_test_data(
)
]
Ys = [torch.tensor([[3.0 + offset], [4.0 + offset]], **tkwargs)]
Yvars = [torch.tensor([[0.0 + offset], [2.0 + offset]], **tkwargs)]
if constant_noise:
Yvars[0].fill_(1.0)
Yvar = torch.ones(2, 1, **tkwargs)
else:
Yvar = torch.tensor([[0.0 + offset], [2.0 + offset]], **tkwargs)
Yvars = [Yvar]

bounds = [
(0.0 + offset, 1.0 + offset),
(1.0 + offset, 4.0 + offset),
Expand Down

0 comments on commit 5f3e8e2

Please sign in to comment.