From 08e6e380460bdac8f9cd64bcdb2f2ab0a811518e Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Tue, 10 Oct 2023 18:02:28 -0400 Subject: [PATCH 1/6] using sparse linear algebra for the basis computations --- src/jaxoplanet/experimental/starry/basis.py | 72 ++++++++++++++------- tests/experimental/starry/basis_test.py | 30 ++++++--- tests/experimental/starry/solution_test.py | 2 +- 3 files changed, 71 insertions(+), 33 deletions(-) diff --git a/src/jaxoplanet/experimental/starry/basis.py b/src/jaxoplanet/experimental/starry/basis.py index f3afaf54..59c7eee9 100644 --- a/src/jaxoplanet/experimental/starry/basis.py +++ b/src/jaxoplanet/experimental/starry/basis.py @@ -2,35 +2,47 @@ from collections import defaultdict import numpy as np +import scipy.sparse.linalg +from jax.experimental.sparse import BCOO from scipy.special import gamma +try: + from scipy.sparse import csc_array +except ImportError: + from scipy.sparse import csc_matrix as csc_array + def A1(lmax): - """Note: The normalization here matches the starry paper, but not the - code. To get the code's normalization, multiply the result by 2 / - sqrt(pi). - """ - n = (lmax + 1) ** 2 - res = np.zeros((n, n)) - p = {ptilde(m): m for m in range(n)} - n = 0 - for l in range(lmax + 1): - for m in range(-l, l + 1): - p_Y(p, l, m, res[:, n]) - n += 1 - return res + return _A_impl(lmax, p_Y) * 2 / np.sqrt(np.pi) def A2_inv(lmax): + return _A_impl(lmax, p_G) + + +def _A_impl(lmax, func): n = (lmax + 1) ** 2 - res = np.zeros((n, n)) + data = [] + row_ind = [] + col_ind = [] p = {ptilde(m): m for m in range(n)} n = 0 for l in range(lmax + 1): - for _ in range(-l, l + 1): - p_G(p, n, res[:, n]) + for m in range(-l, l + 1): + idx, val = func(p, l, m, n) + data.extend(val) + row_ind.extend(idx) + col_ind.extend([n] * len(idx)) n += 1 - return res + return csc_array((np.array(data), (row_ind, col_ind)), shape=(n, n)) + + +def A(lmax): + matrix = scipy.sparse.linalg.spsolve(A2_inv(lmax), A1(lmax)) + if lmax > 0: + return BCOO.from_scipy_sparse(matrix) + else: + return BCOO.fromdense(np.squeeze(matrix)[None, None]) def ptilde(n): @@ -116,12 +128,19 @@ def Ylm(l, m): return dict(res) -def p_Y(p, l, m, res): +def p_Y(p, l, m, n): + del n + indicies = [] + data = [] for k, v in Ylm(l, m).items(): if k not in p: continue - res[p[k]] = v - return res + indicies.append(p[k]) + data.append(v) + indicies = np.array(indicies, dtype=int) + data = np.array(data, dtype=float) + idx = np.argsort(indicies) + return indicies[idx], data[idx] def gtilde(n): @@ -160,9 +179,16 @@ def gtilde(n): return res -def p_G(p, n, res): +def p_G(p, l, m, n): + del l, m + indicies = [] + data = [] for k, v in gtilde(n).items(): if k not in p: continue - res[p[k]] = v - return res + indicies.append(p[k]) + data.append(v) + indicies = np.array(indicies, dtype=int) + data = np.array(data, dtype=float) + idx = np.argsort(indicies) + return indicies[idx], data[idx] diff --git a/tests/experimental/starry/basis_test.py b/tests/experimental/starry/basis_test.py index 3fe9c58e..0344d00b 100644 --- a/tests/experimental/starry/basis_test.py +++ b/tests/experimental/starry/basis_test.py @@ -3,14 +3,15 @@ import numpy as np import pytest -from jaxoplanet.experimental.starry.basis import A1, A2_inv +from jaxoplanet.experimental.starry.basis import A1, A, A2_inv +from jaxoplanet.test_utils import assert_allclose @pytest.mark.parametrize("lmax", [10, 7, 5, 4, 3, 2, 1, 0]) def test_A1(lmax): pytest.importorskip("sympy") - expected = A1_symbolic(lmax) - calc = A1(lmax) + expected = A1_symbolic(lmax) / (0.5 * np.sqrt(np.pi)) + calc = A1(lmax).todense() np.testing.assert_allclose(calc, expected, atol=5e-12) @@ -18,30 +19,41 @@ def test_A1(lmax): def test_A2_inv(lmax): pytest.importorskip("sympy") expected = A2_inv_symbolic(lmax) - calc = A2_inv(lmax) + calc = A2_inv(lmax).todense() np.testing.assert_allclose(calc, expected) @pytest.mark.parametrize("lmax", [10, 7, 5, 4, 3, 2, 1, 0]) -def test_compare_starry_A1(lmax): +def test_A1_compare_starry(lmax): starry = pytest.importorskip("starry") with warnings.catch_warnings(): warnings.simplefilter("ignore") m = starry.Map(lmax) - expect = m.ops.A1.eval().toarray() * (0.5 * np.sqrt(np.pi)) - calc = A1(lmax) + expect = m.ops.A1.eval().toarray() + calc = A1(lmax).todense() np.testing.assert_allclose(calc, expect, atol=5e-12) @pytest.mark.parametrize("lmax", [10, 7, 5, 4, 3, 2, 1, 0]) -def test_compare_starry_A2_inv(lmax): +def test_A2_inv_compare_starry(lmax): starry = pytest.importorskip("starry") with warnings.catch_warnings(): warnings.simplefilter("ignore") m = starry.Map(lmax) A2 = m.ops.A.eval().toarray() @ m.ops.A1Inv.eval().toarray() inv = A2_inv(lmax) - np.testing.assert_allclose(inv @ A2, np.eye(len(inv)), atol=5e-12) + np.testing.assert_allclose(inv @ A2, np.eye(inv.shape[0]), atol=5e-12) + + +@pytest.mark.parametrize("lmax", [10, 7, 5, 4, 3, 2, 1, 0]) +def test_A_compare_starry(lmax): + starry = pytest.importorskip("starry") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + m = starry.Map(lmax) + expect = m.ops.A.eval().toarray() + calc = A(lmax).todense() + assert_allclose(calc, expect) def A1_symbolic(lmax): diff --git a/tests/experimental/starry/solution_test.py b/tests/experimental/starry/solution_test.py index 600e11b7..a6564cf1 100644 --- a/tests/experimental/starry/solution_test.py +++ b/tests/experimental/starry/solution_test.py @@ -39,7 +39,7 @@ def _lam(b, r): @pytest.mark.parametrize("r", [0.1, 1.1]) -def test_compare_starry(r, l_max=10, order=20): +def test_solution_compare_starry(r, l_max=10, order=20): starry = pytest.importorskip("starry") theano = pytest.importorskip("theano") theano.config.gcc__cxxflags += " -fexceptions" From 39b941a7ddf33ff8f0735121f4b77f8208fc40b6 Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Tue, 10 Oct 2023 18:06:48 -0400 Subject: [PATCH 2/6] rename A -> basis --- src/jaxoplanet/experimental/starry/basis.py | 16 ++++++++-------- tests/experimental/starry/basis_test.py | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/jaxoplanet/experimental/starry/basis.py b/src/jaxoplanet/experimental/starry/basis.py index 59c7eee9..160e689a 100644 --- a/src/jaxoplanet/experimental/starry/basis.py +++ b/src/jaxoplanet/experimental/starry/basis.py @@ -12,6 +12,14 @@ from scipy.sparse import csc_matrix as csc_array +def basis(lmax): + matrix = scipy.sparse.linalg.spsolve(A2_inv(lmax), A1(lmax)) + if lmax > 0: + return BCOO.from_scipy_sparse(matrix) + else: + return BCOO.fromdense(np.squeeze(matrix)[None, None]) + + def A1(lmax): return _A_impl(lmax, p_Y) * 2 / np.sqrt(np.pi) @@ -37,14 +45,6 @@ def _A_impl(lmax, func): return csc_array((np.array(data), (row_ind, col_ind)), shape=(n, n)) -def A(lmax): - matrix = scipy.sparse.linalg.spsolve(A2_inv(lmax), A1(lmax)) - if lmax > 0: - return BCOO.from_scipy_sparse(matrix) - else: - return BCOO.fromdense(np.squeeze(matrix)[None, None]) - - def ptilde(n): l = math.floor(math.sqrt(n)) m = n - l * l - l diff --git a/tests/experimental/starry/basis_test.py b/tests/experimental/starry/basis_test.py index 0344d00b..e0dbba2b 100644 --- a/tests/experimental/starry/basis_test.py +++ b/tests/experimental/starry/basis_test.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from jaxoplanet.experimental.starry.basis import A1, A, A2_inv +from jaxoplanet.experimental.starry.basis import A1, A2_inv, basis from jaxoplanet.test_utils import assert_allclose @@ -46,13 +46,13 @@ def test_A2_inv_compare_starry(lmax): @pytest.mark.parametrize("lmax", [10, 7, 5, 4, 3, 2, 1, 0]) -def test_A_compare_starry(lmax): +def test_basis_compare_starry(lmax): starry = pytest.importorskip("starry") with warnings.catch_warnings(): warnings.simplefilter("ignore") m = starry.Map(lmax) expect = m.ops.A.eval().toarray() - calc = A(lmax).todense() + calc = basis(lmax).todense() assert_allclose(calc, expect) From 463a90995dc32dab71c65a5579c1e3997ec7d9f1 Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Tue, 10 Oct 2023 21:30:04 -0400 Subject: [PATCH 3/6] Adding initial flux implementation --- src/jaxoplanet/experimental/starry/basis.py | 2 + src/jaxoplanet/experimental/starry/flux.py | 105 ++++++++++++++++++++ tests/experimental/starry/wigner_test.py | 19 +++- 3 files changed, 122 insertions(+), 4 deletions(-) create mode 100644 src/jaxoplanet/experimental/starry/flux.py diff --git a/src/jaxoplanet/experimental/starry/basis.py b/src/jaxoplanet/experimental/starry/basis.py index 160e689a..8246cf20 100644 --- a/src/jaxoplanet/experimental/starry/basis.py +++ b/src/jaxoplanet/experimental/starry/basis.py @@ -9,6 +9,8 @@ try: from scipy.sparse import csc_array except ImportError: + # With older versions of scipy, the data structures were called "matrices" + # not "arrays"; this allows us to support either. from scipy.sparse import csc_matrix as csc_array diff --git a/src/jaxoplanet/experimental/starry/flux.py b/src/jaxoplanet/experimental/starry/flux.py new file mode 100644 index 00000000..25e951e9 --- /dev/null +++ b/src/jaxoplanet/experimental/starry/flux.py @@ -0,0 +1,105 @@ +import jax.numpy as jnp +import numpy as np + +from jaxoplanet.experimental.starry.basis import A1, basis +from jaxoplanet.experimental.starry.solution import solution_vector +from jaxoplanet.experimental.starry.wigner import dot_rotation_matrix + + +def flux(deg, theta, xo, yo, zo, ro, inc, obl, y, u, f): + b = jnp.sqrt(jnp.square(xo) + jnp.square(yo)) + b_rot = jnp.logical_or(jnp.greater_equal(b, 1.0 + ro), jnp.less_equal(zo, 0.0)) + b_occ = jnp.logical_not(b_rot) + + # Occultation + theta_z = jnp.arctan2(xo, yo) + sT = solution_vector(deg)(b, ro) + sTA = sT @ basis(deg) + sTAR = tensordot_rz(deg, sTA, theta_z) + + x = jnp.where(b_occ, sTAR, rTA1(deg)) + + x = dot_rotation_matrix( + deg, -jnp.cos(obl), -jnp.sin(obl), 0.0, -(0.5 * jnp.pi - inc) + )(x) + x = dot_rotation_matrix(deg, 0.0, 0.0, 1.0, obl)(x) + x = dot_rotation_matrix(deg, 1.0, 0.0, 0.0, -0.5 * jnp.pi)(x) + x = dot_rotation_matrix(deg, 0.0, 0.0, 1.0, theta)(x) + x = dot_rotation_matrix(deg, 1.0, 0.0, 0.0, 0.5 * jnp.pi)(x) + + return x @ y + + +def tensordot_rz(deg, M, theta): + c = jnp.cos(theta) + s = jnp.sin(theta) + cosnt = [1.0, c] + sinnt = [0.0, s] + for n in range(2, deg + 1): + cosnt.append(2.0 * cosnt[n - 1] * c - cosnt[n - 2]) + sinnt.append(2.0 * sinnt[n - 1] * c - sinnt[n - 2]) + + n = 0 + cosmt = [] + sinmt = [] + for ell in range(deg + 1): + for m in range(-ell, 0): + cosmt.append(cosnt[-m]) + sinmt.append(-sinnt[-m]) + for m in range(ell + 1): + cosmt.append(cosnt[m]) + sinmt.append(sinnt[m]) + + result = [0 for _ in range(M.shape[0])] + for ell in range(deg + 1): + for j in range(2 * ell + 1): + result[ell * ell + j] = ( + M[ell * ell + j] * cosmt[ell * ell + j] + + M[ell * ell + 2 * ell - j] * sinmt[ell * ell + j] + ) + + return jnp.array(result) + + +def rT(lmax): + rt = [0 for _ in range((lmax + 1) * (lmax + 1))] + amp0 = jnp.pi + lfac1 = 1.0 + lfac2 = 2.0 / 3.0 + for ell in range(0, lmax + 1, 4): + amp = amp0 + for m in range(0, ell + 1, 4): + mu = ell - m + nu = ell + m + rt[ell * ell + ell + m] = amp * lfac1 + rt[ell * ell + ell - m] = amp * lfac1 + if ell < lmax: + rt[(ell + 1) * (ell + 1) + ell + m + 1] = amp * lfac2 + rt[(ell + 1) * (ell + 1) + ell - m + 1] = amp * lfac2 + amp *= (nu + 2.0) / (mu - 2.0) + lfac1 /= (ell / 2 + 2) * (ell / 2 + 3) + lfac2 /= (ell / 2 + 2.5) * (ell / 2 + 3.5) + amp0 *= 0.0625 * (ell + 2) * (ell + 2) + + amp0 = 0.5 * jnp.pi + lfac1 = 0.5 + lfac2 = 4.0 / 15.0 + for ell in range(2, lmax + 1, 4): + amp = amp0 + for m in range(2, ell + 1, 4): + mu = ell - m + nu = ell + m + rt[ell * ell + ell + m] = amp * lfac1 + rt[ell * ell + ell - m] = amp * lfac1 + if ell < lmax: + rt[(ell + 1) * (ell + 1) + ell + m + 1] = amp * lfac2 + rt[(ell + 1) * (ell + 1) + ell - m + 1] = amp * lfac2 + amp *= (nu + 2.0) / (mu - 2.0) + lfac1 /= (ell / 2 + 2) * (ell / 2 + 3) + lfac2 /= (ell / 2 + 2.5) * (ell / 2 + 3.5) + amp0 *= 0.0625 * ell * (ell + 4) + return np.array(rt) + + +def rTA1(lmax): + return rT(lmax) @ A1(lmax) diff --git a/tests/experimental/starry/wigner_test.py b/tests/experimental/starry/wigner_test.py index 9dee56c8..bbf4cfd8 100644 --- a/tests/experimental/starry/wigner_test.py +++ b/tests/experimental/starry/wigner_test.py @@ -8,16 +8,27 @@ @pytest.mark.parametrize("l_max", [5, 4, 3, 2, 1, 0]) @pytest.mark.parametrize("u", [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)]) -def test_dot_rotation(l_max, u): +@pytest.mark.parametrize("theta", [0.1]) +def test_dot_rotation(l_max, u, theta): """Test full rotation matrix against symbolic one""" pytest.importorskip("sympy") ident = np.eye(l_max**2 + 2 * l_max + 1) - theta = 0.1 expected = np.array(R_symbolic(l_max, u, theta)).astype(float) calc = dot_rotation_matrix(l_max, u[0], u[1], u[2], theta)(ident) assert_allclose(calc, expected) +def test_dot_rotation_negative(): + starry = pytest.importorskip("starry") + l_max = 5 + n_max = l_max**2 + 2 * l_max + 1 + y = np.linspace(-1, 1, n_max) + starry_op = starry._core.core.OpsYlm(l_max, 0, 0, 1) + expected = starry_op.dotR(y[None, :], 1.0, 0, 0.0, -0.5 * np.pi)[0] + calc = dot_rotation_matrix(l_max, 1.0, 0.0, 0.0, -0.5 * np.pi)(y) + assert_allclose(calc, expected) + + def test_dot_rotation_edge_cases(): l_max = 5 n_max = l_max**2 + 2 * l_max + 1 @@ -34,11 +45,11 @@ def test_dot_rotation_edge_cases(): @pytest.mark.parametrize("l_max", [10, 7, 5, 4]) @pytest.mark.parametrize("u", [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)]) -def test_dot_rotation_compare_starry(l_max, u): +@pytest.mark.parametrize("theta", [0.1]) +def test_dot_rotation_compare_starry(l_max, u, theta): """Comparison test with starry OpsYlm.dotR""" starry = pytest.importorskip("starry") random = np.random.default_rng(l_max) - theta = 0.1 n_max = l_max**2 + 2 * l_max + 1 M1 = np.eye(n_max) M2 = random.normal(size=(5, n_max)) From ed4887921fdc1f7c5935eb63c2222fa0d11053aa Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Wed, 11 Oct 2023 08:06:34 -0400 Subject: [PATCH 4/6] adding special case of rotation for just z axis --- src/jaxoplanet/experimental/starry/flux.py | 6 +-- src/jaxoplanet/experimental/starry/wigner.py | 50 ++++++++++++++++++++ tests/experimental/starry/wigner_test.py | 9 ++++ 3 files changed, 62 insertions(+), 3 deletions(-) diff --git a/src/jaxoplanet/experimental/starry/flux.py b/src/jaxoplanet/experimental/starry/flux.py index 25e951e9..4c2ba716 100644 --- a/src/jaxoplanet/experimental/starry/flux.py +++ b/src/jaxoplanet/experimental/starry/flux.py @@ -15,16 +15,16 @@ def flux(deg, theta, xo, yo, zo, ro, inc, obl, y, u, f): theta_z = jnp.arctan2(xo, yo) sT = solution_vector(deg)(b, ro) sTA = sT @ basis(deg) - sTAR = tensordot_rz(deg, sTA, theta_z) + sTAR = dot_rotation_matrix(deg, None, None, 1.0, theta_z)(sTA) x = jnp.where(b_occ, sTAR, rTA1(deg)) x = dot_rotation_matrix( deg, -jnp.cos(obl), -jnp.sin(obl), 0.0, -(0.5 * jnp.pi - inc) )(x) - x = dot_rotation_matrix(deg, 0.0, 0.0, 1.0, obl)(x) + x = dot_rotation_matrix(deg, None, None, 1.0, obl)(x) x = dot_rotation_matrix(deg, 1.0, 0.0, 0.0, -0.5 * jnp.pi)(x) - x = dot_rotation_matrix(deg, 0.0, 0.0, 1.0, theta)(x) + x = dot_rotation_matrix(deg, None, None, 1.0, theta)(x) x = dot_rotation_matrix(deg, 1.0, 0.0, 0.0, 0.5 * jnp.pi)(x) return x @ y diff --git a/src/jaxoplanet/experimental/starry/wigner.py b/src/jaxoplanet/experimental/starry/wigner.py index e435133b..902ceb3c 100644 --- a/src/jaxoplanet/experimental/starry/wigner.py +++ b/src/jaxoplanet/experimental/starry/wigner.py @@ -20,6 +20,17 @@ def dot_rotation_matrix(ydeg, x, y, z, theta): ydeg = int(ydeg) except TypeError as e: raise TypeError(f"ydeg must be an integer; got {ydeg}") from e + + if x is None and y is None: + if z is None: + raise ValueError("Either x, y, or z must be specified") + + return dot_rz(ydeg, theta) + + x = 0.0 if x is None else x + y = 0.0 if y is None else y + z = 0.0 if z is None else z + if jnp.shape(x) != (): raise ValueError(f"x must be a scalar; got {jnp.shape(x)}") if jnp.shape(y) != (): @@ -272,3 +283,42 @@ def dlmn(ell, s1, c1, c2, tgbet2, s3, c3, D): cosmal = aux return jnp.asarray(D_), jnp.asarray(R_) + + +def dot_rz(deg, theta): + """Special case for rotation only around z axis""" + c = jnp.cos(theta) + s = jnp.sin(theta) + cosnt = [1.0, c] + sinnt = [0.0, s] + for n in range(2, deg + 1): + cosnt.append(2.0 * cosnt[n - 1] * c - cosnt[n - 2]) + sinnt.append(2.0 * sinnt[n - 1] * c - sinnt[n - 2]) + + n = 0 + cosmt = [] + sinmt = [] + for ell in range(deg + 1): + for m in range(-ell, 0): + cosmt.append(cosnt[-m]) + sinmt.append(-sinnt[-m]) + for m in range(ell + 1): + cosmt.append(cosnt[m]) + sinmt.append(sinnt[m]) + + n_max = deg**2 + 2 * deg + 1 + + @jax.jit + @partial(jnp.vectorize, signature=f"({n_max})->({n_max})") + def impl(M): + result = [0 for _ in range(n_max)] + for ell in range(deg + 1): + for j in range(2 * ell + 1): + result[ell * ell + j] = ( + M[ell * ell + j] * cosmt[ell * ell + j] + + M[ell * ell + 2 * ell - j] * sinmt[ell * ell + j] + ) + + return jnp.array(result, dtype=jnp.dtype(M)) + + return impl diff --git a/tests/experimental/starry/wigner_test.py b/tests/experimental/starry/wigner_test.py index bbf4cfd8..1b317dc2 100644 --- a/tests/experimental/starry/wigner_test.py +++ b/tests/experimental/starry/wigner_test.py @@ -18,6 +18,15 @@ def test_dot_rotation(l_max, u, theta): assert_allclose(calc, expected) +@pytest.mark.parametrize("l_max", [5, 4, 3, 2, 1, 0]) +@pytest.mark.parametrize("theta", [-0.5, 0.0, 0.1, 1.5 * np.pi]) +def test_dot_rotation_z(l_max, theta): + ident = np.eye(l_max**2 + 2 * l_max + 1) + expected = dot_rotation_matrix(l_max, 0.0, 0.0, 1.0, theta)(ident) + calc = dot_rotation_matrix(l_max, None, None, 1.0, theta)(ident) + assert_allclose(calc, expected) + + def test_dot_rotation_negative(): starry = pytest.importorskip("starry") l_max = 5 From 2745727ef8e4a00ff17880777eb81040410cc1db Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Wed, 11 Oct 2023 08:08:18 -0400 Subject: [PATCH 5/6] removing tensordot function --- src/jaxoplanet/experimental/starry/flux.py | 31 ---------------------- 1 file changed, 31 deletions(-) diff --git a/src/jaxoplanet/experimental/starry/flux.py b/src/jaxoplanet/experimental/starry/flux.py index 4c2ba716..78afd65a 100644 --- a/src/jaxoplanet/experimental/starry/flux.py +++ b/src/jaxoplanet/experimental/starry/flux.py @@ -30,37 +30,6 @@ def flux(deg, theta, xo, yo, zo, ro, inc, obl, y, u, f): return x @ y -def tensordot_rz(deg, M, theta): - c = jnp.cos(theta) - s = jnp.sin(theta) - cosnt = [1.0, c] - sinnt = [0.0, s] - for n in range(2, deg + 1): - cosnt.append(2.0 * cosnt[n - 1] * c - cosnt[n - 2]) - sinnt.append(2.0 * sinnt[n - 1] * c - sinnt[n - 2]) - - n = 0 - cosmt = [] - sinmt = [] - for ell in range(deg + 1): - for m in range(-ell, 0): - cosmt.append(cosnt[-m]) - sinmt.append(-sinnt[-m]) - for m in range(ell + 1): - cosmt.append(cosnt[m]) - sinmt.append(sinnt[m]) - - result = [0 for _ in range(M.shape[0])] - for ell in range(deg + 1): - for j in range(2 * ell + 1): - result[ell * ell + j] = ( - M[ell * ell + j] * cosmt[ell * ell + j] - + M[ell * ell + 2 * ell - j] * sinmt[ell * ell + j] - ) - - return jnp.array(result) - - def rT(lmax): rt = [0 for _ in range((lmax + 1) * (lmax + 1))] amp0 = jnp.pi From c1b182d6fffcd3544d2d15129f16abe634cdfcb5 Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Thu, 12 Oct 2023 10:35:43 -0400 Subject: [PATCH 6/6] removing flux function from this PR --- src/jaxoplanet/experimental/starry/flux.py | 74 ---------------------- 1 file changed, 74 deletions(-) delete mode 100644 src/jaxoplanet/experimental/starry/flux.py diff --git a/src/jaxoplanet/experimental/starry/flux.py b/src/jaxoplanet/experimental/starry/flux.py deleted file mode 100644 index 78afd65a..00000000 --- a/src/jaxoplanet/experimental/starry/flux.py +++ /dev/null @@ -1,74 +0,0 @@ -import jax.numpy as jnp -import numpy as np - -from jaxoplanet.experimental.starry.basis import A1, basis -from jaxoplanet.experimental.starry.solution import solution_vector -from jaxoplanet.experimental.starry.wigner import dot_rotation_matrix - - -def flux(deg, theta, xo, yo, zo, ro, inc, obl, y, u, f): - b = jnp.sqrt(jnp.square(xo) + jnp.square(yo)) - b_rot = jnp.logical_or(jnp.greater_equal(b, 1.0 + ro), jnp.less_equal(zo, 0.0)) - b_occ = jnp.logical_not(b_rot) - - # Occultation - theta_z = jnp.arctan2(xo, yo) - sT = solution_vector(deg)(b, ro) - sTA = sT @ basis(deg) - sTAR = dot_rotation_matrix(deg, None, None, 1.0, theta_z)(sTA) - - x = jnp.where(b_occ, sTAR, rTA1(deg)) - - x = dot_rotation_matrix( - deg, -jnp.cos(obl), -jnp.sin(obl), 0.0, -(0.5 * jnp.pi - inc) - )(x) - x = dot_rotation_matrix(deg, None, None, 1.0, obl)(x) - x = dot_rotation_matrix(deg, 1.0, 0.0, 0.0, -0.5 * jnp.pi)(x) - x = dot_rotation_matrix(deg, None, None, 1.0, theta)(x) - x = dot_rotation_matrix(deg, 1.0, 0.0, 0.0, 0.5 * jnp.pi)(x) - - return x @ y - - -def rT(lmax): - rt = [0 for _ in range((lmax + 1) * (lmax + 1))] - amp0 = jnp.pi - lfac1 = 1.0 - lfac2 = 2.0 / 3.0 - for ell in range(0, lmax + 1, 4): - amp = amp0 - for m in range(0, ell + 1, 4): - mu = ell - m - nu = ell + m - rt[ell * ell + ell + m] = amp * lfac1 - rt[ell * ell + ell - m] = amp * lfac1 - if ell < lmax: - rt[(ell + 1) * (ell + 1) + ell + m + 1] = amp * lfac2 - rt[(ell + 1) * (ell + 1) + ell - m + 1] = amp * lfac2 - amp *= (nu + 2.0) / (mu - 2.0) - lfac1 /= (ell / 2 + 2) * (ell / 2 + 3) - lfac2 /= (ell / 2 + 2.5) * (ell / 2 + 3.5) - amp0 *= 0.0625 * (ell + 2) * (ell + 2) - - amp0 = 0.5 * jnp.pi - lfac1 = 0.5 - lfac2 = 4.0 / 15.0 - for ell in range(2, lmax + 1, 4): - amp = amp0 - for m in range(2, ell + 1, 4): - mu = ell - m - nu = ell + m - rt[ell * ell + ell + m] = amp * lfac1 - rt[ell * ell + ell - m] = amp * lfac1 - if ell < lmax: - rt[(ell + 1) * (ell + 1) + ell + m + 1] = amp * lfac2 - rt[(ell + 1) * (ell + 1) + ell - m + 1] = amp * lfac2 - amp *= (nu + 2.0) / (mu - 2.0) - lfac1 /= (ell / 2 + 2) * (ell / 2 + 3) - lfac2 /= (ell / 2 + 2.5) * (ell / 2 + 3.5) - amp0 *= 0.0625 * ell * (ell + 4) - return np.array(rt) - - -def rTA1(lmax): - return rT(lmax) @ A1(lmax)