diff --git a/cola/decompositions.py b/cola/decompositions.py index cc36be61..e490a30c 100644 --- a/cola/decompositions.py +++ b/cola/decompositions.py @@ -3,20 +3,46 @@ from cola import Unitary from cola.fns import lazify from cola.ops.operator_base import LinearOperator -from cola.ops import Triangular, Permutation +from cola.ops import Triangular, Permutation, Diagonal, Identity, ScalarMul, Kronecker, BlockDiag, I_like from cola.utils import export from cola.linalg import inverse, eig, trace, apply_unary +from plum import dispatch - +@dispatch @export def cholesky_decomposed(A: LinearOperator): """ Performs a cholesky decomposition A=LL* of a linear operator A. The returned operator LL* is the same as A, but represented using - the triangular structure """ + the triangular structure. + + (Implicitly assumes A is PSD) + """ L = Triangular(A.xnp.cholesky(A.to_dense()), lower=True) return L @ L.H +@dispatch +def cholesky_decomposed(A: Identity): + return A + +@dispatch +def cholesky_decomposed(A: Diagonal): + return A + +@dispatch +def cholesky_decomposed(A: ScalarMul): + return A + +@dispatch +def cholesky_decomposed(A: Kronecker): + # see https://www.math.uwaterloo.ca/~hwolkowi/henry/reports/kronthesisschaecke04.pdf + return Kronecker(*[cholesky_decomposed(Ai) for Ai in A.Ms]) + +@dispatch +def cholesky_decomposed(A: BlockDiag): + return BlockDiag(*[cholesky_decomposed(Ai) for Ai in A.Ms],multiplicities=A.multiplicities) + +@dispatch @export def lu_decomposed(A: LinearOperator): """ Performs a cholesky decomposition A=PLU of a linear operator A. @@ -27,6 +53,26 @@ def lu_decomposed(A: LinearOperator): P, L, U = P.to(A.device), L.to(A.device), U.to(A.device) return P @ L @ U +@dispatch +def lu_decomposed(A: Identity): + return A + +@dispatch +def lu_decomposed(A: Diagonal): + return A + +@dispatch +def lu_decomposed(A: ScalarMul): + return A + +@dispatch +def lu_decomposed(A: Kronecker): + # see https://www.math.uwaterloo.ca/~hwolkowi/henry/reports/kronthesisschaecke04.pdf + return Kronecker(*[lu_decomposed(Ai) for Ai in A.Ms]) + +@dispatch +def lu_decomposed(A: BlockDiag): + return BlockDiag(*[lu_decomposed(Ai) for Ai in A.Ms], multiplicities=A.multiplicities) @export class UnitaryDecomposition(LinearOperator): diff --git a/cola/jax_fns.py b/cola/jax_fns.py index 27609fb9..7deef11c 100644 --- a/cola/jax_fns.py +++ b/cola/jax_fns.py @@ -82,6 +82,7 @@ slogdet = jnp.linalg.slogdet prod = jnp.prod moveaxis = jnp.moveaxis +promote_types = jnp.promote_types finfo = jnp.finfo def eig(A): diff --git a/cola/linalg/logdet.py b/cola/linalg/logdet.py index be54ccab..196f3183 100644 --- a/cola/linalg/logdet.py +++ b/cola/linalg/logdet.py @@ -1,5 +1,6 @@ from plum import dispatch -from cola.ops import LinearOperator, Triangular, Permutation +from cola.ops import Array +from cola.ops import LinearOperator, Triangular, Permutation, Identity, ScalarMul from cola.ops import Diagonal, Kronecker, BlockDiag, Product from cola.utils import export from cola.annotations import PSD @@ -107,6 +108,19 @@ def slogdet(A: Product, **kwargs): return product(signs), sum(logdets) +@dispatch +def slogdet(A: Identity, **kwargs): + xnp = A.xnp + zero = xnp.array(0., dtype=A.dtype, device=A.device) + return 1. + zero, zero + +@dispatch +def slogdet(A: ScalarMul, **kwargs): + xnp = A.xnp + c = A.c + phase = c/xnp.abs(c) + return phase, xnp.log(xnp.abs(c)) + @dispatch def slogdet(A: Diagonal, **kwargs): xnp = A.xnp diff --git a/cola/np_fns.py b/cola/np_fns.py index 864b699f..006f9fb9 100644 --- a/cola/np_fns.py +++ b/cola/np_fns.py @@ -65,6 +65,7 @@ def __init__(self): sum = np.sum svd = np.linalg.svd where = np.where +promote_types = np.promote_types finfo = np.finfo def PRNGKey(key): diff --git a/cola/ops/operators.py b/cola/ops/operators.py index f5754ebb..f8c5f5f8 100644 --- a/cola/ops/operators.py +++ b/cola/ops/operators.py @@ -21,10 +21,12 @@ def __init__(self, A: Array): super().__init__(dtype=A.dtype, shape=A.shape) def _matmat(self, X: Array) -> Array: - return self.A @ X + dtype = self.xnp.promote_types(self.dtype, X.dtype) + return self.xnp.cast(self.A,dtype) @ self.xnp.cast(X,dtype) def _rmatmat(self, X: Array) -> Array: - return X @ self.A + dtype = self.xnp.promote_types(self.dtype, X.dtype) + return self.xnp.cast(X,dtype) @ self.xnp.cast(self.A,dtype) def to_dense(self): return self.A @@ -116,7 +118,7 @@ def __init__(self, *Ms): if M1.shape[-1] != M2.shape[-2]: raise ValueError(f"dimension mismatch {M1.shape} vs {M2.shape}") shape = (Ms[0].shape[-2], Ms[-1].shape[-1]) - dtype = Ms[0].dtype + dtype = reduce(self.Ms[0].xnp.promote_types, (M.dtype for M in Ms)) super().__init__(dtype, shape) def _matmat(self, v): @@ -176,7 +178,7 @@ class Kronecker(LinearOperator): def __init__(self, *Ms): self.Ms = tuple(cola.fns.lazify(M) for M in Ms) shape = product([Mi.shape[-2] for Mi in Ms]), product([Mi.shape[-1] for Mi in Ms]) - dtype = Ms[0].dtype + dtype = reduce(self.Ms[0].xnp.promote_types, (M.dtype for M in Ms)) super().__init__(dtype, shape) def _matmat(self, v): @@ -218,7 +220,7 @@ class KronSum(LinearOperator): def __init__(self, *Ms): self.Ms = tuple(cola.fns.lazify(M) for M in Ms) shape = product([Mi.shape[-2] for Mi in Ms]), product([Mi.shape[-1] for Mi in Ms]) - dtype = Ms[0].dtype + dtype = reduce(self.Ms[0].xnp.promote_types, (M.dtype for M in Ms)) super().__init__(dtype, shape) def _matmat(self, v): @@ -259,7 +261,8 @@ def __init__(self, *Ms, multiplicities=None): self.multiplicities = [1 for _ in Ms] if multiplicities is None else multiplicities shape = (sum(Mi.shape[-2] * c for Mi, c in zip(Ms, self.multiplicities)), sum(Mi.shape[-1] * c for Mi, c in zip(Ms, self.multiplicities))) - super().__init__(Ms[0].dtype, shape) + dtype = reduce(self.Ms[0].xnp.promote_types, (M.dtype for M in Ms)) + super().__init__(dtype, shape) def _matmat(self, v): # (n,k) # n = v.shape[0] diff --git a/cola/torch_fns.py b/cola/torch_fns.py index 3d0dd68b..a4f73bca 100644 --- a/cola/torch_fns.py +++ b/cola/torch_fns.py @@ -62,6 +62,7 @@ slogdet = torch.linalg.slogdet prod = torch.prod moveaxis = torch.moveaxis +promote_types = torch.promote_types finfo = torch.finfo def max(array, axis, keepdims=False): diff --git a/tests/linalg/operator_market.py b/tests/linalg/operator_market.py index b1a7081f..eaf81570 100644 --- a/tests/linalg/operator_market.py +++ b/tests/linalg/operator_market.py @@ -16,6 +16,7 @@ 'psd_identity', 'psd_prod', 'psd_scalarmul', + 'psd_kron', 'selfadj_hessian', 'selfadj_tridiagonal', # 'square_big', # skipped by default @@ -72,7 +73,11 @@ def get_test_operator(backend: str, precision: str, op_name: str, op = BlockDiag(M1, M2, multiplicities=[2, 3]) case 'prod': op = M1 @ M1.T - + case ('psd', 'kron'): + M1 = Dense(xnp.array([[6., 2], [2, 4]], dtype=dtype, device=device)) + M2 = Dense(xnp.array([[7, 6], [6, 8]], dtype=dtype, device=device)) + op = Kronecker(M1,M2) + case (('selfadj' | 'square') as op_prop, 'tridiagonal'): alpha = xnp.array([1, 2, 3], dtype=dtype, device=device)[:2] beta = xnp.array([4, 5, 6], dtype=dtype, device=device) diff --git a/tests/linalg/test_logdet.py b/tests/linalg/test_logdet.py index 81663c47..4489fd79 100644 --- a/tests/linalg/test_logdet.py +++ b/tests/linalg/test_logdet.py @@ -5,7 +5,7 @@ from cola.ops import LinearOperator from cola.utils_test import parametrize, relative_error - +# tests should be passing now @parametrize(['torch', 'jax'], ['float64'], op_names).excluding[:,:,['psd_identity','psd_scalarmul']] def test_logdet(backend, precision, op_name): operator = get_test_operator(backend, precision, op_name) diff --git a/tests/test_decomps.py b/tests/test_decomps.py new file mode 100644 index 00000000..7c829607 --- /dev/null +++ b/tests/test_decomps.py @@ -0,0 +1,27 @@ +import cola +from cola.utils_test import get_xnp, parametrize, relative_error +from linalg.operator_market import op_names, get_test_operator + +@parametrize(['torch', 'jax'],[op for op in op_names if op.startswith('psd')]) +def test_cholesky(backend, opname): + xnp = get_xnp(backend) + A = get_test_operator(backend, 'float32', opname) + A_decomposed = cola.cholesky_decomposed(A) + Ainv1 = xnp.inv(A_decomposed.to_dense()) + Ainv2 = cola.inverse(A_decomposed).to_dense() + assert relative_error(Ainv1, Ainv2) < 1e-5 + logdet1 = xnp.slogdet(A_decomposed.to_dense())[1] + logdet2 = cola.logdet(A_decomposed) + assert relative_error(logdet1, logdet2) < 1e-5 + +@parametrize(['torch', 'jax'],[op for op in op_names if op.startswith('square')]) +def test_lu(backend, opname): + xnp = get_xnp(backend) + A = get_test_operator(backend, 'float32', opname) + A_decomposed = cola.lu_decomposed(A) + Ainv1 = xnp.inv(A_decomposed.to_dense()) + Ainv2 = cola.inverse(A_decomposed).to_dense() + assert relative_error(xnp.cast(Ainv1,Ainv2.dtype), Ainv2) < 1e-5 + logdet1 = xnp.slogdet(A_decomposed.to_dense())[1] + logdet2 = cola.logdet(A_decomposed) + assert relative_error(logdet1, logdet2) < 1e-5 \ No newline at end of file