Skip to content

Commit

Permalink
Implementing polynomial basis (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm authored Oct 18, 2023
1 parent 9851977 commit 933a8ce
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 1 deletion.
63 changes: 63 additions & 0 deletions src/jaxoplanet/experimental/starry/basis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import math
from collections import defaultdict
from functools import partial

import jax.numpy as jnp
import numpy as np
import scipy.sparse.linalg
from jax.experimental.sparse import BCOO
Expand Down Expand Up @@ -194,3 +196,64 @@ def p_G(p, l, m, n):
data = np.array(data, dtype=float)
idx = np.argsort(indicies)
return indicies[idx], data[idx]


def poly_basis(deg):
N = (deg + 1) * (deg + 1)

@partial(jnp.vectorize, signature=f"(),(),()->({N})")
def impl(x, y, z):
xarr = [None for _ in range(N)]
yarr = [None for _ in range(N)]

# Ensures we get `nan`s off the disk
xterm = 1.0 + 0.0 * z
yterm = 1.0 + 0.0 * z

i0 = 0
di0 = 3
j0 = 0
dj0 = 2
for n in range(deg + 1):
i = i0
di = di0
xarr[i] = xterm
j = j0
dj = dj0
yarr[j] = yterm
i = i0 + di - 1
j = j0 + dj - 1
while i + 1 < N:
xarr[i] = xterm
xarr[i + 1] = xterm
di += 2
i += di
yarr[j] = yterm
yarr[j + 1] = yterm
dj += 2
j += dj - 1
xterm *= x
i0 += 2 * n + 1
di0 += 2
yterm *= y
j0 += 2 * (n + 1) + 1
dj0 += 2

assert all(v is not None for v in xarr)
assert all(v is not None for v in yarr)

inds = []
n = 0
for ell in range(deg + 1):
for m in range(-ell, ell + 1):
if (ell + m) % 2 != 0:
inds.append(n)
n += 1

p = jnp.array(xarr) * jnp.array(yarr)
if len(inds):
return p.at[np.array(inds)].multiply(z)
else:
return p

return impl
23 changes: 22 additions & 1 deletion tests/experimental/starry/basis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

from jaxoplanet.experimental.starry.basis import A1, A2_inv, basis
from jaxoplanet.experimental.starry.basis import A1, A2_inv, basis, poly_basis
from jaxoplanet.test_utils import assert_allclose


Expand Down Expand Up @@ -56,6 +56,27 @@ def test_basis_compare_starry(lmax):
assert_allclose(calc, expect)


def test_poly_basis():
expect = np.array([1.0, 0.1, 0.2, 0.3, 0.01, 0.02, 0.03, 0.06, 0.09])
calc = poly_basis(2)(0.1, 0.3, 0.2)
assert_allclose(calc, expect)


@pytest.mark.parametrize("lmax", [10, 5, 2, 1, 0])
def test_poly_basis_compare_starry(lmax):
starry = pytest.importorskip("starry")

x = np.linspace(-1, 1, 100)
y = np.linspace(-0.5, 0.5, 100)
z = np.linspace(-0.1, 0.1, 100)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
m = starry.Map(lmax)
expect = m.ops.pT(x, y, z)
calc = poly_basis(lmax)(x, y, z)
assert_allclose(calc, expect)


def A1_symbolic(lmax):
"""The sympy implementation of the A1 matrix from the starry paper"""
import math
Expand Down

0 comments on commit 933a8ce

Please sign in to comment.