Skip to content

Commit

Permalink
SMC: Joint tuning and pretuning (#776)
Browse files Browse the repository at this point in the history
* impl

* rename

* docs

---------

Co-authored-by: Junpeng Lao <[email protected]>
  • Loading branch information
ciguaran and junpenglao authored Feb 19, 2025
1 parent 3f0cbb7 commit 7e4241f
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 45 deletions.
50 changes: 44 additions & 6 deletions blackjax/smc/inner_kernel_tuning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Callable, Dict, NamedTuple, Tuple

import jax

from blackjax.base import SamplingAlgorithm
from blackjax.smc.base import SMCInfo, SMCState
from blackjax.types import ArrayTree, PRNGKey
Expand Down Expand Up @@ -28,8 +30,11 @@ def build_kernel(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]],
mcmc_parameter_update_fn: Callable[
[PRNGKey, SMCState, SMCInfo], Dict[str, ArrayTree]
],
num_mcmc_steps: int = 10,
smc_returns_state_with_parameter_override=False,
**extra_parameters,
) -> Callable:
"""In the context of an SMC sampler (whose step_fn returning state has a .particles attribute), there's an inner
Expand All @@ -40,7 +45,8 @@ def build_kernel(
----------
smc_algorithm
Either blackjax.adaptive_tempered_smc or blackjax.tempered_smc (or any other implementation of
a sampling algorithm that returns an SMCState and SMCInfo pair).
a sampling algorithm that returns an SMCState and SMCInfo pair). It is also possible for this
to return an StateWithParameterOverride, in such case smc_returns_state_with_parameter_override needs to be True
logprior_fn
A function that computes the log density of the prior distribution
loglikelihood_fn
Expand All @@ -54,7 +60,30 @@ def build_kernel(
A callable that takes the SMCState and SMCInfo at step i and constructs a parameter to be used by the inner kernel in i+1 iteration.
extra_parameters:
parameters to be used for the creation of the smc_algorithm.
smc_returns_state_with_parameter_override:
a boolean indicating that the underlying smc_algorithm returns a smc_returns_state_with_parameter_override.
this is used in order to compose different adaptation mechanisms, such as pretuning with tuning.
"""
if smc_returns_state_with_parameter_override:

def extract_state_for_delegate(state):
return state

def compose_new_state(new_state, new_parameter_override):
composed_parameter_override = (
new_state.parameter_override | new_parameter_override
)
return StateWithParameterOverride(
new_state.sampler_state, composed_parameter_override
)

else:

def extract_state_for_delegate(state):
return state.sampler_state

def compose_new_state(new_state, new_parameter_override):
return StateWithParameterOverride(new_state, new_parameter_override)

def kernel(
rng_key: PRNGKey, state: StateWithParameterOverride, **extra_step_parameters
Expand All @@ -69,9 +98,14 @@ def kernel(
num_mcmc_steps=num_mcmc_steps,
**extra_parameters,
).step
new_state, info = step_fn(rng_key, state.sampler_state, **extra_step_parameters)
new_parameter_override = mcmc_parameter_update_fn(new_state, info)
return StateWithParameterOverride(new_state, new_parameter_override), info
parameter_update_key, step_key = jax.random.split(rng_key, 2)
new_state, info = step_fn(
step_key, extract_state_for_delegate(state), **extra_step_parameters
)
new_parameter_override = mcmc_parameter_update_fn(
parameter_update_key, new_state, info
)
return compose_new_state(new_state, new_parameter_override), info

return kernel

Expand All @@ -83,9 +117,12 @@ def as_top_level_api(
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
mcmc_parameter_update_fn: Callable[[SMCState, SMCInfo], Dict[str, ArrayTree]],
mcmc_parameter_update_fn: Callable[
[PRNGKey, SMCState, SMCInfo], Dict[str, ArrayTree]
],
initial_parameter_value,
num_mcmc_steps: int = 10,
smc_returns_state_with_parameter_override=False,
**extra_parameters,
) -> SamplingAlgorithm:
"""In the context of an SMC sampler (whose step_fn returning state
Expand Down Expand Up @@ -130,6 +167,7 @@ def as_top_level_api(
resampling_fn,
mcmc_parameter_update_fn,
num_mcmc_steps,
smc_returns_state_with_parameter_override,
**extra_parameters,
)

Expand Down
12 changes: 9 additions & 3 deletions blackjax/smc/pretuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,21 @@ def update_parameter_distribution(
)


def default_measure_factory(state):
inverse_mass_matrix = state.parameter_override["inverse_mass_matrix"]
if not (len(inverse_mass_matrix.shape) == 3 and inverse_mass_matrix.shape[0] == 1):
raise ValueError("ESJD only works if chains share the inverse_mass_matrix.")

return esjd(inverse_mass_matrix[0])


def build_pretune(
mcmc_init_fn: Callable,
mcmc_step_fn: Callable,
alpha: float,
sigma_parameters: ArrayLikeTree,
n_particles: int,
performance_of_chain_measure_factory: Callable = lambda state: esjd(
state.parameter_override["inverse_mass_matrix"]
),
performance_of_chain_measure_factory: Callable = default_measure_factory,
natural_parameters: Optional[List[str]] = None,
positive_parameters: Optional[List[str]] = None,
):
Expand Down
12 changes: 5 additions & 7 deletions blackjax/smc/tuning/from_particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"particles_means",
"particles_stds",
"particles_covariance_matrix",
"mass_matrix_from_particles",
"inverse_mass_matrix_from_particles",
]


Expand All @@ -28,18 +28,16 @@ def particles_covariance_matrix(particles):
return jnp.cov(particles_as_rows(particles), ddof=0, rowvar=False)


def mass_matrix_from_particles(particles) -> Array:
def inverse_mass_matrix_from_particles(particles) -> Array:
"""
Implements tuning from section 3.1 from https://arxiv.org/pdf/1808.07730.pdf
Computing a mass matrix to be used in HMC from particles.
Given the particles covariance matrix, set all non-diagonal elements as zero,
take the inverse, and keep the diagonal.
Computing an inverse mass matrix to be used in HMC from particles.
Returns
-------
A mass Matrix
An inverse mass matrix
"""
return jnp.diag(1.0 / jnp.var(particles_as_rows(particles), axis=0))
return jnp.diag(jnp.var(particles_as_rows(particles), axis=0))


def particles_as_rows(particles):
Expand Down
Loading

0 comments on commit 7e4241f

Please sign in to comment.