Skip to content

Commit

Permalink
Introducing units and refactoring Keplerian orbit models (#61)
Browse files Browse the repository at this point in the history
* adding units

* adding example of how we might use units

* adding tests for quantity_input decorator

* adding docstring to decorator

* adding doctring and tests for field

* adding tests for keplerian

* fixing edge cases test

* updating keplerian tests

* adding units to TransitOrbit

* fixing python 3.9 errors

* remove pyright from pre-commit

* less stringent test comparisons

* test tol for soln

* sketching OO interface

* default central

* improving interface for keplerian system
  • Loading branch information
dfm authored Oct 17, 2023
1 parent b9f3cc2 commit 9851977
Show file tree
Hide file tree
Showing 23 changed files with 1,160 additions and 753 deletions.
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,13 @@ repos:
rev: "0.6.1"
hooks:
- id: nbstripout
# - repo: https://github.com/RobertCraigie/pyright-python
# rev: v1.1.324
# hooks:
# - id: pyright
# additional_dependencies:
# - nox
# - pytest
# - jax
# - equinox
# - jpu>=0.0.2
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers = [
"Programming Language :: Python :: 3",
]
dynamic = ["version"]
dependencies = ["jax", "jaxlib", "jpu"]
dependencies = ["jax", "jaxlib", "jpu>=0.0.2", "equinox"]

[project.urls]
"Homepage" = "https://jax.exoplanet.codes"
Expand Down
2 changes: 1 addition & 1 deletion src/jaxoplanet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from jaxoplanet import core as core, orbits as orbits
from jaxoplanet import core as core, orbits as orbits, units as units
from jaxoplanet.jaxoplanet_version import __version__ as __version__
Empty file removed src/jaxoplanet/core/py.typed
Empty file.
24 changes: 13 additions & 11 deletions src/jaxoplanet/light_curves.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
from functools import partial
from typing import NamedTuple, Optional
from typing import Optional

import equinox as eqx
import jax.numpy as jnp

from jaxoplanet import units
from jaxoplanet.core.limb_dark import light_curve
from jaxoplanet.proto import LightCurveOrbit
from jaxoplanet.types import Array
from jaxoplanet.types import Array, Quantity
from jaxoplanet.units import unit_registry as ureg


class LimbDarkLightCurve(NamedTuple):
class LimbDarkLightCurve(eqx.Module):
u: Array

@classmethod
def init(cls, *u: Array) -> "LimbDarkLightCurve":
def __init__(self, *u: Array):
if u:
u = jnp.concatenate([jnp.atleast_1d(u0) for u0 in u], axis=0)
self.u = jnp.concatenate([jnp.atleast_1d(u0) for u0 in u], axis=0)
else:
u = jnp.array([])
return cls(u=u)
self.u = jnp.array([])

@units.quantity_input(t=ureg.d, texp=ureg.s)
def light_curve(
self,
orbit: LightCurveOrbit,
t: Array,
t: Quantity,
*,
texp: Optional[Array] = None,
texp: Optional[Quantity] = None,
oversample: Optional[int] = 7,
texp_order: Optional[int] = 0,
limbdark_order: Optional[int] = 10,
Expand Down Expand Up @@ -100,7 +102,7 @@ def light_curve(
lc_func = partial(light_curve, self.u, order=limbdark_order)
if orbit.shape == ():
b /= r_star
lc = lc_func(b, r)
lc: Array = lc_func(b.magnitude, r.magnitude)
else:
b /= r_star[..., None]
lc = jnp.vectorize(lc_func, signature="(k),()->(k)")(b, r)
Expand Down
6 changes: 1 addition & 5 deletions src/jaxoplanet/orbits/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,2 @@
from jaxoplanet.orbits.keplerian import (
KeplerianBody as KeplerianBody,
KeplerianCentral as KeplerianCentral,
KeplerianOrbit as KeplerianOrbit,
)
from jaxoplanet.orbits import keplerian as keplerian
from jaxoplanet.orbits.transit import TransitOrbit as TransitOrbit
Loading

0 comments on commit 9851977

Please sign in to comment.