diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index beb1c988c..ab89d4378 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -143,9 +143,14 @@ def condition_on_theta( def conditioned_potential( theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True ) -> Tensor: - assert len(dims_global_theta) == theta.shape[1], ( + assert len(dims_global_theta) == theta.shape[-1], ( "dims_global_theta must match the number of parameters to sample." ) + if theta.dim() > 2: + assert theta.shape[0] == 1, ( + "condition_on_theta does not support sample shape for theta." + ) + theta = theta.squeeze(0) global_theta = theta[:, dims_global_theta] x_o = x_o if x_o is not None else self.x_o # x needs shape (sample_dim (iid), batch_dim (xs), *event_shape) @@ -155,7 +160,7 @@ def conditioned_potential( ) return _log_likelihood_over_iid_trials_and_local_theta( - x=x_o, + x=x_o.to(self.device), global_theta=global_theta, local_theta=local_theta, estimator=self.likelihood_estimator, @@ -266,6 +271,10 @@ def _log_likelihood_over_iid_trials_and_local_theta( assert local_theta.shape[0] == num_trials, ( "Condition batch size must match the number of iid trials in x." ) + if num_xs > 1: + raise NotImplementedError( + "Batched sampling for multiple `x` is not supported for iid conditions." + ) # move the iid batch dimension onto the batch dimension of theta and repeat it there x_repeated = torch.transpose(x, 0, 1).repeat_interleave(num_thetas, dim=1) @@ -289,7 +298,8 @@ def _log_likelihood_over_iid_trials_and_local_theta( num_xs, num_trials, num_thetas ).sum(1) - return log_likelihood_trial_sum + # remove xs batch dimension + return log_likelihood_trial_sum.squeeze(0) def mixed_likelihood_estimator_based_potential( diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index 4da4c47bf..ca91174fb 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -947,7 +947,7 @@ def gradient_ascent( ) best_theta_iter = optimize_inits[ # type: ignore torch.argmax(log_probs_of_optimized) - ].view(1, -1) + ].unsqueeze(0) # add batch dim best_log_prob_iter = potential_fn( theta_transform.inv(best_theta_iter) ) diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index 74e560770..d4894733e 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -24,6 +24,7 @@ ratio_estimator_based_potential, ) from sbi.inference.posteriors.importance_posterior import ImportanceSamplingPosterior +from sbi.inference.posteriors.mcmc_posterior import MCMCPosterior from sbi.inference.potentials.base_potential import BasePotential from sbi.neural_nets.embedding_nets import FCEmbedding from sbi.neural_nets.factory import ( @@ -33,7 +34,11 @@ posterior_nn, ) from sbi.simulators import diagonal_linear_gaussian, linear_gaussian -from sbi.utils.torchutils import BoxUniform, gpu_available, process_device +from sbi.utils.torchutils import ( + BoxUniform, + gpu_available, + process_device, +) from sbi.utils.user_input_checks import ( validate_theta_and_x, ) @@ -465,3 +470,49 @@ def test_multiround_mdn_training_on_device(method: Union[NPE_A, NPE_C], device: proposal = trainer.build_posterior().set_default_x(torch.zeros(num_dim)) theta = proposal.sample((num_simulations,)) x = simulator(theta) + + +@pytest.mark.gpu +@pytest.mark.parametrize("device", ["cpu", "gpu"]) +def test_conditioned_posterior_on_gpu(device: str, mcmc_params_fast: dict): + device = process_device(device) + num_dims = 3 + + proposal = BoxUniform( + low=-torch.ones(num_dims, device=device), + high=torch.ones(num_dims, device=device), + ) + + inference = NPE_C(device=device, show_progress_bars=False) + + num_simulations = 100 + theta = proposal.sample((num_simulations,)) + x = torch.randn_like(theta) + x_o = torch.zeros(1, num_dims).to(device) + inference = inference.append_simulations(theta, x) + + estimator = inference.train(max_num_epochs=2) + + # condition on one dim of theta + condition_o = torch.ones(1, 1).to(device) + prior = BoxUniform( + low=-torch.ones(num_dims - 1, device=device), + high=torch.ones(num_dims - 1, device=device), + ) + prior_transform = utils.mcmc_transform(prior, device=device) + + potential_fn, _ = likelihood_estimator_based_potential(estimator, proposal, x_o) + conditioned_potential_fn = potential_fn.condition_on_theta( + condition_o, dims_global_theta=[0, 1] + ) + + conditional_posterior = MCMCPosterior( + potential_fn=conditioned_potential_fn, + theta_transform=prior_transform, + proposal=prior, + device=device, + **mcmc_params_fast, + ).set_default_x(x_o) + samples = conditional_posterior.sample((1,), x=x_o) + conditional_posterior.potential_fn(samples) + conditional_posterior.map() diff --git a/tests/mnle_test.py b/tests/mnle_test.py index 8245bbf08..53f1e862c 100644 --- a/tests/mnle_test.py +++ b/tests/mnle_test.py @@ -40,9 +40,10 @@ def mixed_simulator(theta: Tensor, stimulus_condition: Union[Tensor, float] = 2. return torch.cat((rts, choices), dim=1) -def wrapped_simulator( +def mixed_simulator_with_conditions( theta_and_condition: Tensor, last_idx_parameters: int = 2 ) -> Tensor: + """Simulator for mixed data with experimental conditions.""" # simulate with experiment conditions theta = theta_and_condition[:, :last_idx_parameters] condition = theta_and_condition[:, last_idx_parameters:] @@ -278,7 +279,7 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict): ) theta = proposal.sample((num_simulations,)) - x = wrapped_simulator(theta) + x = mixed_simulator_with_conditions(theta) assert x.shape == (num_simulations, 2) num_trials = 10 @@ -289,7 +290,7 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict): condition_o = theta_and_condition[:, 2:] theta_and_conditions_o = torch.cat((theta_o, condition_o), dim=1) - x_o = wrapped_simulator(theta_and_conditions_o) + x_o = mixed_simulator_with_conditions(theta_and_conditions_o) mcmc_kwargs = dict( method="slice_np_vectorized", init_strategy="proposal", **mcmc_params_accurate @@ -313,6 +314,9 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict): ], validate_args=False, ) + # test theta with sample shape. + conditioned_potential_fn(prior.sample((10,)).unsqueeze(0)) + prior_transform = mcmc_transform(prior) true_posterior_samples = MCMCPosterior( BinomialGammaPotential( @@ -339,14 +343,28 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict): @pytest.mark.parametrize("num_thetas", [1, 10]) @pytest.mark.parametrize("num_trials", [1, 5]) -@pytest.mark.parametrize("num_xs", [1, 3]) +@pytest.mark.parametrize( + "num_xs", + [ + 1, + pytest.param( + 2, + marks=pytest.mark.xfail( + reason="Batched x not supported for iid trials.", + raises=NotImplementedError, + ), + ), + ], +) @pytest.mark.parametrize( "num_conditions", [ 1, pytest.param( 2, - marks=pytest.mark.xfail(reason="Batched theta_condition is not supported"), + marks=pytest.mark.xfail( + reason="Batched theta_condition is not supported", + ), ), ], ) @@ -376,7 +394,7 @@ def test_log_likelihood_over_local_iid_theta( num_simulations = 100 theta = proposal.sample((num_simulations,)) - x = wrapped_simulator(theta) + x = mixed_simulator_with_conditions(theta) estimator = trainer.append_simulations(theta, x).train(max_num_epochs=1) # condition on multiple conditions @@ -407,8 +425,10 @@ def test_log_likelihood_over_local_iid_theta( ) x_i = x_o[i].reshape(num_xs, 1, -1).repeat(1, num_thetas, 1) ll_single.append(estimator.log_prob(input=x_i, condition=theta_and_condition)) - ll_single = torch.stack(ll_single).sum(0) # sum over trials + ll_single = ( + torch.stack(ll_single).sum(0).squeeze(0) + ) # sum over trials, squeeze x batch. - assert ll_batched.shape == torch.Size([num_xs, num_thetas]) + assert ll_batched.shape == torch.Size([num_thetas]) assert ll_batched.shape == ll_single.shape assert torch.allclose(ll_batched, ll_single, atol=1e-5)