diff --git a/sbi/inference/posteriors/score_posterior.py b/sbi/inference/posteriors/score_posterior.py index bbf5aa812..034bdefda 100644 --- a/sbi/inference/posteriors/score_posterior.py +++ b/sbi/inference/posteriors/score_posterior.py @@ -177,6 +177,7 @@ def sample( def _sample_via_diffusion( self, sample_shape: Shape = torch.Size(), + x: Optional[Tensor] = None, predictor: Union[str, Predictor] = "euler_maruyama", corrector: Optional[Union[str, Corrector]] = None, predictor_params: Optional[Dict] = None, @@ -313,13 +314,67 @@ def sample_batched( self, sample_shape: torch.Size, x: Tensor, + predictor: Union[str, Predictor] = "euler_maruyama", + corrector: Optional[Union[str, Corrector]] = None, + predictor_params: Optional[Dict] = None, + corrector_params: Optional[Dict] = None, + steps: int = 500, + ts: Optional[Tensor] = None, max_sampling_batch_size: int = 10000, show_progress_bars: bool = True, ) -> Tensor: - raise NotImplementedError( - "Batched sampling is not implemented for ScorePosterior." + num_samples = torch.Size(sample_shape).numel() + x = reshape_to_batch_event(x, self.score_estimator.condition_shape) + condition_dim = len(self.score_estimator.condition_shape) + batch_shape = x.shape[:-condition_dim] + batch_size = batch_shape.numel() + self.potential_fn.set_x(x) + + max_sampling_batch_size = ( + self.max_sampling_batch_size + if max_sampling_batch_size is None + else max_sampling_batch_size ) + if self.sample_with == "ode": + samples = rejection.accept_reject_sample( + proposal=self.sample_via_zuko, + accept_reject_fn=lambda theta: within_support(self.prior, theta), + num_samples=num_samples, + num_xos=batch_size, + show_progress_bars=show_progress_bars, + max_sampling_batch_size=max_sampling_batch_size, + proposal_sampling_kwargs={"x": x}, + )[0] + samples = samples.reshape( + sample_shape + batch_shape + self.score_estimator.input_shape + ) + elif self.sample_with == "sde": + proposal_sampling_kwargs = { + "predictor": predictor, + "corrector": corrector, + "predictor_params": predictor_params, + "corrector_params": corrector_params, + "steps": steps, + "ts": ts, + "max_sampling_batch_size": max_sampling_batch_size, + "show_progress_bars": show_progress_bars, + } + samples = rejection.accept_reject_sample( + proposal=self._sample_via_diffusion, + accept_reject_fn=lambda theta: within_support(self.prior, theta), + num_samples=num_samples, + num_xos=batch_size, + show_progress_bars=show_progress_bars, + max_sampling_batch_size=max_sampling_batch_size, + proposal_sampling_kwargs=proposal_sampling_kwargs, + )[0] + samples = samples.reshape( + sample_shape + batch_shape + self.score_estimator.input_shape + ) + + return samples + def map( self, x: Optional[Tensor] = None, diff --git a/sbi/inference/potentials/score_based_potential.py b/sbi/inference/potentials/score_based_potential.py index cbc1df73f..4b2fe3706 100644 --- a/sbi/inference/potentials/score_based_potential.py +++ b/sbi/inference/potentials/score_based_potential.py @@ -86,9 +86,6 @@ def set_x( x_density_estimator = reshape_to_batch_event( self.x_o, event_shape=self.score_estimator.condition_shape ) - assert x_density_estimator.shape[0] == 1, ( - "PosteriorScoreBasedPotential supports only x batchsize of 1`." - ) # For large number of evals, we want a high-tolerance flow. # This flow will be used mainly for MAP calculations, hence we want to save # it instead of rebuilding it every time. diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 458c4e22a..1a90d449e 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -190,6 +190,7 @@ def accept_reject_sample( proposal: Callable, accept_reject_fn: Callable, num_samples: int, + num_xos: int = 1, show_progress_bars: bool = False, warn_acceptance: float = 0.01, sample_for_correction_factor: bool = False, @@ -218,6 +219,8 @@ def accept_reject_sample( rejected. Must take a batch of parameters and return a boolean tensor which indicates which parameters get accepted. num_samples: Desired number of samples. + num_xos: Number of conditions for batched_sampling (currently only accepting + one batch dimension for the condition). show_progress_bars: Whether to show a progressbar during sampling. warn_acceptance: A minimum acceptance rate under which to warn about slowness. sample_for_correction_factor: True if this function was called by @@ -263,8 +266,6 @@ def accept_reject_sample( # But this would require giving the method the condition_shape explicitly... if "condition" in proposal_sampling_kwargs: num_xos = proposal_sampling_kwargs["condition"].shape[0] - else: - num_xos = 1 accepted = [[] for _ in range(num_xos)] acceptance_rate = torch.full((num_xos,), float("Nan")) diff --git a/sbi/samplers/score/score.py b/sbi/samplers/score/score.py index 48337292f..ee2c98ef2 100644 --- a/sbi/samplers/score/score.py +++ b/sbi/samplers/score/score.py @@ -101,11 +101,11 @@ def initialize(self, num_samples: int) -> Tensor: # batched sampling setting with a flag. # TODO: this fixes the iid setting shape problems, but iid inference via # iid_bridge is not accurate. - # num_batch = self.batch_shape.numel() - # init_shape = (num_batch, num_samples) + self.input_shape - init_shape = ( - num_samples, - ) + self.input_shape # just use num_samples, not num_batch + num_batch = self.batch_shape.numel() + init_shape = (num_samples, num_batch) + self.input_shape + # init_shape = ( + # num_samples, + # ) + self.input_shape # just use num_samples, not num_batch # NOTE: for the IID setting we might need to scale the noise with iid batch # size, as in equation (7) in the paper. eps = torch.randn(init_shape, device=self.device)