Skip to content

Commit

Permalink
Give all DensityEstimators an input_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Apr 23, 2024
1 parent 0b5f931 commit f0aa0f9
Show file tree
Hide file tree
Showing 26 changed files with 161 additions and 111 deletions.
2 changes: 1 addition & 1 deletion sbi/inference/abc/mcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def simulator(theta):
x = simulator(theta)

# Infer shape of x to test and set x_o.
self.x_shape = x[0].unsqueeze(0).shape
self.x_shape = x[0].shape
self.x_o = process_x(x_o, self.x_shape)

distances = self.distance(self.x_o, x)
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/abc/smcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def _set_xo_and_sample_initial_population(
x = self._simulate_with_budget(theta)

# Infer x shape to test and set x_o.
self.x_shape = x[0].unsqueeze(0).shape
self.x_shape = x[0].shape
self.x_o = process_x(x_o, self.x_shape)

distances = self.distance(self.x_o, x)
Expand Down
15 changes: 11 additions & 4 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Optional, Union
from warnings import warn

import torch
import torch.distributions.transforms as torch_tf
Expand Down Expand Up @@ -42,8 +43,15 @@ def __init__(
Allows to perform, e.g. MCMC in unconstrained space.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
x_shape: Shape of the observed data.
x_shape: Deprecated, should not be passed.
"""
if x_shape is not None:
warn(
"x_shape is not None. However, passing x_shape to the `Posterior` is "
"deprecated and will be removed in a future release of `sbi`.",
stacklevel=2,
)

if not isinstance(potential_fn, BasePotential):
kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys())
for key in ["theta", "x_o"]:
Expand Down Expand Up @@ -74,7 +82,6 @@ def __init__(

self._map = None
self._purpose = ""
self._x_shape = x_shape

# If the sampler interface (#573) is used, the user might have passed `x_o`
# already to the potential function builder. If so, this `x_o` will be used
Expand Down Expand Up @@ -146,7 +153,7 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior":
`NeuralPosterior` that will use a default `x` when not explicitly passed.
"""
self._x = process_x(
x, x_shape=self._x_shape, allow_iid_x=self.potential_fn.allow_iid_x
x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x
).to(self._device)
self._map = None
return self
Expand All @@ -156,7 +163,7 @@ def _x_else_default_x(self, x: Optional[Array]) -> Tensor:
# New x, reset posterior sampler.
self._posterior_sampler = None
return process_x(
x, x_shape=self._x_shape, allow_iid_x=self.potential_fn.allow_iid_x
x, x_event_shape=None, allow_iid_x=self.potential_fn.allow_iid_x
)
elif self.default_x is None:
raise ValueError(
Expand Down
24 changes: 10 additions & 14 deletions sbi/inference/posteriors/direct_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def __init__(
the proposal at every iteration.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
x_shape: Shape of a single simulator output. If passed, it is used to check
the shape of the observed data and give a descriptive error.
x_shape: Deprecated, should not be passed.
enable_transform: Whether to transform parameters to unconstrained space
during MAP optimization. When False, an identity transform will be
returned for `theta_transform`.
Expand Down Expand Up @@ -106,12 +105,9 @@ def sample(

num_samples = torch.Size(sample_shape).numel()
x = self._x_else_default_x(x)

# [1:] because we remove batch dimension for `reshape_to_batch_event`.
# Note: This line will break if `x_shape` is `None` and if `x` is passed without
# batch dimension.
x_shape = self._x_shape[1:] if self._x_shape is not None else x.shape[1:]
x = reshape_to_batch_event(x, event_shape=x_shape)
x = reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)

max_sampling_batch_size = (
self.max_sampling_batch_size
Expand Down Expand Up @@ -172,14 +168,13 @@ def log_prob(
"""
x = self._x_else_default_x(x)

# [1:] to remove batch dimension for `reshape_to_sample_batch_event`.
x_shape = self._x_shape[1:] if self._x_shape is not None else x.shape[1:]

theta = ensure_theta_batched(torch.as_tensor(theta))
theta_density_estimator = reshape_to_sample_batch_event(
theta, theta.shape[1:], leading_is_sample=True
)
x_density_estimator = reshape_to_batch_event(x, x_shape)
x_density_estimator = reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)
assert (
x_density_estimator.shape[0] == 1
), ".log_prob() supports only `batchsize == 1`."
Expand Down Expand Up @@ -244,7 +239,6 @@ def leakage_correction(

def acceptance_at(x: Tensor) -> Tensor:
# [1:] to remove batch-dimension for `reshape_to_batch_event`.
x_shape = self._x_shape[1:] if self._x_shape is not None else x.shape[1:]
return accept_reject_sample(
proposal=self.posterior_estimator,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
Expand All @@ -253,7 +247,9 @@ def acceptance_at(x: Tensor) -> Tensor:
sample_for_correction_factor=True,
max_sampling_batch_size=rejection_sampling_batch_size,
proposal_sampling_kwargs={
"condition": reshape_to_batch_event(x, x_shape)
"condition": reshape_to_batch_event(
x, event_shape=self.posterior_estimator.condition_shape
)
},
)[1]

Expand Down
6 changes: 2 additions & 4 deletions sbi/inference/posteriors/ensemble_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(
potential_fn=potential_fn,
theta_transform=theta_transform,
device=device,
x_shape=self.posteriors[0]._x_shape,
x_shape=None,
)

def ensure_same_device(self, posteriors: List) -> str:
Expand Down Expand Up @@ -242,9 +242,7 @@ def set_default_x(self, x: Tensor) -> "NeuralPosterior":
`EnsemblePosterior` that will use a default `x` when not explicitly
passed.
"""
self._x = process_x(
x, x_shape=self._x_shape, allow_iid_x=self.potential_fn.allow_iid_x
).to(self._device)
self._x = x.to(self._device)

for posterior in self.posteriors:
posterior.set_default_x(x)
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/posteriors/importance_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def __init__(
proposal at every iteration.
device: Device on which to sample, e.g., "cpu", "cuda" or "cuda:0". If
None, `potential_fn.device` is used.
x_shape: Shape of a single simulator output. If passed, it is used to check
the shape of the observed data and give a descriptive error.
x_shape: Deprecated, should not be passed.
"""
super().__init__(
potential_fn,
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,7 @@ def __init__(
(e.g. Linux and macOS, not Windows).
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
x_shape: Shape of a single simulator output. If passed, it is used to check
the shape of the observed data and give a descriptive error.
x_shape: Deprecated, should not be passed.
"""
if method == "slice":
warn(
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/posteriors/rejection_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def __init__(
m: Multiplier to the `potential_fn / proposal` ratio.
device: Training device, e.g., "cpu", "cuda" or "cuda:0". If None,
`potential_fn.device` is used.
x_shape: Shape of a single simulator output. If passed, it is used to check
the shape of the observed data and give a descriptive error.
x_shape: Deprecated, should not be passed.
"""
super().__init__(
potential_fn,
Expand Down
3 changes: 1 addition & 2 deletions sbi/inference/posteriors/vi_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ def __init__(
typically cover all modes (`fKL`, `IW`, `alpha` for alpha < 1).
device: Training device, e.g., `cpu`, `cuda` or `cuda:0`. We will ensure
that all other objects are also on this device.
x_shape: Shape of a single simulator output. If passed, it is used to check
the shape of the observed data and give a descriptive error.
x_shape: Deprecated, should not be passed.
parameters: List of parameters of the variational posterior. This is only
required for user-defined q i.e. if q does not have a `parameters`
attribute.
Expand Down
9 changes: 4 additions & 5 deletions sbi/inference/snle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,13 @@ def build_posterior(
proposal=prior,
method=mcmc_method,
device=device,
x_shape=self._x_shape,
**mcmc_parameters or {},
)
elif sample_with == "rejection":
self._posterior = RejectionPosterior(
potential_fn=potential_fn,
proposal=prior,
device=device,
x_shape=self._x_shape,
**rejection_sampling_parameters or {},
)
elif sample_with == "vi":
Expand All @@ -189,7 +187,6 @@ def build_posterior(
prior=prior, # type: ignore
vi_method=vi_method,
device=device,
x_shape=self._x_shape,
**vi_parameters or {},
)
else:
Expand All @@ -209,6 +206,8 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor:
Returns:
Negative log prob.
"""
theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:])
x = reshape_to_sample_batch_event(x, event_shape=self._x_shape[1:])
theta = reshape_to_batch_event(
theta, event_shape=self._neural_net.condition_shape
)
x = reshape_to_sample_batch_event(x, event_shape=self._neural_net.input_shape)
return -self._neural_net.log_prob(x, condition=theta)
14 changes: 6 additions & 8 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,10 @@ def train(
theta[self.train_indices].to("cpu"),
x[self.train_indices].to("cpu"),
)
self._x_shape = x_shape_from_simulation(x.to("cpu"))
del theta, x
assert (
len(self._x_shape) < 3
len(x_shape_from_simulation(x.to("cpu"))) < 3
), "SNLE cannot handle multi-dimensional simulator output."
del theta, x

self._neural_net.to(self._device)
if not resume_training:
Expand Down Expand Up @@ -335,15 +334,13 @@ def build_posterior(
proposal=prior,
method=mcmc_method,
device=device,
x_shape=self._x_shape,
**mcmc_parameters or {},
)
elif sample_with == "rejection":
self._posterior = RejectionPosterior(
potential_fn=potential_fn,
proposal=prior,
device=device,
x_shape=self._x_shape,
**rejection_sampling_parameters or {},
)
elif sample_with == "vi":
Expand All @@ -353,7 +350,6 @@ def build_posterior(
prior=prior, # type: ignore
vi_method=vi_method,
device=device,
x_shape=self._x_shape,
**vi_parameters or {},
)
else:
Expand All @@ -370,8 +366,10 @@ def _loss(self, theta: Tensor, x: Tensor) -> Tensor:
Returns:
Negative log prob.
"""
theta = reshape_to_batch_event(theta, event_shape=theta.shape[1:])
theta = reshape_to_batch_event(
theta, event_shape=self._neural_net.condition_shape
)
x = reshape_to_sample_batch_event(
x, event_shape=self._x_shape[1:], leading_is_sample=False
x, event_shape=self._neural_net.input_shape, leading_is_sample=False
)
return self._neural_net.loss(x, condition=theta)
4 changes: 2 additions & 2 deletions sbi/inference/snpe/snpe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def __init__(
"""
# Call nn.Module's constructor.

super().__init__(flow, flow._condition_shape)
super().__init__(flow, flow.input_shape, flow.condition_shape)

self._neural_net = flow
self._prior = prior
Expand Down Expand Up @@ -480,7 +480,7 @@ def sample(self, sample_shape: torch.Size, condition: Tensor, **kwargs) -> Tenso
# \tilde{p} has already been observed. To analytically calculate the
# log-prob of the Gaussian, we first need to compute the mixture components.
num_samples = torch.Size(sample_shape).numel()
condition_ndim = len(self._condition_shape)
condition_ndim = len(self.condition_shape)
batch_size = condition.shape[:-condition_ndim]
batch_size = torch.Size(batch_size).numel()
return self._sample_approx_posterior_mog(num_samples, condition, batch_size)
Expand Down
18 changes: 8 additions & 10 deletions sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
test_posterior_net_for_multi_d_x,
validate_theta_and_x,
warn_if_zscoring_changes_data,
x_shape_from_simulation,
)
from sbi.utils.sbiutils import ImproperEmpirical, mask_sims_from_prior

Expand Down Expand Up @@ -320,10 +319,11 @@ def default_calibration_kernel(x):
theta[self.train_indices].to("cpu"),
x[self.train_indices].to("cpu"),
)
self._x_shape = x_shape_from_simulation(x.to("cpu"))

theta = reshape_to_sample_batch_event(theta.to("cpu"), theta.shape[1:])
x = reshape_to_batch_event(x.to("cpu"), self._x_shape[1:])
theta = reshape_to_sample_batch_event(
theta.to("cpu"), self._neural_net.input_shape
)
x = reshape_to_batch_event(x.to("cpu"), self._neural_net.condition_shape)
test_posterior_net_for_multi_d_x(self._neural_net, theta, x)

del theta, x
Expand Down Expand Up @@ -503,7 +503,6 @@ def build_posterior(
self._posterior = DirectPosterior(
posterior_estimator=posterior_estimator, # type: ignore
prior=prior,
x_shape=self._x_shape,
device=device,
**direct_sampling_parameters or {},
)
Expand All @@ -520,7 +519,6 @@ def build_posterior(
self._posterior = RejectionPosterior(
potential_fn=potential_fn,
device=device,
x_shape=self._x_shape,
**rejection_sampling_parameters,
)
elif sample_with == "mcmc":
Expand All @@ -530,7 +528,6 @@ def build_posterior(
proposal=prior,
method=mcmc_method,
device=device,
x_shape=self._x_shape,
**mcmc_parameters or {},
)
elif sample_with == "vi":
Expand All @@ -540,7 +537,6 @@ def build_posterior(
prior=prior, # type: ignore
vi_method=vi_method,
device=device,
x_shape=self._x_shape,
**vi_parameters or {},
)
else:
Expand Down Expand Up @@ -582,8 +578,10 @@ def _loss(
distribution different from the prior.
"""
if self._round == 0 or force_first_round_loss:
theta = reshape_to_sample_batch_event(theta, event_shape=theta.shape[1:])
x = reshape_to_batch_event(x, event_shape=self._x_shape[1:])
theta = reshape_to_sample_batch_event(
theta, event_shape=self._neural_net.input_shape
)
x = reshape_to_batch_event(x, event_shape=self._neural_net.condition_shape)
# Use posterior log prob (without proposal correction) for first round.
loss = self._neural_net.loss(theta, x)
else:
Expand Down
8 changes: 5 additions & 3 deletions sbi/inference/snpe/snpe_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@ def _log_prob_proposal_posterior_atomic(
atomic_theta = reshape_to_sample_batch_event(
atomic_theta, atomic_theta.shape[1:]
)
repeated_x = reshape_to_batch_event(repeated_x, self._x_shape[1:])
repeated_x = reshape_to_batch_event(
repeated_x, self._neural_net.condition_shape
)
log_prob_posterior = self._neural_net.log_prob(atomic_theta, repeated_x)
utils.assert_all_finite(log_prob_posterior, "posterior eval")
log_prob_posterior = log_prob_posterior.reshape(batch_size, num_atoms)
Expand All @@ -371,8 +373,8 @@ def _log_prob_proposal_posterior_atomic(

# XXX This evaluates the posterior on _all_ prior samples
if self._use_combined_loss:
theta = reshape_to_sample_batch_event(theta, theta.shape[1:])
x = reshape_to_batch_event(x, self._x_shape[1:])
theta = reshape_to_sample_batch_event(theta, self._neural_net.input_shape)
x = reshape_to_batch_event(x, self._neural_net.condition_shape)
log_prob_posterior_non_atomic = self._neural_net.log_prob(theta, x)
# squeeze to remove sample dimension, which is always one during the loss
# evaluation of `SNPE_C` (because we have one theta vector per x vector).
Expand Down
Loading

0 comments on commit f0aa0f9

Please sign in to comment.