Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: update batch size and mcmc defaults. #1221

Merged
merged 4 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"matplotlib",
"numpy",
"pillow",
"pyknos>=0.15.1",
"pyknos>=0.16.0",
"pyro-ppl>=1.3.1",
"scikit-learn",
"scipy",
Expand Down
2 changes: 1 addition & 1 deletion sbi/analysis/sensitivity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def build_mlp(theta):

def train(
self,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
Expand Down
24 changes: 17 additions & 7 deletions sbi/diagnostics/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def run_sbc(
reduce_fns: Union[str, Callable, List[Callable]] = "marginals",
num_workers: int = 1,
show_progress_bar: bool = True,
use_batched_sampling: bool = True,
**kwargs,
):
"""Run simulation-based calibration (SBC) (parallelized across sbc runs).
Expand All @@ -47,6 +48,8 @@ def run_sbc(
num_workers: number of CPU cores to use in parallel for running
`num_sbc_samples` inferences.
show_progress_bar: whether to display a progress over sbc runs.
use_batched_sampling: whether to use batched sampling for posterior
samples.

Returns:
ranks: ranks of the ground truth parameters under the inferred
Expand Down Expand Up @@ -81,13 +84,16 @@ def run_sbc(

# Get posterior samples, batched or parallelized.
posterior_samples = get_posterior_samples_on_batch(
xs, posterior, num_posterior_samples, num_workers, show_progress_bar
xs,
posterior,
(num_posterior_samples,),
num_workers,
show_progress_bar,
use_batched_sampling=use_batched_sampling,
)
# for calibration methods its handy to have len(xs) in first dim.
posterior_samples = posterior_samples.transpose(0, 1)

# take a random draw from each posterior to get data averaged posterior samples.
dap_samples = posterior_samples[:, 0, :]
dap_samples = posterior_samples[0, :, :]
assert dap_samples.shape == (num_sbc_samples, thetas.shape[1]), "Wrong dap shape."

ranks = _run_sbc(
Expand Down Expand Up @@ -126,8 +132,8 @@ def _run_sbc(

ranks = torch.zeros((num_sbc_samples, len(reduce_fns)))
# Iterate over all sbc samples and calculate ranks.
for sbc_idx, (ths, theta_i, x_i) in tqdm(
enumerate(zip(posterior_samples, thetas, xs)),
for sbc_idx, (true_theta, x_i) in tqdm(
enumerate(zip(thetas, xs)),
total=num_sbc_samples,
disable=not show_progress_bar,
desc=f"Calculating ranks for {num_sbc_samples} sbc samples.",
Expand All @@ -139,8 +145,12 @@ def _run_sbc(

# For each reduce_fn (e.g., per marginal for SBC)
for dim_idx, reduce_fn in enumerate(reduce_fns):
# rank posterior samples against true parameter, reduced to 1D.
ranks[sbc_idx, dim_idx] = (
(reduce_fn(ths, x_i) < reduce_fn(theta_i.unsqueeze(0), x_i))
(
reduce_fn(posterior_samples[:, sbc_idx, :], x_i)
< reduce_fn(true_theta.unsqueeze(0), x_i)
)
janfb marked this conversation as resolved.
Show resolved Hide resolved
.sum()
.item()
)
Expand Down
2 changes: 1 addition & 1 deletion sbi/diagnostics/tarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def run_tarp(
posterior_samples = get_posterior_samples_on_batch(
xs,
posterior,
num_posterior_samples,
(num_posterior_samples,),
num_workers,
show_progress_bar=show_progress_bar,
)
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def append_simulations(
@abstractmethod
def train(
self,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
Expand All @@ -312,7 +312,7 @@ def train(
def get_dataloaders(
self,
starting_round: int = 0,
training_batch_size: int = 50,
training_batch_size: int = 200,
validation_fraction: float = 0.1,
resume_training: bool = False,
dataloader_kwargs: Optional[dict] = None,
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/fmpe/fmpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def append_simulations(

def train(
self,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def __init__(
potential_fn: Union[Callable, BasePotential],
proposal: Any,
theta_transform: Optional[TorchTransform] = None,
method: str = "slice_np",
method: str = "slice_np_vectorized",
thin: int = -1,
warmup_steps: int = 200,
num_chains: int = 1,
num_chains: int = 20,
init_strategy: str = "resample",
init_strategy_parameters: Optional[Dict[str, Any]] = None,
init_strategy_num_candidates: Optional[int] = None,
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/snle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(

def train(
self,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
Expand All @@ -92,7 +92,7 @@ def build_posterior(
density_estimator: Optional[TorchModule] = None,
prior: Optional[Distribution] = None,
sample_with: str = "mcmc",
mcmc_method: str = "slice_np",
mcmc_method: str = "slice_np_vectorized",
vi_method: str = "rKL",
mcmc_parameters: Optional[Dict[str, Any]] = None,
vi_parameters: Optional[Dict[str, Any]] = None,
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def append_simulations(

def train(
self,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
Expand Down Expand Up @@ -267,7 +267,7 @@ def build_posterior(
density_estimator: Optional[ConditionalDensityEstimator] = None,
prior: Optional[Distribution] = None,
sample_with: str = "mcmc",
mcmc_method: str = "slice_np",
mcmc_method: str = "slice_np_vectorized",
vi_method: str = "rKL",
mcmc_parameters: Optional[Dict[str, Any]] = None,
vi_parameters: Optional[Dict[str, Any]] = None,
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/snpe/snpe_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
def train(
self,
final_round: bool = False,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def append_simulations(

def train(
self,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
Expand Down Expand Up @@ -435,7 +435,7 @@ def build_posterior(
density_estimator: Optional[ConditionalDensityEstimator] = None,
prior: Optional[Distribution] = None,
sample_with: str = "direct",
mcmc_method: str = "slice_np",
mcmc_method: str = "slice_np_vectorized",
vi_method: str = "rKL",
direct_sampling_parameters: Optional[Dict[str, Any]] = None,
mcmc_parameters: Optional[Dict[str, Any]] = None,
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/snpe/snpe_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
def train(
self,
num_atoms: int = 10,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/snre/bnre.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
def train(
self,
regularization_strength: float = 100.0,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/snre/snre_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(

def train(
self,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/snre/snre_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
def train(
self,
num_atoms: int = 10,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
Expand Down
4 changes: 2 additions & 2 deletions sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def append_simulations(
def train(
self,
num_atoms: int = 10,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
Expand Down Expand Up @@ -319,7 +319,7 @@ def build_posterior(
density_estimator: Optional[nn.Module] = None,
prior: Optional[Distribution] = None,
sample_with: str = "mcmc",
mcmc_method: str = "slice_np",
mcmc_method: str = "slice_np_vectorized",
vi_method: str = "rKL",
mcmc_parameters: Optional[Dict[str, Any]] = None,
vi_parameters: Optional[Dict[str, Any]] = None,
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/snre/snre_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def train(
self,
num_classes: int = 5,
gamma: float = 1.0,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
Expand Down
45 changes: 28 additions & 17 deletions sbi/utils/diagnostics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,26 @@

from sbi.inference.posteriors.base_posterior import NeuralPosterior
from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.sbi_types import Shape


def get_posterior_samples_on_batch(
xs: Tensor,
posterior: NeuralPosterior,
num_samples: int,
sample_shape: Shape,
num_workers: int = 1,
show_progress_bar: bool = False,
use_batched_sampling: bool = True,
) -> Tensor:
"""Get posterior samples for a batch of xs.

Args:
xs: batch of observations.
posterior: sbi posterior.
num_posterior_samples: number of samples to draw from the posterior in each sbc
run.
num_samples: number of samples to draw from the posterior for each x.
num_workers: number of workers to use for parallelization.
show_progress_bars: whether to show progress bars.
use_batched_sampling: whether to use batched sampling if possible.

Returns:
posterior_samples: of shape (num_samples, batch_size, dim_parameters).
Expand All @@ -32,35 +34,44 @@
# Try using batched sampling when implemented.
try:
# has shape (num_samples, batch_size, dim_parameters)
posterior_samples = posterior.sample_batched(
(num_samples,), xs, show_progress_bars=show_progress_bar
)
if use_batched_sampling:
posterior_samples = posterior.sample_batched(
sample_shape, x=xs, show_progress_bars=show_progress_bar
)
else:
raise NotImplementedError

Check warning on line 42 in sbi/utils/diagnostics_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/diagnostics_utils.py#L42

Added line #L42 was not covered by tests
except NotImplementedError:
# We need a function with extra training step for new x for VIPosterior.
def sample_fun(posterior: NeuralPosterior, sample_shape, x: Tensor) -> Tensor:
def sample_fun(
posterior: NeuralPosterior, sample_shape: Shape, x: Tensor, seed: int = 0
) -> Tensor:
if isinstance(posterior, VIPosterior):
posterior.set_default_x(x)
posterior.train()
torch.manual_seed(seed)
return posterior.sample(sample_shape, x=x, show_progress_bars=False)

# Run in parallel with progress bar.
seeds = torch.randint(0, 2**32, (batch_size,))
outputs = list(
tqdm(
Parallel(return_as="generator", n_jobs=num_workers)(
delayed(sample_fun)(posterior, (num_samples,), x=x) for x in xs
delayed(sample_fun)(posterior, sample_shape, x=x, seed=s)
for x, s in zip(xs, seeds)
),
disable=not show_progress_bar,
total=len(xs),
desc=f"Sampling {batch_size} times {num_samples} posterior samples.",
desc=f"Sampling {batch_size} times {sample_shape} posterior samples.",
)
)
# Transpose to sample_batched shape convention:
posterior_samples = torch.stack(outputs).transpose(0, 1) # type: ignore
) # (batch_size, num_samples, dim_parameters)
# Transpose to shape convention: (sample_shape, batch_size, dim_parameters)
posterior_samples = torch.stack(
outputs # type: ignore
).permute(1, 0, 2)

assert posterior_samples.shape[:2] == (
num_samples,
assert posterior_samples.shape[:2] == sample_shape + (
batch_size,
), f"""Expected batched posterior samples of shape {(num_samples, batch_size)} got {
posterior_samples.shape[:2]
}."""
), f"""Expected batched posterior samples of shape {
sample_shape + (batch_size,)
} got {posterior_samples.shape[:2]}."""
return posterior_samples
2 changes: 1 addition & 1 deletion sbi/utils/restriction_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def get_simulations(self, starting_round: int = 0) -> Tuple[Tensor, Tensor, Tens

def train(
self,
training_batch_size: int = 50,
training_batch_size: int = 200,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
Expand Down
Loading