From 0ab3a147a698fc9dfbe47e96af23cfd4f9a93b68 Mon Sep 17 00:00:00 2001 From: Jan Date: Fri, 10 Jan 2025 15:24:39 +0100 Subject: [PATCH] fix #1343 device handling in mog_log_prob (#1356) * fix device in mog_log_prob, add test, fix tests. * add cpu test for multiround mdn --- sbi/utils/sbiutils.py | 4 +++- tests/inference_on_device_test.py | 37 +++++++++++++++++++++++++++---- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index fc01d4dbd..75d83dbbd 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -841,7 +841,9 @@ def mog_log_prob( # Split up evaluation into parts. weights = logits_pp - torch.logsumexp(logits_pp, dim=-1, keepdim=True) - constant = -(output_dim / 2.0) * torch.log(torch.tensor([2 * pi])) + constant = -(output_dim / 2.0) * torch.log( + torch.tensor([2 * pi], device=theta.device) + ) log_det = 0.5 * torch.log(torch.det(precisions_pp)) theta_minus_mean = theta.expand_as(means_pp) - means_pp exponent = -0.5 * batched_mixture_vmv(precisions_pp, theta_minus_mean) diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index 1c6653af6..6e213acb7 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import Tuple +from typing import Tuple, Union import pytest import torch @@ -127,7 +127,7 @@ def simulator(theta): model=model, num_transforms=2, dtype=torch.float32 ) ) - train_kwargs = dict(force_first_round_loss=True) + train_kwargs = dict() elif method == NLE: kwargs = dict( density_estimator=likelihood_nn( @@ -152,9 +152,12 @@ def simulator(theta): x = simulator(theta).to(data_device) theta = theta.to(data_device) - estimator = inferer.append_simulations(theta, x, data_device=data_device).train( - training_batch_size=100, max_num_epochs=max_num_epochs, **train_kwargs + data_kwargs = ( + dict(proposal=proposals[-1]) if method in [NPE_A, NPE_C] else dict() ) + estimator = inferer.append_simulations( + theta, x, data_device=data_device, **data_kwargs + ).train(max_num_epochs=max_num_epochs, **train_kwargs) # mcmc cases if sampling_method in ["slice_np", "slice_np_vectorized", "nuts_pymc"]: @@ -436,3 +439,29 @@ def test_boxuniform_device_handling(arg_device, device): low=zeros(1).to(arg_device), high=ones(1).to(arg_device), device=device ) NPE_C(prior=prior, device=arg_device) + + +@pytest.mark.gpu +@pytest.mark.parametrize("method", [NPE_A, NPE_C]) +@pytest.mark.parametrize("device", ["cpu", "gpu"]) +def test_multiround_mdn_training_on_device(method: Union[NPE_A, NPE_C], device: str): + num_dim = 2 + num_rounds = 2 + num_simulations = 100 + device = process_device("gpu") + prior = BoxUniform(-torch.ones(num_dim), torch.ones(num_dim), device=device) + simulator = diagonal_linear_gaussian + + estimator = "mdn_snpe_a" if method == NPE_A else "mdn" + + trainer = method(prior, density_estimator=estimator, device=device) + + theta = prior.sample((num_simulations,)) + x = simulator(theta) + + proposal = prior + for _ in range(num_rounds): + trainer.append_simulations(theta, x, proposal=proposal).train(max_num_epochs=2) + proposal = trainer.build_posterior().set_default_x(torch.zeros(num_dim)) + theta = proposal.sample((num_simulations,)) + x = simulator(theta)