Skip to content

Commit

Permalink
ready for test
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Dec 27, 2024
1 parent 996258c commit 4c7f07b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 14 deletions.
30 changes: 17 additions & 13 deletions blackjax/adaptation/adjusted_mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def adjusted_mclmc_find_L_and_step_size(
params=None,
max="avg",
num_windows=1,
tuning_factor=1.0,
tuning_factor=1.3,
):
"""
Finds the optimal value of the parameters for the MH-MCHMC algorithm.
Expand All @@ -50,12 +50,17 @@ def adjusted_mclmc_find_L_and_step_size(
The fraction of tuning for the second step of the adaptation.
frac_tune3
The fraction of tuning for the third step of the adaptation.
desired_energy_va
The desired energy variance for the MCMC algorithm.
trust_in_estimate
The trust in the estimate of optimal stepsize.
num_effective_samples
The number of effective samples for the MCMC algorithm.
diagonal_preconditioning
Whether to do diagonal preconditioning (i.e. a mass matrix)
params
Initial params to start tuning from (optional)
max
whether to calculate L from maximum or average eigenvalue. Average is advised.
num_windows
how many iterations of the tuning are carried out
tuning_factor
multiplicative factor for L
Returns
-------
Expand Down Expand Up @@ -126,7 +131,7 @@ def adjusted_mclmc_make_L_step_size_adaptation(
max="avg",
tuning_factor=1.0,
):
"""Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC"""
"""Adapts the stepsize and L of the MCLMC kernel. Designed for adjusted MCLMC"""

def dual_avg_step(fix_L, update_da):
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""
Expand Down Expand Up @@ -199,10 +204,7 @@ def step(iteration_state, weight_and_key):
+ (1 - mask) * params.L,
)

if max != "max_svd":
state_position = None
else:
state_position = state.position
state_position = state.position

return (
state,
Expand Down Expand Up @@ -352,7 +354,9 @@ def step(state, key):
ess = contract(effective_sample_size(flat_samples[None, ...])) / num_steps

return state, params._replace(
L=jnp.clip(Lfactor * params.L / jnp.mean(ess), max=params.L * 2)
L=jnp.clip(
Lfactor * params.L / jnp.mean(ess), max=params.L * Lratio_upperbound
)
)

return adaptation_L
Expand Down
4 changes: 3 additions & 1 deletion blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def mclmc_find_L_and_step_size(
The trust in the estimate of optimal stepsize.
num_effective_samples
The number of effective samples for the MCMC algorithm.
diagonal_preconditioning
Whether to do diagonal preconditioning (i.e. a mass matrix)
Returns
-------
Expand Down Expand Up @@ -137,7 +139,7 @@ def make_L_step_size_adaptation(
trust_in_estimate=1.5,
num_effective_samples=150,
):
"""Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC"""
"""Adapts the stepsize and L of the MCLMC kernel. Designed for unadjusted MCLMC"""

decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0)

Expand Down

0 comments on commit 4c7f07b

Please sign in to comment.