Skip to content

Commit

Permalink
remove debug statements
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Feb 10, 2025
1 parent cc7bfbd commit 58d9920
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 33 deletions.
30 changes: 10 additions & 20 deletions blackjax/adaptation/ensemble_mclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,8 @@ def __init__(
# adjustment_factor = jnp.power(0.82 / (num_dims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample)
step_size = (
adaptation_state.step_size
) # * integrator_factor * adjustment_factor
)

# steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1])))

# Initialize the dual averaging adaptation #
# da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target)
# stepsize_adaptation_state = da_init_fn(step_size)

# Initialize the bisection for finding the step size
stepsize_adaptation_state, self.epsadap_update = bisection_monotonic_fn(
Expand Down Expand Up @@ -169,18 +164,18 @@ def emaus(
sample_init,
transform,
ndims,
num_steps1, # max number in phase 1
num_steps2, # fixed number in phase 2
num_steps1,
num_steps2,
num_chains,
mesh,
rng_key,
alpha=1.9, # L = sqrt{d}*alpha*vars
save_frac=0.2, # to end stage one, the fraction of stage 1 samples used to estimate fluctuation. min is: save_frac*num_steps1
C=0.1, # constant in stage 1 that determines step size (eq (9) in paper)
early_stop=True, # for stage 1
r_end=5e-3, # stage1 parameters
alpha=1.9,
save_frac=0.2,
C=0.1,
early_stop=True,
r_end=5e-3,
diagonal_preconditioning=True,
integrator_coefficients=None, # (for stage 2)
integrator_coefficients=None,
steps_per_sample=10,
acc_prob=None,
observables=lambda x: None,
Expand Down Expand Up @@ -229,11 +224,9 @@ def emaus(
C=C,
power=3.0 / 8.0,
r_end=r_end,
# observables=observables,
observables_for_bias=lambda position: jnp.square(
transform(jax.flatten_util.ravel_pytree(position)[0])
),
# contract=contract,
)

final_state, final_adaptation_state, info1 = run_eca(
Expand All @@ -248,7 +241,7 @@ def emaus(
early_stop=early_stop,
)

# refine the results with the adjusted method #
# refine the results with the adjusted method
_acc_prob = acc_prob
if integrator_coefficients is None:
high_dims = ndims > 200
Expand Down Expand Up @@ -297,9 +290,6 @@ def emaus(
num_adaptation_samples,
steps_per_sample,
_acc_prob,
# observables=observables,
# observables_for_bias=observables_for_bias,
# contract=contract,
)

final_state, final_adaptation_state, info2 = run_eca(
Expand Down
14 changes: 1 addition & 13 deletions blackjax/adaptation/ensemble_umclmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ def summary_statistics_fn(state):

def ensemble_init(key, state, signs):
"""flip the velocity, depending on the equipartition condition"""
# velocity = jax.tree_util.tree_map(
# lambda sign, u: sign * u, signs, state.momentum
# )

momentum, unflatten = jax.flatten_util.ravel_pytree(state.momentum)

velocity_flat = jax.tree_util.tree_map(
Expand Down Expand Up @@ -124,14 +122,8 @@ def ensemble_init(key, state, signs):


def update_history(new_vals, history):
# new_vals = jax.flatten_util.ravel_pytree(new_vals)[0]
# history = jax.flatten_util.ravel_pytree(history)[0]
# print(new_vals, "FOOO\n\n")

new_vals, _ = jax.flatten_util.ravel_pytree(new_vals)
# print(history, "FOOO\n\n")
return jnp.concatenate((new_vals[None, :], history[:-1]))
# return history # TODO CHANGE BACK!!!!


def update_history_scalar(new_val, history):
Expand All @@ -146,8 +138,6 @@ def contract_history(theta, weights):

return jnp.array([jnp.max(r), jnp.average(r)])


# used for the early stopping
class History(NamedTuple):
observables: Array
stopping: Array
Expand Down Expand Up @@ -224,8 +214,6 @@ def __init__(
self.contract = contract
self.bias_type = bias_type
self.save_num = save_num
# sigma = unravel_fn(jnp.ones(flat_pytree.shape, dtype = flat_pytree.dtype))

r_save_num = save_num

history = History(
Expand Down

0 comments on commit 58d9920

Please sign in to comment.