Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the basis computations to use sparse linear algebra #72

Merged
merged 6 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 51 additions & 23 deletions src/jaxoplanet/experimental/starry/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,49 @@
from collections import defaultdict

import numpy as np
import scipy.sparse.linalg
from jax.experimental.sparse import BCOO
from scipy.special import gamma

try:
from scipy.sparse import csc_array
except ImportError:
# With older versions of scipy, the data structures were called "matrices"
# not "arrays"; this allows us to support either.
from scipy.sparse import csc_matrix as csc_array


def basis(lmax):
matrix = scipy.sparse.linalg.spsolve(A2_inv(lmax), A1(lmax))
if lmax > 0:
return BCOO.from_scipy_sparse(matrix)
else:
return BCOO.fromdense(np.squeeze(matrix)[None, None])


def A1(lmax):
"""Note: The normalization here matches the starry paper, but not the
code. To get the code's normalization, multiply the result by 2 /
sqrt(pi).
"""
n = (lmax + 1) ** 2
res = np.zeros((n, n))
p = {ptilde(m): m for m in range(n)}
n = 0
for l in range(lmax + 1):
for m in range(-l, l + 1):
p_Y(p, l, m, res[:, n])
n += 1
return res
return _A_impl(lmax, p_Y) * 2 / np.sqrt(np.pi)


def A2_inv(lmax):
return _A_impl(lmax, p_G)


def _A_impl(lmax, func):
n = (lmax + 1) ** 2
res = np.zeros((n, n))
data = []
row_ind = []
col_ind = []
p = {ptilde(m): m for m in range(n)}
n = 0
for l in range(lmax + 1):
for _ in range(-l, l + 1):
p_G(p, n, res[:, n])
for m in range(-l, l + 1):
idx, val = func(p, l, m, n)
data.extend(val)
row_ind.extend(idx)
col_ind.extend([n] * len(idx))
n += 1
return res
return csc_array((np.array(data), (row_ind, col_ind)), shape=(n, n))


def ptilde(n):
Expand Down Expand Up @@ -116,12 +130,19 @@ def Ylm(l, m):
return dict(res)


def p_Y(p, l, m, res):
def p_Y(p, l, m, n):
del n
indicies = []
data = []
for k, v in Ylm(l, m).items():
if k not in p:
continue
res[p[k]] = v
return res
indicies.append(p[k])
data.append(v)
indicies = np.array(indicies, dtype=int)
data = np.array(data, dtype=float)
idx = np.argsort(indicies)
return indicies[idx], data[idx]


def gtilde(n):
Expand Down Expand Up @@ -160,9 +181,16 @@ def gtilde(n):
return res


def p_G(p, n, res):
def p_G(p, l, m, n):
del l, m
indicies = []
data = []
for k, v in gtilde(n).items():
if k not in p:
continue
res[p[k]] = v
return res
indicies.append(p[k])
data.append(v)
indicies = np.array(indicies, dtype=int)
data = np.array(data, dtype=float)
idx = np.argsort(indicies)
return indicies[idx], data[idx]
50 changes: 50 additions & 0 deletions src/jaxoplanet/experimental/starry/wigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ def dot_rotation_matrix(ydeg, x, y, z, theta):
ydeg = int(ydeg)
except TypeError as e:
raise TypeError(f"ydeg must be an integer; got {ydeg}") from e

if x is None and y is None:
if z is None:
raise ValueError("Either x, y, or z must be specified")

return dot_rz(ydeg, theta)

x = 0.0 if x is None else x
y = 0.0 if y is None else y
z = 0.0 if z is None else z

if jnp.shape(x) != ():
raise ValueError(f"x must be a scalar; got {jnp.shape(x)}")
if jnp.shape(y) != ():
Expand Down Expand Up @@ -272,3 +283,42 @@ def dlmn(ell, s1, c1, c2, tgbet2, s3, c3, D):
cosmal = aux

return jnp.asarray(D_), jnp.asarray(R_)


def dot_rz(deg, theta):
"""Special case for rotation only around z axis"""
c = jnp.cos(theta)
s = jnp.sin(theta)
cosnt = [1.0, c]
sinnt = [0.0, s]
for n in range(2, deg + 1):
cosnt.append(2.0 * cosnt[n - 1] * c - cosnt[n - 2])
sinnt.append(2.0 * sinnt[n - 1] * c - sinnt[n - 2])

n = 0
cosmt = []
sinmt = []
for ell in range(deg + 1):
for m in range(-ell, 0):
cosmt.append(cosnt[-m])
sinmt.append(-sinnt[-m])
for m in range(ell + 1):
cosmt.append(cosnt[m])
sinmt.append(sinnt[m])

n_max = deg**2 + 2 * deg + 1

@jax.jit
@partial(jnp.vectorize, signature=f"({n_max})->({n_max})")
def impl(M):
result = [0 for _ in range(n_max)]
for ell in range(deg + 1):
for j in range(2 * ell + 1):
result[ell * ell + j] = (
M[ell * ell + j] * cosmt[ell * ell + j]
+ M[ell * ell + 2 * ell - j] * sinmt[ell * ell + j]
)

return jnp.array(result, dtype=jnp.dtype(M))

return impl
30 changes: 21 additions & 9 deletions tests/experimental/starry/basis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,57 @@
import numpy as np
import pytest

from jaxoplanet.experimental.starry.basis import A1, A2_inv
from jaxoplanet.experimental.starry.basis import A1, A2_inv, basis
from jaxoplanet.test_utils import assert_allclose


