From 58d9920a2fa6281334e7832cb4a6c37dbc9181f9 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 10 Feb 2025 12:11:55 -0500 Subject: [PATCH] remove debug statements --- blackjax/adaptation/ensemble_mclmc.py | 30 +++++++++----------------- blackjax/adaptation/ensemble_umclmc.py | 14 +----------- 2 files changed, 11 insertions(+), 33 deletions(-) diff --git a/blackjax/adaptation/ensemble_mclmc.py b/blackjax/adaptation/ensemble_mclmc.py index 7294f4936..7a117665f 100644 --- a/blackjax/adaptation/ensemble_mclmc.py +++ b/blackjax/adaptation/ensemble_mclmc.py @@ -84,13 +84,8 @@ def __init__( # adjustment_factor = jnp.power(0.82 / (num_dims * adaptation_state.EEVPD), 0.25) / jnp.sqrt(steps_per_sample) step_size = ( adaptation_state.step_size - ) # * integrator_factor * adjustment_factor + ) - # steps_per_sample = (int)(jnp.max(jnp.array([Lfull / step_size, 1]))) - - # Initialize the dual averaging adaptation # - # da_init_fn, self.epsadap_update, _ = dual_averaging_adaptation(target= acc_prob_target) - # stepsize_adaptation_state = da_init_fn(step_size) # Initialize the bisection for finding the step size stepsize_adaptation_state, self.epsadap_update = bisection_monotonic_fn( @@ -169,18 +164,18 @@ def emaus( sample_init, transform, ndims, - num_steps1, # max number in phase 1 - num_steps2, # fixed number in phase 2 + num_steps1, + num_steps2, num_chains, mesh, rng_key, - alpha=1.9, # L = sqrt{d}*alpha*vars - save_frac=0.2, # to end stage one, the fraction of stage 1 samples used to estimate fluctuation. min is: save_frac*num_steps1 - C=0.1, # constant in stage 1 that determines step size (eq (9) in paper) - early_stop=True, # for stage 1 - r_end=5e-3, # stage1 parameters + alpha=1.9, + save_frac=0.2, + C=0.1, + early_stop=True, + r_end=5e-3, diagonal_preconditioning=True, - integrator_coefficients=None, # (for stage 2) + integrator_coefficients=None, steps_per_sample=10, acc_prob=None, observables=lambda x: None, @@ -229,11 +224,9 @@ def emaus( C=C, power=3.0 / 8.0, r_end=r_end, - # observables=observables, observables_for_bias=lambda position: jnp.square( transform(jax.flatten_util.ravel_pytree(position)[0]) ), - # contract=contract, ) final_state, final_adaptation_state, info1 = run_eca( @@ -248,7 +241,7 @@ def emaus( early_stop=early_stop, ) - # refine the results with the adjusted method # + # refine the results with the adjusted method _acc_prob = acc_prob if integrator_coefficients is None: high_dims = ndims > 200 @@ -297,9 +290,6 @@ def emaus( num_adaptation_samples, steps_per_sample, _acc_prob, - # observables=observables, - # observables_for_bias=observables_for_bias, - # contract=contract, ) final_state, final_adaptation_state, info2 = run_eca( diff --git a/blackjax/adaptation/ensemble_umclmc.py b/blackjax/adaptation/ensemble_umclmc.py index f919a82e8..830099ce6 100644 --- a/blackjax/adaptation/ensemble_umclmc.py +++ b/blackjax/adaptation/ensemble_umclmc.py @@ -88,9 +88,7 @@ def summary_statistics_fn(state): def ensemble_init(key, state, signs): """flip the velocity, depending on the equipartition condition""" - # velocity = jax.tree_util.tree_map( - # lambda sign, u: sign * u, signs, state.momentum - # ) + momentum, unflatten = jax.flatten_util.ravel_pytree(state.momentum) velocity_flat = jax.tree_util.tree_map( @@ -124,14 +122,8 @@ def ensemble_init(key, state, signs): def update_history(new_vals, history): - # new_vals = jax.flatten_util.ravel_pytree(new_vals)[0] - # history = jax.flatten_util.ravel_pytree(history)[0] - # print(new_vals, "FOOO\n\n") - new_vals, _ = jax.flatten_util.ravel_pytree(new_vals) - # print(history, "FOOO\n\n") return jnp.concatenate((new_vals[None, :], history[:-1])) - # return history # TODO CHANGE BACK!!!! def update_history_scalar(new_val, history): @@ -146,8 +138,6 @@ def contract_history(theta, weights): return jnp.array([jnp.max(r), jnp.average(r)]) - -# used for the early stopping class History(NamedTuple): observables: Array stopping: Array @@ -224,8 +214,6 @@ def __init__( self.contract = contract self.bias_type = bias_type self.save_num = save_num - # sigma = unravel_fn(jnp.ones(flat_pytree.shape, dtype = flat_pytree.dtype)) - r_save_num = save_num history = History(