Skip to content

Commit

Permalink
Add examples (#12)
Browse files Browse the repository at this point in the history
* Add notebook example

* Add docs

* Fix env
  • Loading branch information
AlecThomson authored Jan 31, 2025
1 parent e681ffd commit 3e0e0fb
Show file tree
Hide file tree
Showing 14 changed files with 180 additions and 62 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ env:
# Many color libraries just need this to be set to any value, but at least
# one distinguishes color depth, where "3" -> "256-bit color".
FORCE_COLOR: 3
# Jupyter is migrating its paths to use standard platformdirs
# given by the platformdirs library. To remove this warning and
# see the appropriate new directories, set the environment variable
# `JUPYTER_PLATFORM_DIRS=1` and then run `jupyter --paths`.
# The use of platformdirs will be the default in `jupyter_core` v6
JUPYTER_PLATFORM_DIRS: 1

jobs:
pre-commit:
Expand Down
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,8 @@ repos:
- id: check-dependabot
- id: check-github-workflows
- id: check-readthedocs

- repo: https://github.com/kynan/nbstripout
rev: 0.8.1
hooks:
- id: nbstripout
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"sphinx_autodoc_typehints",
"sphinx_copybutton",
"autoapi.extension",
"nbsphinx",
]

autoapi_type = "python"
Expand Down
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
```{toctree}
:maxdepth: 2
:hidden:
usage
```

Expand Down
12 changes: 12 additions & 0 deletions docs/usage.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Usage and examples

The following notebooks should serve as a guide to using `rm-lite`. These
notebooks are run during the automated testing, so if you can read this the
examples should work as written.

```{toctree}
:maxdepth: 1
:caption: Example notebooks:
examples/rmsyth_1d.ipynb
```
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dependencies = [
"nestle",
"corner",
"polars",
"sigfig",
]

[project.optional-dependencies]
Expand All @@ -50,6 +51,7 @@ dev = [
"pytest",
"pytest-cov",
"nox",
"nbconvert",
]
docs = [
"sphinx>=7.0",
Expand Down
77 changes: 53 additions & 24 deletions rm_lite/tools_1d/rmsynth.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,6 @@
rmsynth_nufft,
)

logger.setLevel("WARNING")


# class RMSynth1DArrays(NamedTuple):
# """Resulting arrays from RM-synthesis"""

# phi_arr_radm2: NDArray[np.float64]
# """ Array of Faraday depths """
# phi2_arr_radm2: NDArray[np.float64]
# """ Double length of Faraday depths """
# rmsf_arr: NDArray[np.float64]
# """ Rotation Measure Spread Function """
# freq_arr_hz: NDArray[np.float64]
# """ Frequency array """
# weight_arr: NDArray[np.float64]
# """ Weight array """
# fdf_dirty_arr: NDArray[np.float64]
# """ Dirty Faraday dispersion function """


class RMSynth1DResults(NamedTuple):
"""Resulting arrays from RM-synthesis"""
Expand All @@ -52,12 +33,14 @@ class RMSynth1DResults(NamedTuple):
""" RMSynth arrays """
rmsf_arrs: pl.DataFrame
""" RMSF arrays """
stokes_i_arrs: pl.DataFrame
""" Stokes I arrays """


