Skip to content

Commit

Permalink
added more notes to self
Browse files Browse the repository at this point in the history
  • Loading branch information
aphc14 committed Oct 2, 2024
1 parent 6ba6c8b commit d90c426
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 13 deletions.
33 changes: 27 additions & 6 deletions blackjax/optimizers/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def scan_body(tup, it):
x=x0,
f=value0,
g=grad0,
alpha=jnp.ones_like(x0),
alpha=jnp.ones_like(x0), # line 2 of Algorithm 3
update_mask=jnp.zeros_like(x0, dtype=bool),
)

Expand Down Expand Up @@ -276,6 +276,7 @@ def compute_next_alpha(s_l, z_l, alpha_lm1):
)
return 1.0 / inv_alpha_l

# Q: shouldn't it be "<" ???
pred = s_l.T @ z_l > (epsilon * jnp.linalg.norm(z_l, 2))
alpha_l = lax.cond(
pred, compute_next_alpha, lambda *_: alpha_lm1, s_l, z_l, alpha_lm1
Expand All @@ -296,20 +297,40 @@ def lbfgs_inverse_hessian_factors(S, Z, alpha):
Pathfinder: Parallel quasi-newton variational inference, Lu Zhang et al., arXiv:2108.03782
"""

param_dims = S.shape[-1]
StZ = S.T @ Z
R = jnp.triu(StZ) + jnp.eye(param_dims) * jnp.finfo(S.dtype).eps
# StZ = jnp.einsum("...lj,...lj->", S, Z)
E = jnp.triu(StZ) + jnp.eye(param_dims) * jnp.finfo(S.dtype).eps

# dim(eta) = (L^max, N) the diagonal elements
# dim(diag(eta)) = (L^max, N, N)
eta = jnp.diag(StZ)

beta = jnp.hstack([jnp.diag(alpha) @ Z, S])

minvR = -jnp.linalg.inv(R)
alphaZ = jnp.diag(jnp.sqrt(alpha)) @ Z
block_dd = minvR.T @ (alphaZ.T @ alphaZ + jnp.diag(eta)) @ minvR
# invE = jnp.linalg.inv(E)
invE = jnp.linalg.solve(E, jnp.eye(param_dims))
block_dd = invE.T @ (jnp.diag(eta) + Z.T @ jnp.diag(alpha) @ Z) @ invE
gamma = jnp.block(
[[jnp.zeros((param_dims, param_dims)), minvR], [minvR.T, block_dd]]
[[jnp.zeros((param_dims, param_dims)), -invE], [-invE.T, block_dd]]
)

# param_dims = S.shape[-1]
# StZ = S.T @ Z
# R = jnp.triu(StZ) + jnp.eye(param_dims) * jnp.finfo(S.dtype).eps

# eta = jnp.diag(StZ)

# beta = jnp.hstack([jnp.diag(alpha) @ Z, S])

# minvR = -jnp.linalg.inv(R)
# alphaZ = jnp.diag(jnp.sqrt(alpha)) @ Z
# block_dd = minvR.T @ (alphaZ.T @ alphaZ + jnp.diag(eta)) @ minvR
# gamma = jnp.block(
# [[jnp.zeros((param_dims, param_dims)), minvR], [minvR.T, block_dd]]
# )

return beta, gamma


Expand Down
32 changes: 25 additions & 7 deletions blackjax/vi/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,24 @@ def approximate(
----------
rng_key
PRPNG key
logdensity_fn
logdensity_fn # @log p
(un-normalized) log densify function of target distribution to take
approximate samples from
initial_position
initial_position # @theta^(0) ~ pi_0
starting point of the L-BFGS optimization routine
num_samples
num_samples # @K
number of samples to draw to estimate ELBO
maxiter
maxiter # @L
Maximum number of iterations of the LGBFS algorithm.
maxcor
maxcor # @J
Maximum number of metric corrections of the LGBFS algorithm ("history
size")
ftol # Q: isn't this relative tolerance? it looks like _minimize_lbfgs treats ftol like relative tolerance
The LGBFS algorithm terminates the minimization when `(f_k - f_{k+1}) < ftol`
gtol
The LGBFS algorithm terminates the minimization when `|g_k|_norm < gtol`
maxls
The maximum number of line search steps (per iteration) for the LGBFS
algorithm
The maximum number of line search steps (per iteration) for the LGBFS algorithm
**lbfgs_kwargs
other keyword arguments passed to `jaxopt.LBFGS`.
Expand All @@ -141,6 +140,9 @@ def approximate(
)

# Get postions and gradients of the optimization path (including the starting point).

# NTS: status.iter_num.item() returns the idx it LBFGS converged

position = history.x
grad_position = history.g
alpha = history.alpha
Expand Down Expand Up @@ -169,6 +171,8 @@ def path_finder_body_fn(rng_key, S, Z, alpha_l, theta, theta_grad):
)
logp = -jax.vmap(objective_fn)(phi)
elbo = (logp - logq).mean() # Algorithm 7 of the paper
# Q: gamma has a very large negative number in one of the indices! -4e+15

return elbo, beta, gamma

# Index and reshape S and Z to be sliding window view shape=(maxiter,
Expand All @@ -179,9 +183,23 @@ def path_finder_body_fn(rng_key, S, Z, alpha_l, theta, theta_grad):
s_j = s_padded[index.reshape(path_size, maxcor)].reshape(path_size, maxcor, -1)
z_j = z_padded[index.reshape(path_size, maxcor)].reshape(path_size, maxcor, -1)
rng_keys = jax.random.split(rng_key, path_size)

# Q: batch_size = (maxiter + 1,) but the s_j, z_j, alpha, position, grad_position can be clipped up till convergence to reduce compute time.
# Q: L <= L^max, therefore, find L and do not use L^max

"""
try this:
clip_from = jnp.argmin(
(jnp.arange(path_size) < (status.iter_num)) & jnp.isfinite(elbo)
)
and return clipped s_j, z_j, alpha, position, grad_position
"""

elbo, beta, gamma = jax.vmap(path_finder_body_fn)(
rng_keys, s_j, z_j, alpha, position, grad_position
)

# TODO: remove -jnp.inf. let's keep all the elbo up till convergence
elbo = jnp.where(
(jnp.arange(path_size) < (status.iter_num)) & jnp.isfinite(elbo),
elbo,
Expand Down

0 comments on commit d90c426

Please sign in to comment.