From 3c14a9cdb08f5b5eaebe4ca20d3188b97ff070be Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Fri, 15 Mar 2024 16:28:34 -0300 Subject: [PATCH 1/8] vmaping over parameters in base --- blackjax/smc/base.py | 9 ++-- blackjax/smc/inner_kernel_tuning.py | 12 +++-- blackjax/smc/tempered.py | 6 +-- blackjax/smc/tuning/__init__.py | 31 +++++++++++++ blackjax/smc/tuning/from_particles.py | 1 + tests/mcmc/test_sampling.py | 6 +-- tests/smc/test_inner_kernel_tuning.py | 25 +++++------ tests/smc/test_kernel_compatibility.py | 61 ++++++++++++++++++-------- tests/smc/test_smc.py | 56 ++++++++++++++--------- tests/smc/test_tempered_smc.py | 13 +++--- 10 files changed, 149 insertions(+), 71 deletions(-) diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 409f588d2..5b5085648 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -40,6 +40,7 @@ class SMCState(NamedTuple): particles: ArrayTree weights: Array + update_parameters: ArrayTree class SMCInfo(NamedTuple): @@ -59,12 +60,12 @@ class SMCInfo(NamedTuple): update_info: NamedTuple -def init(particles: ArrayLikeTree): +def init(particles: ArrayLikeTree, init_update_params): # Infer the number of particles from the size of the leading dimension of # the first leaf of the inputted PyTree. num_particles = jax.tree_util.tree_flatten(particles)[0][0].shape[0] weights = jnp.ones(num_particles) / num_particles - return SMCState(particles, weights) + return SMCState(particles, weights, init_update_params) def step( @@ -137,13 +138,13 @@ def step( particles = jax.tree_map(lambda x: x[resampling_idx], state.particles) keys = jax.random.split(updating_key, num_resampled) - particles, update_info = update_fn(keys, particles) + particles, update_info = update_fn(keys, particles, state.update_parameters) log_weights = weight_fn(particles) logsum_weights = jax.scipy.special.logsumexp(log_weights) normalizing_constant = logsum_weights - jnp.log(num_particles) weights = jnp.exp(log_weights - logsum_weights) - return SMCState(particles, weights), SMCInfo( + return SMCState(particles, weights, state.update_parameters), SMCInfo( resampling_idx, normalizing_constant, update_info ) diff --git a/blackjax/smc/inner_kernel_tuning.py b/blackjax/smc/inner_kernel_tuning.py index 6aaf3a5d3..e50bd944d 100644 --- a/blackjax/smc/inner_kernel_tuning.py +++ b/blackjax/smc/inner_kernel_tuning.py @@ -8,8 +8,14 @@ class StateWithParameterOverride(NamedTuple): + """ + Stores both the sampling status and also a dictionary + that contains an dictionary with parameter names as key + and (n_particles, *) arrays as meanings. The latter + represent a parameter per chain for the next mutation step. + """ sampler_state: ArrayTree - parameter_override: ArrayTree + parameter_override: Dict[str, ArrayTree] def init(alg_init_fn, position, initial_parameter_value): @@ -42,7 +48,7 @@ def build_kernel( loglikelihood_fn A function that returns the probability at a given position. mcmc_factory - A callable that can construct an inner kernel out of the newly-computed parameter + A callable that can construct an array of kernels out of newly-computed parameters. mcmc_init_fn A callable that initializes the inner kernel mcmc_parameters @@ -59,7 +65,7 @@ def kernel( step_fn = smc_algorithm( logprior_fn=logprior_fn, loglikelihood_fn=loglikelihood_fn, - mcmc_step_fn=mcmc_factory(state.parameter_override), + mcmc_step_fn=mcmc_factory(**state.parameter_override), mcmc_init_fn=mcmc_init_fn, mcmc_parameters=mcmc_parameters, resampling_fn=resampling_fn, diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index 49fa21277..561eadecc 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -127,12 +127,12 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float: tempered_loglikelihood = state.lmbda * loglikelihood_fn(position) return logprior + tempered_loglikelihood - def mcmc_kernel(rng_key, position): + def mcmc_kernel(rng_key, position, step_parameters): state = mcmc_init_fn(position, tempered_logposterior_fn) def body_fn(state, rng_key): new_state, info = mcmc_step_fn( - rng_key, state, tempered_logposterior_fn, **mcmc_parameters + rng_key, state, tempered_logposterior_fn, **step_parameters ) return new_state, info @@ -142,7 +142,7 @@ def body_fn(state, rng_key): smc_state, info = smc.base.step( rng_key, - SMCState(state.particles, state.weights), + SMCState(state.particles, state.weights, mcmc_parameters), jax.vmap(mcmc_kernel), jax.vmap(log_weights_fn), resampling_fn, diff --git a/blackjax/smc/tuning/__init__.py b/blackjax/smc/tuning/__init__.py index e69de29bb..6d53fc3f8 100644 --- a/blackjax/smc/tuning/__init__.py +++ b/blackjax/smc/tuning/__init__.py @@ -0,0 +1,31 @@ +import jax +import jax.numpy as jnp +import numpy as np + + +def extend_to_all_particles(n_particles, tuning_strategy): + """ + given a tuning strategy that returns a single parameter, + that parameter gets extended to be applied to all particles + """ + def extended(state, info): + res = tuning_strategy(state, info) + return jnp.repeat(res, n_particles) + + return extended + +def extend_params_inner_kernel(n_particles, params): + """ + Given a dictionary of params, repeats them for every single particle + Shapes> + scalar, 1000, + 1 . 1000, 1 + 2,2 . 1000,2,2 + """ + def extend(param): + if np.isscalar(param): + return jnp.repeat(param, n_particles, axis=0) + else: + return jnp.repeat(jnp.atleast_1d(param)[np.newaxis, :], n_particles, axis=0) + + return jax.tree_map(extend, params) \ No newline at end of file diff --git a/blackjax/smc/tuning/from_particles.py b/blackjax/smc/tuning/from_particles.py index 4c8ca98da..ae5f3c88a 100755 --- a/blackjax/smc/tuning/from_particles.py +++ b/blackjax/smc/tuning/from_particles.py @@ -48,3 +48,4 @@ def particles_as_rows(particles): as a matrix where each column is a variable, each row a particle. """ return jax.vmap(lambda x: ravel_pytree(x)[0])(particles) + diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 7d20805ab..18b241cf1 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -27,12 +27,12 @@ def sample_orbit(orbit, weights, rng_key): return samples -def irmh_proposal_distribution(rng_key): +def irmh_proposal_distribution(rng_key, mean): """ The proposal distribution is chosen to be wider than the target, so that the RMH rejection doesn't make the sample overemphasize the center of the target distribution. """ - return 1.0 + jax.random.normal(rng_key) * 25.0 + return mean + jax.random.normal(rng_key) * 25.0 def rmh_proposal_distribution(rng_key, position): @@ -657,7 +657,7 @@ def test_univariate_normal( self, algorithm, initial_position, parameters, num_sampling_steps, burnin ): if algorithm == blackjax.irmh: - parameters["proposal_distribution"] = irmh_proposal_distribution + parameters["proposal_distribution"] = functools.partial(irmh_proposal_distribution, mean=1.0) if algorithm == blackjax.rmh: parameters["proposal_generator"] = rmh_proposal_distribution diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index 1bbc68970..d8f0752f5 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -92,18 +92,18 @@ def smc_inner_kernel_tuning_test_case( proposal_factory.return_value = 100 def mcmc_parameter_update_fn(state, info): - return 100 + return {"mean":100} mcmc_factory = MagicMock() sampling_algorithm = MagicMock() mcmc_factory.return_value = sampling_algorithm prior = lambda x: stats.norm.logpdf(x) - def kernel_factory(proposal_distribution): + def kernel_factory(mean): kernel = blackjax.irmh.build_kernel() def wrapped_kernel(rng_key, state, logdensity): - return kernel(rng_key, state, logdensity, proposal_distribution) + return kernel(rng_key, state, logdensity, functools.partial(irmh_proposal_distribution, mean=mean)) return wrapped_kernel @@ -116,14 +116,14 @@ def wrapped_kernel(rng_key, state, logdensity): smc_algorithm=smc_algorithm, mcmc_parameters={}, mcmc_parameter_update_fn=mcmc_parameter_update_fn, - initial_parameter_value=irmh_proposal_distribution, + initial_parameter_value={"mean":1.0}, **smc_parameters, ) new_state, new_info = kernel.step( self.key, state=kernel.init(init_particles), **step_parameters ) - assert new_state.parameter_override == 100 + assert new_state.parameter_override == {"mean":100} class MeanAndStdFromParticlesTest(chex.TestCase): @@ -294,10 +294,9 @@ def test_with_adaptive_tempered(self): blackjax.hmc.init, {}, resampling.systematic, - mcmc_parameter_update_fn=lambda state, info: mass_matrix_from_particles( - state.particles - ), - initial_parameter_value=jnp.eye(2), + mcmc_parameter_update_fn=lambda state, info: {"mass_matrix": + mass_matrix_from_particles(state.particles)}, + initial_parameter_value={"mass_matrix":jnp.eye(2)}, num_mcmc_steps=10, target_ess=0.5, ) @@ -319,7 +318,7 @@ def body(carry): state, _ = inference_loop(smc_kernel, self.key, init_state) - assert state.parameter_override.shape == (2, 2) + assert state.parameter_override["mass_matrix"].shape == (2, 2) self.assert_linear_regression_test_case(state.sampler_state) @chex.all_variants(with_pmap=False) @@ -339,10 +338,10 @@ def test_with_tempered_smc(self): blackjax.hmc.init, {}, resampling.systematic, - mcmc_parameter_update_fn=lambda state, info: mass_matrix_from_particles( + mcmc_parameter_update_fn=lambda state, info: {"mass_matrix":mass_matrix_from_particles( state.particles - ), - initial_parameter_value=jnp.eye(2), + )}, + initial_parameter_value={"mass_matrix":jnp.eye(2)}, num_mcmc_steps=10, ) diff --git a/tests/smc/test_kernel_compatibility.py b/tests/smc/test_kernel_compatibility.py index 3d2469914..e79a02a1e 100644 --- a/tests/smc/test_kernel_compatibility.py +++ b/tests/smc/test_kernel_compatibility.py @@ -7,6 +7,7 @@ import blackjax from blackjax import adaptive_tempered_smc from blackjax.mcmc.random_walk import normal +from blackjax.smc.tuning import extend_params_inner_kernel class SMCAndMCMCIntegrationTest(unittest.TestCase): @@ -18,10 +19,12 @@ class SMCAndMCMCIntegrationTest(unittest.TestCase): def setUp(self): super().setUp() self.key = jax.random.key(42) + self.n_particles = 3 self.initial_particles = jax.random.multivariate_normal( - self.key, jnp.zeros(2), jnp.eye(2), (3,) + self.key, jnp.zeros(2), jnp.eye(2), (self.n_particles,) ) + def check_compatible(self, mcmc_step_fn, mcmc_init_fn, mcmc_parameters): """ Runs one SMC step @@ -40,54 +43,74 @@ def check_compatible(self, mcmc_step_fn, mcmc_init_fn, mcmc_parameters): kernel(self.key, init(self.initial_particles)) def test_compatible_with_rwm(self): + rwm = blackjax.additive_step_random_walk.build_kernel() + + def kernel(rng_key, state, logdensity_fn, proposal_mean): + return rwm(rng_key, state, logdensity_fn, normal(proposal_mean)) + self.check_compatible( - blackjax.additive_step_random_walk.build_kernel(), + kernel, blackjax.additive_step_random_walk.init, - {"random_step": normal(1.0)}, + extend_params_inner_kernel(self.n_particles, {"proposal_mean": 1.0}) ) def test_compatible_with_rmh(self): + rmh = blackjax.rmh.build_kernel() + + def kernel( + rng_key, + state, + logdensity_fn, + proposal_mean, + proposal_logdensity_fn=None + ): + return rmh(rng_key, + state, + logdensity_fn, + lambda a, b: blackjax.mcmc.random_walk.normal(proposal_mean)(a, b), + proposal_logdensity_fn) self.check_compatible( - blackjax.rmh.build_kernel(), + kernel, blackjax.rmh.init, - { - "transition_generator": lambda a, b: blackjax.mcmc.random_walk.normal( - 1.0 - )(a, b) - }, + extend_params_inner_kernel(self.n_particles, {"proposal_mean":1.0}) ) def test_compatible_with_hmc(self): self.check_compatible( blackjax.hmc.build_kernel(), blackjax.hmc.init, - { + extend_params_inner_kernel(self.n_particles,{ "step_size": 0.3, - "inverse_mass_matrix": jnp.array([1]), + "inverse_mass_matrix": jnp.array([1.]), "num_integration_steps": 1, - }, + }), ) def test_compatible_with_irmh(self): + def kernel(rng_key, state, logdensity_fn, mean, proposal_logdensity_fn=None): + return blackjax.irmh.build_kernel()(rng_key, + state, + logdensity_fn, + lambda key: mean + jax.random.normal(key), proposal_logdensity_fn) + self.check_compatible( - blackjax.irmh.build_kernel(), + kernel, blackjax.irmh.init, - { - "proposal_distribution": lambda key: jnp.array([1.0, 1.0]) - + jax.random.normal(key) - }, + extend_params_inner_kernel(self.n_particles,{ + "mean": jnp.array([1.0, 1.0]) + }) ) def test_compatible_with_nuts(self): self.check_compatible( blackjax.nuts.build_kernel(), blackjax.nuts.init, - {"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)}, + extend_params_inner_kernel(self.n_particles, {"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)}), ) def test_compatible_with_mala(self): self.check_compatible( - blackjax.mala.build_kernel(), blackjax.mala.init, {"step_size": 1e-10} + blackjax.mala.build_kernel(), blackjax.mala.init, extend_params_inner_kernel(self.n_particles,{"step_size": 1e-10}) ) @staticmethod diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 242e11c55..a0ee4c269 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -1,4 +1,6 @@ """Test the generic SMC sampler""" +import unittest + import chex import jax import jax.numpy as jnp @@ -9,6 +11,7 @@ import blackjax import blackjax.smc.resampling as resampling from blackjax.smc.base import init, step +from blackjax.smc.tuning import extend_params_inner_kernel def logdensity_fn(position): @@ -31,14 +34,8 @@ def test_smc(self): num_mcmc_steps = 20 num_particles = 1000 - hmc = blackjax.hmc( - logdensity_fn, - step_size=1e-2, - inverse_mass_matrix=jnp.eye(1), - num_integration_steps=50, - ) - - def update_fn(rng_key, position): + def update_fn(rng_key, position, update_params): + hmc = blackjax.hmc(logdensity_fn,**update_params) state = hmc.init(position) def body_fn(state, rng_key): @@ -53,7 +50,10 @@ def body_fn(state, rng_key): # Initialize the state of the SMC sampler init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,)) - state = init(init_particles) + same_for_all_params = dict(step_size=1e-2, + inverse_mass_matrix=jnp.eye(1), + num_integration_steps=50) + state = init(init_particles, extend_params_inner_kernel(num_particles, same_for_all_params)) # Run the SMC sampler once new_state, info = self.variant(step, static_argnums=(2, 3, 4))( @@ -74,15 +74,12 @@ def test_smc_waste_free(self): num_particles = 1000 num_resampled = num_particles // num_mcmc_steps - hmc = blackjax.hmc( - logdensity_fn, - step_size=1e-2, - inverse_mass_matrix=jnp.eye(1), - num_integration_steps=100, - ) - - def waste_free_update_fn(keys, particles): - def one_particle_fn(rng_key, position): + def waste_free_update_fn(keys, particles, update_params): + def one_particle_fn(rng_key, position, particle_update_params): + hmc = blackjax.hmc( + logdensity_fn, + **particle_update_params + ) state = hmc.init(position) def body_fn(state, rng_key): @@ -93,7 +90,7 @@ def body_fn(state, rng_key): _, (states, info) = jax.lax.scan(body_fn, state, keys) return states.position, info - particles, info = jax.vmap(one_particle_fn)(keys, particles) + particles, info = jax.vmap(one_particle_fn)(keys, particles, update_params) particles = particles.reshape((num_particles,)) return particles, info @@ -101,7 +98,9 @@ def body_fn(state, rng_key): # Initialize the state of the SMC sampler init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,)) - state = init(init_particles) + state = init(init_particles, extend_params_inner_kernel(num_resampled, dict(step_size=1e-2, + inverse_mass_matrix=jnp.eye(1), + num_integration_steps=100))) # Run the SMC sampler once new_state, info = self.variant(step, static_argnums=(2, 3, 4, 5))( @@ -118,5 +117,22 @@ def body_fn(state, rng_key): np.testing.assert_allclose(1.0, std, atol=1e-1) +class ExtendToParticlesTest(unittest.TestCase): + def test_extend_params_inner_kernel(self): + extended = extend_params_inner_kernel(3, {"a":50, + "b": np.array([50]), + "c": np.array([50, 60]), + "d": np.array([[1,2],[3,4]]) + }) + np.testing.assert_allclose(extended["a"], np.ones((3,))*50) + np.testing.assert_allclose(extended["b"], np.array([[50],[50],[50]])) + np.testing.assert_allclose(extended["c"], np.array([[50, 60], [50, 60], [50, 60]])) + np.testing.assert_allclose(extended["d"], np.array([[[1,2],[3,4]], + [[1, 2], [3, 4]], + [[1, 2], [3, 4]] + ])) + + + if __name__ == "__main__": absltest.main() diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index f4234d117..0729cce88 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -12,6 +12,7 @@ import blackjax.smc.resampling as resampling import blackjax.smc.solver as solver from blackjax import adaptive_tempered_smc, tempered_smc +from blackjax.smc.tuning import extend_params_inner_kernel from tests.smc import SMCLinearRegressionTestCase @@ -64,11 +65,11 @@ def logprior_fn(x): hmc_kernel = blackjax.hmc.build_kernel() hmc_init = blackjax.hmc.init - hmc_parameters = { + hmc_parameters = extend_params_inner_kernel(num_particles, { "step_size": 10e-2, "inverse_mass_matrix": jnp.eye(2), "num_integration_steps": 50, - } + }) for target_ess in [0.5, 0.75]: tempering = adaptive_tempered_smc( @@ -110,11 +111,11 @@ def test_fixed_schedule_tempered_smc(self): lambda_schedule = np.logspace(-5, 0, num_tempering_steps) hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() - hmc_parameters = { + hmc_parameters = extend_params_inner_kernel(100, { "step_size": 10e-2, "inverse_mass_matrix": jnp.eye(2), "num_integration_steps": 50, - } + }) tempering = tempered_smc( logprior_fn, @@ -174,11 +175,11 @@ def test_normalizing_constant(self): hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() - hmc_parameters = { + hmc_parameters = extend_params_inner_kernel(num_particles, { "step_size": 10e-2, "inverse_mass_matrix": jnp.eye(num_dim), "num_integration_steps": 50, - } + }) tempering = adaptive_tempered_smc( logprior_fn, From 2f225d35f3f14597d4ae2db8635f2f505007fdd6 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Thu, 21 Mar 2024 15:17:55 -0300 Subject: [PATCH 2/8] switch from mcmc_factory to just passing in parameters --- blackjax/smc/__init__.py | 6 ++- blackjax/smc/base.py | 13 +++++ blackjax/smc/inner_kernel_tuning.py | 27 ++++------- blackjax/smc/tuning/__init__.py | 31 ------------ tests/smc/test_inner_kernel_tuning.py | 67 +++++++++++++------------- tests/smc/test_kernel_compatibility.py | 2 +- tests/smc/test_smc.py | 3 +- tests/smc/test_tempered_smc.py | 2 +- 8 files changed, 64 insertions(+), 87 deletions(-) diff --git a/blackjax/smc/__init__.py b/blackjax/smc/__init__.py index 180cd8259..492ed67f0 100644 --- a/blackjax/smc/__init__.py +++ b/blackjax/smc/__init__.py @@ -1,3 +1,7 @@ from . import adaptive_tempered, inner_kernel_tuning, tempered +from .base import extend_params_inner_kernel + +__all__ = ["adaptive_tempered", "tempered", "inner_kernel_tuning", "extend_params_inner_kernel"] + + -__all__ = ["adaptive_tempered", "tempered", "inner_kernel_tuning"] diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 5b5085648..4e8f761a4 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -148,3 +148,16 @@ def step( return SMCState(particles, weights, state.update_parameters), SMCInfo( resampling_idx, normalizing_constant, update_info ) + +def extend_params_inner_kernel(n_particles, params): + """ + Given a dictionary of params, repeats them for every single particle. The expected + usage is in cases where the aim is to repeat the same parameters for all chains within SMC. + """ + def extend(param): + if jnp.isscalar(param): + return jnp.repeat(param, n_particles, axis=0) + else: + return jnp.repeat(jnp.atleast_1d(param)[jnp.newaxis, :], n_particles, axis=0) + + return jax.tree_map(extend, params) \ No newline at end of file diff --git a/blackjax/smc/inner_kernel_tuning.py b/blackjax/smc/inner_kernel_tuning.py index e50bd944d..b309b9369 100644 --- a/blackjax/smc/inner_kernel_tuning.py +++ b/blackjax/smc/inner_kernel_tuning.py @@ -26,9 +26,8 @@ def build_kernel( smc_algorithm, logprior_fn: Callable, loglikelihood_fn: Callable, - mcmc_factory: Callable, + mcmc_step_fn: Callable, mcmc_init_fn: Callable, - mcmc_parameters: Dict, resampling_fn: Callable, mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree], num_mcmc_steps: int = 10, @@ -47,12 +46,10 @@ def build_kernel( A function that computes the log density of the prior distribution loglikelihood_fn A function that returns the probability at a given position. - mcmc_factory - A callable that can construct an array of kernels out of newly-computed parameters. + mcmc_step_fn: + The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. mcmc_init_fn A callable that initializes the inner kernel - mcmc_parameters - Other (fixed across SMC iterations) parameters for the inner kernel mcmc_parameter_update_fn A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration. extra_parameters: @@ -65,9 +62,9 @@ def kernel( step_fn = smc_algorithm( logprior_fn=logprior_fn, loglikelihood_fn=loglikelihood_fn, - mcmc_step_fn=mcmc_factory(**state.parameter_override), + mcmc_step_fn=mcmc_step_fn, mcmc_init_fn=mcmc_init_fn, - mcmc_parameters=mcmc_parameters, + mcmc_parameters=state.parameter_override, resampling_fn=resampling_fn, num_mcmc_steps=num_mcmc_steps, **extra_parameters, @@ -95,17 +92,15 @@ class inner_kernel_tuning: A function that computes the log density of the prior distribution loglikelihood_fn A function that returns the probability at a given position. - mcmc_factory - A callable that can construct an inner kernel out of the newly-computed parameter + mcmc_step_fn + The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. mcmc_init_fn A callable that initializes the inner kernel - mcmc_parameters - Other (fixed across SMC iterations) parameters for the inner kernel step mcmc_parameter_update_fn A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration. initial_parameter_value - Paramter to be used by the mcmc_factory before the first iteration. + Parameter to be used by the mcmc_factory before the first iteration. extra_parameters: parameters to be used for the creation of the smc_algorithm. @@ -123,9 +118,8 @@ def __new__( # type: ignore[misc] smc_algorithm: Union[adaptive_tempered_smc, tempered_smc], logprior_fn: Callable, loglikelihood_fn: Callable, - mcmc_factory: Callable, + mcmc_step_fn: Callable, mcmc_init_fn: Callable, - mcmc_parameters: Dict, resampling_fn: Callable, mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree], initial_parameter_value, @@ -136,9 +130,8 @@ def __new__( # type: ignore[misc] smc_algorithm, logprior_fn, loglikelihood_fn, - mcmc_factory, + mcmc_step_fn, mcmc_init_fn, - mcmc_parameters, resampling_fn, mcmc_parameter_update_fn, num_mcmc_steps, diff --git a/blackjax/smc/tuning/__init__.py b/blackjax/smc/tuning/__init__.py index 6d53fc3f8..e69de29bb 100644 --- a/blackjax/smc/tuning/__init__.py +++ b/blackjax/smc/tuning/__init__.py @@ -1,31 +0,0 @@ -import jax -import jax.numpy as jnp -import numpy as np - - -def extend_to_all_particles(n_particles, tuning_strategy): - """ - given a tuning strategy that returns a single parameter, - that parameter gets extended to be applied to all particles - """ - def extended(state, info): - res = tuning_strategy(state, info) - return jnp.repeat(res, n_particles) - - return extended - -def extend_params_inner_kernel(n_particles, params): - """ - Given a dictionary of params, repeats them for every single particle - Shapes> - scalar, 1000, - 1 . 1000, 1 - 2,2 . 1000,2,2 - """ - def extend(param): - if np.isscalar(param): - return jnp.repeat(param, n_particles, axis=0) - else: - return jnp.repeat(jnp.atleast_1d(param)[np.newaxis, :], n_particles, axis=0) - - return jax.tree_map(extend, params) \ No newline at end of file diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index d8f0752f5..af78ec7ce 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -12,6 +12,8 @@ import blackjax import blackjax.smc.resampling as resampling from blackjax import adaptive_tempered_smc, tempered_smc +from blackjax.mcmc.random_walk import build_irmh +from blackjax.smc import extend_params_inner_kernel from blackjax.smc.inner_kernel_tuning import inner_kernel_tuning from blackjax.smc.tuning.from_kernel_info import update_scale_from_acceptance_rate from blackjax.smc.tuning.from_particles import ( @@ -92,38 +94,31 @@ def smc_inner_kernel_tuning_test_case( proposal_factory.return_value = 100 def mcmc_parameter_update_fn(state, info): - return {"mean":100} + return extend_params_inner_kernel(1000,{"mean":100}) - mcmc_factory = MagicMock() - sampling_algorithm = MagicMock() - mcmc_factory.return_value = sampling_algorithm prior = lambda x: stats.norm.logpdf(x) - def kernel_factory(mean): - kernel = blackjax.irmh.build_kernel() + def wrapped_kernel(rng_key, state, logdensity, mean): + return build_irmh()(rng_key, state, logdensity, functools.partial(irmh_proposal_distribution, mean=mean)) - def wrapped_kernel(rng_key, state, logdensity): - return kernel(rng_key, state, logdensity, functools.partial(irmh_proposal_distribution, mean=mean)) - - return wrapped_kernel kernel = inner_kernel_tuning( logprior_fn=prior, loglikelihood_fn=specialized_log_weights_fn, - mcmc_factory=kernel_factory, + mcmc_step_fn=wrapped_kernel, mcmc_init_fn=blackjax.irmh.init, resampling_fn=resampling.systematic, smc_algorithm=smc_algorithm, - mcmc_parameters={}, mcmc_parameter_update_fn=mcmc_parameter_update_fn, - initial_parameter_value={"mean":1.0}, + initial_parameter_value=extend_params_inner_kernel(1000,{"mean":1.0}), **smc_parameters, ) new_state, new_info = kernel.step( self.key, state=kernel.init(init_particles), **step_parameters ) - assert new_state.parameter_override == {"mean":100} + assert set(new_state.parameter_override.keys()) == {"mean",} + np.testing.assert_allclose(new_state.parameter_override["mean"], 100) class MeanAndStdFromParticlesTest(chex.TestCase): @@ -270,14 +265,6 @@ def setUp(self): super().setUp() self.key = jax.random.key(42) - def mcmc_factory(self, mass_matrix): - return functools.partial( - blackjax.hmc.build_kernel(), - inverse_mass_matrix=mass_matrix, - step_size=10e-2, - num_integration_steps=50, - ) - @chex.all_variants(with_pmap=False) def test_with_adaptive_tempered(self): ( @@ -286,17 +273,23 @@ def test_with_adaptive_tempered(self): loglikelihood_fn, ) = self.particles_prior_loglikelihood() + def parameter_update(state, info): + return extend_params_inner_kernel(100, + {"inverse_mass_matrix": mass_matrix_from_particles(state.particles), + "step_size": 10e-2, + "num_integration_steps":50}) init, step = blackjax.inner_kernel_tuning( adaptive_tempered_smc, logprior_fn, loglikelihood_fn, - self.mcmc_factory, + blackjax.hmc.build_kernel(), blackjax.hmc.init, - {}, resampling.systematic, - mcmc_parameter_update_fn=lambda state, info: {"mass_matrix": - mass_matrix_from_particles(state.particles)}, - initial_parameter_value={"mass_matrix":jnp.eye(2)}, + mcmc_parameter_update_fn=parameter_update, + initial_parameter_value=extend_params_inner_kernel(100, dict( + inverse_mass_matrix = jnp.eye(2), + step_size= 10e-2, + num_integration_steps = 50)), num_mcmc_steps=10, target_ess=0.5, ) @@ -318,7 +311,7 @@ def body(carry): state, _ = inference_loop(smc_kernel, self.key, init_state) - assert state.parameter_override["mass_matrix"].shape == (2, 2) + assert state.parameter_override["inverse_mass_matrix"].shape == (100, 2, 2) self.assert_linear_regression_test_case(state.sampler_state) @chex.all_variants(with_pmap=False) @@ -330,18 +323,24 @@ def test_with_tempered_smc(self): loglikelihood_fn, ) = self.particles_prior_loglikelihood() + def parameter_update(state, info): + return extend_params_inner_kernel(100, + {"inverse_mass_matrix": mass_matrix_from_particles(state.particles), + "step_size": 10e-2, + "num_integration_steps": 50}) + init, step = blackjax.inner_kernel_tuning( tempered_smc, logprior_fn, loglikelihood_fn, - self.mcmc_factory, + blackjax.hmc.build_kernel(), blackjax.hmc.init, - {}, resampling.systematic, - mcmc_parameter_update_fn=lambda state, info: {"mass_matrix":mass_matrix_from_particles( - state.particles - )}, - initial_parameter_value={"mass_matrix":jnp.eye(2)}, + mcmc_parameter_update_fn=parameter_update, + initial_parameter_value=extend_params_inner_kernel(100, dict( + inverse_mass_matrix = jnp.eye(2), + step_size= 10e-2, + num_integration_steps = 50)), num_mcmc_steps=10, ) diff --git a/tests/smc/test_kernel_compatibility.py b/tests/smc/test_kernel_compatibility.py index e79a02a1e..3afa56425 100644 --- a/tests/smc/test_kernel_compatibility.py +++ b/tests/smc/test_kernel_compatibility.py @@ -7,7 +7,7 @@ import blackjax from blackjax import adaptive_tempered_smc from blackjax.mcmc.random_walk import normal -from blackjax.smc.tuning import extend_params_inner_kernel +from blackjax.smc import extend_params_inner_kernel class SMCAndMCMCIntegrationTest(unittest.TestCase): diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index a0ee4c269..3ada15292 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -10,8 +10,7 @@ import blackjax import blackjax.smc.resampling as resampling -from blackjax.smc.base import init, step -from blackjax.smc.tuning import extend_params_inner_kernel +from blackjax.smc.base import init, step, extend_params_inner_kernel def logdensity_fn(position): diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index 0729cce88..89257e638 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -12,7 +12,7 @@ import blackjax.smc.resampling as resampling import blackjax.smc.solver as solver from blackjax import adaptive_tempered_smc, tempered_smc -from blackjax.smc.tuning import extend_params_inner_kernel +from blackjax.smc import extend_params_inner_kernel from tests.smc import SMCLinearRegressionTestCase From 372a5841c9594bde6ca9e2e52c8a6a5e320ad028 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Thu, 21 Mar 2024 16:09:26 -0300 Subject: [PATCH 3/8] pre-commit and typing --- blackjax/smc/__init__.py | 10 ++-- blackjax/smc/base.py | 8 +++- blackjax/smc/inner_kernel_tuning.py | 3 +- blackjax/smc/tuning/from_particles.py | 1 - tests/mcmc/test_sampling.py | 4 +- tests/smc/test_inner_kernel_tuning.py | 65 +++++++++++++++++--------- tests/smc/test_kernel_compatibility.py | 59 +++++++++++++---------- tests/smc/test_smc.py | 65 +++++++++++++++----------- tests/smc/test_tempered_smc.py | 39 ++++++++++------ 9 files changed, 158 insertions(+), 96 deletions(-) diff --git a/blackjax/smc/__init__.py b/blackjax/smc/__init__.py index 492ed67f0..868cdb42f 100644 --- a/blackjax/smc/__init__.py +++ b/blackjax/smc/__init__.py @@ -1,7 +1,9 @@ from . import adaptive_tempered, inner_kernel_tuning, tempered from .base import extend_params_inner_kernel -__all__ = ["adaptive_tempered", "tempered", "inner_kernel_tuning", "extend_params_inner_kernel"] - - - +__all__ = [ + "adaptive_tempered", + "tempered", + "inner_kernel_tuning", + "extend_params_inner_kernel", +] diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 4e8f761a4..77cab2f5a 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -149,15 +149,19 @@ def step( resampling_idx, normalizing_constant, update_info ) + def extend_params_inner_kernel(n_particles, params): """ Given a dictionary of params, repeats them for every single particle. The expected usage is in cases where the aim is to repeat the same parameters for all chains within SMC. """ + def extend(param): if jnp.isscalar(param): return jnp.repeat(param, n_particles, axis=0) else: - return jnp.repeat(jnp.atleast_1d(param)[jnp.newaxis, :], n_particles, axis=0) + return jnp.repeat( + jnp.atleast_1d(param)[jnp.newaxis, :], n_particles, axis=0 + ) - return jax.tree_map(extend, params) \ No newline at end of file + return jax.tree_map(extend, params) diff --git a/blackjax/smc/inner_kernel_tuning.py b/blackjax/smc/inner_kernel_tuning.py index b309b9369..9d3866145 100644 --- a/blackjax/smc/inner_kernel_tuning.py +++ b/blackjax/smc/inner_kernel_tuning.py @@ -14,6 +14,7 @@ class StateWithParameterOverride(NamedTuple): and (n_particles, *) arrays as meanings. The latter represent a parameter per chain for the next mutation step. """ + sampler_state: ArrayTree parameter_override: Dict[str, ArrayTree] @@ -29,7 +30,7 @@ def build_kernel( mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, - mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], ArrayTree], + mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]], num_mcmc_steps: int = 10, **extra_parameters, ) -> Callable: diff --git a/blackjax/smc/tuning/from_particles.py b/blackjax/smc/tuning/from_particles.py index ae5f3c88a..4c8ca98da 100755 --- a/blackjax/smc/tuning/from_particles.py +++ b/blackjax/smc/tuning/from_particles.py @@ -48,4 +48,3 @@ def particles_as_rows(particles): as a matrix where each column is a variable, each row a particle. """ return jax.vmap(lambda x: ravel_pytree(x)[0])(particles) - diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 18b241cf1..51831b587 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -657,7 +657,9 @@ def test_univariate_normal( self, algorithm, initial_position, parameters, num_sampling_steps, burnin ): if algorithm == blackjax.irmh: - parameters["proposal_distribution"] = functools.partial(irmh_proposal_distribution, mean=1.0) + parameters["proposal_distribution"] = functools.partial( + irmh_proposal_distribution, mean=1.0 + ) if algorithm == blackjax.rmh: parameters["proposal_generator"] = rmh_proposal_distribution diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index af78ec7ce..c61dbc9b5 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -94,13 +94,17 @@ def smc_inner_kernel_tuning_test_case( proposal_factory.return_value = 100 def mcmc_parameter_update_fn(state, info): - return extend_params_inner_kernel(1000,{"mean":100}) + return extend_params_inner_kernel(1000, {"mean": 100}) prior = lambda x: stats.norm.logpdf(x) def wrapped_kernel(rng_key, state, logdensity, mean): - return build_irmh()(rng_key, state, logdensity, functools.partial(irmh_proposal_distribution, mean=mean)) - + return build_irmh()( + rng_key, + state, + logdensity, + functools.partial(irmh_proposal_distribution, mean=mean), + ) kernel = inner_kernel_tuning( logprior_fn=prior, @@ -110,14 +114,16 @@ def wrapped_kernel(rng_key, state, logdensity, mean): resampling_fn=resampling.systematic, smc_algorithm=smc_algorithm, mcmc_parameter_update_fn=mcmc_parameter_update_fn, - initial_parameter_value=extend_params_inner_kernel(1000,{"mean":1.0}), + initial_parameter_value=extend_params_inner_kernel(1000, {"mean": 1.0}), **smc_parameters, ) new_state, new_info = kernel.step( self.key, state=kernel.init(init_particles), **step_parameters ) - assert set(new_state.parameter_override.keys()) == {"mean",} + assert set(new_state.parameter_override.keys()) == { + "mean", + } np.testing.assert_allclose(new_state.parameter_override["mean"], 100) @@ -274,10 +280,15 @@ def test_with_adaptive_tempered(self): ) = self.particles_prior_loglikelihood() def parameter_update(state, info): - return extend_params_inner_kernel(100, - {"inverse_mass_matrix": mass_matrix_from_particles(state.particles), - "step_size": 10e-2, - "num_integration_steps":50}) + return extend_params_inner_kernel( + 100, + { + "inverse_mass_matrix": mass_matrix_from_particles(state.particles), + "step_size": 10e-2, + "num_integration_steps": 50, + }, + ) + init, step = blackjax.inner_kernel_tuning( adaptive_tempered_smc, logprior_fn, @@ -286,10 +297,14 @@ def parameter_update(state, info): blackjax.hmc.init, resampling.systematic, mcmc_parameter_update_fn=parameter_update, - initial_parameter_value=extend_params_inner_kernel(100, dict( - inverse_mass_matrix = jnp.eye(2), - step_size= 10e-2, - num_integration_steps = 50)), + initial_parameter_value=extend_params_inner_kernel( + 100, + dict( + inverse_mass_matrix=jnp.eye(2), + step_size=10e-2, + num_integration_steps=50, + ), + ), num_mcmc_steps=10, target_ess=0.5, ) @@ -324,10 +339,14 @@ def test_with_tempered_smc(self): ) = self.particles_prior_loglikelihood() def parameter_update(state, info): - return extend_params_inner_kernel(100, - {"inverse_mass_matrix": mass_matrix_from_particles(state.particles), - "step_size": 10e-2, - "num_integration_steps": 50}) + return extend_params_inner_kernel( + 100, + { + "inverse_mass_matrix": mass_matrix_from_particles(state.particles), + "step_size": 10e-2, + "num_integration_steps": 50, + }, + ) init, step = blackjax.inner_kernel_tuning( tempered_smc, @@ -337,10 +356,14 @@ def parameter_update(state, info): blackjax.hmc.init, resampling.systematic, mcmc_parameter_update_fn=parameter_update, - initial_parameter_value=extend_params_inner_kernel(100, dict( - inverse_mass_matrix = jnp.eye(2), - step_size= 10e-2, - num_integration_steps = 50)), + initial_parameter_value=extend_params_inner_kernel( + 100, + dict( + inverse_mass_matrix=jnp.eye(2), + step_size=10e-2, + num_integration_steps=50, + ), + ), num_mcmc_steps=10, ) diff --git a/tests/smc/test_kernel_compatibility.py b/tests/smc/test_kernel_compatibility.py index 3afa56425..8d1feee6d 100644 --- a/tests/smc/test_kernel_compatibility.py +++ b/tests/smc/test_kernel_compatibility.py @@ -24,7 +24,6 @@ def setUp(self): self.key, jnp.zeros(2), jnp.eye(2), (self.n_particles,) ) - def check_compatible(self, mcmc_step_fn, mcmc_init_fn, mcmc_parameters): """ Runs one SMC step @@ -51,66 +50,76 @@ def kernel(rng_key, state, logdensity_fn, proposal_mean): self.check_compatible( kernel, blackjax.additive_step_random_walk.init, - extend_params_inner_kernel(self.n_particles, {"proposal_mean": 1.0}) + extend_params_inner_kernel(self.n_particles, {"proposal_mean": 1.0}), ) def test_compatible_with_rmh(self): rmh = blackjax.rmh.build_kernel() def kernel( + rng_key, state, logdensity_fn, proposal_mean, proposal_logdensity_fn=None + ): + return rmh( rng_key, state, logdensity_fn, - proposal_mean, - proposal_logdensity_fn=None - ): - return rmh(rng_key, - state, - logdensity_fn, - lambda a, b: blackjax.mcmc.random_walk.normal(proposal_mean)(a, b), - proposal_logdensity_fn) + lambda a, b: blackjax.mcmc.random_walk.normal(proposal_mean)(a, b), + proposal_logdensity_fn, + ) + self.check_compatible( kernel, blackjax.rmh.init, - extend_params_inner_kernel(self.n_particles, {"proposal_mean":1.0}) + extend_params_inner_kernel(self.n_particles, {"proposal_mean": 1.0}), ) def test_compatible_with_hmc(self): self.check_compatible( blackjax.hmc.build_kernel(), blackjax.hmc.init, - extend_params_inner_kernel(self.n_particles,{ - "step_size": 0.3, - "inverse_mass_matrix": jnp.array([1.]), - "num_integration_steps": 1, - }), + extend_params_inner_kernel( + self.n_particles, + { + "step_size": 0.3, + "inverse_mass_matrix": jnp.array([1.0]), + "num_integration_steps": 1, + }, + ), ) def test_compatible_with_irmh(self): def kernel(rng_key, state, logdensity_fn, mean, proposal_logdensity_fn=None): - return blackjax.irmh.build_kernel()(rng_key, - state, - logdensity_fn, - lambda key: mean + jax.random.normal(key), proposal_logdensity_fn) + return blackjax.irmh.build_kernel()( + rng_key, + state, + logdensity_fn, + lambda key: mean + jax.random.normal(key), + proposal_logdensity_fn, + ) self.check_compatible( kernel, blackjax.irmh.init, - extend_params_inner_kernel(self.n_particles,{ - "mean": jnp.array([1.0, 1.0]) - }) + extend_params_inner_kernel( + self.n_particles, {"mean": jnp.array([1.0, 1.0])} + ), ) def test_compatible_with_nuts(self): self.check_compatible( blackjax.nuts.build_kernel(), blackjax.nuts.init, - extend_params_inner_kernel(self.n_particles, {"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)}), + extend_params_inner_kernel( + self.n_particles, + {"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)}, + ), ) def test_compatible_with_mala(self): self.check_compatible( - blackjax.mala.build_kernel(), blackjax.mala.init, extend_params_inner_kernel(self.n_particles,{"step_size": 1e-10}) + blackjax.mala.build_kernel(), + blackjax.mala.init, + extend_params_inner_kernel(self.n_particles, {"step_size": 1e-10}), ) @staticmethod diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 3ada15292..caf7d73c1 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -10,7 +10,7 @@ import blackjax import blackjax.smc.resampling as resampling -from blackjax.smc.base import init, step, extend_params_inner_kernel +from blackjax.smc.base import extend_params_inner_kernel, init, step def logdensity_fn(position): @@ -34,7 +34,7 @@ def test_smc(self): num_particles = 1000 def update_fn(rng_key, position, update_params): - hmc = blackjax.hmc(logdensity_fn,**update_params) + hmc = blackjax.hmc(logdensity_fn, **update_params) state = hmc.init(position) def body_fn(state, rng_key): @@ -49,10 +49,13 @@ def body_fn(state, rng_key): # Initialize the state of the SMC sampler init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,)) - same_for_all_params = dict(step_size=1e-2, - inverse_mass_matrix=jnp.eye(1), - num_integration_steps=50) - state = init(init_particles, extend_params_inner_kernel(num_particles, same_for_all_params)) + same_for_all_params = dict( + step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50 + ) + state = init( + init_particles, + extend_params_inner_kernel(num_particles, same_for_all_params), + ) # Run the SMC sampler once new_state, info = self.variant(step, static_argnums=(2, 3, 4))( @@ -75,10 +78,7 @@ def test_smc_waste_free(self): def waste_free_update_fn(keys, particles, update_params): def one_particle_fn(rng_key, position, particle_update_params): - hmc = blackjax.hmc( - logdensity_fn, - **particle_update_params - ) + hmc = blackjax.hmc(logdensity_fn, **particle_update_params) state = hmc.init(position) def body_fn(state, rng_key): @@ -97,9 +97,17 @@ def body_fn(state, rng_key): # Initialize the state of the SMC sampler init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,)) - state = init(init_particles, extend_params_inner_kernel(num_resampled, dict(step_size=1e-2, - inverse_mass_matrix=jnp.eye(1), - num_integration_steps=100))) + state = init( + init_particles, + extend_params_inner_kernel( + num_resampled, + dict( + step_size=1e-2, + inverse_mass_matrix=jnp.eye(1), + num_integration_steps=100, + ), + ), + ) # Run the SMC sampler once new_state, info = self.variant(step, static_argnums=(2, 3, 4, 5))( @@ -118,19 +126,24 @@ def body_fn(state, rng_key): class ExtendToParticlesTest(unittest.TestCase): def test_extend_params_inner_kernel(self): - extended = extend_params_inner_kernel(3, {"a":50, - "b": np.array([50]), - "c": np.array([50, 60]), - "d": np.array([[1,2],[3,4]]) - }) - np.testing.assert_allclose(extended["a"], np.ones((3,))*50) - np.testing.assert_allclose(extended["b"], np.array([[50],[50],[50]])) - np.testing.assert_allclose(extended["c"], np.array([[50, 60], [50, 60], [50, 60]])) - np.testing.assert_allclose(extended["d"], np.array([[[1,2],[3,4]], - [[1, 2], [3, 4]], - [[1, 2], [3, 4]] - ])) - + extended = extend_params_inner_kernel( + 3, + { + "a": 50, + "b": np.array([50]), + "c": np.array([50, 60]), + "d": np.array([[1, 2], [3, 4]]), + }, + ) + np.testing.assert_allclose(extended["a"], np.ones((3,)) * 50) + np.testing.assert_allclose(extended["b"], np.array([[50], [50], [50]])) + np.testing.assert_allclose( + extended["c"], np.array([[50, 60], [50, 60], [50, 60]]) + ) + np.testing.assert_allclose( + extended["d"], + np.array([[[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]]]), + ) if __name__ == "__main__": diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index 89257e638..ca2fae94b 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -65,11 +65,14 @@ def logprior_fn(x): hmc_kernel = blackjax.hmc.build_kernel() hmc_init = blackjax.hmc.init - hmc_parameters = extend_params_inner_kernel(num_particles, { - "step_size": 10e-2, - "inverse_mass_matrix": jnp.eye(2), - "num_integration_steps": 50, - }) + hmc_parameters = extend_params_inner_kernel( + num_particles, + { + "step_size": 10e-2, + "inverse_mass_matrix": jnp.eye(2), + "num_integration_steps": 50, + }, + ) for target_ess in [0.5, 0.75]: tempering = adaptive_tempered_smc( @@ -111,11 +114,14 @@ def test_fixed_schedule_tempered_smc(self): lambda_schedule = np.logspace(-5, 0, num_tempering_steps) hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() - hmc_parameters = extend_params_inner_kernel(100, { - "step_size": 10e-2, - "inverse_mass_matrix": jnp.eye(2), - "num_integration_steps": 50, - }) + hmc_parameters = extend_params_inner_kernel( + 100, + { + "step_size": 10e-2, + "inverse_mass_matrix": jnp.eye(2), + "num_integration_steps": 50, + }, + ) tempering = tempered_smc( logprior_fn, @@ -175,11 +181,14 @@ def test_normalizing_constant(self): hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() - hmc_parameters = extend_params_inner_kernel(num_particles, { - "step_size": 10e-2, - "inverse_mass_matrix": jnp.eye(num_dim), - "num_integration_steps": 50, - }) + hmc_parameters = extend_params_inner_kernel( + num_particles, + { + "step_size": 10e-2, + "inverse_mass_matrix": jnp.eye(num_dim), + "num_integration_steps": 50, + }, + ) tempering = adaptive_tempered_smc( logprior_fn, From ea8e59096bb4423fc6bfe1446b898577cd94643c Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Fri, 22 Mar 2024 13:29:52 -0300 Subject: [PATCH 4/8] CRU and docs improvement --- blackjax/smc/base.py | 11 ++--------- blackjax/smc/inner_kernel_tuning.py | 1 + 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 77cab2f5a..30730115f 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -151,17 +151,10 @@ def step( def extend_params_inner_kernel(n_particles, params): - """ - Given a dictionary of params, repeats them for every single particle. The expected + """Given a dictionary of params, repeats them for every single particle. The expected usage is in cases where the aim is to repeat the same parameters for all chains within SMC. """ - def extend(param): - if jnp.isscalar(param): - return jnp.repeat(param, n_particles, axis=0) - else: - return jnp.repeat( - jnp.atleast_1d(param)[jnp.newaxis, :], n_particles, axis=0 - ) + return jnp.repeat(jnp.asarray(param)[None, ...], n_particles, axis=0) return jax.tree_map(extend, params) diff --git a/blackjax/smc/inner_kernel_tuning.py b/blackjax/smc/inner_kernel_tuning.py index 9d3866145..705a60c35 100644 --- a/blackjax/smc/inner_kernel_tuning.py +++ b/blackjax/smc/inner_kernel_tuning.py @@ -49,6 +49,7 @@ def build_kernel( A function that returns the probability at a given position. mcmc_step_fn: The transition kernel, should take as parameters the dictionary output of mcmc_parameter_update_fn. + mcmc_step_fn(rng_key, state, tempered_logposterior_fn, **mcmc_parameter_update_fn()) mcmc_init_fn A callable that initializes the inner kernel mcmc_parameter_update_fn From 663c4a0dc9f7ab0f8ffa50ef7d7fa71a83ecfcad Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Fri, 22 Mar 2024 13:31:42 -0300 Subject: [PATCH 5/8] pre-commit --- blackjax/smc/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 30730115f..ba18eba33 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -154,6 +154,7 @@ def extend_params_inner_kernel(n_particles, params): """Given a dictionary of params, repeats them for every single particle. The expected usage is in cases where the aim is to repeat the same parameters for all chains within SMC. """ + def extend(param): return jnp.repeat(jnp.asarray(param)[None, ...], n_particles, axis=0) From 5208fc2f132fd46801b53f021ec792a67ec477aa Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Mon, 25 Mar 2024 11:08:04 -0300 Subject: [PATCH 6/8] code review updates --- blackjax/smc/__init__.py | 4 ++-- blackjax/smc/base.py | 2 +- tests/smc/test_inner_kernel_tuning.py | 14 +++++++------- tests/smc/test_kernel_compatibility.py | 14 +++++++------- tests/smc/test_smc.py | 12 +++++------- tests/smc/test_tempered_smc.py | 8 ++++---- 6 files changed, 26 insertions(+), 28 deletions(-) diff --git a/blackjax/smc/__init__.py b/blackjax/smc/__init__.py index 868cdb42f..ef10b10e6 100644 --- a/blackjax/smc/__init__.py +++ b/blackjax/smc/__init__.py @@ -1,9 +1,9 @@ from . import adaptive_tempered, inner_kernel_tuning, tempered -from .base import extend_params_inner_kernel +from .base import extend_params __all__ = [ "adaptive_tempered", "tempered", "inner_kernel_tuning", - "extend_params_inner_kernel", + "extend_params", ] diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index ba18eba33..4a9ff17c3 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -150,7 +150,7 @@ def step( ) -def extend_params_inner_kernel(n_particles, params): +def extend_params(n_particles, params): """Given a dictionary of params, repeats them for every single particle. The expected usage is in cases where the aim is to repeat the same parameters for all chains within SMC. """ diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index c61dbc9b5..cf1db09dd 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -13,7 +13,7 @@ import blackjax.smc.resampling as resampling from blackjax import adaptive_tempered_smc, tempered_smc from blackjax.mcmc.random_walk import build_irmh -from blackjax.smc import extend_params_inner_kernel +from blackjax.smc import extend_params from blackjax.smc.inner_kernel_tuning import inner_kernel_tuning from blackjax.smc.tuning.from_kernel_info import update_scale_from_acceptance_rate from blackjax.smc.tuning.from_particles import ( @@ -94,7 +94,7 @@ def smc_inner_kernel_tuning_test_case( proposal_factory.return_value = 100 def mcmc_parameter_update_fn(state, info): - return extend_params_inner_kernel(1000, {"mean": 100}) + return extend_params(1000, {"mean": 100}) prior = lambda x: stats.norm.logpdf(x) @@ -114,7 +114,7 @@ def wrapped_kernel(rng_key, state, logdensity, mean): resampling_fn=resampling.systematic, smc_algorithm=smc_algorithm, mcmc_parameter_update_fn=mcmc_parameter_update_fn, - initial_parameter_value=extend_params_inner_kernel(1000, {"mean": 1.0}), + initial_parameter_value=extend_params(1000, {"mean": 1.0}), **smc_parameters, ) @@ -280,7 +280,7 @@ def test_with_adaptive_tempered(self): ) = self.particles_prior_loglikelihood() def parameter_update(state, info): - return extend_params_inner_kernel( + return extend_params( 100, { "inverse_mass_matrix": mass_matrix_from_particles(state.particles), @@ -297,7 +297,7 @@ def parameter_update(state, info): blackjax.hmc.init, resampling.systematic, mcmc_parameter_update_fn=parameter_update, - initial_parameter_value=extend_params_inner_kernel( + initial_parameter_value=extend_params( 100, dict( inverse_mass_matrix=jnp.eye(2), @@ -339,7 +339,7 @@ def test_with_tempered_smc(self): ) = self.particles_prior_loglikelihood() def parameter_update(state, info): - return extend_params_inner_kernel( + return extend_params( 100, { "inverse_mass_matrix": mass_matrix_from_particles(state.particles), @@ -356,7 +356,7 @@ def parameter_update(state, info): blackjax.hmc.init, resampling.systematic, mcmc_parameter_update_fn=parameter_update, - initial_parameter_value=extend_params_inner_kernel( + initial_parameter_value=extend_params( 100, dict( inverse_mass_matrix=jnp.eye(2), diff --git a/tests/smc/test_kernel_compatibility.py b/tests/smc/test_kernel_compatibility.py index 8d1feee6d..091d8ad88 100644 --- a/tests/smc/test_kernel_compatibility.py +++ b/tests/smc/test_kernel_compatibility.py @@ -7,7 +7,7 @@ import blackjax from blackjax import adaptive_tempered_smc from blackjax.mcmc.random_walk import normal -from blackjax.smc import extend_params_inner_kernel +from blackjax.smc import extend_params class SMCAndMCMCIntegrationTest(unittest.TestCase): @@ -50,7 +50,7 @@ def kernel(rng_key, state, logdensity_fn, proposal_mean): self.check_compatible( kernel, blackjax.additive_step_random_walk.init, - extend_params_inner_kernel(self.n_particles, {"proposal_mean": 1.0}), + extend_params(self.n_particles, {"proposal_mean": 1.0}), ) def test_compatible_with_rmh(self): @@ -70,14 +70,14 @@ def kernel( self.check_compatible( kernel, blackjax.rmh.init, - extend_params_inner_kernel(self.n_particles, {"proposal_mean": 1.0}), + extend_params(self.n_particles, {"proposal_mean": 1.0}), ) def test_compatible_with_hmc(self): self.check_compatible( blackjax.hmc.build_kernel(), blackjax.hmc.init, - extend_params_inner_kernel( + extend_params( self.n_particles, { "step_size": 0.3, @@ -100,7 +100,7 @@ def kernel(rng_key, state, logdensity_fn, mean, proposal_logdensity_fn=None): self.check_compatible( kernel, blackjax.irmh.init, - extend_params_inner_kernel( + extend_params( self.n_particles, {"mean": jnp.array([1.0, 1.0])} ), ) @@ -109,7 +109,7 @@ def test_compatible_with_nuts(self): self.check_compatible( blackjax.nuts.build_kernel(), blackjax.nuts.init, - extend_params_inner_kernel( + extend_params( self.n_particles, {"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)}, ), @@ -119,7 +119,7 @@ def test_compatible_with_mala(self): self.check_compatible( blackjax.mala.build_kernel(), blackjax.mala.init, - extend_params_inner_kernel(self.n_particles, {"step_size": 1e-10}), + extend_params(self.n_particles, {"step_size": 1e-10}), ) @staticmethod diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index caf7d73c1..9cbd45a0c 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -1,6 +1,4 @@ """Test the generic SMC sampler""" -import unittest - import chex import jax import jax.numpy as jnp @@ -10,7 +8,7 @@ import blackjax import blackjax.smc.resampling as resampling -from blackjax.smc.base import extend_params_inner_kernel, init, step +from blackjax.smc.base import extend_params, init, step def logdensity_fn(position): @@ -54,7 +52,7 @@ def body_fn(state, rng_key): ) state = init( init_particles, - extend_params_inner_kernel(num_particles, same_for_all_params), + extend_params(num_particles, same_for_all_params), ) # Run the SMC sampler once @@ -99,7 +97,7 @@ def body_fn(state, rng_key): init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,)) state = init( init_particles, - extend_params_inner_kernel( + extend_params( num_resampled, dict( step_size=1e-2, @@ -124,9 +122,9 @@ def body_fn(state, rng_key): np.testing.assert_allclose(1.0, std, atol=1e-1) -class ExtendToParticlesTest(unittest.TestCase): +class ExtendToParticlesTest(chex.TestCase): def test_extend_params_inner_kernel(self): - extended = extend_params_inner_kernel( + extended = extend_params( 3, { "a": 50, diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index ca2fae94b..3ab387e14 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -12,7 +12,7 @@ import blackjax.smc.resampling as resampling import blackjax.smc.solver as solver from blackjax import adaptive_tempered_smc, tempered_smc -from blackjax.smc import extend_params_inner_kernel +from blackjax.smc import extend_params from tests.smc import SMCLinearRegressionTestCase @@ -65,7 +65,7 @@ def logprior_fn(x): hmc_kernel = blackjax.hmc.build_kernel() hmc_init = blackjax.hmc.init - hmc_parameters = extend_params_inner_kernel( + hmc_parameters = extend_params( num_particles, { "step_size": 10e-2, @@ -114,7 +114,7 @@ def test_fixed_schedule_tempered_smc(self): lambda_schedule = np.logspace(-5, 0, num_tempering_steps) hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() - hmc_parameters = extend_params_inner_kernel( + hmc_parameters = extend_params( 100, { "step_size": 10e-2, @@ -181,7 +181,7 @@ def test_normalizing_constant(self): hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() - hmc_parameters = extend_params_inner_kernel( + hmc_parameters = extend_params( num_particles, { "step_size": 10e-2, From 2ee74e145aab144d9e555d4c2b111a820bd133e9 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Mon, 25 Mar 2024 11:08:26 -0300 Subject: [PATCH 7/8] pre-commit --- tests/smc/test_kernel_compatibility.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/smc/test_kernel_compatibility.py b/tests/smc/test_kernel_compatibility.py index 091d8ad88..3e675c2cc 100644 --- a/tests/smc/test_kernel_compatibility.py +++ b/tests/smc/test_kernel_compatibility.py @@ -100,9 +100,7 @@ def kernel(rng_key, state, logdensity_fn, mean, proposal_logdensity_fn=None): self.check_compatible( kernel, blackjax.irmh.init, - extend_params( - self.n_particles, {"mean": jnp.array([1.0, 1.0])} - ), + extend_params(self.n_particles, {"mean": jnp.array([1.0, 1.0])}), ) def test_compatible_with_nuts(self): From f3d9170267a3928f079bcd2c3f08d077570fda36 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Mon, 25 Mar 2024 11:12:15 -0300 Subject: [PATCH 8/8] rename test --- tests/smc/test_smc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 9cbd45a0c..2838e984f 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -122,8 +122,8 @@ def body_fn(state, rng_key): np.testing.assert_allclose(1.0, std, atol=1e-1) -class ExtendToParticlesTest(chex.TestCase): - def test_extend_params_inner_kernel(self): +class ExtendParamsTest(chex.TestCase): + def test_extend_params(self): extended = extend_params( 3, {