From 0b12b59e37e4906b5d904a02661a17dd6c2287f3 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Wed, 12 Feb 2025 15:12:38 -0300 Subject: [PATCH 1/3] impl --- blackjax/smc/inner_kernel_tuning.py | 50 +++++++- blackjax/smc/pretuning.py | 8 ++ blackjax/smc/tuning/from_particles.py | 6 +- tests/smc/test_inner_kernel_tuning.py | 172 +++++++++++++++++++++----- 4 files changed, 198 insertions(+), 38 deletions(-) diff --git a/blackjax/smc/inner_kernel_tuning.py b/blackjax/smc/inner_kernel_tuning.py index 2a63fd1ce..334a1488c 100644 --- a/blackjax/smc/inner_kernel_tuning.py +++ b/blackjax/smc/inner_kernel_tuning.py @@ -1,5 +1,7 @@ from typing import Callable, Dict, NamedTuple, Tuple +import jax + from blackjax.base import SamplingAlgorithm from blackjax.smc.base import SMCInfo, SMCState from blackjax.types import ArrayTree, PRNGKey @@ -28,8 +30,11 @@ def build_kernel( mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, - mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]], + mcmc_parameter_update_fn: Callable[ + [PRNGKey, SMCState, SMCInfo], Dict[str, ArrayTree] + ], num_mcmc_steps: int = 10, + smc_returns_state_with_parameter_override=False, **extra_parameters, ) -> Callable: """In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner @@ -40,7 +45,8 @@ def build_kernel( ---------- smc_algorithm Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of - a sampling algorithm that returns an SMCState and SMCInfo pair). + a sampling algorithm that returns an SMCState and SMCInfo pair). It is also possible for this + to return an StateWithParameterOverride, in such case smc_returns_state_with_parameter_override needs to be True logprior_fn A function that computes the log density of the prior distribution loglikelihood_fn @@ -54,7 +60,30 @@ def build_kernel( A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration. extra_parameters: parameters to be used for the creation of the smc_algorithm. + smc_returns_state_with_parameter_override: + a boolean indicating that the underlying smc_algorithm returns a smc_returns_state_with_parameter_override. + this is used in order to compose different adaptation mechanisms, such as pretuning with tuning. """ + if smc_returns_state_with_parameter_override: + + def extract_state_for_delegate(state): + return state + + def compose_new_state(new_state, new_parameter_override): + composed_parameter_override = ( + new_state.parameter_override | new_parameter_override + ) + return StateWithParameterOverride( + new_state.sampler_state, composed_parameter_override + ) + + else: + + def extract_state_for_delegate(state): + return state.sampler_state + + def compose_new_state(new_state, new_parameter_override): + return StateWithParameterOverride(new_state, new_parameter_override) def kernel( rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters @@ -69,9 +98,14 @@ def kernel( num_mcmc_steps=num_mcmc_steps, **extra_parameters, ).step - new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters) - new_parameter_override = mcmc_parameter_update_fn(new_state, info) - return StateWithParameterOverride(new_state, new_parameter_override), info + parameter_update_key, step_key = jax.random.split(rng_key, 2) + new_state, info = step_fn( + step_key, extract_state_for_delegate(state), **extra_step_parameters + ) + new_parameter_override = mcmc_parameter_update_fn( + parameter_update_key, new_state, info + ) + return compose_new_state(new_state, new_parameter_override), info return kernel @@ -83,9 +117,12 @@ def as_top_level_api( mcmc_step_fn: Callable, mcmc_init_fn: Callable, resampling_fn: Callable, - mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]], + mcmc_parameter_update_fn: Callable[ + [PRNGKey, SMCState, SMCInfo], Dict[str, ArrayTree] + ], initial_parameter_value, num_mcmc_steps: int = 10, + smc_returns_state_with_parameter_override=False, **extra_parameters, ) -> SamplingAlgorithm: """In the context of an SMC sampler (whose step_fn returning state @@ -130,6 +167,7 @@ def as_top_level_api( resampling_fn, mcmc_parameter_update_fn, num_mcmc_steps, + smc_returns_state_with_parameter_override, **extra_parameters, ) diff --git a/blackjax/smc/pretuning.py b/blackjax/smc/pretuning.py index f489a0dc2..d4aa56969 100644 --- a/blackjax/smc/pretuning.py +++ b/blackjax/smc/pretuning.py @@ -99,6 +99,14 @@ def update_parameter_distribution( ) +def default_measure_factory(state): + inverse_mass_matrix = state.parameter_override["inverse_mass_matrix"] + if not (len(inverse_mass_matrix.shape) == 3 and inverse_mass_matrix.shape[0] == 1): + raise ValueError("ESJD only works if chains share the inverse_mass_matrix.") + + return esjd(inverse_mass_matrix[0]) + + def build_pretune( mcmc_init_fn: Callable, mcmc_step_fn: Callable, diff --git a/blackjax/smc/tuning/from_particles.py b/blackjax/smc/tuning/from_particles.py index 279a718cb..c027fdf87 100755 --- a/blackjax/smc/tuning/from_particles.py +++ b/blackjax/smc/tuning/from_particles.py @@ -12,7 +12,7 @@ "particles_means", "particles_stds", "particles_covariance_matrix", - "mass_matrix_from_particles", + "inverse_mass_matrix_from_particles", ] @@ -28,7 +28,7 @@ def particles_covariance_matrix(particles): return jnp.cov(particles_as_rows(particles), ddof=0, rowvar=False) -def mass_matrix_from_particles(particles) -> Array: +def inverse_mass_matrix_from_particles(particles) -> Array: """ Implements tuning from section 3.1 from https://arxiv.org/pdf/1808.07730.pdf Computing a mass matrix to be used in HMC from particles. @@ -39,7 +39,7 @@ def mass_matrix_from_particles(particles) -> Array: ------- A mass Matrix """ - return jnp.diag(1.0 / jnp.var(particles_as_rows(particles), axis=0)) + return jnp.diag(jnp.var(particles_as_rows(particles), axis=0)) def particles_as_rows(particles): diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index 7d6190af5..d7daaf839 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -15,9 +15,10 @@ from blackjax.mcmc.random_walk import build_irmh from blackjax.smc import extend_params from blackjax.smc.inner_kernel_tuning import as_top_level_api as inner_kernel_tuning +from blackjax.smc.pretuning import build_pretune from blackjax.smc.tuning.from_kernel_info import update_scale_from_acceptance_rate from blackjax.smc.tuning.from_particles import ( - mass_matrix_from_particles, + inverse_mass_matrix_from_particles, particles_as_rows, particles_covariance_matrix, particles_means, @@ -93,7 +94,7 @@ def smc_inner_kernel_tuning_test_case( proposal_factory = MagicMock() proposal_factory.return_value = 100 - def mcmc_parameter_update_fn(state, info): + def mcmc_parameter_update_fn(key, state, info): return extend_params({"mean": 100}) prior = lambda x: stats.norm.logpdf(x) @@ -186,30 +187,30 @@ def setUp(self): self.key = jax.random.key(42) def test_inverse_mass_matrix_from_particles(self): - inverse_mass_matrix = mass_matrix_from_particles( + inverse_mass_matrix = inverse_mass_matrix_from_particles( np.array([np.array(10.0), np.array(3.0)]) ) np.testing.assert_allclose( - inverse_mass_matrix, np.diag(np.array([0.08163])), rtol=1e-4 + inverse_mass_matrix, np.diag(np.array([12.25])), rtol=1e-4 ) def test_inverse_mass_matrix_from_multivariate_particles(self): - inverse_mass_matrix = mass_matrix_from_particles( + inverse_mass_matrix = inverse_mass_matrix_from_particles( np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) ) np.testing.assert_allclose( - inverse_mass_matrix, np.diag(np.array([0.081633, 0.033058])), rtol=1e-4 + inverse_mass_matrix, np.diag(np.array([12.25, 30.25])), rtol=1e-4 ) def test_inverse_mass_matrix_from_multivariable_particles(self): var1 = np.array([jnp.array([10.0, 15.0]), jnp.array([3.0, 4.0])]) var2 = np.array([jnp.array([10.0]), jnp.array([3.0])]) init_particles = {"var1": var1, "var2": var2} - mass_matrix = mass_matrix_from_particles(init_particles) + mass_matrix = inverse_mass_matrix_from_particles(init_particles) assert mass_matrix.shape == (3, 3) np.testing.assert_allclose( np.diag(mass_matrix), - np.array([0.081633, 0.033058, 0.081633], dtype="float32"), + np.array([12.25, 30.25, 12.25], dtype="float32"), rtol=1e-4, ) @@ -217,10 +218,10 @@ def test_inverse_mass_matrix_from_multivariable_univariate_particles(self): var1 = np.array([3.0, 2.0]) var2 = np.array([10.0, 3.0]) init_particles = {"var1": var1, "var2": var2} - mass_matrix = mass_matrix_from_particles(init_particles) + mass_matrix = inverse_mass_matrix_from_particles(init_particles) assert mass_matrix.shape == (2, 2) np.testing.assert_allclose( - np.diag(mass_matrix), np.array([4, 0.081633], dtype="float32"), rtol=1e-4 + np.diag(mass_matrix), np.array([0.25, 12.25], dtype="float32"), rtol=1e-4 ) @@ -279,10 +280,12 @@ def test_with_adaptive_tempered(self): loglikelihood_fn, ) = self.particles_prior_loglikelihood() - def parameter_update(state, info): + def parameter_update(key, state, info): return extend_params( { - "inverse_mass_matrix": mass_matrix_from_particles(state.particles), + "inverse_mass_matrix": inverse_mass_matrix_from_particles( + state.particles + ), "step_size": 10e-2, "num_integration_steps": 50, }, @@ -308,21 +311,7 @@ def parameter_update(state, info): ) init_state = init(init_particles) smc_kernel = self.variant(step) - - def inference_loop(kernel, rng_key, initial_state): - def cond(carry): - _, state = carry - return state.sampler_state.lmbda < 1 - - def body(carry): - i, state = carry - subkey = jax.random.fold_in(rng_key, i) - state, _ = kernel(subkey, state) - return i + 1, state - - return jax.lax.while_loop(cond, body, (0, initial_state)) - - _, state = inference_loop(smc_kernel, self.key, init_state) + _, state = adaptive_tempered_loop(smc_kernel, self.key, init_state) assert state.parameter_override["inverse_mass_matrix"].shape == (1, 2, 2) self.assert_linear_regression_test_case(state.sampler_state) @@ -336,10 +325,12 @@ def test_with_tempered_smc(self): loglikelihood_fn, ) = self.particles_prior_loglikelihood() - def parameter_update(state, info): + def parameter_update(key, state, info): return extend_params( { - "inverse_mass_matrix": mass_matrix_from_particles(state.particles), + "inverse_mass_matrix": inverse_mass_matrix_from_particles( + state.particles + ), "step_size": 10e-2, "num_integration_steps": 50, }, @@ -393,5 +384,128 @@ def test_particles_as_rows(self): np.testing.assert_array_equal(np.arange(3 * 5 + 2), flatten_particles[0]) +def adaptive_tempered_loop(kernel, rng_key, initial_state): + def cond(carry): + _, state = carry + return state.sampler_state.lmbda < 1 + + def body(carry): + i, state = carry + subkey = jax.random.fold_in(rng_key, i) + state, _ = kernel(subkey, state) + return i + 1, state + + return jax.lax.while_loop(cond, body, (0, initial_state)) + + +class MultipleTuningTest(SMCLinearRegressionTestCase): + def setUp(self): + super().setUp() + self.key = jax.random.key(42) + + @chex.all_variants(with_pmap=False) + def test_tuning_pretuning(self): + """ + Tests that we can apply tuning on some parameters + and pretuning in some others at the same time. + """ + + ( + init_particles, + logprior_fn, + loglikelihood_fn, + ) = self.particles_prior_loglikelihood() + + n_particles = 100 + dimentions = 2 + + step_size_key, integration_steps_key = jax.random.split(self.key, 2) + + # Set initial samples for integration steps and step sizes. + integration_steps_distribution = jnp.round( + jax.random.uniform( + integration_steps_key, (n_particles,), minval=1, maxval=50 + ) + ).astype(int) + + step_sizes_distribution = jax.random.uniform( + step_size_key, (n_particles,), minval=1e-1 / 2, maxval=1e-1 * 2 + ) + + # Fixes inverse_mass_matrix and distribution for the other two parameters. + initial_parameters = dict( + inverse_mass_matrix=extend_params(jnp.eye(dimentions)), + step_size=step_sizes_distribution, + num_integration_steps=integration_steps_distribution, + ) + + pretune = build_pretune( + blackjax.hmc.init, + blackjax.hmc.build_kernel(), + alpha=2, + n_particles=n_particles, + sigma_parameters={ + "step_size": jnp.array(0.1), + "num_integration_steps": jnp.array(2.0), + }, + natural_parameters=["num_integration_steps"], + positive_parameters=["step_size"], + ) + + def pretuning_factory( + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + mcmc_parameters, + resampling_fn, + num_mcmc_steps, + initial_parameter_value, + target_ess, + ): + # we need to wrap the pretuning into a factory, which is what + # the inner_kernel_tuning expects + return blackjax.pretuning( + blackjax.adaptive_tempered_smc, + logprior_fn, + loglikelihood_fn, + mcmc_step_fn, + mcmc_init_fn, + resampling_fn, + num_mcmc_steps, + initial_parameter_value, + pretune, + target_ess=target_ess, + ) + + def mcmc_parameter_update_fn(key, state, info): + imm = inverse_mass_matrix_from_particles(state.sampler_state.particles) + return {"inverse_mass_matrix": extend_params(imm)} + + step = blackjax.smc.inner_kernel_tuning.build_kernel( + pretuning_factory, + logprior_fn, + loglikelihood_fn, + blackjax.hmc.build_kernel(), + blackjax.hmc.init, + resampling.systematic, + mcmc_parameter_update_fn=mcmc_parameter_update_fn, + initial_parameter_value=initial_parameters, + num_mcmc_steps=10, + target_ess=0.5, + smc_returns_state_with_parameter_override=True, + ) + + def init(position): + return blackjax.smc.inner_kernel_tuning.init( + blackjax.adaptive_tempered_smc.init, position, initial_parameters + ) + + init_state = init(init_particles) + smc_kernel = self.variant(step) + _, state = adaptive_tempered_loop(smc_kernel, self.key, init_state) + self.assert_linear_regression_test_case(state.sampler_state) + + if __name__ == "__main__": absltest.main() From 020b966991490a1341466abfe488d279f8a53549 Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Wed, 12 Feb 2025 15:15:47 -0300 Subject: [PATCH 2/3] rename --- blackjax/smc/pretuning.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/blackjax/smc/pretuning.py b/blackjax/smc/pretuning.py index d4aa56969..374b8f425 100644 --- a/blackjax/smc/pretuning.py +++ b/blackjax/smc/pretuning.py @@ -113,9 +113,7 @@ def build_pretune( alpha: float, sigma_parameters: ArrayLikeTree, n_particles: int, - performance_of_chain_measure_factory: Callable = lambda state: esjd( - state.parameter_override["inverse_mass_matrix"] - ), + performance_of_chain_measure_factory: Callable = default_measure_factory, natural_parameters: Optional[List[str]] = None, positive_parameters: Optional[List[str]] = None, ): From d7934d688fb9d204fc764109db45a0d6993f614b Mon Sep 17 00:00:00 2001 From: Carlos Iguaran Date: Wed, 12 Feb 2025 15:17:45 -0300 Subject: [PATCH 3/3] docs --- blackjax/smc/tuning/from_particles.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/blackjax/smc/tuning/from_particles.py b/blackjax/smc/tuning/from_particles.py index c027fdf87..505e7f3a1 100755 --- a/blackjax/smc/tuning/from_particles.py +++ b/blackjax/smc/tuning/from_particles.py @@ -31,13 +31,11 @@ def particles_covariance_matrix(particles): def inverse_mass_matrix_from_particles(particles) -> Array: """ Implements tuning from section 3.1 from https://arxiv.org/pdf/1808.07730.pdf - Computing a mass matrix to be used in HMC from particles. - Given the particles covariance matrix, set all non-diagonal elements as zero, - take the inverse, and keep the diagonal. + Computing an inverse mass matrix to be used in HMC from particles. Returns ------- - A mass Matrix + An inverse mass matrix """ return jnp.diag(jnp.var(particles_as_rows(particles), axis=0))