Skip to content

Commit

Permalink
Code maintenance (#104)
Browse files Browse the repository at this point in the history
Code maintenance: sorted imports, opened up CoLA ops for go to, pytest
incorporates defaults marks, and took out unsued files.
  • Loading branch information
AndPotap authored Sep 25, 2024
1 parent c29c268 commit 2ddcb3d
Show file tree
Hide file tree
Showing 32 changed files with 199 additions and 458 deletions.
24 changes: 18 additions & 6 deletions cola/annotations.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from collections.abc import Iterable
from functools import reduce
from typing import Set, Union
from collections.abc import Iterable

from plum import dispatch
from cola.ops import LinearOperator, Array
from cola.ops import Kronecker, Product, Sum
from cola.ops import Transpose, Adjoint
from cola.ops import BlockDiag, Identity, ScalarMul
from cola.ops import Hessian, Permutation, Sliced

from cola.ops import (
Adjoint,
Array,
BlockDiag,
Hessian,
Identity,
Kronecker,
LinearOperator,
Permutation,
Product,
ScalarMul,
Sliced,
Sum,
Transpose,
)
from cola.utils import export

Scalar = Array
Expand Down
1 change: 1 addition & 0 deletions cola/linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" High level linear algebra functions, """
import pkgutil

from cola.utils import import_from_all

__all__ = []
Expand Down
5 changes: 4 additions & 1 deletion cola/linalg/algorithm_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from types import SimpleNamespace

from plum import parametric

from cola.ops import LinearOperator
from cola.utils import export
from types import SimpleNamespace

# import pytreeclass as tc


Expand Down
1 change: 1 addition & 0 deletions cola/linalg/decompositions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pkgutil

from cola.utils import import_from_all

__all__ = []
Expand Down
8 changes: 3 additions & 5 deletions cola/linalg/decompositions/arnoldi.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from typing import Tuple
from cola import Stiefel
from cola.ops import LinearOperator
from cola.ops import Array, Dense
from cola.ops import Householder, Product

# from cola.utils import export
from cola import lazify
from cola import Stiefel, lazify
from cola.ops import Array, Dense, Householder, LinearOperator, Product

# def arnoldi_eigs_bwd(res, grads, unflatten, *args, **kwargs):
# val_grads, eig_grads, _ = grads
Expand Down
4 changes: 2 additions & 2 deletions cola/linalg/decompositions/lanczos.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import cola
from cola import SelfAdjoint, Unitary
from cola.fns import lazify
from cola.ops import Array, LinearOperator, Dense, Tridiagonal
import cola
from cola.ops import Array, Dense, LinearOperator, Tridiagonal


def lanczos_eig_bwd(res, grads, unflatten, *args, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions cola/linalg/eig/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pkgutil

from cola.utils import import_from_all

__all__ = []
Expand Down
5 changes: 2 additions & 3 deletions cola/linalg/eig/iram.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import numpy as np
from scipy.sparse.linalg import LinearOperator as LO
from scipy.sparse.linalg import eigs
from cola.ops import LinearOperator
from cola.ops import Array
from cola.ops import Dense

from cola.ops import Array, Dense, LinearOperator
from cola.utils import export
from cola.utils.utils_linalg import get_numpy_dtype

Expand Down
9 changes: 5 additions & 4 deletions cola/linalg/eig/lobpcg.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
from dataclasses import dataclass
from cola.linalg.algorithm_base import Algorithm

import numpy as np
from scipy.sparse.linalg import LinearOperator as LO
from scipy.sparse.linalg import lobpcg as lobpcg_sp
from cola.ops import LinearOperator
from cola.ops import Dense

from cola.linalg.algorithm_base import Algorithm
from cola.ops import Dense, LinearOperator
from cola.utils import export


Expand Down
9 changes: 5 additions & 4 deletions cola/linalg/eig/power_iteration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from cola.utils import export
from cola.ops import LinearOperator
from cola.linalg.algorithm_base import Algorithm
from dataclasses import dataclass
from typing import Optional, Any
from typing import Any, Optional

from cola.linalg.algorithm_base import Algorithm
from cola.ops import LinearOperator
from cola.utils import export

PRNGKey = Any

Expand Down
1 change: 1 addition & 0 deletions cola/linalg/inverse/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pkgutil

from cola.utils import import_from_all

__all__ = []
Expand Down
9 changes: 4 additions & 5 deletions cola/linalg/inverse/pinv.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import numpy as np
from plum import dispatch

import cola
from cola.annotations import PSD
from cola.linalg.algorithm_base import Algorithm, Auto, IterativeOperatorWInfo
from cola.linalg.inverse.cg import CG
from cola.ops.operators import Diagonal, Identity, LinearOperator, Permutation, ScalarMul, I_like
from cola.annotations import PSD
from cola.ops.operators import Diagonal, I_like, Identity, LinearOperator, Permutation, ScalarMul
from cola.utils import export
from cola.utils.utils_linalg import get_precision

Expand Down Expand Up @@ -43,7 +42,7 @@ def pinv(A: LinearOperator, alg: Algorithm = Auto()):
Example:
>>> A = MyLinearOperator()
>>> x = cola.pseudo(A) @ b
>>> x = cola.pinv(A) @ b
"""

Expand All @@ -67,7 +66,7 @@ def pinv(A: LinearOperator, alg: Auto):
def pinv(A: LinearOperator, alg: CG):
xnp = A.xnp
M = A.H @ A
cons = get_precision(xnp, A.dtype) * xnp.sqrt(cola.eigmax(M))
cons = get_precision(xnp, A.dtype) * max(A.shape)
Op = IterativeOperatorWInfo(M, alg)
return PSD(Op + cons * I_like(M)) @ A.H

Expand Down
1 change: 1 addition & 0 deletions cola/linalg/logdet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pkgutil

from cola.utils import import_from_all

__all__ = []
Expand Down
22 changes: 16 additions & 6 deletions cola/linalg/logdet/logdet.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
import numpy as np
from functools import reduce

import numpy as np
from plum import dispatch

from cola.annotations import PSD
from cola.ops.operators import LinearOperator, Triangular, Permutation, Identity, ScalarMul
from cola.ops.operators import Diagonal, Kronecker, BlockDiag, Product
from cola.utils import export
from cola.linalg.algorithm_base import Algorithm, Auto
from cola.linalg.decompositions.decompositions import Cholesky, LU, Arnoldi, Lanczos
from cola.linalg.decompositions.decompositions import plu, cholesky
from cola.linalg.decompositions.decompositions import LU, Arnoldi, Cholesky, Lanczos, cholesky, plu
from cola.linalg.trace.diag_trace import trace
from cola.linalg.unary.unary import log
from cola.ops.operators import (
BlockDiag,
Diagonal,
Identity,
Kronecker,
LinearOperator,
Permutation,
Product,
ScalarMul,
Triangular,
)
from cola.utils import export


def product(xs):
Expand Down
6 changes: 4 additions & 2 deletions cola/linalg/preconditioning/preconditioners.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Union
from cola.ops import LinearOperator
from cola.linalg.eig.power_iteration import power_iteration

from plum import dispatch

from cola.linalg.eig.power_iteration import power_iteration
from cola.ops import LinearOperator
from cola.utils import export


Expand Down
8 changes: 5 additions & 3 deletions cola/linalg/tbd/nullspace.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from cola.ops import LinearOperator, Array
from cola.backends import get_library_fns
from cola.utils import export
import logging

import numpy as np
from plum import dispatch

from cola.backends import get_library_fns
from cola.ops import Array, LinearOperator
from cola.utils import export

eigmax = None # TODO: fix


Expand Down
37 changes: 0 additions & 37 deletions cola/linalg/tbd/pinv.py

This file was deleted.

3 changes: 2 additions & 1 deletion cola/linalg/tbd/slq.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Callable
from cola.ops import LinearOperator

from cola.linalg.decompositions.lanczos import lanczos
from cola.linalg.inverse.cg import cg
from cola.ops import LinearOperator
from cola.utils import export
from cola.utils.custom_autodiff import iterative_autograd

Expand Down
6 changes: 4 additions & 2 deletions cola/linalg/tbd/svrg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np

import cola

# from cola.linalg.eigs import eigmax
from cola.ops import Sum, Product, Dense
from cola.ops import I_like
from cola.ops import Dense, I_like, Product, Sum
from cola.utils import export

# import standard Union type


Expand Down
1 change: 1 addition & 0 deletions cola/linalg/trace/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pkgutil

from cola.utils import import_from_all

__all__ = []
Expand Down
23 changes: 17 additions & 6 deletions cola/linalg/trace/diag_trace.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
from functools import reduce
from cola.utils import export, dispatch
from cola.ops.operators import LinearOperator, I_like, Diagonal, Identity
from cola.ops.operators import BlockDiag, ScalarMul, Sum, Dense
from cola.ops.operators import Kronecker, KronSum
from cola.linalg.algorithm_base import Algorithm, Auto
from cola.linalg.trace.diagonal_estimation import Hutch, HutchPP, Exact

import numpy as np

from cola.linalg.algorithm_base import Algorithm, Auto
from cola.linalg.trace.diagonal_estimation import Exact, Hutch, HutchPP
from cola.ops.operators import (
BlockDiag,
Dense,
Diagonal,
I_like,
Identity,
Kronecker,
KronSum,
LinearOperator,
ScalarMul,
Sum,
)
from cola.utils import dispatch, export


@export
@dispatch.abstract
Expand Down
10 changes: 6 additions & 4 deletions cola/linalg/trace/diagonal_estimation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from dataclasses import dataclass
from typing import Any, Optional

import numpy as np
from cola.utils import export
from cola.ops import I_like, LinearOperator

from cola.linalg.algorithm_base import Algorithm
from dataclasses import dataclass
from typing import Optional, Any
from cola.ops import I_like, LinearOperator
from cola.utils import export

PRNGKey = Any

Expand Down
1 change: 1 addition & 0 deletions cola/linalg/unary/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pkgutil

from cola.utils import import_from_all

__all__ = []
Expand Down
32 changes: 21 additions & 11 deletions cola/linalg/unary/unary.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
from plum import dispatch
from dataclasses import dataclass
from functools import reduce
from numbers import Number
from typing import Callable
from functools import reduce

import numpy as np
from plum import parametric
from plum import dispatch, parametric

from cola.annotations import PSD, SelfAdjoint
from cola.fns import lazify
from cola.ops import LinearOperator
from cola.ops import Diagonal, Identity, ScalarMul
from cola.ops import BlockDiag, Kronecker, KronSum, I_like, Transpose, Adjoint
from cola.annotations import SelfAdjoint, PSD
from cola.linalg.algorithm_base import Algorithm, Auto
from cola.linalg.inverse.inv import inv
from cola.linalg.decompositions.arnoldi import arnoldi
from cola.linalg.decompositions.decompositions import LU, Arnoldi, Cholesky, Lanczos
from cola.linalg.decompositions.lanczos import lanczos
from cola.linalg.inverse.cg import CG
from cola.linalg.inverse.gmres import GMRES
from cola.linalg.decompositions.lanczos import lanczos
from cola.linalg.decompositions.arnoldi import arnoldi
from cola.linalg.decompositions.decompositions import Arnoldi, Lanczos, LU, Cholesky
from cola.linalg.inverse.inv import inv
from cola.ops import (
Adjoint,
BlockDiag,
Diagonal,
I_like,
Identity,
Kronecker,
KronSum,
LinearOperator,
ScalarMul,
Transpose,
)
from cola.utils import export


Expand Down
Loading

0 comments on commit 2ddcb3d

Please sign in to comment.