diff --git a/blackjax/smc/partial_posteriors_path.py b/blackjax/smc/partial_posteriors_path.py index 1279ad245..81f19716d 100644 --- a/blackjax/smc/partial_posteriors_path.py +++ b/blackjax/smc/partial_posteriors_path.py @@ -86,7 +86,10 @@ def log_weights_fn(x): key, state, num_mcmc_steps, mcmc_parameters, logposterior_fn, log_weights_fn ) - return PartialPosteriorsSMCState(state.particles, state.weights, data_mask), info + return ( + PartialPosteriorsSMCState(state.particles, state.weights, data_mask), + info, + ) return step diff --git a/tests/smc/test_partial_posteriors_smc.py b/tests/smc/test_partial_posteriors_smc.py index 5d5a5e0ed..78d57a934 100644 --- a/tests/smc/test_partial_posteriors_smc.py +++ b/tests/smc/test_partial_posteriors_smc.py @@ -62,7 +62,12 @@ def partial_logposterior(x): data_masks = jnp.array( [ - jnp.concat([jnp.ones(datapoints_chosen), jnp.zeros(dataset_size - datapoints_chosen)]) + jnp.concat( + [ + jnp.ones(datapoints_chosen), + jnp.zeros(dataset_size - datapoints_chosen), + ] + ) for datapoints_chosen in np.arange(100, 1001, 50) ] )