Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update and pin pre commit and ruff to recent version. #1358

Merged
merged 3 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.3
rev: v0.9.0
hooks:
- id: ruff
- id: ruff-format
args: [--diff]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v5.0.0
hooks:
- id: check-added-large-files
- id: check-merge-conflict
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ doc = [
dev = [
"ffmpeg",
# Lint
"pre-commit == 3.5.0",
"pre-commit == 4.0.1",
"pyyaml",
"pyright",
"ruff>=0.3.3",
"ruff==0.9.0",
# Test
"pytest",
"pytest-cov",
Expand Down Expand Up @@ -106,6 +106,7 @@ ignore = [
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["E402", "F401", "F403"] # allow unused imports and undefined names
"test_*.py" = ["F403", "F405"]
"tutorials/*.ipynb" = ["E501"] # allow long lines in notebooks

[tool.ruff.lint.isort]
case-sensitive = true
Expand Down
30 changes: 15 additions & 15 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,9 +775,9 @@ def pairplot(

# checks.
if fig_kwargs_filled["legend"]:
assert len(fig_kwargs_filled["samples_labels"]) >= len(
samples
), "Provide at least as many labels as samples."
assert len(fig_kwargs_filled["samples_labels"]) >= len(samples), (
"Provide at least as many labels as samples."
)
if offdiag is not None:
warn("offdiag is deprecated, use upper or lower instead.", stacklevel=2)
upper = offdiag
Expand Down Expand Up @@ -1594,9 +1594,9 @@ def _sbc_rank_plot(
ranks_list[idx]: np.ndarray = rank.numpy() # type: ignore

plot_types = ["hist", "cdf"]
assert (
plot_type in plot_types
), "plot type {plot_type} not implemented, use one in {plot_types}."
assert plot_type in plot_types, (
"plot type {plot_type} not implemented, use one in {plot_types}."
)

if legend_kwargs is None:
legend_kwargs = dict(loc="best", handlelength=0.8)
Expand All @@ -1609,9 +1609,9 @@ def _sbc_rank_plot(
params_in_subplots = True

for ranki in ranks_list:
assert (
ranki.shape == ranks_list[0].shape
), "all ranks in list must have the same shape."
assert ranki.shape == ranks_list[0].shape, (
"all ranks in list must have the same shape."
)

num_rows = int(np.ceil(num_parameters / num_cols))
if figsize is None:
Expand All @@ -1636,9 +1636,9 @@ def _sbc_rank_plot(
)
ax = np.atleast_1d(ax) # type: ignore
else:
assert (
ax.size >= num_parameters
), "There must be at least as many subplots as parameters."
assert ax.size >= num_parameters, (
"There must be at least as many subplots as parameters."
)
num_rows = ax.shape[0] if ax.ndim > 1 else 1
assert ax is not None

Expand Down Expand Up @@ -2221,9 +2221,9 @@ def pairplot_dep(

# checks.
if opts["legend"]:
assert len(opts["samples_labels"]) >= len(
samples
), "Provide at least as many labels as samples."
assert len(opts["samples_labels"]) >= len(samples), (
"Provide at least as many labels as samples."
)
if opts["upper"] is not None:
opts["offdiag"] = opts["upper"]

Expand Down
12 changes: 6 additions & 6 deletions sbi/analysis/sensitivity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,9 @@
prevent exploding gradients. Use `None` for no clipping.
"""

assert (
self._theta is not None and self._emergent_property is not None
), "You must call .add_property() first."
assert self._theta is not None and self._emergent_property is not None, (

Check warning on line 253 in sbi/analysis/sensitivity_analysis.py

View check run for this annotation

Codecov / codecov/patch

sbi/analysis/sensitivity_analysis.py#L253

Added line #L253 was not covered by tests
"You must call .add_property() first."
)

# Get indices for permutation of the data.
num_examples = len(self._theta)
Expand Down Expand Up @@ -433,9 +433,9 @@
if posterior_log_prob_as_property:
predictions = self._posterior.potential(thetas, track_gradients=True)
else:
assert (
self._regression_net is not None
), "self._regression_net is None, you must call `.train()` first."
assert self._regression_net is not None, (

Check warning on line 436 in sbi/analysis/sensitivity_analysis.py

View check run for this annotation

Codecov / codecov/patch

sbi/analysis/sensitivity_analysis.py#L436

Added line #L436 was not covered by tests
"self._regression_net is None, you must call `.train()` first."
)
predictions = self._regression_net.forward(thetas)
loss = predictions.mean()
loss.backward()
Expand Down
30 changes: 15 additions & 15 deletions sbi/diagnostics/lc2st.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def __init__(
[2] : https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py
"""

assert (
thetas.shape[0] == xs.shape[0] == posterior_samples.shape[0]
), "Number of samples must match"
assert thetas.shape[0] == xs.shape[0] == posterior_samples.shape[0], (
"Number of samples must match"
)

# set observed data for classification
self.theta_p = posterior_samples
Expand Down Expand Up @@ -283,9 +283,9 @@ def get_statistic_on_observed_data(
Returns:
L-C2ST statistic at `x_o`.
"""
assert (
self.trained_clfs is not None
), "No trained classifiers found. Run `train_on_observed_data` first."
assert self.trained_clfs is not None, (
"No trained classifiers found. Run `train_on_observed_data` first."
)
_, scores = self.get_scores(
theta_o=theta_o,
x_o=x_o,
Expand Down Expand Up @@ -372,9 +372,9 @@ def train_under_null_hypothesis(
joint_q_perm[:, self.theta_q.shape[1] :],
)
else:
assert (
self.null_distribution is not None
), "You need to provide a null distribution"
assert self.null_distribution is not None, (
"You need to provide a null distribution"
)
theta_p_t = self.null_distribution.sample((self.theta_p.shape[0],))
theta_q_t = self.null_distribution.sample((self.theta_p.shape[0],))
x_p_t, x_q_t = self.x_p, self.x_q
Expand Down Expand Up @@ -419,9 +419,9 @@ def get_statistics_under_null_hypothesis(
Run `train_under_null_hypothesis`."
)
else:
assert (
len(self.trained_clfs_null) == self.num_trials_null
), "You need one classifier per trial."
assert len(self.trained_clfs_null) == self.num_trials_null, (
"You need one classifier per trial."
)

probs_null, stats_null = [], []
for t in tqdm(
Expand All @@ -433,9 +433,9 @@ def get_statistics_under_null_hypothesis(
if self.permutation:
theta_o_t = theta_o
else:
assert (
self.null_distribution is not None
), "You need to provide a null distribution"
assert self.null_distribution is not None, (
"You need to provide a null distribution"
)

theta_o_t = self.null_distribution.sample((theta_o.shape[0],))

Expand Down
6 changes: 3 additions & 3 deletions sbi/diagnostics/sbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def run_sbc(
stacklevel=2,
)

assert (
thetas.shape[0] == xs.shape[0]
), "Unequal number of parameters and observations."
assert thetas.shape[0] == xs.shape[0], (
"Unequal number of parameters and observations."
)

if "sbc_batch_size" in kwargs:
warnings.warn(
Expand Down
6 changes: 3 additions & 3 deletions sbi/diagnostics/tarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ def _run_tarp(
"""
num_posterior_samples, num_tarp_samples, _ = posterior_samples.shape

assert (
references.shape == thetas.shape
), "references must have the same shape as thetas"
assert references.shape == thetas.shape, (
"references must have the same shape as thetas"
)

if num_bins is None:
num_bins = num_tarp_samples // 10
Expand Down
6 changes: 3 additions & 3 deletions sbi/inference/abc/mcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@
"""

# Exactly one of eps or quantile need to be passed.
assert (eps is not None) ^ (
quantile is not None
), "Eps or quantile must be passed, but not both."
assert (eps is not None) ^ (quantile is not None), (

Check warning on line 133 in sbi/inference/abc/mcabc.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/abc/mcabc.py#L133

Added line #L133 was not covered by tests
"Eps or quantile must be passed, but not both."
)
if kde_kwargs is None:
kde_kwargs = {}

Expand Down
24 changes: 12 additions & 12 deletions sbi/inference/abc/smcabc.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@
)

kernels = ("gaussian", "uniform")
assert (
kernel in kernels
), f"Kernel '{kernel}' not supported. Choose one from {kernels}."
assert kernel in kernels, (

Check warning on line 98 in sbi/inference/abc/smcabc.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/abc/smcabc.py#L98

Added line #L98 was not covered by tests
f"Kernel '{kernel}' not supported. Choose one from {kernels}."
)
self.kernel = kernel

algorithm_variants = ("A", "B", "C")
Expand Down Expand Up @@ -198,13 +198,13 @@
if kde_kwargs is None:
kde_kwargs = {}
assert isinstance(epsilon_decay, float) and epsilon_decay > 0.0
assert not (
self.distance.requires_iid_data and lra
), "Currently there is no support to run inference "
assert not (self.distance.requires_iid_data and lra), (

Check warning on line 201 in sbi/inference/abc/smcabc.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/abc/smcabc.py#L201

Added line #L201 was not covered by tests
"Currently there is no support to run inference "
)
"on multiple observations together with lra."
assert not (
self.distance.requires_iid_data and sass
), "Currently there is no support to run inference "
assert not (self.distance.requires_iid_data and sass), (

Check warning on line 205 in sbi/inference/abc/smcabc.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/abc/smcabc.py#L205

Added line #L205 was not covered by tests
"Currently there is no support to run inference "
)
"on multiple observations together with sass."

# Pilot run for SASS.
Expand Down Expand Up @@ -363,9 +363,9 @@
) -> Tuple[Tensor, float, Tensor, Tensor]:
"""Return particles, epsilon and distances of initial population."""

assert (
num_particles <= num_initial_pop
), "number of initial round simulations must be greater than population size"
assert num_particles <= num_initial_pop, (

Check warning on line 366 in sbi/inference/abc/smcabc.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/abc/smcabc.py#L366

Added line #L366 was not covered by tests
"number of initial round simulations must be greater than population size"
)

assert (x_o.shape[0] == 1) or self.distance.requires_iid_data, (
"Your data contain iid data-points, but the choice of "
Expand Down
4 changes: 1 addition & 3 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,7 @@
return desc

def __str__(self):
desc = (
f"Posterior p(θ|x) of type {self.__class__.__name__}. " f"{self._purpose}"
)
desc = f"Posterior p(θ|x) of type {self.__class__.__name__}. {self._purpose}"

Check warning on line 291 in sbi/inference/posteriors/base_posterior.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/posteriors/base_posterior.py#L291

Added line #L291 was not covered by tests
return desc

def __getstate__(self) -> Dict:
Expand Down
6 changes: 3 additions & 3 deletions sbi/inference/posteriors/ensemble_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ def ensure_same_device(self, posteriors: List) -> str:
A device string, that is the same for all posteriors.
"""
devices = [posterior._device for posterior in posteriors]
assert all(
device == devices[0] for device in devices
), "Only supported if all posteriors are on the same device."
assert all(device == devices[0] for device in devices), (
"Only supported if all posteriors are on the same device."
)
return devices[0]

@property
Expand Down
12 changes: 6 additions & 6 deletions sbi/inference/posteriors/mcmc_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,9 @@ def sample_batched(
else init_strategy_parameters
)

assert (
method == "slice_np_vectorized"
), "Batched sampling only supported for vectorized samplers!"
assert method == "slice_np_vectorized", (
"Batched sampling only supported for vectorized samplers!"
)

# warn if num_chains is larger than num requested samples
if num_chains > torch.Size(sample_shape).numel():
Expand Down Expand Up @@ -1003,9 +1003,9 @@ def get_arviz_inference_data(self) -> InferenceData:
Returns:
inference_data: Arviz InferenceData object.
"""
assert (
self._posterior_sampler is not None
), """No samples have been generated, call .sample() first."""
assert self._posterior_sampler is not None, (
"""No samples have been generated, call .sample() first."""
)

sampler: Union[
MCMC, SliceSamplerSerial, SliceSamplerVectorized, PyMCSampler
Expand Down
24 changes: 12 additions & 12 deletions sbi/inference/potentials/likelihood_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@
# Calculate likelihood for each (theta,x) pair separately
theta_batch_size = theta.shape[0]
x_batch_size = self.x_o.shape[0]
assert (
theta_batch_size == x_batch_size
), f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\
assert theta_batch_size == x_batch_size, (
f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\
When performing batched sampling for multiple `x`, the batch size of\
`theta` must match the batch size of `x`."
)
x = self.x_o.unsqueeze(0)
with torch.set_grad_enabled(track_gradients):
log_likelihood_batches = self.likelihood_estimator.log_prob(
Expand Down Expand Up @@ -143,9 +143,9 @@
def conditioned_potential(
theta: Tensor, x_o: Optional[Tensor] = None, track_gradients: bool = True
) -> Tensor:
assert (
len(dims_global_theta) == theta.shape[1]
), "dims_global_theta must match the number of parameters to sample."
assert len(dims_global_theta) == theta.shape[1], (

Check warning on line 146 in sbi/inference/potentials/likelihood_based_potential.py

View check run for this annotation

Codecov / codecov/patch

sbi/inference/potentials/likelihood_based_potential.py#L146

Added line #L146 was not covered by tests
"dims_global_theta must match the number of parameters to sample."
)
global_theta = theta[:, dims_global_theta]
x_o = x_o if x_o is not None else self.x_o
# x needs shape (sample_dim (iid), batch_dim (xs), *event_shape)
Expand Down Expand Up @@ -257,15 +257,15 @@
theta_batch_dim)`.
"""
assert x.dim() > 2, "x must have shape (sample_dim, batch_dim, *event_shape)."
assert (
local_theta.dim() == 2
), "condition must have shape (sample_dim, num_conditions)."
assert local_theta.dim() == 2, (
"condition must have shape (sample_dim, num_conditions)."
)
assert global_theta.dim() == 2, "theta must have shape (batch_dim, num_parameters)."
num_trials, num_xs = x.shape[:2]
num_thetas = global_theta.shape[0]
assert (
local_theta.shape[0] == num_trials
), "Condition batch size must match the number of iid trials in x."
assert local_theta.shape[0] == num_trials, (
"Condition batch size must match the number of iid trials in x."
)

# move the iid batch dimension onto the batch dimension of theta and repeat it there
x_repeated = torch.transpose(x, 0, 1).repeat_interleave(num_thetas, dim=1)
Expand Down
6 changes: 3 additions & 3 deletions sbi/inference/potentials/posterior_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ def __call__(self, theta: Tensor, track_gradients: bool = True) -> Tensor:
theta_batch_size = theta.shape[0]
x_batch_size = x.shape[0]

assert (
theta_batch_size == x_batch_size or x_batch_size == 1
), f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\
assert theta_batch_size == x_batch_size or x_batch_size == 1, (
f"Batch size mismatch: {theta_batch_size} and {x_batch_size}.\
When performing batched sampling for multiple `x`, the batch size of\
`theta` must match the batch size of `x`."
)

if x_batch_size == 1:
# If a single `x` is passed (i.e. batchsize==1), we squeeze
Expand Down
Loading
Loading