Skip to content

Commit

Permalink
added hessian approximation
Browse files Browse the repository at this point in the history
  • Loading branch information
OlaRonning committed Feb 23, 2024
1 parent 50324cc commit 74b8b9f
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions numpyro/contrib/ecs_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,11 @@ def log_likelihood_sum(params_flat, subsample_indices=None):
raise ValueError("Taylor proxy only defined for first and second degree.")

# those stats are dict keyed by subsample names
ref_log_lik_sum = log_likelihood_sum(ref_params_flat)
ref_log_lik_grads_sum = jacobian(log_likelihood_sum)(ref_params_flat)
ref_sum_log_lik = log_likelihood_sum(ref_params_flat)
ref_sum_log_lik_grads = jacobian(log_likelihood_sum)(ref_params_flat)

if degree == 2 and not approx:
ref_log_lik_hessians_sum = hessian(log_likelihood_sum)(ref_params_flat)
ref_sum_log_lik_hessians = hessian(log_likelihood_sum)(ref_params_flat)

def gibbs_init(rng_key, gibbs_sites):

Expand Down Expand Up @@ -249,11 +249,8 @@ def proxy_fn(params, subsample_lik_sites, gibbs_state):
)
high_order_terms = 0.0
if degree == 2:
if approx: # TODO: fixme
high_order_terms = 0.5 * jnp.dot(
jnp.dot(ref_subsample_log_lik_hessians[name], params_diff),
params_diff,
)
if approx: # compute z.THz \approx z.T(JJ.T)z = (z.T J)(J.T z) = \sum (z.T J)**2
high_order_terms = 0.5 * (jnp.dot(params_diff, ref_subsample_log_lik_grads[name].T) ** 2).sum()
else:
high_order_terms = 0.5 * jnp.dot(
jnp.dot(ref_subsample_log_lik_hessians[name], params_diff),
Expand All @@ -262,21 +259,17 @@ def proxy_fn(params, subsample_lik_sites, gibbs_state):

proxy_subsample[name] = proxy_subsample[name] + high_order_terms

proxy_sum[name] = ref_log_lik_sum[name] + jnp.dot(
ref_log_lik_grads_sum[name], params_diff
proxy_sum[name] = ref_sum_log_lik[name] + jnp.dot(
ref_sum_log_lik_grads[name], params_diff
)

high_order_terms = 0.0
if degree == 2:
if approx: # TODO: fixme
high_order_terms = 0.5 * jnp.dot(
jnp.dot(ref_log_lik_hessians_sum[name], params_diff),
params_diff,
)
if approx:
high_order_terms = 0.5 * (jnp.dot(ref_sum_log_lik_grads[name], params_diff) ** 2).sum()
else:

high_order_terms = 0.5 * jnp.dot(
jnp.dot(ref_log_lik_hessians_sum[name], params_diff),
jnp.dot(ref_sum_log_lik_hessians[name], params_diff),
params_diff,
)
proxy_sum[name] = proxy_sum[name] + high_order_terms
Expand Down

0 comments on commit 74b8b9f

Please sign in to comment.