Skip to content

Commit

Permalink
FIX BUGS IN MHMCHMC AND TUNING
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Jan 11, 2024
1 parent a158710 commit 98439f8
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 103 deletions.
92 changes: 38 additions & 54 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ class MCLMCAdaptationState(NamedTuple):
L: float
step_size: float

def streaming_average(O, x, streaming_avg, weight, zero_prevention):
"""streaming average of f(x)"""
total, average = streaming_avg
average = (total * average + weight * O(x)) / (total + weight + zero_prevention)
total += weight
streaming_avg = (total, average)
return streaming_avg

def mclmc_find_L_and_step_size(
mclmc_kernel,
Expand Down Expand Up @@ -184,37 +191,27 @@ def predictor(previous_state, params, adaptive_state, rng_key):

return state, params_new, adaptive_state, success

def streaming_average(O, x, streaming_state, outer_weight, success, step_size):
"""streaming average of f(x)"""
total, average = streaming_state
weight = outer_weight * step_size * success
zero_prevention = 1 - outer_weight
average = (total * average + weight * O(x)) / (total + weight + zero_prevention)
total += weight
streaming_state = (total, average)
return streaming_state

def step(iteration_state, weight_and_key):
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""

outer_weight, rng_key = weight_and_key
state, params, adaptive_state, streaming_state = iteration_state
mask, rng_key = weight_and_key
state, params, adaptive_state, streaming_avg = iteration_state

state, params, adaptive_state, success = predictor(
state, params, adaptive_state, rng_key
)

# update the running average of x, x^2
streaming_state = streaming_average(
lambda x: jnp.array([x, jnp.square(x)]),
ravel_pytree(state.position)[0],
streaming_state,
outer_weight,
success,
params.step_size,
streaming_avg = streaming_average(
O=lambda x: jnp.array([x, jnp.square(x)]),
x=ravel_pytree(state.position)[0],
streaming_avg=streaming_avg,
weight=(1-mask)*success*params.step_size,
zero_prevention=mask,
)

return (state, params, adaptive_state, streaming_state), None
return (state, params, adaptive_state, streaming_avg), None

def L_step_size_adaptation(state, params, num_steps, rng_key):
num_steps1, num_steps2 = int(num_steps * frac_tune1), int(
Expand All @@ -223,7 +220,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2)

# we use the last num_steps2 to compute the diagonal preconditioner
outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))
mask = 1-jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

# initial state of the kalman filter

Expand All @@ -236,8 +233,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
(0.0, 0.0, jnp.inf),
(0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])),
),
xs=(outer_weights, L_step_size_adaptation_keys),
length=num_steps1 + num_steps2,
xs=(mask, L_step_size_adaptation_keys),
)[0]

