diff --git a/src/jaxoplanet/experimental/starry/basis.py b/src/jaxoplanet/experimental/starry/basis.py index f3afaf54..8246cf20 100644 --- a/src/jaxoplanet/experimental/starry/basis.py +++ b/src/jaxoplanet/experimental/starry/basis.py @@ -2,35 +2,49 @@ from collections import defaultdict import numpy as np +import scipy.sparse.linalg +from jax.experimental.sparse import BCOO from scipy.special import gamma +try: + from scipy.sparse import csc_array +except ImportError: + # With older versions of scipy, the data structures were called "matrices" + # not "arrays"; this allows us to support either. + from scipy.sparse import csc_matrix as csc_array + + +def basis(lmax): + matrix = scipy.sparse.linalg.spsolve(A2_inv(lmax), A1(lmax)) + if lmax > 0: + return BCOO.from_scipy_sparse(matrix) + else: + return BCOO.fromdense(np.squeeze(matrix)[None, None]) + def A1(lmax): - """Note: The normalization here matches the starry paper, but not the - code. To get the code's normalization, multiply the result by 2 / - sqrt(pi). - """ - n = (lmax + 1) ** 2 - res = np.zeros((n, n)) - p = {ptilde(m): m for m in range(n)} - n = 0 - for l in range(lmax + 1): - for m in range(-l, l + 1): - p_Y(p, l, m, res[:, n]) - n += 1 - return res + return _A_impl(lmax, p_Y) * 2 / np.sqrt(np.pi) def A2_inv(lmax): + return _A_impl(lmax, p_G) + + +def _A_impl(lmax, func): n = (lmax + 1) ** 2 - res = np.zeros((n, n)) + data = [] + row_ind = [] + col_ind = [] p = {ptilde(m): m for m in range(n)} n = 0 for l in range(lmax + 1): - for _ in range(-l, l + 1): - p_G(p, n, res[:, n]) + for m in range(-l, l + 1): + idx, val = func(p, l, m, n) + data.extend(val) + row_ind.extend(idx) + col_ind.extend([n] * len(idx)) n += 1 - return res + return csc_array((np.array(data), (row_ind, col_ind)), shape=(n, n)) def ptilde(n): @@ -116,12 +130,19 @@ def Ylm(l, m): return dict(res) -def p_Y(p, l, m, res): +def p_Y(p, l, m, n): + del n + indicies = [] + data = [] for k, v in Ylm(l, m).items(): if k not in p: continue - res[p[k]] = v - return res + indicies.append(p[k]) + data.append(v) + indicies = np.array(indicies, dtype=int) + data = np.array(data, dtype=float) + idx = np.argsort(indicies) + return indicies[idx], data[idx] def gtilde(n): @@ -160,9 +181,16 @@ def gtilde(n): return res -def p_G(p, n, res): +def p_G(p, l, m, n): + del l, m + indicies = [] + data = [] for k, v in gtilde(n).items(): if k not in p: continue - res[p[k]] = v - return res + indicies.append(p[k]) + data.append(v) + indicies = np.array(indicies, dtype=int) + data = np.array(data, dtype=float) + idx = np.argsort(indicies) + return indicies[idx], data[idx] diff --git a/src/jaxoplanet/experimental/starry/wigner.py b/src/jaxoplanet/experimental/starry/wigner.py index e435133b..902ceb3c 100644 --- a/src/jaxoplanet/experimental/starry/wigner.py +++ b/src/jaxoplanet/experimental/starry/wigner.py @@ -20,6 +20,17 @@ def dot_rotation_matrix(ydeg, x, y, z, theta): ydeg = int(ydeg) except TypeError as e: raise TypeError(f"ydeg must be an integer; got {ydeg}") from e + + if x is None and y is None: + if z is None: + raise ValueError("Either x, y, or z must be specified") + + return dot_rz(ydeg, theta) + + x = 0.0 if x is None else x + y = 0.0 if y is None else y + z = 0.0 if z is None else z + if jnp.shape(x) != (): raise ValueError(f"x must be a scalar; got {jnp.shape(x)}") if jnp.shape(y) != (): @@ -272,3 +283,42 @@ def dlmn(ell, s1, c1, c2, tgbet2, s3, c3, D): cosmal = aux return jnp.asarray(D_), jnp.asarray(R_) + + +def dot_rz(deg, theta): + """Special case for rotation only around z axis""" + c = jnp.cos(theta) + s = jnp.sin(theta) + cosnt = [1.0, c] + sinnt = [0.0, s] + for n in range(2, deg + 1): + cosnt.append(2.0 * cosnt[n - 1] * c - cosnt[n - 2]) + sinnt.append(2.0 * sinnt[n - 1] * c - sinnt[n - 2]) + + n = 0 + cosmt = [] + sinmt = [] + for ell in range(deg + 1): + for m in range(-ell, 0): + cosmt.append(cosnt[-m]) + sinmt.append(-sinnt[-m]) + for m in range(ell + 1): + cosmt.append(cosnt[m]) + sinmt.append(sinnt[m]) + + n_max = deg**2 + 2 * deg + 1 + + @jax.jit + @partial(jnp.vectorize, signature=f"({n_max})->({n_max})") + def impl(M): + result = [0 for _ in range(n_max)] + for ell in range(deg + 1): + for j in range(2 * ell + 1): + result[ell * ell + j] = ( + M[ell * ell + j] * cosmt[ell * ell + j] + + M[ell * ell + 2 * ell - j] * sinmt[ell * ell + j] + ) + + return jnp.array(result, dtype=jnp.dtype(M)) + + return impl diff --git a/tests/experimental/starry/basis_test.py b/tests/experimental/starry/basis_test.py index 3fe9c58e..e0dbba2b 100644 --- a/tests/experimental/starry/basis_test.py +++ b/tests/experimental/starry/basis_test.py @@ -3,14 +3,15 @@ import numpy as np import pytest -from jaxoplanet.experimental.starry.basis import A1, A2_inv +from jaxoplanet.experimental.starry.basis import A1, A2_inv, basis +from jaxoplanet.test_utils import assert_allclose @pytest.mark.parametrize("lmax", [10, 7, 5, 4, 3, 2, 1, 0]) def test_A1(lmax): pytest.importorskip("sympy") - expected = A1_symbolic(lmax) - calc = A1(lmax) + expected = A1_symbolic(lmax) / (0.5 * np.sqrt(np.pi)) + calc = A1(lmax).todense() np.testing.assert_allclose(calc, expected, atol=5e-12) @@ -18,30 +19,41 @@ def test_A1(lmax): def test_A2_inv(lmax): pytest.importorskip("sympy") expected = A2_inv_symbolic(lmax) - calc = A2_inv(lmax) + calc = A2_inv(lmax).todense() np.testing.assert_allclose(calc, expected) @pytest.mark.parametrize("lmax", [10, 7, 5, 4, 3, 2, 1, 0]) -def test_compare_starry_A1(lmax): +def test_A1_compare_starry(lmax): starry = pytest.importorskip("starry") with warnings.catch_warnings(): warnings.simplefilter("ignore") m = starry.Map(lmax) - expect = m.ops.A1.eval().toarray() * (0.5 * np.sqrt(np.pi)) - calc = A1(lmax) + expect = m.ops.A1.eval().toarray() + calc = A1(lmax).todense() np.testing.assert_allclose(calc, expect, atol=5e-12) @pytest.mark.parametrize("lmax", [10, 7, 5, 4, 3, 2, 1, 0]) -def test_compare_starry_A2_inv(lmax): +def test_A2_inv_compare_starry(lmax): starry = pytest.importorskip("starry") with warnings.catch_warnings(): warnings.simplefilter("ignore") m = starry.Map(lmax) A2 = m.ops.A.eval().toarray() @ m.ops.A1Inv.eval().toarray() inv = A2_inv(lmax) - np.testing.assert_allclose(inv @ A2, np.eye(len(inv)), atol=5e-12) + np.testing.assert_allclose(inv @ A2, np.eye(inv.shape[0]), atol=5e-12) + + +@pytest.mark.parametrize("lmax", [10, 7, 5, 4, 3, 2, 1, 0]) +def test_basis_compare_starry(lmax): + starry = pytest.importorskip("starry") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + m = starry.Map(lmax) + expect = m.ops.A.eval().toarray() + calc = basis(lmax).todense() + assert_allclose(calc, expect) def A1_symbolic(lmax): diff --git a/tests/experimental/starry/solution_test.py b/tests/experimental/starry/solution_test.py index 600e11b7..a6564cf1 100644 --- a/tests/experimental/starry/solution_test.py +++ b/tests/experimental/starry/solution_test.py @@ -39,7 +39,7 @@ def _lam(b, r): @pytest.mark.parametrize("r", [0.1, 1.1]) -def test_compare_starry(r, l_max=10, order=20): +def test_solution_compare_starry(r, l_max=10, order=20): starry = pytest.importorskip("starry") theano = pytest.importorskip("theano") theano.config.gcc__cxxflags += " -fexceptions" diff --git a/tests/experimental/starry/wigner_test.py b/tests/experimental/starry/wigner_test.py index 9dee56c8..1b317dc2 100644 --- a/tests/experimental/starry/wigner_test.py +++ b/tests/experimental/starry/wigner_test.py @@ -8,16 +8,36 @@ @pytest.mark.parametrize("l_max", [5, 4, 3, 2, 1, 0]) @pytest.mark.parametrize("u", [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)]) -def test_dot_rotation(l_max, u): +@pytest.mark.parametrize("theta", [0.1]) +def test_dot_rotation(l_max, u, theta): """Test full rotation matrix against symbolic one""" pytest.importorskip("sympy") ident = np.eye(l_max**2 + 2 * l_max + 1) - theta = 0.1 expected = np.array(R_symbolic(l_max, u, theta)).astype(float) calc = dot_rotation_matrix(l_max, u[0], u[1], u[2], theta)(ident) assert_allclose(calc, expected) +@pytest.mark.parametrize("l_max", [5, 4, 3, 2, 1, 0]) +@pytest.mark.parametrize("theta", [-0.5, 0.0, 0.1, 1.5 * np.pi]) +def test_dot_rotation_z(l_max, theta): + ident = np.eye(l_max**2 + 2 * l_max + 1) + expected = dot_rotation_matrix(l_max, 0.0, 0.0, 1.0, theta)(ident) + calc = dot_rotation_matrix(l_max, None, None, 1.0, theta)(ident) + assert_allclose(calc, expected) + + +def test_dot_rotation_negative(): + starry = pytest.importorskip("starry") + l_max = 5 + n_max = l_max**2 + 2 * l_max + 1 + y = np.linspace(-1, 1, n_max) + starry_op = starry._core.core.OpsYlm(l_max, 0, 0, 1) + expected = starry_op.dotR(y[None, :], 1.0, 0, 0.0, -0.5 * np.pi)[0] + calc = dot_rotation_matrix(l_max, 1.0, 0.0, 0.0, -0.5 * np.pi)(y) + assert_allclose(calc, expected) + + def test_dot_rotation_edge_cases(): l_max = 5 n_max = l_max**2 + 2 * l_max + 1 @@ -34,11 +54,11 @@ def test_dot_rotation_edge_cases(): @pytest.mark.parametrize("l_max", [10, 7, 5, 4]) @pytest.mark.parametrize("u", [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)]) -def test_dot_rotation_compare_starry(l_max, u): +@pytest.mark.parametrize("theta", [0.1]) +def test_dot_rotation_compare_starry(l_max, u, theta): """Comparison test with starry OpsYlm.dotR""" starry = pytest.importorskip("starry") random = np.random.default_rng(l_max) - theta = 0.1 n_max = l_max**2 + 2 * l_max + 1 M1 = np.eye(n_max) M2 = random.normal(size=(5, n_max))