Skip to content

Commit

Permalink
batched sampling for score-based posteriors
Browse files Browse the repository at this point in the history
  • Loading branch information
gmoss13 committed Jan 29, 2025
1 parent 362fb64 commit 9f24294
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 12 deletions.
59 changes: 57 additions & 2 deletions sbi/inference/posteriors/score_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def sample(
def _sample_via_diffusion(
self,
sample_shape: Shape = torch.Size(),
x: Optional[Tensor] = None,
predictor: Union[str, Predictor] = "euler_maruyama",
corrector: Optional[Union[str, Corrector]] = None,
predictor_params: Optional[Dict] = None,
Expand Down Expand Up @@ -313,13 +314,67 @@ def sample_batched(
self,
sample_shape: torch.Size,
x: Tensor,
predictor: Union[str, Predictor] = "euler_maruyama",
corrector: Optional[Union[str, Corrector]] = None,
predictor_params: Optional[Dict] = None,
corrector_params: Optional[Dict] = None,
steps: int = 500,
ts: Optional[Tensor] = None,
max_sampling_batch_size: int = 10000,
show_progress_bars: bool = True,
) -> Tensor:
raise NotImplementedError(
"Batched sampling is not implemented for ScorePosterior."
num_samples = torch.Size(sample_shape).numel()
x = reshape_to_batch_event(x, self.score_estimator.condition_shape)
condition_dim = len(self.score_estimator.condition_shape)
batch_shape = x.shape[:-condition_dim]
batch_size = batch_shape.numel()
self.potential_fn.set_x(x)

max_sampling_batch_size = (
self.max_sampling_batch_size
if max_sampling_batch_size is None
else max_sampling_batch_size
)

if self.sample_with == "ode":
samples = rejection.accept_reject_sample(
proposal=self.sample_via_zuko,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
num_xos=batch_size,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs={"x": x},
)[0]
samples = samples.reshape(
sample_shape + batch_shape + self.score_estimator.input_shape
)
elif self.sample_with == "sde":
proposal_sampling_kwargs = {
"predictor": predictor,
"corrector": corrector,
"predictor_params": predictor_params,
"corrector_params": corrector_params,
"steps": steps,
"ts": ts,
"max_sampling_batch_size": max_sampling_batch_size,
"show_progress_bars": show_progress_bars,
}
samples = rejection.accept_reject_sample(
proposal=self._sample_via_diffusion,
accept_reject_fn=lambda theta: within_support(self.prior, theta),
num_samples=num_samples,
num_xos=batch_size,
show_progress_bars=show_progress_bars,
max_sampling_batch_size=max_sampling_batch_size,
proposal_sampling_kwargs=proposal_sampling_kwargs,
)[0]
samples = samples.reshape(
sample_shape + batch_shape + self.score_estimator.input_shape
)

return samples

def map(
self,
x: Optional[Tensor] = None,
Expand Down
3 changes: 0 additions & 3 deletions sbi/inference/potentials/score_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,6 @@ def set_x(
x_density_estimator = reshape_to_batch_event(
self.x_o, event_shape=self.score_estimator.condition_shape
)
assert x_density_estimator.shape[0] == 1, (
"PosteriorScoreBasedPotential supports only x batchsize of 1`."
)
# For large number of evals, we want a high-tolerance flow.
# This flow will be used mainly for MAP calculations, hence we want to save
# it instead of rebuilding it every time.
Expand Down
5 changes: 3 additions & 2 deletions sbi/samplers/rejection/rejection.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def accept_reject_sample(
proposal: Callable,
accept_reject_fn: Callable,
num_samples: int,
num_xos: int = 1,
show_progress_bars: bool = False,
warn_acceptance: float = 0.01,
sample_for_correction_factor: bool = False,
Expand Down Expand Up @@ -218,6 +219,8 @@ def accept_reject_sample(
rejected. Must take a batch of parameters and return a boolean tensor which
indicates which parameters get accepted.
num_samples: Desired number of samples.
num_xos: Number of conditions for batched_sampling (currently only accepting
one batch dimension for the condition).
show_progress_bars: Whether to show a progressbar during sampling.
warn_acceptance: A minimum acceptance rate under which to warn about slowness.
sample_for_correction_factor: True if this function was called by
Expand Down Expand Up @@ -263,8 +266,6 @@ def accept_reject_sample(
# But this would require giving the method the condition_shape explicitly...
if "condition" in proposal_sampling_kwargs:
num_xos = proposal_sampling_kwargs["condition"].shape[0]
else:
num_xos = 1

accepted = [[] for _ in range(num_xos)]
acceptance_rate = torch.full((num_xos,), float("Nan"))
Expand Down
10 changes: 5 additions & 5 deletions sbi/samplers/score/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,11 @@ def initialize(self, num_samples: int) -> Tensor:
# batched sampling setting with a flag.
# TODO: this fixes the iid setting shape problems, but iid inference via
# iid_bridge is not accurate.
# num_batch = self.batch_shape.numel()
# init_shape = (num_batch, num_samples) + self.input_shape
init_shape = (
num_samples,
) + self.input_shape # just use num_samples, not num_batch
num_batch = self.batch_shape.numel()
init_shape = (num_samples, num_batch) + self.input_shape
# init_shape = (
# num_samples,
# ) + self.input_shape # just use num_samples, not num_batch
# NOTE: for the IID setting we might need to scale the noise with iid batch
# size, as in equation (7) in the paper.
eps = torch.randn(init_shape, device=self.device)
Expand Down

0 comments on commit 9f24294

Please sign in to comment.