Skip to content

Commit

Permalink
Arnoldi & Lanczos output and naming convention parity (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
AndPotap authored Sep 7, 2023
2 parents 76a799a + 15c64d6 commit dc4905b
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 24 deletions.
5 changes: 2 additions & 3 deletions cola/algorithms/arnoldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import cola
from cola import Stiefel, lazify


# def arnoldi_eigs_bwd(res, grads, unflatten, *args, **kwargs):
# val_grads, eig_grads, _ = grads
# op_args, (eig_vals, eig_vecs, _) = res
Expand All @@ -34,8 +33,8 @@
# return (dA, )


#@export
#@iterative_autograd(arnoldi_eigs_bwd)
# @export
# @iterative_autograd(arnoldi_eigs_bwd)
@export
def arnoldi_eigs(A: LinearOperator, start_vector: Array = None, max_iters: int = 100,
tol: float = 1e-7, use_householder: bool = False, pbar: bool = False):
Expand Down
19 changes: 8 additions & 11 deletions cola/algorithms/lanczos.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def lanczos_max_eig(A: LinearOperator, rhs: Array, max_iters: int, tol: float =
max_iters: int maximum number of iters to run lanczos
tol: float: tolerance criteria to stop lanczos
"""
eigvals, *_ = lanczos(A=A, start_vector=rhs, max_iters=max_iters, tol=tol)
eigvals, *_ = lanczos_eigs(A=A, start_vector=rhs, max_iters=max_iters, tol=tol)
return eigvals[-1]


Expand Down Expand Up @@ -50,8 +50,8 @@ def altogether(*theta):
# @export
# @iterative_autograd(lanczos_eig_bwd)
@export
def lanczos(A: LinearOperator, start_vector: Array = None, max_iters: int = 100, tol: float = 1e-7,
pbar: bool = False):
def lanczos_eigs(A: LinearOperator, start_vector: Array = None, max_iters: int = 100,
tol: float = 1e-7, pbar: bool = False):
"""
Computes the eigenvalues and eigenvectors using Lanczos.
Expand All @@ -71,12 +71,11 @@ def lanczos(A: LinearOperator, start_vector: Array = None, max_iters: int = 100,
"""
xnp = A.xnp
Q, T, info = lanczos_decomp(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol,
pbar=pbar)
Q, T, info = lanczos(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol, pbar=pbar)
eigvals, eigvectors = xnp.eigh(T)
idx = xnp.argsort(eigvals, axis=-1)
V = lazify(Q) @ lazify(eigvectors[:,idx])
V = lazify(Q) @ lazify(eigvectors[:, idx])

eigvals = eigvals[..., idx]
# V = V[..., idx]
return eigvals, V, info
Expand All @@ -85,16 +84,14 @@ def lanczos(A: LinearOperator, start_vector: Array = None, max_iters: int = 100,
def LanczosDecomposition(A: LinearOperator, start_vector=None, max_iters=100, tol=1e-7, pbar=False):
""" Provides the Lanczos decomposition of a matrix A = Q T Q^*.
LinearOperator form of lanczos, see lanczos for arguments."""
Q, T, info = lanczos_decomp(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol,
pbar=pbar)
Q, T, info = lanczos(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol, pbar=pbar)
A_approx = cola.UnitaryDecomposition(lazify(Q), SelfAdjoint(lazify(T)))
A_approx.info = info
return A_approx


@export
def lanczos_decomp(A: LinearOperator, start_vector: Array = None, max_iters=100, tol=1e-7,
pbar=False):
def lanczos(A: LinearOperator, start_vector: Array = None, max_iters=100, tol=1e-7, pbar=False):
"""
Computes the Lanczos decomposition of a the operator A, A = Q T Q^*.
Expand Down
2 changes: 1 addition & 1 deletion cola/algorithms/slq.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def slq_fwd(A, fun, num_samples, max_iters, tol, pbar, key):
tau = Q[..., 0, :]
# approx = xnp.sum(tau**2 * fun(eigvals), axis=-1)
# fn_vals = xnp.where(xnp.abs(eigvals) > _mp, fun(eigvals), xnp.zeros_like(eigvals))
const = 10*_mp * xnp.max(eigvals, axis=1, keepdims=True)
const = 10 * _mp * xnp.max(eigvals, axis=1, keepdims=True)
fn_vals = xnp.where(xnp.abs(eigvals) > const, fun(eigvals), xnp.zeros_like(eigvals))
approx = xnp.sum(tau**2 * fn_vals, axis=-1)
estimate = A.shape[-2] * approx
Expand Down
4 changes: 2 additions & 2 deletions cola/linalg/eigs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from cola.ops import Identity
from cola.ops import Triangular
from cola.algorithms import power_iteration
from cola.algorithms.lanczos import lanczos
from cola.algorithms.lanczos import lanczos_eigs
from cola.algorithms.arnoldi import arnoldi_eigs
from cola.utils import export

Expand Down Expand Up @@ -75,7 +75,7 @@ def eig(A: LinearOperator, **kwargs):
eig_vals, eig_vecs = xnp.eigh(A.to_dense())
return eig_vals[eig_slice], Stiefel(lazify(eig_vecs[:, eig_slice]))
elif method in ('lanczos', 'iterative') or (method == 'auto' and prod(A.shape) >= 1e6):
eig_vals, eig_vecs, _ = lanczos(A, **kws)
eig_vals, eig_vecs, _ = lanczos_eigs(A, **kws)
return eig_vals, eig_vecs
else:
raise ValueError(f"Unknown method {method} for SelfAdjoint operator")
Expand Down
29 changes: 24 additions & 5 deletions tests/algorithms/test_arnoldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from cola.ops import Householder
from cola.ops import Product
from cola.ops import Dense
from cola.algorithms.arnoldi import get_householder_vec
from cola.fns import lazify
from cola.algorithms.arnoldi import get_arnoldi_matrix
from cola.algorithms.arnoldi import arnoldi_eigs
Expand Down Expand Up @@ -92,7 +91,7 @@ def test_householder_arnoldi_decomp(backend):
# A_np, rhs_np = np.array(A, dtype=np.complex128), np.array(rhs[:, 0], dtype=np.complex128)
A_np, rhs_np = np.array(A, dtype=np.float64), np.array(rhs[:, 0], dtype=np.float64)
# Q_sol, H_sol = run_householder_arnoldi(A, rhs, A.shape[0], np.float64, xnp)
Q_sol, H_sol = run_householder_arnoldi_np(A_np, rhs_np, A.shape[0], np.float64, xnp)
Q_sol, H_sol = run_householder_arnoldi_np(A_np, rhs_np, A.shape[0], np.float64)

# fn = run_householder_arnoldi
fn = xnp.jit(run_householder_arnoldi, static_argnums=(0, 2))
Expand Down Expand Up @@ -146,7 +145,7 @@ def test_numpy_arnoldi(backend):
rhs = np.random.normal(size=(A.shape[0], ))
# rhs = np.random.normal(size=(A.shape[0], 2)).view(np.complex128)[:, 0]

Q, H = run_householder_arnoldi_np(A, rhs, max_iter=A.shape[0], dtype=dtype, xnp=xnp)
Q, H = run_householder_arnoldi_np(A, rhs, max_iter=A.shape[0], dtype=dtype)
abs_error = np.linalg.norm(np.eye(A.shape[0]) - Q.T @ Q)
assert abs_error < 1e-4
abs_error = np.linalg.norm(Q.T @ A @ Q - H)
Expand All @@ -159,10 +158,10 @@ def test_numpy_arnoldi(backend):
assert abs_error < 1e-10


def run_householder_arnoldi_np(A, rhs, max_iter, dtype, xnp):
def run_householder_arnoldi_np(A, rhs, max_iter, dtype):
H, Q, Ps, zj = initialize_householder_arnoldi(rhs, max_iter, dtype)
for jdx in range(1, max_iter + 2):
vec, beta = get_householder_vec(zj, jdx - 1, xnp)
vec, beta = get_householder_vec_np(zj, jdx - 1)
Ps[jdx].vec, Ps[jdx].beta = vec[:, None], beta
H[:, jdx - 1] = np.array(Ps[jdx] @ zj)
if jdx <= max_iter:
Expand All @@ -186,6 +185,26 @@ def initialize_householder_arnoldi(rhs, max_iter, dtype):
return H, Q, Ps, zj


def get_householder_vec_np(x, idx):
sigma_2 = np.linalg.norm(x[idx + 1:])**2.
vec = np.zeros_like(x)
vec[idx:] = x[idx:]
if sigma_2 == 0 and x[idx] >= 0:
beta = 0
elif sigma_2 == 0 and x[idx] < 0:
beta = -2
else:
x_norm_partial = np.sqrt(x[idx]**2 + sigma_2)
if x[idx] <= 0:
vec[idx] = x[idx] - x_norm_partial
else:
vec[idx] = -sigma_2 / (x[idx] + x_norm_partial)
beta = 2 * vec[idx]**2 / (sigma_2 + vec[idx]**2)
vec = vec / vec[idx]
vec[idx:] = vec[idx:] / vec[idx]
return vec, beta


def run_arnoldi(A, rhs, max_iter, tol, dtype):
Q, H = initialize_arnoldi(rhs, max_iter=max_iter, dtype=dtype)
idx, vec = 0, rhs.copy()
Expand Down
4 changes: 2 additions & 2 deletions tests/algorithms/test_lanczos.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from cola.algorithms.lanczos import construct_tridiagonal_batched
from cola.algorithms.lanczos import get_lanczos_coeffs
from cola.algorithms.lanczos import lanczos_parts
from cola.algorithms.lanczos import lanczos
from cola.algorithms.lanczos import lanczos_eigs
from cola.algorithms.lanczos import lanczos_max_eig
from cola.utils_test import get_xnp, parametrize, relative_error
from cola.utils_test import generate_spectrum, generate_pd_from_diag
Expand Down Expand Up @@ -35,7 +35,7 @@ def test_lanczos_vjp(backend):

def f(theta):
Aop = unflatten([theta])
out = lanczos(Aop, x0, max_iters=10, tol=1e-6, pbar=False)
out = lanczos_eigs(Aop, x0, max_iters=10, tol=1e-6, pbar=False)
eig_vals, eig_vecs, _ = out
# loss = xnp.sum(eig_vals ** 2.) + xnp.sum(xnp.abs(eig_vecs), axis=[0, 1])
loss = xnp.sum(eig_vals**2.)
Expand Down

0 comments on commit dc4905b

Please sign in to comment.