Skip to content

Commit

Permalink
initial commit for new cholesky feature (#33)
Browse files Browse the repository at this point in the history
New feature: 
Dispatch rules for Cholesky and LU decompositions. (+ some more
consistent type promotion for torch)

Includes cases like

```python
@dispatch
def cholesky_decomposed(A: ScalarMul):
    return A

@dispatch
def cholesky_decomposed(A: Kronecker):
    # see https://www.math.uwaterloo.ca/~hwolkowi/henry/reports/kronthesisschaecke04.pdf
    return Kronecker(*[cholesky_decomposed(Ai) for Ai in A.Ms])
```

Because of the way the function returns right now, I'm assuming that the
output structure being `Product[Triangular, Triangular]` is not
necessarily a desired invariant.

For $A = B\otimes C$, we return $L_AL_A^T \otimes L_BL_B^T$ rather than
$(L_A\otimes L_B)(L_A \otimes L_B)^T$.

Alternatively, the function could just return `L` in `[email protected]`, though the
triangular structure is just a wrapper of dense right now (it could be
converted to an annotation).

One might also consider combining the LU and Cholesky logic, but they
are left separate for now.
  • Loading branch information
mfinzi authored Aug 30, 2023
1 parent 74406c9 commit 45cf124
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 12 deletions.
52 changes: 49 additions & 3 deletions cola/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,46 @@
from cola import Unitary
from cola.fns import lazify
from cola.ops.operator_base import LinearOperator
from cola.ops import Triangular, Permutation
from cola.ops import Triangular, Permutation, Diagonal, Identity, ScalarMul, Kronecker, BlockDiag, I_like
from cola.utils import export
from cola.linalg import inverse, eig, trace, apply_unary
from plum import dispatch


@dispatch
@export
def cholesky_decomposed(A: LinearOperator):
""" Performs a cholesky decomposition A=LL* of a linear operator A.
The returned operator LL* is the same as A, but represented using
the triangular structure """
the triangular structure.
(Implicitly assumes A is PSD)
"""
L = Triangular(A.xnp.cholesky(A.to_dense()), lower=True)
return L @ L.H

@dispatch
def cholesky_decomposed(A: Identity):
return A

@dispatch
def cholesky_decomposed(A: Diagonal):
return A

@dispatch
def cholesky_decomposed(A: ScalarMul):
return A

@dispatch
def cholesky_decomposed(A: Kronecker):
# see https://www.math.uwaterloo.ca/~hwolkowi/henry/reports/kronthesisschaecke04.pdf
return Kronecker(*[cholesky_decomposed(Ai) for Ai in A.Ms])

@dispatch
def cholesky_decomposed(A: BlockDiag):
return BlockDiag(*[cholesky_decomposed(Ai) for Ai in A.Ms],multiplicities=A.multiplicities)


@dispatch
@export
def lu_decomposed(A: LinearOperator):
""" Performs a cholesky decomposition A=PLU of a linear operator A.
Expand All @@ -27,6 +53,26 @@ def lu_decomposed(A: LinearOperator):
P, L, U = P.to(A.device), L.to(A.device), U.to(A.device)
return P @ L @ U

@dispatch
def lu_decomposed(A: Identity):
return A

@dispatch
def lu_decomposed(A: Diagonal):
return A

@dispatch
def lu_decomposed(A: ScalarMul):
return A

@dispatch
def lu_decomposed(A: Kronecker):
# see https://www.math.uwaterloo.ca/~hwolkowi/henry/reports/kronthesisschaecke04.pdf
return Kronecker(*[lu_decomposed(Ai) for Ai in A.Ms])

@dispatch
def lu_decomposed(A: BlockDiag):
return BlockDiag(*[lu_decomposed(Ai) for Ai in A.Ms], multiplicities=A.multiplicities)

@export
class UnitaryDecomposition(LinearOperator):
Expand Down
1 change: 1 addition & 0 deletions cola/jax_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
slogdet = jnp.linalg.slogdet
prod = jnp.prod
moveaxis = jnp.moveaxis
promote_types = jnp.promote_types
finfo = jnp.finfo

def eig(A):
Expand Down
16 changes: 15 additions & 1 deletion cola/linalg/logdet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from plum import dispatch
from cola.ops import LinearOperator, Triangular, Permutation
from cola.ops import Array
from cola.ops import LinearOperator, Triangular, Permutation, Identity, ScalarMul
from cola.ops import Diagonal, Kronecker, BlockDiag, Product
from cola.utils import export
from cola.annotations import PSD
Expand Down Expand Up @@ -107,6 +108,19 @@ def slogdet(A: Product, **kwargs):
return product(signs), sum(logdets)


@dispatch
def slogdet(A: Identity, **kwargs):
xnp = A.xnp
zero = xnp.array(0., dtype=A.dtype, device=A.device)
return 1. + zero, zero

@dispatch
def slogdet(A: ScalarMul, **kwargs):
xnp = A.xnp
c = A.c
phase = c/xnp.abs(c)
return phase, xnp.log(xnp.abs(c))

@dispatch
def slogdet(A: Diagonal, **kwargs):
xnp = A.xnp
Expand Down
1 change: 1 addition & 0 deletions cola/np_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(self):
sum = np.sum
svd = np.linalg.svd
where = np.where
promote_types = np.promote_types
finfo = np.finfo

def PRNGKey(key):
Expand Down
15 changes: 9 additions & 6 deletions cola/ops/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ def __init__(self, A: Array):
super().__init__(dtype=A.dtype, shape=A.shape)

def _matmat(self, X: Array) -> Array:
return self.A @ X
dtype = self.xnp.promote_types(self.dtype, X.dtype)
return self.xnp.cast(self.A,dtype) @ self.xnp.cast(X,dtype)

def _rmatmat(self, X: Array) -> Array:
return X @ self.A
dtype = self.xnp.promote_types(self.dtype, X.dtype)
return self.xnp.cast(X,dtype) @ self.xnp.cast(self.A,dtype)

def to_dense(self):
return self.A
Expand Down Expand Up @@ -116,7 +118,7 @@ def __init__(self, *Ms):
if M1.shape[-1] != M2.shape[-2]:
raise ValueError(f"dimension mismatch {M1.shape} vs {M2.shape}")
shape = (Ms[0].shape[-2], Ms[-1].shape[-1])
dtype = Ms[0].dtype
dtype = reduce(self.Ms[0].xnp.promote_types, (M.dtype for M in Ms))
super().__init__(dtype, shape)

def _matmat(self, v):
Expand Down Expand Up @@ -176,7 +178,7 @@ class Kronecker(LinearOperator):
def __init__(self, *Ms):
self.Ms = tuple(cola.fns.lazify(M) for M in Ms)
shape = product([Mi.shape[-2] for Mi in Ms]), product([Mi.shape[-1] for Mi in Ms])
dtype = Ms[0].dtype
dtype = reduce(self.Ms[0].xnp.promote_types, (M.dtype for M in Ms))
super().__init__(dtype, shape)

def _matmat(self, v):
Expand Down Expand Up @@ -218,7 +220,7 @@ class KronSum(LinearOperator):
def __init__(self, *Ms):
self.Ms = tuple(cola.fns.lazify(M) for M in Ms)
shape = product([Mi.shape[-2] for Mi in Ms]), product([Mi.shape[-1] for Mi in Ms])
dtype = Ms[0].dtype
dtype = reduce(self.Ms[0].xnp.promote_types, (M.dtype for M in Ms))
super().__init__(dtype, shape)

def _matmat(self, v):
Expand Down Expand Up @@ -259,7 +261,8 @@ def __init__(self, *Ms, multiplicities=None):
self.multiplicities = [1 for _ in Ms] if multiplicities is None else multiplicities
shape = (sum(Mi.shape[-2] * c for Mi, c in zip(Ms, self.multiplicities)),
sum(Mi.shape[-1] * c for Mi, c in zip(Ms, self.multiplicities)))
super().__init__(Ms[0].dtype, shape)
dtype = reduce(self.Ms[0].xnp.promote_types, (M.dtype for M in Ms))
super().__init__(dtype, shape)

def _matmat(self, v): # (n,k)
# n = v.shape[0]
Expand Down
1 change: 1 addition & 0 deletions cola/torch_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
slogdet = torch.linalg.slogdet
prod = torch.prod
moveaxis = torch.moveaxis
promote_types = torch.promote_types
finfo = torch.finfo

def max(array, axis, keepdims=False):
Expand Down
7 changes: 6 additions & 1 deletion tests/linalg/operator_market.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
'psd_identity',
'psd_prod',
'psd_scalarmul',
'psd_kron',
'selfadj_hessian',
'selfadj_tridiagonal',
# 'square_big', # skipped by default
Expand Down Expand Up @@ -72,7 +73,11 @@ def get_test_operator(backend: str, precision: str, op_name: str,
op = BlockDiag(M1, M2, multiplicities=[2, 3])
case 'prod':
op = M1 @ M1.T

case ('psd', 'kron'):
M1 = Dense(xnp.array([[6., 2], [2, 4]], dtype=dtype, device=device))
M2 = Dense(xnp.array([[7, 6], [6, 8]], dtype=dtype, device=device))
op = Kronecker(M1,M2)

case (('selfadj' | 'square') as op_prop, 'tridiagonal'):
alpha = xnp.array([1, 2, 3], dtype=dtype, device=device)[:2]
beta = xnp.array([4, 5, 6], dtype=dtype, device=device)
Expand Down
2 changes: 1 addition & 1 deletion tests/linalg/test_logdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from cola.ops import LinearOperator
from cola.utils_test import parametrize, relative_error


# tests should be passing now
@parametrize(['torch', 'jax'], ['float64'], op_names).excluding[:,:,['psd_identity','psd_scalarmul']]
def test_logdet(backend, precision, op_name):
operator = get_test_operator(backend, precision, op_name)
Expand Down
27 changes: 27 additions & 0 deletions tests/test_decomps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import cola
from cola.utils_test import get_xnp, parametrize, relative_error
from linalg.operator_market import op_names, get_test_operator

@parametrize(['torch', 'jax'],[op for op in op_names if op.startswith('psd')])
def test_cholesky(backend, opname):
xnp = get_xnp(backend)
A = get_test_operator(backend, 'float32', opname)
A_decomposed = cola.cholesky_decomposed(A)
Ainv1 = xnp.inv(A_decomposed.to_dense())
Ainv2 = cola.inverse(A_decomposed).to_dense()
assert relative_error(Ainv1, Ainv2) < 1e-5
logdet1 = xnp.slogdet(A_decomposed.to_dense())[1]
logdet2 = cola.logdet(A_decomposed)
assert relative_error(logdet1, logdet2) < 1e-5

@parametrize(['torch', 'jax'],[op for op in op_names if op.startswith('square')])
def test_lu(backend, opname):
xnp = get_xnp(backend)
A = get_test_operator(backend, 'float32', opname)
A_decomposed = cola.lu_decomposed(A)
Ainv1 = xnp.inv(A_decomposed.to_dense())
Ainv2 = cola.inverse(A_decomposed).to_dense()
assert relative_error(xnp.cast(Ainv1,Ainv2.dtype), Ainv2) < 1e-5
logdet1 = xnp.slogdet(A_decomposed.to_dense())[1]
logdet2 = cola.logdet(A_decomposed)
assert relative_error(logdet1, logdet2) < 1e-5

0 comments on commit 45cf124

Please sign in to comment.