diff --git a/cola/algorithms/arnoldi.py b/cola/algorithms/arnoldi.py index 361df2e5..8db81b97 100644 --- a/cola/algorithms/arnoldi.py +++ b/cola/algorithms/arnoldi.py @@ -1,5 +1,6 @@ +from cola import Stiefel from cola.ops import LinearOperator -from cola.ops import Array +from cola.ops import Array, Dense from cola.ops import Householder, Product from cola.utils import export import cola @@ -57,10 +58,8 @@ def arnoldi_eigs(A: LinearOperator, start_vector: Array = None, max_iters: int = Q, H, info = arnoldi(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol, use_householder=use_householder, pbar=pbar) xnp = A.xnp - eigvals, vs = xnp.eig(H) - eigvectors = Stiefel(lazify(xnp.cast(Q, dtype=vs.dtype))) @ lazify(vs) - # eigvectors = xnp.cast(eigvectors, dtype=A.dtype) - # eigvals = xnp.cast(eigvals, dtype=A.dtype) + eigvals, vs = xnp.eig(H.to_dense()) + eigvectors = Q @ lazify(vs) return eigvals, eigvectors, info @@ -86,7 +85,6 @@ def arnoldi(A: LinearOperator, start_vector=None, max_iters=100, tol: float = 1e - info (dict): General information about the iterative procedure. """ xnp = A.xnp - xnp = A.xnp if start_vector is None: start_vector = xnp.randn(A.shape[-1], dtype=A.dtype, device=A.device) if len(start_vector.shape) == 1: @@ -98,8 +96,11 @@ def arnoldi(A: LinearOperator, start_vector=None, max_iters=100, tol: float = 1e else: Q, H, _, infodict = get_arnoldi_matrix(A=A, rhs=rhs, max_iters=max_iters, tol=tol, pbar=pbar) if len(start_vector.shape) == 1: - return Q[0], H[0], infodict - return Q, H, infodict + return Stiefel(Dense(Q[0])), Dense(H[0]), infodict + else: + H = xnp.vmap(Dense)(H) + Q = Stiefel(xnp.vmap(Dense)(Q)) + return Q, H, infodict def ArnoldiDecomposition(A: LinearOperator, start_vector=None, max_iters=100, tol=1e-7, use_householder=False, diff --git a/cola/algorithms/gmres.py b/cola/algorithms/gmres.py index 49a50c75..5f79a521 100644 --- a/cola/algorithms/gmres.py +++ b/cola/algorithms/gmres.py @@ -66,6 +66,7 @@ def gmres_fwd(A, rhs, x0, max_iters, tol, P, use_householder, use_triangular, pb res = rhs - A @ x0 # (m,k) Q, H, infodict = arnoldi(A=A, start_vector=res, max_iters=max_iters, tol=tol, pbar=pbar, use_householder=use_householder) + Q, H = Q.to_dense(), H.to_dense() beta = xnp.norm(res, axis=-2) e1 = xnp.zeros(shape=(H.shape[1], beta.shape[0]), dtype=rhs.dtype, device=A.device)