def run_rmsynth(
freq_arr_hz: NDArray[np.float64],
complex_pol_arr: NDArray[np.complex128],
complex_pol_error: NDArray[np.float64],
complex_pol_error: NDArray[np.complex128],
stokes_i_arr: NDArray[np.float64] | None = None,
stokes_i_error_arr: NDArray[np.float64] | None = None,
stokes_i_model_arr: NDArray[np.float64] | None = None,
Expand All @@ -71,6 +54,31 @@ def run_rmsynth(
fit_function: Literal["log", "linear"] = "log",
fit_order: int = 2,
) -> RMSynth1DResults:
"""Run RM-synthesis on 1D data
Args:
freq_arr_hz (NDArray[np.float64]): Frequencies in Hz
complex_pol_arr (NDArray[np.complex128]): Complex polarisation values (Q + iU)
complex_pol_error (NDArray[np.float64]): Complex polarisation errors (dQ + idU)
stokes_i_arr (NDArray[np.float64] | None, optional): Total itensity values. Defaults to None.
stokes_i_error_arr (NDArray[np.float64] | None, optional): Total intensity errors. Defaults to None.
stokes_i_model_arr (NDArray[np.float64] | None, optional): Total intensity model array. Defaults to None.
stokes_i_model_error (NDArray[np.float64] | None, optional): Total intensity model error. Defaults to None.
phi_max_radm2 (float | None, optional): Maximum Faraday depth. Defaults to None.
d_phi_radm2 (float | None, optional): Spacing in Faraday depth. Defaults to None.
n_samples (float | None, optional): Number of samples across the RMSF. Defaults to 10.0.
weight_type ("variance", "uniform", optional): Type of weighting. Defaults to "variance".
do_fit_rmsf (bool, optional): Fit the RMSF main lobe. Defaults to False.
do_fit_rmsf_real (bool, optional): The the real part of the RMSF. Defaults to False.
fit_function ("log" | "linear", optional): _description_. Defaults to "log".
fit_order (int, optional): Polynomial fit order. Defaults to 2. Negative values will iterate until the fit is good.
Returns:
RMSynth1DResults:
fdf_parameters (pl.DataFrame): FDF parameters
fdf_arrs (pl.DataFrame): RMSynth arrays
rmsf_arrs (pl.DataFrame): RMSF arrays
"""
stokes_data = StokesData(
freq_arr_hz=freq_arr_hz,
complex_pol_arr=complex_pol_arr,
Expand Down Expand Up @@ -99,7 +107,20 @@ def _run_rmsynth(
fit_function: Literal["log", "linear"] = "log",
fit_order: int = 2,
) -> RMSynth1DResults:
"""Run RM-synthesis on 1D data"""
"""Run RM-synthesis on 1D data with packed data
Args:
stokes_data (StokesData): Frequency-dependent polarisation data
fdf_options (FDFOptions): RM-synthesis options
fit_function ("log", "linear", optional): Type of function to fit. Defaults to "log".
fit_order (int, optional): Polynomial fit order. Defaults to 2. Negative values will iterate until the fit is good.
Returns:
RMSynth1DResults:
fdf_parameters (pl.DataFrame): FDF parameters
fdf_arrs (pl.DataFrame): RMSynth arrays
rmsf_arrs (pl.DataFrame): RMSF arrays
"""

rmsynth_params = compute_rmsynth_params(
freq_arr_hz=stokes_data.freq_arr_hz,
Expand Down Expand Up @@ -182,15 +203,23 @@ def _run_rmsynth(
rmsyth_arrs = pl.DataFrame(
{
"phi_arr_radm2": rmsynth_params.phi_arr_radm2,
"fdf_dirty_arr": fdf_dirty_arr,
"fdf_dirty_complex_arr": fdf_dirty_arr,
}
)

rmsf_arrs = pl.DataFrame(
{
"phi2_arr_radm2": rmsf_result.phi_double_arr_radm2,
"rmsf_arr": rmsf_result.rmsf_cube,
"rmsf_complex_arr": rmsf_result.rmsf_cube,
}
)
stokes_i_arrs = pl.DataFrame(
{
"freq_arr_hz": stokes_data.freq_arr_hz,
"stokes_i_model_arr": fractional_spectra.stokes_i_model_arr,
"stokes_i_model_error": fractional_spectra.stokes_i_model_error,
"flag_arr": no_nan_idx,
}
)

return RMSynth1DResults(fdf_parameters, rmsyth_arrs, rmsf_arrs)
return RMSynth1DResults(fdf_parameters, rmsyth_arrs, rmsf_arrs, stokes_i_arrs)
25 changes: 13 additions & 12 deletions rm_lite/utils/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ class RMCleanResults(NamedTuple):
class CleanLoopResults(NamedTuple):
"""Results of the RM-CLEAN loop"""

clean_fdf_spectrum: NDArray[np.float64]
clean_fdf_spectrum: NDArray[np.complex128]
"""The cleaned Faraday dispersion function cube"""
resid_fdf_spectrum: NDArray[np.float64]
resid_fdf_spectrum: NDArray[np.complex128]
"""The residual Faraday dispersion function cube"""
model_fdf_spectrum: NDArray[np.float64]
model_fdf_spectrum: NDArray[np.complex128]
"""The clean components cube"""
iter_count: int
"""The number of iterations"""
Expand All @@ -53,25 +53,25 @@ class CleanLoopResults(NamedTuple):
class MinorLoopResults(NamedTuple):
"""Results of the RM-CLEAN minor loop"""

clean_fdf_spectrum: NDArray[np.float64]
clean_fdf_spectrum: NDArray[np.complex128]
"""The cleaned Faraday dispersion function cube"""
resid_fdf_spectrum: NDArray[np.float64]
resid_fdf_spectrum: NDArray[np.complex128]
"""The residual Faraday dispersion function cube"""
resid_fdf_spectrum_mask: np.ma.MaskedArray
"""The masked residual Faraday dispersion function cube"""
model_fdf_spectrum: NDArray[np.float64]
model_fdf_spectrum: NDArray[np.complex128]
"""The clean components cube"""
model_rmsf_spectrum: NDArray[np.float64]
model_rmsf_spectrum: NDArray[np.complex128]
""" Model * RMSF """
iter_count: int
"""The number of iterations"""


def restore_fdf(
model_fdf_spectrum: NDArray[np.float64],
model_fdf_spectrum: NDArray[np.complex128],
phi_double_arr_radm2: NDArray[np.float64],
fwhm_rmsf: float,
) -> NDArray[np.float64]:
) -> NDArray[np.complex128]:
clean_beam = unit_centred_gaussian(
x=phi_double_arr_radm2,
fwhm=fwhm_rmsf,
Expand All @@ -83,16 +83,16 @@ def restore_fdf(


def rmclean(
dirty_fdf_arr: NDArray[np.float64],
dirty_fdf_arr: NDArray[np.complex128],
phi_arr_radm2: NDArray[np.float64],
rmsf_arr: NDArray[np.float64],
rmsf_arr: NDArray[np.complex128],
phi_double_arr_radm2: NDArray[np.float64],
fwhm_rmsf_arr: NDArray[np.float64],
mask: float,
threshold: float,
max_iter: int = 1000,
gain: float = 0.1,
mask_arr: NDArray[np.float64] | None = None,
mask_arr: NDArray[np.bool_] | None = None,
) -> RMCleanResults:
_bad_result = RMCleanResults(
clean_fdf_arr=dirty_fdf_arr,
Expand Down Expand Up @@ -833,6 +833,7 @@ def mutliscale_rmclean(
mask_arr: NDArray[np.float64] | None = None,
kernel: Literal["tapered_quad", "gaussian"] = "gaussian",
) -> RMCleanResults:
raise NotImplementedError
_bad_result = RMCleanResults(
clean_fdf_arr=dirty_fdf_arr,
model_fdf_arr=np.zeros_like(dirty_fdf_arr),
Expand Down
18 changes: 13 additions & 5 deletions rm_lite/utils/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,19 @@
from typing import Literal, NamedTuple, Protocol

import numpy as np
import sigfig as sf
from astropy.modeling.models import Gaussian1D
from astropy.stats import akaike_info_criterion_lsq
from numpy.typing import ArrayLike, NDArray
from scipy import optimize

from rm_lite.utils.logging import logger

logger.setLevel("INFO")

GAUSSIAN_SIGMA_TO_FWHM = float(2.0 * np.sqrt(2.0 * np.log(2.0)))


class StokesIModel(Protocol):
def __call__(self, x: NDArray[np.float64]) -> NDArray[np.float64]: ...
def __call__(self, x: NDArray[np.float64], *params) -> NDArray[np.float64]: ...


class FitResult(NamedTuple):
Expand Down Expand Up @@ -146,6 +145,7 @@ def fit_fdf(
mask = np.zeros_like(phi_arr_radm2, dtype=bool)
mask[np.argmax(fdf_to_fit_arr)] = 1
fwhm_fdf_arr_pix = fwhm_fdf_radm2 / d_phi
fwhm_fdf_arr_pix /= 2 # fit within half the FWHM
for i in np.where(mask)[0]:
start = int(i - fwhm_fdf_arr_pix / 2)
end = int(i + fwhm_fdf_arr_pix / 2)
Expand Down Expand Up @@ -231,6 +231,8 @@ def static_fit(
fit_order: int = 2,
fit_type: Literal["log", "linear"] = "log",
) -> FitResult:
msg = f"Fitting Stokes I model of type {fit_type} with order {fit_order}."
logger.info(msg)
if fit_type == "linear":
fit_func = polynomial(fit_order)
elif fit_type == "log":
Expand Down Expand Up @@ -265,6 +267,10 @@ def static_fit(
ssr=ssr, n_params=fit_order + 1, n_samples=len(freq_arr_hz)
)

errors = np.sqrt(np.diag(pcov))
fit_vals = [sf.round(p, e) for p, e in zip(popt, errors)]
logger.info(f"Fit results: {fit_vals}")

return FitResult(
popt=popt,
pcov=pcov,
Expand All @@ -281,6 +287,8 @@ def dynamic_fit(
fit_order: int = 2,
fit_type: Literal["log", "linear"] = "log",
) -> FitResult:
msg = f"Iteratively fitting Stokes I model of type {fit_type} with max order {fit_order}."
logger.info(msg)
orders = np.arange(fit_order + 1)
n_parameters = orders + 1
fit_results: list[FitResult] = []
Expand All @@ -296,10 +304,10 @@ def dynamic_fit(
)
fit_results.append(fit_result)

logger.debug(f"Fit results for orders {orders}:")
logger.info(f"Fit results for orders {orders}:")
aics = np.array([fit_result.aic for fit_result in fit_results])
bestest_aic, bestest_n, bestest_aic_idx = best_aic_func(aics, n_parameters)
logger.debug(f"Best fit found with {bestest_n} parameters.")
logger.info(f"Best fit found with {bestest_n} parameters.")
logger.debug(f"Best fit found with AIC {bestest_aic}.")
logger.debug(f"Best fit found at index {bestest_aic_idx}.")
logger.debug(f"Best fit found with order {orders[bestest_aic_idx]}.")
Expand Down
4 changes: 2 additions & 2 deletions rm_lite/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,12 @@ def get_logger(
"""
logging.captureWarnings(True)
logger = logging.getLogger(name)
logger.setLevel(logging.WARNING)
logger.setLevel(logging.INFO)

if attach_handler:
# Create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
ch.setLevel(logging.INFO)

# Add formatter to ch
ch.setFormatter(CustomFormatter())
Expand Down
Loading

0 comments on commit 3e0e0fb

Please sign in to comment.