diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a28b91c21..0cfc5bd15 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.3 + rev: v0.9.0 hooks: - id: ruff - id: ruff-format args: [--diff] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v5.0.0 hooks: - id: check-added-large-files - id: check-merge-conflict diff --git a/pyproject.toml b/pyproject.toml index 072c0b8d1..0dee18135 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,10 +64,10 @@ doc = [ dev = [ "ffmpeg", # Lint - "pre-commit == 3.5.0", + "pre-commit == 4.0.1", "pyyaml", "pyright", - "ruff>=0.3.3", + "ruff==0.9.0", # Test "pytest", "pytest-cov", @@ -106,6 +106,7 @@ ignore = [ [tool.ruff.lint.extend-per-file-ignores] "__init__.py" = ["E402", "F401", "F403"] # allow unused imports and undefined names "test_*.py" = ["F403", "F405"] +"tutorials/*.ipynb" = ["E501"] # allow long lines in notebooks [tool.ruff.lint.isort] case-sensitive = true diff --git a/sbi/analysis/plot.py b/sbi/analysis/plot.py index 5825e3386..745c55131 100644 --- a/sbi/analysis/plot.py +++ b/sbi/analysis/plot.py @@ -775,9 +775,9 @@ def pairplot( # checks. if fig_kwargs_filled["legend"]: - assert len(fig_kwargs_filled["samples_labels"]) >= len( - samples - ), "Provide at least as many labels as samples." + assert len(fig_kwargs_filled["samples_labels"]) >= len(samples), ( + "Provide at least as many labels as samples." + ) if offdiag is not None: warn("offdiag is deprecated, use upper or lower instead.", stacklevel=2) upper = offdiag @@ -1594,9 +1594,9 @@ def _sbc_rank_plot( ranks_list[idx]: np.ndarray = rank.numpy() # type: ignore plot_types = ["hist", "cdf"] - assert ( - plot_type in plot_types - ), "plot type {plot_type} not implemented, use one in {plot_types}." + assert plot_type in plot_types, ( + "plot type {plot_type} not implemented, use one in {plot_types}." + ) if legend_kwargs is None: legend_kwargs = dict(loc="best", handlelength=0.8) @@ -1609,9 +1609,9 @@ def _sbc_rank_plot( params_in_subplots = True for ranki in ranks_list: - assert ( - ranki.shape == ranks_list[0].shape - ), "all ranks in list must have the same shape." + assert ranki.shape == ranks_list[0].shape, ( + "all ranks in list must have the same shape." + ) num_rows = int(np.ceil(num_parameters / num_cols)) if figsize is None: @@ -1636,9 +1636,9 @@ def _sbc_rank_plot( ) ax = np.atleast_1d(ax) # type: ignore else: - assert ( - ax.size >= num_parameters - ), "There must be at least as many subplots as parameters." + assert ax.size >= num_parameters, ( + "There must be at least as many subplots as parameters." + ) num_rows = ax.shape[0] if ax.ndim > 1 else 1 assert ax is not None @@ -2221,9 +2221,9 @@ def pairplot_dep( # checks. if opts["legend"]: - assert len(opts["samples_labels"]) >= len( - samples - ), "Provide at least as many labels as samples." + assert len(opts["samples_labels"]) >= len(samples), ( + "Provide at least as many labels as samples." + ) if opts["upper"] is not None: opts["offdiag"] = opts["upper"] diff --git a/sbi/analysis/sensitivity_analysis.py b/sbi/analysis/sensitivity_analysis.py index 435071f4f..b43d6c51f 100644 --- a/sbi/analysis/sensitivity_analysis.py +++ b/sbi/analysis/sensitivity_analysis.py @@ -250,9 +250,9 @@ def train( prevent exploding gradients. Use `None` for no clipping. """ - assert ( - self._theta is not None and self._emergent_property is not None - ), "You must call .add_property() first." + assert self._theta is not None and self._emergent_property is not None, ( + "You must call .add_property() first." + ) # Get indices for permutation of the data. num_examples = len(self._theta) @@ -433,9 +433,9 @@ def find_directions( if posterior_log_prob_as_property: predictions = self._posterior.potential(thetas, track_gradients=True) else: - assert ( - self._regression_net is not None - ), "self._regression_net is None, you must call `.train()` first." + assert self._regression_net is not None, ( + "self._regression_net is None, you must call `.train()` first." + ) predictions = self._regression_net.forward(thetas) loss = predictions.mean() loss.backward() diff --git a/sbi/diagnostics/lc2st.py b/sbi/diagnostics/lc2st.py index d78d64ccd..a89f65ba2 100644 --- a/sbi/diagnostics/lc2st.py +++ b/sbi/diagnostics/lc2st.py @@ -83,9 +83,9 @@ def __init__( [2] : https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py """ - assert ( - thetas.shape[0] == xs.shape[0] == posterior_samples.shape[0] - ), "Number of samples must match" + assert thetas.shape[0] == xs.shape[0] == posterior_samples.shape[0], ( + "Number of samples must match" + ) # set observed data for classification self.theta_p = posterior_samples @@ -283,9 +283,9 @@ def get_statistic_on_observed_data( Returns: L-C2ST statistic at `x_o`. """ - assert ( - self.trained_clfs is not None - ), "No trained classifiers found. Run `train_on_observed_data` first." + assert self.trained_clfs is not None, ( + "No trained classifiers found. Run `train_on_observed_data` first." + ) _, scores = self.get_scores( theta_o=theta_o, x_o=x_o, @@ -372,9 +372,9 @@ def train_under_null_hypothesis( joint_q_perm[:, self.theta_q.shape[1] :], ) else: - assert ( - self.null_distribution is not None - ), "You need to provide a null distribution" + assert self.null_distribution is not None, ( + "You need to provide a null distribution" + ) theta_p_t = self.null_distribution.sample((self.theta_p.shape[0],)) theta_q_t = self.null_distribution.sample((self.theta_p.shape[0],)) x_p_t, x_q_t = self.x_p, self.x_q @@ -419,9 +419,9 @@ def get_statistics_under_null_hypothesis( Run `train_under_null_hypothesis`." ) else: - assert ( - len(self.trained_clfs_null) == self.num_trials_null - ), "You need one classifier per trial." + assert len(self.trained_clfs_null) == self.num_trials_null, ( + "You need one classifier per trial." + ) probs_null, stats_null = [], [] for t in tqdm( @@ -433,9 +433,9 @@ def get_statistics_under_null_hypothesis( if self.permutation: theta_o_t = theta_o else: - assert ( - self.null_distribution is not None - ), "You need to provide a null distribution" + assert self.null_distribution is not None, ( + "You need to provide a null distribution" + ) theta_o_t = self.null_distribution.sample((theta_o.shape[0],)) diff --git a/sbi/diagnostics/sbc.py b/sbi/diagnostics/sbc.py index ba01fb0a8..0017893c7 100644 --- a/sbi/diagnostics/sbc.py +++ b/sbi/diagnostics/sbc.py @@ -69,9 +69,9 @@ def run_sbc( stacklevel=2, ) - assert ( - thetas.shape[0] == xs.shape[0] - ), "Unequal number of parameters and observations." + assert thetas.shape[0] == xs.shape[0], ( + "Unequal number of parameters and observations." + ) if "sbc_batch_size" in kwargs: warnings.warn( diff --git a/sbi/diagnostics/tarp.py b/sbi/diagnostics/tarp.py index 44ff114f3..9cffb4c4c 100644 --- a/sbi/diagnostics/tarp.py +++ b/sbi/diagnostics/tarp.py @@ -133,9 +133,9 @@ def _run_tarp( """ num_posterior_samples, num_tarp_samples, _ = posterior_samples.shape - assert ( - references.shape == thetas.shape - ), "references must have the same shape as thetas" + assert references.shape == thetas.shape, ( + "references must have the same shape as thetas" + ) if num_bins is None: num_bins = num_tarp_samples // 10 diff --git a/sbi/inference/abc/mcabc.py b/sbi/inference/abc/mcabc.py index 5fd153fe9..943b29270 100644 --- a/sbi/inference/abc/mcabc.py +++ b/sbi/inference/abc/mcabc.py @@ -130,9 +130,9 @@ def __call__( """ # Exactly one of eps or quantile need to be passed. - assert (eps is not None) ^ ( - quantile is not None - ), "Eps or quantile must be passed, but not both." + assert (eps is not None) ^ (quantile is not None), ( + "Eps or quantile must be passed, but not both." + ) if kde_kwargs is None: kde_kwargs = {} diff --git a/sbi/inference/abc/smcabc.py b/sbi/inference/abc/smcabc.py index a4766321a..25b20400c 100644 --- a/sbi/inference/abc/smcabc.py +++ b/sbi/inference/abc/smcabc.py @@ -95,9 +95,9 @@ def __init__( ) kernels = ("gaussian", "uniform") - assert ( - kernel in kernels - ), f"Kernel '{kernel}' not supported. Choose one from {kernels}." + assert kernel in kernels, ( + f"Kernel '{kernel}' not supported. Choose one from {kernels}." + ) self.kernel = kernel algorithm_variants = ("A", "B", "C") @@ -198,13 +198,13 @@ def __call__( if kde_kwargs is None: kde_kwargs = {} assert isinstance(epsilon_decay, float) and epsilon_decay > 0.0 - assert not ( - self.distance.requires_iid_data and lra - ), "Currently there is no support to run inference " + assert not (self.distance.requires_iid_data and lra), ( + "Currently there is no support to run inference " + ) "on multiple observations together with lra." - assert not ( - self.distance.requires_iid_data and sass - ), "Currently there is no support to run inference " + assert not (self.distance.requires_iid_data and sass), ( + "Currently there is no support to run inference " + ) "on multiple observations together with sass." # Pilot run for SASS. @@ -363,9 +363,9 @@ def _set_xo_and_sample_initial_population( ) -> Tuple[Tensor, float, Tensor, Tensor]: """Return particles, epsilon and distances of initial population.""" - assert ( - num_particles <= num_initial_pop - ), "number of initial round simulations must be greater than population size" + assert num_particles <= num_initial_pop, ( + "number of initial round simulations must be greater than population size" + ) assert (x_o.shape[0] == 1) or self.distance.requires_iid_data, ( "Your data contain iid data-points, but the choice of " diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 55d8f0d88..6282a237d 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -288,9 +288,7 @@ def __repr__(self): return desc def __str__(self): - desc = ( - f"Posterior p(θ|x) of type {self.__class__.__name__}. " f"{self._purpose}" - ) + desc = f"Posterior p(θ|x) of type {self.__class__.__name__}. {self._purpose}" return desc def __getstate__(self) -> Dict: diff --git a/sbi/inference/posteriors/ensemble_posterior.py b/sbi/inference/posteriors/ensemble_posterior.py index 1b251439f..23323dbc5 100644 --- a/sbi/inference/posteriors/ensemble_posterior.py +++ b/sbi/inference/posteriors/ensemble_posterior.py @@ -112,9 +112,9 @@ def ensure_same_device(self, posteriors: List) -> str: A device string, that is the same for all posteriors. """ devices = [posterior._device for posterior in posteriors] - assert all( - device == devices[0] for device in devices - ), "Only supported if all posteriors are on the same device." + assert all(device == devices[0] for device in devices), ( + "Only supported if all posteriors are on the same device." + ) return devices[0] @property diff --git a/sbi/inference/posteriors/mcmc_posterior.py b/sbi/inference/posteriors/mcmc_posterior.py index c7d462688..674aefb04 100644 --- a/sbi/inference/posteriors/mcmc_posterior.py +++ b/sbi/inference/posteriors/mcmc_posterior.py @@ -418,9 +418,9 @@ def sample_batched( else init_strategy_parameters ) - assert ( - method == "slice_np_vectorized" - ), "Batched sampling only supported for vectorized samplers!" + assert method == "slice_np_vectorized", ( + "Batched sampling only supported for vectorized samplers!" + ) # warn if num_chains is larger than num requested samples if num_chains > torch.Size(sample_shape).numel(): @@ -1003,9 +1003,9 @@ def get_arviz_inference_data(self) -> InferenceData: Returns: inference_data: Arviz InferenceData object. """ - assert ( - self._posterior_sampler is not None - ), """No samples have been generated, call .sample() first.""" + assert self._posterior_sampler is not None, ( + """No samples have been generated, call .sample() first.""" + ) sampler: Union[ MCMC, SliceSamplerSerial, SliceSamplerVectorized, PyMCSampler diff --git a/sbi/inference/potentials/likelihood_based_potential.py b/sbi/inference/potentials/likelihood_based_potential.py index 8066d3dd1..beb1c988c 100644 --- a/sbi/inference/potentials/likelihood_based_potential.py +++ b/sbi/inference/potentials/likelihood_based_potential.py @@ -104,11 +104,11 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: # Calculate likelihood for each (theta,x) pair separately theta_batch_size = theta.shape[0] x_batch_size = self.x_o.shape[0] - assert ( - theta_batch_size == x_batch_size - ), f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\ + assert theta_batch_size == x_batch_size, ( + f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\ When performing batched sampling for multiple `x`, the batch size of\ `theta` must match the batch size of `x`." + ) x = self.x_o.unsqueeze(0) with torch.set_grad_enabled(track_gradients): log_likelihood_batches = self.likelihood_estimator.log_prob( @@ -143,9 +143,9 @@ def condition_on_theta( def conditioned_potential( theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True ) -> Tensor: - assert ( - len(dims_global_theta) == theta.shape[1] - ), "dims_global_theta must match the number of parameters to sample." + assert len(dims_global_theta) == theta.shape[1], ( + "dims_global_theta must match the number of parameters to sample." + ) global_theta = theta[:, dims_global_theta] x_o = x_o if x_o is not None else self.x_o # x needs shape (sample_dim (iid), batch_dim (xs), *event_shape) @@ -257,15 +257,15 @@ def _log_likelihood_over_iid_trials_and_local_theta( theta_batch_dim)`. """ assert x.dim() > 2, "x must have shape (sample_dim, batch_dim, *event_shape)." - assert ( - local_theta.dim() == 2 - ), "condition must have shape (sample_dim, num_conditions)." + assert local_theta.dim() == 2, ( + "condition must have shape (sample_dim, num_conditions)." + ) assert global_theta.dim() == 2, "theta must have shape (batch_dim, num_parameters)." num_trials, num_xs = x.shape[:2] num_thetas = global_theta.shape[0] - assert ( - local_theta.shape[0] == num_trials - ), "Condition batch size must match the number of iid trials in x." + assert local_theta.shape[0] == num_trials, ( + "Condition batch size must match the number of iid trials in x." + ) # move the iid batch dimension onto the batch dimension of theta and repeat it there x_repeated = torch.transpose(x, 0, 1).repeat_interleave(num_thetas, dim=1) diff --git a/sbi/inference/potentials/posterior_based_potential.py b/sbi/inference/potentials/posterior_based_potential.py index 4c0359b02..3a4c8100e 100644 --- a/sbi/inference/potentials/posterior_based_potential.py +++ b/sbi/inference/potentials/posterior_based_potential.py @@ -125,11 +125,11 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: theta_batch_size = theta.shape[0] x_batch_size = x.shape[0] - assert ( - theta_batch_size == x_batch_size or x_batch_size == 1 - ), f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\ + assert theta_batch_size == x_batch_size or x_batch_size == 1, ( + f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\ When performing batched sampling for multiple `x`, the batch size of\ `theta` must match the batch size of `x`." + ) if x_batch_size == 1: # If a single `x` is passed (i.e. batchsize==1), we squeeze diff --git a/sbi/inference/potentials/ratio_based_potential.py b/sbi/inference/potentials/ratio_based_potential.py index d45d4c614..1ab2df66d 100644 --- a/sbi/inference/potentials/ratio_based_potential.py +++ b/sbi/inference/potentials/ratio_based_potential.py @@ -95,11 +95,11 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: theta_batch_size = theta.shape[0] x_batch_size = self.x_o.shape[0] - assert ( - theta_batch_size == x_batch_size - ), f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\ + assert theta_batch_size == x_batch_size, ( + f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\ When performing batched sampling for multiple `x`, the batch size of\ `theta` must match the batch size of `x`." + ) with torch.set_grad_enabled(track_gradients): log_ratio_batches = self.ratio_estimator(theta, self.x_o) log_ratio_batches = log_ratio_batches.reshape(-1) @@ -130,9 +130,9 @@ def _log_ratios_over_trials( theta_repeated, x_repeated = match_theta_and_x_batch_shapes( theta=atleast_2d(theta), x=atleast_2d(x) ) - assert ( - x_repeated.shape[0] == theta_repeated.shape[0] - ), "x and theta must match in batch shape." + assert x_repeated.shape[0] == theta_repeated.shape[0], ( + "x and theta must match in batch shape." + ) assert ( next(net.parameters()).device == x.device and x.device == theta.device ), f"""device mismatch: net, x, theta: {next(net.parameters()).device}, {x.device}, diff --git a/sbi/inference/potentials/score_based_potential.py b/sbi/inference/potentials/score_based_potential.py index 50974c18c..9d094323b 100644 --- a/sbi/inference/potentials/score_based_potential.py +++ b/sbi/inference/potentials/score_based_potential.py @@ -41,9 +41,9 @@ def score_estimator_based_potential( score_estimator, prior, x_o, device=device ) - assert ( - enable_transform is False - ), "Transforms are not yet supported for score estimators." + assert enable_transform is False, ( + "Transforms are not yet supported for score estimators." + ) if prior is not None: theta_transform = mcmc_transform( @@ -107,9 +107,9 @@ def __call__( x_density_estimator = reshape_to_batch_event( self.x_o, event_shape=self.score_estimator.condition_shape ) - assert ( - x_density_estimator.shape[0] == 1 - ), "PosteriorScoreBasedPotential supports only x batchsize of 1`." + assert x_density_estimator.shape[0] == 1, ( + "PosteriorScoreBasedPotential supports only x batchsize of 1`." + ) self.score_estimator.eval() diff --git a/sbi/inference/trainers/nle/nle_base.py b/sbi/inference/trainers/nle/nle_base.py index 6bb835859..94f610793 100644 --- a/sbi/inference/trainers/nle/nle_base.py +++ b/sbi/inference/trainers/nle/nle_base.py @@ -177,9 +177,9 @@ def train( theta[self.train_indices].to("cpu"), x[self.train_indices].to("cpu"), ) - assert ( - len(x_shape_from_simulation(x.to("cpu"))) < 3 - ), "SNLE cannot handle multi-dimensional simulator output." + assert len(x_shape_from_simulation(x.to("cpu"))) < 3, ( + "SNLE cannot handle multi-dimensional simulator output." + ) del theta, x self._neural_net.to(self._device) diff --git a/sbi/inference/trainers/npe/npe_a.py b/sbi/inference/trainers/npe/npe_a.py index 670c46a9a..f48cd325a 100644 --- a/sbi/inference/trainers/npe/npe_a.py +++ b/sbi/inference/trainers/npe/npe_a.py @@ -568,9 +568,9 @@ def _posthoc_correction(self, x: Tensor): Mixture components of the posterior. """ # Remove the batch dimension of `x` (SNPE-A always has a single `x`). - assert ( - x.shape[0] == 1 - ), f"Batchsize of `x_o` == {x.shape[0]}. SNPE-A only supports a single `x_o`." + assert x.shape[0] == 1, ( + f"Batchsize of `x_o` == {x.shape[0]}. SNPE-A only supports a single `x_o`." + ) x = x.squeeze(dim=0) # Evaluate the density estimator. diff --git a/sbi/inference/trainers/npse/npse.py b/sbi/inference/trainers/npse/npse.py index 6bc4a62f4..a3bf9ca2b 100644 --- a/sbi/inference/trainers/npse/npse.py +++ b/sbi/inference/trainers/npse/npse.py @@ -129,9 +129,9 @@ def append_simulations( Returns: NeuralInference object (returned so that this function is chainable). """ - assert ( - proposal is None - ), "Multi-round NPSE is not yet implemented. Please use single-round NPSE." + assert proposal is None, ( + "Multi-round NPSE is not yet implemented. Please use single-round NPSE." + ) current_round = 0 if exclude_invalid_x is None: diff --git a/sbi/neural_nets/embedding_nets/cnn.py b/sbi/neural_nets/embedding_nets/cnn.py index aeb9a7e9a..9938a940e 100644 --- a/sbi/neural_nets/embedding_nets/cnn.py +++ b/sbi/neural_nets/embedding_nets/cnn.py @@ -101,9 +101,9 @@ def __init__( """ super(CNNEmbedding, self).__init__() - assert isinstance( - input_shape, Tuple - ), "input_shape must be a Tuple of size 1 or 2, e.g., (width, [height])." + assert isinstance(input_shape, Tuple), ( + "input_shape must be a Tuple of size 1 or 2, e.g., (width, [height])." + ) assert ( 0 < len(input_shape) < 3 ), """input_shape must be a Tuple of size 1 or 2, e.g., @@ -115,9 +115,9 @@ def __init__( if out_channels_per_layer is None: out_channels_per_layer = [6, 12] - assert ( - len(out_channels_per_layer) == num_conv_layers - ), "out_channels needs as many entries as num_cnn_layers." + assert len(out_channels_per_layer) == num_conv_layers, ( + "out_channels needs as many entries as num_cnn_layers." + ) # define input shape with channel self.input_shape = (in_channels, *input_shape) diff --git a/sbi/neural_nets/estimators/mixed_density_estimator.py b/sbi/neural_nets/estimators/mixed_density_estimator.py index dedba1b52..81bed1b26 100644 --- a/sbi/neural_nets/estimators/mixed_density_estimator.py +++ b/sbi/neural_nets/estimators/mixed_density_estimator.py @@ -133,9 +133,9 @@ def log_prob(self, input: Tensor, condition: Tensor) -> Tensor: Sample-wise log probabilities, shape `(input_sample_dim, input_batch_dim)`. """ - assert ( - input.dim() > 2 - ), "Input must be of shape (sample_dim, batch_dim, *event_shape)." + assert input.dim() > 2, ( + "Input must be of shape (sample_dim, batch_dim, *event_shape)." + ) input_sample_dim, input_batch_dim = input.shape[:2] condition_batch_dim = condition.shape[0] combined_batch_size = input_sample_dim * input_batch_dim diff --git a/sbi/neural_nets/estimators/shape_handling.py b/sbi/neural_nets/estimators/shape_handling.py index 7edf3fa33..534df2396 100644 --- a/sbi/neural_nets/estimators/shape_handling.py +++ b/sbi/neural_nets/estimators/shape_handling.py @@ -33,9 +33,9 @@ def reshape_to_sample_batch_event( trailing_theta_or_x_shape = theta_or_x.shape[-event_shape_dim:] leading_theta_or_x_shape = theta_or_x.shape[:-event_shape_dim] - assert ( - trailing_theta_or_x_shape == event_shape - ), "The trailing dimensions of `theta_or_x` do not match the `event_shape`." + assert trailing_theta_or_x_shape == event_shape, ( + "The trailing dimensions of `theta_or_x` do not match the `event_shape`." + ) if len(leading_theta_or_x_shape) == 0: # A single datapoint is passed. Add batch and sample dim artificially. @@ -71,9 +71,9 @@ def reshape_to_batch_event(theta_or_x: Tensor, event_shape: torch.Size) -> Tenso trailing_theta_or_x_shape = theta_or_x.shape[-event_shape_dim:] leading_theta_or_x_shape = theta_or_x.shape[:-event_shape_dim] - assert ( - trailing_theta_or_x_shape == event_shape - ), "The trailing dimensions of `theta_or_x` do not match the `event_shape`." + assert trailing_theta_or_x_shape == event_shape, ( + "The trailing dimensions of `theta_or_x` do not match the `event_shape`." + ) if len(leading_theta_or_x_shape) == 0: # A single datapoint is passed. Add batch artificially. diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 35492ca47..26ac254a3 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -179,9 +179,9 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor: # When in case of leakage a batch size was used there could be too many samples. samples = torch.cat(accepted)[:num_samples] - assert ( - samples.shape[0] == num_samples - ), "Number of accepted samples must match required samples." + assert samples.shape[0] == num_samples, ( + "Number of accepted samples must match required samples." + ) return samples, as_tensor(acceptance_rate) @@ -358,8 +358,8 @@ def accept_reject_sample( samples = [torch.cat(accepted[i], dim=0)[:num_samples] for i in range(num_xos)] samples = torch.stack(samples, dim=1) samples = samples.reshape(num_samples, *candidates.shape[1:]) - assert ( - samples.shape[0] == num_samples - ), "Number of accepted samples must match required samples." + assert samples.shape[0] == num_samples, ( + "Number of accepted samples must match required samples." + ) return samples, as_tensor(acceptance_rate) diff --git a/sbi/samplers/score/correctors.py b/sbi/samplers/score/correctors.py index e64b370d7..80b7b44ce 100644 --- a/sbi/samplers/score/correctors.py +++ b/sbi/samplers/score/correctors.py @@ -32,9 +32,9 @@ def register_corrector(name: str) -> Callable: """ def decorator(corrector: Type[Corrector]) -> Callable: - assert issubclass( - corrector, Corrector - ), "Corrector must be a subclass of Corrector." + assert issubclass(corrector, Corrector), ( + "Corrector must be a subclass of Corrector." + ) CORRECTORS[name] = corrector return corrector diff --git a/sbi/samplers/score/predictors.py b/sbi/samplers/score/predictors.py index 3f0a2eba6..1ac16a053 100644 --- a/sbi/samplers/score/predictors.py +++ b/sbi/samplers/score/predictors.py @@ -37,9 +37,9 @@ def register_predictor(name: str) -> Callable: """ def decorator(predictor: Type[Predictor]) -> Callable: - assert issubclass( - predictor, Predictor - ), "Predictor must be a subclass of Predictor." + assert issubclass(predictor, Predictor), ( + "Predictor must be a subclass of Predictor." + ) PREDICTORS[name] = predictor return predictor diff --git a/sbi/samplers/vi/vi_quality_control.py b/sbi/samplers/vi/vi_quality_control.py index 2c9778404..0e23f2abe 100644 --- a/sbi/samplers/vi/vi_quality_control.py +++ b/sbi/samplers/vi/vi_quality_control.py @@ -74,12 +74,12 @@ def basic_checks(posterior, N: int = int(5e4)): assert ( prior.support.check(samples) # type: ignore ).all(), "Some of the samples are not within the prior support." - assert ( - torch.isfinite(posterior.log_prob(samples)) - ).all(), "The log probability is not finite for some samples" - assert ( - torch.isfinite(posterior.log_prob(prior_samples)) - ).all(), "The log probability is not finite for some samples" + assert (torch.isfinite(posterior.log_prob(samples))).all(), ( + "The log probability is not finite for some samples" + ) + assert (torch.isfinite(posterior.log_prob(prior_samples))).all(), ( + "The log probability is not finite for some samples" + ) def psis_diagnostics( diff --git a/sbi/samplers/vi/vi_utils.py b/sbi/samplers/vi/vi_utils.py index 98c1a4195..2bff72a8c 100644 --- a/sbi/samplers/vi/vi_utils.py +++ b/sbi/samplers/vi/vi_utils.py @@ -95,9 +95,9 @@ def check_parameters_modules_attribute(q: PyroTransformedDistribution) -> None: else: assert isinstance(q.parameters, Callable), "The parameters must be callable" parameters = q.parameters() - assert isinstance( - parameters, Iterable - ), "The parameters return value must be iterable" + assert isinstance(parameters, Iterable), ( + "The parameters return value must be iterable" + ) trainable = 0 for p in parameters: assert isinstance(p, torch.Tensor) @@ -115,13 +115,13 @@ def check_parameters_modules_attribute(q: PyroTransformedDistribution) -> None: else: assert isinstance(q.modules, Callable), "The parameters must be callable" modules = q.modules() - assert isinstance( - modules, Iterable - ), "The parameters return value must be iterable" + assert isinstance(modules, Iterable), ( + "The parameters return value must be iterable" + ) for m in modules: - assert isinstance( - m, Module - ), "The modules must contain PyTorch Module objects" + assert isinstance(m, Module), ( + "The modules must contain PyTorch Module objects" + ) def check_sample_shape_and_support(q: Distribution, prior: Distribution) -> None: @@ -134,12 +134,12 @@ def check_sample_shape_and_support(q: Distribution, prior: Distribution) -> None prior: Prior to check certain attributes which should be satisfied. """ - assert ( - q.event_shape == prior.event_shape - ), "The event shape of q must match that of the prior" - assert ( - q.batch_shape == prior.batch_shape - ), "The batch sahpe of q must match that of the prior" + assert q.event_shape == prior.event_shape, ( + "The event shape of q must match that of the prior" + ) + assert q.batch_shape == prior.batch_shape, ( + "The batch sahpe of q must match that of the prior" + ) sample_shape = torch.Size((1000,)) samples = q.sample(sample_shape) @@ -153,15 +153,15 @@ def check_sample_shape_and_support(q: Distribution, prior: Distribution) -> None assert all( prior.support.check(samples) # type: ignore ), "The support of q must match that of the prior" - assert ( - samples.shape == samples_prior.shape - ), "Something is wrong with sample shape and event_shape or batch_shape attributes." - assert torch.isfinite( - q.log_prob(samples_prior) - ).all(), "Invalid values in logprob on prior samples." - assert torch.isfinite( - prior.log_prob(samples) - ).all(), "Invalid values in logprob on q samples." + assert samples.shape == samples_prior.shape, ( + "sample_shape and event_shape or batch_shape do not match." + ) + assert torch.isfinite(q.log_prob(samples_prior)).all(), ( + "Invalid values in logprob on prior samples." + ) + assert torch.isfinite(prior.log_prob(samples)).all(), ( + "Invalid values in logprob on q samples." + ) def check_variational_distribution(q: Distribution, prior: Distribution) -> None: diff --git a/sbi/utils/metrics.py b/sbi/utils/metrics.py index 7de4a67cd..f71356dc6 100644 --- a/sbi/utils/metrics.py +++ b/sbi/utils/metrics.py @@ -257,9 +257,9 @@ def wasserstein_2_squared( [1] Peyré, G., & Cuturi, M. (2019). Computational optimal transport: With applications to data science. """ - assert ( - x.ndim == y.ndim - ), "Please make sure that 'x' and 'y' are both either batched or not." + assert x.ndim == y.ndim, ( + "Please make sure that 'x' and 'y' are both either batched or not." + ) if x.ndim == 2: nx, ny = x.shape[0], y.shape[0] a = torch.ones(nx) / nx @@ -313,9 +313,9 @@ def regularized_ot_dual( Optimal transport coupling of shape (B, m, n) or (m, n) """ - assert ( - a.ndim == b.ndim - ), "Please make sure that 'a' and 'b' are both either batched or not." + assert a.ndim == b.ndim, ( + "Please make sure that 'a' and 'b' are both either batched or not." + ) f"currently a.ndim={a.ndim} and b.ndim={b.ndim}" batched = True diff --git a/sbi/utils/restriction_estimator.py b/sbi/utils/restriction_estimator.py index e895d0d1a..2cd885689 100644 --- a/sbi/utils/restriction_estimator.py +++ b/sbi/utils/restriction_estimator.py @@ -560,9 +560,9 @@ def __init__( else: raise NameError(f"`safety_margin` {safety_margin} not supported.") else: - assert ( - allowed_false_negatives is not None - ), "`allowed_false_negatives` must be set." + assert allowed_false_negatives is not None, ( + "`allowed_false_negatives` must be set." + ) quantile_index = floor(num_valid * allowed_false_negatives) self._classifier_thr, _ = torch.kthvalue(clf_probs, quantile_index + 1) diff --git a/sbi/utils/torchutils.py b/sbi/utils/torchutils.py index be11ead4d..ee4a62930 100644 --- a/sbi/utils/torchutils.py +++ b/sbi/utils/torchutils.py @@ -301,9 +301,9 @@ def __init__( """ # Type checks. - assert isinstance(low, Tensor) and isinstance( - high, Tensor - ), f"low and high must be tensors but are {type(low)} and {type(high)}." + assert isinstance(low, Tensor) and isinstance(high, Tensor), ( + f"low and high must be tensors but are {type(low)} and {type(high)}." + ) if not low.device == high.device: raise RuntimeError( "Expected all tensors to be on the same device, but found at least" diff --git a/sbi/utils/user_input_checks.py b/sbi/utils/user_input_checks.py index f407d632f..fbad7c6f0 100644 --- a/sbi/utils/user_input_checks.py +++ b/sbi/utils/user_input_checks.py @@ -386,9 +386,9 @@ def check_prior_return_type( """Check whether prior.sample() returns float32 Tensor.""" prior_dtype = prior.sample().dtype - assert ( - prior_dtype == return_type - ), f"Prior return type must be {return_type}, but is {prior_dtype}." + assert prior_dtype == return_type, ( + f"Prior return type must be {return_type}, but is {prior_dtype}." + ) def check_prior_batch_behavior(prior) -> None: @@ -408,13 +408,13 @@ def check_prior_batch_behavior(prior) -> None: # Using len here because `log_prob` could be `ndarray` or `torch.Tensor`. num_log_probs = len(log_probs) - assert ( - num_sampled == num_samples - ), "prior.sample((batch_size, )) must return batch_size parameters." + assert num_sampled == num_samples, ( + "prior.sample((batch_size, )) must return batch_size parameters." + ) - assert ( - num_log_probs == num_samples - ), "prior.log_prob must return as many log probs as samples." + assert num_log_probs == num_samples, ( + "prior.log_prob must return as many log probs as samples." + ) def check_prior_support(prior): @@ -503,9 +503,9 @@ def wrap_as_pytorch_simulator( # Get data to check input type is consistent with data. theta = prior.sample().numpy() # Cast to numpy because is in PyTorch already. x = simulator(theta) - assert isinstance( - x, ndarray - ), f"Simulator output type {type(x)} must match its input type {type(theta)}" + assert isinstance(x, ndarray), ( + f"Simulator output type {type(x)} must match its input type {type(theta)}" + ) # Define a wrapper function to PyTorch def pytorch_simulator(theta: Tensor) -> Tensor: diff --git a/sbi/utils/user_input_checks_utils.py b/sbi/utils/user_input_checks_utils.py index 554a6f8f2..7c549eed8 100644 --- a/sbi/utils/user_input_checks_utils.py +++ b/sbi/utils/user_input_checks_utils.py @@ -210,9 +210,9 @@ def _check_distributions(self, dists): def _check_distribution(self, dist: Distribution): """Check type and shape of a single input distribution.""" - assert not isinstance( - dist, (MultipleIndependent, Sequence) - ), "Nesting of combined distributions is not possible." + assert not isinstance(dist, (MultipleIndependent, Sequence)), ( + "Nesting of combined distributions is not possible." + ) assert isinstance( dist, Distribution ), """priors passed to MultipleIndependent must be PyTorch distributions. Make @@ -274,15 +274,15 @@ def _prepare_value(self, value) -> Tensor: if value.ndim < 2: value = value.unsqueeze(0) - assert ( - value.ndim == 2 - ), f"value in log_prob must have ndim <= 2, it is {value.ndim}." + assert value.ndim == 2, ( + f"value in log_prob must have ndim <= 2, it is {value.ndim}." + ) batch_shape, num_value_dims = value.shape - assert ( - num_value_dims == self.ndims - ), f"Number of dimensions must match dimensions of this joint: {self.ndims}." + assert num_value_dims == self.ndims, ( + f"Number of dimensions must match dimensions of this joint: {self.ndims}." + ) return value @@ -361,9 +361,9 @@ def build_support( # Both are specified. else: num_dimensions = lower_bound.numel() - assert ( - num_dimensions == upper_bound.numel() - ), "There must be an equal number of independent bounds." + assert num_dimensions == upper_bound.numel(), ( + "There must be an equal number of independent bounds." + ) if num_dimensions > 1: support = constraints._IndependentConstraint( constraints.interval(lower_bound, upper_bound), diff --git a/tests/density_estimator_test.py b/tests/density_estimator_test.py index 0f006dde1..372a8d757 100644 --- a/tests/density_estimator_test.py +++ b/tests/density_estimator_test.py @@ -279,9 +279,9 @@ def test_correctness_of_batched_vs_seperate_sample_and_log_prob( samples_separate2_m = torch.mean(samples_separate2, dim=0, dtype=torch.float32) samples_sep_m = torch.cat([samples_separate1_m, samples_separate2_m], dim=0) - assert torch.allclose( - samples_m, samples_sep_m, atol=0.5, rtol=0.5 - ), "Batched sampling is not consistent with separate sampling." + assert torch.allclose(samples_m, samples_sep_m, atol=0.5, rtol=0.5), ( + "Batched sampling is not consistent with separate sampling." + ) # Batched vs separate log_prob log_probs = density_estimator.log_prob(inputs, condition=condition) @@ -294,9 +294,9 @@ def test_correctness_of_batched_vs_seperate_sample_and_log_prob( ) log_probs_sep = torch.hstack([log_probs_separate1, log_probs_separate2]) - assert torch.allclose( - log_probs, log_probs_sep, atol=1e-2, rtol=1e-2 - ), "Batched log_prob is not consistent with separate log_prob." + assert torch.allclose(log_probs, log_probs_sep, atol=1e-2, rtol=1e-2), ( + "Batched log_prob is not consistent with separate log_prob." + ) def _build_density_estimator_and_tensors( diff --git a/tests/ensemble_test.py b/tests/ensemble_test.py index 837d02e74..21daf2d6a 100644 --- a/tests/ensemble_test.py +++ b/tests/ensemble_test.py @@ -133,9 +133,9 @@ def simulator(theta): num_samples=num_samples, ) max_dkl = 0.15 - assert ( - dkl < max_dkl - ), f"D-KL={dkl} is more than 2 stds above the average performance." + assert dkl < max_dkl, ( + f"D-KL={dkl} is more than 2 stds above the average performance." + ) # test individual log_prob and map posterior.log_prob(samples, individually=True) diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index 1c6653af6..b32d982ea 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -315,9 +315,9 @@ def test_train_with_different_data_and_training_device( density_estimator=estimator if data_device == "cpu" else None, prior=prior, ).set_default_x(x_o) - assert posterior._device == str( - weights_device - ), "inferred posterior device not correct." + assert posterior._device == str(weights_device), ( + "inferred posterior device not correct." + ) @pytest.mark.parametrize("inference_method", [NPE_A, NPE_C, NRE_A, NRE_B, NRE_C, NLE]) @@ -406,12 +406,12 @@ def allow_iid_x(self) -> bool: samples = posterior.sample((1,), method=sampling_method) logprobs = posterior.log_prob(samples) - assert ( - str(samples.device) == device - ), f"The devices after training do not match: {samples.device} vs {device}" - assert ( - str(logprobs.device) == device - ), f"The devices after training do not match: {logprobs.device} vs {device}" + assert str(samples.device) == device, ( + f"The devices after training do not match: {samples.device} vs {device}" + ) + assert str(logprobs.device) == device, ( + f"The devices after training do not match: {logprobs.device} vs {device}" + ) @pytest.mark.gpu diff --git a/tests/inference_with_NaN_simulator_test.py b/tests/inference_with_NaN_simulator_test.py index 7ec7ed898..c2e1636cf 100644 --- a/tests/inference_with_NaN_simulator_test.py +++ b/tests/inference_with_NaN_simulator_test.py @@ -226,9 +226,9 @@ def integrate_grid(distribution): restricted_prior_probs = torch.exp(restricted_prior.log_prob(theta)) valid_thetas = restricted_prior._accept_reject_fn(theta).bool() - assert torch.all( - restricted_prior_probs[valid_thetas] > 0.0 - ), "Accepted theta have zero probability." - assert torch.all( - restricted_prior_probs[torch.logical_not(valid_thetas)] == 0.0 - ), "Rejected theta has non-zero probablity." + assert torch.all(restricted_prior_probs[valid_thetas] > 0.0), ( + "Accepted theta have zero probability." + ) + assert torch.all(restricted_prior_probs[torch.logical_not(valid_thetas)] == 0.0), ( + "Rejected theta has non-zero probablity." + ) diff --git a/tests/lc2st_test.py b/tests/lc2st_test.py index e1e445bac..62f5e633b 100644 --- a/tests/lc2st_test.py +++ b/tests/lc2st_test.py @@ -180,11 +180,11 @@ def test_lc2st_true_positiv_rate(method): proportion_rejected = torch.tensor(results).float().mean() - assert ( - proportion_rejected > confidence_level - ), f"LC2ST p-values too big, test should be rejected \ + assert proportion_rejected > confidence_level, ( + f"LC2ST p-values too big, test should be rejected \ at least {confidence_level * 100}% of the time, but was rejected \ only {proportion_rejected * 100}% of the time." + ) @pytest.mark.slow @@ -259,8 +259,8 @@ def test_lc2st_false_positiv_rate(method): proportion_rejected = torch.tensor(results).float().mean() - assert proportion_rejected < ( - 1 - confidence_level - ), f"LC2ST p-values too small, test should be rejected \ + assert proportion_rejected < (1 - confidence_level), ( + f"LC2ST p-values too small, test should be rejected \ less then {(1 - confidence_level) * 100}% of the time, \ but was rejected {proportion_rejected * 100}% of the time." + ) diff --git a/tests/linearGaussian_fmpe_test.py b/tests/linearGaussian_fmpe_test.py index 804ea46fb..c7280a368 100644 --- a/tests/linearGaussian_fmpe_test.py +++ b/tests/linearGaussian_fmpe_test.py @@ -96,9 +96,9 @@ def test_c2st_fmpe_on_linearGaussian(num_dim: int, prior_str: str): max_dkl = 0.15 - assert ( - dkl < max_dkl - ), f"D-KL={dkl} is more than 2 stds above the average performance." + assert dkl < max_dkl, ( + f"D-KL={dkl} is more than 2 stds above the average performance." + ) # test probs probs = posterior.log_prob(samples).exp() @@ -110,9 +110,9 @@ def test_c2st_fmpe_on_linearGaussian(num_dim: int, prior_str: str): elif prior_str == "uniform": # Check whether the returned probability outside of the support is zero. posterior_prob = get_prob_outside_uniform_prior(posterior, prior, num_dim) - assert ( - posterior_prob == 0.0 - ), "The posterior probability outside of the prior support is not zero" + assert posterior_prob == 0.0, ( + "The posterior probability outside of the prior support is not zero" + ) # Check whether normalization (i.e. scaling up the density due # to leakage into regions without prior support) scales up the density by the @@ -377,9 +377,9 @@ def test_fmpe_map(): map_ = posterior.map(show_progress_bars=True, num_iter=20) # Check whether the MAP is close to the ground truth. - assert torch.allclose( - map_, gt_posterior.mean, atol=0.2 - ), f"{map_} != {gt_posterior.mean}" + assert torch.allclose(map_, gt_posterior.mean, atol=0.2), ( + f"{map_} != {gt_posterior.mean}" + ) def test_multi_round_handling_fmpe(): diff --git a/tests/linearGaussian_npse_test.py b/tests/linearGaussian_npse_test.py index 3b3bed9f7..c75da767d 100644 --- a/tests/linearGaussian_npse_test.py +++ b/tests/linearGaussian_npse_test.py @@ -96,9 +96,9 @@ def test_c2st_npse_on_linearGaussian( max_dkl = 0.15 - assert ( - dkl < max_dkl - ), f"D-KL={dkl} is more than 2 stds above the average performance." + assert dkl < max_dkl, ( + f"D-KL={dkl} is more than 2 stds above the average performance." + ) def test_c2st_npse_on_linearGaussian_different_dims(): diff --git a/tests/linearGaussian_simulator_test.py b/tests/linearGaussian_simulator_test.py index 3aec035f9..4457bcf0d 100644 --- a/tests/linearGaussian_simulator_test.py +++ b/tests/linearGaussian_simulator_test.py @@ -31,12 +31,12 @@ def test_standardlinearGaussian_simulator(D: int, N: int): assert xs.shape == torch.Size([N, D]) # Check mean and std. - assert torch.allclose( - xs.mean(axis=0), true_parameters, atol=5e-2 - ), f"Expected mean of zero, obtained {xs.mean(axis=0)}" - assert torch.allclose( - xs.std(axis=0), torch.ones(D), atol=5e-2 - ), f"Expected std of one, obtained {xs.std(axis=0)}" + assert torch.allclose(xs.mean(axis=0), true_parameters, atol=5e-2), ( + f"Expected mean of zero, obtained {xs.mean(axis=0)}" + ) + assert torch.allclose(xs.std(axis=0), torch.ones(D), atol=5e-2), ( + f"Expected std of one, obtained {xs.std(axis=0)}" + ) @pytest.mark.parametrize("D, N", ((1, 10000), (5, 100000))) @@ -61,9 +61,9 @@ def test_linearGaussian_simulator(D: int, N: int): assert xs.shape == torch.Size([N, D]) # Check mean and std. - assert torch.allclose( - xs.mean(axis=0), true_parameters, atol=5e-2 - ), f"Expected mean of zero, obtained {xs.mean(axis=0)}" - assert torch.allclose( - xs.std(axis=0), torch.ones(D), atol=5e-2 - ), f"Expected std of one, obtained {xs.std(axis=0)}" + assert torch.allclose(xs.mean(axis=0), true_parameters, atol=5e-2), ( + f"Expected mean of zero, obtained {xs.mean(axis=0)}" + ) + assert torch.allclose(xs.std(axis=0), torch.ones(D), atol=5e-2), ( + f"Expected std of one, obtained {xs.std(axis=0)}" + ) diff --git a/tests/linearGaussian_snle_test.py b/tests/linearGaussian_snle_test.py index 006e8ac3a..39412f53c 100644 --- a/tests/linearGaussian_snle_test.py +++ b/tests/linearGaussian_snle_test.py @@ -218,9 +218,9 @@ def simulator(theta): if prior_str == "uniform": # Check whether the returned probability outside of the support is zero. posterior_prob = get_prob_outside_uniform_prior(posterior, prior, num_dim) - assert ( - posterior_prob == 0.0 - ), "The posterior probability outside of the prior support is not zero" + assert posterior_prob == 0.0, ( + "The posterior probability outside of the prior support is not zero" + ) assert ((map_ - ones(num_dim)) ** 2).sum() < 0.5 else: diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index 51dec56f6..a0a0a188b 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -110,18 +110,18 @@ def simulator(theta): max_dkl = 0.15 - assert ( - dkl < max_dkl - ), f"D-KL={dkl} is more than 2 stds above the average performance." + assert dkl < max_dkl, ( + f"D-KL={dkl} is more than 2 stds above the average performance." + ) assert ((map_ - gt_posterior.mean) ** 2).sum() < 0.5 elif prior_str == "uniform": # Check whether the returned probability outside of the support is zero. posterior_prob = get_prob_outside_uniform_prior(posterior, prior, num_dim) - assert ( - posterior_prob == 0.0 - ), "The posterior probability outside of the prior support is not zero" + assert posterior_prob == 0.0, ( + "The posterior probability outside of the prior support is not zero" + ) # Check whether normalization (i.e. scaling up the density due # to leakage into regions without prior support) scales up the density by the @@ -625,9 +625,9 @@ def test_mdn_conditional_density(num_dim: int = 3, cond_dim: int = 1): cond_dim: Dimensionality of the condition. """ - assert ( - num_dim > cond_dim - ), "The number of dimensions needs to be greater than that of the condition!" + assert num_dim > cond_dim, ( + "The number of dimensions needs to be greater than that of the condition!" + ) x_o = zeros(1, num_dim) num_samples = 1000 diff --git a/tests/linearGaussian_snre_test.py b/tests/linearGaussian_snre_test.py index 04a30ddee..ea2cba5f5 100644 --- a/tests/linearGaussian_snre_test.py +++ b/tests/linearGaussian_snre_test.py @@ -232,18 +232,18 @@ def simulator(theta): max_dkl = 0.15 - assert ( - dkl < max_dkl - ), f"KLd={dkl} is more than 2 stds above the average performance." + assert dkl < max_dkl, ( + f"KLd={dkl} is more than 2 stds above the average performance." + ) assert ((map_ - gt_posterior.mean) ** 2).sum() < 0.5 if prior_str == "uniform": # Check whether the returned probability outside of the support is zero. posterior_prob = get_prob_outside_uniform_prior(posterior, prior, num_dim) - assert ( - posterior_prob == 0.0 - ), "The posterior probability outside of the prior support is not zero" + assert posterior_prob == 0.0, ( + "The posterior probability outside of the prior support is not zero" + ) assert ((map_ - ones(num_dim)) ** 2).sum() < 0.5 diff --git a/tests/mnle_test.py b/tests/mnle_test.py index 099876a3e..778cc44cd 100644 --- a/tests/mnle_test.py +++ b/tests/mnle_test.py @@ -344,9 +344,7 @@ def test_mnle_with_experimental_conditions(mcmc_params_accurate: dict): 1, pytest.param( 2, - marks=pytest.mark.xfail( - reason="Batched theta_condition is not " "supported" - ), + marks=pytest.mark.xfail(reason="Batched theta_condition is not supported"), ), ], ) diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index b236e21e0..ca7d0f103 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -233,9 +233,9 @@ def test_batched_mcmc_sample_log_prob_with_different_x( samples_separate2_m = torch.mean(samples_separate2, dim=0, dtype=torch.float32) samples_sep_m = torch.stack([samples_separate1_m, samples_separate2_m], dim=0) - assert torch.allclose( - samples_m, samples_sep_m, atol=0.2, rtol=0.2 - ), "Batched sampling is not consistent with separate sampling." + assert torch.allclose(samples_m, samples_sep_m, atol=0.2, rtol=0.2), ( + "Batched sampling is not consistent with separate sampling." + ) @pytest.mark.slow diff --git a/tests/sbc_test.py b/tests/sbc_test.py index 788421158..d5a59f70e 100644 --- a/tests/sbc_test.py +++ b/tests/sbc_test.py @@ -131,12 +131,12 @@ def simulator(theta): num_posterior_samples=num_posterior_samples, ) - assert ( - checks["ks_pvals"] > 0.05 - ).all(), f"KS p-values too small: {checks['ks_pvals']}" - assert ( - checks["c2st_ranks"] < 0.6 - ).all(), f"C2ST ranks too large: {checks['c2st_ranks']}" + assert (checks["ks_pvals"] > 0.05).all(), ( + f"KS p-values too small: {checks['ks_pvals']}" + ) + assert (checks["c2st_ranks"] < 0.6).all(), ( + f"C2ST ranks too large: {checks['c2st_ranks']}" + ) assert (checks["c2st_dap"] < 0.6).all(), f"C2ST DAP too large: {checks['c2st_dap']}" diff --git a/tests/test_utils.py b/tests/test_utils.py index 0d1c717da..1e260e12d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -101,9 +101,9 @@ def get_prob_outside_uniform_prior( # Test whether likelihood outside prior support is zero. assert isinstance(prior, BoxUniform) sample_outside_support = 1.1 * prior.base_dist.low - assert not within_support( - prior, sample_outside_support - ).all(), "Samples must be outside of support." + assert not within_support(prior, sample_outside_support).all(), ( + "Samples must be outside of support." + ) return torch.exp(posterior.log_prob(sample_outside_support)) @@ -148,9 +148,9 @@ def check_c2st(x: Tensor, y: Tensor, alg: str, tol: float = 0.1) -> None: score = c2st(x, y).item() print(f"c2st for {alg} is {score:.2f}.") - assert ( - (0.5 - tol) <= score <= (0.5 + tol) - ), f"{alg}'s c2st={score:.2f} is too far from the desired near-chance performance." + assert (0.5 - tol) <= score <= (0.5 + tol), ( + f"{alg}'s c2st={score:.2f} is too far from the desired near-chance performance." + ) class PosteriorPotential(BasePotential): @@ -178,9 +178,9 @@ def __init__( """ super().__init__(prior, x_o, device) - assert ( - x_o is None - ), "No need to pass x_o, passed Posterior must be fixed to x_o." + assert x_o is None, ( + "No need to pass x_o, passed Posterior must be fixed to x_o." + ) self.posterior = posterior def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor: diff --git a/tests/user_input_checks_test.py b/tests/user_input_checks_test.py index cd275a7b0..d7c58a410 100644 --- a/tests/user_input_checks_test.py +++ b/tests/user_input_checks_test.py @@ -193,9 +193,9 @@ def test_process_prior(prior): batch_size, parameter_dim, )), "Number of sampled parameters must match batch size." - assert ( - prior.log_prob(theta).shape[0] == batch_size - ), "Number of log probs must match number of input values." + assert prior.log_prob(theta).shape[0] == batch_size, ( + "Number of log probs must match number of input values." + ) @pytest.mark.parametrize( @@ -236,9 +236,9 @@ def test_process_simulator(simulator: Callable, prior: Distribution, x_shape: Tu x = simulator(prior.sample((n_batch,))) assert isinstance(x, Tensor), "Processed simulator must return Tensor." - assert ( - x.shape[0] == n_batch - ), "Processed simulator must return as many data points as parameters in batch." + assert x.shape[0] == n_batch, ( + "Processed simulator must return as many data points as parameters in batch." + ) assert x.shape[1:] == x_shape @@ -561,9 +561,9 @@ def failing_simulator(theta): simulation_batch_size=simulation_batch_size, num_workers=num_workers, ) - assert ( - theta.numel() == 0 - ), "Theta should be an empty tensor when num_simulations=0" + assert theta.numel() == 0, ( + "Theta should be an empty tensor when num_simulations=0" + ) assert x.numel() == 0, "x should be an empty tensor when num_simulations=0" else: if not use_process_simulator and num_workers > 1: @@ -583,9 +583,9 @@ def failing_simulator(theta): simulation_batch_size=simulation_batch_size, num_workers=num_workers, ) - assert ( - theta.shape[0] == num_simulations - ), "Theta should have num_simulations rows" + assert theta.shape[0] == num_simulations, ( + "Theta should have num_simulations rows" + ) assert x.shape[0] == num_simulations, "x should have num_simulations rows" assert theta.shape[1] == num_dim, "Theta should have num_dim columns" assert x.shape[1] == num_dim, "x should have num_dim columns" diff --git a/tests/vi_test.py b/tests/vi_test.py index e01a6d866..ca26fa559 100644 --- a/tests/vi_test.py +++ b/tests/vi_test.py @@ -215,23 +215,23 @@ def test_deepcopy_support(q: str): ) posterior_copy = deepcopy(posterior) posterior.set_default_x(torch.tensor(np.zeros((num_dim,)).astype(np.float32))) - assert ( - posterior._x != posterior_copy._x - ), "Default x attributed of original and copied but modified VIPosterior must be\ + assert posterior._x != posterior_copy._x, ( + "Default x attributed of original and copied but modified VIPosterior must be\ the different, on change (otherwise it is not a deep copy)." + ) posterior_copy = deepcopy(posterior) - assert ( - posterior._x == posterior_copy._x - ).all(), "Default x attributed of original and copied VIPosterior must be the same." + assert (posterior._x == posterior_copy._x).all(), ( + "Default x attributed of original and copied VIPosterior must be the same." + ) # Try if they are the same torch.manual_seed(0) s1 = posterior._q.rsample() torch.manual_seed(0) s2 = posterior_copy._q.rsample() - assert torch.allclose( - s1, s2 - ), "Samples from original and unpickled VIPosterior must be close." + assert torch.allclose(s1, s2), ( + "Samples from original and unpickled VIPosterior must be close." + ) # Produces nonleaf tensors in the cache... -> Can lead to failure of deepcopy. posterior.q.rsample() @@ -262,9 +262,9 @@ def test_pickle_support(q: str): with tempfile.NamedTemporaryFile(suffix=".pt") as f: torch.save(posterior, f.name) posterior_loaded = torch.load(f.name) - assert ( - posterior._x == posterior_loaded._x - ).all(), "Mhh, something with the pickled is strange" + assert (posterior._x == posterior_loaded._x).all(), ( + "Mhh, something with the pickled is strange" + ) # Try if they are the same torch.manual_seed(0) @@ -291,52 +291,52 @@ def test_vi_posterior_inferface(): posterior2 = VIPosterior(potential_fn) # Raising errors if untrained - assert isinstance( - posterior.q.support, type(posterior2.q.support) - ), "The support indicated by 'theta_transform' is different than that of 'prior'." + assert isinstance(posterior.q.support, type(posterior2.q.support)), ( + "The support indicated by 'theta_transform' is different than that of 'prior'." + ) with pytest.raises(Exception) as execinfo: posterior.sample() - assert ( - "The variational posterior was not fit" in execinfo.value.args[0] - ), "An expected error was raised but the error message is different than expected." + assert "The variational posterior was not fit" in execinfo.value.args[0], ( + "An expected error was raised but the error message is different than expected." + ) with pytest.raises(Exception) as execinfo: posterior.log_prob(prior.sample()) - assert ( - "The variational posterior was not fit" in execinfo.value.args[0] - ), "An expected error was raised but the error message is different than expected." + assert "The variational posterior was not fit" in execinfo.value.args[0], ( + "An expected error was raised but the error message is different than expected." + ) # Passing Hyperparameters in train posterior.train(max_num_iters=20) posterior.train(max_num_iters=20, optimizer=torch.optim.SGD) - assert isinstance( - posterior._optimizer._optimizer, torch.optim.SGD - ), "Assert chaning the optimizer base class did not work" + assert isinstance(posterior._optimizer._optimizer, torch.optim.SGD), ( + "Assert chaning the optimizer base class did not work" + ) posterior.train(max_num_iters=20, stick_the_landing=True) - assert ( - posterior._optimizer.stick_the_landing - ), "The sticking_the_landing argument is not correctly passed." + assert posterior._optimizer.stick_the_landing, ( + "The sticking_the_landing argument is not correctly passed." + ) posterior.vi_method = "alpha" posterior.train(max_num_iters=20) posterior.train(max_num_iters=20, alpha=0.9) - assert ( - posterior._optimizer.alpha == 0.9 - ), "The Hyperparameter alpha is not passed to the corresponding optmizer" + assert posterior._optimizer.alpha == 0.9, ( + "The Hyperparameter alpha is not passed to the corresponding optmizer" + ) posterior.vi_method = "IW" posterior.train(max_num_iters=20) posterior.train(max_num_iters=20, K=32) - assert ( - posterior._optimizer.K == 32 - ), "The Hyperparameter K is not passed to the corresponding optmizer" + assert posterior._optimizer.K == 32, ( + "The Hyperparameter K is not passed to the corresponding optmizer" + ) # Passing Hyperparameters in sample posterior.sample() diff --git a/tutorials/05_conditional_distributions.ipynb b/tutorials/05_conditional_distributions.ipynb index 9e81fbf5f..fc52bdcb6 100644 --- a/tutorials/05_conditional_distributions.ipynb +++ b/tutorials/05_conditional_distributions.ipynb @@ -363,14 +363,8 @@ " NPE,\n", " MCMCPosterior,\n", " posterior_estimator_based_potential,\n", - " simulate_for_sbi,\n", ")\n", "from sbi.utils import BoxUniform\n", - "from sbi.utils.user_input_checks import (\n", - " check_sbi_inputs,\n", - " process_prior,\n", - " process_simulator,\n", - ")\n", "\n", "num_dim = 4\n", "prior = BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))\n", diff --git a/tutorials/07_sensitivity_analysis.ipynb b/tutorials/07_sensitivity_analysis.ipynb index d1d6bf3b7..534849a2b 100644 --- a/tutorials/07_sensitivity_analysis.ipynb +++ b/tutorials/07_sensitivity_analysis.ipynb @@ -45,8 +45,8 @@ "import torch\n", "from torch.distributions import MultivariateNormal\n", "\n", - "from sbi.inference import NPE\n", "from sbi.analysis import ActiveSubspace, pairplot\n", + "from sbi.inference import NPE\n", "from sbi.simulators import linear_gaussian\n", "\n", "_ = torch.manual_seed(0)" diff --git a/tutorials/08_crafting_summary_statistics.ipynb b/tutorials/08_crafting_summary_statistics.ipynb index b7775aaa8..1d7017dd7 100644 --- a/tutorials/08_crafting_summary_statistics.ipynb +++ b/tutorials/08_crafting_summary_statistics.ipynb @@ -32,8 +32,8 @@ "import numpy as np\n", "import torch\n", "\n", - "from sbi.inference import NPE\n", "from sbi.analysis import pairplot\n", + "from sbi.inference import NPE\n", "from sbi.utils import BoxUniform" ] }, diff --git a/tutorials/11_diagnostics_simulation_based_calibration.ipynb b/tutorials/11_diagnostics_simulation_based_calibration.ipynb index 0763772b3..c54f0aa9c 100644 --- a/tutorials/11_diagnostics_simulation_based_calibration.ipynb +++ b/tutorials/11_diagnostics_simulation_based_calibration.ipynb @@ -1098,6 +1098,7 @@ "source": [ "# Or, we can perform a visual check.\n", "from sbi.analysis.plot import plot_tarp\n", + "\n", "plot_tarp(ecp, alpha);" ] }, diff --git a/tutorials/12_iid_data_and_permutation_invariant_embeddings.ipynb b/tutorials/12_iid_data_and_permutation_invariant_embeddings.ipynb index 3b4eccaf6..d208180b6 100644 --- a/tutorials/12_iid_data_and_permutation_invariant_embeddings.ipynb +++ b/tutorials/12_iid_data_and_permutation_invariant_embeddings.ipynb @@ -380,8 +380,8 @@ "metadata": {}, "outputs": [], "source": [ - "from sbi.neural_nets.embedding_nets import FCEmbedding, PermutationInvariantEmbedding\n", "from sbi.neural_nets import posterior_nn\n", + "from sbi.neural_nets.embedding_nets import FCEmbedding, PermutationInvariantEmbedding\n", "\n", "# embedding\n", "latent_dim = 10\n", diff --git a/tutorials/15_importance_sampled_posteriors.ipynb b/tutorials/15_importance_sampled_posteriors.ipynb index f18afde41..15052ffd6 100644 --- a/tutorials/15_importance_sampled_posteriors.ipynb +++ b/tutorials/15_importance_sampled_posteriors.ipynb @@ -55,13 +55,13 @@ "metadata": {}, "outputs": [], "source": [ - "from torch import ones, eye\n", "import torch\n", + "from torch import eye, ones\n", "from torch.distributions import MultivariateNormal\n", "\n", + "from sbi.analysis import marginal_plot\n", "from sbi.inference import NPE, ImportanceSamplingPosterior\n", - "from sbi.utils import BoxUniform\n", - "from sbi.analysis import marginal_plot" + "from sbi.utils import BoxUniform" ] }, { diff --git a/tutorials/16_implemented_methods.ipynb b/tutorials/16_implemented_methods.ipynb index c9730e49a..9dcb78240 100644 --- a/tutorials/16_implemented_methods.ipynb +++ b/tutorials/16_implemented_methods.ipynb @@ -71,7 +71,7 @@ " x = simulator(theta)\n", " # NPE-A trains a Gaussian density estimator in all but the last round. In the last round,\n", " # it trains a mixture of Gaussians, which is why we have to pass the `final_round` flag.\n", - " final_round = True if r == num_rounds - 1 else False\n", + " final_round = r == num_rounds - 1\n", " _ = inference.append_simulations(theta, x, proposal=proposal).train(final_round=final_round)\n", " posterior = inference.build_posterior().set_default_x(x_o)\n", " proposal = posterior" @@ -454,8 +454,8 @@ "metadata": {}, "outputs": [], "source": [ - "from sbi.diagnostics import run_sbc\n", "from sbi.analysis import sbc_rank_plot\n", + "from sbi.diagnostics import run_sbc\n", "\n", "thetas = prior.sample((1000,))\n", "xs = simulator(thetas)\n", @@ -532,8 +532,8 @@ "metadata": {}, "outputs": [], "source": [ - "from sbi.diagnostics.tarp import run_tarp\n", "from sbi.analysis import plot_tarp\n", + "from sbi.diagnostics.tarp import run_tarp\n", "\n", "thetas = prior.sample((1000,))\n", "xs = simulator(thetas)\n", diff --git a/tutorials/17_plotting_functionality.ipynb b/tutorials/17_plotting_functionality.ipynb index fdf0eec7b..418467896 100644 --- a/tutorials/17_plotting_functionality.ipynb +++ b/tutorials/17_plotting_functionality.ipynb @@ -22,10 +22,11 @@ "metadata": {}, "outputs": [], "source": [ - "import torch \n", - "from sbi.analysis import pairplot\n", + "import torch\n", "from toy_posterior_for_07_cc import ExamplePosterior\n", "\n", + "from sbi.analysis import pairplot\n", + "\n", "posterior = ExamplePosterior()\n", "posterior_samples = posterior.sample((100,))" ] diff --git a/tutorials/18_training_interface.ipynb b/tutorials/18_training_interface.ipynb index 63e954c7f..6e1a17cdc 100644 --- a/tutorials/18_training_interface.ipynb +++ b/tutorials/18_training_interface.ipynb @@ -38,13 +38,14 @@ "metadata": {}, "outputs": [], "source": [ + "from typing import Callable\n", + "\n", "import torch\n", - "from torch import ones, eye\n", + "from torch import eye, ones\n", "from torch.optim import Adam, AdamW\n", "\n", - "from sbi.utils import BoxUniform\n", "from sbi.analysis import pairplot\n", - "from typing import Callable" + "from sbi.utils import BoxUniform" ] }, { @@ -466,7 +467,7 @@ "num_epochs = 100\n", "\n", "for ep in range(num_epochs):\n", - " for idx, (theta_batch, x_batch) in enumerate(train_loader):\n", + " for _, (theta_batch, x_batch) in enumerate(train_loader):\n", " optw.zero_grad()\n", " losses = maf_estimator.loss(theta_batch, condition=x_batch)\n", " loss = torch.mean(losses)\n", @@ -624,9 +625,9 @@ "metadata": {}, "outputs": [], "source": [ - "from sbi.neural_nets.net_builders import build_resnet_classifier\n", + "from sbi import utils as utils\n", "from sbi.inference.potentials import ratio_estimator_based_potential\n", - "from sbi import utils as utils" + "from sbi.neural_nets.net_builders import build_resnet_classifier" ] }, { diff --git a/tutorials/19_flowmatching_and_scorematching.ipynb b/tutorials/19_flowmatching_and_scorematching.ipynb index deaeb33f4..6768c4da0 100644 --- a/tutorials/19_flowmatching_and_scorematching.ipynb +++ b/tutorials/19_flowmatching_and_scorematching.ipynb @@ -37,9 +37,9 @@ "source": [ "import torch\n", "\n", + "from sbi.analysis import pairplot\n", "from sbi.inference import NPSE\n", - "from sbi.utils import BoxUniform\n", - "from sbi.analysis import pairplot" + "from sbi.utils import BoxUniform" ] }, { diff --git a/tutorials/Example_01_DecisionMakingModel.ipynb b/tutorials/Example_01_DecisionMakingModel.ipynb index eb16182c2..b211b19ff 100644 --- a/tutorials/Example_01_DecisionMakingModel.ipynb +++ b/tutorials/Example_01_DecisionMakingModel.ipynb @@ -77,6 +77,7 @@ "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", + "from example_01_utils import BinomialGammaPotential\n", "from pyro.distributions import InverseGamma\n", "from torch import Tensor\n", "from torch.distributions import Beta, Binomial, Gamma\n", @@ -86,10 +87,7 @@ "from sbi.inference.potentials.likelihood_based_potential import LikelihoodBasedPotential\n", "from sbi.neural_nets import likelihood_nn\n", "from sbi.utils import BoxUniform, MultipleIndependent, mcmc_transform\n", - "from sbi.utils.metrics import c2st\n", - "\n", - "\n", - "from example_01_utils import BinomialGammaPotential" + "from sbi.utils.metrics import c2st" ] }, { @@ -722,7 +720,7 @@ " [\"Prior\", \"Reference\", \"MNLE\", r\"$\\theta_o$\"],\n", " frameon=False,\n", " fontsize=12,\n", - ");\n", + ")\n", "print(\"c2st between true and MNLE posterior:\", c2st(true_posterior_samples, conditional_samples).item())" ] },