Skip to content

Commit

Permalink
Algorithm implementation and refactor (#73)
Browse files Browse the repository at this point in the history
Co-authored-by: AndPotap <[email protected]>
  • Loading branch information
mfinzi and AndPotap authored Oct 13, 2023
1 parent b04d87e commit 4f3aa16
Show file tree
Hide file tree
Showing 88 changed files with 3,773 additions and 4,481 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ print(F @ v)
```
```
[0.2 0.2 2.2 2.2 4.2 4.2 6.2
6.2 8.2 8.2 7.8121004 2.062 ]
6.2 8.2 8.2 7.8 2.1 ]
```

2. **Performing Linear Algebra**. With these objects we can perform linear algebra operations even when they are very big.
Expand All @@ -70,8 +70,8 @@ print(cola.linalg.trace(F))
Q = F.T @ F + 1e-3 * cola.ops.I_like(F)
b = cola.linalg.inv(Q) @ v
print(jnp.linalg.norm(Q @ b - v))
print(cola.linalg.eig(F)[0][:5])
print(cola.sqrt(A))
print(cola.linalg.eig(F, k=F.shape[0])[0][:5])
print(cola.linalg.sqrt(A))
```

```
Expand Down
1 change: 0 additions & 1 deletion cola/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import_from_all("annotations", globals(), __all__, __name__)
import_from_all("linalg", globals(), __all__, __name__)
import_from_all("utils", globals(), __all__, __name__)
import_from_all("decompositions", globals(), __all__, __name__)

__all__.append("LinearOperator")
# import_from_all("ops", globals(), __all__,__name__)
9 changes: 5 additions & 4 deletions cola/backends/jax_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,12 @@ def is_cuda_available():

def eig(A):
# if GPU, convert to CPU first since jax doesn't support it
device = A.device_buffer.device()
if str(device)[:3] != 'cpu':
A = jax.device_put(A, jax.devices("cpu")[0])
# device = A.device_buffer.device()
# if str(device)[:3] != 'cpu':
# A = jax.device_put(A, jax.devices("cpu")[0])
w, v = jnp.linalg.eig(A)
return jax.device_put(w, device), jax.device_put(v, device)
return w, v
# return jax.device_put(w, device), jax.device_put(v, device)


def eye(n, m, dtype, device):
Expand Down
131 changes: 0 additions & 131 deletions cola/decompositions.py

This file was deleted.

31 changes: 31 additions & 0 deletions cola/linalg/algorithm_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from plum import parametric
from cola.ops import LinearOperator
from cola.utils import export
from types import SimpleNamespace
# import pytreeclass as tc


@export
class Algorithm:
pass


@parametric
class IterativeOperatorWInfo(LinearOperator):
def __init__(self, A: LinearOperator, alg: Algorithm):
super().__init__(A.dtype, A.shape)
self.A = A
self.alg = alg
self.info = {}

def _matmat(self, X):
Y, self.info = self.alg(self.A, X)
return Y

def __str__(self):
return f"{self.alg}({str(self.A)})"


@export
class Auto(SimpleNamespace, Algorithm):
pass
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
""" Low-level algorithms used in CoLA (no dispatch rules). """
import pkgutil
from cola.utils import import_from_all

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from cola.ops import LinearOperator
from cola.ops import Array, Dense
from cola.ops import Householder, Product
from cola.utils import export
import cola
from cola import Stiefel, lazify
# from cola.utils import export
from cola import lazify

# def arnoldi_eigs_bwd(res, grads, unflatten, *args, **kwargs):
# val_grads, eig_grads, _ = grads
Expand Down Expand Up @@ -34,9 +33,8 @@

# @export
# @iterative_autograd(arnoldi_eigs_bwd)
@export
def arnoldi_eigs(A: LinearOperator, start_vector: Array = None, max_iters: int = 100, tol: float = 1e-7,
use_householder: bool = False, pbar: bool = False):
use_householder: bool = False, pbar: bool = False, key=None):
"""
Computes eigenvalues and eigenvectors using Arnoldi.
Expand All @@ -48,6 +46,7 @@ def arnoldi_eigs(A: LinearOperator, start_vector: Array = None, max_iters: int =
tol (float, optional): Stopping criteria.
use_householder (bool, optional): Use Householder Arnoldi variant.
pbar (bool, optional): Show a progress bar.
key (PNRGKey, optional): PRNGKey for random number generation.
Returns:
tuple:
Expand All @@ -56,16 +55,15 @@ def arnoldi_eigs(A: LinearOperator, start_vector: Array = None, max_iters: int =
- info (dict): General information about the iterative procedure.
"""
Q, H, info = arnoldi(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol, use_householder=use_householder,
pbar=pbar)
pbar=pbar, key=key)
xnp = A.xnp
eigvals, vs = xnp.eig(H.to_dense())
eigvectors = Q @ lazify(vs)
return eigvals, eigvectors, info


@export
def arnoldi(A: LinearOperator, start_vector=None, max_iters=100, tol: float = 1e-7, use_householder: bool = False,
pbar: bool = False):
pbar: bool = False, key=None):
"""
Computes the Arnoldi decomposition of the linear operator A, A = QHQ^*.
Expand All @@ -77,6 +75,7 @@ def arnoldi(A: LinearOperator, start_vector=None, max_iters=100, tol: float = 1e
tol (float, optional): Stopping criteria.
use_householder (bool, optional): Use Householder Arnoldi iteration.
pbar (bool, optional): Show a progress bar.
key (PNRGKey, optional): PRNGKey for random number generation.
Returns:
tuple:
Expand All @@ -86,7 +85,8 @@ def arnoldi(A: LinearOperator, start_vector=None, max_iters=100, tol: float = 1e
"""
xnp = A.xnp
if start_vector is None:
start_vector = xnp.randn(A.shape[-1], dtype=A.dtype, device=A.device)
key = xnp.PRNGKey(42) if key is None else key
start_vector = xnp.randn(A.shape[-1], dtype=A.dtype, device=A.device, key=key)
if len(start_vector.shape) == 1:
rhs = start_vector[:, None]
else:
Expand All @@ -103,17 +103,6 @@ def arnoldi(A: LinearOperator, start_vector=None, max_iters=100, tol: float = 1e
return Q, H, infodict


def ArnoldiDecomposition(A: LinearOperator, start_vector=None, max_iters=100, tol=1e-7, use_householder=False,
pbar=False):
""" Provides the Arnoldi decomposition of a matrix A = Q H Q^H. LinearOperator form of arnoldi,
see arnoldi for arguments."""
Q, H, info = arnoldi(A=A, start_vector=start_vector, max_iters=max_iters, tol=tol, use_householder=use_householder,
pbar=pbar)
A_approx = cola.UnitaryDecomposition(Q, H)
A_approx.info = info
return A_approx


def get_householder_vec_simple(x, idx, xnp):
indices = xnp.arange(x.shape[0])
vec = xnp.where(indices >= idx, x, 0.)
Expand Down
Loading

0 comments on commit 4f3aa16

Please sign in to comment.