-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
initial commit for new cholesky feature (#33)
New feature: Dispatch rules for Cholesky and LU decompositions. (+ some more consistent type promotion for torch) Includes cases like ```python @dispatch def cholesky_decomposed(A: ScalarMul): return A @dispatch def cholesky_decomposed(A: Kronecker): # see https://www.math.uwaterloo.ca/~hwolkowi/henry/reports/kronthesisschaecke04.pdf return Kronecker(*[cholesky_decomposed(Ai) for Ai in A.Ms]) ``` Because of the way the function returns right now, I'm assuming that the output structure being `Product[Triangular, Triangular]` is not necessarily a desired invariant. For $A = B\otimes C$, we return $L_AL_A^T \otimes L_BL_B^T$ rather than $(L_A\otimes L_B)(L_A \otimes L_B)^T$. Alternatively, the function could just return `L` in `[email protected]`, though the triangular structure is just a wrapper of dense right now (it could be converted to an annotation). One might also consider combining the LU and Cholesky logic, but they are left separate for now.
- Loading branch information
Showing
9 changed files
with
110 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import cola | ||
from cola.utils_test import get_xnp, parametrize, relative_error | ||
from linalg.operator_market import op_names, get_test_operator | ||
|
||
@parametrize(['torch', 'jax'],[op for op in op_names if op.startswith('psd')]) | ||
def test_cholesky(backend, opname): | ||
xnp = get_xnp(backend) | ||
A = get_test_operator(backend, 'float32', opname) | ||
A_decomposed = cola.cholesky_decomposed(A) | ||
Ainv1 = xnp.inv(A_decomposed.to_dense()) | ||
Ainv2 = cola.inverse(A_decomposed).to_dense() | ||
assert relative_error(Ainv1, Ainv2) < 1e-5 | ||
logdet1 = xnp.slogdet(A_decomposed.to_dense())[1] | ||
logdet2 = cola.logdet(A_decomposed) | ||
assert relative_error(logdet1, logdet2) < 1e-5 | ||
|
||
@parametrize(['torch', 'jax'],[op for op in op_names if op.startswith('square')]) | ||
def test_lu(backend, opname): | ||
xnp = get_xnp(backend) | ||
A = get_test_operator(backend, 'float32', opname) | ||
A_decomposed = cola.lu_decomposed(A) | ||
Ainv1 = xnp.inv(A_decomposed.to_dense()) | ||
Ainv2 = cola.inverse(A_decomposed).to_dense() | ||
assert relative_error(xnp.cast(Ainv1,Ainv2.dtype), Ainv2) < 1e-5 | ||
logdet1 = xnp.slogdet(A_decomposed.to_dense())[1] | ||
logdet2 = cola.logdet(A_decomposed) | ||
assert relative_error(logdet1, logdet2) < 1e-5 |