L = params.L
Expand Down Expand Up @@ -313,9 +309,6 @@ def mhmchmc_find_L_and_step_size(
frac_tune1=0.1,
frac_tune2=0.1,
frac_tune3=0.1,
desired_energy_var=5e-4,
trust_in_estimate=1.5,
num_effective_samples=150,
):
"""
Finds the optimal value of the parameters for the MH-MCHMC algorithm.
Expand Down Expand Up @@ -380,30 +373,23 @@ def mhmchmc_make_L_step_size_adaptation(
target=0.65
)

def streaming_average(O, x, streaming_state, outer_weight, success, step_size):
"""streaming average of f(x)"""
total, average = streaming_state
weight = outer_weight * step_size * success
zero_prevention = 1 - outer_weight
average = (total * average + weight * O(x)) / (total + weight + zero_prevention)
total += weight
streaming_state = (total, average)
return streaming_state


def step(iteration_state, weight_and_key):
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""

outer_weight, rng_key = weight_and_key
previous_state, params, adaptive_state, streaming_state = iteration_state
mask, rng_key = weight_and_key
previous_state, params, adaptive_state, streaming_avg = iteration_state

step_size_max = 1.0

jax.debug.print("{x}",x=(params.step_size, params.L//params.step_size))

# dynamics
next_state, info = kernel(
rng_key=rng_key,
state=previous_state,
# L=params.L,
num_integration_steps=params.L//1,
num_integration_steps=1 + (params.L//params.step_size),
step_size=params.step_size,
)

Expand All @@ -418,6 +404,7 @@ def step(iteration_state, weight_and_key):
info.energy,
)

# jax.debug.print("{x}",x=(info.acceptance_rate))
adaptive_state = update(
adaptive_state, info.acceptance_rate
)
Expand All @@ -426,18 +413,17 @@ def step(iteration_state, weight_and_key):
step_size = jnp.exp(adaptive_state.log_step_size)

# update the running average of x, x^2
streaming_state = streaming_average(
lambda x: jnp.array([x, jnp.square(x)]),
ravel_pytree(state.position)[0],
streaming_state,
outer_weight,
success,
step_size,
streaming_avg = streaming_average(
O=lambda x: jnp.array([x, jnp.square(x)]),
x=ravel_pytree(state.position)[0],
streaming_avg=streaming_avg,
weight=(1-mask)*success*step_size,
zero_prevention=mask,
)

params = params._replace(step_size=step_size)

return (state, params, adaptive_state, streaming_state), None
return (state, params, adaptive_state, streaming_avg), None


def L_step_size_adaptation(state, params, num_steps, rng_key):
Expand All @@ -446,11 +432,10 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
)
L_step_size_adaptation_keys = jax.random.split(rng_key, num_steps1 + num_steps2)

# we use the last num_steps2 to compute the diagonal preconditioner
outer_weights = jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

# initial state of the kalman filter
# determine which steps to ignore in the streaming average
mask = 1- jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

# dual averaging initialization
init_adaptive_state = init(params.step_size)

# run the steps
Expand All @@ -459,11 +444,10 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
init=(
state,
params,
init_adaptive_state,
(0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])),
init_adaptive_state, # state of the dual averaging algorithm
(0.0, jnp.array([jnp.zeros(dim), jnp.zeros(dim)])), # streaming average of t, x, x^2
),
xs=(outer_weights, L_step_size_adaptation_keys),
length=num_steps1 + num_steps2,
xs=(mask, L_step_size_adaptation_keys),
)[0]

L = params.L
Expand Down
64 changes: 17 additions & 47 deletions blackjax/mcmc/mhmchmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import blackjax.mcmc.metrics as metrics
import blackjax.mcmc.trajectory as trajectory
from blackjax.base import SamplingAlgorithm
from blackjax.mcmc.hmc import HMCInfo, HMCState
from blackjax.mcmc.hmc import HMCInfo, HMCState, flip_momentum
from blackjax.mcmc.proposal import safe_energy_diff, static_binomial_sampling

# from blackjax.mcmc.trajectory import mhmchmc_energy
Expand All @@ -35,23 +35,17 @@
"mhmchmc",
]



def init(
position: ArrayLikeTree, logdensity_fn: Callable, random_generator_arg: Array
):
logdensity, logdensity_grad = jax.value_and_grad(logdensity_fn)(position)
return DynamicHMCState(position, logdensity, logdensity_grad, random_generator_arg)




def build_kernel(
integrator: Callable = integrators.isokinetic_mclachlan,
divergence_threshold: float = 1000,
next_random_arg_fn: Callable = lambda key: jax.random.split(key)[1],
integration_steps_fn: Callable = lambda key: jax.random.randint(key, (), 1, 10),
# integration_steps_fn: Callable = lambda key: 10,
):
"""Build a Dynamic HMC kernel where the number of integration steps is chosen randomly.
Expand Down Expand Up @@ -83,40 +77,32 @@ def kernel(
**integration_steps_kwargs,
) -> tuple[DynamicHMCState, HMCInfo]:
"""Generate a new sample with the MHMCHMC kernel."""

num_integration_steps = integration_steps_fn(
state.random_generator_arg, **integration_steps_kwargs
)
mhmchmc_state = HMCState(
state.position, state.logdensity, state.logdensity_grad
)
key_momentum, key_integrator = jax.random.split(rng_key, 2)
momentum = generate_unit_vector(key_momentum, state.position)

