diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f12d93bb..be2401ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,3 +19,13 @@ repos: rev: "0.6.1" hooks: - id: nbstripout + # - repo: https://github.com/RobertCraigie/pyright-python + # rev: v1.1.324 + # hooks: + # - id: pyright + # additional_dependencies: + # - nox + # - pytest + # - jax + # - equinox + # - jpu>=0.0.2 diff --git a/pyproject.toml b/pyproject.toml index 516dada2..c3915987 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers = [ "Programming Language :: Python :: 3", ] dynamic = ["version"] -dependencies = ["jax", "jaxlib", "jpu"] +dependencies = ["jax", "jaxlib", "jpu>=0.0.2", "equinox"] [project.urls] "Homepage" = "https://jax.exoplanet.codes" diff --git a/src/jaxoplanet/__init__.py b/src/jaxoplanet/__init__.py index 16819df3..4e6aa1bb 100644 --- a/src/jaxoplanet/__init__.py +++ b/src/jaxoplanet/__init__.py @@ -1,2 +1,2 @@ -from jaxoplanet import core as core, orbits as orbits +from jaxoplanet import core as core, orbits as orbits, units as units from jaxoplanet.jaxoplanet_version import __version__ as __version__ diff --git a/src/jaxoplanet/core/py.typed b/src/jaxoplanet/core/py.typed deleted file mode 100644 index e69de29b..00000000 diff --git a/src/jaxoplanet/light_curves.py b/src/jaxoplanet/light_curves.py index 5f617545..db739442 100644 --- a/src/jaxoplanet/light_curves.py +++ b/src/jaxoplanet/light_curves.py @@ -1,30 +1,32 @@ from functools import partial -from typing import NamedTuple, Optional +from typing import Optional +import equinox as eqx import jax.numpy as jnp +from jaxoplanet import units from jaxoplanet.core.limb_dark import light_curve from jaxoplanet.proto import LightCurveOrbit -from jaxoplanet.types import Array +from jaxoplanet.types import Array, Quantity +from jaxoplanet.units import unit_registry as ureg -class LimbDarkLightCurve(NamedTuple): +class LimbDarkLightCurve(eqx.Module): u: Array - @classmethod - def init(cls, *u: Array) -> "LimbDarkLightCurve": + def __init__(self, *u: Array): if u: - u = jnp.concatenate([jnp.atleast_1d(u0) for u0 in u], axis=0) + self.u = jnp.concatenate([jnp.atleast_1d(u0) for u0 in u], axis=0) else: - u = jnp.array([]) - return cls(u=u) + self.u = jnp.array([]) + @units.quantity_input(t=ureg.d, texp=ureg.s) def light_curve( self, orbit: LightCurveOrbit, - t: Array, + t: Quantity, *, - texp: Optional[Array] = None, + texp: Optional[Quantity] = None, oversample: Optional[int] = 7, texp_order: Optional[int] = 0, limbdark_order: Optional[int] = 10, @@ -100,7 +102,7 @@ def light_curve( lc_func = partial(light_curve, self.u, order=limbdark_order) if orbit.shape == (): b /= r_star - lc = lc_func(b, r) + lc: Array = lc_func(b.magnitude, r.magnitude) else: b /= r_star[..., None] lc = jnp.vectorize(lc_func, signature="(k),()->(k)")(b, r) diff --git a/src/jaxoplanet/orbits/__init__.py b/src/jaxoplanet/orbits/__init__.py index f08e226e..83f77eee 100644 --- a/src/jaxoplanet/orbits/__init__.py +++ b/src/jaxoplanet/orbits/__init__.py @@ -1,6 +1,2 @@ -from jaxoplanet.orbits.keplerian import ( - KeplerianBody as KeplerianBody, - KeplerianCentral as KeplerianCentral, - KeplerianOrbit as KeplerianOrbit, -) +from jaxoplanet.orbits import keplerian as keplerian from jaxoplanet.orbits.transit import TransitOrbit as TransitOrbit diff --git a/src/jaxoplanet/orbits/keplerian.py b/src/jaxoplanet/orbits/keplerian.py index 1da35e85..5da74561 100644 --- a/src/jaxoplanet/orbits/keplerian.py +++ b/src/jaxoplanet/orbits/keplerian.py @@ -1,47 +1,35 @@ -from functools import partial -from typing import TYPE_CHECKING, Any, NamedTuple, Optional +from typing import Any, Optional +import equinox as eqx import jax import jax.numpy as jnp +import jpu.numpy as jnpu +from jaxoplanet import units from jaxoplanet.core.kepler import kepler -from jaxoplanet.types import Scalar +from jaxoplanet.types import Quantity +from jaxoplanet.units import unit_registry as ureg -# FIXME: Switch to constants from astropy -GRAVITATIONAL_CONSTANT = 2942.2062175044193 / (4 * jnp.pi**2) -AU_PER_R_SUN = 0.00465046726096215 +class Central(eqx.Module): + mass: Quantity = units.field(units=ureg.M_sun) + radius: Quantity = units.field(units=ureg.R_sun) + density: Quantity = units.field(units=ureg.M_sun / ureg.R_sun**3) -class KeplerianCentral(NamedTuple): - mass: Scalar - radius: Scalar - density: Scalar - - @property - def shape(self) -> tuple[int, ...]: - leaves, _ = jax.tree_util.tree_flatten(self) - shape = leaves[0].shape - if any(leaf.shape != shape for leaf in leaves[1:]): - raise ValueError("Inconsistent shapes for parameters of 'KeplerianCentral'") - return shape - - @classmethod - def init( - cls, + @units.quantity_input( + mass=ureg.M_sun, radius=ureg.R_sun, density=ureg.M_sun / ureg.R_sun**3 + ) + def __init__( + self, *, - mass: Optional[Scalar] = None, - radius: Optional[Scalar] = None, - density: Optional[Scalar] = None, - ) -> "KeplerianCentral": + mass: Optional[Quantity] = None, + radius: Optional[Quantity] = None, + density: Optional[Quantity] = None, + ): if radius is None and mass is None: - radius = 1.0 + radius = 1.0 * ureg.R_sun if density is None: - mass = 1.0 - - if sum(arg is None for arg in (mass, radius, density)) != 1: - raise ValueError( - "Values must be provided for exactly two of mass, radius, and density" - ) + mass = 1.0 * ureg.M_sun # Check that all the input values are scalars; we don't support Scalars # here @@ -51,39 +39,37 @@ def init( raise ValueError("All parameters of a KeplerianCentral must be scalars") # Compute all three parameters based on the input values + error_msg = ( + "Values must be provided for exactly two of mass, radius, and density" + ) if density is None: - if TYPE_CHECKING: - assert mass is not None - assert radius is not None - density = 3 * mass / (4 * jnp.pi * radius**3) + if mass is None or radius is None: + raise ValueError(error_msg) + self.mass = mass + self.radius = radius + self.density = 3 * mass / (4 * jnp.pi * radius**3) elif radius is None: - if TYPE_CHECKING: - assert mass is not None - assert density is not None - radius = (3 * mass / (4 * jnp.pi * density)) ** (1 / 3) + if mass is None or density is None: + raise ValueError(error_msg) + self.mass = mass + self.radius = (3 * mass / (4 * jnp.pi * density)) ** (1 / 3) + self.density = density elif mass is None: - if TYPE_CHECKING: - assert density is not None - assert radius is not None - mass = 4 * jnp.pi * radius**3 * density / 3.0 - - # Convert dtypes to be at least float32 - dtype = jnp.result_type(mass, radius, density, jnp.float32) - mass = jnp.asarray(mass, dtype=dtype) - radius = jnp.asarray(radius, dtype=dtype) - density = jnp.asarray(density, dtype=dtype) - - return cls(mass=mass, radius=radius, density=density) + if radius is None or density is None: + raise ValueError(error_msg) + self.mass = 4 * jnp.pi * radius**3 * density / 3.0 + self.radius = radius + self.density = density @classmethod def from_orbital_properties( cls, *, - period: Scalar, - semimajor: Scalar, - radius: Optional[Scalar] = None, - body_mass: Optional[Scalar] = None, - ) -> "KeplerianCentral": + period: Quantity, + semimajor: Quantity, + radius: Optional[Quantity] = None, + body_mass: Optional[Quantity] = None, + ) -> "Central": if jnp.ndim(semimajor) != 0: raise ValueError( "The 'semimajor' argument to " @@ -92,123 +78,101 @@ def from_orbital_properties( "use 'jax.vmap'" ) - radius = 1.0 if radius is None else radius + radius = 1.0 * ureg.R_sun if radius is None else radius - mass = semimajor**3 / (GRAVITATIONAL_CONSTANT * period**2) + mass = semimajor**3 / (ureg.gravitational_constant * period**2) if body_mass is not None: mass -= body_mass - return cls.init(mass=mass, radius=radius) - - -class KeplerianBody(NamedTuple): - central: KeplerianCentral - time_ref: Scalar - time_transit: Scalar - period: Scalar - semimajor: Scalar - sin_inclination: Scalar - cos_inclination: Scalar - impact_param: Scalar - mass: Scalar - radius: Scalar - eccentricity: Optional[Scalar] - sin_omega_peri: Optional[Scalar] - cos_omega_peri: Optional[Scalar] - sin_asc_node: Optional[Scalar] - cos_asc_node: Optional[Scalar] + return cls(mass=mass, radius=radius) @property def shape(self) -> tuple[int, ...]: - leaves, _ = jax.tree_util.tree_flatten(self) - shape = leaves[0].shape - if any(leaf.shape != shape for leaf in leaves[1:]): - raise ValueError("Inconsistent shapes for parameters of 'KeplerianBody'") - return shape - - @property - def time_peri(self) -> Scalar: - return self.time_transit + self.time_ref - - @property - def total_mass(self) -> Scalar: - return self.central.mass if self.mass is None else self.central.mass + self.mass - - @property - def _baseline_rv_semiamplitude(self) -> Scalar: - k0 = 2 * jnp.pi * self.semimajor / (self.total_mass * self.period) - if self.eccentricity is None: - return k0 - return k0 / jnp.sqrt(1 - self.eccentricity**2) - - @classmethod - def _central_from_orbital_properties( - cls, + return self.mass.shape + + +class Body(eqx.Module): + central: Central + time_ref: Quantity = units.field(units=ureg.d) + time_transit: Quantity = units.field(units=ureg.d) + period: Quantity = units.field(units=ureg.d) + semimajor: Quantity = units.field(units=ureg.R_sun) + sin_inclination: Quantity = units.field(units=ureg.dimensionless) + cos_inclination: Quantity = units.field(units=ureg.dimensionless) + impact_param: Quantity = units.field(units=ureg.dimensionless) + mass: Optional[Quantity] = units.field(units=ureg.M_sun) + radius: Optional[Quantity] = units.field(units=ureg.R_sun) + eccentricity: Optional[Quantity] = units.field(units=ureg.dimensionless) + sin_omega_peri: Optional[Quantity] = units.field(units=ureg.dimensionless) + cos_omega_peri: Optional[Quantity] = units.field(units=ureg.dimensionless) + sin_asc_node: Optional[Quantity] = units.field(units=ureg.dimensionless) + cos_asc_node: Optional[Quantity] = units.field(units=ureg.dimensionless) + radial_velocity_semiamplitude: Optional[Quantity] = units.field( + units=ureg.R_sun / ureg.d + ) + parallax: Optional[Quantity] = units.field(units=ureg.arcsec) + + @units.quantity_input( + time_transit=ureg.d, + time_peri=ureg.d, + period=ureg.d, + semimajor=ureg.R_sun, + inclination=ureg.radian, + impact_param=ureg.dimensionless, + eccentricity=ureg.dimensionless, + omega_peri=ureg.radian, + sin_omega_peri=ureg.dimensionless, + cos_omega_peri=ureg.dimensionless, + asc_node=ureg.radian, + sin_asc_node=ureg.dimensionless, + cos_asc_node=ureg.dimensionless, + mass=ureg.M_sun, + radius=ureg.R_sun, + central_radius=ureg.R_sun, + radial_velocity_semiamplitude=ureg.R_sun / ureg.d, + parallax=ureg.arcsec, + ) + def __init__( + self, + central: Optional[Central] = None, *, - period: Optional[Scalar] = None, - semimajor: Optional[Scalar] = None, - mass: Optional[Scalar] = None, - central: Optional[KeplerianCentral] = None, - central_radius: Optional[Scalar] = None, - ) -> KeplerianCentral: - if period is None and semimajor is None: - raise ValueError("Either a period or a semimajor axis must be provided") - - # When both period and semimajor axis are provided, we set up the - # central with the density implied by these parameters - if period is not None and semimajor is not None: - if central is not None: - raise ValueError( - "Cannot specify both period and semimajor axis when also " - "providing a central body" + time_transit: Optional[Quantity] = None, + time_peri: Optional[Quantity] = None, + period: Optional[Quantity] = None, + semimajor: Optional[Quantity] = None, + inclination: Optional[Quantity] = None, + impact_param: Optional[Quantity] = None, + eccentricity: Optional[Quantity] = None, + omega_peri: Optional[Quantity] = None, + sin_omega_peri: Optional[Quantity] = None, + cos_omega_peri: Optional[Quantity] = None, + asc_node: Optional[Quantity] = None, + sin_asc_node: Optional[Quantity] = None, + cos_asc_node: Optional[Quantity] = None, + mass: Optional[Quantity] = None, + radius: Optional[Quantity] = None, + central_radius: Optional[Quantity] = None, + radial_velocity_semiamplitude: Optional[Quantity] = None, + parallax: Optional[Quantity] = None, + ): + # Handle the special case when passing both `period` and `semimajor`. + # This occurs sometimes when doing transit fits, and we want to fit for + # the "photoeccentric effect". In this case, the central ends up with an + # implied density. + if central is None: + if period is not None and semimajor is not None: + central = Central.from_orbital_properties( + period=period, + semimajor=semimajor, + radius=central_radius, + body_mass=mass, ) - central = KeplerianCentral.from_orbital_properties( - period=period, - semimajor=semimajor, - radius=central_radius, - body_mass=mass, - ) - - return ( - KeplerianCentral.init(radius=central_radius) if central is None else central - ) + semimajor = None + else: + central = Central() + self.central = central - @classmethod - def init( - cls, - *, - time_transit: Optional[Scalar] = None, - time_peri: Optional[Scalar] = None, - period: Optional[Scalar] = None, - semimajor: Optional[Scalar] = None, - inclination: Optional[Scalar] = None, - impact_param: Optional[Scalar] = None, - eccentricity: Optional[Scalar] = None, - omega_peri: Optional[Scalar] = None, - sin_omega_peri: Optional[Scalar] = None, - cos_omega_peri: Optional[Scalar] = None, - asc_node: Optional[Scalar] = None, - sin_asc_node: Optional[Scalar] = None, - cos_asc_node: Optional[Scalar] = None, - mass: Optional[Scalar] = None, - radius: Optional[Scalar] = None, - central: Optional[KeplerianCentral] = None, - central_radius: Optional[Scalar] = None, - ) -> "KeplerianBody": - central = cls._central_from_orbital_properties( - period=period, - semimajor=semimajor, - mass=mass, - central=central, - central_radius=central_radius, - ) - if central.shape != (): - raise ValueError( - "The central body for a 'KeplerianBody' must have scalar " - "parameters; for multi-planet systems, use 'jax.vmap'" - ) - - # Check the input arguments + # Check that all the input arguments have the right shape provided_input_arguments = [ arg for arg in ( @@ -228,25 +192,41 @@ def init( mass, radius, central_radius, + radial_velocity_semiamplitude, + parallax, ) if arg is not None ] - dtype = jnp.result_type(*provided_input_arguments, jnp.float32) if any(jnp.ndim(arg) != 0 for arg in provided_input_arguments): raise ValueError( - "All input arguments to 'KeplerianBody.init' must be scalars; " - "for multi-planet systems, use 'jax.vmap'" + "All input arguments to 'Body' must be scalars; " + "for multi-planet systems, use a 'System'" ) + # Save the input mass and radius + self.radius = radius + self.mass = mass + self.radial_velocity_semiamplitude = radial_velocity_semiamplitude + self.parallax = parallax + # Work out the period and semimajor axis to be consistent - total_mass = central.mass + mass if mass is not None else central.mass - mass_factor = GRAVITATIONAL_CONSTANT * total_mass + mass_factor = ureg.gravitational_constant * self.total_mass if semimajor is None: - if TYPE_CHECKING: - assert period is not None - semimajor = (mass_factor * period**2) ** (1.0 / 3) + if period is None: + raise ValueError( + "Either `period` or `semimajor` must be specified when constructing " + "a Keplerian 'Body'" + ) + self.semimajor = jnpu.cbrt(mass_factor * period**2 / (4 * jnp.pi**2)) + self.period = period elif period is None: - period = semimajor**1.5 / jnp.sqrt(mass_factor) + self.semimajor = semimajor + self.period = 2 * jnp.pi * semimajor * jnpu.sqrt(semimajor / mass_factor) + else: + raise ValueError( + "`period` or `semimajor` cannot both be specified when constructing " + "a Keplerian 'Body'" + ) # Handle treatment and normalization of angles if omega_peri is not None: @@ -254,154 +234,106 @@ def init( raise ValueError( "Cannot specify both omega_peri and sin_omega_peri or cos_omega_peri" ) - sin_omega_peri = jnp.sin(omega_peri) - cos_omega_peri = jnp.cos(omega_peri) + self.sin_omega_peri = jnpu.sin(omega_peri) + self.cos_omega_peri = jnpu.cos(omega_peri) elif (sin_omega_peri is None) != (cos_omega_peri is None): raise ValueError("Must specify both sin_omega_peri and cos_omega_peri") + else: + self.sin_omega_peri = sin_omega_peri + self.cos_omega_peri = cos_omega_peri if asc_node is not None: if sin_asc_node is not None or cos_asc_node is not None: raise ValueError( "Cannot specify both asc_node and sin_asc_node or cos_asc_node" ) - sin_asc_node = jnp.sin(asc_node) - cos_asc_node = jnp.cos(asc_node) + self.sin_asc_node = jnpu.sin(asc_node) + self.cos_asc_node = jnpu.cos(asc_node) elif (sin_asc_node is None) != (cos_asc_node is None): raise ValueError("Must specify both sin_asc_node and cos_asc_node") + else: + self.sin_asc_node = sin_asc_node + self.cos_asc_node = cos_asc_node # Handle eccentric and circular orbits + self.eccentricity = eccentricity if eccentricity is None: if sin_omega_peri is not None: raise ValueError("Cannot specify omega_peri without eccentricity") - M0 = jnp.full_like(period, 0.5 * jnp.pi) + M0 = jnpu.full_like(self.period, 0.5 * jnp.pi) incl_factor = 1 else: - if sin_omega_peri is None: + if self.sin_omega_peri is None: raise ValueError("Must specify omega_peri for eccentric orbits") - opsw = 1 + sin_omega_peri - E0 = 2 * jnp.arctan2( - jnp.sqrt(1 - eccentricity) * cos_omega_peri, - jnp.sqrt(1 + eccentricity) * opsw, + opsw = 1 + self.sin_omega_peri + E0 = 2 * jnpu.arctan2( + jnpu.sqrt(1 - eccentricity) * self.cos_omega_peri, + jnpu.sqrt(1 + eccentricity) * opsw, ) - M0 = E0 - eccentricity * jnp.sin(E0) + M0 = E0 - eccentricity * jnpu.sin(E0) ome2 = 1 - eccentricity**2 - incl_factor = (1 + eccentricity * sin_omega_peri) / ome2 + incl_factor = (1 + eccentricity * self.sin_omega_peri) / ome2 # Handle inclined orbits - dcosidb = incl_factor * central.radius / semimajor + dcosidb = incl_factor * central.radius / self.semimajor if impact_param is not None: if inclination is not None: raise ValueError("Cannot specify both inclination and impact_param") - cos_inclination = dcosidb * impact_param - sin_inclination = jnp.sqrt(1 - cos_inclination**2) + self.impact_param = impact_param + self.cos_inclination = dcosidb * impact_param + self.sin_inclination = jnpu.sqrt(1 - self.cos_inclination**2) elif inclination is not None: - cos_inclination = jnp.cos(inclination) - sin_inclination = jnp.sin(inclination) - impact_param = cos_inclination / dcosidb + self.cos_inclination = jnpu.cos(inclination) + self.sin_inclination = jnpu.sin(inclination) + self.impact_param = self.cos_inclination / dcosidb else: - impact_param = jnp.zeros_like(period) - cos_inclination = jnp.zeros_like(period) - sin_inclination = jnp.ones_like(period) + self.impact_param = jnpu.zeros_like(self.period) + self.cos_inclination = jnpu.zeros_like(self.period) + self.sin_inclination = jnpu.ones_like(self.period) # Work out all the relevant reference times - time_ref = -M0 * period / (2 * jnp.pi) + self.time_ref = -M0 * self.period / (2 * jnp.pi) if time_transit is not None and time_peri is not None: raise ValueError("Cannot specify both time_transit or time_peri") elif time_transit is not None: - time_peri = time_transit + time_ref + self.time_transit = time_transit elif time_peri is not None: - time_transit = time_peri - time_ref + self.time_transit = time_peri - self.time_ref else: - time_transit = jnp.zeros_like(period) - time_peri = time_ref - - return cls( - central=central, - time_ref=jnp.asarray(time_ref, dtype=dtype), - time_transit=jnp.asarray(time_transit, dtype=dtype), - period=jnp.asarray(period, dtype=dtype), - semimajor=jnp.asarray(semimajor, dtype=dtype), - sin_inclination=jnp.asarray(sin_inclination, dtype=dtype), - cos_inclination=jnp.asarray(cos_inclination, dtype=dtype), - impact_param=jnp.asarray(impact_param, dtype=dtype), - mass=jnp.asarray(0, dtype=dtype) - if mass is None - else jnp.asarray(mass, dtype=dtype), - radius=jnp.asarray(0, dtype=dtype) - if radius is None - else jnp.asarray(radius, dtype=dtype), - eccentricity=None - if eccentricity is None - else jnp.asarray(eccentricity, dtype=dtype), - sin_omega_peri=None - if sin_omega_peri is None - else jnp.asarray(sin_omega_peri, dtype=dtype), - cos_omega_peri=None - if cos_omega_peri is None - else jnp.asarray(cos_omega_peri, dtype=dtype), - sin_asc_node=None - if sin_asc_node is None - else jnp.asarray(sin_asc_node, dtype=dtype), - cos_asc_node=None - if cos_asc_node is None - else jnp.asarray(cos_asc_node, dtype=dtype), - ) + self.time_transit = jnpu.zeros_like(self.period) - @classmethod - def circular_from_duration( - cls, - *, - period: Scalar, - duration: Scalar, - impact_param: Scalar, - radius: Scalar, - mass: Optional[Scalar] = None, - central: Optional[KeplerianCentral] = None, - central_radius: Optional[KeplerianCentral] = None, - **kwargs: Any, - ) -> "KeplerianBody": - central = cls._central_from_orbital_properties( - period=period, - mass=mass, - central=central, - central_radius=central_radius, - ) + @property + def shape(self) -> tuple[int, ...]: + return self.period.shape - # Get the semimajor axis implied by the duration, period, and impact - # parameter - b2 = jnp.square(impact_param) - opk2 = jnp.square(1 + radius / central.radius) - phi = jnp.pi * duration / period - sinp = jnp.sin(phi) - cosp = jnp.cos(phi) - semimajor = jnp.sqrt(opk2 - b2 * cosp**2) / sinp - - central = KeplerianCentral.from_orbital_properties( - period=period, - semimajor=semimajor, - radius=central_radius, - body_mass=mass, - ) + @property + def central_radius(self) -> Quantity: + return self.central.radius - return cls.init( - period=period, - impact_param=impact_param, - radius=radius, - mass=mass, - central=central, - **kwargs, - ) + @property + def time_peri(self) -> Quantity: + return self.time_transit + self.time_ref # type: ignore @property - def central_radius(self) -> Scalar: - return self.central.radius + def inclination(self) -> Quantity: + return jnpu.arctan2(self.sin_inclination, self.cos_inclination) + + @property + def omega_peri(self) -> Optional[Quantity]: + if self.eccentricity is None: + return None + return jnpu.arctan2(self.sin_omega_peri, self.cos_omega_peri) + + @property + def total_mass(self) -> Quantity: + return self.central.mass if self.mass is None else self.mass + self.central.mass def position( - self, t: Scalar, parallax: Optional[Scalar] = None - ) -> tuple[Scalar, Scalar, Scalar]: + self, t: Quantity, parallax: Optional[Quantity] = None + ) -> tuple[Quantity, Quantity, Quantity]: """This body's position in the barycentric frame Args: @@ -417,8 +349,8 @@ def position( )[0] def central_position( - self, t: Scalar, parallax: Optional[Scalar] = None - ) -> tuple[Scalar, Scalar, Scalar]: + self, t: Quantity, parallax: Optional[Quantity] = None + ) -> tuple[Quantity, Quantity, Quantity]: """The central's position in the barycentric frame Args: @@ -434,8 +366,8 @@ def central_position( )[0] def relative_position( - self, t: Scalar, parallax: Optional[Scalar] = None - ) -> tuple[Scalar, Scalar, Scalar]: + self, t: Quantity, parallax: Optional[Quantity] = None + ) -> tuple[Quantity, Quantity, Quantity]: """This body's position relative to the central in the X,Y,Z frame Args: @@ -446,12 +378,12 @@ def relative_position( ``R_sun``. """ return self._get_position_and_velocity( - t, semimajor=-self.semimajor, parallax=parallax + t, semimajor=-self.semimajor, parallax=parallax # type: ignore )[0] def relative_angles( - self, t: Scalar, parallax: Optional[Scalar] = None - ) -> tuple[Scalar, Scalar]: + self, t: Quantity, parallax: Optional[Quantity] = None + ) -> tuple[Quantity, Quantity]: """This body's relative position to the central in the sky plane, in separation, position angle coordinates @@ -463,13 +395,13 @@ def relative_angles( east of north) of the planet relative to the star. """ X, Y, _ = self.relative_position(t, parallax=parallax)[0] - rho = jnp.sqrt(X**2 + Y**2) - theta = jnp.arctan2(Y, X) - return (rho, theta) + rho = jnpu.sqrt(X**2 + Y**2) + theta = jnpu.arctan2(Y, X) + return rho, theta def velocity( - self, t: Scalar, semiamplitude: Optional[Scalar] = None - ) -> tuple[Scalar, Scalar, Scalar]: + self, t: Quantity, semiamplitude: Optional[Quantity] = None + ) -> tuple[Quantity, Quantity, Quantity]: """This body's velocity in the barycentric frame Args: @@ -484,14 +416,14 @@ def velocity( ``M_sun/day``. """ if semiamplitude is None: - mass = -self.central.mass + mass: Quantity = -self.central.mass # type: ignore return self._get_position_and_velocity(t, mass=mass)[1] k = -semiamplitude * self.central.mass / self.mass return self._get_position_and_velocity(t, semiamplitude=k)[1] def central_velocity( - self, t: Scalar, semiamplitude: Optional[Scalar] = None - ) -> tuple[Scalar, Scalar, Scalar]: + self, t: Quantity, semiamplitude: Optional[Quantity] = None + ) -> tuple[Quantity, Quantity, Quantity]: """The central's velocity in the barycentric frame Args: @@ -511,8 +443,8 @@ def central_velocity( return v def relative_velocity( - self, t: Scalar, semiamplitude: Optional[Scalar] = None - ) -> tuple[Scalar, Scalar, Scalar]: + self, t: Quantity, semiamplitude: Optional[Quantity] = None + ) -> tuple[Quantity, Quantity, Quantity]: """This body's velocity relative to the central Args: @@ -527,14 +459,15 @@ def relative_velocity( ``M_sun/day``. """ if semiamplitude is None: - return self._get_position_and_velocity(t, mass=-self.total_mass)[1] + mass: Quantity = -self.total_mass # type: ignore + return self._get_position_and_velocity(t, mass=mass)[1] k = -semiamplitude * self.total_mass / self.mass _, v = self._get_position_and_velocity(t, semiamplitude=k) return v def radial_velocity( - self, t: Scalar, semiamplitude: Optional[Scalar] = None - ) -> Scalar: + self, t: Quantity, semiamplitude: Optional[Quantity] = None + ) -> Quantity: """Get the radial velocity of the central .. note:: The convention is that positive `z` points *towards* the @@ -551,20 +484,21 @@ def radial_velocity( Returns: The reflex radial velocity evaluated at ``t``. """ - return -self.central_velocity(t, semiamplitude=semiamplitude)[2] + return -self.central_velocity(t, semiamplitude=semiamplitude)[2] # type: ignore - def _warp_times(self, t: Scalar) -> Scalar: - return t - self.time_transit + def _warp_times(self, t: Quantity) -> Quantity: + return t - self.time_transit # type: ignore - def _get_true_anomaly(self, t: Scalar) -> tuple[Scalar, Scalar]: + @units.quantity_input(t=ureg.d) + def _get_true_anomaly(self, t: Quantity) -> tuple[Quantity, Quantity]: M = 2 * jnp.pi * (self._warp_times(t) - self.time_ref) / self.period if self.eccentricity is None: - return jnp.sin(M), jnp.cos(M) - return kepler(M, self.eccentricity) + return jnpu.sin(M), jnpu.cos(M) + return kepler(M.magnitude, self.eccentricity.magnitude) def _rotate_vector( - self, x: Scalar, y: Scalar, *, include_inclination: bool = True - ) -> tuple[Scalar, Scalar, Scalar]: + self, x: Quantity, y: Quantity, *, include_inclination: bool = True + ) -> tuple[Quantity, Quantity, Quantity]: """Apply the rotation matrices to go from orbit to observer frame In order, @@ -603,33 +537,62 @@ def _rotate_vector( # 3) rotate about z2 axis by Omega if self.cos_asc_node is None: - return (x2, y2, Z) + return x2, y2, Z # type: ignore X = self.cos_asc_node * x2 - self.sin_asc_node * y2 Y = self.sin_asc_node * x2 + self.cos_asc_node * y2 - return X, Y, Z - + return X, Y, Z # type: ignore + + @units.quantity_input( + t=ureg.d, + semimajor=ureg.R_sun, + mass=ureg.M_sun, + parallax=ureg.arcsec, + ) def _get_position_and_velocity( self, - t: Scalar, - semimajor: Optional[Scalar] = None, - mass: Optional[Scalar] = None, - semiamplitude: Optional[Scalar] = None, - parallax: Optional[Scalar] = None, - ) -> tuple[tuple[Scalar, Scalar, Scalar], tuple[Scalar, Scalar, Scalar]]: - sinf, cosf = self._get_true_anomaly(t) + t: Quantity, + semimajor: Optional[Quantity] = None, + mass: Optional[Quantity] = None, + semiamplitude: Optional[Quantity] = None, + parallax: Optional[Quantity] = None, + ) -> tuple[ + tuple[Quantity, Quantity, Quantity], tuple[Quantity, Quantity, Quantity] + ]: + if self.shape != (): + raise ValueError( + "Cannot evaluate the position or velocity of a Keplerian 'Body' " + "with multiple planets. Use 'jax.vmap' instead." + ) + + if semiamplitude is None: + semiamplitude = self.radial_velocity_semiamplitude + + if parallax is None: + parallax = self.parallax if semiamplitude is None: - factor = 1 if mass is None else mass - factor *= 1 if parallax is None else parallax * AU_PER_R_SUN - k0 = factor * self._baseline_rv_semiamplitude + if self.radial_velocity_semiamplitude is None: + m = 1 * ureg.M_sun if mass is None else mass + k0 = 2 * jnp.pi * self.semimajor * m / (self.total_mass * self.period) + if self.eccentricity is not None: + k0 /= jnpu.sqrt(1 - self.eccentricity**2) + else: + k0 = self.radial_velocity_semiamplitude + + if parallax is not None: + k0 = k0.to(ureg.au / ureg.d).magnitude * parallax / ureg.day else: k0 = semiamplitude r0 = 1 if semimajor is not None: - r0 *= semimajor if parallax is None else semimajor * parallax * AU_PER_R_SUN + if parallax is None: + r0 = semimajor + else: + r0 = semimajor.to(ureg.au).magnitude * parallax + sinf, cosf = self._get_true_anomaly(t) if self.eccentricity is None: v1, v2 = -k0 * sinf, k0 * cosf else: @@ -641,71 +604,90 @@ def _get_position_and_velocity( v1, v2, include_inclination=semiamplitude is None ) - return (x, y, z), (vx, vy, vz) + # In the case a semiamplitude was passed without units, we strip the + # units of the output + if semiamplitude is not None and not ( + hasattr(semiamplitude, "_magnitude") and hasattr(semiamplitude, "_units") + ): + vx = vx.magnitude + vy = vy.magnitude + vz = vz.magnitude + return (x, y, z), (vx, vy, vz) -class KeplerianOrbit(NamedTuple): - bodies: KeplerianBody - def __len__(self) -> int: - return len(self.bodies.period) +class System(eqx.Module): + central: Central + bodies: tuple[Body, ...] + _body_stack: Optional[Body] + + def __init__( + self, central: Optional[Central] = None, *, bodies: tuple[Body, ...] = () + ): + self.central = Central() if central is None else central + self.bodies = bodies + + # If all the bodies have matching Pytree structure then we save a + # stacked version that we can use for vmaps below. This allows for more + # efficient evaluations in the case of multiple bodies. + self._body_stack = None + if len(bodies): + spec = list(map(jax.tree_util.tree_structure, bodies)) + if spec.count(spec[0]) == len(spec): + self._body_stack = jax.tree_util.tree_map( + lambda *x: jnp.stack(x, axis=0), *bodies + ) @property def shape(self) -> tuple[int, ...]: - return self.bodies.shape + return (len(self.bodies),) @property - def radius(self) -> Scalar: - return self.bodies.radius + def radius(self) -> Quantity: + return jax.tree_util.tree_map( + lambda *x: jnp.stack(x, axis=0), + *[body.radius for body in self.bodies], + ) @property - def central_radius(self) -> Scalar: - return self.bodies.central.radius - - @classmethod - def init( - cls, central: Optional[KeplerianCentral] = None, **body_args: Scalar - ) -> "KeplerianOrbit": - names, values = zip(*body_args.items()) - values = jnp.broadcast_arrays(*(jnp.atleast_1d(x) for x in values)) - bodies = jax.vmap(partial(KeplerianBody.init, central=central))( - **dict(zip(names, values)) + def central_radius(self) -> Quantity: + return jax.tree_util.tree_map( + lambda *x: jnp.stack(x, axis=0), + *[body.central_radius for body in self.bodies], ) - return cls(bodies=bodies) - def position( - self, t: Scalar, parallax: Optional[Scalar] = None - ) -> tuple[Scalar, Scalar, Scalar]: - return jax.vmap(partial(KeplerianBody.position, t=t, parallax=parallax))( - self.bodies + def add_body(self, body: Optional[Body] = None, **kwargs: Any) -> "System": + if body is None: + body = Body(self.central, **kwargs) + return System(central=self.central, bodies=self.bodies + (body,)) + + def _body_vmap(self, func_name: str, t: Quantity) -> Any: + if self._body_stack is not None: + return jax.vmap(getattr(Body, func_name), in_axes=(0, None))( + self._body_stack, t + ) + return jax.tree_util.tree_map( + lambda *x: jnp.stack(x, axis=0), + *[getattr(body, func_name)(t) for body in self.bodies], ) - def central_position( - self, t: Scalar, parallax: Optional[Scalar] = None - ) -> tuple[Scalar, Scalar, Scalar]: - return jax.vmap( - partial(KeplerianBody.central_position, t=t, parallax=parallax) - )(self.bodies) + def position(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: + return self._body_vmap("position", t) - def relative_position( - self, t: Scalar, parallax: Optional[Scalar] = None - ) -> tuple[Scalar, Scalar, Scalar]: - return jax.vmap( - partial(KeplerianBody.relative_position, t=t, parallax=parallax) - )(self.bodies) + def central_position(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: + return self._body_vmap("central_position", t) - def velocity(self, t: Scalar) -> tuple[Scalar, Scalar, Scalar]: - return jax.vmap(partial(KeplerianBody.velocity, t=t))(self.bodies) + def relative_position(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: + return self._body_vmap("relative_position", t) - def central_velocity(self, t: Scalar) -> tuple[Scalar, Scalar, Scalar]: - return jax.vmap(partial(KeplerianBody.central_velocity, t=t))(self.bodies) + def velocity(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: + return self._body_vmap("velocity", t) - def relative_velocity(self, t: Scalar) -> tuple[Scalar, Scalar, Scalar]: - return jax.vmap(partial(KeplerianBody.relative_velocity, t=t))(self.bodies) + def central_velocity(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: + return self._body_vmap("central_velocity", t) - def radial_velocity( - self, t: Scalar, semiamplitude: Optional[Scalar] = None - ) -> tuple[Scalar, Scalar, Scalar]: - return jax.vmap(partial(KeplerianBody.radial_velocity, t=t))( - self.bodies, semiamplitude=semiamplitude - ) + def relative_velocity(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: + return self._body_vmap("relative_velocity", t) + + def radial_velocity(self, t: Quantity) -> Quantity: + return self._body_vmap("radial_velocity", t) diff --git a/src/jaxoplanet/orbits/transit.py b/src/jaxoplanet/orbits/transit.py index 4008228c..3f956b55 100644 --- a/src/jaxoplanet/orbits/transit.py +++ b/src/jaxoplanet/orbits/transit.py @@ -1,104 +1,93 @@ -from typing import NamedTuple, Optional +from typing import Optional -import jax +import equinox as eqx import jax.numpy as jnp +import jpu.numpy as jnpu -from jaxoplanet.types import Array +from jaxoplanet import units +# from jaxoplanet.types import Array +from jaxoplanet.types import Quantity +from jaxoplanet.units import unit_registry as ureg -class TransitOrbit(NamedTuple): - period: Array - speed: Array - duration: Array - time_transit: Array - impact_param: Array - radius: Array - @property - def shape(self) -> tuple[int, ...]: - return self.period.shape +class TransitOrbit(eqx.Module): + period: Quantity = units.field(units=ureg.d) + speed: Quantity = units.field(units=1 / ureg.d) + duration: Quantity = units.field(units=ureg.d) + time_transit: Quantity = units.field(units=ureg.d) + impact_param: Quantity = units.field(units=ureg.dimensionless) + radius: Quantity = units.field(units=ureg.dimensionless) - @classmethod - def init( - cls, + @units.quantity_input( + period=ureg.d, + duration=ureg.d, + speed=1 / ureg.d, + time_transit=ureg.d, + impact_param=ureg.dimensionless, + radius=ureg.dimensionless, + ) + def __init__( + self, *, - period: Array, - duration: Optional[Array] = None, - speed: Optional[Array] = None, - time_transit: Optional[Array] = None, - impact_param: Optional[Array] = None, - radius: Optional[Array] = None, - ) -> "TransitOrbit": + period: Quantity, + duration: Optional[Quantity] = None, + speed: Optional[Quantity] = None, + time_transit: Optional[Quantity] = None, + impact_param: Optional[Quantity] = None, + radius: Optional[Quantity] = None, + ): if duration is None: if speed is None: raise ValueError("Either 'speed' or 'duration' must be provided") - period, speed = jnp.broadcast_arrays( - jnp.atleast_1d(period), jnp.atleast_1d(speed) - ) + self.period = period + self.speed = speed + else: + self.period = period + self.duration = duration + + if time_transit is None: + self.time_transit = 0.0 * ureg.d else: - period, duration = jnp.broadcast_arrays( - jnp.atleast_1d(period), jnp.atleast_1d(duration) - ) - - shape = period.shape - time_transit = ( - jnp.zeros_like(period) - if time_transit is None - else jnp.broadcast_to(time_transit, shape) - ) - impact_param = ( - jnp.zeros_like(period) - if impact_param is None - else jnp.broadcast_to(impact_param, shape) - ) - radius = ( - jnp.zeros_like(period) - if radius is None - else jnp.broadcast_to(radius, shape) - ) - - x2 = (1 + radius) ** 2 - impact_param**2 + self.time_transit = time_transit + + if impact_param is None: + self.impact_param = 0.0 * ureg.dimensionless + else: + self.impact_param = impact_param + + if radius is None: + self.radius = 0.0 * ureg.dimensionless + else: + self.radius = radius + + x2 = jnpu.square(1 + self.radius) - jnpu.square(self.impact_param) if duration is None: - assert speed is not None - duration = 2 * jnp.sqrt(jnp.maximum(0, x2)) / speed + self.duration = 2 * jnpu.sqrt(jnpu.maximum(0, x2)) / self.speed else: - speed = 2 * jnp.sqrt(jnp.maximum(0, x2)) / duration + self.speed = 2 * jnpu.sqrt(jnpu.maximum(0, x2)) / self.duration - return cls( - period=period, - speed=speed, - duration=duration, - time_transit=time_transit, - impact_param=impact_param, - radius=radius, - ) + @property + def shape(self) -> tuple[int, ...]: + return jnp.shape(self.period) @property - def central_radius(self) -> Array: - return jnp.ones_like(self.period) + def central_radius(self) -> Quantity: + return jnp.ones_like(self.period) * ureg.dimensionless + @units.quantity_input(t=ureg.d, parallax=ureg.arcsec) def relative_position( - self, t: Array, parallax: Optional[Array] = None - ) -> tuple[Array, Array, Array]: + self, t: Quantity, parallax: Optional[Quantity] = None + ) -> tuple[Quantity, Quantity, Quantity]: del parallax - def impl( - period: Array, - speed: Array, - duration: Array, - time_transit: Array, - impact_param: Array, - radius: Array, - ) -> tuple[Array, Array, Array]: - del radius - half_period = 0.5 * period - ref_time = time_transit - half_period - dt = (t - ref_time) % period - half_period - - x = speed * dt - y = jnp.broadcast_to(impact_param, dt.shape) - m = jnp.abs(dt) < 0.5 * duration - z = m * 1.0 - (~m) * 1.0 - return x, y, z - - return jax.vmap(impl)(*self) + half_period = 0.5 * self.period + ref_time = self.time_transit - half_period + dt = jnpu.mod(t - ref_time, self.period) - half_period + + x = self.speed * dt + y = jnpu.full_like(dt, self.impact_param) + m = jnpu.fabs(dt) < 0.5 * self.duration + z = (m * 1.0 - (~m) * 1.0) * ureg.dimensionless + + return x, y, z diff --git a/src/jaxoplanet/proto.py b/src/jaxoplanet/proto.py index 3d4a8c58..4ad991a2 100644 --- a/src/jaxoplanet/proto.py +++ b/src/jaxoplanet/proto.py @@ -1,8 +1,6 @@ -from typing import Optional +from typing import Optional, Protocol -from typing_extensions import Protocol - -from jaxoplanet.types import Array +from jaxoplanet.types import Quantity class LightCurveBody(Protocol): @@ -11,42 +9,65 @@ def shape(self) -> tuple[int, ...]: ... @property - def radius(self) -> Array: + def radius(self) -> Quantity: ... -class LightCurveOrbit(LightCurveBody): - def relative_position( - self, t: Array, parallax: Optional[Array] = None - ) -> tuple[Array, Array, Array]: +class LightCurveOrbit(Protocol): + @property + def shape(self) -> tuple[int, ...]: + ... + + @property + def radius(self) -> Quantity: + ... + + def relative_position(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: + ... + + @property + def central_radius(self) -> Quantity: + ... + + +class Orbit(Protocol): + @property + def shape(self) -> tuple[int, ...]: ... @property - def central_radius(self) -> Array: + def radius(self) -> Quantity: ... + def relative_position( + self, t: Quantity, parallax: Optional[Quantity] = None + ) -> tuple[Quantity, Quantity, Quantity]: + ... + + @property + def central_radius(self) -> Quantity: + ... -class Orbit(LightCurveOrbit): def position( - self, t: Array, parallax: Optional[Array] = None - ) -> tuple[Array, Array, Array]: + self, t: Quantity, parallax: Optional[Quantity] = None + ) -> tuple[Quantity, Quantity, Quantity]: ... def central_position( - self, t: Array, parallax: Optional[Array] = None - ) -> tuple[Array, Array, Array]: + self, t: Quantity, parallax: Optional[Quantity] = None + ) -> tuple[Quantity, Quantity, Quantity]: ... - def velocity(self, t: Array) -> tuple[Array, Array, Array]: + def velocity(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: ... - def central_velocity(self, t: Array) -> tuple[Array, Array, Array]: + def central_velocity(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: ... - def relative_velocity(self, t: Array) -> tuple[Array, Array, Array]: + def relative_velocity(self, t: Quantity) -> tuple[Quantity, Quantity, Quantity]: ... def radial_velocity( - self, t: Array, semiamplitude: Optional[Array] = None - ) -> tuple[Array, Array, Array]: + self, t: Quantity, semiamplitude: Optional[Quantity] = None + ) -> tuple[Quantity, Quantity, Quantity]: ... diff --git a/src/jaxoplanet/test_utils.py b/src/jaxoplanet/test_utils.py index b4f23086..c5391cbd 100644 --- a/src/jaxoplanet/test_utils.py +++ b/src/jaxoplanet/test_utils.py @@ -1,11 +1,25 @@ -import jax.numpy as jnp -import numpy as np +from jax._src.public_test_util import check_close def assert_allclose(calculated, expected, *args, **kwargs): - dtype = jnp.result_type(calculated, expected) - if dtype == jnp.float64: - kwargs["atol"] = kwargs.get("atol", 2e-5) + kwargs["rtol"] = kwargs.get( + "rtol", + { + "float32": 5e-4, + "float64": 5e-7, + }, + ) + check_close(calculated, expected, *args, **kwargs) + + +def assert_quantity_allclose(calculated, expected, *args, convert=False, **kwargs): + if convert: + assert_allclose( + calculated.magnitude, + expected.to(calculated.units).magnitude, + *args, + **kwargs, + ) else: - kwargs["atol"] = kwargs.get("atol", 2e-2) - np.testing.assert_allclose(calculated, expected, *args, **kwargs) + assert calculated.units == expected.units + assert_allclose(calculated.magnitude, expected.magnitude, *args, **kwargs) diff --git a/src/jaxoplanet/types.py b/src/jaxoplanet/types.py index dd69190c..0b1899a0 100644 --- a/src/jaxoplanet/types.py +++ b/src/jaxoplanet/types.py @@ -3,3 +3,4 @@ Scalar = Any Array = Any PyTree = Any +Quantity = Any diff --git a/src/jaxoplanet/units/__init__.py b/src/jaxoplanet/units/__init__.py new file mode 100644 index 00000000..83035fec --- /dev/null +++ b/src/jaxoplanet/units/__init__.py @@ -0,0 +1,3 @@ +from jaxoplanet.units.decorator import quantity_input as quantity_input +from jaxoplanet.units.field import field as field +from jaxoplanet.units.registry import unit_registry as unit_registry diff --git a/src/jaxoplanet/units/astro_constants_and_units.txt b/src/jaxoplanet/units/astro_constants_and_units.txt new file mode 100644 index 00000000..efe638c7 --- /dev/null +++ b/src/jaxoplanet/units/astro_constants_and_units.txt @@ -0,0 +1,21 @@ +# Constants from astropy.constants for consistency + +# IAU astronomical constants +AU = 1.49597870700e11 m = au = astronomical_unit +gravitational_constant = 6.67430e-11 m^3/kg/s^2 = G +parsec = astronomical_unit / tansec = pc +L_bol0 = 3.0128e28 W = _ + +# Solar constants +L_sun = 3.828e26 W = L_sun = L_solar = solar_luminosity +GM_sun = 1.3271244e20 m^3/s^2 = GM_sun = GM_solar = solar_mass_parameter +M_sun = solar_mass_parameter / gravitational_constant = M_sun = M_solar = solar_mass +R_sun = 6.957e8 m = R_sun = R_solar = solar_radius + +# Solar system constants +GM_jupiter = 1.2668653e17 m^3/s^2 = GM_jup = jupiter_mass_parameter +M_jupiter = jupiter_mass_parameter / gravitational_constant = M_jup = jupiter_mass +R_jupiter = 7.1492e7 m = R_jup = jupiter_radius +GM_earth = 3.986004e14 m^3/s^2 = GM_earth = earth_mass_parameter +M_earth = earth_mass_parameter / gravitational_constant = M_earth = earth_mass +R_earth = 6.3781e6 m = R_earth = earth_radius diff --git a/src/jaxoplanet/units/decorator.py b/src/jaxoplanet/units/decorator.py new file mode 100644 index 00000000..a9b8b646 --- /dev/null +++ b/src/jaxoplanet/units/decorator.py @@ -0,0 +1,201 @@ +import inspect +from functools import partial, wraps +from typing import Any, Callable, Optional + +import jax +from pint import DimensionalityError + +from jaxoplanet.units.registry import unit_registry + + +def quantity_input( + func: Optional[Callable[..., Any]] = None, + *, + _strict: bool = False, + **kwargs: Any, +) -> Any: + """A decorator to wrap functions that require quantities as inputs + + Please note, this is similar to the decorator of the same name from + ``astropy.units``, but the behavior is slightly different, in ways that + we'll try to highlight here. + + This decorator will wrap a function and check or convert the units of all + inputs, such that the wrapped function can assume that all input units are + correct. Note that all arguments must be specified by name (even when they) + are positional, and variable ``*args`` and ``**kwargs`` arguments are not + supported. + + By default, if a non-``Quantity`` is provided, it will be assumed to be in + the correct units, and converted to a ``Quantity``. If the ``_strict`` flag + is instead set to ``True``, inputting a non-``Quantity`` will raise a + ``ValueError``. + + Examples + -------- + + The following function expects a length in meters and a time in seconds, and + it returns a speed in meters per second: + + .. code-block:: python + + from jaxoplanet.units import unit_registry as ureg + + @units.quantity_input(a=ureg.m, b=ureg.s) + def speed(a, b): + return a / b + + If we call this function with a length and a time, it will work as expected: + + .. code-block:: python + + speed(1.5 * ureg.m, 0.5 * ureg.s) + + And it will also handle unit conversions: + + .. code-block:: python + + speed(1.5 * ureg.AU, 0.5 * ureg.day) # The result will still be in m/s + + To skip validating specific inputs, you can set the unit to ``None``, or + omit it from the decorator arguments: + + .. code-block:: python + + @units.quantity_input(x=ureg.m) # optionally include flag=None + def condition(x, flag): + if flag: + return x + else: + return 0.0 * x + + JAX Pytree support + ------------------ + + This decorator also supports JAX Pytrees, so you can use it to wrap + functions with structured inputs. For example, we could rewrite the + ``speed`` example from above as: + + .. code-block:: python + + @units.quantity_input(params={"distance": ureg.m, "time": ureg.s}) + def speed(params): + return params["distance"] / params["time"] + + This will work with arbitrary Pytrees, as long as structure of the input + fully matches the decorator argument. In other words, since the full Pytree + structure must be specified, you'll need to explicitly list ``None`` for any + Pytree nodes that you want to skip during validation: + + .. code-block:: python + + # Omitting `flag` from the decorator wouldn't work here + @units.quantity_input(params={"x": ureg.m, "flag": None}) + def condition(params): + if params["flag"]: + return params["x"] + else: + return 0.0 * params["x"] + """ + if func is None: + return QuantityInput(_strict=_strict, **kwargs) + else: + if not callable(func): + raise TypeError( + "The first argument to 'quantity_input' must be a callable function, " + "and all unit specifications must be passed as keyword arguments " + "by name" + ) + return QuantityInput(_strict=_strict, **kwargs)(func) + + +class QuantityInput: + """This helper class implements the logic for ``quantity_input`` + + Typically users should expect to interact primarily with the + ``quantity_input`` decorator directly instead of this class, but this + enables the use of ``quantity_input`` either as a decorator with arguments + or as a single function call. + """ + + def __init__( + self, + *, + _strict: bool = False, + **kwargs: Any, + ): + self.decorator_kwargs = kwargs + self.strict = _strict + + def __call__(self, func: Callable) -> Callable: + signature = inspect.signature(func) + bound_input_units = signature.bind_partial(**self.decorator_kwargs) + + input_unit_map = {} + for param in signature.parameters.values(): + if param.kind in ( + inspect.Parameter.VAR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + ): + # We typically ignore *args and **kwargs, but if they were + # explicitly specified, we raise an error since we don't support + # that use case + if param.name in bound_input_units.arguments: + raise TypeError( + "Units for general variable arguments and keyword " + "arguments are not supported" + ) + else: + input_unit_map[param.name] = bound_input_units.arguments.get( + param.name, None + ) + + @wraps(func) + def wrapped(*args: Any, **kwargs: Any) -> Any: + bound_args = signature.bind(*args, **kwargs) + bound_args.apply_defaults() + + for name, value in bound_args.arguments.items(): + # If the value and the default value is None, pass through + # ignoring units + if value is None and signature.parameters[name].default is None: + continue + + unit = input_unit_map.get(name, None) + if unit is not None: + bound_args.arguments[name] = jax.tree_util.tree_map( + partial(_apply_units, name=name, strict=self.strict), + value, + unit, + is_leaf=_is_quantity, + ) + + return func(*bound_args.args, **bound_args.kwargs) + + return wrapped + + +def _apply_units( + value: Any, units: Any, strict: bool = False, name: Optional[str] = None +) -> Any: + if units is None: + return value + if _is_quantity(value): + try: + return value.to(units) + except DimensionalityError as e: + raise DimensionalityError( + e.units1, + e.units2, + e.dim1, + e.dim2, + "" if name is None else f" for input '{name}'", + ) from None + elif strict: + raise ValueError("Arguments must be quantities for strict parsing") + else: + return unit_registry.Quantity(value, units) + + +def _is_quantity(x: Any) -> bool: + return hasattr(x, "_magnitude") and hasattr(x, "_units") diff --git a/src/jaxoplanet/units/field.py b/src/jaxoplanet/units/field.py new file mode 100644 index 00000000..b8e41396 --- /dev/null +++ b/src/jaxoplanet/units/field.py @@ -0,0 +1,31 @@ +from functools import partial +from typing import Any + +import equinox as eqx +import jax + +from jaxoplanet.units.decorator import _apply_units, _is_quantity + + +def field(*, units: Any = None, strict: bool = False, **kwargs: Any) -> Any: + """A custom Equinox field with support for units + + This is a wrapper around :func:`equinox.field` that adds support for units. + The ``units`` argument is used to specify the units of the field either as a + plain ``Unit`` or string, or as a Pytree with the same structure as the + expected input. + """ + if units is None: + return eqx.field(**kwargs) + + original_converter = kwargs.pop("converter", lambda x: x) + + def converter(value: Any) -> Any: + return jax.tree_util.tree_map( + partial(_apply_units, strict=strict), + original_converter(value), + units, + is_leaf=_is_quantity, + ) + + return eqx.field(converter=converter, **kwargs) diff --git a/src/jaxoplanet/units/registry.py b/src/jaxoplanet/units/registry.py new file mode 100644 index 00000000..78ce6815 --- /dev/null +++ b/src/jaxoplanet/units/registry.py @@ -0,0 +1,9 @@ +from importlib.resources import as_file, files + +import jpu + +unit_registry = jpu.UnitRegistry() +with as_file( + files("jaxoplanet.units").joinpath("astro_constants_and_units.txt") +) as path: + unit_registry.load_definitions(path) diff --git a/tests/core/limb_dark_test.py b/tests/core/limb_dark_test.py index dd2c9224..51e1f0eb 100644 --- a/tests/core/limb_dark_test.py +++ b/tests/core/limb_dark_test.py @@ -31,13 +31,13 @@ def test_edge_cases(u, r): assert np.isfinite(calc) if len(u) == 2: expect = exoplanet_core.quad_limbdark_light_curve(*u, b, r) - assert_allclose(calc, expect) + assert_allclose(calc, expect[0]) for n in range(3): g = jax.grad(light_curve, argnums=n)(u, b, r) assert np.all(np.isfinite(g)) - if jax.config.jax_enable_x64: + if jax.config.jax_enable_x64: # type: ignore for b in [0.0, 0.5, 1.0, r, 1 + 2 * r]: if np.allclose(b, r) or np.allclose(np.abs(b - r), 1): continue diff --git a/tests/experimental/starry/solution_test.py b/tests/experimental/starry/solution_test.py index a6564cf1..02dc0a4d 100644 --- a/tests/experimental/starry/solution_test.py +++ b/tests/experimental/starry/solution_test.py @@ -74,4 +74,5 @@ def test_solution_compare_starry(r, l_max=10, order=20): s_calc[:, n], s_expect[:, n], err_msg=f"n={n}, l={l}, m={m}, mu={mu}, nu={nu}, case={case}", + atol=1e-6, ) diff --git a/tests/keplerian_test.py b/tests/keplerian_test.py deleted file mode 100644 index 24f19caf..00000000 --- a/tests/keplerian_test.py +++ /dev/null @@ -1,12 +0,0 @@ -import jax.numpy as jnp - -from jaxoplanet import orbits - - -def test_keplerian_central_shape(): - assert orbits.KeplerianCentral.init(mass=0.98, radius=0.93).shape == () - - -def test_casting_dtype(): - orbit = orbits.KeplerianBody.init(period=1) - assert orbit.period.dtype in (jnp.float32, jnp.float64) diff --git a/tests/light_curves_test.py b/tests/light_curves_test.py index 87444b5a..e69de29b 100644 --- a/tests/light_curves_test.py +++ b/tests/light_curves_test.py @@ -1,40 +0,0 @@ -import jax.numpy as jnp - -from jaxoplanet import light_curves, orbits -from jaxoplanet.test_utils import assert_allclose - - -def test_keplerian_basic(): - t = jnp.linspace(-1.0, 10.0, 1000) - - mstar = 0.98 - rstar = 0.93 - ld = [0.1, 0.3] - - time_transit = jnp.array([0.0, 0.5]) - period = jnp.array([1, 4.5]) - b = jnp.array([0.5, 0.2]) - r = jnp.array([0.1, 0.3]) - - host_star = orbits.KeplerianCentral.init(mass=mstar, radius=rstar) - orbit_both = orbits.KeplerianOrbit.init( - central=host_star, - time_transit=time_transit, - period=period, - impact_param=b, - radius=r, - ) - lc_both = light_curves.LimbDarkLightCurve.init(ld[0], ld[1]).light_curve( - orbit_both, t - ) - - for n in range(2): - orbit = orbits.KeplerianBody.init( - central=host_star, - time_transit=time_transit[n], - period=period[n], - impact_param=b[n], - radius=r[n], - ) - lc = light_curves.LimbDarkLightCurve.init(ld[0], ld[1]).light_curve(orbit, t) - assert_allclose(lc, lc_both[n]) diff --git a/tests/orbits/keplerian_test.py b/tests/orbits/keplerian_test.py new file mode 100644 index 00000000..22941fde --- /dev/null +++ b/tests/orbits/keplerian_test.py @@ -0,0 +1,206 @@ +import jax +import jax.experimental +import jax.numpy as jnp +import jpu.numpy as jnpu +import numpy as np +import pytest + +from jaxoplanet.orbits import keplerian +from jaxoplanet.test_utils import assert_allclose, assert_quantity_allclose +from jaxoplanet.units import unit_registry as ureg + + +@pytest.fixture( + params=[ + { + "central": keplerian.Central(mass=1.3, radius=1.1), + "mass": 0.1, + "time_transit": 0.1, + "period": 12.5, + "inclination": 0.3, + }, + { + "central": keplerian.Central(mass=1.3, radius=1.1), + "mass": 0.1, + "time_transit": 0.1, + "period": 12.5, + "inclination": 0.3, + "eccentricity": 0.3, + "omega_peri": -1.5, + "asc_node": 0.3, + }, + ] +) +def keplerian_body(request): + return keplerian.Body(**request.param) + + +@pytest.fixture +def time(): + return jnp.linspace(-50.0, 50.0, 500) * ureg.day + + +def test_keplerian_central_density(): + star = keplerian.Central() + assert_quantity_allclose( + star.density, 1.4 * ureg.g / ureg.cm**3, atol=0.01, convert=True + ) + + +def test_keplerian_body_keplers_law(): + orbit = keplerian.Body(semimajor=1.0 * ureg.au) + assert_quantity_allclose(orbit.period, 1.0 * ureg.year, atol=0.01, convert=True) + + orbit = keplerian.Body(period=1.0 * ureg.year) + assert_quantity_allclose(orbit.semimajor, 1.0 * ureg.au, atol=0.01, convert=True) + + +@pytest.mark.parametrize("prefix", ["", "central_", "relative_"]) +def test_keplerian_body_velocity(time, keplerian_body, prefix): + v = getattr(keplerian_body, f"{prefix}velocity")(time) + for i, v_ in enumerate(v): + pos_func = getattr(keplerian_body, f"{prefix}position") + assert_allclose( + v_.magnitude, + jax.vmap(jax.grad(lambda t: pos_func(t)[i].magnitude))(time).magnitude, + ) + + +def test_keplerian_body_radial_velocity(time, keplerian_body): + computed = keplerian_body.radial_velocity(time) + assert computed.units == ureg.R_sun / ureg.d + computed.to(ureg.m / ureg.s) + + computed = keplerian_body.radial_velocity(time, semiamplitude=1.0) + assert not hasattr(computed, "_magnitude") + assert not hasattr(computed, "_units") + + +def test_keplerian_body_impact_parameter(keplerian_body): + x, y, z = keplerian_body.relative_position(keplerian_body.time_transit) + assert_quantity_allclose( + (jnpu.sqrt(x**2 + y**2) / keplerian_body.central.radius), + keplerian_body.impact_param, + ) + assert jnpu.all(z > 0) + + +def test_keplerian_body_coordinates_match_batman(time, keplerian_body): + _rsky = pytest.importorskip("batman._rsky") + with jax.experimental.enable_x64(True): + r_batman = _rsky._rsky( + np.array(time.magnitude, dtype=np.float64), + float(keplerian_body.time_transit.magnitude), + float(keplerian_body.period.magnitude), + float(keplerian_body.semimajor.magnitude), + float(keplerian_body.inclination.magnitude), + float(keplerian_body.eccentricity.magnitude) + if keplerian_body.eccentricity + else 0.0, + float(keplerian_body.omega_peri.magnitude) + if keplerian_body.omega_peri + else 0.0, + 1, + 1, + ) + m = r_batman < 100.0 + assert m.sum() > 0 + + x, y, z = keplerian_body.relative_position(time) + r = jnpu.sqrt(x**2 + y**2) + + # Make sure that the in-transit impact parameter matches batman + assert_allclose(r_batman[m], r.magnitude[m]) + + # In-transit should correspond to positive z in our parameterization + assert np.all(z.magnitude[m] > 0) + + # Therefore, when batman doesn't see a transit we shouldn't be transiting + no_transit = z.magnitude[~m] < 0 + no_transit |= r.magnitude[~m] > 2 + assert np.all(no_transit) + + +def test_keplerian_body_positions_small_star(time): + _rsky = pytest.importorskip("batman._rsky") + with jax.experimental.enable_x64(True): + keplerian_body = keplerian.Body( + central=keplerian.Central(radius=0.189, mass=0.151), + period=0.4626413, + time_transit=0.2, + impact_param=0.5, + eccentricity=0.1, + omega_peri=0.1, + ) + + r_batman = _rsky._rsky( + np.array(time.magnitude, dtype=np.float64), + float(keplerian_body.time_transit.magnitude), + float(keplerian_body.period.magnitude), + float(keplerian_body.semimajor.magnitude), + float(keplerian_body.inclination.magnitude), + float(keplerian_body.eccentricity.magnitude) + if keplerian_body.eccentricity + else 0.0, + float(keplerian_body.omega_peri.magnitude) + if keplerian_body.omega_peri + else 0.0, + 1, + 1, + ) + m = r_batman < 100.0 + assert m.sum() > 0 + + x, y, _ = keplerian_body.relative_position(time) + r = jnpu.sqrt(x**2 + y**2) + assert_allclose(r_batman[m], r[m].magnitude) + + +def test_keplerian_system_stack_construction(): + sys = keplerian.System().add_body(period=0.1).add_body(period=0.2) + assert_quantity_allclose(sys._body_stack.period, jnp.array([0.1, 0.2]) * ureg.day) + + sys = ( + keplerian.System() + .add_body(period=0.1) + .add_body(period=0.2, eccentricity=0.1, omega_peri=-1.5) + ) + assert sys._body_stack is None + + +def test_keplerian_system_radial_velocity(): + sys1 = keplerian.System().add_body(period=0.1).add_body(period=0.2) + sys2 = ( + keplerian.System() + .add_body(period=0.1) + .add_body(period=0.2, eccentricity=0.0, omega_peri=0.0) + ) + assert sys1._body_stack is not None + assert sys2._body_stack is None + with pytest.raises(ValueError): + sys1._body_stack.radial_velocity(0.0) + + for t in [0.0, jnp.linspace(0, 1, 5)]: + assert_quantity_allclose( + sys1.radial_velocity(t), + sys2.radial_velocity(t), + # TODO(dfm): I'm not sure why we need to loosen the tolerance here, + # but the ecc=0 model doesn't give the same results as the ecc=None + # model otherwise. + atol={jnp.float32: 5e-6, jnp.float64: 1e-12}, + ) + + +def test_keplerian_system_position(): + sys1 = keplerian.System().add_body(period=0.1).add_body(period=0.2) + sys2 = ( + keplerian.System() + .add_body(period=0.1) + .add_body(period=0.2, eccentricity=0.0, omega_peri=0.0) + ) + for t in [0.0, jnp.linspace(0, 1, 5)]: + x1, y1, z1 = sys1.position(t) + x2, y2, z2 = sys2.position(t) + assert_quantity_allclose(x1, x2) + assert_quantity_allclose(y1, y2) + assert_quantity_allclose(z1, z2) diff --git a/tests/orbits_test.py b/tests/orbits_test.py deleted file mode 100644 index a9ad5a33..00000000 --- a/tests/orbits_test.py +++ /dev/null @@ -1,202 +0,0 @@ -# mypy: ignore-errors - - -import jax -import jax.numpy as jnp -import numpy as np -import pytest - -from jaxoplanet.orbits import KeplerianBody, KeplerianCentral, KeplerianOrbit -from jaxoplanet.test_utils import assert_allclose - - -def test_sky_coords(): - _rsky = pytest.importorskip("batman._rsky") - - t = np.linspace(-100, 100, 1000) - - t0, period, a, e, omega, incl = ( - x.flatten() - for x in np.meshgrid( - np.linspace(-5.0, 5.0, 2), - np.exp(np.linspace(np.log(5.0), np.log(50.0), 3)), - np.linspace(50.0, 100.0, 2), - np.linspace(0.0, 0.9, 5), - np.linspace(-np.pi, np.pi, 3), - np.arccos(np.linspace(0, 1, 5)[:-1]), - ) - ) - r_batman = np.empty((len(t), len(t0))) - - for i in range(len(t0)): - r_batman[:, i] = _rsky._rsky( - t, t0[i], period[i], a[i], incl[i], e[i], omega[i], 1, 1 - ) - m = r_batman < 100.0 - assert m.sum() > 0 - - def get_r(**kwargs): - orbit = KeplerianBody.init(**kwargs) - x, y, z = orbit.relative_position(t) - return jnp.sqrt(x**2 + y**2), z - - r, z = jax.vmap(get_r, out_axes=1)( - period=period, - semimajor=a, - time_transit=t0, - eccentricity=e, - omega_peri=omega, - inclination=incl, - ) - - # Make sure that the in-transit impact parameter matches batman - assert_allclose(r_batman[m], r[m]) - - # In-transit should correspond to positive z in our parameterization - assert np.all(z[m] > 0) - - # Therefore, when batman doesn't see a transit we shouldn't be transiting - no_transit = z[~m] < 0 - no_transit |= r[~m] > 2 - assert np.all(no_transit) - - -def test_center_of_mass(): - t = np.linspace(0, 100, 1000) - m_planet = np.array([0.5, 0.1]) - m_star = 1.45 - - orbit = KeplerianOrbit.init( - central=KeplerianCentral.init(mass=m_star, radius=1.0), - time_transit=np.array([0.5, 17.4]), - period=np.array([100.0, 37.3]), - eccentricity=np.array([0.1, 0.8]), - omega_peri=np.array([0.5, 1.3]), - asc_node=np.array([0.0, 1.0]), - inclination=np.array([0.25 * np.pi, 0.3 * np.pi]), - mass=m_planet, - ) - - coords = np.asarray(orbit.position(t)) - central_coords = np.asarray(orbit.central_position(t)) - com = np.sum( - (m_planet[..., None] * coords + m_star * central_coords) - / (m_star + m_planet)[..., None], - axis=0, - ) - assert_allclose(com, 0.0) - - -def test_velocity(): - t = np.linspace(0, 100, 1000) - m_planet = 0.1 - m_star = 1.3 - orbit = KeplerianOrbit.init( - central=KeplerianCentral.init(mass=m_star, radius=1.0), - time_transit=0.5, - period=100.0, - eccentricity=0.1, - omega_peri=0.5, - asc_node=1.0, - inclination=0.25 * np.pi, - mass=m_planet, - ) - - computed = orbit.central_velocity(t) - for n in range(3): - for i in range(len(orbit)): - expected = jax.vmap(jax.grad(lambda t: orbit.central_position(t)[n][i]))(t) - assert_allclose(computed[n][i], expected) - - computed = orbit.velocity(t) - for n in range(3): - for i in range(len(orbit)): - expected = jax.vmap(jax.grad(lambda t: orbit.position(t)[n][i]))(t) - assert_allclose(computed[n][i], expected) - - computed = orbit.relative_velocity(t) - for n in range(3): - for i in range(len(orbit)): - expected = jax.vmap(jax.grad(lambda t: orbit.relative_position(t)[n][i]))(t) - assert_allclose(computed[n][i], expected) - - -def test_radial_velocity(): - t = np.linspace(0, 100, 1000) - m_planet = 0.1 - m_star = 1.3 - orbit = KeplerianOrbit.init( - central=KeplerianCentral.init(mass=m_star, radius=1.0), - time_transit=0.5, - period=100.0, - eccentricity=0.1, - omega_peri=0.5, - asc_node=1.0, - inclination=0.25 * np.pi, - mass=m_planet, - ) - - computed = orbit.radial_velocity(t) - expected = orbit.radial_velocity( - t, - semiamplitude=orbit.bodies._baseline_rv_semiamplitude - * orbit.bodies.mass - * orbit.bodies.sin_inclination, - ) - assert_allclose(expected, computed) - - -def test_small_star(): - _rsky = pytest.importorskip("batman._rsky") - - m_star = 0.151 - r_star = 0.189 - period = 0.4626413 - t0 = 0.2 - b = 0.5 - ecc = 0.1 - omega = 0.1 - t = np.linspace(0, period, 500) - - orbit = KeplerianOrbit.init( - central=KeplerianCentral.init(radius=r_star, mass=m_star), - period=period, - time_transit=t0, - impact_param=b, - eccentricity=ecc, - omega_peri=omega, - ) - a = float(orbit.bodies.semimajor[0]) - incl = float( - jnp.arctan2(orbit.bodies.sin_inclination, orbit.bodies.cos_inclination)[0] - ) - - r_batman = _rsky._rsky(t, t0, period, a, incl, ecc, omega, 1, 1) - m = r_batman < 100.0 - assert m.sum() > 0 - - x, y, _ = orbit.relative_position(t) - r = np.sqrt(x**2 + y**2) - assert_allclose(r_batman[m], r[0, m]) - - -def test_impact(): - m_star = 0.151 - r_star = 0.189 - period = 0.4626413 - t0 = 0.2 - b = 0.5 - ecc = 0.8 - omega = 0.1 - - orbit = KeplerianOrbit.init( - central=KeplerianCentral.init(radius=r_star, mass=m_star), - period=period, - time_transit=t0, - impact_param=b, - eccentricity=ecc, - omega_peri=omega, - ) - x, y, z = orbit.relative_position(t0) - assert_allclose((jnp.sqrt(x**2 + y**2) / r_star), b) - assert jnp.all(z > 0) diff --git a/tests/units_test.py b/tests/units_test.py new file mode 100644 index 00000000..a0d65a79 --- /dev/null +++ b/tests/units_test.py @@ -0,0 +1,174 @@ +from typing import Any + +import equinox as eqx +import jax +import pytest +from pint import DimensionalityError + +from jaxoplanet import units +from jaxoplanet.test_utils import assert_quantity_allclose +from jaxoplanet.units import unit_registry as ureg + + +def test_quantity_input(): + @units.quantity_input(a=ureg.m, b=ureg.s) + def func(a, b): + assert a.units == ureg.m + assert b.units == ureg.s + return a / b + + assert_quantity_allclose(func(1.5 * ureg.m, 0.5 * ureg.s), 3.0 * ureg.m / ureg.s) + + +def test_quantity_input_functional(): + def func(a, b): + assert a.units == ureg.m + assert b.units == ureg.s + return a / b + + func = units.quantity_input(func, a=ureg.m, b=ureg.s) + assert_quantity_allclose(func(1.5 * ureg.m, 0.5 * ureg.s), 3.0 * ureg.m / ureg.s) + + +def test_quantity_input_string(): + @units.quantity_input(a="m", b="s") + def func(a, b): + assert a.units == ureg.m + assert b.units == ureg.s + return a / b + + assert_quantity_allclose(func(1.5 * ureg.m, 0.5 * ureg.s), 3.0 * ureg.m / ureg.s) + + +def test_quantity_input_varargs(): + @units.quantity_input(a=ureg.m, b=ureg.s) + def func(a, b, *args, **kwargs): + assert a.units == ureg.m + assert b.units == ureg.s + return a / b + + assert_quantity_allclose(func(1.5 * ureg.m, 0.5 * ureg.s), 3.0 * ureg.m / ureg.s) + + +def test_quantity_input_without_units(): + @units.quantity_input(a=ureg.m, b=ureg.s) + def func(a, b): + return a / b + + assert_quantity_allclose(func(1.5 * ureg.m, 0.5), 3.0 * ureg.m / ureg.s) + assert_quantity_allclose(func(1.5, 0.5 * ureg.s), 3.0 * ureg.m / ureg.s) + + +def test_quantity_input_skip_parameter_explicit(): + @units.quantity_input(a=ureg.m, b=None) + def func(a, b): + return a * b + + assert_quantity_allclose(func(1.5 * ureg.m, 2.0), 3.0 * ureg.m) + + +def test_quantity_input_skip_parameter_implicit(): + @units.quantity_input(a=ureg.m) + def func(a, b): + return a * b + + assert_quantity_allclose(func(1.5 * ureg.m, 2.0), 3.0 * ureg.m) + + +def test_quantity_input_conversion(): + @units.quantity_input(a=ureg.m, b=ureg.s) + def func(a, b): + return a / b + + assert_quantity_allclose( + func(150.0 * ureg.cm, 500.0 * ureg.ms), 3.0 * ureg.m / ureg.s + ) + + +@pytest.mark.parametrize("with_jit", [False, True]) +def test_quantity_input_pytree(with_jit): + @units.quantity_input(params={"a": ureg.m, "b": ureg.s}) + def func(params): + assert params["a"].units == ureg.m + assert params["b"].units == ureg.s + return params["a"] / params["b"] + + if with_jit: + func = jax.jit(func) + + assert_quantity_allclose( + func({"a": 1.5 * ureg.m, "b": 0.5 * ureg.s}), 3.0 * ureg.m / ureg.s + ) + assert_quantity_allclose(func({"a": 1.5 * ureg.m, "b": 0.5}), 3.0 * ureg.m / ureg.s) + assert_quantity_allclose( + func({"a": 150.0 * ureg.cm, "b": 500.0 * ureg.ms}), 3.0 * ureg.m / ureg.s + ) + + +def test_quantity_input_invalid(): + @units.quantity_input(a=ureg.m, b=ureg.s) + def func(a, b): + return a / b + + with pytest.raises(DimensionalityError): + func(150.0 * ureg.hr, 0.5 * ureg.s) + + +def test_quantity_input_unrecognized_argument(): + with pytest.raises(TypeError, match="got an unexpected keyword argument 'x'"): + + @units.quantity_input(a=ureg.m, b=ureg.s, x=ureg.cm) + def func(a, b): + return a / b + + +def test_quantity_input_positional_unit(): + with pytest.raises(TypeError): + + @units.quantity_input(ureg.m, b=ureg.s) # type: ignore + def func(a, b): + return a / b + + +def test_quantity_input_varargs_error(): + with pytest.raises(TypeError): + + @units.quantity_input(a=ureg.m, b=ureg.s, args=ureg.cm) + def func(a, b, *args, **kwargs): + assert a.units == ureg.m + assert b.units == ureg.s + return a / b + + +def test_field(): + class Model(eqx.Module): + x: Any = units.field(units=ureg.m) + + assert_quantity_allclose(Model(x=1.5).x, 1.5 * ureg.m) + assert_quantity_allclose(Model(x=1.5 * ureg.m).x, 1.5 * ureg.m) + assert_quantity_allclose(Model(x=150.0 * ureg.cm).x, 1.5 * ureg.m) + + +def test_field_converter(): + class Model(eqx.Module): + x: Any = units.field(units=ureg.m, converter=lambda x: 2 * x) + + assert_quantity_allclose(Model(x=1.5).x, 3 * ureg.m) + + +def test_field_pytree(): + class Model(eqx.Module): + x: Any = units.field(units={"a": ureg.m, "b": ureg.s}) + + model = Model(x={"a": 150.0 * ureg.cm, "b": 0.5}) + assert_quantity_allclose(model.x["a"], 1.5 * ureg.m) + assert_quantity_allclose(model.x["b"], 0.5 * ureg.s) + + +def test_field_optional(): + class Model(eqx.Module): + x: Any = units.field(units=ureg.m) + + model = Model(x=None) + assert model.x is None + assert_quantity_allclose(Model(x=1.5).x, 1.5 * ureg.m)