@pytest.mark.parametrize("lmax", [10, 7, 5, 4, 3, 2, 1, 0])
def test_A1(lmax):
pytest.importorskip("sympy")
expected = A1_symbolic(lmax)
calc = A1(lmax)
expected = A1_symbolic(lmax) / (0.5 * np.sqrt(np.pi))
calc = A1(lmax).todense()
np.testing.assert_allclose(calc, expected, atol=5e-12)


@pytest.mark.parametrize("lmax", [10, 7, 5, 4, 3, 2, 1, 0])
def test_A2_inv(lmax):
pytest.importorskip("sympy")
expected = A2_inv_symbolic(lmax)
calc = A2_inv(lmax)
calc = A2_inv(lmax).todense()
np.testing.assert_allclose(calc, expected)


@pytest.mark.parametrize("lmax", [10, 7, 5, 4, 3, 2, 1, 0])
def test_compare_starry_A1(lmax):
def test_A1_compare_starry(lmax):
starry = pytest.importorskip("starry")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
m = starry.Map(lmax)
expect = m.ops.A1.eval().toarray() * (0.5 * np.sqrt(np.pi))
calc = A1(lmax)
expect = m.ops.A1.eval().toarray()
calc = A1(lmax).todense()
np.testing.assert_allclose(calc, expect, atol=5e-12)


@pytest.mark.parametrize("lmax", [10, 7, 5, 4, 3, 2, 1, 0])
def test_compare_starry_A2_inv(lmax):
def test_A2_inv_compare_starry(lmax):
starry = pytest.importorskip("starry")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
m = starry.Map(lmax)
A2 = m.ops.A.eval().toarray() @ m.ops.A1Inv.eval().toarray()
inv = A2_inv(lmax)
np.testing.assert_allclose(inv @ A2, np.eye(len(inv)), atol=5e-12)
np.testing.assert_allclose(inv @ A2, np.eye(inv.shape[0]), atol=5e-12)


@pytest.mark.parametrize("lmax", [10, 7, 5, 4, 3, 2, 1, 0])
def test_basis_compare_starry(lmax):
starry = pytest.importorskip("starry")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
m = starry.Map(lmax)
expect = m.ops.A.eval().toarray()
calc = basis(lmax).todense()
assert_allclose(calc, expect)


def A1_symbolic(lmax):
Expand Down
2 changes: 1 addition & 1 deletion tests/experimental/starry/solution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _lam(b, r):


@pytest.mark.parametrize("r", [0.1, 1.1])
def test_compare_starry(r, l_max=10, order=20):
def test_solution_compare_starry(r, l_max=10, order=20):
starry = pytest.importorskip("starry")
theano = pytest.importorskip("theano")
theano.config.gcc__cxxflags += " -fexceptions"
Expand Down
28 changes: 24 additions & 4 deletions tests/experimental/starry/wigner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,36 @@

@pytest.mark.parametrize("l_max", [5, 4, 3, 2, 1, 0])
@pytest.mark.parametrize("u", [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)])
def test_dot_rotation(l_max, u):
@pytest.mark.parametrize("theta", [0.1])
def test_dot_rotation(l_max, u, theta):
"""Test full rotation matrix against symbolic one"""
pytest.importorskip("sympy")
ident = np.eye(l_max**2 + 2 * l_max + 1)
theta = 0.1
expected = np.array(R_symbolic(l_max, u, theta)).astype(float)
calc = dot_rotation_matrix(l_max, u[0], u[1], u[2], theta)(ident)
assert_allclose(calc, expected)


@pytest.mark.parametrize("l_max", [5, 4, 3, 2, 1, 0])
@pytest.mark.parametrize("theta", [-0.5, 0.0, 0.1, 1.5 * np.pi])
def test_dot_rotation_z(l_max, theta):
ident = np.eye(l_max**2 + 2 * l_max + 1)
expected = dot_rotation_matrix(l_max, 0.0, 0.0, 1.0, theta)(ident)
calc = dot_rotation_matrix(l_max, None, None, 1.0, theta)(ident)
assert_allclose(calc, expected)


def test_dot_rotation_negative():
starry = pytest.importorskip("starry")
l_max = 5
n_max = l_max**2 + 2 * l_max + 1
y = np.linspace(-1, 1, n_max)
starry_op = starry._core.core.OpsYlm(l_max, 0, 0, 1)
expected = starry_op.dotR(y[None, :], 1.0, 0, 0.0, -0.5 * np.pi)[0]
calc = dot_rotation_matrix(l_max, 1.0, 0.0, 0.0, -0.5 * np.pi)(y)
assert_allclose(calc, expected)


def test_dot_rotation_edge_cases():
l_max = 5
n_max = l_max**2 + 2 * l_max + 1
Expand All @@ -34,11 +54,11 @@ def test_dot_rotation_edge_cases():

@pytest.mark.parametrize("l_max", [10, 7, 5, 4])
@pytest.mark.parametrize("u", [(1, 0, 0), (0, 1, 0), (0, 0, 1), (1, 1, 1)])
def test_dot_rotation_compare_starry(l_max, u):
@pytest.mark.parametrize("theta", [0.1])
def test_dot_rotation_compare_starry(l_max, u, theta):
"""Comparison test with starry OpsYlm.dotR"""
starry = pytest.importorskip("starry")
random = np.random.default_rng(l_max)
theta = 0.1
n_max = l_max**2 + 2 * l_max + 1
M1 = np.eye(n_max)
M2 = random.normal(size=(5, n_max))
Expand Down