diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 28840fdf022..bc058b03652 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -1151,13 +1151,13 @@ def cross_validate( posterior = assert_is_instance(posterior, GPyTorchPosterior) pred_mean = posterior.mean pred_var = posterior.variance - pred_Y[i] = pred_mean.view(-1).numpy() - pred_Yvar[i] = pred_var.view(-1).numpy() + pred_Y[i] = pred_mean.view(-1).cpu().numpy() + pred_Yvar[i] = pred_var.view(-1).cpu().numpy() train_mask[i] = 1 # evaluate model fit metric diag_fn = DIAGNOSTIC_FNS[none_throws(self.surrogate_spec.eval_criterion)] return diag_fn( - y_obs=Y.view(-1).numpy(), + y_obs=Y.view(-1).cpu().numpy(), y_pred=pred_Y, se_pred=pred_Yvar, ) diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 736407384b0..10792dd209e 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -277,9 +277,9 @@ def test__make_botorch_input_transform(self) -> None: class SurrogateTest(TestCase): - def setUp(self) -> None: + def setUp(self, cuda: bool = False) -> None: super().setUp() - self.device = torch.device("cpu") + self.device = torch.device("cuda" if cuda else "cpu") self.dtype = torch.float self.tkwargs = {"device": self.device, "dtype": self.dtype} ( @@ -290,7 +290,7 @@ def setUp(self) -> None: _, self.feature_names, _, - ) = get_torch_test_data(dtype=self.dtype) + ) = get_torch_test_data(dtype=self.dtype, cuda=cuda) self.metric_names = ["metric"] self.training_data = [ SupervisedDataset( @@ -309,13 +309,14 @@ def setUp(self) -> None: ) self.fixed_features = {1: 2.0} self.refit = True - self.objective_weights = torch.tensor( - [-1.0, 1.0], dtype=self.dtype, device=self.device + self.objective_weights = torch.tensor([-1.0, 1.0], **self.tkwargs) + self.outcome_constraints = ( + torch.tensor([[1.0]], **self.tkwargs), + torch.tensor([[0.5]], **self.tkwargs), ) - self.outcome_constraints = (torch.tensor([[1.0]]), torch.tensor([[0.5]])) self.linear_constraints = ( - torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), - torch.tensor([[0.5], [1.0]]), + torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0]], **self.tkwargs), + torch.tensor([[0.5], [1.0]], **self.tkwargs), ) self.options = {} self.torch_opt_config = TorchOptConfig( @@ -782,7 +783,9 @@ def test_construct_model_with_metric_to_model_configs(self) -> None: @mock_botorch_optimize @patch("ax.models.torch.botorch_modular.surrogate.DIAGNOSTIC_FNS") - def test_fit_multiple_model_configs(self, mock_diag_dict: Mock) -> None: + def test_fit_multiple_model_configs( + self, mock_diag_dict: Mock, cuda: bool = False + ) -> None: mse_side_effect = [0.2, 0.1] ll_side_effect = [0.3, 0.05] mock_mse = Mock() # this should select linear kernel @@ -891,15 +894,21 @@ def test_fit_multiple_model_configs(self, mock_diag_dict: Mock) -> None: expected_X = torch.cat( [ torch.cat( - [target_dataset.X, torch.zeros(2, 1)], + [ + target_dataset.X, + torch.zeros(2, 1, **self.tkwargs), + ], + dim=-1, + ), + torch.cat( + [self.ds2.X, torch.ones(2, 1, **self.tkwargs)], dim=-1, ), - torch.cat([self.ds2.X, torch.ones(2, 1)], dim=-1), ], dim=0, ) # check that only target data is used for evaluation - mask = torch.ones(4, dtype=torch.bool) + mask = torch.ones(4, dtype=torch.bool, device=self.device) loo_idx = 0 for i in range(6): # If i in (0,3) then all data is used. @@ -926,6 +935,11 @@ def test_fit_multiple_model_configs(self, mock_diag_dict: Mock) -> None: LinearKernel if eval_criterion == "MSE" else RBFKernel, ) + def test_fit_multiple_model_configs_cuda(self) -> None: + if torch.cuda.is_available(): + self.setUp(cuda=True) + self.test_fit_multiple_model_configs(cuda=True) + def test_cross_validate_error_for_heterogeneous_datasets(self) -> None: # self.ds2.outcome_names[0] = "metric" new_feature_names = copy(self.ds2.feature_names) diff --git a/ax/utils/testing/torch_stubs.py b/ax/utils/testing/torch_stubs.py index fa258b3ebdd..fa26a87b170 100644 --- a/ax/utils/testing/torch_stubs.py +++ b/ax/utils/testing/torch_stubs.py @@ -42,9 +42,12 @@ def get_torch_test_data( ) ] Ys = [torch.tensor([[3.0 + offset], [4.0 + offset]], **tkwargs)] - Yvars = [torch.tensor([[0.0 + offset], [2.0 + offset]], **tkwargs)] if constant_noise: - Yvars[0].fill_(1.0) + Yvar = torch.ones(2, 1, **tkwargs) + else: + Yvar = torch.tensor([[0.0 + offset], [2.0 + offset]], **tkwargs) + Yvars = [Yvar] + bounds = [ (0.0 + offset, 1.0 + offset), (1.0 + offset, 4.0 + offset),