Skip to content

Commit

Permalink
fix #1343 device handling in mog_log_prob (#1356)
Browse files Browse the repository at this point in the history
* fix device in mog_log_prob, add test, fix tests.

* add cpu test for multiround mdn
  • Loading branch information
janfb authored Jan 10, 2025
1 parent a6a220d commit 0ab3a14
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
4 changes: 3 additions & 1 deletion sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,9 @@ def mog_log_prob(

# Split up evaluation into parts.
weights = logits_pp - torch.logsumexp(logits_pp, dim=-1, keepdim=True)
constant = -(output_dim / 2.0) * torch.log(torch.tensor([2 * pi]))
constant = -(output_dim / 2.0) * torch.log(
torch.tensor([2 * pi], device=theta.device)
)
log_det = 0.5 * torch.log(torch.det(precisions_pp))
theta_minus_mean = theta.expand_as(means_pp) - means_pp
exponent = -0.5 * batched_mixture_vmv(precisions_pp, theta_minus_mean)
Expand Down
37 changes: 33 additions & 4 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

from typing import Tuple
from typing import Tuple, Union

import pytest
import torch
Expand Down Expand Up @@ -127,7 +127,7 @@ def simulator(theta):
model=model, num_transforms=2, dtype=torch.float32
)
)
train_kwargs = dict(force_first_round_loss=True)
train_kwargs = dict()
elif method == NLE:
kwargs = dict(
density_estimator=likelihood_nn(
Expand All @@ -152,9 +152,12 @@ def simulator(theta):
x = simulator(theta).to(data_device)
theta = theta.to(data_device)

estimator = inferer.append_simulations(theta, x, data_device=data_device).train(
training_batch_size=100, max_num_epochs=max_num_epochs, **train_kwargs
data_kwargs = (
dict(proposal=proposals[-1]) if method in [NPE_A, NPE_C] else dict()
)
estimator = inferer.append_simulations(
theta, x, data_device=data_device, **data_kwargs
).train(max_num_epochs=max_num_epochs, **train_kwargs)

# mcmc cases
if sampling_method in ["slice_np", "slice_np_vectorized", "nuts_pymc"]:
Expand Down Expand Up @@ -436,3 +439,29 @@ def test_boxuniform_device_handling(arg_device, device):
low=zeros(1).to(arg_device), high=ones(1).to(arg_device), device=device
)
NPE_C(prior=prior, device=arg_device)


@pytest.mark.gpu
@pytest.mark.parametrize("method", [NPE_A, NPE_C])
@pytest.mark.parametrize("device", ["cpu", "gpu"])
def test_multiround_mdn_training_on_device(method: Union[NPE_A, NPE_C], device: str):
num_dim = 2
num_rounds = 2
num_simulations = 100
device = process_device("gpu")
prior = BoxUniform(-torch.ones(num_dim), torch.ones(num_dim), device=device)
simulator = diagonal_linear_gaussian

estimator = "mdn_snpe_a" if method == NPE_A else "mdn"

trainer = method(prior, density_estimator=estimator, device=device)

theta = prior.sample((num_simulations,))
x = simulator(theta)

proposal = prior
for _ in range(num_rounds):
trainer.append_simulations(theta, x, proposal=proposal).train(max_num_epochs=2)
proposal = trainer.build_posterior().set_default_x(torch.zeros(num_dim))
theta = proposal.sample((num_simulations,))
x = simulator(theta)

0 comments on commit 0ab3a14

Please sign in to comment.