Skip to content

Commit

Permalink
Synced arnoldi and lanczos output (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
AndPotap authored Oct 4, 2023
2 parents 401e791 + f6c59ff commit b04d87e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
17 changes: 9 additions & 8 deletions cola/algorithms/arnoldi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from cola import Stiefel
from cola.ops import LinearOperator
from cola.ops import Array
from cola.ops import Array, Dense
from cola.ops import Householder, Product
from cola.utils import export
import cola
Expand Down Expand Up @@ -57,10 +58,8 @@ def arnoldi_eigs(A: LinearOperator, start_vector: Array = None, max_iters: int =
Q, H, info = arnoldi(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol, use_householder=use_householder,
pbar=pbar)
xnp = A.xnp
eigvals, vs = xnp.eig(H)
eigvectors = Stiefel(lazify(xnp.cast(Q, dtype=vs.dtype))) @ lazify(vs)
# eigvectors = xnp.cast(eigvectors, dtype=A.dtype)
# eigvals = xnp.cast(eigvals, dtype=A.dtype)
eigvals, vs = xnp.eig(H.to_dense())
eigvectors = Q @ lazify(vs)
return eigvals, eigvectors, info


Expand All @@ -86,7 +85,6 @@ def arnoldi(A: LinearOperator, start_vector=None, max_iters=100, tol: float = 1e
- info (dict): General information about the iterative procedure.
"""
xnp = A.xnp
xnp = A.xnp
if start_vector is None:
start_vector = xnp.randn(A.shape[-1], dtype=A.dtype, device=A.device)
if len(start_vector.shape) == 1:
Expand All @@ -98,8 +96,11 @@ def arnoldi(A: LinearOperator, start_vector=None, max_iters=100, tol: float = 1e
else:
Q, H, _, infodict = get_arnoldi_matrix(A=A, rhs=rhs, max_iters=max_iters, tol=tol, pbar=pbar)
if len(start_vector.shape) == 1:
return Q[0], H[0], infodict
return Q, H, infodict
return Stiefel(Dense(Q[0])), Dense(H[0]), infodict
else:
H = xnp.vmap(Dense)(H)
Q = Stiefel(xnp.vmap(Dense)(Q))
return Q, H, infodict


def ArnoldiDecomposition(A: LinearOperator, start_vector=None, max_iters=100, tol=1e-7, use_householder=False,
Expand Down
1 change: 1 addition & 0 deletions cola/algorithms/gmres.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def gmres_fwd(A, rhs, x0, max_iters, tol, P, use_householder, use_triangular, pb
res = rhs - A @ x0 # (m,k)
Q, H, infodict = arnoldi(A=A, start_vector=res, max_iters=max_iters, tol=tol, pbar=pbar,
use_householder=use_householder)
Q, H = Q.to_dense(), H.to_dense()

beta = xnp.norm(res, axis=-2)
e1 = xnp.zeros(shape=(H.shape[1], beta.shape[0]), dtype=rhs.dtype, device=A.device)
Expand Down

0 comments on commit b04d87e

Please sign in to comment.