From 5377cb5196f33b77956e09700a7c43adccfa8695 Mon Sep 17 00:00:00 2001 From: Pravir Kumar Date: Mon, 16 Sep 2024 02:39:39 +0300 Subject: [PATCH 1/2] match filter improved. template normalisation include removing mean. docs improved with linked code --- docs/api/core.md | 9 + docs/conf.py | 66 ++++-- pyproject.toml | 7 +- sigpyproc/block.py | 8 +- sigpyproc/core/filters.py | 433 ++++++++++++++++++++++++++------------ sigpyproc/core/kernels.py | 78 +++++++ sigpyproc/core/rfi.py | 73 ++++--- sigpyproc/core/stats.py | 269 ++++++++++++++--------- tests/test_filters.py | 75 +++++-- tests/test_stats.py | 22 +- 10 files changed, 714 insertions(+), 326 deletions(-) diff --git a/docs/api/core.md b/docs/api/core.md index 8003872..896c839 100644 --- a/docs/api/core.md +++ b/docs/api/core.md @@ -9,6 +9,15 @@ :show-inheritance: ``` +## sigpyproc.core.filters + +```{eval-rst} +.. automodule:: sigpyproc.core.filters + :members: + :undoc-members: + :show-inheritance: +``` + ## sigpyproc.core.rfi ```{eval-rst} diff --git a/docs/conf.py b/docs/conf.py index 1c981c4..ee464eb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,8 +10,13 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. +from __future__ import annotations + import datetime +import inspect +import os import sys +from importlib import import_module from importlib.metadata import version as meta_version from pathlib import Path @@ -26,6 +31,7 @@ version = meta_version("sigpyproc") release = version master_doc = "index" +repo_url = "https://github.com/FRBs/sigpyproc3" # -- General configuration --------------------------------------------------- @@ -33,14 +39,15 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ "sphinx.ext.autodoc", - "numpydoc", - "sphinx_autodoc_typehints", "sphinx.ext.coverage", "sphinx.ext.intersphinx", + "sphinx.ext.linkcode", "sphinx_click", "sphinx-prompt", "sphinx_copybutton", + "numpydoc", "myst_nb", + "jupyter_sphinx", ] # Add any paths that contain templates here, relative to this directory. @@ -62,36 +69,30 @@ # a list of builtin themes. html_theme = "sphinx_book_theme" +html_context = {"default_mode": "light"} html_title = project html_theme_options = { - "repository_url": "https://github.com/FRBs/sigpyproc3", + "repository_url": repo_url, "use_repository_button": True, "use_issues_button": True, "use_download_button": True, } -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -# \html_static_path = ["_static"] # -- Extension configuration ------------------------------------------------- autoclass_content = "class" # include both class docstring and __init__ autodoc_member_order = "bysource" autodoc_typehints = "none" -autodoc_inherit_docstrings = True - -typehints_document_rtype = False -numpydoc_use_plots = True -numpydoc_class_members_toctree = False +numpydoc_show_class_members = False numpydoc_show_inherited_class_members = False +numpydoc_class_members_toctree = False numpydoc_xref_param_type = True numpydoc_xref_aliases = { "ndarray": "numpy.ndarray", "dtype": "numpy.dtype", "ArrayLike": "numpy.typing.ArrayLike", - "plt": "matplotlib.pyplot", + "Figure": "matplotlib.figure.Figure", "scipy": "scipy", "astropy": "astropy", "attrs": "attrs", @@ -99,23 +100,27 @@ "Buffer": "typing_extensions.Buffer", "Iterator": "collections.abc.Iterator", "Callable": "collections.abc.Callable", + "Literal": "typing.Literal", } numpydoc_xref_ignore = { "of", + "or", "shape", "type", "optional", + "scalar", "default", } - coverage_show_missing_items = True -myst_enable_extensions = ["colon_fence"] +myst_enable_extensions = ["colon_fence", "deflist", "dollarmath", "amsmath"] nb_execution_mode = "auto" nb_execution_timeout = -1 +copybutton_prompt_text = ">>> " + # -- Options for intersphinx extension --------------------------------------- intersphinx_mapping = { @@ -127,3 +132,34 @@ "matplotlib": ("https://matplotlib.org/stable/", None), "typing_extensions": ("https://typing-extensions.readthedocs.io/en/stable/", None), } + +# -- Linkcode configuration -------------------------------------------------- + + +def linkcode_resolve(domain: str, info: dict) -> str | None: + """Point to the source code repository, file and line number.""" + if domain != "py" or not info["module"]: + return None + try: + mod = import_module(info["module"]) + if "." in info["fullname"]: + objname, attrname = info["fullname"].split(".") + obj = getattr(getattr(mod, objname), attrname) + else: + obj = getattr(mod, info["fullname"]) + + file = inspect.getsourcefile(obj) + lines, start_line = inspect.getsourcelines(obj) + except (TypeError, AttributeError, ImportError): + return None + + if not file or not lines: + return None + file_path = Path(file).resolve().relative_to(Path("..").resolve()) + end_line = start_line + len(lines) - 1 + + # Determine the branch based on RTD version + rtd_version = os.getenv("READTHEDOCS_VERSION", "latest") + github_branch = "develop" if rtd_version == "develop" else "main" + + return f"{repo_url}/blob/{github_branch}/{file_path}#L{start_line}-L{end_line}" diff --git a/pyproject.toml b/pyproject.toml index 0831ff8..aa03574 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "attrs", "click", "rich", + "rocket-fft", "bidict", "typing_extensions", ] @@ -58,9 +59,9 @@ docs = [ "sphinx-click", "sphinx-prompt", "sphinx-copybutton", - "sphinx-autodoc-typehints", - "myst-nb", "numpydoc", + "myst-nb", + "jupyter_sphinx", ] develop = ["ruff"] @@ -89,7 +90,7 @@ indent-style = "space" [tool.ruff.lint] select = ["ALL"] -ignore = ["D1", "ANN1", "PLR2004", "G004"] +ignore = ["D1", "D301", "ANN1", "PLR2004", "G004"] [tool.ruff.lint.pylint] max-args = 15 diff --git a/sigpyproc/block.py b/sigpyproc/block.py index 9155eea..da8050f 100644 --- a/sigpyproc/block.py +++ b/sigpyproc/block.py @@ -9,7 +9,8 @@ class FilterbankBlock: - """An array class to handle a discrete block of data in time-major order. + """ + An array class to handle a discrete block of data in time-major order. Parameters ---------- @@ -19,11 +20,6 @@ class FilterbankBlock: header object containing metadata dm : float, optional DM of the input_array, by default 0 - - Returns - ------- - :py:obj:`~numpy.ndarray` - 2 dimensional array of shape (nchans, nsamples) with header metadata """ def __init__(self, data: np.ndarray, hdr: Header, dm: float = 0) -> None: diff --git a/sigpyproc/core/filters.py b/sigpyproc/core/filters.py index 8996b96..6afa158 100644 --- a/sigpyproc/core/filters.py +++ b/sigpyproc/core/filters.py @@ -1,42 +1,92 @@ from __future__ import annotations +from typing import Literal + +import attrs import numpy as np -from astropy.convolution import convolve_fft -from astropy.convolution.kernels import Model1DKernel -from astropy.modeling.models import Box1D, Gaussian1D, Lorentz1D from astropy.stats import gaussian_fwhm_to_sigma from matplotlib import pyplot as plt +from numba import typed -from sigpyproc.core.stats import estimate_scale +from sigpyproc.core import kernels +from sigpyproc.core.stats import ( + LocMethodType, + ScaleMethodType, + ZScoreResult, + estimate_zscore, +) class MatchedFilter: - """Matched filter class for pulse detection. + """ + Matched filter class for pulse detection in 1D time series data. - This class implements a matched filter algorithm to detect pulses in 1D data. + This class implements a matched filter algorithm to detect pulses of varying + durations in 1D time series data. It uses a set of pulse templates with + varying widths and selects the template that produces the highest + signal-to-noise ratio (SNR) as the best match. Parameters ---------- - data : np.ndarray - Input data array - temp_kind : str, optional - Type of the pulse template, by default "boxcar" + data : ndarray + Input data array for matched filtering (1D). + loc_method : {"median", "mean", "norm"}, optional + Method to estimate location, by default "median". + scale_method : str, optional + Method to estimate scale, by default "iqr". + temp_kind : {"boxcar", "gaussian", "lorentzian"}, optional + Type of the pulse template, by default "boxcar". nbins_max : int, optional - Maximum number of bins for template width, by default 32 + Maximum number of bins for template width, by default 32. spacing_factor : float, optional - Factor for spacing between template widths, by default 1.5 + Factor for spacing between template widths, by default 1.5. Raises ------ ValueError - _description_ + If the input ``data`` dimension is not 1. + + See Also + -------- + sigpyproc.core.stats.estimate_zscore : Estimate Z-score of input data. + sigpyproc.core.stats.estimate_loc : Estimate location of input data. + sigpyproc.core.stats.estimate_scale: Estimate scale of input data. + + Notes + ----- + The matched filter is the optimal linear filter for maximizing the signal-to-noise + ratio (SNR) of a known pulse template in the presence of additive white noise. + + For input data :math:`x(t)` and a template :math:`h(t)`, the matched + filter output :math:`y(t)` is: + + .. math:: y(t) = (x \\star h)(t) = \\sum_{\\tau} x(\\tau) h(t - \\tau) + + This is computed efficiently using the FFT-based methods. As per the circular + convolution theorem: + + .. math:: Y(f) = X(f) H(f) + .. math:: y(t) = \\mathcal{F}^{-1}(Y(f)) + + where :math:`\\mathcal{F}^{-1}` is the inverse Fourier transform, :math:`X(f)` + and :math:`H(f)` are the Fourier transforms of :math:`x(t)` and :math:`h(t)` + respectively. + + References + ---------- + .. [1] Wikipedia, "Matched filter", + https://en.wikipedia.org/wiki/Matched_filter + .. [2] Wikipedia, "Circular convolution", + https://en.wikipedia.org/wiki/Circular_convolution + """ def __init__( self, data: np.ndarray, - noise_method: str = "iqr", - temp_kind: str = "boxcar", + loc_method: LocMethodType | Literal["norm"] = "median", + scale_method: ScaleMethodType | Literal["norm"] = "iqr", + temp_kind: Literal["boxcar", "gaussian", "lorentzian"] = "boxcar", nbins_max: int = 32, spacing_factor: float = 1.5, ) -> None: @@ -44,200 +94,287 @@ def __init__( msg = f"Data dimension {data.ndim} is not supported." raise ValueError(msg) self._temp_kind = temp_kind - self._noise_method = noise_method - self._data = self._get_norm_data(data) - self._temp_widths = self.get_width_spacing(nbins_max, spacing_factor) - self._temp_bank = [ - getattr(Template, f"gen_{self.temp_kind}")(iwidth) - for iwidth in self.temp_widths - ] - - self._convs = np.array( - [ - convolve_fft(self.data, temp.kernel, normalize_kernel=False) - for temp in self.temp_bank - ], - ) - self._itemp, self._peak_bin = np.unravel_index( - self._convs.argmax(), - self._convs.shape, + self._data = np.asarray(data, dtype=np.float32) + self._zscores = estimate_zscore( + self.data, + loc_method=loc_method, + scale_method=scale_method, ) - self._best_temp = self.temp_bank[self._itemp] - self._best_snr = self._convs[self._itemp, self._peak_bin] + self._setup_templates(nbins_max, spacing_factor) + self._compute() @property def data(self) -> np.ndarray: + """:obj:`~numpy.ndarray`: Input data array for matched filtering.""" return self._data @property - def noise_method(self) -> str: - return self._noise_method + def zscores(self) -> ZScoreResult: + """:class:`~sigpyproc.core.stats.ZScoreResult`: Z-score of the input data.""" + return self._zscores @property def temp_kind(self) -> str: + """:obj:`str`: Type of the pulse template.""" return self._temp_kind @property def temp_widths(self) -> np.ndarray: + """:obj:`~numpy.ndarray`: Template widths used for matched filtering.""" return self._temp_widths @property def temp_bank(self) -> list[Template]: + """:obj:`list[Template]`: List of pulse templates used for matched filtering.""" return self._temp_bank @property def convs(self) -> np.ndarray: + """:obj:`~numpy.ndarray`: Convolution results for all templates.""" return self._convs @property def peak_bin(self) -> int: - """Best match template peak bin (`int`, read-only).""" + """:obj:`int`: Best match template peak bin.""" return int(self._peak_bin) @property def best_temp(self) -> Template: + """:class:`~sigpyproc.core.filters.Template`: Best match template.""" return self._best_temp @property def snr(self) -> float: - """Signal-to-noise ratio based on best match template on pulse.""" + """:obj:`float`: Signal-to-noise ratio based on best match template.""" return self._best_snr @property def best_model(self) -> np.ndarray: - """Best match template fit (`np.ndarray`, read-only).""" - return self.snr * np.roll( - self.best_temp.get_padded(self.data.size), - self.peak_bin - self.best_temp.ref_bin, + """:obj:`~numpy.ndarray`: Best match template fit.""" + return ( + self.best_temp.get_model(self.peak_bin, self.data.size) + * self.snr + * self.zscores.scale + + self.zscores.loc ) @property def on_pulse(self) -> tuple[int, int]: - """Best match template pulse region (`Tuple[int, int]`, read-only).""" - start = max(0, self.peak_bin - round(self.best_temp.width)) - end = min(self.data.size, self.peak_bin + round(self.best_temp.width)) - return (start, end) - - def plot(self) -> plt.Figure: - fig, ax = plt.subplots(figsize=(12, 6)) - ax.plot(self.data, label="Data") - ax.plot(self.best_model, label="Best Model") - ax.axvline(self.peak_bin, color="r", linestyle="--", label="Peak") + """:obj:`tuple[int, int]`: Best match template pulse region.""" + return self.best_temp.get_on_pulse(self.peak_bin, self.data.size) + + def plot( + self, + figsize: tuple[float, float] = (12, 6), + dpi: int = 100, + ) -> plt.Figure: + """ + Plot the pulse template. + + Parameters + ---------- + figsize : tuple[float, float], optional + Figure size in inches, by default (12, 6) + dpi : int, optional + Dots per inch, by default 100 + + Returns + ------- + Figure + Matplotlib figure object. + """ + title = ( + f"Matched Filter Result (Temp Kind: {self.temp_kind}, " + f"Best width: {self.best_temp.width:.2f}, " + f"SNR: {self.snr:.2f})" + ) + stats_box = f"loc: {self.zscores.loc:.2f}, scale: {self.zscores.scale:.2f}" + fig, ax = plt.subplots(figsize=figsize, dpi=dpi) + ax.plot(self.data, label="Data", lw=2) + ax.plot(self.best_model, label="Best Model", lw=2) + ax.axvline(self.peak_bin, color="r", linestyle="--", label="Peak", lw=2) ax.axvspan(*self.on_pulse, alpha=0.2, color="g", label="On Pulse") - ax.set( - xlabel="Bin", - ylabel="Amplitude", - title=f"Matched Filter Result (SNR: {self.snr:.2f})", + ax.text( + 0.05, + 0.95, + stats_box, + transform=ax.transAxes, + verticalalignment="top", + bbox={ + "fc": "white", + "ec": "gray", + "alpha": 0.8, + "boxstyle": "round, pad=0.5", + }, ) + + ax.set(xlabel="Bin", ylabel="Amplitude", title=title, xlim=(0, len(self.data))) ax.legend() fig.tight_layout() return fig - def _get_norm_data(self, data: np.ndarray) -> np.ndarray: - data = np.asarray(data, dtype=np.float32) - median = np.median(data) - noise_std = estimate_scale(data, self.noise_method) - return (data - median) / noise_std + def _setup_templates(self, nbins_max: int, spacing_factor: float) -> None: + if self.temp_kind == "boxcar": + self._temp_widths = self.get_box_width_spacing(nbins_max, spacing_factor) + else: + if spacing_factor <= 1: + msg = "Spacing factor must be greater than 1 for non-boxcar templates." + raise ValueError(msg) + npoints = int(np.ceil(np.log(nbins_max) / np.log(spacing_factor))) + 1 + self._temp_widths = np.geomspace(1, nbins_max, npoints) + temp_bank = [] + for width in self.temp_widths: + temp = getattr(Template, f"gen_{self.temp_kind}")(width) + if temp.data.size > self.data.size: + msg = ( + f"Template size ({temp.data.size}) is larger than the data size" + f"({self.data.size})." + ) + raise ValueError(msg) + temp_bank.append(temp) + self._temp_bank = temp_bank + + def _compute(self) -> None: + temp_kernels = typed.List([temp.data for temp in self.temp_bank]) + ref_bins = typed.List([temp.ref_bin for temp in self.temp_bank]) + self._convs = kernels.convolve_fft(self.zscores.data, temp_kernels, ref_bins) + self._itemp, self._peak_bin = np.unravel_index( + self._convs.argmax(), + self._convs.shape, + ) + self._best_temp = self.temp_bank[self._itemp] + self._best_snr = self._convs[self._itemp, self._peak_bin] @staticmethod - def get_width_spacing( - nbins_max: int, + def get_box_width_spacing( + size_max: int, spacing_factor: float = 1.5, ) -> np.ndarray: - """Get width spacing for matched filtering. + """ + Get box width spacing for matched filtering. Parameters ---------- - nbins_max : int - Maximum number of bins. + size_max : int + Maximum number of bins for box template width. spacing_factor : float, optional Spacing factor for width, by default 1.5 Returns ------- - np.ndarray + ndarray Width spacing for matched filtering. """ widths = [1] - while widths[-1] < nbins_max: + while widths[-1] < size_max: next_width = int(max(widths[-1] + 1, spacing_factor * widths[-1])) - if next_width > nbins_max: + if next_width > size_max: break widths.append(next_width) return np.array(widths, dtype=np.float32) +@attrs.define(auto_attribs=True, slots=True, frozen=True) class Template: - """1D pulse template class for matched filtering. + """ + 1D pulse template class for matched filtering. This class represents various pulse shapes as templates for matched filtering and provides methods to generate and visualize them. Parameters ---------- - kernel : Model1DKernel - Astropy 1D model kernel. + data : ndarray + Pulse template data array (1D). width : float Width of the pulse template in bins. + ref_bin : int, optional + Reference bin for the pulse template, by default 0 + ref : {"start", "peak"}, optional + Reference type for the pulse template, by default "start" kind : str, optional Type of the pulse template, by default "custom" """ - def __init__( - self, - kernel: Model1DKernel, - width: float, - kind: str = "custom", - ) -> None: - self._kernel = kernel - self._width = width - self._kind = kind - - @property - def kernel(self) -> Model1DKernel: - """Astropy 1D model kernel (`Model1DKernel`, read-only).""" - return self._kernel - - @property - def width(self) -> float: - """Width of the pulse template in bins (`float`, read-only).""" - return self._width + data: np.ndarray + width: float + ref_bin: int = attrs.field( + default=0, + validator=[attrs.validators.instance_of(int), attrs.validators.ge(0)], + ) + ref: str = attrs.field( + default="start", + validator=attrs.validators.in_({"start", "peak"}), + ) + kind: str = attrs.field( + default="custom", + validator=attrs.validators.instance_of(str), + ) + + def __attrs_post_init__(self) -> None: + if not self.data.size: + msg = "Empty data array is not supported." + raise ValueError(msg) + if self.data.ndim != 1: + msg = f"Only 1D data is supported, got {self.data.ndim}." + raise ValueError(msg) + if self.ref_bin >= self.data.size: + msg = f"Reference bin {self.ref_bin} is out of bounds." + raise ValueError(msg) - @property - def kind(self) -> str: - """Type of the pulse template (`str`, read-only).""" - return self._kind + def get_model(self, peak_bin: int, nbins: int) -> np.ndarray: + """ + Get profile model for the pulse template. - @property - def ref_bin(self) -> int: - """Reference bin of the pulse template (`int`, read-only).""" - return self.kernel.center[0] + Parameters + ---------- + peak_bin : int + Peak bin in the profile + nbins : int + Profile size - @property - def size(self) -> int: - """Size of the pulse template (`int`, read-only).""" - return self.kernel.shape[0] + Returns + ------- + ndarray + Profile model for the pulse template + """ + padded = np.pad(self.data, (0, nbins - self.data.size)) + padded_norm = kernels.normalize_template(padded) + return np.roll(padded_norm, peak_bin - self.ref_bin) - def get_padded(self, size: int) -> np.ndarray: + def get_on_pulse(self, peak_bin: int, nbins: int) -> tuple[int, int]: """ - Pad template to desired size. + Get on pulse region in the profile model for the pulse template. Parameters ---------- - size: int - Size of the padded pulse template. + peak_bin : int + Peak bin in the model + nbins : int + Profile size + + Returns + ------- + tuple[int, int] + Start and end bin of the on pulse region """ - if self.size >= size: - msg = f"Template size {self.size} is larger than {size}." - raise ValueError(msg) - return np.pad(self.kernel.array, (0, size - self.size)) + if self.ref == "start": + pulse_left = peak_bin + pulse_right = peak_bin + self.width + else: + pulse_left = peak_bin - round(self.width) + pulse_right = peak_bin + round(self.width) + start = max(0, pulse_left) + end = min(nbins, pulse_right) + return (start, int(end)) def plot( self, figsize: tuple[float, float] = (10, 5), dpi: int = 100, ) -> plt.Figure: - """Plot the pulse template. + """ + Plot the pulse template. Parameters ---------- @@ -248,15 +385,15 @@ def plot( Returns ------- - plt.Figure + Figure Matplotlib figure object. """ fig, ax = plt.subplots(figsize=figsize, dpi=dpi) - ax.bar(range(self.size), self.kernel.array, ec="k", fc="#a6cee3") + ax.bar(range(self.data.size), self.data, ec="k", fc="#a6cee3") ax.axvline(self.ref_bin, ls="--", lw=2, color="k", label="Ref Bin") ax.legend() ax.set( - xlim=(-0.5, self.size - 0.5), + xlim=(-0.5, self.data.size - 0.5), xlabel="Bin", ylabel="Amplitude", title=str(self), @@ -271,12 +408,20 @@ def gen_boxcar(cls, width: int) -> Template: Parameters ---------- - width: int + width : int Width of the box in bins. + + Returns + ------- + Template + Boxcar pulse template with the reference bin at the start. """ - norm = 1 / np.sqrt(width) - temp = Model1DKernel(Box1D(norm, 0, width), x_size=width) - return cls(temp, width, kind="boxcar") + width = int(width) + if width <= 0: + msg = f"Width {width} must be greater than 0." + raise ValueError(msg) + arr = np.ones(width, dtype=np.float32) + return cls(arr, width, ref_bin=0, ref="start", kind="boxcar") @classmethod def gen_gaussian(cls, width: float, extent: float = 3.5) -> Template: @@ -285,39 +430,57 @@ def gen_gaussian(cls, width: float, extent: float = 3.5) -> Template: Parameters ---------- - width: float + width : float FWHM of the Gaussian pulse in bins. - - extent: float + extent : float, optional Extent of the Gaussian pulse in sigma units, by default 3.5. + + Returns + ------- + Template + Gaussian pulse template with the reference bin at the peak. """ + if width <= 0: + msg = f"Width {width} must be greater than 0." + raise ValueError(msg) stddev = gaussian_fwhm_to_sigma * width - norm = 1 / (np.sqrt(np.sqrt(np.pi) * stddev)) - size = int(np.ceil(extent * stddev) * 2 + 1) - temp = Model1DKernel(Gaussian1D(norm, 0, stddev), x_size=size) - return cls(temp, width, kind="gaussian") + size = int(np.ceil(extent * stddev)) + x = np.arange(-size, size + 1) + ref_bin = len(x) // 2 + arr = np.exp(-0.5 * x**2 / stddev**2) + return cls(arr, width, ref_bin=ref_bin, ref="peak", kind="gaussian") @classmethod def gen_lorentzian(cls, width: float, extent: float = 3.5) -> Template: """ - Generate a Lorentzian pulse template for given pulse FWHM (bins). + Generate a Lorentzian pulse template. Parameters ---------- - width: float + width : float FWHM of the Lorentzian pulse in bins. - - extent: float + extent : float, optional Extent of the Lorentzian pulse in sigma units, by default 3.5. + + Returns + ------- + Template + Lorentzian pulse template. """ + if width <= 0: + msg = f"Width {width} must be greater than 0." + raise ValueError(msg) stddev = gaussian_fwhm_to_sigma * width - norm = 1 / (np.sqrt((np.pi * width) / 4)) - size = int(np.ceil(extent * stddev) * 2 + 1) - temp = Model1DKernel(Lorentz1D(norm, 0, width), x_size=size) - return cls(temp, width, kind="lorentzian") + size = int(np.ceil(extent * stddev)) + x = np.arange(-size, size + 1) + ref_bin = len(x) // 2 + arr = 1 / (1 + (x / stddev) ** 2) + return cls(arr, width, ref_bin=ref_bin, ref="peak", kind="lorentzian") def __str__(self) -> str: - return f"Template(size={self.size}, kind={self.kind}, width={self.width:.3f})" + return ( + f"Template(size={self.data.size}, kind={self.kind}, width={self.width:.3f})" + ) def __repr__(self) -> str: return str(self) diff --git a/sigpyproc/core/kernels.py b/sigpyproc/core/kernels.py index d8d4a96..ac8696c 100644 --- a/sigpyproc/core/kernels.py +++ b/sigpyproc/core/kernels.py @@ -492,6 +492,7 @@ def fs_running_median( ) return out_arr + @njit("f4[:,:](f4[:], i8)", cache=True, fastmath=True) def sum_harmonics(pow_spec: np.ndarray, nfolds: int) -> np.ndarray: nfreqs = len(pow_spec) @@ -708,3 +709,80 @@ def detrend_1d(arr: np.ndarray) -> np.ndarray: trend = slope * np.arange(m, dtype=arr.dtype) + intercept return arr - trend.astype(arr.dtype) + + +@njit(cache=True, fastmath=True) +def convolve_fft( + data: np.ndarray, + temp_bank: types.List[types.Array], + ref_bin: types.List[int], +) -> np.ndarray: + """ + Convolve the data with the templates in the template bank. + + Parameters + ---------- + data : np.ndarray + Input data array. + temp_bank : list[np.ndarray] + List of template arrays. + ref_bin : list[int] + List of reference bin indices. + + Returns + ------- + np.ndarray + Convolved array. + + Notes + ----- + The reference bin is aligned to the index 0 and the template is time-reversed + (to perform convolution rather than correlation). The template is then + normalised to zero mean and unity power. + """ + nbins = len(data) + ntemps = len(temp_bank) + convs = np.empty((ntemps, nbins), dtype=data.dtype) + data_pad = circular_pad_pow2(data) + data_fft = np.fft.rfft(data_pad) + for itemp in range(ntemps): + temp_kernel = temp_bank[itemp] + temp_pad = np.zeros(data_pad.size, dtype=data.dtype) + temp_pad[: len(temp_kernel)] = temp_kernel + # Align the reference bin to the index 0 + temp_pad_roll = np.roll(temp_pad, -ref_bin[itemp]) + # Time reverse the template (for convolution) + temp_pad_aligned = np.roll(temp_pad_roll[::-1], 1) + temp_norm = normalize_template(temp_pad_aligned) + conv = np.fft.irfft(data_fft * np.fft.rfft(temp_norm)) + convs[itemp] = conv[:nbins] + return convs + + +@njit(cache=True, fastmath=True) +def circular_pad_pow2(arr: np.ndarray) -> np.ndarray: + nbins = len(arr) + nbins_pow2 = 2 ** int(np.ceil(np.log2(nbins))) + result = np.empty(nbins_pow2, dtype=arr.dtype) + for i in range(nbins_pow2): + result[i] = arr[i % nbins] + return result + + +@njit(cache=True, fastmath=True) +def normalize_template(arr: np.ndarray) -> np.ndarray: + """ + Normalize the template to have zero mean and unit power. + + Parameters + ---------- + arr : np.ndarray + Template array. + + Returns + ------- + np.ndarray + Normalized template array. + """ + arr_norm = arr - np.mean(arr) + return arr_norm / (np.dot(arr_norm, arr_norm) ** 0.5) diff --git a/sigpyproc/core/rfi.py b/sigpyproc/core/rfi.py index 706d086..0cadb86 100644 --- a/sigpyproc/core/rfi.py +++ b/sigpyproc/core/rfi.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import attrs import h5py @@ -14,53 +14,55 @@ def double_mad_mask(array: np.ndarray, threshold: float = 3) -> np.ndarray: - """Calculate the mask of an array using the double MAD (Modified z-score). + """ + Calculate the mask of an array using the double MAD (Modified z-score). Parameters ---------- - array : :py:obj:`~numpy.ndarray` - The array to calculate the mask of. + array : ndarray + The input array to calculate the mask of. threshold : float, optional - Threshold in sigmas, by default 3.0 + Threshold in sigmas, by default 3.0. Returns ------- - :py:obj:`~numpy.ndarray` + ndarray The mask for the array. Raises ------ ValueError - If the threshold is not positive. + If the ``threshold`` is not positive. """ if threshold <= 0: msg = f"threshold must be positive, got {threshold}" raise ValueError(msg) - zscore_re = stats.zscore(array, scale_method="doublemad") - return np.abs(zscore_re.zscores) > threshold + zscore = stats.estimate_zscore(array, scale_method="doublemad") + return np.abs(zscore.data) > threshold def iqrm_mask(array: np.ndarray, threshold: float = 3, radius: int = 5) -> np.ndarray: - """Calculate the mask of an array using the IQRM (Interquartile Range Method). + """ + Calculate the mask of an array using the IQRM (Interquartile Range Method). Parameters ---------- - array : :py:obj:`~numpy.ndarray` - The array to calculate the mask of. + array : ndarray + The input array to calculate the mask of. threshold : float, optional - Threshold in sigmas, by default 3.0 + Threshold in sigmas, by default 3.0. radius : int, optional - Radius to calculate the IQRM, by default 5 + Radius to calculate the IQRM, by default 5. Returns ------- - :py:obj:`~numpy.ndarray` + ndarray The mask for the array. Raises ------ ValueError - If the threshold is not positive. + If the ``threshold`` is not positive. """ if threshold <= 0: msg = f"threshold must be positive, got {threshold}" @@ -75,8 +77,8 @@ def iqrm_mask(array: np.ndarray, threshold: float = 3, radius: int = 5) -> np.nd lagged_diffs = array[:, np.newaxis] - shifted_x[:, lags + radius] lagged_diffs = lagged_diffs.T for lagged_diff in lagged_diffs: - zscore_re = stats.zscore(lagged_diff, scale_method="iqr") - mask = np.logical_or(mask, np.abs(zscore_re.zscores) > threshold) + zscore = stats.estimate_zscore(lagged_diff, scale_method="iqr") + mask = np.logical_or(mask, np.abs(zscore.data) > threshold) return mask @@ -99,26 +101,27 @@ def _set_chan_mask(self) -> np.ndarray: @property def num_masked(self) -> int: - """int: Number of masked channels.""" + """:obj:`int`: Number of masked channels.""" return np.sum(self.chan_mask) @property def masked_fraction(self) -> float: - """float: Fraction of channels masked.""" + """:obj:`float`: Fraction of channels masked.""" return self.num_masked * 100 / self.header.nchans def apply_mask(self, chanmask: np.ndarray) -> None: - """Apply a channel mask to the current mask. + """ + Apply a channel mask to the current mask. Parameters ---------- - chanmask : :py:obj:`~numpy.typing.ArrayLike` + chanmask : ndarray User channel mask to apply. Raises ------ ValueError - If the channel mask is not the same size as the current mask. + If the ``chanmask`` is not the same size as the current mask. """ chanmask = np.asarray(chanmask, dtype="bool") if chanmask.size != self.header.nchans: @@ -126,18 +129,19 @@ def apply_mask(self, chanmask: np.ndarray) -> None: raise ValueError(msg) self.chan_mask = np.logical_or(self.chan_mask, chanmask) - def apply_method(self, method: str = "mad") -> None: - """Apply a mask method using channel statistics. + def apply_method(self, method: Literal["iqrm", "mad"] = "mad") -> None: + """ + Apply a mask method using channel statistics. Parameters ---------- - method : str - Mask method to apply (`iqrm`, `mad`). + method : {'iqrm', 'mad'}, optional + Method to apply, by default 'mad'. Raises ------ ValueError - If the method is not supported. + If the ``method`` is not supported. """ if method == "mad": method_funcn = double_mad_mask @@ -153,17 +157,18 @@ def apply_method(self, method: str = "mad") -> None: self.chan_mask = np.logical_or(self.chan_mask, mask_stats) def apply_funcn(self, custom_funcn: Callable[[np.ndarray], np.ndarray]) -> None: - """Apply a custom function to the channel mask. + """ + Apply a custom function to the channel mask. Parameters ---------- - custom_funcn : :py:obj:`~typing.Callable` + custom_funcn : Callable[[ndarray], ndarray] Custom function to apply to the mask. Raises ------ ValueError - If the custom_funcn is not callable. + If the ``custom_funcn`` is not callable. """ if not callable(custom_funcn): msg = f"{custom_funcn} is not callable" @@ -171,7 +176,8 @@ def apply_funcn(self, custom_funcn: Callable[[np.ndarray], np.ndarray]) -> None: self.chan_mask = custom_funcn(self.chan_mask) def to_file(self, filename: str | None = None) -> str: - """Write the mask to a HDF5 file. + """ + Write the mask to a HDF5 file. Parameters ---------- @@ -197,7 +203,8 @@ def to_file(self, filename: str | None = None) -> str: @classmethod def from_file(cls, filename: str) -> RFIMask: - """Load a mask from a HDF5 file. + """ + Load a mask from a HDF5 file. Parameters ---------- diff --git a/sigpyproc/core/stats.py b/sigpyproc/core/stats.py index 0d8d449..d64356b 100644 --- a/sigpyproc/core/stats.py +++ b/sigpyproc/core/stats.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, Literal import attrs import bottleneck as bn @@ -9,22 +9,36 @@ from sigpyproc.core import kernels +LocMethodType = Literal["median", "mean"] +ScaleMethodType = Literal[ + "iqr", + "mad", + "doublemad", + "diffcov", + "biweight", + "qn", + "sn", + "gapper", +] + @attrs.define(auto_attribs=True, slots=True, kw_only=True) class ZScoreResult: - """Result of a Z-score calculation. + """ + Container for Z-score calculation results. - Attributes + Parameters ---------- - zscores: numpy.ndarray - Robust Z-scores of the array. + data: ndarray + Robust Z-scores of the input array (normalized data). loc: float - Estimated location used for the Z-score calculation. - scale: float | numpy.ndarray - Estimated scale used for the Z-score calculation. + Estimated location (central tendency) used for the Z-score calculation. + scale: float | ndarray + Estimated scale (variability) used for the Z-score calculation. + Can be a scalar or an array matching the shape of `data`. """ - zscores: np.ndarray + data: np.ndarray loc: float scale: float | np.ndarray @@ -32,33 +46,35 @@ class ZScoreResult: def running_filter( array: np.ndarray, window: int, - method: str = "mean", + method: Literal["mean", "median"] = "mean", ) -> np.ndarray: """ Calculate the running filter of an array. + Applies a sliding window filter to the input array using the specified method. + Parameters ---------- - array : numpy.ndarray - The array to calculate the running filter of. + array : ndarray + The input array to filter. window : int - The window size of the filter. - method : str, optional - The method to use for the filter, by default "mean". + The size of the sliding window. + method : {"mean", "median"}, optional + The filtering method to use, by default "mean". Returns ------- - numpy.ndarray - The running filter of the array. + ndarray + The filtered array with the same shape as the input array. Raises ------ ValueError - If the filter function is not "mean" or "median". + If the ``method`` is not supported. Notes ----- - Window edges are handled by reflecting about the edges. + Window edges are handled by reflecting about the edges of the input array. """ filter_methods: dict[str, Callable[[np.ndarray, int], np.ndarray]] = { @@ -78,14 +94,18 @@ def running_filter( return filtered_ar[window - 1 :] -def estimate_loc(array: np.ndarray, method: str = "median") -> float: - """Estimate the location of an array. +def estimate_loc( + array: np.ndarray, + method: LocMethodType = "median", +) -> float: + """ + Estimate the location (central tendency) of an array. Parameters ---------- - array : numpy.ndarray - The array to estimate the location of. - method : str, optional + array : ndarray + The input array to estimate the location of. + method : {"median", "mean"}, optional The method to use for estimating the location, by default "median". Returns @@ -96,7 +116,7 @@ def estimate_loc(array: np.ndarray, method: str = "median") -> float: Raises ------ ValueError - If the method is not supported + If the ``array`` is empty or if the ``method`` is not supported. """ loc_methods: dict[str, Callable[[np.ndarray], float]] = { "median": np.median, @@ -114,38 +134,47 @@ def estimate_loc(array: np.ndarray, method: str = "median") -> float: return loc_func(array) -def estimate_scale(array: np.ndarray, method: str = "mad") -> float | np.ndarray: - """Estimate the scale or standard deviation of an array. +def estimate_scale( + array: np.ndarray, + method: ScaleMethodType = "mad", +) -> float | np.ndarray: + """ + Estimate the scale (variability) or standard deviation of an array. Parameters ---------- - array : numpy.ndarray - The array to estimate the scale of. - method : str, optional + array : ndarray + The input array to estimate the scale of. + method : {"iqr", "mad", "doublemad", "diffcov", "biweight", "qn", + "sn", "gapper"}, optional + The method to use for estimating the scale, by default "mad". + - `iqr`: Normalized Inter-quartile Range. + - `mad`: Median Absolute Deviation. + - `doublemad`: Double MAD. + - `diffcov`: Difference Covariance + - `biweight`: Biweight Midvariance + - `qn`: Normalized Qn scale + - `sn`: Normalized Sn scale + - `gapper`: Gapper Estimator + Returns ------- - float | numpy.ndarray - The estimated scale of the array. + float | ndarray + The estimated scale of the array. If the method is "doublemad", the + output is an array of the same shape as the input array. Raises ------ ValueError - If the method is not supported or if the array is empty. + If the ``array`` is empty or if the ``method`` is not supported. + + References + ---------- + .. [1] Wikipedia, "Robust measures of scale", + https://en.wikipedia.org/wiki/Robust_measures_of_scale - Notes - ----- - https://en.wikipedia.org/wiki/Robust_measures_of_scale - - Following methods are supported: - - "iqr": Normalized Inter-quartile Range - - "mad": Median Absolute Deviation - - "diffcov": Difference Covariance - - "biweight": Biweight Midvariance - - "qn": Normalized Qn scale - - "sn": Normalized Sn scale - - "gapper": Gapper Estimator """ scale_methods: dict[str, Callable[[np.ndarray], float | np.ndarray]] = { "iqr": _scale_iqr, @@ -169,41 +198,53 @@ def estimate_scale(array: np.ndarray, method: str = "mad") -> float | np.ndarray return scale_func(array) -def zscore( +def estimate_zscore( array: np.ndarray, - loc_method: str = "median", - scale_method: str = "mad", + loc_method: LocMethodType | Literal["norm"] = "median", + scale_method: ScaleMethodType | Literal["norm"] = "mad", ) -> ZScoreResult: - """Calculate robust Z-scores of an array. + """ + Calculate robust Z-scores of an array. Parameters ---------- - array : numpy.ndarray - The array to calculate the Z-score of. - loc_method : str, optional + array : ndarray + The input array to calculate the Z-score of. + loc_method : {"median", "mean", "norm"}, optional The method to use for estimating the location, by default "median". - scale_method : str, optional + + Use "norm" to set the location to 0. + scale_method : {"mad", "iqr", "doublemad", "diffcov", "biweight", "qn", "sn", + "gapper", "norm"}, optional + The method to use for estimating the scale, by default "mad". + Use "norm" to set the scale to 1. + Returns ------- ZScoreResult - The robust Z-scores of the array. + A container with the Z-scores, estimated location, and scale. Raises ------ ValueError - If the location or scale method is not supported. + If the ``loc_method`` or ``scale_method`` is not supported. + + See Also + -------- + estimate_loc, estimate_scale """ - loc = estimate_loc(array, loc_method) - scale = estimate_scale(array, scale_method) + loc = 0 if loc_method == "norm" else estimate_loc(array, loc_method) + scale = 1 if scale_method == "norm" else estimate_scale(array, scale_method) diff = array - loc zscores = np.divide(diff, scale, out=np.zeros_like(diff), where=scale != 0) - return ZScoreResult(zscores=zscores, loc=loc, scale=scale) + return ZScoreResult(data=zscores, loc=loc, scale=scale) def _scale_iqr(array: np.ndarray) -> float: - """Calculate the normalized Inter-quartile Range (IQR) scale of an array. + """ + Calculate the normalized Inter-quartile Range (IQR) scale of an array. Parameters ---------- @@ -222,7 +263,8 @@ def _scale_iqr(array: np.ndarray) -> float: def _scale_mad(array: np.ndarray) -> float: - """Calculate the Median Absolute Deviation (MAD) scale of an array. + """ + Calculate the Median Absolute Deviation (MAD) scale of an array. Parameters ---------- @@ -234,9 +276,10 @@ def _scale_mad(array: np.ndarray) -> float: float The MAD scale of the array. - Notes - ----- - https://www.ibm.com/docs/en/cognos-analytics/12.0.0?topic=terms-modified-z-score + References + ---------- + .. [1] IBM, "Modified Z-score", + https://www.ibm.com/docs/en/cognos-analytics/12.0.0?topic=terms-modified-z-score """ norm = 0.6744897501960817 # scipy.stats.norm.ppf(0.75) norm_aad = np.sqrt(2 / np.pi) @@ -248,7 +291,8 @@ def _scale_mad(array: np.ndarray) -> float: def _scale_doublemad(array: np.ndarray) -> np.ndarray: - """Calculate the Double MAD scale of an array. + """ + Calculate the Double MAD scale of an array. Parameters ---------- @@ -260,11 +304,12 @@ def _scale_doublemad(array: np.ndarray) -> np.ndarray: np.ndarray The Double MAD scale of the array. - Notes - ----- - The Double MAD is defined as: - https://eurekastatistics.com/using-the-median-absolute-deviation-to-find-outliers/ - https://aakinshin.net/posts/harrell-davis-double-mad-outlier-detector/ + References + ---------- + .. [1] Eureka Statistics, "Using the Median Absolute Deviation to Find Outliers", + https://eurekastatistics.com/using-the-median-absolute-deviation-to-find-outliers/ + .. [2] A. Akinshin, "Harrell-Davis Double MAD Outlier Detector", + https://aakinshin.net/posts/harrell-davis-double-mad-outlier-detector/ """ norm = 0.6744897501960817 # scipy.stats.norm.ppf(0.75) @@ -287,7 +332,8 @@ def _scale_doublemad(array: np.ndarray) -> np.ndarray: def _scale_diffcov(array: np.ndarray) -> float: - """Calculate the Difference Covariance scale of an array. + """ + Calculate the Difference Covariance scale of an array. Parameters ---------- @@ -304,7 +350,8 @@ def _scale_diffcov(array: np.ndarray) -> float: def _scale_biweight(array: np.ndarray) -> float: - """Calculate the Biweight Midvariance scale of an array. + """ + Calculate the Biweight Midvariance scale of an array. Parameters ---------- @@ -320,7 +367,8 @@ def _scale_biweight(array: np.ndarray) -> float: def _scale_qn(array: np.ndarray) -> float: - """Calculate the Normalized Qn scale of an array. + """ + Calculate the Normalized Qn scale of an array. Parameters ---------- @@ -342,7 +390,8 @@ def _scale_qn(array: np.ndarray) -> float: def _scale_sn(array: np.ndarray) -> float: - """Calculate the Normalized Sn scale of an array. + """ + Calculate the Normalized Sn scale of an array. Parameters ---------- @@ -359,7 +408,8 @@ def _scale_sn(array: np.ndarray) -> float: def _scale_gapper(array: np.ndarray) -> float: - """Calculate the Gapper Estimator scale of an array. + """ + Calculate the Gapper Estimator scale of an array. Parameters ---------- @@ -378,71 +428,74 @@ def _scale_gapper(array: np.ndarray) -> float: class ChannelStats: - def __init__(self, nchans: int, nsamps: int) -> None: - """Central central moments for filterbank channels in one pass. + """ + A class to compute the central moments of filterbank data in one pass. - Parameters - ---------- - nchans : int - Number of channels in the data. - nsamps : int - Number of samples in the data. - - Notes - ----- - The algorithm is numerically stable and accurate: + Parameters + ---------- + nchans : int + Number of channels in the data. + nsamps : int + Number of samples in the data. + + References + ---------- + .. [1] Wikipedia, "Algorithms for calculating variance", https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm + .. [2] John D. Cook, "Skewness and kurtosis formulas for normal distributions", https://www.johndcook.com/blog/skewness_kurtosis/ + .. [3] Pebay, Philippe P., "One-Pass covariances and Statistical Moments", https://doi.org/10.2172/1028931 - """ + """ + + def __init__(self, nchans: int, nsamps: int) -> None: self._nchans = nchans self._nsamps = nsamps - self._moments = np.zeros(nchans, dtype=kernels.moments_dtype) @property def moments(self) -> np.ndarray: - """:class:`~numpy.ndarray`: Central moments of the data.""" + """:class:`~numpy.ndarray`: Get the central moments of the data.""" return self._moments @property def nchans(self) -> int: - """int: Get the number of channels.""" + """:obj:`int`: Get the number of channels.""" return self._nchans @property def nsamps(self) -> int: - """int: Get the number of samples.""" + """:obj:`int`: Get the number of samples.""" return self._nsamps @property def maxima(self) -> np.ndarray: - """numpy.ndarray: Get the maximum value of each channel.""" + """:class:`~numpy.ndarray`: Get the maximum value of each channel.""" return self._moments["max"] @property def minima(self) -> np.ndarray: - """numpy.ndarray: Get the minimum value of each channel.""" + """:class:`~numpy.ndarray`: Get the minimum value of each channel.""" return self._moments["min"] @property def mean(self) -> np.ndarray: - """numpy.ndarray: Get the mean of each channel.""" + """:class:`~numpy.ndarray`: Get the mean of each channel.""" return self._moments["m1"] @property def var(self) -> np.ndarray: - """numpy.ndarray: Get the variance of each channel.""" + """:class:`~numpy.ndarray`: Get the variance of each channel.""" return self._moments["m2"] / self.nsamps @property def std(self) -> np.ndarray: - """numpy.ndarray: Get the standard deviation of each channel.""" + """:class:`~numpy.ndarray`: Get the standard deviation of each channel.""" return np.sqrt(self.var) @property def skew(self) -> np.ndarray: - """numpy.ndarray: Get the skewness of each channel.""" + """:class:`~numpy.ndarray`: Get the skewness of each channel.""" return np.divide( self._moments["m3"], np.power(self._moments["m2"], 1.5), @@ -452,7 +505,7 @@ def skew(self) -> np.ndarray: @property def kurtosis(self) -> np.ndarray: - """numpy.ndarray: Get the kurtosis of each channel.""" + """:class:`~numpy.ndarray`: Get the kurtosis of each channel.""" return ( np.divide( self._moments["m4"], @@ -468,8 +521,23 @@ def push_data( self, array: np.ndarray, start_index: int, - mode: str = "basic", + mode: Literal["basic", "full"] = "basic", ) -> None: + """ + Update the central moments of the data with new samples. + + Parameters + ---------- + array : ndarray + The input array to update the moments with. + start_index : int + The starting time (sample) index of the data. + mode : {"basic", "full"}, optional + The mode to use for computing the moments, by default "basic". + + - "basic": Compute the moments upto 2nd order (variance). + - "full": Compute the moments upto 4th order (kurtosis). + """ if mode == "basic": kernels.compute_online_moments_basic( array, @@ -480,7 +548,8 @@ def push_data( kernels.compute_online_moments(array, self._moments, start_index) def __add__(self, other: ChannelStats) -> ChannelStats: - """Add two ChannelStats objects together as if all the data belonged to one. + """ + Add two ChannelStats objects together as if all the data belonged to one. Parameters ---------- diff --git a/tests/test_filters.py b/tests/test_filters.py index 31ac1d8..4782d70 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -3,32 +3,39 @@ from matplotlib import pyplot as plt from sigpyproc.core.filters import MatchedFilter, Template +from sigpyproc.core.stats import ZScoreResult @pytest.fixture(scope="module", autouse=True) def random_data() -> np.ndarray: rng = np.random.default_rng(42) - return rng.normal(loc=0, scale=1, size=1000) + return rng.normal(loc=0, scale=1, size=1000).astype(np.float32) @pytest.fixture(scope="module", autouse=True) def pulse_data() -> np.ndarray: rng = np.random.default_rng(42) - x = np.linspace(-10, 10, 1000) + x = np.linspace(-10, 10, 1000, dtype=np.float32) pulse = np.exp(-0.5 * (x**2)) - noise = rng.normal(0, 0.1, 1000) + noise = rng.normal(0, 0.1, 1000).astype(np.float32) return pulse + noise class TestMatchedFilter: def test_initialization(self, pulse_data: np.ndarray) -> None: - mf = MatchedFilter(pulse_data) + mf = MatchedFilter(pulse_data, temp_kind="boxcar") assert isinstance(mf, MatchedFilter) assert mf.data.shape == pulse_data.shape assert mf.temp_kind == "boxcar" - assert mf.noise_method == "iqr" + assert isinstance(mf.zscores, ZScoreResult) with pytest.raises(ValueError): - MatchedFilter(np.zeros((2, 2))) + MatchedFilter(np.zeros((2, 2), dtype=np.float32)) + + def test_fails(self, pulse_data: np.ndarray) -> None: + with pytest.raises(ValueError): + MatchedFilter(pulse_data, temp_kind="gaussian", spacing_factor=1) + with pytest.raises(ValueError): + MatchedFilter(np.ones(10), nbins_max=20) def test_convolution(self, pulse_data: np.ndarray) -> None: mf = MatchedFilter(pulse_data) @@ -46,13 +53,13 @@ def test_plot(self, pulse_data: np.ndarray) -> None: assert isinstance(fig, plt.Figure) plt.close(fig) - @pytest.mark.parametrize(("nbins_max", "spacing_factor"), [(16, 1.2), (64, 2.0)]) - def test_width_spacing(self, nbins_max: int, spacing_factor: float) -> None: - widths = MatchedFilter.get_width_spacing(nbins_max, spacing_factor) + @pytest.mark.parametrize(("size_max", "spacing_factor"), [(16, 1.2), (64, 2.0)]) + def test_get_box_width_spacing(self, size_max: int, spacing_factor: float) -> None: + widths = MatchedFilter.get_box_width_spacing(size_max, spacing_factor) assert isinstance(widths, np.ndarray) assert widths[0] == 1 assert np.all(np.diff(widths) > 0) - assert widths[-1] <= nbins_max + assert widths[-1] <= size_max class TestTemplate: @@ -62,7 +69,7 @@ def test_boxcar(self, width: int) -> None: assert isinstance(temp, Template) assert temp.kind == "boxcar" assert temp.width == width - assert temp.size == width + assert temp.data.size == width assert str(temp) == repr(temp) assert "Template" in str(temp) @@ -71,24 +78,46 @@ def test_gaussian(self, width: float, extent: float) -> None: temp = Template.gen_gaussian(width, extent) assert isinstance(temp, Template) assert temp.kind == "gaussian" - assert temp.width == width - assert temp.size == int(np.ceil(extent * width / 2.355) * 2 + 1) + np.testing.assert_equal(temp.width, width) + expected_size = int(np.ceil(extent * width / 2.355) * 2 + 1) + np.testing.assert_equal(temp.data.size, expected_size) @pytest.mark.parametrize(("width", "extent"), [(1.0, 3.5), (5.0, 4.0)]) def test_lorentzian(self, width: float, extent: float) -> None: temp = Template.gen_lorentzian(width, extent) assert isinstance(temp, Template) assert temp.kind == "lorentzian" - assert temp.width == width - assert temp.size == int(np.ceil(extent * width / 2.355) * 2 + 1) + np.testing.assert_equal(temp.width, width) + expected_size = int(np.ceil(extent * width / 2.355) * 2 + 1) + np.testing.assert_equal(temp.data.size, expected_size) - def test_get_padded(self) -> None: - temp = Template.gen_boxcar(5) - padded = temp.get_padded(10) - assert len(padded) == 10 - assert np.all(padded[5:] == 0) + def test_fails(self) -> None: + with pytest.raises(ValueError): + Template.gen_boxcar(-1) + with pytest.raises(ValueError): + Template.gen_gaussian(-1) with pytest.raises(ValueError): - temp.get_padded(3) + Template.gen_lorentzian(-1) + with pytest.raises(ValueError): + Template(np.array([]), 5, 2) + with pytest.raises(ValueError): + Template(np.zeros((10, 10)), 5, 2) + with pytest.raises(ValueError): + Template(np.zeros(10), 5, 20) + + def test_get_model(self) -> None: + temp = Template.gen_boxcar(5) + model = temp.get_model(peak_bin=10, nbins=20) + np.testing.assert_equal(model.size, 20) + assert np.all(model[10:15] > 0) + + def test_get_on_pulse(self) -> None: + temp = Template.gen_boxcar(5) + on_pulse = temp.get_on_pulse(peak_bin=10, nbins=20) + np.testing.assert_equal(on_pulse, (10, 15)) + temp = Template.gen_gaussian(5) + on_pulse = temp.get_on_pulse(peak_bin=10, nbins=20) + np.testing.assert_equal(on_pulse, (5, 15)) def test_plot(self) -> None: temp = Template.gen_gaussian(5) @@ -99,7 +128,7 @@ def test_plot(self) -> None: @pytest.mark.parametrize("temp_kind", ["boxcar", "gaussian", "lorentzian"]) def test_end_to_end(pulse_data: np.ndarray, temp_kind: str) -> None: - mf = MatchedFilter(pulse_data, temp_kind=temp_kind) + mf = MatchedFilter(pulse_data, temp_kind=temp_kind) # type: ignore[arg-type] assert mf.snr > 5 # Assuming a strong pulse assert 450 < mf.peak_bin < 550 # Assuming pulse is roughly in the middle - + assert isinstance(mf.best_model, np.ndarray) diff --git a/tests/test_stats.py b/tests/test_stats.py index 659e10a..dcb43d8 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -126,7 +126,7 @@ def test_zscore_methods( loc_method: str, scale_method: str, ) -> None: - zscore_re = stats.zscore( + zscore_re = stats.estimate_zscore( random_normal, loc_method=loc_method, scale_method=scale_method, @@ -134,30 +134,30 @@ def test_zscore_methods( assert isinstance(zscore_re, stats.ZScoreResult) np.testing.assert_almost_equal(5, zscore_re.loc, decimal=1) np.testing.assert_allclose(2, zscore_re.scale, atol=0.3) - assert zscore_re.zscores.shape == random_normal.shape - np.testing.assert_almost_equal(0, zscore_re.zscores.mean(), decimal=1) - np.testing.assert_almost_equal(1, zscore_re.zscores.std(), decimal=1) + assert zscore_re.data.shape == random_normal.shape + np.testing.assert_almost_equal(0, zscore_re.data.mean(), decimal=1) + np.testing.assert_almost_equal(1, zscore_re.data.std(), decimal=1) def test_zscore_double_mad(self, random_normal: np.ndarray) -> None: - zscore_re = stats.zscore(random_normal, scale_method="doublemad") + zscore_re = stats.estimate_zscore(random_normal, scale_method="doublemad") assert isinstance(zscore_re, stats.ZScoreResult) np.testing.assert_almost_equal(5, zscore_re.loc, decimal=1) assert isinstance(zscore_re.scale, np.ndarray) np.testing.assert_allclose(2, zscore_re.scale.mean(), atol=0.3) assert zscore_re.scale.shape == random_normal.shape - assert zscore_re.zscores.shape == random_normal.shape - np.testing.assert_almost_equal(0, zscore_re.zscores.mean(), decimal=1) - np.testing.assert_almost_equal(1, zscore_re.zscores.std(), decimal=1) + assert zscore_re.data.shape == random_normal.shape + np.testing.assert_almost_equal(0, zscore_re.data.mean(), decimal=1) + np.testing.assert_almost_equal(1, zscore_re.data.std(), decimal=1) def test_zscore_empty(self) -> None: with pytest.raises(ValueError): - stats.zscore(np.array([])) + stats.estimate_zscore(np.array([])) def test_zscore_invalid(self, random_normal: np.ndarray) -> None: with pytest.raises(ValueError): - stats.zscore(random_normal, loc_method="invalid") + stats.estimate_zscore(random_normal, loc_method="invalid") with pytest.raises(ValueError): - stats.zscore(random_normal, scale_method="invalid") + stats.estimate_zscore(random_normal, scale_method="invalid") class TestRunningFilter: From 0f893766fc724b3ddf0f98f4cf6f8acd45f67d76 Mon Sep 17 00:00:00 2001 From: Pravir Kumar Date: Mon, 16 Sep 2024 03:14:26 +0300 Subject: [PATCH 2/2] added tests for new njit functions --- tests/test_kernels.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/test_kernels.py b/tests/test_kernels.py index 61f237f..531ed62 100644 --- a/tests/test_kernels.py +++ b/tests/test_kernels.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from numba import typed from scipy import signal, stats from sigpyproc.core import kernels @@ -347,6 +348,34 @@ def test_detrend_1d_fail(self) -> None: with pytest.raises(ValueError): kernels.detrend_1d.py_func(np.array([])) + def test_normalize_template(self) -> None: + arr = np.ones(10, dtype=np.float32) + temp = np.pad(arr, (0, 100 - arr.size), mode="constant") + temp_norm = kernels.normalize_template(temp) + np.testing.assert_equal(temp_norm.size, temp.size) + np.testing.assert_array_almost_equal(np.mean(temp_norm), 0.0, decimal=5) + np.testing.assert_array_almost_equal( + np.sqrt(np.sum(temp_norm**2)), + 1.0, + decimal=5, + ) + kernels.normalize_template.py_func(temp) + + def test_circular_pad_pow2(self, random_normal_1d: np.ndarray) -> None: + padded = kernels.circular_pad_pow2(random_normal_1d) + np.testing.assert_equal(padded.size, 1024) + np.testing.assert_array_equal(padded[:1000], random_normal_1d) + np.testing.assert_array_equal(padded[1000:], random_normal_1d[:24]) + kernels.circular_pad_pow2.py_func(random_normal_1d) + + def test_convolve_fft(self, random_normal_1d: np.ndarray) -> None: + samp_temps = typed.List([np.array([0.5, 1.0, 0.5]), np.array([1.0, -1.0])]) + ref_bins = typed.List([1, 0]) + result = kernels.convolve_fft(random_normal_1d, samp_temps, ref_bins) + assert isinstance(result, np.ndarray) + assert result.shape == (len(samp_temps), len(random_normal_1d)) + kernels.convolve_fft.py_func(random_normal_1d, samp_temps, ref_bins) + class TestFourierKernels: def test_form_mspec(self, random_normal_1d_complex: np.ndarray) -> None: