Skip to content

Commit

Permalink
fix: issues #234 and #235 (#236)
Browse files Browse the repository at this point in the history
* fix: issue #235

* fix: issue #234
  • Loading branch information
lgrcia authored Nov 13, 2024
1 parent 194c9f6 commit 73c8aac
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 7 deletions.
4 changes: 2 additions & 2 deletions src/jaxoplanet/experimental/starry/light_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def surface_light_curve(
theta_z = jnp.arctan2(x, y)

# trick to avoid nan `x=jnp.where...` grad caused by nan sT
r = jnp.where(b_rot, 0.0, r)
b = jnp.where(b_rot, 0.0, b)
r = jnp.where(b_rot, 1.0, r)
b = jnp.where(b_rot, 1.0, b)

sT = solution_vector(surface.deg, order=order)(b, r)

Expand Down
5 changes: 4 additions & 1 deletion src/jaxoplanet/experimental/starry/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from jaxoplanet.core.limb_dark import kite_area
from jaxoplanet.types import Array
from jaxoplanet.utils import zero_safe_sqrt


def solution_vector(l_max: int, order: int = 20) -> Callable[[Array, Array], Array]:
Expand Down Expand Up @@ -156,7 +157,9 @@ def p_integral(order: int, l_max: int, b: Array, r: Array, kappa0: Array) -> Arr
cond = jnp.less(omz2, 10 * jnp.finfo(omz2.dtype).eps)
omz2 = jnp.where(cond, 1, omz2)
z2 = jnp.maximum(0, 1 - omz2)
result = 2 * r * (r - b * c) * (1 - z2 * jnp.sqrt(z2)) / (3 * omz2)
result = (
2 * r * (r - b * c) * (1 - z2 * zero_safe_sqrt(z2)) / (3 * omz2)
)
integrand.append(jnp.where(cond, 0, 2 * result))
weights.append(high_weights)

Expand Down
4 changes: 0 additions & 4 deletions src/jaxoplanet/experimental/starry/ylm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,6 @@ def normalize(self) -> "Ylm":
Raises:
ValueError: if the (0, 0) coefficient is zero.
"""

assert self.data[(0, 0)] != 0.0, ValueError(
"The (0, 0) coefficient must be non-zero to normalize"
)
data = {k: v / self.data[(0, 0)] for k, v in self.data.items()}
return Ylm(data=data)

Expand Down
5 changes: 5 additions & 0 deletions tests/experimental/starry/solution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,8 @@ def test_solution_compare_starry(r, l_max=10, order=500):
err_msg=f"n={n}, l={l}, m={m}, mu={mu}, nu={nu}, case={case}",
atol=1e-6,
)


def test_r_greater_one_grad():
# this is an even more minimal version of what caused issue #235
assert np.isfinite(jax.jacrev(solution_vector(3))(0.5, 10.0)).all()

0 comments on commit 73c8aac

Please sign in to comment.