From 6515f69bc85cdfc6bd9ba16718d467c2e7a21f7a Mon Sep 17 00:00:00 2001 From: Dan F-M Date: Tue, 17 Oct 2023 20:10:40 -0400 Subject: [PATCH] Implementing polynomial basis --- src/jaxoplanet/experimental/starry/basis.py | 63 +++++++++++++++++++++ tests/experimental/starry/basis_test.py | 23 +++++++- 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/src/jaxoplanet/experimental/starry/basis.py b/src/jaxoplanet/experimental/starry/basis.py index 8246cf20..5654cbc1 100644 --- a/src/jaxoplanet/experimental/starry/basis.py +++ b/src/jaxoplanet/experimental/starry/basis.py @@ -1,6 +1,8 @@ import math from collections import defaultdict +from functools import partial +import jax.numpy as jnp import numpy as np import scipy.sparse.linalg from jax.experimental.sparse import BCOO @@ -194,3 +196,64 @@ def p_G(p, l, m, n): data = np.array(data, dtype=float) idx = np.argsort(indicies) return indicies[idx], data[idx] + + +def poly_basis(deg): + N = (deg + 1) * (deg + 1) + + @partial(jnp.vectorize, signature=f"(),(),()->({N})") + def impl(x, y, z): + xarr = [None for _ in range(N)] + yarr = [None for _ in range(N)] + + # Ensures we get `nan`s off the disk + xterm = 1.0 + 0.0 * z + yterm = 1.0 + 0.0 * z + + i0 = 0 + di0 = 3 + j0 = 0 + dj0 = 2 + for n in range(deg + 1): + i = i0 + di = di0 + xarr[i] = xterm + j = j0 + dj = dj0 + yarr[j] = yterm + i = i0 + di - 1 + j = j0 + dj - 1 + while i + 1 < N: + xarr[i] = xterm + xarr[i + 1] = xterm + di += 2 + i += di + yarr[j] = yterm + yarr[j + 1] = yterm + dj += 2 + j += dj - 1 + xterm *= x + i0 += 2 * n + 1 + di0 += 2 + yterm *= y + j0 += 2 * (n + 1) + 1 + dj0 += 2 + + assert all(v is not None for v in xarr) + assert all(v is not None for v in yarr) + + inds = [] + n = 0 + for ell in range(deg + 1): + for m in range(-ell, ell + 1): + if (ell + m) % 2 != 0: + inds.append(n) + n += 1 + + p = jnp.array(xarr) * jnp.array(yarr) + if len(inds): + return p.at[np.array(inds)].multiply(z) + else: + return p + + return impl diff --git a/tests/experimental/starry/basis_test.py b/tests/experimental/starry/basis_test.py index e0dbba2b..8c179dc0 100644 --- a/tests/experimental/starry/basis_test.py +++ b/tests/experimental/starry/basis_test.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from jaxoplanet.experimental.starry.basis import A1, A2_inv, basis +from jaxoplanet.experimental.starry.basis import A1, A2_inv, basis, poly_basis from jaxoplanet.test_utils import assert_allclose @@ -56,6 +56,27 @@ def test_basis_compare_starry(lmax): assert_allclose(calc, expect) +def test_poly_basis(): + expect = np.array([1.0, 0.1, 0.2, 0.3, 0.01, 0.02, 0.03, 0.06, 0.09]) + calc = poly_basis(2)(0.1, 0.3, 0.2) + assert_allclose(calc, expect) + + +@pytest.mark.parametrize("lmax", [10, 5, 2, 1, 0]) +def test_poly_basis_compare_starry(lmax): + starry = pytest.importorskip("starry") + + x = np.linspace(-1, 1, 100) + y = np.linspace(-0.5, 0.5, 100) + z = np.linspace(-0.1, 0.1, 100) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + m = starry.Map(lmax) + expect = m.ops.pT(x, y, z) + calc = poly_basis(lmax)(x, y, z) + assert_allclose(calc, expect) + + def A1_symbolic(lmax): """The sympy implementation of the A1 matrix from the starry paper""" import math