From ad4950b4cdec0f96475ead327d1da47f6c0a2238 Mon Sep 17 00:00:00 2001 From: AndPotap Date: Thu, 7 Sep 2023 11:42:34 -0400 Subject: [PATCH 1/3] Achieved Lanczos and Arnoldi parity of output and naming convention --- cola/algorithms/arnoldi.py | 5 ++--- cola/algorithms/lanczos.py | 19 ++++++++----------- cola/algorithms/slq.py | 2 +- cola/linalg/eigs.py | 4 ++-- tests/algorithms/test_lanczos.py | 4 ++-- 5 files changed, 15 insertions(+), 19 deletions(-) diff --git a/cola/algorithms/arnoldi.py b/cola/algorithms/arnoldi.py index 01df4526..54dc7e14 100644 --- a/cola/algorithms/arnoldi.py +++ b/cola/algorithms/arnoldi.py @@ -7,7 +7,6 @@ import cola from cola import Stiefel, lazify - # def arnoldi_eigs_bwd(res, grads, unflatten, *args, **kwargs): # val_grads, eig_grads, _ = grads # op_args, (eig_vals, eig_vecs, _) = res @@ -34,8 +33,8 @@ # return (dA, ) -#@export -#@iterative_autograd(arnoldi_eigs_bwd) +# @export +# @iterative_autograd(arnoldi_eigs_bwd) @export def arnoldi_eigs(A: LinearOperator, start_vector: Array = None, max_iters: int = 100, tol: float = 1e-7, use_householder: bool = False, pbar: bool = False): diff --git a/cola/algorithms/lanczos.py b/cola/algorithms/lanczos.py index 59c7150b..49a8f3bb 100644 --- a/cola/algorithms/lanczos.py +++ b/cola/algorithms/lanczos.py @@ -17,7 +17,7 @@ def lanczos_max_eig(A: LinearOperator, rhs: Array, max_iters: int, tol: float = max_iters: int maximum number of iters to run lanczos tol: float: tolerance criteria to stop lanczos """ - eigvals, *_ = lanczos(A=A, start_vector=rhs, max_iters=max_iters, tol=tol) + eigvals, *_ = lanczos_eigs(A=A, start_vector=rhs, max_iters=max_iters, tol=tol) return eigvals[-1] @@ -50,8 +50,8 @@ def altogether(*theta): # @export # @iterative_autograd(lanczos_eig_bwd) @export -def lanczos(A: LinearOperator, start_vector: Array = None, max_iters: int = 100, tol: float = 1e-7, - pbar: bool = False): +def lanczos_eigs(A: LinearOperator, start_vector: Array = None, max_iters: int = 100, + tol: float = 1e-7, pbar: bool = False): """ Computes the eigenvalues and eigenvectors using Lanczos. @@ -71,12 +71,11 @@ def lanczos(A: LinearOperator, start_vector: Array = None, max_iters: int = 100, """ xnp = A.xnp - Q, T, info = lanczos_decomp(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol, - pbar=pbar) + Q, T, info = lanczos(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol, pbar=pbar) eigvals, eigvectors = xnp.eigh(T) idx = xnp.argsort(eigvals, axis=-1) - V = lazify(Q) @ lazify(eigvectors[:,idx]) - + V = lazify(Q) @ lazify(eigvectors[:, idx]) + eigvals = eigvals[..., idx] # V = V[..., idx] return eigvals, V, info @@ -85,16 +84,14 @@ def lanczos(A: LinearOperator, start_vector: Array = None, max_iters: int = 100, def LanczosDecomposition(A: LinearOperator, start_vector=None, max_iters=100, tol=1e-7, pbar=False): """ Provides the Lanczos decomposition of a matrix A = Q T Q^*. LinearOperator form of lanczos, see lanczos for arguments.""" - Q, T, info = lanczos_decomp(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol, - pbar=pbar) + Q, T, info = lanczos(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol, pbar=pbar) A_approx = cola.UnitaryDecomposition(lazify(Q), SelfAdjoint(lazify(T))) A_approx.info = info return A_approx @export -def lanczos_decomp(A: LinearOperator, start_vector: Array = None, max_iters=100, tol=1e-7, - pbar=False): +def lanczos(A: LinearOperator, start_vector: Array = None, max_iters=100, tol=1e-7, pbar=False): """ Computes the Lanczos decomposition of a the operator A, A = Q T Q^*. diff --git a/cola/algorithms/slq.py b/cola/algorithms/slq.py index ae7d6c7a..a8d18236 100644 --- a/cola/algorithms/slq.py +++ b/cola/algorithms/slq.py @@ -50,7 +50,7 @@ def slq_fwd(A, fun, num_samples, max_iters, tol, pbar, key): tau = Q[..., 0, :] # approx = xnp.sum(tau**2 * fun(eigvals), axis=-1) # fn_vals = xnp.where(xnp.abs(eigvals) > _mp, fun(eigvals), xnp.zeros_like(eigvals)) - const = 10*_mp * xnp.max(eigvals, axis=1, keepdims=True) + const = 10 * _mp * xnp.max(eigvals, axis=1, keepdims=True) fn_vals = xnp.where(xnp.abs(eigvals) > const, fun(eigvals), xnp.zeros_like(eigvals)) approx = xnp.sum(tau**2 * fn_vals, axis=-1) estimate = A.shape[-2] * approx diff --git a/cola/linalg/eigs.py b/cola/linalg/eigs.py index 5117602d..b51f431d 100644 --- a/cola/linalg/eigs.py +++ b/cola/linalg/eigs.py @@ -11,7 +11,7 @@ from cola.ops import Identity from cola.ops import Triangular from cola.algorithms import power_iteration -from cola.algorithms.lanczos import lanczos +from cola.algorithms.lanczos import lanczos_eigs from cola.algorithms.arnoldi import arnoldi_eigs from cola.utils import export @@ -75,7 +75,7 @@ def eig(A: LinearOperator, **kwargs): eig_vals, eig_vecs = xnp.eigh(A.to_dense()) return eig_vals[eig_slice], Stiefel(lazify(eig_vecs[:, eig_slice])) elif method in ('lanczos', 'iterative') or (method == 'auto' and prod(A.shape) >= 1e6): - eig_vals, eig_vecs, _ = lanczos(A, **kws) + eig_vals, eig_vecs, _ = lanczos_eigs(A, **kws) return eig_vals, eig_vecs else: raise ValueError(f"Unknown method {method} for SelfAdjoint operator") diff --git a/tests/algorithms/test_lanczos.py b/tests/algorithms/test_lanczos.py index 4674797a..e01e5061 100644 --- a/tests/algorithms/test_lanczos.py +++ b/tests/algorithms/test_lanczos.py @@ -5,7 +5,7 @@ from cola.algorithms.lanczos import construct_tridiagonal_batched from cola.algorithms.lanczos import get_lanczos_coeffs from cola.algorithms.lanczos import lanczos_parts -from cola.algorithms.lanczos import lanczos +from cola.algorithms.lanczos import lanczos_eigs from cola.algorithms.lanczos import lanczos_max_eig from cola.utils_test import get_xnp, parametrize, relative_error from cola.utils_test import generate_spectrum, generate_pd_from_diag @@ -35,7 +35,7 @@ def test_lanczos_vjp(backend): def f(theta): Aop = unflatten([theta]) - out = lanczos(Aop, x0, max_iters=10, tol=1e-6, pbar=False) + out = lanczos_eigs(Aop, x0, max_iters=10, tol=1e-6, pbar=False) eig_vals, eig_vecs, _ = out # loss = xnp.sum(eig_vals ** 2.) + xnp.sum(xnp.abs(eig_vecs), axis=[0, 1]) loss = xnp.sum(eig_vals**2.) From 627ed9b7b495987482858fdf6f2fdf22b191b157 Mon Sep 17 00:00:00 2001 From: AndPotap Date: Thu, 7 Sep 2023 11:54:56 -0400 Subject: [PATCH 2/3] Fixed use_householder --- tests/algorithms/test_arnoldi.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/tests/algorithms/test_arnoldi.py b/tests/algorithms/test_arnoldi.py index c5cee354..7e21d718 100644 --- a/tests/algorithms/test_arnoldi.py +++ b/tests/algorithms/test_arnoldi.py @@ -2,7 +2,6 @@ from cola.ops import Householder from cola.ops import Product from cola.ops import Dense -from cola.algorithms.arnoldi import get_householder_vec from cola.fns import lazify from cola.algorithms.arnoldi import get_arnoldi_matrix from cola.algorithms.arnoldi import arnoldi_eigs @@ -82,7 +81,7 @@ def test_arnoldi(backend): assert rel_error < 1e-3 -@parametrize(['jax']) +@parametrize(['jax', 'torch']) def test_householder_arnoldi_decomp(backend): xnp = get_xnp(backend) dtype = xnp.float32 @@ -92,7 +91,7 @@ def test_householder_arnoldi_decomp(backend): # A_np, rhs_np = np.array(A, dtype=np.complex128), np.array(rhs[:, 0], dtype=np.complex128) A_np, rhs_np = np.array(A, dtype=np.float64), np.array(rhs[:, 0], dtype=np.float64) # Q_sol, H_sol = run_householder_arnoldi(A, rhs, A.shape[0], np.float64, xnp) - Q_sol, H_sol = run_householder_arnoldi_np(A_np, rhs_np, A.shape[0], np.float64, xnp) + Q_sol, H_sol = run_householder_arnoldi_np(A_np, rhs_np, A.shape[0], np.float64) # fn = run_householder_arnoldi fn = xnp.jit(run_householder_arnoldi, static_argnums=(0, 2)) @@ -146,7 +145,7 @@ def test_numpy_arnoldi(backend): rhs = np.random.normal(size=(A.shape[0], )) # rhs = np.random.normal(size=(A.shape[0], 2)).view(np.complex128)[:, 0] - Q, H = run_householder_arnoldi_np(A, rhs, max_iter=A.shape[0], dtype=dtype, xnp=xnp) + Q, H = run_householder_arnoldi_np(A, rhs, max_iter=A.shape[0], dtype=dtype) abs_error = np.linalg.norm(np.eye(A.shape[0]) - Q.T @ Q) assert abs_error < 1e-4 abs_error = np.linalg.norm(Q.T @ A @ Q - H) @@ -159,10 +158,10 @@ def test_numpy_arnoldi(backend): assert abs_error < 1e-10 -def run_householder_arnoldi_np(A, rhs, max_iter, dtype, xnp): +def run_householder_arnoldi_np(A, rhs, max_iter, dtype): H, Q, Ps, zj = initialize_householder_arnoldi(rhs, max_iter, dtype) for jdx in range(1, max_iter + 2): - vec, beta = get_householder_vec(zj, jdx - 1, xnp) + vec, beta = get_householder_vec_np(zj, jdx - 1) Ps[jdx].vec, Ps[jdx].beta = vec[:, None], beta H[:, jdx - 1] = np.array(Ps[jdx] @ zj) if jdx <= max_iter: @@ -186,6 +185,26 @@ def initialize_householder_arnoldi(rhs, max_iter, dtype): return H, Q, Ps, zj +def get_householder_vec_np(x, idx): + sigma_2 = np.linalg.norm(x[idx + 1:])**2. + vec = np.zeros_like(x) + vec[idx:] = x[idx:] + if sigma_2 == 0 and x[idx] >= 0: + beta = 0 + elif sigma_2 == 0 and x[idx] < 0: + beta = -2 + else: + x_norm_partial = np.sqrt(x[idx]**2 + sigma_2) + if x[idx] <= 0: + vec[idx] = x[idx] - x_norm_partial + else: + vec[idx] = -sigma_2 / (x[idx] + x_norm_partial) + beta = 2 * vec[idx]**2 / (sigma_2 + vec[idx]**2) + vec = vec / vec[idx] + vec[idx:] = vec[idx:] / vec[idx] + return vec, beta + + def run_arnoldi(A, rhs, max_iter, tol, dtype): Q, H = initialize_arnoldi(rhs, max_iter=max_iter, dtype=dtype) idx, vec = 0, rhs.copy() From 15c64d6983eaaa4c3b0e02cc83499644c22150e1 Mon Sep 17 00:00:00 2001 From: AndPotap Date: Thu, 7 Sep 2023 12:00:23 -0400 Subject: [PATCH 3/3] Deactivated torch householder test --- tests/algorithms/test_arnoldi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/algorithms/test_arnoldi.py b/tests/algorithms/test_arnoldi.py index 7e21d718..6a3047b5 100644 --- a/tests/algorithms/test_arnoldi.py +++ b/tests/algorithms/test_arnoldi.py @@ -81,7 +81,7 @@ def test_arnoldi(backend): assert rel_error < 1e-3 -@parametrize(['jax', 'torch']) +@parametrize(['jax']) def test_householder_arnoldi_decomp(backend): xnp = get_xnp(backend) dtype = xnp.float32