From 2ddcb3d005ec832b359dde357cc0ce799ef10764 Mon Sep 17 00:00:00 2001 From: Andres Potapczynski Date: Tue, 24 Sep 2024 21:25:19 -0400 Subject: [PATCH] Code maintenance (#104) Code maintenance: sorted imports, opened up CoLA ops for go to, pytest incorporates defaults marks, and took out unsued files. --- cola/annotations.py | 24 +- cola/linalg/__init__.py | 1 + cola/linalg/algorithm_base.py | 5 +- cola/linalg/decompositions/__init__.py | 1 + cola/linalg/decompositions/arnoldi.py | 8 +- cola/linalg/decompositions/lanczos.py | 4 +- cola/linalg/eig/__init__.py | 1 + cola/linalg/eig/iram.py | 5 +- cola/linalg/eig/lobpcg.py | 9 +- cola/linalg/eig/power_iteration.py | 9 +- cola/linalg/inverse/__init__.py | 1 + cola/linalg/inverse/pinv.py | 9 +- cola/linalg/logdet/__init__.py | 1 + cola/linalg/logdet/logdet.py | 22 +- .../linalg/preconditioning/preconditioners.py | 6 +- cola/linalg/tbd/nullspace.py | 8 +- cola/linalg/tbd/pinv.py | 37 -- cola/linalg/tbd/slq.py | 3 +- cola/linalg/tbd/svrg.py | 6 +- cola/linalg/trace/__init__.py | 1 + cola/linalg/trace/diag_trace.py | 23 +- cola/linalg/trace/diagonal_estimation.py | 10 +- cola/linalg/unary/__init__.py | 1 + cola/linalg/unary/unary.py | 32 +- cola/ops/__init__.py | 68 +++- cola/ops/operator_base.py | 6 +- cola/utils/__init__.py | 7 +- cola/utils/dispatch.py | 331 ------------------ cola/utils/jax_tqdm.py | 7 +- cola/utils/torch_tqdm.py | 2 +- cola/utils/utils_for_tests.py | 7 +- pytest.ini | 2 + 32 files changed, 199 insertions(+), 458 deletions(-) delete mode 100644 cola/linalg/tbd/pinv.py delete mode 100644 cola/utils/dispatch.py create mode 100644 pytest.ini diff --git a/cola/annotations.py b/cola/annotations.py index d12985c3..ad857591 100644 --- a/cola/annotations.py +++ b/cola/annotations.py @@ -1,12 +1,24 @@ +from collections.abc import Iterable from functools import reduce from typing import Set, Union -from collections.abc import Iterable + from plum import dispatch -from cola.ops import LinearOperator, Array -from cola.ops import Kronecker, Product, Sum -from cola.ops import Transpose, Adjoint -from cola.ops import BlockDiag, Identity, ScalarMul -from cola.ops import Hessian, Permutation, Sliced + +from cola.ops import ( + Adjoint, + Array, + BlockDiag, + Hessian, + Identity, + Kronecker, + LinearOperator, + Permutation, + Product, + ScalarMul, + Sliced, + Sum, + Transpose, +) from cola.utils import export Scalar = Array diff --git a/cola/linalg/__init__.py b/cola/linalg/__init__.py index d960900e..a41d5be2 100644 --- a/cola/linalg/__init__.py +++ b/cola/linalg/__init__.py @@ -1,5 +1,6 @@ """ High level linear algebra functions, """ import pkgutil + from cola.utils import import_from_all __all__ = [] diff --git a/cola/linalg/algorithm_base.py b/cola/linalg/algorithm_base.py index ee5bb045..1c577cd8 100644 --- a/cola/linalg/algorithm_base.py +++ b/cola/linalg/algorithm_base.py @@ -1,7 +1,10 @@ +from types import SimpleNamespace + from plum import parametric + from cola.ops import LinearOperator from cola.utils import export -from types import SimpleNamespace + # import pytreeclass as tc diff --git a/cola/linalg/decompositions/__init__.py b/cola/linalg/decompositions/__init__.py index ac4dec98..29a4036c 100644 --- a/cola/linalg/decompositions/__init__.py +++ b/cola/linalg/decompositions/__init__.py @@ -1,4 +1,5 @@ import pkgutil + from cola.utils import import_from_all __all__ = [] diff --git a/cola/linalg/decompositions/arnoldi.py b/cola/linalg/decompositions/arnoldi.py index 952d5234..f04ac7d9 100644 --- a/cola/linalg/decompositions/arnoldi.py +++ b/cola/linalg/decompositions/arnoldi.py @@ -1,10 +1,8 @@ from typing import Tuple -from cola import Stiefel -from cola.ops import LinearOperator -from cola.ops import Array, Dense -from cola.ops import Householder, Product + # from cola.utils import export -from cola import lazify +from cola import Stiefel, lazify +from cola.ops import Array, Dense, Householder, LinearOperator, Product # def arnoldi_eigs_bwd(res, grads, unflatten, *args, **kwargs): # val_grads, eig_grads, _ = grads diff --git a/cola/linalg/decompositions/lanczos.py b/cola/linalg/decompositions/lanczos.py index f74d8812..ee5c3909 100644 --- a/cola/linalg/decompositions/lanczos.py +++ b/cola/linalg/decompositions/lanczos.py @@ -1,7 +1,7 @@ +import cola from cola import SelfAdjoint, Unitary from cola.fns import lazify -from cola.ops import Array, LinearOperator, Dense, Tridiagonal -import cola +from cola.ops import Array, Dense, LinearOperator, Tridiagonal def lanczos_eig_bwd(res, grads, unflatten, *args, **kwargs): diff --git a/cola/linalg/eig/__init__.py b/cola/linalg/eig/__init__.py index ac4dec98..29a4036c 100644 --- a/cola/linalg/eig/__init__.py +++ b/cola/linalg/eig/__init__.py @@ -1,4 +1,5 @@ import pkgutil + from cola.utils import import_from_all __all__ = [] diff --git a/cola/linalg/eig/iram.py b/cola/linalg/eig/iram.py index 4b5a346b..f96a5242 100644 --- a/cola/linalg/eig/iram.py +++ b/cola/linalg/eig/iram.py @@ -1,9 +1,8 @@ import numpy as np from scipy.sparse.linalg import LinearOperator as LO from scipy.sparse.linalg import eigs -from cola.ops import LinearOperator -from cola.ops import Array -from cola.ops import Dense + +from cola.ops import Array, Dense, LinearOperator from cola.utils import export from cola.utils.utils_linalg import get_numpy_dtype diff --git a/cola/linalg/eig/lobpcg.py b/cola/linalg/eig/lobpcg.py index de41e58e..6193eaac 100644 --- a/cola/linalg/eig/lobpcg.py +++ b/cola/linalg/eig/lobpcg.py @@ -1,10 +1,11 @@ -import numpy as np from dataclasses import dataclass -from cola.linalg.algorithm_base import Algorithm + +import numpy as np from scipy.sparse.linalg import LinearOperator as LO from scipy.sparse.linalg import lobpcg as lobpcg_sp -from cola.ops import LinearOperator -from cola.ops import Dense + +from cola.linalg.algorithm_base import Algorithm +from cola.ops import Dense, LinearOperator from cola.utils import export diff --git a/cola/linalg/eig/power_iteration.py b/cola/linalg/eig/power_iteration.py index ed1388ab..df474da9 100644 --- a/cola/linalg/eig/power_iteration.py +++ b/cola/linalg/eig/power_iteration.py @@ -1,8 +1,9 @@ -from cola.utils import export -from cola.ops import LinearOperator -from cola.linalg.algorithm_base import Algorithm from dataclasses import dataclass -from typing import Optional, Any +from typing import Any, Optional + +from cola.linalg.algorithm_base import Algorithm +from cola.ops import LinearOperator +from cola.utils import export PRNGKey = Any diff --git a/cola/linalg/inverse/__init__.py b/cola/linalg/inverse/__init__.py index ac4dec98..29a4036c 100644 --- a/cola/linalg/inverse/__init__.py +++ b/cola/linalg/inverse/__init__.py @@ -1,4 +1,5 @@ import pkgutil + from cola.utils import import_from_all __all__ = [] diff --git a/cola/linalg/inverse/pinv.py b/cola/linalg/inverse/pinv.py index c47c3476..3c8c725a 100644 --- a/cola/linalg/inverse/pinv.py +++ b/cola/linalg/inverse/pinv.py @@ -1,11 +1,10 @@ import numpy as np from plum import dispatch -import cola +from cola.annotations import PSD from cola.linalg.algorithm_base import Algorithm, Auto, IterativeOperatorWInfo from cola.linalg.inverse.cg import CG -from cola.ops.operators import Diagonal, Identity, LinearOperator, Permutation, ScalarMul, I_like -from cola.annotations import PSD +from cola.ops.operators import Diagonal, I_like, Identity, LinearOperator, Permutation, ScalarMul from cola.utils import export from cola.utils.utils_linalg import get_precision @@ -43,7 +42,7 @@ def pinv(A: LinearOperator, alg: Algorithm = Auto()): Example: >>> A = MyLinearOperator() - >>> x = cola.pseudo(A) @ b + >>> x = cola.pinv(A) @ b """ @@ -67,7 +66,7 @@ def pinv(A: LinearOperator, alg: Auto): def pinv(A: LinearOperator, alg: CG): xnp = A.xnp M = A.H @ A - cons = get_precision(xnp, A.dtype) * xnp.sqrt(cola.eigmax(M)) + cons = get_precision(xnp, A.dtype) * max(A.shape) Op = IterativeOperatorWInfo(M, alg) return PSD(Op + cons * I_like(M)) @ A.H diff --git a/cola/linalg/logdet/__init__.py b/cola/linalg/logdet/__init__.py index ac4dec98..29a4036c 100644 --- a/cola/linalg/logdet/__init__.py +++ b/cola/linalg/logdet/__init__.py @@ -1,4 +1,5 @@ import pkgutil + from cola.utils import import_from_all __all__ = [] diff --git a/cola/linalg/logdet/logdet.py b/cola/linalg/logdet/logdet.py index 5ae8eec2..fd94a2e5 100644 --- a/cola/linalg/logdet/logdet.py +++ b/cola/linalg/logdet/logdet.py @@ -1,15 +1,25 @@ -import numpy as np from functools import reduce + +import numpy as np from plum import dispatch + from cola.annotations import PSD -from cola.ops.operators import LinearOperator, Triangular, Permutation, Identity, ScalarMul -from cola.ops.operators import Diagonal, Kronecker, BlockDiag, Product -from cola.utils import export from cola.linalg.algorithm_base import Algorithm, Auto -from cola.linalg.decompositions.decompositions import Cholesky, LU, Arnoldi, Lanczos -from cola.linalg.decompositions.decompositions import plu, cholesky +from cola.linalg.decompositions.decompositions import LU, Arnoldi, Cholesky, Lanczos, cholesky, plu from cola.linalg.trace.diag_trace import trace from cola.linalg.unary.unary import log +from cola.ops.operators import ( + BlockDiag, + Diagonal, + Identity, + Kronecker, + LinearOperator, + Permutation, + Product, + ScalarMul, + Triangular, +) +from cola.utils import export def product(xs): diff --git a/cola/linalg/preconditioning/preconditioners.py b/cola/linalg/preconditioning/preconditioners.py index 8c4d3eaf..dd7c8e31 100644 --- a/cola/linalg/preconditioning/preconditioners.py +++ b/cola/linalg/preconditioning/preconditioners.py @@ -1,7 +1,9 @@ from typing import Union -from cola.ops import LinearOperator -from cola.linalg.eig.power_iteration import power_iteration + from plum import dispatch + +from cola.linalg.eig.power_iteration import power_iteration +from cola.ops import LinearOperator from cola.utils import export diff --git a/cola/linalg/tbd/nullspace.py b/cola/linalg/tbd/nullspace.py index 3cf95a53..11f614bf 100644 --- a/cola/linalg/tbd/nullspace.py +++ b/cola/linalg/tbd/nullspace.py @@ -1,10 +1,12 @@ -from cola.ops import LinearOperator, Array -from cola.backends import get_library_fns -from cola.utils import export import logging + import numpy as np from plum import dispatch +from cola.backends import get_library_fns +from cola.ops import Array, LinearOperator +from cola.utils import export + eigmax = None # TODO: fix diff --git a/cola/linalg/tbd/pinv.py b/cola/linalg/tbd/pinv.py deleted file mode 100644 index 01ae1b01..00000000 --- a/cola/linalg/tbd/pinv.py +++ /dev/null @@ -1,37 +0,0 @@ -from cola.ops import LinearOperator, I_like -from plum import dispatch -from cola.utils import export -import cola -from cola.linalg.inv import inv - - -@dispatch -@export -def pinv(A: LinearOperator, **kwargs): - """Computes the Moore-Penrose pseudoinverse of a linear operator A. - - Args: - A (LinearOperator): The linear operator to compute the pseudoinverse for. - **kwargs: Additional keyword arguments, including 'tol', 'P', 'x0', 'pbar', 'info', - and 'max_iters'. These are used to customize the inverse computation. - If not provided, they take default values. - - Returns: - LinearOperator: The pseudoinverse of A. - - Example: - A = LinearOperator((3, 5), jnp.float32, lambda x: x[:3]) - A_pinv = pinv(A, tol=1e-4, max_iters=1000) - - .. warning:: - This function is not yet well tested and does not yet include composition rules. - """ - kws = dict(tol=1e-6, P=None, x0=None, pbar=False, max_iters=5000) - kws.update(kwargs) - n, m = A.shape - if n > m: - M = A.H @ A - return inv(cola.PSD(M + kws['tol'] * I_like(M) / 10), **kws) @ A.H - else: - M = A @ A.H - return A.H @ inv(cola.PSD(M + kws['tol'] * I_like(M) / 10), **kws) diff --git a/cola/linalg/tbd/slq.py b/cola/linalg/tbd/slq.py index 83ca906b..764ba76c 100644 --- a/cola/linalg/tbd/slq.py +++ b/cola/linalg/tbd/slq.py @@ -1,7 +1,8 @@ from typing import Callable -from cola.ops import LinearOperator + from cola.linalg.decompositions.lanczos import lanczos from cola.linalg.inverse.cg import cg +from cola.ops import LinearOperator from cola.utils import export from cola.utils.custom_autodiff import iterative_autograd diff --git a/cola/linalg/tbd/svrg.py b/cola/linalg/tbd/svrg.py index 11673bbf..e14e7725 100644 --- a/cola/linalg/tbd/svrg.py +++ b/cola/linalg/tbd/svrg.py @@ -1,9 +1,11 @@ import numpy as np + import cola + # from cola.linalg.eigs import eigmax -from cola.ops import Sum, Product, Dense -from cola.ops import I_like +from cola.ops import Dense, I_like, Product, Sum from cola.utils import export + # import standard Union type diff --git a/cola/linalg/trace/__init__.py b/cola/linalg/trace/__init__.py index ac4dec98..29a4036c 100644 --- a/cola/linalg/trace/__init__.py +++ b/cola/linalg/trace/__init__.py @@ -1,4 +1,5 @@ import pkgutil + from cola.utils import import_from_all __all__ = [] diff --git a/cola/linalg/trace/diag_trace.py b/cola/linalg/trace/diag_trace.py index 1c15d3e5..8094c5fc 100644 --- a/cola/linalg/trace/diag_trace.py +++ b/cola/linalg/trace/diag_trace.py @@ -1,12 +1,23 @@ from functools import reduce -from cola.utils import export, dispatch -from cola.ops.operators import LinearOperator, I_like, Diagonal, Identity -from cola.ops.operators import BlockDiag, ScalarMul, Sum, Dense -from cola.ops.operators import Kronecker, KronSum -from cola.linalg.algorithm_base import Algorithm, Auto -from cola.linalg.trace.diagonal_estimation import Hutch, HutchPP, Exact + import numpy as np +from cola.linalg.algorithm_base import Algorithm, Auto +from cola.linalg.trace.diagonal_estimation import Exact, Hutch, HutchPP +from cola.ops.operators import ( + BlockDiag, + Dense, + Diagonal, + I_like, + Identity, + Kronecker, + KronSum, + LinearOperator, + ScalarMul, + Sum, +) +from cola.utils import dispatch, export + @export @dispatch.abstract diff --git a/cola/linalg/trace/diagonal_estimation.py b/cola/linalg/trace/diagonal_estimation.py index 318a549a..ade6f03e 100644 --- a/cola/linalg/trace/diagonal_estimation.py +++ b/cola/linalg/trace/diagonal_estimation.py @@ -1,9 +1,11 @@ +from dataclasses import dataclass +from typing import Any, Optional + import numpy as np -from cola.utils import export -from cola.ops import I_like, LinearOperator + from cola.linalg.algorithm_base import Algorithm -from dataclasses import dataclass -from typing import Optional, Any +from cola.ops import I_like, LinearOperator +from cola.utils import export PRNGKey = Any diff --git a/cola/linalg/unary/__init__.py b/cola/linalg/unary/__init__.py index ac4dec98..29a4036c 100644 --- a/cola/linalg/unary/__init__.py +++ b/cola/linalg/unary/__init__.py @@ -1,4 +1,5 @@ import pkgutil + from cola.utils import import_from_all __all__ = [] diff --git a/cola/linalg/unary/unary.py b/cola/linalg/unary/unary.py index 8ba07076..9b83d908 100644 --- a/cola/linalg/unary/unary.py +++ b/cola/linalg/unary/unary.py @@ -1,22 +1,32 @@ -from plum import dispatch from dataclasses import dataclass +from functools import reduce from numbers import Number from typing import Callable -from functools import reduce + import numpy as np -from plum import parametric +from plum import dispatch, parametric + +from cola.annotations import PSD, SelfAdjoint from cola.fns import lazify -from cola.ops import LinearOperator -from cola.ops import Diagonal, Identity, ScalarMul -from cola.ops import BlockDiag, Kronecker, KronSum, I_like, Transpose, Adjoint -from cola.annotations import SelfAdjoint, PSD from cola.linalg.algorithm_base import Algorithm, Auto -from cola.linalg.inverse.inv import inv +from cola.linalg.decompositions.arnoldi import arnoldi +from cola.linalg.decompositions.decompositions import LU, Arnoldi, Cholesky, Lanczos +from cola.linalg.decompositions.lanczos import lanczos from cola.linalg.inverse.cg import CG from cola.linalg.inverse.gmres import GMRES -from cola.linalg.decompositions.lanczos import lanczos -from cola.linalg.decompositions.arnoldi import arnoldi -from cola.linalg.decompositions.decompositions import Arnoldi, Lanczos, LU, Cholesky +from cola.linalg.inverse.inv import inv +from cola.ops import ( + Adjoint, + BlockDiag, + Diagonal, + I_like, + Identity, + Kronecker, + KronSum, + LinearOperator, + ScalarMul, + Transpose, +) from cola.utils import export diff --git a/cola/ops/__init__.py b/cola/ops/__init__.py index c61bd9fe..eb64f86a 100644 --- a/cola/ops/__init__.py +++ b/cola/ops/__init__.py @@ -1,14 +1,56 @@ -""" Linear Operators in CoLA""" -from cola.utils import import_from_all, import_every +from cola.ops.operator_base import Array, LinearOperator +from cola.ops.operators import ( + FFT, + Adjoint, + BlockDiag, + Concatenated, + ConvolveND, + Dense, + Diagonal, + Hessian, + Householder, + I_like, + Identity, + Jacobian, + Kernel, + Kronecker, + KronSum, + Permutation, + Product, + ScalarMul, + Sliced, + Sparse, + Sum, + Transpose, + Triangular, + Tridiagonal, +) -__all__ = [] -import_from_all("operator_base", globals(), __all__, __name__) - - -# is_operator = lambda name,value: isinstance(value,type) and issubclass(value,LinearOperator) -def has_docstring(name, value): - return hasattr(value, "__doc__") and value.__doc__ is not None - - -import_every("operators", globals(), __all__, __name__) # ,has_docstring) -# import_from_all("decompositions", globals(), __all__, __name__) +__all__ = [ + "LinearOperator", + "Array", + "Dense", + "Triangular", + "Sparse", + "ScalarMul", + "Identity", + "Product", + "Sum", + "Kronecker", + "KronSum", + "BlockDiag", + "Diagonal", + "Tridiagonal", + "Transpose", + "Adjoint", + "Sliced", + "Jacobian", + "Hessian", + "Permutation", + "Concatenated", + "ConvolveND", + "Householder", + "Kernel", + "I_like", + "FFT", +] diff --git a/cola/ops/operator_base.py b/cola/ops/operator_base.py index 7d4f2cf4..44f43d8e 100644 --- a/cola/ops/operator_base.py +++ b/cola/ops/operator_base.py @@ -1,10 +1,12 @@ from abc import abstractmethod -from typing import Union, Tuple, Any from numbers import Number +from typing import Any, Tuple, Union + import numpy as np + import cola +from cola.backends import AutoRegisteringPyTree, get_library_fns, np_fns from cola.utils import export -from cola.backends import np_fns, get_library_fns, AutoRegisteringPyTree Array = Dtype = Any export(Array) diff --git a/cola/utils/__init__.py b/cola/utils/__init__.py index 23dfb1dd..9f5fa4a5 100644 --- a/cola/utils/__init__.py +++ b/cola/utils/__init__.py @@ -1,9 +1,10 @@ # from .dispatch import dispatch, parametric -from plum import dispatch, parametric -import sys -import inspect import importlib +import inspect import logging +import sys + +from plum import dispatch, parametric def export(fn): diff --git a/cola/utils/dispatch.py b/cola/utils/dispatch.py deleted file mode 100644 index 24a1f4b5..00000000 --- a/cola/utils/dispatch.py +++ /dev/null @@ -1,331 +0,0 @@ -# from beartype.door import TypeHint -# from beartype.roar import BeartypeDoorNonpepException - -# from plum.dispatcher import Dispatcher -# from plum.function import _owner_transfer -# from plum.type import resolve_type_hint -# from plum.util import repr_short - -# dispatch = Dispatcher() - -# class ParametricTypeMeta(type): -# """Parametric types can be instantiated with indexing. -# A concrete parametric type can be instantiated by calling `Type[Par1, Par2]`. -# If `Type(Arg1, Arg2, **kw_args)` is called, this returns -# `Type[type(Arg1), type(Arg2)](Arg1, Arg2, **kw_args)`. -# """ -# def __getitem__(cls, p): -# if not cls.concrete: -# # Initialise the type parameters. This can perform, e.g., validation. -# p = p if isinstance(p, tuple) else (p, ) # Ensure that it is a tuple. -# p = cls.__init_type_parameter__(*p) -# # Type parameter has been initialised! Proceed to construct the type. -# p = p if isinstance(p, tuple) else (p, ) # Again ensure that it is a tuple. -# return cls.__new__(cls, *p) -# else: -# raise TypeError("Cannot specify type parameters. This type is concrete.") - -# def __concrete_class__(cls, *args, **kw_args): -# """If `cls` is not a concrete class, infer the type parameters and return a -# concrete class. If `cls` is already a concrete class, simply return it. -# Args: -# *args: Positional arguments passed to the `__init__` method. -# **kw_args: Keyword arguments passed to the `__init__` method. -# Returns: -# type: A concrete class. -# """ -# if getattr(cls, "parametric", False): -# if not cls.concrete: -# type_parameter = cls.__infer_type_parameter__(*args, **kw_args) -# cls = cls[type_parameter] -# return cls - -# def __init_type_parameter__(cls, *ps): -# """Function called to initialise the type parameters. -# The default behaviour is to just return `ps`. -# Args: -# *ps (object): Type parameters. -# Returns: -# object: Initialised type parameters. -# """ -# return ps - -# def __infer_type_parameter__(cls, *args, **kw_args): -# """Function called when the constructor of this parametric type is called -# before the parameters have been specified. -# The default behaviour is to take as parameters the type of every argument, -# but this behaviour can be overridden by redefining this function on the -# metaclass. -# Args: -# *args: Positional arguments passed to the `__init__` method. -# **kw_args: Keyword arguments passed to the `__init__` method. -# Returns: -# type or tuple[type]: A type or tuple of types. -# """ -# type_parameter = tuple(type(arg) for arg in args) -# if len(type_parameter) == 1: -# type_parameter = type_parameter[0] -# return type_parameter - -# @property -# def parametric(cls): -# """bool: Check whether the type is a parametric type.""" -# return getattr(cls, "_parametric", False) - -# @property -# def concrete(cls): -# """bool: Check whether the parametric type is instantiated or not.""" -# if cls.parametric: -# return getattr(cls, "_concrete", False) -# else: -# raise RuntimeError("Cannot check whether a non-parametric type is instantiated or not.") - -# @property -# def type_parameter(cls): -# """object: Get the type parameter. Parametric type must be instantiated.""" -# if cls.concrete: -# return cls._type_parameter -# else: -# raise RuntimeError("Cannot get the type parameter of non-instantiated parametric type.") - -# def _default_le_type_par(p_left, p_right): -# if is_type(p_left) and is_type(p_right): -# p_left = TypeHint(resolve_type_hint(p_left)) -# p_right = TypeHint(resolve_type_hint(p_right)) -# return p_left <= p_right -# else: -# return p_left == p_right - -# class CovariantMeta(ParametricTypeMeta): -# """A metaclass that implements *covariance* of parametric types.""" -# def __subclasscheck__(cls, subclass): -# if is_concrete(cls) and is_concrete(subclass): -# # Check that they are instances of the same parametric type. -# if all(issubclass(b, cls.__bases__) for b in subclass.__bases__): -# p_sub = subclass.type_parameter -# p_cls = cls.type_parameter -# # Ensure that both are in tuple form. -# p_sub = p_sub if isinstance(p_sub, tuple) else (p_sub, ) -# p_cls = p_cls if isinstance(p_cls, tuple) else (p_cls, ) -# return cls.__le_type_parameter__(p_sub, p_cls) - -# # Default behaviour to `type`s subclass check. -# return type.__subclasscheck__(cls, subclass) - -# def __le_type_parameter__(cls, p_left, p_right): -# # Check that there are an equal number of parameters. -# if len(p_left) != len(p_right): -# return False -# # Check every pair of parameters. -# return all(_default_le_type_par(p1, p2) for p1, p2 in zip(p_left, p_right)) - -# def parametric(original_class=None): -# """A decorator for parametric classes. -# When the constructor of this parametric type is called before the type parameter -# has been specified, the type parameter is inferred from the arguments of the -# constructor by calling `__inter_type_parameter__`. The default implementation is -# shown here, but it is possible to override it:: -# @classmethod -# def __infer_type_parameter__(cls, *args, **kw_args) -> tuple: -# return tuple(type(arg) for arg in args) -# After the type parameter is given or inferred, `__init_type_parameter__` is called. -# Again, the default implementation is show here, but it is possible to override it:: -# @classmethod -# def __init_type_parameter__(cls, *ps) -> tuple: -# return ps -# To determine which one instance of a parametric class is a subclass of another, -# the type parameters are compared with `__le_type_parameter__`:: -# @classmethod -# def __le_type_parameter__(cls, left, right) -> bool: -# ... # Is `left <= right`? -# """ - -# original_meta = type(original_class) - -# # Make a metaclass that derives from both the metaclass of `original_meta` and -# # `CovariantMeta`, but make sure not to insert `CovariantMeta` twice, because that -# # will error. - -# if CovariantMeta in original_meta.__mro__: -# bases = (original_meta, ) -# name = original_meta.__name__ -# else: -# bases = (CovariantMeta, original_meta) -# name = f"CovariantMeta[{repr_short(original_meta)}]" - -# def __call__(cls, *args, **kw_args): -# cls = cls.__concrete_class__(*args, **kw_args) -# return original_meta.__call__(cls, *args, **kw_args) - -# meta = type(name, bases, {"__call__": __call__}) - -# subclasses = {} - -# def __new__(cls, *ps): -# # Only create a new subclass if it doesn't exist already. -# if ps not in subclasses: - -# if original_class.__new__ is not object.__new__: - -# def __new__(cls, *args, **kw_args): -# return original_class.__new__(cls, *args, **kw_args) -# else: -# __new__ = original_class.__new__ - -# # Create subclass. -# name = original_class.__name__ -# name += "[" + ", ".join(repr_short(p) for p in ps) + "]" -# subclass = meta( -# name, -# (parametric_class, ), -# {"__new__": __new__}, -# ) -# subclass._parametric = True -# subclass._concrete = True -# subclass._type_parameter = ps[0] if len(ps) == 1 else ps -# subclass.__module__ = original_class.__module__ - -# # Attempt to correct docstring. -# try: -# subclass.__doc__ = original_class.__doc__ -# except AttributeError: # pragma: no cover -# pass - -# subclasses[ps] = subclass -# return subclasses[ps] - -# def __init_subclass__(cls, **kw_args): -# cls._parametric = False -# # If the subclass has the same `__new__` as `ParametricClass`, then we should -# # replace it with the `__new__` of `Class`. If the user already defined another -# # `__new__`, then everything is fine. -# if cls.__new__ is __new__: - -# def class_new(cls, *args, **kw_args): -# return original_class.__new__(cls) - -# cls.__new__ = class_new -# original_class.__init_subclass__(**kw_args) - -# # Create parametric class. -# parametric_class = meta( -# original_class.__name__, -# (original_class, ), -# { -# "__new__": __new__, -# "__init_subclass__": __init_subclass__ -# }, -# ) -# parametric_class._parametric = True -# parametric_class._concrete = False -# parametric_class.__module__ = original_class.__module__ - -# # When dispatch is used in methods of `original_class`, because we return -# # `parametric_class`, `parametric_class` will be inferred as the owner of those -# # functions. This is erroneous, because the owner should be `original_class`. What -# # will happen is that `original_class` will be the next in the MRO, which means -# # that, whenever a `NotFoundLookupError` happens, the method will try itself again, -# # resulting in an infinite loop. To prevent this from happening, we must adjust the -# # owner. -# _owner_transfer[parametric_class] = original_class - -# # Attempt to correct docstring. -# try: -# parametric_class.__doc__ = original_class.__doc__ -# except AttributeError: # pragma: no cover -# pass - -# return parametric_class - -# def is_concrete(t): -# """Check if a type `t` is a concrete instance of a parametric type. -# Args: -# t (type): Type to check. -# Returns: -# bool: `True` if `t` is a concrete instance of a parametric type and `False` -# otherwise. -# """ -# return getattr(t, "parametric", False) and t.concrete - -# def is_type(x): -# """Check whether `x` is a type or a type hint. -# Under the hood, this attempts to construct a :class:`beartype.door.TypeHint` from -# `x`. If successful, then `x` is deemed a type or type hint. -# Args: -# x (object): Object to check. -# Returns: -# bool: Whether `x` is a type or a type hint. -# """ -# try: -# TypeHint(x) -# return True -# except BeartypeDoorNonpepException: -# return False - -# def type_parameter(x): -# """Get the type parameter of concrete parametric type or an instance of a concrete -# parametric type. -# Args: -# x (object): Concrete parametric type or instance thereof. -# Returns: -# object: Type parameter. -# """ -# if is_type(x): -# t = x -# else: -# t = type(x) -# if hasattr(t, "parametric"): -# return t.type_parameter -# raise ValueError(f"`{x}` is not a concrete parametric type or an instance of a" -# f" concrete parametric type.") - -# def kind(SuperClass=object): -# """Create a parametric wrapper type for dispatch purposes. -# Args: -# SuperClass (type): Super class. -# Returns: -# object: New parametric type wrapper. -# """ -# @parametric -# class Kind(SuperClass): -# def __init__(self, *xs): -# self.xs = xs - -# def get(self): -# return self.xs[0] if len(self.xs) == 1 else self.xs - -# return Kind - -# Kind = kind() #: A default kind provided for convenience. - -# @parametric -# class Val: -# """A parametric type used to move information from the value domain to the type -# domain.""" -# @classmethod -# def __infer_type_parameter__(cls, *arg): -# """Function called when the constructor of `Val` is called to determine the type -# parameters.""" -# if len(arg) == 0: -# raise ValueError("The value must be specified.") -# elif len(arg) > 1: -# raise ValueError("Too many values. `Val` accepts only one argument.") -# return arg[0] - -# def __init__(self, val=None): -# """Construct a value object with type `Val(arg)` that can be used to dispatch -# based on values. -# Args: -# val (object): The value to be moved to the type domain. -# """ -# if type(self).concrete: -# if val is not None and type_parameter(self) != val: -# raise ValueError("The value must be equal to the type parameter.") -# else: -# raise ValueError("The value must be specified.") - -# def __repr__(self): -# return repr_short(type(self)) + "()" - -# def __eq__(self, other): -# return type(self) is type(other) diff --git a/cola/utils/jax_tqdm.py b/cola/utils/jax_tqdm.py index 1e506112..9481a9d7 100644 --- a/cola/utils/jax_tqdm.py +++ b/cola/utils/jax_tqdm.py @@ -1,12 +1,13 @@ # Credit to Jeremie Coullon # Adapted from https://github.com/jeremiecoullon/jax-tqdm -import typing import time -from tqdm.auto import tqdm -import numpy as np +import typing + import jax +import numpy as np from jax.experimental import host_callback +from tqdm.auto import tqdm def scan_tqdm(n: int, message: typing.Optional[str] = None) -> typing.Callable: diff --git a/cola/utils/torch_tqdm.py b/cola/utils/torch_tqdm.py index c8b5b064..40b6a599 100644 --- a/cola/utils/torch_tqdm.py +++ b/cola/utils/torch_tqdm.py @@ -1,6 +1,6 @@ import time -import numpy as np +import numpy as np from tqdm.auto import tqdm diff --git a/cola/utils/utils_for_tests.py b/cola/utils/utils_for_tests.py index 4fd1718a..7139ee3d 100644 --- a/cola/utils/utils_for_tests.py +++ b/cola/utils/utils_for_tests.py @@ -1,11 +1,12 @@ +import functools import inspect import itertools + +import numpy as np import pytest -from cola.backends import get_library_fns, get_xnp, all_backends +from cola.backends import all_backends, get_library_fns, get_xnp from cola.backends.np_fns import NumpyNotImplementedError -import numpy as np -import functools get_xnp = get_xnp diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..19930f2f --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +addopts = -m "not tricky and not big and not market"