diff --git a/docs/_autoapi_templates/index.rst b/docs/_autoapi_templates/index.rst new file mode 100644 index 00000000..e48a192b --- /dev/null +++ b/docs/_autoapi_templates/index.rst @@ -0,0 +1,15 @@ +:orphan: + +Full API Reference +================== + +This page contains auto-generated API reference documentation [#f1]_. + +.. toctree:: + :titlesonly: + + {% for page in pages|selectattr("is_top_level_object") %} + {{ page.include_path }} + {% endfor %} + +.. [#f1] Created with `sphinx-autoapi `_ diff --git a/docs/_static/favicon.png b/docs/_static/favicon.png index 787d536f..f77812b6 100644 Binary files a/docs/_static/favicon.png and b/docs/_static/favicon.png differ diff --git a/docs/_static/logo.png b/docs/_static/logo.png index 89c49452..9f0aa21a 100644 Binary files a/docs/_static/logo.png and b/docs/_static/logo.png differ diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 00000000..b523409e --- /dev/null +++ b/docs/api.md @@ -0,0 +1,55 @@ +# API reference + +These pages contain the API reference for the `jaxoplanet` package. + +## Orbital systems + +The modules to define orbital systems are: + +- [keplerian](jaxoplanet.orbits.keplerian) : a module to define a Keplerian system made of a central object and its orbiting bodies. +- [transit](jaxoplanet.orbits.transit) : a module to define a system made of a star and its transiting planet, defined by the transit parameters assuming the planet transits at a constant velocity. + +## Limb-darkened stars + +Given an orbital system, the following modules can be used to compute different kind of light curves: + +- [limb-darkened](jaxoplanet.light_curves.limb_dark) : a module to compute occultation light curve of a star with polynomial limb darkening. +- [transforms](jaxoplanet.light_curves.transforms) : a module providing decorators to transform light curve functions. + + +## Non-uniform surfaces + + +```{warning} +While being stable, computing *starry* light curves of non-uniform surfaces is still experimental. +``` + +The following modules can be used to create systems of bodies with limb-darkened and non-uniform emitting surfaces: + +- [starry orbit](jaxoplanet.experimental.starry.orbit) : a module to define a system made of a central object and orbiting bodies, all with non-uniform emitting surfaces. (experimental) +- [starry light curve](jaxoplanet.experimental.starry.light_curves) : a module to compute the light curve of a non-uniform star whose surface is represented by a sum of spherical harmonics. + +And the following are lower-level modules to define and manipulate non-uniform surfaces: + +- [Ylm](jaxoplanet.experimental.starry.ylm) : a lower-level module to create and manipulate vectors in the spherical harmonic basis. +- [Pijk](jaxoplanet.experimental.starry.pijk) : a lower-level module to create and manipulate vectors in the polynomial basis. +- [Surface](jaxoplanet.experimental.starry.surface) : a module to manipulate the oriented surface of spherical bodies, represented by a sum of spherical harmonics. +- [visualization](jaxoplanet.experimental.starry.visualization) : a module to visualize the surface of non-uniform spherical bodies. + + +```{toctree} +:hidden: + +keplerian orbit +transit orbit +starry orbit +limb-darkened light curve +transforms +starry light curve +Surface +Ylm +Pijk +visualization +``` + +**Missing something?** Check the [full API reference](autoapi/jaxoplanet/index). diff --git a/docs/conf.py b/docs/conf.py index ee54b318..55abee52 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -19,7 +19,7 @@ ] autoapi_dirs = ["../src"] -autoapi_ignore = ["*/experimental/*", "*_version*", "*/types*"] +autoapi_ignore = ["*_version*", "*/types*"] autoapi_options = [ "members", "undoc-members", @@ -29,6 +29,9 @@ "special-members", # "imported-members", ] +# autoapi_add_toctree_entry = False +autoapi_template_dir = "_autoapi_templates" + suppress_warnings = ["autoapi.python_import_resolution"] myst_enable_extensions = ["dollarmath", "colon_fence"] @@ -44,7 +47,7 @@ version = jaxoplanet.__version__ release = jaxoplanet.__version__ -exclude_patterns = ["_build"] +exclude_patterns = ["_build", "_autoapi_templates"] html_theme = "sphinx_book_theme" html_title = "jaxoplanet documentation" html_logo = "_static/logo.png" diff --git a/docs/guide.md b/docs/guide.md deleted file mode 100644 index 6e35cf20..00000000 --- a/docs/guide.md +++ /dev/null @@ -1,15 +0,0 @@ -(guide)= - -# User Guide - -The following pages give some background on the context within which `jaxoplanet` -exists, as well as detailed installation and API documentation. Click through -for all the details, or head over to the {ref}`tutorials` for a more hands-on -experience. - -```{toctree} -:maxdepth: 1 - -install -troubleshooting -``` diff --git a/docs/index.md b/docs/index.md index 85d139e2..67da25cd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,11 +4,29 @@ ## Table of contents ```{toctree} ---- -maxdepth: 1 ---- +:caption: User guide -guide -tutorials +install +troubleshooting +tutorials/getting-started.ipynb +``` + +```{toctree} +:caption: Tutorials + +tutorials/about.ipynb +tutorials/transit.ipynb +tutorials/rv.ipynb +tutorials/starry.ipynb +``` + +```{toctree} +:caption: Reference + +api +tutorials/autodiff.ipynb +tutorials/introduction-to-jax.ipynb +tutorials/core-from-scratch.ipynb contributing + ``` diff --git a/docs/install.md b/docs/install.md index 47634b1d..f51fccd1 100644 --- a/docs/install.md +++ b/docs/install.md @@ -1,6 +1,6 @@ (install)= -# Installation Guide +# Installation `jaxoplanet` is built on top of [`jax`](https://github.com/google/jax) so that's the primary dependency that you'll need. All of the methods below will install any diff --git a/docs/tutorials.md b/docs/tutorials.md deleted file mode 100644 index b3be97da..00000000 --- a/docs/tutorials.md +++ /dev/null @@ -1,28 +0,0 @@ -(tutorials)= - -# Tutorials - -The tutorials in this section are automatically executed with every change of -the code to make sure that they are always up-to-date. As a result, they are -designed to require only a relatively small amount of computation time; when -using `jaxoplanet` for real problems you will probably find that your run -times are longer. - -To execute a tutorial on your own, you can click on the buttons at the top right -or this page to launch the notebook using [Binder](https://mybinder.org), -[Colab](https://colab.research.google.com), or download the `.ipynb` file -directly. - -## Introductory Topics - -```{toctree} -:maxdepth: 1 - -tutorials/getting-started.ipynb -tutorials/autodiff.ipynb -tutorials/introduction-to-jax.ipynb -tutorials/transit.ipynb -tutorials/rv.ipynb -tutorials/starry.ipynb -tutorials/core-from-scratch.ipynb -``` diff --git a/docs/tutorials/about.ipynb b/docs/tutorials/about.ipynb new file mode 100644 index 00000000..eba2e11c --- /dev/null +++ b/docs/tutorials/about.ipynb @@ -0,0 +1,79 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# About these tutorials\n", + "\n", + "The tutorials in this section are automatically executed with every change of\n", + "the code to make sure that they are always up-to-date. As a result, they are\n", + "designed to require only a relatively small amount of computation time; when\n", + "using `jaxoplanet` for real problems you will probably find that your run\n", + "times are longer.\n", + "\n", + "To execute a tutorial on your own, you can click on the buttons at the top right\n", + "or this page to launch the notebook using [Binder](https://mybinder.org),\n", + "[Colab](https://colab.research.google.com), or download the `.ipynb` file\n", + "directly.\n", + "\n", + "## Extra dependencies\n", + "\n", + "Most of these tutorials use the probabilistic programming library `NumPyro` to make \n", + "inference based on real or synthetic datasets. Hence, you will need the following extra packages:\n", + "- [NumPyro](https://num.pyro.ai/en/stable/getting_started.html) : a lightweight probabilistic programming library.\n", + "- [NumPyro-ext](https://github.com/dfm/numpyro-ext) : a package that extends NumPyro with a set of helper functions, custom distributions and other utilities.\n", + "- [corner](https://corner.readthedocs.io/en/latest/) : a package to visualize multidimensional samples using a scatterplot matrix.\n", + "- [Arviz](https://python.arviz.org/en/stable/) : a package for exploratory analysis of Bayesian models\n", + "\n", + "In addition, the [Lightkurve](https://docs.lightkurve.org/) may be required to download real datasets from the Kepler and TESS missions.\n", + "\n", + "For reference, here is the version of the packages used to generate these tutorials:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "import arviz\n", + "import numpy\n", + "import corner\n", + "import numpyro\n", + "import jaxoplanet\n", + "import numpyro_ext\n", + "\n", + "print(f\"jaxoplanet.__version__ = {jaxoplanet.__version__}\")\n", + "print(f\"numpy.__version__ = {numpy.__version__}\")\n", + "print(f\"numpyro.__version__ = {numpyro.__version__}\")\n", + "print(f\"numpyro_ext.__version__ = {numpyro_ext.__version__}\")\n", + "print(f\"jax.__version__ = {jax.__version__}\")\n", + "print(f\"corner.__version__ = {corner.__version__}\")\n", + "print(f\"arviz.__version__ = {arviz.__version__}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jaxoplanet", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/tutorials/getting-started.ipynb b/docs/tutorials/getting-started.ipynb index 3a519c67..2904c070 100644 --- a/docs/tutorials/getting-started.ipynb +++ b/docs/tutorials/getting-started.ipynb @@ -11,128 +11,182 @@ ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "id": "1", + "metadata": { + "tags": [ + "hide-input" + ] + }, + "outputs": [], + "source": [ + "import jax\n", + "\n", + "jax.config.update(\"jax_enable_x64\", True)" + ] + }, + { + "cell_type": "markdown", + "id": "2", "metadata": {}, "source": [ - "To set up a Keplerian orbital system in `jaxoplanet` we can define initialize a `Central` object (e.g. a star) and an orbiting `Body` object (e.g. a planet)." + "## Keplerian system\n", + "\n", + "In jaxoplanet, a Keplerian system can be instantiated with a [Central](jaxoplanet.orbits.keplerian.Central) object" ] }, { "cell_type": "code", "execution_count": null, - "id": "2", + "id": "3", "metadata": {}, "outputs": [], "source": [ - "from jax import config\n", + "from jaxoplanet.orbits.keplerian import System, Central\n", "\n", - "config.update(\"jax_enable_x64\", True)\n", - "\n", - "from jaxoplanet.orbits.keplerian import Central, Body, System\n", - "from jaxoplanet.units import unit_registry as ureg\n", - "from jaxoplanet import units\n", - "import jax.numpy as jnp" + "system = System(Central()) # a central object with some default parameters" ] }, { "cell_type": "markdown", - "id": "3", + "id": "4", "metadata": {}, "source": [ - "We can initialize the `Central` object using two of radius, mass and/or density, otherwise `jaxoplanet` will populate these parameters with the default values and units of a Solar analogue star:" + "and add an orbiting [Body](jaxoplanet.orbits.keplerian.Body) " ] }, { "cell_type": "code", "execution_count": null, - "id": "4", + "id": "5", "metadata": {}, "outputs": [], "source": [ - "Central()" + "system = system.add_body(period=0.1)" ] }, { "cell_type": "markdown", - "id": "5", + "id": "6", "metadata": {}, "source": [ - "We can instead choose to create a `Central` object using orbital parameters, for example for the Sun-Earth system:" + "As many arguments are optional, it's always a good idea to check the parameters of the system." ] }, { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "7", "metadata": {}, "outputs": [], "source": [ - "Sun = Central.from_orbital_properties(\n", - " period=1.0 * ureg.yr,\n", - " semimajor=1.0 * ureg.au,\n", - " radius=1.0 * ureg.R_sun,\n", - " body_mass=1.0 * ureg.M_earth,\n", - ")\n", - "Sun" + "system" ] }, { "cell_type": "markdown", - "id": "7", + "id": "8", "metadata": {}, "source": [ - "To create a Keplerian `Body` we must define either the orbital period or semi-major axis. There are also a number of optional orbital parameters we can set at this point:" + "For the reminder of this notebook, let's define a system consisting of an Earth-like planet orbiting a Sun-like star." ] }, { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "9", "metadata": {}, "outputs": [], "source": [ - "Earth = System(Sun).add_body(semimajor=1 * ureg.au).bodies[0]\n", - "Earth" + "from jaxoplanet.units import unit_registry as ureg\n", + "\n", + "sun = Central(\n", + " radius=1.0 * ureg.R_sun,\n", + " mass=1.0 * ureg.M_sun,\n", + ")\n", + "\n", + "system = System(sun).add_body(\n", + " semimajor=1.0 * ureg.au,\n", + " radius=1.0 * ureg.R_earth,\n", + " mass=1.0 * ureg.M_earth,\n", + ")\n", + "\n", + "earth = system.bodies[0]\n", + "\n", + "# checking the parameters of the system\n", + "system" ] }, { "cell_type": "markdown", - "id": "9", + "id": "10", "metadata": {}, "source": [ - "Note: The `eccentricity` by default is None (=circular orbit). This is not (entirely) equivalent to setting `eccentricity`=0. If we set the `eccentricity`=0 then we will have to explicitly define the argument of periastron (`omega_peri`) too!" + "```{note}\n", + "\n", + "Notice the use of the [jaxoplanet.units](jaxoplanet.units) module to handle physical units. Check *TODO* for an introduction to the unit system used by jaxoplanet.\n", + "\n", + "```" ] }, { "cell_type": "markdown", - "id": "10", + "id": "11", + "metadata": {}, + "source": [ + "# Radial velocity" + ] + }, + { + "cell_type": "markdown", + "id": "12", "metadata": {}, "source": [ - "Users familiar with `KeplarianOrbit`s within `exoplanet` (see [tutorial](https://docs.exoplanet.codes/en/latest/tutorials/data-and-models/)) can access the relative positions and velocities of the bodies in a similar way:" + "Then, one can access the relative position and velocity of the planet relative to the sun. " ] }, { "cell_type": "code", "execution_count": null, - "id": "11", + "id": "13", "metadata": {}, "outputs": [], "source": [ + "import jax.numpy as jnp\n", "from matplotlib import pyplot as plt\n", "\n", "# Get the position of the planet and velocity of the star as a function of time\n", "t = jnp.linspace(0, 730, 5000)\n", - "x, y, z = Earth.relative_position(t)\n", - "vx, vy, vz = Earth.central_velocity(t)\n", + "x, y, z = earth.relative_position(t)\n", + "vx, vy, vz = earth.central_velocity(t)" + ] + }, + { + "cell_type": "markdown", + "id": "14", + "metadata": {}, + "source": [ + "```{note}\n", + "Axes and orbital parameters conventions follow that of the [*exoplanet* package](https://docs.exoplanet.codes/en/latest/tutorials/data-and-models/).\n", + "```\n", "\n", - "# Plot the coordinates\n", - "fig, axes = plt.subplots(2, 1, figsize=(8, 8), sharex=True)\n", + "And plot the results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15", + "metadata": {}, + "outputs": [], + "source": [ + "fig, axes = plt.subplots(2, 1, sharex=True)\n", "ax = axes[0]\n", "ax.plot(t, x.magnitude, label=\"x\")\n", "ax.plot(t, y.magnitude, label=\"y\")\n", "ax.plot(t, z.magnitude, label=\"z\")\n", - "ax.set_ylabel(\"position of orbiting body [$R_*$]\")\n", + "ax.set_ylabel(\"earth position [$R_*$]\")\n", "ax.legend(fontsize=10, loc=1)\n", "\n", "ax = axes[1]\n", @@ -141,14 +195,61 @@ "ax.plot(t, vz.magnitude, label=\"$v_z$\")\n", "ax.set_xlim(t.min(), t.max())\n", "ax.set_xlabel(\"time [days]\")\n", - "ax.set_ylabel(\"velocity of central [$R_*$/day]\")\n", + "ax.set_ylabel(\"central velocity [$R_*$/day]\")\n", "_ = ax.legend(fontsize=10, loc=1)" ] }, + { + "cell_type": "markdown", + "id": "16", + "metadata": {}, + "source": [ + "## Light curve\n", + "\n", + "jaxoplanet contains module to compute occultation light curves of stars given different photosphere properties. For example, we can define a limb-darkened [light_curve](jaxoplanet.light_curves.limb_dark.light_curve) to compute the flux of a star with a polynomial limb darkening, allowing to express linear, quadratic and more complex laws.\n", + "\n", + "Using the limb-darkening coefficients from [Hestroffer and Magnan](https://www.physics.hmc.edu/faculty/esin/a101/limbdarkening.pdf) we compute the flux" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "17", + "metadata": {}, + "outputs": [], + "source": [ + "from jaxoplanet.light_curves.limb_dark import light_curve\n", + "\n", + "u = (0.30505, 1.13123, -0.78604, 0.40560, 0.02297, -0.07880)\n", + "time = jnp.linspace(-0.5, 0.5, 1000)\n", + "\n", + "flux = 1.0 + light_curve(system, u)(time)" + ] + }, + { + "cell_type": "markdown", + "id": "18", + "metadata": {}, + "source": [ + "and plot the resulting light curve" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19", + "metadata": {}, + "outputs": [], + "source": [ + "plt.plot(time, flux)\n", + "plt.xlabel(\"time (days)\")\n", + "_ = plt.ylabel(\"relative flux\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", "metadata": {}, "outputs": [], "source": [] @@ -170,7 +271,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/docs/tutorials/rv.ipynb b/docs/tutorials/rv.ipynb index 56298454..e82e9b16 100644 --- a/docs/tutorials/rv.ipynb +++ b/docs/tutorials/rv.ipynb @@ -6,17 +6,24 @@ "source": [ "(rv)=\n", "\n", - "# Radial Velocities Fitting" + "# Radial Velocities Fitting\n", + "\n", + "\n", + "In this tutorial we will learn how to use `jaxoplanet` to compute the radial velocities of a star hosting a single exoplanet, and how to fit this dataset using `numpyro`.\n", + "\n", + "```{note}\n", + "This tutorial requires some [extra packages](about.ipynb) that are not included in the `jaxoplanet` dependencies.\n", + "```\n", + "\n", + "## Setup\n", + "\n", + "We first setup the number of CPUs to use and enable the use of double-precision numbers with jax." ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [ - "hide-input" - ] - }, + "metadata": {}, "outputs": [], "source": [ "import jax\n", @@ -26,24 +33,6 @@ "jax.config.update(\"jax_enable_x64\", True)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this tutorial we will learn how to use `jaxoplanet` to compute the radial velocities of a star hosting a single exoplanet, and how to fit this dataset using `numpyro`." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "```{note}\n", - "This tutorial requires the installation of the following packages:\n", - "- [`numpyro`](https://num.pyro.ai)\n", - "- [`corner`](https://corner.readthedocs.io)\n", - "```" - ] - }, { "cell_type": "markdown", "metadata": {}, diff --git a/docs/tutorials/transit.ipynb b/docs/tutorials/transit.ipynb index 853738f6..09e559d1 100644 --- a/docs/tutorials/transit.ipynb +++ b/docs/tutorials/transit.ipynb @@ -11,13 +11,13 @@ "\n", "Like `exoplanet`, `jaxoplanet` includes methods for computing the light curves of transiting exoplanets. In this tutorial, we introduce these methods and use it alongside the `NumPyro` probabilistic programming library to do some transit fitting. Parts of this tutorial will follow the [Transit Fitting tutorial](https://gallery.exoplanet.codes/tutorials/transit/) for the `exoplanet` package.\n", "\n", - "In addition to `jaxoplanet` (and [`NumPy`](https://numpy.org/), [`Matplotlib`](https://matplotlib.org/stable/)), you'll need to install the following packages to run this tutorial:\n", - "- [`NumPyro`](https://num.pyro.ai/en/stable/getting_started.html)\n", - "- [`NumPyro-ext`](https://github.com/dfm/numpyro-ext)\n", - "- [`corner`](https://corner.readthedocs.io/en/latest/)\n", - "- [`Arviz`](https://python.arviz.org/en/stable/)\n", + "```{note}\n", + "This tutorial requires some [extra packages](about.ipynb) that are not included in the `jaxoplanet` dependencies.\n", + "```\n", "\n", - "Let's import the necessary packages and configure the setup." + "## Setup\n", + "\n", + "We first setup the number of CPUs to use and enable the use of double-precision numbers with jax. We also import the required packages." ] }, { @@ -45,17 +45,7 @@ "numpyro.set_platform(\"cpu\") # For CPU (use \"gpu\" for GPU)\n", "jax.config.update(\n", " \"jax_enable_x64\", True\n", - ") # For 64-bit precision since JAX defaults to 32-bit\n", - "\n", - "\n", - "print(f\"jaxoplanet.__version__ = {jaxoplanet.__version__}\")\n", - "print(f\"numpy.__version__ = {np.__version__}\")\n", - "print(f\"matplotlib.__version__ = {plt.matplotlib.__version__}\")\n", - "print(f\"numpyro.__version__ = {numpyro.__version__}\")\n", - "print(f\"numpyro_ext.__version__ = {numpyro_ext.__version__}\")\n", - "print(f\"jax.__version__ = {jax.__version__}\")\n", - "print(f\"corner.__version__ = {corner.__version__}\")\n", - "print(f\"arviz.__version__ = {az.__version__}\")" + ") # For 64-bit precision since JAX defaults to 32-bit" ] }, { @@ -63,7 +53,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let's first compute a simple light curve." + "## Generating the data\n", + "\n", + "Let's first compute a simple light curve.\n", + "\n", + "The light curve calculation requires an orbit object. We'll use [TransitOrbit](jaxoplanet.orbits.transit.TransitOrbit) (similar to [SimpleTransitOrbit](https://docs.exoplanet.codes/en/latest/user/api/#exoplanet.orbits.SimpleTransitOrbit) in the exoplanet package), which is an orbit parameterized by the observables of a transiting system: period, speed/duration, time of transit, impact parameter, and radius ratio." ] }, { @@ -72,26 +66,21 @@ "metadata": {}, "outputs": [], "source": [ - "# The light curve calculation requires an orbit object.\n", - "# We'll use TransitOrbit (similar to SimpleTransitOrbit in the exoplanet package),\n", - "# which is an orbit parameterized by the observables of a transiting system:\n", - "# period, speed/duration, time of transit, impact parameter, and radius ratio.\n", "orbit = TransitOrbit(\n", " period=3.456, duration=0.12, time_transit=0.0, impact_param=0.0, radius_ratio=0.1\n", ")\n", "\n", - "\n", "# Compute a limb-darkened light curve for this orbit\n", "t = np.linspace(-0.1, 0.1, 1000)\n", "u = [0.1, 0.06] # Quadratic limb-darkening coefficients\n", "light_curve = limb_dark_light_curve(orbit, u)(t)\n", "\n", "# Plot the light curve\n", - "plt.figure(dpi=150)\n", "plt.plot(t, light_curve, lw=2)\n", "plt.xlabel(\"time [days]\")\n", "plt.ylabel(\"relative flux\")\n", - "plt.xlim(t.min(), t.max());" + "plt.xlim(t.min(), t.max())\n", + "plt.tight_layout()" ] }, { @@ -133,13 +122,12 @@ "y = y_true + yerr * random.normal(size=len(t))\n", "\n", "# Let's see what the light curve looks like\n", - "plt.figure(dpi=150)\n", "plt.plot(t, y_true, \"-k\", lw=1.0, label=\"truth\")\n", "plt.plot(t, y, \".k\", ms=2, label=\"data\")\n", "plt.xlabel(\"time [days]\")\n", "plt.ylabel(\"relative flux\")\n", "plt.xlim(t.min(), t.max())\n", - "plt.legend(loc=4);" + "_ = plt.legend(loc=4)" ] }, { @@ -159,6 +147,17 @@ "metadata": {}, "outputs": [], "source": [ + "def light_curve_model(time, params):\n", + " orbit = TransitOrbit(\n", + " period=params[\"period\"],\n", + " duration=params[\"duration\"],\n", + " time_transit=params[\"t0\"],\n", + " impact_param=params[\"b\"],\n", + " radius_ratio=params[\"r\"],\n", + " )\n", + " return limb_dark_light_curve(orbit, params[\"u\"])(time)\n", + "\n", + "\n", "def model(t, yerr, y=None):\n", " # Priors for the parameters we're fitting for\n", "\n", @@ -187,14 +186,9 @@ " u = numpyro.sample(\"u\", numpyro_ext.distributions.QuadLDParams())\n", "\n", " # The orbit and light curve\n", - " orbit = TransitOrbit(\n", - " period=period,\n", - " duration=duration,\n", - " time_transit=t0,\n", - " impact_param=b,\n", - " radius_ratio=r,\n", + " y_pred = light_curve_model(\n", + " t, {\"period\": period, \"duration\": duration, \"t0\": t0, \"b\": b, \"r\": r, \"u\": u}\n", " )\n", - " y_pred = limb_dark_light_curve(orbit, u)(t)\n", "\n", " # Let's track the light curve\n", " numpyro.deterministic(\"light_curve\", y_pred)\n", @@ -247,7 +241,7 @@ " truths=[T0, PERIOD, DURATION, ROR, B, U[0], U[1]],\n", " show_titles=True,\n", " title_kwargs={\"fontsize\": 10},\n", - " label_kwargs={\"fontsize\": 12},\n", + " label_kwargs={\"fontsize\": 10},\n", ")" ] }, @@ -324,14 +318,14 @@ "metadata": {}, "outputs": [], "source": [ - "plt.figure(dpi=150)\n", "plt.plot(t, y, \".k\", ms=2, label=\"data\")\n", "plt.plot(t, y_true, \"-k\", lw=1.0, label=\"truth\")\n", "plt.plot(t, opt_params[\"light_curve\"], \"--C0\", lw=1.0, label=\"MAP model\")\n", "plt.xlabel(\"time [days]\")\n", "plt.ylabel(\"relative flux\")\n", "plt.legend(fontsize=10, loc=4)\n", - "plt.xlim(t.min(), t.max());" + "plt.xlim(t.min(), t.max())\n", + "plt.tight_layout()" ] }, { @@ -437,11 +431,11 @@ "metadata": {}, "outputs": [], "source": [ - "az.plot_trace(\n", + "_ = az.plot_trace(\n", " inf_data,\n", " var_names=[\"t0\", \"period\", \"duration\", \"r\", \"b\", \"u\"],\n", " backend_kwargs={\"constrained_layout\": True},\n", - ");" + ")" ] }, { @@ -460,16 +454,18 @@ "metadata": {}, "outputs": [], "source": [ - "corner.corner(\n", + "fig = plt.figure(figsize=(12, 12))\n", + "_ = corner.corner(\n", " inf_data,\n", " var_names=[\"t0\", \"period\", \"duration\", \"r\", \"b\", \"u\"],\n", " truths=[T0, PERIOD, DURATION, ROR, B, U[0], U[1]],\n", " show_titles=True,\n", " quantiles=[0.16, 0.5, 0.84],\n", - " title_kwargs={\"fontsize\": 12},\n", - " label_kwargs={\"fontsize\": 15},\n", + " title_kwargs={\"fontsize\": 10},\n", + " label_kwargs={\"fontsize\": 10},\n", " title_fmt=\".4f\",\n", - ");" + " fig=fig,\n", + ")" ] }, { @@ -495,28 +491,24 @@ "metadata": {}, "outputs": [], "source": [ - "inferred_t0 = np.median(samples[\"t0\"])\n", - "inferred_period = np.median(samples[\"period\"])\n", - "inferred_duration = np.median(samples[\"duration\"])\n", - "inferred_r = np.median(samples[\"r\"])\n", - "inferred_b = np.median(samples[\"b\"])\n", - "inferred_u = np.median(samples[\"u\"], axis=0)\n", + "inferred_params = {\n", + " \"t0\": np.median(samples[\"t0\"]),\n", + " \"period\": np.median(samples[\"period\"]),\n", + " \"duration\": np.median(samples[\"duration\"]),\n", + " \"r\": np.median(samples[\"r\"]),\n", + " \"b\": np.median(samples[\"b\"]),\n", + " \"u\": np.median(samples[\"u\"], axis=0),\n", + "}\n", "\n", - "orbit = TransitOrbit(\n", - " period=inferred_period,\n", - " duration=inferred_duration,\n", - " time_transit=inferred_t0,\n", - " impact_param=inferred_b,\n", - " radius_ratio=inferred_r,\n", - ")\n", - "y_model = limb_dark_light_curve(orbit, inferred_u)(t)\n", "\n", - "fig, ax = plt.subplots(dpi=150)\n", + "y_model = light_curve_model(t, inferred_params)\n", + "\n", + "fig, ax = plt.subplots()\n", "\n", "# Plot the folded data\n", "t_fold = (\n", - " t - inferred_t0 + 0.5 * inferred_period\n", - ") % inferred_period - 0.5 * inferred_period\n", + " t - inferred_params[\"t0\"] + 0.5 * inferred_params[\"period\"]\n", + ") % inferred_params[\"period\"] - 0.5 * inferred_params[\"period\"]\n", "ax.errorbar(\n", " t_fold,\n", " y,\n", @@ -538,7 +530,8 @@ "ax.set_xlabel(\"time since transit [days]\")\n", "ax.set_ylabel(\"relative flux\")\n", "ax.legend(fontsize=10, loc=4)\n", - "ax.set_xlim(-inferred_duration, inferred_duration);" + "ax.set_xlim(-inferred_params[\"duration\"], inferred_params[\"duration\"])\n", + "plt.tight_layout()" ] }, { @@ -565,7 +558,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/src/jaxoplanet/experimental/starry/light_curves.py b/src/jaxoplanet/experimental/starry/light_curves.py index d8ae06e9..4a7c11ae 100644 --- a/src/jaxoplanet/experimental/starry/light_curves.py +++ b/src/jaxoplanet/experimental/starry/light_curves.py @@ -87,7 +87,7 @@ def map_light_curve( """Light curve of an occulted map. Args: - map (Map): map object + map (Map): Surface object r (float or None): radius of the occulting body, relative to the current map body xo (float or None): x position of the occulting body, relative to the current diff --git a/src/jaxoplanet/experimental/starry/pijk.py b/src/jaxoplanet/experimental/starry/pijk.py index ddb67d82..5ed8d483 100644 --- a/src/jaxoplanet/experimental/starry/pijk.py +++ b/src/jaxoplanet/experimental/starry/pijk.py @@ -12,19 +12,26 @@ class Pijk(eqx.Module): - """A class to represent and manipulate spherical harmonics in the polynomial basis. - Several indices are used throughout the class: - - Indices (i, j, k) represent the order of the polynomials of (x, y, z), for example - (1, 0, 2) represents x * z^2. - - Indices (l, m) represent the orders of the spherical harmonics. - - Index n represent the index of the polynomial in the flattened array. + r"""A class to represent and manipulate spherical harmonics in the + polynomial basis. Several indices are used throughout the class: - Flattened array `to_dense` and `from_dense` follow the convention from Luger et al. - (2019). More specifically: + * Indices :math:`(i, j, k)` represent the order of the polynomials of + :math:`(x, y, z)`, for example :math:`(1, 0, 2)` represents :math:`x\,z^2`. + + * Indices :math:`(l, m)` represent the orders of the spherical harmonics. + + * Index n represent the index of the polynomial in the flattened array. + + Flattened array ``todense`` and ``from_dense`` follow the convention from + Luger et al. (2019). More specifically: .. math:: - \tilde{p} = (1, x, y, z, x^2, xz, xy, yz, y^2, ...)^T + \tilde{p} = + \begin{pmatrix} + 1 & x & y & z & x^2 & xz & xy & yz & y^2 & + \cdot\cdot\cdot + \end{pmatrix}^\mathsf{T} """ diff --git a/src/jaxoplanet/experimental/starry/surface.py b/src/jaxoplanet/experimental/starry/surface.py index 135d2ebd..71111482 100644 --- a/src/jaxoplanet/experimental/starry/surface.py +++ b/src/jaxoplanet/experimental/starry/surface.py @@ -51,26 +51,26 @@ class Surface(eqx.Module): show_map(m) """ - # Ylm object representing the spherical harmonic expansion of the map. y: Ylm + """Ylm object representing the spherical harmonic expansion of the map""" - # Inclination of the map in radians. inc: Array + """Inclination of the map in radians""" - # Obliquity of the map in radians. obl: Array + """Obliquity of the map in radians.""" - # Tuple of limb darkening coefficients. u: tuple[Array, ...] + """Tuple of limb darkening coefficients.""" - # Rotation period of the map in days (attribute subject to change) period: Array | None + """Rotation period of the map in days (attribute subject to change)""" - # Amplitude of the map, a quantity proportional to map luminosity. amplitude: Array + """Amplitude of the map, a quantity proportional to map luminosity.""" - # Boolean to specify whether the Ylm coefficients should be normalized normalize: bool + """Boolean to specify whether the Ylm coefficients should be normalized""" def __init__( self, @@ -100,23 +100,26 @@ def __init__( self.normalize = normalize @property - def poly_basis(self): + def _poly_basis(self): return jax.jit(poly_basis(self.deg)) @property def udeg(self): + """Order of the polynomial limb darkening.""" return len(self.u) @property def ydeg(self): + """Degree of the spherical harmonic expansion.""" return self.y.ell_max @property def deg(self): + """Total degree of the spherical harmonic expansion (`udeg + ydeg`).""" return self.ydeg + self.udeg def _intensity(self, x, y, z, theta=0.0): - pT = self.poly_basis(x, y, z) + pT = self._poly_basis(x, y, z) Ry = left_project(self.ydeg, self.inc, self.obl, theta, 0.0, self.y.todense()) A1Ry = A1(self.ydeg).todense() @ Ry p_y = Pijk.from_dense(A1Ry, degree=self.ydeg) @@ -165,6 +168,14 @@ def intensity(self, lat: float, lon: float): return self._intensity(x, y, z) def rotational_phase(self, time: Array) -> Array: + """Returns the rotational phase of the map at a given time. + + Args: + time (ArrayLike): time in same units as the period + + Returns: + ArrayLike: rotational phase of the map at the given time + """ if self.period is None: return jnp.zeros_like(time) else: diff --git a/src/jaxoplanet/experimental/starry/ylm.py b/src/jaxoplanet/experimental/starry/ylm.py index e1ff631f..ab0c2780 100644 --- a/src/jaxoplanet/experimental/starry/ylm.py +++ b/src/jaxoplanet/experimental/starry/ylm.py @@ -1,3 +1,40 @@ +r"""A module to manipulate vectors in the spherical harmonic basis. + +The spherical harmonics basis is a set of orthogonal functions defined on the +unit sphere. In jaxoplanet, this basis is used to represent the intensity at the surface +of a spherical body, such as a star or a planet. We say that :math:`y` represents the +intensity of a surface in the spherical harmonics basis if the specific intensity at the +:math:`(x,y)` on the surface can be written as: + +.. math:: + + I(x, y) = \mathbf{\tilde{y}_n^\mathsf{T}} (x, y) \, \mathbf{y} + \quad, + +where :math:`\tilde{y}_n` is the **spherical harmonic basis**, +arranged in increasing degree and order: + +.. math:: + + \mathbf{\tilde{y}_n} = + \begin{pmatrix} + Y_{0, 0} & + Y_{1, -1} & Y_{1, 0} & Y_{1, 1} & + Y_{2, -2} & Y_{2, -1} & Y_{2, 0} & Y_{2, 1} & Y_{2, 2} & + \cdot\cdot\cdot + \end{pmatrix}^\mathsf{T} + \quad, + +where :math:`Y_{l, m} = Y_{l, m}(x, y)` is the spherical harmonic of degree :math:`l` +and order :math:`m`. For reference, in this basis the coefficient of the spherical +harmonic :math:`Y_{l, m}` is located at the index + +.. math:: + + n = l^2 + l + m + +""" + import math from collections import defaultdict from collections.abc import Mapping @@ -23,15 +60,16 @@ class Ylm(eqx.Module): spherical harmonic coefficients. Defaults to {(0, 0): 1.0}. """ - # coefficients of the spherical harmonic expansion of the map in the form - # {(l, m): coefficient} data: dict[tuple[int, int], Array] + """coefficients of the spherical harmonic expansion of the map in the form + `{(l, m): coefficient}`""" - # maximum degree of the spherical harmonic expansion ell_max: int = eqx.field(static=True) + """The maximum degree of the spherical harmonic coefficients.""" - # whether the spherical harmonic expansion is diagonal (all m=0) diagonal: bool = eqx.field(static=True) + """Whether are orders m of the spherical harmonic coefficients are zero. + Diagonal if only the degrees "l" are non-zero.""" def __init__( self, @@ -46,21 +84,22 @@ def __init__( @property def shape(self) -> tuple[int, ...]: - """The number of coefficients in the expansion. This sets the shape of + """The number of coefficients in the basis. This sets the shape of the output of `todense`.""" return (self.ell_max**2 + 2 * self.ell_max + 1,) @property def indices(self) -> list[tuple[int, int]]: + """List of (l,m) indices of the spherical harmonic coefficients.""" return list(self.data.keys()) def index(self, ell: Array, m: Array) -> Array: """Convert the degree and order of the spherical harmonic to the - corresponding index in the coefficient array.""" + corresponding index in the coefficients array.""" return ell * (ell + 1) + m def normalize(self) -> "Ylm": - """Return a new Ylm instance with normalized coefficients. + """Return a new Ylm instance with coefficients normalized to :math:`Y_{0,0}`. Returns: Ylm instance with normalized coefficients. @@ -76,15 +115,26 @@ def normalize(self) -> "Ylm": return Ylm(data=data) def tosparse(self) -> BCOO: + """Return a sparse (jax.experimental.sparse.BCOO) spherical harmonic + coefficients vector where the spherical harmonic :math:`Y_{l, m}` is located at + the index :math:`n = l^2 + l + m`. + """ indices, values = zip(*self.data.items(), strict=False) idx = jnp.array([self.index(ell, m) for ell, m in indices])[:, None] return BCOO((jnp.asarray(values), idx), shape=self.shape) def todense(self) -> Array: + """Return a dense spherical harmonic coefficients vector where the spherical + harmonic :math:`Y_{l, m}` is located at the index :math:`n = l^2 + l + m`. + """ return self.tosparse().todense() @classmethod def from_dense(cls, y: Array, normalize: bool = True) -> "Ylm": + """Create a Ylm object from a dense array of spherical harmonic coefficients + where the spherical harmonic :math:`Y_{l, m}` is located at the index + :math:`n = l^2 + l + m`. + """ data = {} for i, ylm in enumerate(y): ell = int(np.floor(np.sqrt(i))) @@ -115,7 +165,8 @@ def _mul(f: Ylm, g: Ylm) -> Ylm: """ Based closely on the implementation from the MIT-licensed spherical package: - https://github.com/moble/spherical/blob/0aa81c309cac70b90f8dfb743ce35d2cc9ae6dee/spherical/multiplication.py + https://github.com/moble/spherical/blob/0aa81c309cac70b90f8dfb743ce35d2cc9ae6dee/ + spherical/multiplication.py """ ellmax_f = f.ell_max ellmax_g = g.ell_max @@ -208,8 +259,8 @@ def func(contrast: float, r: float, lat: float = 0.0, lon: float = 0.0): """spot expansion in the spherical harmonics basis. Args: - contrast (float): spot contrast, defined as (1-c) where c is the intensity of - the center of the spot relative to the unspotted surface. A contrast of 1. + contrast (float): spot contrast, defined as (1-c) where c is the intensity + of the center of the spot relative to the unspotted surface. A contrast of 1. means that the spot intensity drops to zero at the center, 0. means that the intensity at the center of the spot is the same as the intensity of the unspotted surface. diff --git a/src/jaxoplanet/light_curves/transforms.py b/src/jaxoplanet/light_curves/transforms.py index eca8ecbb..6c2a0e9d 100644 --- a/src/jaxoplanet/light_curves/transforms.py +++ b/src/jaxoplanet/light_curves/transforms.py @@ -1,3 +1,6 @@ +"""A module providing decorators to transform light curve functions +""" + __all__ = ["integrate", "interpolate"] from functools import wraps diff --git a/src/jaxoplanet/orbits/keplerian.py b/src/jaxoplanet/orbits/keplerian.py index 29dcf14c..a4f0c375 100644 --- a/src/jaxoplanet/orbits/keplerian.py +++ b/src/jaxoplanet/orbits/keplerian.py @@ -1,3 +1,6 @@ +"""A module to define Keplerian systems of bodies. +""" + from collections.abc import Callable, Iterable, Sequence from typing import Any @@ -717,6 +720,15 @@ def add_body( central: Central | None = None, **kwargs: Any, ) -> "System": + """Add a body to the system and return a new system + + Args: + body (Body | None, optional): body to add. Defaults to None. + central (Central | None, optional): TODO. Defaults to None. + + Returns: + System: :py:class:`~jaxoplanet.orbits.keplerian.System` with the added body + """ body_: Body | OrbitalBody | None = body if body_ is None: body_ = Body(**kwargs)