proposal_generator = mhmchmc_proposal(

proposal, info, _ = mhmchmc_proposal(
integrator(logdensity_fn),
step_size,
num_integration_steps,
divergence_threshold,
)(
key_integrator,
integrators.IntegratorState(
state.position, momentum, state.logdensity, state.logdensity_grad
)
)

key_momentum, key_integrator = jax.random.split(rng_key, 2)

position, logdensity, logdensity_grad = mhmchmc_state
momentum = generate_unit_vector(key_momentum, position)

integrator_state = integrators.IntegratorState(
position, momentum, logdensity, logdensity_grad
)
proposal, info, _ = proposal_generator(key_integrator, integrator_state)
proposal = HMCState(
proposal.position, proposal.logdensity, proposal.logdensity_grad
)

next_random_arg = next_random_arg_fn(state.random_generator_arg)
return (
DynamicHMCState(
proposal.position,
proposal.logdensity,
proposal.logdensity_grad,
next_random_arg,
next_random_arg_fn(state.random_generator_arg),
),
info,
)
Expand Down Expand Up @@ -221,25 +207,27 @@ def step(i, vars):
next_state, next_kinetic_energy = integrator(state, step_size=step_size)
return next_state, kinetic_energy + next_kinetic_energy

def build_trajectory(state, step_size, num_integration_steps):
def build_trajectory(state, num_integration_steps):
return jax.lax.fori_loop(0*num_integration_steps, num_integration_steps, step, (state, 0))

mhmchmc_energy_fn = lambda state, kinetic_energy: state.logdensity + kinetic_energy
mhmchmc_energy_fn = lambda state, kinetic_energy: -state.logdensity + kinetic_energy

def generate(
rng_key, state: integrators.IntegratorState
) -> tuple[integrators.IntegratorState, HMCInfo, ArrayTree]:
"""Generate a new chain state."""
end_state, kinetic_energy = build_trajectory(
state, step_size, num_integration_steps
state, num_integration_steps
)
end_state = flip_momentum(end_state)
proposal_energy = mhmchmc_energy_fn(state, kinetic_energy)
new_energy = mhmchmc_energy_fn(end_state, kinetic_energy)
# jax.debug.print("mhmchmc 225 {x}", x=(proposal_energy, new_energy))
delta_energy = safe_energy_diff(proposal_energy, new_energy)
is_diverging = -delta_energy > divergence_threshold
sampled_state, info = sample_proposal(rng_key, delta_energy, state, end_state)
do_accept, p_accept, other_proposal_info = info
# jax.debug.print("mhmchmc 230 {x}", x=(do_accept, p_accept))

info = HMCInfo(
state.momentum,
Expand All @@ -256,24 +244,6 @@ def generate(
return generate


def flip_momentum(
state: integrators.IntegratorState,
) -> integrators.IntegratorState:
"""Flip the momentum at the end of the trajectory.
To guarantee time-reversibility (hence detailed balance) we
need to flip the last state's momentum. If we run the hamiltonian
dynamics starting from the last state with flipped momentum we
should indeed retrieve the initial state (with flipped momentum).
"""
flipped_momentum = jax.tree_util.tree_map(lambda m: -1.0 * m, state.momentum)
return integrators.IntegratorState(
state.position,
flipped_momentum,
state.logdensity,
state.logdensity_grad,
)


def rescale(mu):
Expand Down
9 changes: 7 additions & 2 deletions explore-mhmchmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def logdensity_fn(x):
return -0.5 * jnp.sum(jnp.square(x))

initial_position = jnp.array([1.0, 1.0])
initial_position = jnp.array([0.01, 0.01])

def run_hmc(initial_position):

Expand Down Expand Up @@ -70,6 +70,8 @@ def run_mhmchmc_dynamic(initial_position):
rng_key=tune_key,
)

# raise Exception


# step_size = 1.0784992
# L = 1.7056025
Expand Down Expand Up @@ -121,10 +123,13 @@ def run_mclmc(logdensity_fn, num_steps, initial_position):

print(blackjax_mclmc_sampler_params)


# compare_static_dynamic()


# run_mclmc(logdensity_fn, num_steps, initial_position)
out = run_mclmc(logdensity_fn, num_steps, initial_position)
# print(out.position.mean(axis=0) )


# out = run_hmc(initial_position)
out = run_mhmchmc_dynamic(initial_position)
Expand Down

0 comments on commit 98439f8

Please sign in to comment.