diff --git a/numpyro/contrib/ecs_proxies.py b/numpyro/contrib/ecs_proxies.py index 43236e6a9..039fb7618 100644 --- a/numpyro/contrib/ecs_proxies.py +++ b/numpyro/contrib/ecs_proxies.py @@ -172,11 +172,11 @@ def log_likelihood_sum(params_flat, subsample_indices=None): raise ValueError("Taylor proxy only defined for first and second degree.") # those stats are dict keyed by subsample names - ref_log_lik_sum = log_likelihood_sum(ref_params_flat) - ref_log_lik_grads_sum = jacobian(log_likelihood_sum)(ref_params_flat) + ref_sum_log_lik = log_likelihood_sum(ref_params_flat) + ref_sum_log_lik_grads = jacobian(log_likelihood_sum)(ref_params_flat) if degree == 2 and not approx: - ref_log_lik_hessians_sum = hessian(log_likelihood_sum)(ref_params_flat) + ref_sum_log_lik_hessians = hessian(log_likelihood_sum)(ref_params_flat) def gibbs_init(rng_key, gibbs_sites): @@ -249,11 +249,8 @@ def proxy_fn(params, subsample_lik_sites, gibbs_state): ) high_order_terms = 0.0 if degree == 2: - if approx: # TODO: fixme - high_order_terms = 0.5 * jnp.dot( - jnp.dot(ref_subsample_log_lik_hessians[name], params_diff), - params_diff, - ) + if approx: # compute z.THz \approx z.T(JJ.T)z = (z.T J)(J.T z) = \sum (z.T J)**2 + high_order_terms = 0.5 * (jnp.dot(params_diff, ref_subsample_log_lik_grads[name].T) ** 2).sum() else: high_order_terms = 0.5 * jnp.dot( jnp.dot(ref_subsample_log_lik_hessians[name], params_diff), @@ -262,21 +259,17 @@ def proxy_fn(params, subsample_lik_sites, gibbs_state): proxy_subsample[name] = proxy_subsample[name] + high_order_terms - proxy_sum[name] = ref_log_lik_sum[name] + jnp.dot( - ref_log_lik_grads_sum[name], params_diff + proxy_sum[name] = ref_sum_log_lik[name] + jnp.dot( + ref_sum_log_lik_grads[name], params_diff ) high_order_terms = 0.0 if degree == 2: - if approx: # TODO: fixme - high_order_terms = 0.5 * jnp.dot( - jnp.dot(ref_log_lik_hessians_sum[name], params_diff), - params_diff, - ) + if approx: + high_order_terms = 0.5 * (jnp.dot(ref_sum_log_lik_grads[name], params_diff) ** 2).sum() else: - high_order_terms = 0.5 * jnp.dot( - jnp.dot(ref_log_lik_hessians_sum[name], params_diff), + jnp.dot(ref_sum_log_lik_hessians[name], params_diff), params_diff, ) proxy_sum[name] = proxy_sum[name] + high_order_terms