diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py
index f382968cd..8066d3dd1 100644
--- a/sbi/inference/potentials/likelihood_based_potential.py
+++ b/sbi/inference/potentials/likelihood_based_potential.py
@@ -1,7 +1,8 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see
-from typing import Callable, Optional, Tuple
+import warnings
+from typing import Callable, List, Optional, Tuple
import torch
from torch import Tensor
@@ -115,6 +116,54 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
)
return log_likelihood_batches + self.prior.log_prob(theta) # type: ignore
+ def condition_on_theta(
+ self, local_theta: Tensor, dims_global_theta: List[int]
+ ) -> Callable:
+ r"""Returns a potential function conditioned on a subset of theta dimensions.
+
+ The goal of this function is to divide the original `theta` into a
+ `global_theta` we do inference over, and a `local_theta` we condition on (in
+ addition to conditioning on `x_o`). Thus, the returned potential function will
+ calculate $\prod_{i=1}^{N}p(x_i | local_theta_i, \global_theta)$, where `x_i`
+ and `local_theta_i` are fixed and `global_theta` varies at inference time.
+
+ Args:
+ local_theta: The condition values to be conditioned.
+ dims_global_theta: The indices of the columns in `theta` that will be
+ sampled, i.e., that *not* conditioned. For example, if original theta
+ has shape `(batch_dim, 3)`, and `dims_global_theta=[0, 1]`, then the
+ potential will set `theta[:, 3] = local_theta` at inference time.
+
+ Returns:
+ A potential function conditioned on the `local_theta`.
+ """
+
+ assert self.x_is_iid, "Conditioning is only supported for iid data."
+
+ def conditioned_potential(
+ theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
+ ) -> Tensor:
+ assert (
+ len(dims_global_theta) == theta.shape[1]
+ ), "dims_global_theta must match the number of parameters to sample."
+ global_theta = theta[:, dims_global_theta]
+ x_o = x_o if x_o is not None else self.x_o
+ # x needs shape (sample_dim (iid), batch_dim (xs), *event_shape)
+ if x_o.dim() < 3:
+ x_o = reshape_to_sample_batch_event(
+ x_o, event_shape=x_o.shape[1:], leading_is_sample=self.x_is_iid
+ )
+
+ return _log_likelihood_over_iid_trials_and_local_theta(
+ x=x_o,
+ global_theta=global_theta,
+ local_theta=local_theta,
+ estimator=self.likelihood_estimator,
+ track_gradients=track_gradients,
+ )
+
+ return conditioned_potential
+
def _log_likelihoods_over_trials(
x: Tensor,
@@ -172,6 +221,77 @@ def _log_likelihoods_over_trials(
return log_likelihood_trial_sum
+def _log_likelihood_over_iid_trials_and_local_theta(
+ x: Tensor,
+ global_theta: Tensor,
+ local_theta: Tensor,
+ estimator: ConditionalDensityEstimator,
+ track_gradients: bool = False,
+) -> Tensor:
+ """Returns $\\prod_{i=1}^N \\log(p(x_i|\theta, local_theta_i)$.
+
+ `x` is a batch of iid data, and `local_theta` is a matching batch of condition
+ values that were part of `theta` but are treated as local iid variables at inference
+ time.
+
+ This function is different from `_log_likelihoods_over_trials` in that it moves the
+ iid batch dimension of `x` onto the batch dimension of `theta`. This is needed when
+ the likelihood estimator is conditioned on a batch of conditions that are iid with
+ the batch of `x`. It avoids the evaluation of the likelihood for every combination
+ of `x` and `local_theta`.
+
+ Args:
+ x: data with shape `(sample_dim, x_batch_dim, *x_event_shape)`, where sample_dim
+ holds the i.i.d. trials and batch_dim holds a batch of xs, e.g., non-iid
+ observations.
+ global_theta: Batch of parameters `(theta_batch_dim,
+ num_parameters)`.
+ local_theta: Batch of conditions of shape `(sample_dim, num_local_thetas)`, must
+ match x's `sample_dim`.
+ estimator: DensityEstimator.
+ track_gradients: Whether to track gradients.
+
+ Returns:
+ log_likelihood: log likelihood for each x in x_batch_dim, for each theta in
+ theta_batch_dim, summed over all iid trials. Shape `(x_batch_dim,
+ theta_batch_dim)`.
+ """
+ assert x.dim() > 2, "x must have shape (sample_dim, batch_dim, *event_shape)."
+ assert (
+ local_theta.dim() == 2
+ ), "condition must have shape (sample_dim, num_conditions)."
+ assert global_theta.dim() == 2, "theta must have shape (batch_dim, num_parameters)."
+ num_trials, num_xs = x.shape[:2]
+ num_thetas = global_theta.shape[0]
+ assert (
+ local_theta.shape[0] == num_trials
+ ), "Condition batch size must match the number of iid trials in x."
+
+ # move the iid batch dimension onto the batch dimension of theta and repeat it there
+ x_repeated = torch.transpose(x, 0, 1).repeat_interleave(num_thetas, dim=1)
+
+ # construct theta and condition to cover all trial-theta combinations
+ theta_with_condition = torch.cat(
+ [
+ global_theta.repeat(num_trials, 1), # repeat ABAB
+ local_theta.repeat_interleave(num_thetas, dim=0), # repeat AABB
+ ],
+ dim=-1,
+ )
+
+ with torch.set_grad_enabled(track_gradients):
+ # Calculate likelihood in one batch. Returns (1, num_trials * num_theta)
+ log_likelihood_trial_batch = estimator.log_prob(
+ x_repeated, condition=theta_with_condition
+ )
+ # Reshape to (x-trials x parameters), sum over trial-log likelihoods.
+ log_likelihood_trial_sum = log_likelihood_trial_batch.reshape(
+ num_xs, num_trials, num_thetas
+ ).sum(1)
+
+ return log_likelihood_trial_sum
+
+
def mixed_likelihood_estimator_based_potential(
likelihood_estimator: MixedDensityEstimator,
prior: Distribution,
@@ -192,6 +312,13 @@ def mixed_likelihood_estimator_based_potential(
to unconstrained space.
"""
+ warnings.warn(
+ "This function is deprecated and will be removed in a future release. Use "
+ "`likelihood_estimator_based_potential` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
device = str(next(likelihood_estimator.discrete_net.parameters()).device)
potential_fn = MixedLikelihoodBasedPotential(
@@ -212,6 +339,13 @@ def __init__(
):
super().__init__(likelihood_estimator, prior, x_o, device)
+ warnings.warn(
+ "This function is deprecated and will be removed in a future release. Use "
+ "`LikelihoodBasedPotential` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
prior_log_prob = self.prior.log_prob(theta) # type: ignore
@@ -231,7 +365,6 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
with torch.set_grad_enabled(track_gradients):
# Call the specific log prob method of the mixed likelihood estimator as
# this optimizes the evaluation of the discrete data part.
- # TODO log_prob_iid
log_likelihood_trial_batch = self.likelihood_estimator.log_prob(
input=x,
condition=theta.to(self.device),
diff --git a/sbi/inference/trainers/nle/mnle.py b/sbi/inference/trainers/nle/mnle.py
index d01ce1e91..83622eaea 100644
--- a/sbi/inference/trainers/nle/mnle.py
+++ b/sbi/inference/trainers/nle/mnle.py
@@ -7,7 +7,7 @@
from torch.distributions import Distribution
from sbi.inference.posteriors import MCMCPosterior, RejectionPosterior, VIPosterior
-from sbi.inference.potentials import mixed_likelihood_estimator_based_potential
+from sbi.inference.potentials import likelihood_estimator_based_potential
from sbi.inference.trainers.nle.nle_base import LikelihoodEstimator
from sbi.neural_nets.estimators import MixedDensityEstimator
from sbi.sbi_types import TensorboardSummaryWriter, TorchModule
@@ -155,9 +155,7 @@ def build_posterior(
(
potential_fn,
theta_transform,
- ) = mixed_likelihood_estimator_based_potential(
- likelihood_estimator=likelihood_estimator, prior=prior, x_o=None
- )
+ ) = likelihood_estimator_based_potential(likelihood_estimator, prior, x_o=None)
if sample_with == "mcmc":
self._posterior = MCMCPosterior(
diff --git a/sbi/utils/conditional_density_utils.py b/sbi/utils/conditional_density_utils.py
index d6c73b7c9..829f5e1df 100644
--- a/sbi/utils/conditional_density_utils.py
+++ b/sbi/utils/conditional_density_utils.py
@@ -293,7 +293,7 @@ def __init__(
masked outside of prior.
"""
condition = torch.atleast_2d(condition)
- if condition.shape[0] != 1:
+ if condition.shape[0] > 1:
raise ValueError("Condition with batch size > 1 not supported.")
self.potential_fn = potential_fn
diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py
index fcb5953d9..fc01d4dbd 100644
--- a/sbi/utils/sbiutils.py
+++ b/sbi/utils/sbiutils.py
@@ -60,8 +60,8 @@ def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) -
if num_unique_z < num_unique * (1 - duplicate_tolerance):
warnings.warn(
- "Z-scoring these simulation outputs resulted in {num_unique_z} unique "
- "datapoints. Before z-scoring, it had been {num_unique}. This can "
+ f"Z-scoring these simulation outputs resulted in {num_unique_z} unique "
+ f"datapoints. Before z-scoring, it had been {num_unique}. This can "
"occur due to numerical inaccuracies when the data covers a large "
"range of values. Consider either setting `z_score_x=False` (but "
"beware that this can be problematic for training the NN) or exclude "
diff --git a/tests/mnle_test.py b/tests/mnle_test.py
index a95a2a6ac..099876a3e 100644
--- a/tests/mnle_test.py
+++ b/tests/mnle_test.py
@@ -1,29 +1,32 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see
+from typing import Union
+
import pytest
import torch
from pyro.distributions import InverseGamma
-from torch.distributions import Beta, Binomial, Categorical, Gamma
+from torch import Tensor
+from torch.distributions import Beta, Binomial, Distribution, Gamma
from sbi.inference import MNLE, MCMCPosterior
from sbi.inference.posteriors.rejection_posterior import RejectionPosterior
from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.inference.potentials.likelihood_based_potential import (
- MixedLikelihoodBasedPotential,
+ _log_likelihood_over_iid_trials_and_local_theta,
+ likelihood_estimator_based_potential,
)
from sbi.neural_nets import likelihood_nn
from sbi.neural_nets.embedding_nets import FCEmbedding
from sbi.utils import BoxUniform, mcmc_transform
-from sbi.utils.conditional_density_utils import ConditionedPotential
from sbi.utils.torchutils import atleast_2d, process_device
from sbi.utils.user_input_checks_utils import MultipleIndependent
from tests.test_utils import check_c2st
# toy simulator for mixed data
-def mixed_simulator(theta, stimulus_condition=2.0):
+def mixed_simulator(theta: Tensor, stimulus_condition: Union[Tensor, float] = 2.0):
"""Simulator for mixed data."""
# Extract parameters
beta, ps = theta[:, :1], theta[:, 1:]
@@ -37,6 +40,15 @@ def mixed_simulator(theta, stimulus_condition=2.0):
return torch.cat((rts, choices), dim=1)
+def wrapped_simulator(
+ theta_and_condition: Tensor, last_idx_parameters: int = 2
+) -> Tensor:
+ # simulate with experiment conditions
+ theta = theta_and_condition[:, :last_idx_parameters]
+ condition = theta_and_condition[:, last_idx_parameters:]
+ return mixed_simulator(theta, condition)
+
+
@pytest.mark.mcmc
@pytest.mark.gpu
@pytest.mark.parametrize("device", ("cpu", "gpu"))
@@ -190,11 +202,28 @@ def test_mnle_accuracy_with_different_samplers_and_trials(
class BinomialGammaPotential(BasePotential):
- def __init__(self, prior, x_o, concentration_scaling=1.0, device="cpu"):
+ """Binomial-Gamma potential for mixed data."""
+
+ def __init__(
+ self,
+ prior: Distribution,
+ x_o: Tensor,
+ concentration_scaling: Union[Tensor, float] = 1.0,
+ device="cpu",
+ ):
super().__init__(prior, x_o, device)
+
+ # concentration_scaling needs to be a float or match the batch size
+ if isinstance(concentration_scaling, Tensor):
+ num_trials = x_o.shape[0]
+ assert concentration_scaling.shape[0] == num_trials
+
+ # Reshape to match convention (batch_size, num_trials, *event_shape)
+ concentration_scaling = concentration_scaling.reshape(1, num_trials, -1)
+
self.concentration_scaling = concentration_scaling
- def __call__(self, theta, track_gradients: bool = True):
+ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
theta = atleast_2d(theta)
with torch.set_grad_enabled(track_gradients):
@@ -202,11 +231,12 @@ def __call__(self, theta, track_gradients: bool = True):
return iid_ll + self.prior.log_prob(theta)
- def iid_likelihood(self, theta):
+ def iid_likelihood(self, theta: Tensor) -> Tensor:
batch_size = theta.shape[0]
num_trials = self.x_o.shape[0]
theta = theta.reshape(batch_size, 1, -1)
beta, rho = theta[:, :, :1], theta[:, :, 1:]
+
# vectorized
logprob_choices = Binomial(probs=rho).log_prob(
self.x_o[:, 1:].reshape(1, num_trials, -1)
@@ -233,43 +263,44 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict):
categorical parameter is set to a fixed value (conditioned posterior), and the
accuracy of the conditioned posterior is tested against the true posterior.
"""
- num_simulations = 6000
- num_samples = 500
-
- def sim_wrapper(theta):
- # simulate with experiment conditions
- return mixed_simulator(theta[:, :2], theta[:, 2:] + 1)
+ num_simulations = 10000
+ num_samples = 1000
proposal = MultipleIndependent(
[
Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
Beta(torch.tensor([2.0]), torch.tensor([2.0])),
- Categorical(probs=torch.ones(1, 3)),
+ BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])),
],
validate_args=False,
)
theta = proposal.sample((num_simulations,))
- x = sim_wrapper(theta)
+ x = wrapped_simulator(theta)
assert x.shape == (num_simulations, 2)
num_trials = 10
- theta_o = proposal.sample((1,))
- theta_o[0, 2] = 2.0 # set condition to 2 as in original simulator.
- x_o = sim_wrapper(theta_o.repeat(num_trials, 1))
+ theta_and_condition = proposal.sample((num_trials,))
+ # use only a single parameter (iid trials)
+ theta_o = theta_and_condition[:1, :2].repeat(num_trials, 1)
+ # but different conditions
+ condition_o = theta_and_condition[:, 2:]
+ theta_and_conditions_o = torch.cat((theta_o, condition_o), dim=1)
+
+ x_o = wrapped_simulator(theta_and_conditions_o)
mcmc_kwargs = dict(
method="slice_np_vectorized", init_strategy="proposal", **mcmc_params_accurate
)
# MNLE
- trainer = MNLE(proposal)
- estimator = trainer.append_simulations(theta, x).train(training_batch_size=1000)
-
- potential_fn = MixedLikelihoodBasedPotential(estimator, proposal, x_o)
+ estimator_fun = likelihood_nn(model="mnle", z_score_x=None)
+ trainer = MNLE(proposal, estimator_fun)
+ estimator = trainer.append_simulations(theta, x).train()
- conditioned_potential_fn = ConditionedPotential(
- potential_fn, condition=theta_o, dims_to_sample=[0, 1]
+ potential_fn, _ = likelihood_estimator_based_potential(estimator, proposal, x_o)
+ conditioned_potential_fn = potential_fn.condition_on_theta(
+ condition_o, dims_global_theta=[0, 1]
)
# True posterior samples
@@ -283,10 +314,7 @@ def sim_wrapper(theta):
prior_transform = mcmc_transform(prior)
true_posterior_samples = MCMCPosterior(
BinomialGammaPotential(
- prior,
- atleast_2d(x_o),
- concentration_scaling=float(theta_o[0, 2])
- + 1.0, # add one because the sim_wrapper adds one (see above)
+ prior, atleast_2d(x_o), concentration_scaling=condition_o
),
theta_transform=prior_transform,
proposal=prior,
@@ -303,5 +331,86 @@ def sim_wrapper(theta):
check_c2st(
cond_samples,
true_posterior_samples,
- alg=f"MNLE trained with {num_simulations}",
+ alg=f"MNLE trained with {num_simulations} simulations",
+ )
+
+
+@pytest.mark.parametrize("num_thetas", [1, 10])
+@pytest.mark.parametrize("num_trials", [1, 5])
+@pytest.mark.parametrize("num_xs", [1, 3])
+@pytest.mark.parametrize(
+ "num_conditions",
+ [
+ 1,
+ pytest.param(
+ 2,
+ marks=pytest.mark.xfail(
+ reason="Batched theta_condition is not " "supported"
+ ),
+ ),
+ ],
+)
+def test_log_likelihood_over_local_iid_theta(
+ num_thetas, num_trials, num_xs, num_conditions
+):
+ """Test log likelihood over iid conditions using MNLE.
+
+ Args:
+ num_thetas: batch of theta to condition on.
+ num_trials: number of i.i.d. trials in x
+ num_xs: batch of x, e.g., different subjects in a study.
+ num_conditions: number of batches of conditions, e.g., different conditions
+ for each x (not implemented yet).
+ """
+
+ # train mnle on mixed data
+ trainer = MNLE(
+ density_estimator=likelihood_nn(model="mnle", z_score_x=None),
)
+ proposal = MultipleIndependent(
+ [
+ Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
+ Beta(torch.tensor([2.0]), torch.tensor([2.0])),
+ BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])),
+ ],
+ validate_args=False,
+ )
+
+ num_simulations = 100
+ theta = proposal.sample((num_simulations,))
+ x = wrapped_simulator(theta)
+ estimator = trainer.append_simulations(theta, x).train(max_num_epochs=1)
+
+ # condition on multiple conditions
+ theta_o = proposal.sample((num_xs,))[:, :2]
+
+ x_o = torch.zeros(num_trials, num_xs, 2)
+ condition_o = proposal.sample((
+ num_conditions,
+ num_trials,
+ ))[:, 2:].reshape(num_trials, 1)
+ for i in range(num_xs):
+ # simulate with same iid theta but different conditions
+ x_o[:, i, :] = mixed_simulator(theta_o[i].repeat(num_trials, 1), condition_o)
+
+ # batched conditioning
+ theta = proposal.sample((num_thetas,))[:, :2]
+ # x_o has shape (iid, batch, *event)
+ # condition_o has shape (iid, num_conditions)
+ ll_batched = _log_likelihood_over_iid_trials_and_local_theta(
+ x_o, theta, condition_o, estimator
+ )
+
+ # looped conditioning
+ ll_single = []
+ for i in range(num_trials):
+ theta_and_condition = torch.cat(
+ (theta, condition_o[i].repeat(num_thetas, 1)), dim=1
+ )
+ x_i = x_o[i].reshape(num_xs, 1, -1).repeat(1, num_thetas, 1)
+ ll_single.append(estimator.log_prob(input=x_i, condition=theta_and_condition))
+ ll_single = torch.stack(ll_single).sum(0) # sum over trials
+
+ assert ll_batched.shape == torch.Size([num_xs, num_thetas])
+ assert ll_batched.shape == ll_single.shape
+ assert torch.allclose(ll_batched, ll_single, atol=1e-5)
diff --git a/tutorials/Example_01_DecisionMakingModel.ipynb b/tutorials/Example_01_DecisionMakingModel.ipynb
index fcfa10ced..eb16182c2 100644
--- a/tutorials/Example_01_DecisionMakingModel.ipynb
+++ b/tutorials/Example_01_DecisionMakingModel.ipynb
@@ -73,32 +73,23 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import torch\n",
"from pyro.distributions import InverseGamma\n",
"from torch import Tensor\n",
- "from torch.distributions import Beta, Binomial, Categorical, Gamma\n",
+ "from torch.distributions import Beta, Binomial, Gamma\n",
"\n",
"from sbi.analysis import pairplot\n",
"from sbi.inference import MNLE, MCMCPosterior\n",
- "from sbi.inference.potentials.base_potential import BasePotential\n",
- "from sbi.inference.potentials.likelihood_based_potential import (\n",
- " MixedLikelihoodBasedPotential,\n",
- ")\n",
- "from sbi.utils import MultipleIndependent, mcmc_transform\n",
- "from sbi.utils.conditional_density_utils import ConditionedPotential\n",
+ "from sbi.inference.potentials.likelihood_based_potential import LikelihoodBasedPotential\n",
+ "from sbi.neural_nets import likelihood_nn\n",
+ "from sbi.utils import BoxUniform, MultipleIndependent, mcmc_transform\n",
"from sbi.utils.metrics import c2st\n",
- "from sbi.utils.torchutils import atleast_2d"
+ "\n",
+ "\n",
+ "from example_01_utils import BinomialGammaPotential"
]
},
{
@@ -124,44 +115,7 @@
" concentration=concentration_scaling * torch.ones_like(beta), rate=beta\n",
" ).sample()\n",
"\n",
- " return torch.cat((rts, choices), dim=1)\n",
- "\n",
- "\n",
- "# The potential function defines the ground truth likelihood and allows us to\n",
- "# obtain reference posterior samples via MCMC.\n",
- "class BinomialGammaPotential(BasePotential):\n",
- "\n",
- " def __init__(self, prior, x_o, concentration_scaling=1.0, device=\"cpu\"):\n",
- " super().__init__(prior, x_o, device)\n",
- " self.concentration_scaling = concentration_scaling\n",
- "\n",
- " def __call__(self, theta, track_gradients: bool = True):\n",
- " theta = atleast_2d(theta)\n",
- "\n",
- " with torch.set_grad_enabled(track_gradients):\n",
- " iid_ll = self.iid_likelihood(theta)\n",
- "\n",
- " return iid_ll + self.prior.log_prob(theta)\n",
- "\n",
- " def iid_likelihood(self, theta):\n",
- " batch_size = theta.shape[0]\n",
- " num_trials = self.x_o.shape[0]\n",
- " theta = theta.reshape(batch_size, 1, -1)\n",
- " beta, rho = theta[:, :, :1], theta[:, :, 1:]\n",
- " # vectorized\n",
- " logprob_choices = Binomial(probs=rho).log_prob(\n",
- " self.x_o[:, 1:].reshape(1, num_trials, -1)\n",
- " )\n",
- "\n",
- " logprob_rts = InverseGamma(\n",
- " concentration=self.concentration_scaling * torch.ones_like(beta),\n",
- " rate=beta,\n",
- " ).log_prob(self.x_o[:, :1].reshape(1, num_trials, -1))\n",
- "\n",
- " joint_likelihood = (logprob_choices + logprob_rts).squeeze()\n",
- "\n",
- " assert joint_likelihood.shape == torch.Size([theta.shape[0], self.x_o.shape[0]])\n",
- " return joint_likelihood.sum(1)"
+ " return torch.cat((rts, choices), dim=1)"
]
},
{
@@ -205,18 +159,10 @@
"execution_count": 5,
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/janteusen/qode/sbi/sbi/inference/posteriors/mcmc_posterior.py:115: UserWarning: The default value for thinning in MCMC sampling has been changed from 10 to 1. This might cause the results differ from the last benchmark.\n",
- " thin = _process_thin_default(thin)\n"
- ]
- },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "8070275b9eac45d1991d5be41935c145",
+ "model_id": "92513794bbd148b29b5d60d566338bf6",
"version_major": 2,
"version_minor": 0
},
@@ -234,6 +180,7 @@
" warmup_steps=50,\n",
" method=\"slice_np_vectorized\",\n",
" init_strategy=\"proposal\",\n",
+ " thin=1,\n",
")\n",
"\n",
"true_posterior = MCMCPosterior(\n",
@@ -269,13 +216,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " Neural network successfully converged after 65 epochs."
+ " Neural network successfully converged after 75 epochs."
]
}
],
"source": [
"# Training data\n",
- "num_simulations = 20000\n",
+ "num_simulations = 10000\n",
"# For training the MNLE emulator we need to define a proposal distribution, the prior is\n",
"# a good choice.\n",
"proposal = prior\n",
@@ -284,7 +231,7 @@
"\n",
"# Train MNLE and obtain MCMC-based posterior.\n",
"trainer = MNLE()\n",
- "estimator = trainer.append_simulations(theta, x).train(training_batch_size=1000)"
+ "estimator = trainer.append_simulations(theta, x).train()"
]
},
{
@@ -292,10 +239,18 @@
"execution_count": 7,
"metadata": {},
"outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/janteusen/qode/sbi/sbi/inference/posteriors/mcmc_posterior.py:115: UserWarning: The default value for thinning in MCMC sampling has been changed from 10 to 1. This might cause the results differ from the last benchmark.\n",
+ " thin = _process_thin_default(thin)\n"
+ ]
+ },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "1a7792c605404a11a586681fcd3c0a32",
+ "model_id": "548e67900bd3494481dc61d0f11db250",
"version_major": 2,
"version_minor": 0
},
@@ -328,7 +283,7 @@
"outputs": [
{
"data": {
- "image/png": "",
+ "image/png": "",
"text/plain": [
""
]
@@ -390,7 +345,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "fb02120c58a54d029953b4c589f24eca",
+ "model_id": "21f980fc4f794fe1ab2090ad53a0e323",
"version_major": 2,
"version_minor": 0
},
@@ -404,7 +359,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "1cd3bc58ca8e4a21b1df2812fad8bf45",
+ "model_id": "418d47f1f4864c089bf68e1a119ebb7d",
"version_major": 2,
"version_minor": 0
},
@@ -430,7 +385,7 @@
"outputs": [
{
"data": {
- "image/png": "",
+ "image/png": "",
"text/plain": [
""
]
@@ -477,7 +432,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
- "c2st between true and MNLE posterior: 0.593\n"
+ "c2st between true and MNLE posterior: 0.5155000000000001\n"
]
}
],
@@ -515,16 +470,26 @@
"metadata": {},
"outputs": [],
"source": [
+ "# Define a proposal that contains both, priors for the parameters and a discrte\n",
+ "# prior over experimental conditions.\n",
+ "proposal = MultipleIndependent(\n",
+ " [\n",
+ " Gamma(torch.tensor([1.0]), torch.tensor([0.5])),\n",
+ " Beta(torch.tensor([2.0]), torch.tensor([2.0])),\n",
+ " BoxUniform(torch.tensor([0.0]), torch.tensor([1.0])),\n",
+ " ],\n",
+ " validate_args=False,\n",
+ ")\n",
+ "\n",
"# define a simulator wrapper in which the experimental condition are contained\n",
"# in theta and passed to the simulator.\n",
- "def sim_wrapper(theta):\n",
+ "def sim_wrapper(theta_and_conditions):\n",
" # simulate with experiment conditions\n",
" return mixed_simulator(\n",
" # we assume the first two parameters are beta and rho\n",
- " theta=theta[:, :2],\n",
+ " theta=theta_and_conditions[:, :2],\n",
" # we treat the third concentration parameter as an experimental condition\n",
- " # add 1 to deal with 0 values from Categorical distribution\n",
- " concentration_scaling=theta[:, 2:] + 1,\n",
+ " concentration_scaling=theta_and_conditions[:, 2:],\n",
" )"
]
},
@@ -534,17 +499,6 @@
"metadata": {},
"outputs": [],
"source": [
- "# Define a proposal that contains both, priors for the parameters and a discrte\n",
- "# prior over experimental conditions.\n",
- "proposal = MultipleIndependent(\n",
- " [\n",
- " Gamma(torch.tensor([1.0]), torch.tensor([0.5])),\n",
- " Beta(torch.tensor([2.0]), torch.tensor([2.0])),\n",
- " Categorical(probs=torch.ones(1, 3)), # 3 discrete conditions\n",
- " ],\n",
- " validate_args=False,\n",
- ")\n",
- "\n",
"# Simulated data\n",
"num_simulations = 10000\n",
"num_samples = 1000\n",
@@ -554,10 +508,13 @@
"\n",
"# simulate observed data and define ground truth parameters\n",
"num_trials = 10\n",
- "theta_o = proposal.sample((1,))\n",
- "theta_o[0, 2] = 2.0 # set condition to 2 as in original simulator.\n",
- "# NOTE: we use the same experimental condition for all trials.\n",
- "x_o = sim_wrapper(theta_o.repeat(num_trials, 1))"
+ "# draw one ground truth parameter\n",
+ "theta_o = proposal.sample((1,))[:, :2]\n",
+ "# draw num_trials many different conditions\n",
+ "conditions = proposal.sample((num_trials,))[:, 2:]\n",
+ "# Theta is repeated for each trial, conditions are different for each trial.\n",
+ "theta_and_conditions_o = torch.cat((theta_o.repeat(num_trials, 1), conditions), dim=1)\n",
+ "x_o = sim_wrapper(theta_and_conditions_o)"
]
},
{
@@ -566,11 +523,15 @@
"source": [
"#### Obtain ground truth posterior via MCMC\n",
"\n",
- "We obtain a ground-truth posterior via MCMC by using the PotentialFunctionProvider.\n",
+ "We obtain a ground-truth posterior via MCMC by using the analytical Binomial-Gamma\n",
+ "likelihood as before. \n",
"\n",
- "For that, we first the define the actual prior, i.e., the distribution over the parameter we want to infer (not the proposal).\n",
+ "For that, we first the define the actual prior, i.e., the distribution over the\n",
+ "parameter we want to infer (not the proposal). (dropping the uniform prior over\n",
+ "experimental conditions).\n",
"\n",
- "Thus, we leave out the discrete prior over experimental conditions.\n"
+ "Additionally, we pass the entire batch of i.i.d. data `x_o` and matching batch of i.i.d.\n",
+ "`conditions`.\n"
]
},
{
@@ -578,18 +539,10 @@
"execution_count": 14,
"metadata": {},
"outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/janteusen/qode/sbi/sbi/inference/posteriors/mcmc_posterior.py:115: UserWarning: The default value for thinning in MCMC sampling has been changed from 10 to 1. This might cause the results differ from the last benchmark.\n",
- " thin = _process_thin_default(thin)\n"
- ]
- },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "ad169fdca3da40649e6e1c329460e355",
+ "model_id": "ee7db79e47674ed3b2574c26b09eb0b2",
"version_major": 2,
"version_minor": 0
},
@@ -617,8 +570,7 @@
" BinomialGammaPotential(\n",
" prior,\n",
" x_o,\n",
- " concentration_scaling=float(theta_o[0, 2])\n",
- " + 1.0, # add one because the sim_wrapper adds one (see above)\n",
+ " concentration_scaling=conditions,\n",
" ),\n",
" theta_transform=prior_transform,\n",
" proposal=prior,\n",
@@ -630,7 +582,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Train MNLE including experimental conditions\n"
+ "### Train MNLE including experimental conditions\n",
+ "\n",
+ "Next, we use the combined parameters and conditions (`theta`) and the corresponding\n",
+ "simulated data to train `MNLE`.\n"
]
},
{
@@ -642,6 +597,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
+ "/Users/janteusen/qode/sbi/sbi/inference/trainers/base.py:271: UserWarning: Z-scoring these simulation outputs resulted in 4 unique datapoints. Before z-scoring, it had been 19872. This can occur due to numerical inaccuracies when the data covers a large range of values. Consider either setting `z_score_x=False` (but beware that this can be problematic for training the NN) or exclude outliers from your dataset. Note: if you have already set `z_score_x=False`, this warning will still be displayed, but you can ignore it.\n",
+ " warn_if_zscoring_changes_data(x)\n",
"/Users/janteusen/qode/sbi/sbi/neural_nets/factory.py:205: UserWarning: The mixed neural likelihood estimator assumes that x contains continuous data in the first n-1 columns (e.g., reaction times) and categorical data in the last column (e.g., corresponding choices). If this is not the case for the passed `x` do not use this function.\n",
" return model_builders[model](batch_x=batch_x, batch_y=batch_theta, **kwargs)\n"
]
@@ -650,12 +607,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
- " Neural network successfully converged after 60 epochs."
+ " Neural network successfully converged after 75 epochs."
]
}
],
"source": [
- "trainer = MNLE(proposal)\n",
+ "estimator_builder = likelihood_nn(model=\"mnle\", z_score_x=None) # we don't want to z-score the binary data.\n",
+ "trainer = MNLE(proposal, estimator_builder)\n",
"estimator = trainer.append_simulations(theta, x).train()"
]
},
@@ -681,28 +639,147 @@
"outputs": [
{
"data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "4f887f2ba37a4782964e838895cfc39e",
+ "version_major": 2,
+ "version_minor": 0
+ },
"text/plain": [
- "torch.Size([1, 3])"
+ "Running vectorized MCMC with 20 chains: 0%| | 0/3000 [00:00, ?it/s]"
]
},
- "execution_count": 16,
"metadata": {},
- "output_type": "execute_result"
+ "output_type": "display_data"
}
],
"source": [
- "theta_o.shape"
+ "# First, we define the potential function for the complete, unconditional MNLE-likelihood\n",
+ "potential_fn = LikelihoodBasedPotential(estimator, proposal)\n",
+ "# Then, we condition on the experimental conditions.\n",
+ "conditioned_potential_fn = potential_fn.condition_on_theta(\n",
+ " conditions, # pass only the conditions, must match the batch of iid data in x_o\n",
+ " dims_global_theta=[0, 1] # pass the dimensions in the original theta that correspond to beta and rho\n",
+ ")\n",
+ "\n",
+ "# Using this potential function, we can now obtain conditional samples.\n",
+ "mnle_posterior = MCMCPosterior(\n",
+ " potential_fn=conditioned_potential_fn, # pass the conditioned potential function\n",
+ " theta_transform=prior_transform,\n",
+ " proposal=prior, # pass the prior, not the proposal.\n",
+ " **mcmc_kwargs\n",
+ ")\n",
+ "conditional_samples = mnle_posterior.sample((num_samples,), x=x_o)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "c2st between true and MNLE posterior: 0.551\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Finally, we can compare the ground truth conditional posterior with the\n",
+ "# MNLE-conditional posterior.\n",
+ "fig, ax = pairplot(\n",
+ " [\n",
+ " prior.sample((1000,)),\n",
+ " true_posterior_samples,\n",
+ " conditional_samples,\n",
+ " ],\n",
+ " points=theta_o,\n",
+ " diag=\"kde\",\n",
+ " upper=\"contour\",\n",
+ " diag_kwargs=dict(bins=100),\n",
+ " upper_kwargs=dict(levels=[0.95]),\n",
+ " fig_kwargs=dict(\n",
+ " points_offdiag=dict(marker=\"*\", markersize=10),\n",
+ " points_colors=[\"k\"],\n",
+ "\n",
+ " ),\n",
+ " labels=[r\"$\\beta$\", r\"$\\rho$\"],\n",
+ " figsize=(6, 6),\n",
+ ")\n",
+ "\n",
+ "plt.sca(ax[1, 1])\n",
+ "plt.legend(\n",
+ " [\"Prior\", \"Reference\", \"MNLE\", r\"$\\theta_o$\"],\n",
+ " frameon=False,\n",
+ " fontsize=12,\n",
+ ");\n",
+ "print(\"c2st between true and MNLE posterior:\", c2st(true_posterior_samples, conditional_samples).item())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "They match accurately, showing that we can indeed post-hoc condition the trained MNLE likelihood on different experimental conditions.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Inference with multiple subjects, trials, and conditions\n",
+ "\n",
+ "Note that we can also do inference for multiple `x_os` (e.g., subjects) with varying\n",
+ "numbers of trails and experimental conditions - all without retraining the MNLE.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "54115a1a0f534028b377fa5aa4661dc4",
+ "model_id": "ed79d139f3804547ab14ea8dcdea856e",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Running vectorized MCMC with 20 chains: 0%| | 0/3000 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "b821d7229001426f9d541c3dd4fdcead",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Running vectorized MCMC with 20 chains: 0%| | 0/3000 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "abb52c1815e14845acf1ec9c50b91000",
"version_major": 2,
"version_minor": 0
},
@@ -715,36 +792,50 @@
}
],
"source": [
- "# We define the potential function for the complete, unconditional MNLE-likelihood\n",
- "potential_fn = MixedLikelihoodBasedPotential(estimator, proposal, x_o)\n",
"\n",
- "# Then we use the potential to construct the conditional potential function.\n",
- "# Here, we tell the constructor to condition on the last dimension (index 2) by\n",
- "# passing dims_to_sample=[0, 1].\n",
- "conditioned_potential_fn = ConditionedPotential(\n",
- " potential_fn,\n",
- " condition=theta_o,\n",
- " dims_to_sample=[0, 1],\n",
- ")\n",
- "\n",
- "# Using this potential function, we can now obtain conditional samples.\n",
- "mnle_posterior = MCMCPosterior(\n",
- " potential_fn=conditioned_potential_fn,\n",
- " theta_transform=prior_transform,\n",
- " proposal=prior,\n",
- " **mcmc_kwargs\n",
- ")\n",
- "conditional_samples = mnle_posterior.sample((num_samples,), x=x_o)"
+ "torch.manual_seed(42)\n",
+ "num_subjects = 3\n",
+ "num_trials = [10, 20, 30]\n",
+ "# draw one ground truth parameter\n",
+ "theta_o = proposal.sample((num_subjects,))[:, :2]\n",
+ "# Note that the trial conditions need to be the same for all subjects.\n",
+ "\n",
+ "# Simulate observed data for all subjects and trials.\n",
+ "x_os = []\n",
+ "conditions = []\n",
+ "for i in range(num_subjects):\n",
+ " conditions.append(proposal.sample((num_trials[i],))[:, 2:])\n",
+ " # Theta is repeated for each trial, conditions are different for each trial.\n",
+ " theta_and_condition = torch.cat((theta_o[i].repeat(num_trials[i], 1), conditions[i]), dim=-1)\n",
+ " x_os.append(sim_wrapper(theta_and_condition))\n",
+ "\n",
+ "# loop over subjects (vectorized batched x and batched conditions is not supported yet)\n",
+ "posterior_samples = []\n",
+ "for idx in range(num_subjects):\n",
+ " # condition the potential\n",
+ " conditioned_potential_fn = potential_fn.condition_on_theta(\n",
+ " conditions[idx],\n",
+ " dims_global_theta=[0, 1]\n",
+ " )\n",
+ "\n",
+ " # pass potential to sampler\n",
+ " mnle_posterior = MCMCPosterior(\n",
+ " potential_fn=conditioned_potential_fn, # pass the conditioned potential function\n",
+ " theta_transform=prior_transform,\n",
+ " proposal=prior, # pass the prior, not the proposal.\n",
+ " **mcmc_kwargs\n",
+ " )\n",
+ " posterior_samples.append(mnle_posterior.sample((num_samples,), x=x_os[idx], show_progress_bars=True))"
]
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
- "image/png": "",
+ "image/png": "",
"text/plain": [
""
]
@@ -754,15 +845,11 @@
}
],
"source": [
- "# Finally, we can compare the ground truth conditional posterior with the\n",
- "# MNLE-conditional posterior.\n",
+ "# Plotting all three posteriors in one pairplot.\n",
+ "\n",
"fig, ax = pairplot(\n",
- " [\n",
- " prior.sample((1000,)),\n",
- " true_posterior_samples,\n",
- " conditional_samples,\n",
- " ],\n",
- " points=theta_o,\n",
+ " [prior.sample((1000,))] + posterior_samples,\n",
+ " # points=theta_o,\n",
" diag=\"kde\",\n",
" upper=\"contour\",\n",
" diag_kwargs=dict(bins=100),\n",
@@ -770,13 +857,15 @@
" fig_kwargs=dict(\n",
" points_offdiag=dict(marker=\"*\", markersize=10),\n",
" points_colors=[\"k\"],\n",
+ "\n",
" ),\n",
" labels=[r\"$\\beta$\", r\"$\\rho$\"],\n",
+ " figsize=(10, 10),\n",
")\n",
"\n",
"plt.sca(ax[1, 1])\n",
"plt.legend(\n",
- " [\"Prior\", \"Reference\", \"MNLE\", r\"$\\theta_o$\"],\n",
+ " [\"prior\"] + [f\"Subject {idx+1}\" for idx in range(num_subjects)],\n",
" frameon=False,\n",
" fontsize=12,\n",
");"
@@ -786,13 +875,14 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "They match accurately, showing that we can indeed post-hoc condition the trained MNLE likelihood on different experimental conditions.\n"
+ "Note how the posteriors are becoming more narrow with increasing number of trials\n",
+ "(subject 1: 10 trials vs. subject 3: 30 trials)."
]
}
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3.8.13 ('sbi')",
+ "display_name": "sbi_env",
"language": "python",
"name": "python3"
},
@@ -806,12 +896,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.12.4"
- },
- "vscode": {
- "interpreter": {
- "hash": "9ef9b53a5ce850816b9705a866e49207a37a04a71269aa157d9f9ab944ea42bf"
- }
+ "version": "3.10.13"
}
},
"nbformat": 4,
diff --git a/tutorials/example_01_utils.py b/tutorials/example_01_utils.py
new file mode 100644
index 000000000..620058d05
--- /dev/null
+++ b/tutorials/example_01_utils.py
@@ -0,0 +1,60 @@
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch.distributions import Binomial, Distribution, InverseGamma
+
+from sbi.inference.potentials.base_potential import BasePotential
+from sbi.utils.torchutils import atleast_2d
+
+
+class BinomialGammaPotential(BasePotential):
+ """Binomial-Gamma potential for mixed data."""
+
+ def __init__(
+ self,
+ prior: Distribution,
+ x_o: Tensor,
+ concentration_scaling: Union[Tensor, float] = 1.0,
+ device="cpu",
+ ):
+ super().__init__(prior, x_o, device)
+
+ # concentration_scaling needs to be a float or match the batch size
+ if isinstance(concentration_scaling, Tensor):
+ num_trials = x_o.shape[0]
+ assert concentration_scaling.shape[0] == num_trials
+
+ # Reshape to match convention (batch_size, num_trials, *event_shape)
+ concentration_scaling = concentration_scaling.reshape(1, num_trials, -1)
+
+ self.concentration_scaling = concentration_scaling
+
+ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
+ theta = atleast_2d(theta)
+
+ with torch.set_grad_enabled(track_gradients):
+ iid_ll = self.iid_likelihood(theta)
+
+ return iid_ll + self.prior.log_prob(theta)
+
+ def iid_likelihood(self, theta: Tensor) -> Tensor:
+ batch_size = theta.shape[0]
+ num_trials = self.x_o.shape[0]
+ theta = theta.reshape(batch_size, 1, -1)
+ beta, rho = theta[:, :, :1], theta[:, :, 1:]
+
+ # vectorized
+ logprob_choices = Binomial(probs=rho).log_prob(
+ self.x_o[:, 1:].reshape(1, num_trials, -1)
+ )
+
+ logprob_rts = InverseGamma(
+ concentration=self.concentration_scaling * torch.ones_like(beta),
+ rate=beta,
+ ).log_prob(self.x_o[:, :1].reshape(1, num_trials, -1))
+
+ joint_likelihood = (logprob_choices + logprob_rts).squeeze()
+
+ assert joint_likelihood.shape == torch.Size([theta.shape[0], self.x_o.shape[0]])
+ return joint_likelihood.sum(1)