diff --git a/Makefile b/Makefile index be7d4afa..5132ca7d 100644 --- a/Makefile +++ b/Makefile @@ -60,23 +60,32 @@ pydocstyle: @pydocstyle pep: - @$(MAKE) -k flake pydocstyle codespell-error - -flake: - @if command -v flake8 > /dev/null; then \ - echo "Running flake8"; \ - flake8 --count meegkit examples; \ - else \ - echo "flake8 not found, please install it!"; \ - exit 1; \ - fi; - @echo "flake8 passed" + @$(MAKE) -k ruff codespell -# Tests +ruff: + @ruff check $(CODESPELL_DIRS) + +ruff-fix: + @ruff check $(CODESPELL_DIRS) --fix + +# Build and install # ============================================================================= -# test: -# py.test tests +install-requirements: + @echo "Checking/Installing requirements..." + @pip install -q -r requirements.in + +install: + @echo "Installing package..." + @pip install -q --no-deps . + @echo "\x1b[1m\x1b[32m * Package successfully installed! \x1b[0m" + +install-dev: + @echo "Installing package in editable mode..." + @pip install -q -e ".[docs, tests]" --config-settings editable_mode=compat + @echo "\x1b[1m\x1b[32m * Package successfully installed! \x1b[0m" +# Tests +# ============================================================================= test: in rm -f .coverage $(PYTESTS) -m 'not ultraslowtest' meegkit @@ -94,3 +103,6 @@ test-full: in $(PYTESTS) meegkit .PHONY: init test + + + diff --git a/README.md b/README.md index d5dc75bd..809e023e 100644 --- a/README.md +++ b/README.md @@ -5,19 +5,23 @@ [![DOI](https://zenodo.org/badge/117451752.svg)](https://zenodo.org/badge/latestdoi/117451752) [![twitter](https://img.shields.io/twitter/follow/lebababa?style=flat&logo=Twitter)](https://twitter.com/intent/follow?screen_name=lebababa) -# MEEGkit +# `MEEGkit` -Denoising tools for M/EEG processing in Python 3.7+. +Denoising tools for M/EEG processing in Python 3.8+. ![meegkit-ERP](https://user-images.githubusercontent.com/10333715/176754293-eaa35071-94f8-40dd-a487-9f8103c92571.png) -> **Disclaimer:** The project mostly consists of development code, although some modules and functions are already working. Bugs and performance problems are to be expected, so use at your own risk. More tests and improvements will be added in the future. Comments and suggestions are welcome. +> **Disclaimer:** The project mostly consists of development code, although some modules +and functions are already working. Bugs and performance problems are to be expected, so +use at your own risk. More tests and improvements will be added in the future. Comments +and suggestions are welcome. ## Documentation Automatic documentation is [available online](https://nbara.github.io/python-meegkit/). -This code can also be tested directly from your browser using [Binder](https://mybinder.org), by clicking on the binder badge above. +This code can also be tested directly from your browser using +[Binder](https://mybinder.org), by clicking on the binder badge above. ## Installation @@ -27,18 +31,21 @@ This package can be installed easily using `pip`: pip install meegkit ``` -Or you can clone this repository and run the following commands inside the `python-meegkit` directory: +Or you can clone this repository and run the following commands inside the +`python-meegkit` directory: ```bash pip install -r requirements.txt pip install . ``` -*Note* : Use developer mode with the `-e` flag (`pip install -e .`) to be able to modify the sources even after install. +*Note* : Use developer mode with the `-e` flag (`pip install -e .`) to be able to modify +the sources even after install. ### Advanced installation instructions -Some ASR variants require additional dependencies such as `pymanopt`. To install meegkit with these optional packages, use: +Some ASR variants require additional dependencies such as `pymanopt`. To install meegkit +with these optional packages, use: ```bash pip install -e '.[extra]' @@ -50,90 +57,95 @@ or: pip install meegkit[extra] ``` -Other available options are `[docs]` (which installs dependencies required to build the documentation), or `[tests]` (which install dependencies to run unit tests). +Other available options are `[docs]` (which installs dependencies required to build the +documentation), or `[tests]` (which install dependencies to run unit tests). ## References -### 1. CCA, STAR, SNS, DSS, ZapLine, and Robust Detrending +If you use this code, you should cite the relevant methods from the original articles. -This is mostly a translation of Matlab code from the [NoiseTools toolbox](http://audition.ens.fr/adc/NoiseTools/) by Alain de Cheveigné. It builds on an initial python implementation by [Pedro Alcocer](https://github.com/pealco). +### 1. CCA, STAR, SNS, DSS, ZapLine, and Robust Detrending -Only CCA, SNS, DSS, STAR, ZapLine and robust detrending have been properly tested so far. TSCPA may give inaccurate results due to insufficient testing (contributions welcome!) +This is mostly a translation of Matlab code from the +[NoiseTools toolbox](http://audition.ens.fr/adc/NoiseTools/) by Alain de Cheveigné. +It builds on an initial python implementation by +[Pedro Alcocer](https://github.com/pealco). -If you use this code, you should cite the relevant methods from the original articles: +Only CCA, SNS, DSS, STAR, ZapLine and robust detrending have been properly tested so far. +TSCPA may give inaccurate results due to insufficient testing (contributions welcome!) ```sql -[1] de Cheveigné, A. (2019). ZapLine: A simple and effective method to remove power line artifacts. - NeuroImage, 116356. https://doi.org/10.1016/j.neuroimage.2019.116356 -[2] de Cheveigné, A. et al. (2019). Multiway canonical correlation analysis of brain data. - NeuroImage, 186, 728–740. https://doi.org/10.1016/j.neuroimage.2018.11.026 -[3] de Cheveigné, A. et al. (2018). Decoding the auditory brain with canonical component analysis. - NeuroImage, 172, 206–216. https://doi.org/10.1016/j.neuroimage.2018.01.033 -[4] de Cheveigné, A. (2016). Sparse time artifact removal. - Journal of Neuroscience Methods, 262, 14–20. https://doi.org/10.1016/j.jneumeth.2016.01.005 -[5] de Cheveigné, A., & Parra, L. C. (2014). Joint decorrelation, a versatile tool for multichannel - data analysis. NeuroImage, 98, 487–505. https://doi.org/10.1016/j.neuroimage.2014.05.068 -[6] de Cheveigné, A. (2012). Quadratic component analysis. - NeuroImage, 59(4), 3838–3844. https://doi.org/10.1016/j.neuroimage.2011.10.084 -[7] de Cheveigné, A. (2010). Time-shift denoising source separation. - Journal of Neuroscience Methods, 189(1), 113–120. https://doi.org/10.1016/j.jneumeth.2010.03.002 +[1] de Cheveigné, A. (2019). ZapLine: A simple and effective method to remove power line + artifacts. NeuroImage, 116356. https://doi.org/10.1016/j.neuroimage.2019.116356 +[2] de Cheveigné, A. et al. (2019). Multiway canonical correlation analysis of brain + data. NeuroImage, 186, 728–740. https://doi.org/10.1016/j.neuroimage.2018.11.026 +[3] de Cheveigné, A. et al. (2018). Decoding the auditory brain with canonical component + analysis. NeuroImage, 172, 206–216. https://doi.org/10.1016/j.neuroimage.2018.01.033 +[4] de Cheveigné, A. (2016). Sparse time artifact removal. Journal of Neuroscience + Methods, 262, 14–20. https://doi.org/10.1016/j.jneumeth.2016.01.005 +[5] de Cheveigné, A., & Parra, L. C. (2014). Joint decorrelation, a versatile tool for + multichannel data analysis. NeuroImage, 98, 487–505. + https://doi.org/10.1016/j.neuroimage.2014.05.068 +[6] de Cheveigné, A. (2012). Quadratic component analysis. NeuroImage, 59(4), 3838–3844. + https://doi.org/10.1016/j.neuroimage.2011.10.084 +[7] de Cheveigné, A. (2010). Time-shift denoising source separation. Journal of + Neuroscience Methods, 189(1), 113–120. https://doi.org/10.1016/j.jneumeth.2010.03.002 [8] de Cheveigné, A., & Simon, J. Z. (2008a). Denoising based on spatial filtering. - Journal of Neuroscience Methods, 171(2), 331–339. https://doi.org/10.1016/j.jneumeth.2008.03.015 -[9] de Cheveigné, A., & Simon, J. Z. (2008b). Sensor noise suppression. - Journal of Neuroscience Methods, 168(1), 195–202. https://doi.org/10.1016/j.jneumeth.2007.09.012 + Journal of Neuroscience Methods, 171(2), 331–339. + https://doi.org/10.1016/j.jneumeth.2008.03.015 +[9] de Cheveigné, A., & Simon, J. Z. (2008b). Sensor noise suppression. Journal of + Neuroscience Methods, 168(1), 195–202. https://doi.org/10.1016/j.jneumeth.2007.09.012 [10] de Cheveigné, A., & Simon, J. Z. (2007). Denoising based on time-shift PCA. - Journal of Neuroscience Methods, 165(2), 297–305. https://doi.org/10.1016/j.jneumeth.2007.06.003 + Journal of Neuroscience Methods, 165(2), 297–305. + https://doi.org/10.1016/j.jneumeth.2007.06.003 ``` ### 2. Artifact Subspace Reconstruction (ASR) -The base code is inspired from the original [EEGLAB inplementation](https://github.com/sccn/clean_rawdata) [1], while the riemannian variant [2] was adapted from the [rASR toolbox](https://github.com/s4rify/rASRMatlab) by Sarah Blum. - -If you use this code, you should cite the relevant methods from the original articles: +The base code is inspired from the original +[EEGLAB inplementation](https://github.com/sccn/clean_rawdata) [1], while the Riemannian +variant [2] was adapted from the [rASR toolbox](https://github.com/s4rify/rASRMatlab) by +Sarah Blum. ```sql -[1] Mullen, T. R., Kothe, C. A. E., Chi, Y. M., Ojeda, A., Kerth, T., Makeig, S., et al. (2015). - Real-time neuroimaging and cognitive monitoring using wearable dry EEG. IEEE Trans. Bio-Med. - Eng. 62, 2553–2567. https://doi.org/10.1109/TBME.2015.2481482 -[2] Blum, S., Jacobsen, N., Bleichner, M. G., & Debener, S. (2019). A Riemannian modification of - artifact subspace reconstruction for EEG artifact handling. Frontiers in human neuroscience, - 13, 141. +[1] Mullen, T. R., Kothe, C. A. E., Chi, Y. M., Ojeda, A., Kerth, T., Makeig, S., + et al. (2015). Real-time neuroimaging and cognitive monitoring using wearable dry + EEG. IEEE Trans. Bio-Med. Eng. 62, 2553–2567. + https://doi.org/10.1109/TBME.2015.2481482 +[2] Blum, S., Jacobsen, N., Bleichner, M. G., & Debener, S. (2019). A Riemannian + modification of artifact subspace reconstruction for EEG artifact handling. Frontiers + in human neuroscience, 13, 141. ``` ### 3. Rhythmic Entrainment Source Separation (RESS) The code is based on [Matlab code from Mike X. Cohen](https://mikexcohen.com/data/) [1] -If you use this, you should cite the following article: - ```sql -[1] Cohen, M. X., & Gulbinaite, R. (2017). Rhythmic entrainment source separation: Optimizing analyses - of neural responses to rhythmic sensory stimulation. Neuroimage, 147, 43-56. +[1] Cohen, M. X., & Gulbinaite, R. (2017). Rhythmic entrainment source separation: + Optimizing analyses of neural responses to rhythmic sensory stimulation. Neuroimage, + 147, 43-56. ``` ### 4. Task-Related Component Analysis (TRCA) -This code is based on the [Matlab implementation from Masaki Nakanishi](https://github.com/mnakanishi/TRCA-SSVEP), and was adapted to python by [Giuseppe Ferraro](mailto:giuseppe.ferraro@isae-supaero.fr) - -If you use this, you should cite the following articles: +This code is based on the [Matlab implementation from Masaki Nakanishi](https://github.com/mnakanishi/TRCA-SSVEP), +and was adapted to python by [Giuseppe Ferraro](mailto:giuseppe.ferraro@isae-supaero.fr) ```sql [1] M. Nakanishi, Y. Wang, X. Chen, Y.-T. Wang, X. Gao, and T.-P. Jung, - "Enhancing detection of SSVEPs for a high-speed brain speller using - task-related component analysis", IEEE Trans. Biomed. Eng, 65(1): 104-112, - 2018. -[2] X. Chen, Y. Wang, S. Gao, T. -P. Jung and X. Gao, "Filter bank - canonical correlation analysis for implementing a high-speed SSVEP-based - brain-computer interface", J. Neural Eng., 12: 046008, 2015. -[3] X. Chen, Y. Wang, M. Nakanishi, X. Gao, T. -P. Jung, S. Gao, - "High-speed spelling with a noninvasive brain-computer interface", - Proc. Int. Natl. Acad. Sci. U. S. A, 112(44): E6058-6067, 2015. + "Enhancing detection of SSVEPs for a high-speed brain speller using task-related + component analysis", IEEE Trans. Biomed. Eng, 65(1): 104-112, 2018. +[2] X. Chen, Y. Wang, S. Gao, T. -P. Jung and X. Gao, "Filter bank canonical correlation + analysis for implementing a high-speed SSVEP-based brain-computer interface", + J. Neural Eng., 12: 046008, 2015. +[3] X. Chen, Y. Wang, M. Nakanishi, X. Gao, T. -P. Jung, S. Gao, "High-speed spelling + with a noninvasive brain-computer interface", Proc. Int. Natl. Acad. Sci. U.S.A, + 112(44): E6058-6067, 2015. ``` ### 5. Local Outlier Factor (LOF) -If you use this, you should cite the following article: - ```sql [1] Breunig M, Kriegel HP, Ng RT, Sander J. 2000. LOF: identifying density-based local outliers. SIGMOD Rec. 29, 2, 93-104. https://doi.org/10.1145/335191.335388 diff --git a/citation.cff b/citation.cff index c2dec4cb..373edd39 100644 --- a/citation.cff +++ b/citation.cff @@ -5,7 +5,7 @@ authors: given-names: "Nicolas" orcid: "https://orcid.org/0000-0003-1495-561X" title: "MEEGkit" -version: 0.1.3 +version: 0.1.4 doi: 10.5281/zenodo.5643659 date-released: 2021-10-15 url: "https://github.com/nbara/python-meegkit" diff --git a/doc/_static/logo-dark.png b/doc/_static/logo-dark.png new file mode 100644 index 00000000..5d4f2070 Binary files /dev/null and b/doc/_static/logo-dark.png differ diff --git a/doc/_static/logo.png b/doc/_static/logo.png new file mode 100644 index 00000000..86b3a542 Binary files /dev/null and b/doc/_static/logo.png differ diff --git a/doc/conf.py b/doc/conf.py index e7d4314c..37568fef 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -12,20 +12,22 @@ # import os import sys -import matplotlib -matplotlib.use('agg') + +import matplotlib as mpl + +mpl.use("agg") curdir = os.path.dirname(__file__) -sys.path.append(os.path.abspath(os.path.join(curdir, '..'))) -sys.path.append(os.path.abspath(os.path.join(curdir, '..', 'meegkit'))) +sys.path.append(os.path.abspath(os.path.join(curdir, ".."))) +sys.path.append(os.path.abspath(os.path.join(curdir, "..", "meegkit"))) import meegkit # noqa # -- Project information ----------------------------------------------------- -project = 'MEEGkit' -copyright = '2022, Nicolas Barascud' -author = 'Nicolas Barascud' +project = "MEEGkit" +copyright = "2023, Nicolas Barascud" +author = "Nicolas Barascud" release = meegkit.__version__ version = meegkit.__version__ @@ -35,31 +37,31 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.napoleon', - 'numpydoc', - 'jupyter_sphinx', - 'sphinx_gallery.gen_gallery', - 'sphinxemoji.sphinxemoji', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "numpydoc", + "jupyter_sphinx", + "sphinx_gallery.gen_gallery", + "sphinxemoji.sphinxemoji", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'config.py'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "config.py"] # generate autosummary even if no references # autosummary_generate = True autodoc_default_options = { - 'members': True, - 'special-members': '__init__', - 'undoc-members': True, - 'show-inheritance': True, - 'exclude-members': '__weakref__' + "members": True, + "special-members": "__init__", + "undoc-members": True, + "show-inheritance": True, + "exclude-members": "__weakref__" } numpydoc_show_class_members = True @@ -75,15 +77,19 @@ # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages -html_theme = 'pydata_sphinx_theme' +html_theme = "pydata_sphinx_theme" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - +html_static_path = ["_static"] html_theme_options = { + "logo": { + "image_light": "_static/logo.png", + "image_dark": "_static/logo-dark.png", + "text": "meegkit", + }, "show_toc_level": 1, "external_links": [ { @@ -95,12 +101,12 @@ { "name": "GitHub", "url": "https://github.com/nbara/python-meegkit", - "icon": "fab fa-github-square", + "icon": "fa-brands fa-github", }, { "name": "Twitter", "url": "https://twitter.com/lebababa", - "icon": "fab fa-twitter-square", + "icon": "fa-brands fa-twitter", }, ], "use_edit_page_button": True, @@ -116,10 +122,10 @@ # -- Options for Sphinx-gallery HTML ------------------------------------------ sphinx_gallery_conf = { - 'doc_module': ('meegkit',), - 'examples_dirs': '../examples', # path to your example scripts - 'gallery_dirs': 'auto_examples', # path to where to save gallery generated output - 'filename_pattern': '/example_', - 'ignore_pattern': 'config.py', - 'run_stale_examples': False, + "doc_module": ("meegkit",), + "examples_dirs": "../examples", # path to your example scripts + "gallery_dirs": "auto_examples", # path to where to save gallery generated output + "filename_pattern": "/example_", + "ignore_pattern": "config.py", + "run_stale_examples": False, } diff --git a/doc/index.rst b/doc/index.rst index 23f60816..a88c893e 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -10,7 +10,7 @@ Introduction ------------ ``meegkit`` is a collection of EEG and MEG denoising techniques for -**Python 3.6+**. Please feel free to contribute, or suggest new analyses. Keep +**Python 3.8+**. Please feel free to contribute, or suggest new analyses. Keep in mind that this is mostly development code, and as such is likely to change without any notice. Also, while most of the methods have been fairly robustly tested, bugs can (and should!) be expected. diff --git a/examples/config.py b/examples/config.py index fd1a52a2..1c374858 100644 --- a/examples/config.py +++ b/examples/config.py @@ -1,7 +1,8 @@ """Configuration file for examples.""" import matplotlib.pyplot as plt from IPython.display import set_matplotlib_formats -set_matplotlib_formats('pdf', 'png') -plt.rcParams['savefig.dpi'] = 75 -plt.rcParams['figure.autolayout'] = False -plt.rcParams['figure.figsize'] = 10, 6 + +set_matplotlib_formats("pdf", "png") +plt.rcParams["savefig.dpi"] = 75 +plt.rcParams["figure.autolayout"] = False +plt.rcParams["figure.figsize"] = 10, 6 diff --git a/examples/example_asr.py b/examples/example_asr.py index d8634d6c..9d59a562 100644 --- a/examples/example_asr.py +++ b/examples/example_asr.py @@ -7,14 +7,15 @@ Uses meegkit.ASR(). """ import os -import numpy as np + import matplotlib.pyplot as plt +import numpy as np from meegkit.asr import ASR from meegkit.utils.matrix import sliding_window # THIS_FOLDER = os.path.dirname(os.path.abspath(__file__)) -raw = np.load(os.path.join('..', 'tests', 'data', 'eeg_raw.npy')) +raw = np.load(os.path.join("..", "tests", "data", "eeg_raw.npy")) sfreq = 250 ############################################################################### @@ -22,7 +23,7 @@ # ----------------------------------------------------------------------------- # Train on a clean portion of data -asr = ASR(method='euclid') +asr = ASR(method="euclid") train_idx = np.arange(0 * sfreq, 30 * sfreq, dtype=int) _, sample_mask = asr.fit(raw[:, train_idx]) @@ -46,20 +47,20 @@ times = np.arange(raw.shape[-1]) / sfreq f, ax = plt.subplots(8, sharex=True, figsize=(8, 5)) for i in range(8): - ax[i].fill_between(train_idx / sfreq, 0, 1, color='grey', alpha=.3, + ax[i].fill_between(train_idx / sfreq, 0, 1, color="grey", alpha=.3, transform=ax[i].get_xaxis_transform(), - label='calibration window') + label="calibration window") ax[i].fill_between(train_idx / sfreq, 0, 1, where=sample_mask.flat, transform=ax[i].get_xaxis_transform(), - facecolor='none', hatch='...', edgecolor='k', - label='selected window') - ax[i].plot(times, raw[i], lw=.5, label='before ASR') - ax[i].plot(times, clean[i], label='after ASR', lw=.5) + facecolor="none", hatch="...", edgecolor="k", + label="selected window") + ax[i].plot(times, raw[i], lw=.5, label="before ASR") + ax[i].plot(times, clean[i], label="after ASR", lw=.5) ax[i].set_ylim([-50, 50]) - ax[i].set_ylabel(f'ch{i}') + ax[i].set_ylabel(f"ch{i}") ax[i].set_yticks([]) -ax[i].set_xlabel('Time (s)') -ax[0].legend(fontsize='small', bbox_to_anchor=(1.04, 1), borderaxespad=0) +ax[i].set_xlabel("Time (s)") +ax[0].legend(fontsize="small", bbox_to_anchor=(1.04, 1), borderaxespad=0) plt.subplots_adjust(hspace=0, right=0.75) -plt.suptitle('Before/after ASR') +plt.suptitle("Before/after ASR") plt.show() diff --git a/examples/example_dering.py b/examples/example_dering.py index 2a84f3ad..a5fee882 100644 --- a/examples/example_dering.py +++ b/examples/example_dering.py @@ -8,14 +8,11 @@ """ import matplotlib.pyplot as plt import numpy as np - from scipy.signal import butter, lfilter from meegkit.detrend import reduce_ringing -# import config # plotting utils - -np.random.seed(9) +rng = np.random.default_rng(9) ############################################################################### # Detrending @@ -32,11 +29,11 @@ [b, a] = butter(6, 0.2) # Butterworth filter design x = lfilter(b, a, x) * 50 # Filter data using above filter x = np.roll(x, 500) -x = x[:, None] + np.random.randn(1000, 2) +x = x[:, None] + rng.standard_normal((1000, 2)) y = reduce_ringing(x, samples=np.array([500])) plt.figure() -plt.plot(x + np.array([-10, 10]), 'C0', label='before') -plt.plot(y + np.array([-10, 10]), 'C1:', label='after') +plt.plot(x + np.array([-10, 10]), "C0", label="before") +plt.plot(y + np.array([-10, 10]), "C1:", label="after") plt.legend() plt.show() diff --git a/examples/example_detrend.py b/examples/example_detrend.py index feadb091..d1a377cb 100644 --- a/examples/example_detrend.py +++ b/examples/example_detrend.py @@ -17,11 +17,11 @@ import numpy as np from matplotlib.gridspec import GridSpec -from meegkit.detrend import regress, detrend +from meegkit.detrend import detrend, regress # import config # plotting utils -np.random.seed(9) +rng = np.random.default_rng(9) ############################################################################### # Regression @@ -31,15 +31,15 @@ # Simple regression example, no weights # ----------------------------------------------------------------------------- # We first try to fit a simple random walk process. -x = np.cumsum(np.random.randn(1000, 1), axis=0) +x = np.cumsum(rng.standard_normal((1000, 1)), axis=0) r = np.arange(1000.)[:, None] r = np.hstack([r, r ** 2, r ** 3]) b, y = regress(x, r) plt.figure(1) -plt.plot(x, label='data') -plt.plot(y, label='fit') -plt.title('No weights') +plt.plot(x, label="data") +plt.plot(y, label="fit") +plt.title("No weights") plt.legend() plt.show() @@ -49,7 +49,7 @@ # We can also use weights for each time sample. Here we explicitly restrict the # fit to the second half of the data by setting weights to zero for the first # 500 samples. -x = np.cumsum(np.random.randn(1000, 1), axis=0) + 1000 +x = np.cumsum(rng.standard_normal((1000, 1)), axis=0) + 1000 w = np.ones(y.shape[0]) w[:500] = 0 b, y = regress(x, r, w) @@ -57,27 +57,27 @@ f = plt.figure(3) gs = GridSpec(4, 1, figure=f) ax1 = f.add_subplot(gs[:3, 0]) -ax1.plot(x, label='data') -ax1.plot(y, label='fit') -ax1.set_xticklabels('') -ax1.set_title('Split-wise regression') +ax1.plot(x, label="data") +ax1.plot(y, label="fit") +ax1.set_xticklabels("") +ax1.set_title("Split-wise regression") ax1.legend() ax2 = f.add_subplot(gs[3, 0]) -l, = ax2.plot(np.arange(1000), np.zeros(1000)) -ax2.stackplot(np.arange(1000), w, labels=['weights'], color=l.get_color()) +ll, = ax2.plot(np.arange(1000), np.zeros(1000)) +ax2.stackplot(np.arange(1000), w, labels=["weights"], color=ll.get_color()) ax2.legend(loc=2) ############################################################################### # Multichannel regression # ----------------------------------------------------------------------------- -x = np.cumsum(np.random.randn(1000, 2), axis=0) +x = np.cumsum(rng.standard_normal((1000, 2)), axis=0) w = np.ones(y.shape[0]) b, y = regress(x, r, w) plt.figure(4) -plt.plot(x, label='data', color='C0') -plt.plot(y, ls=':', label='fit', color='C1') -plt.title('Channel-wise regression') +plt.plot(x, label="data", color="C0") +plt.plot(y, ls=":", label="fit", color="C1") +plt.title("Channel-wise regression") plt.legend() @@ -89,23 +89,23 @@ # Basic example with a linear trend # ----------------------------------------------------------------------------- x = np.arange(100)[:, None] -x = x + np.random.randn(*x.shape) +x = x + rng.standard_normal(x.shape) y, _, _ = detrend(x, 1) plt.figure(5) -plt.plot(x, label='original') -plt.plot(y, label='detrended') +plt.plot(x, label="original") +plt.plot(y, label="detrended") plt.legend() ############################################################################### # Detrend biased random walk with a third-order polynomial # ----------------------------------------------------------------------------- -x = np.cumsum(np.random.randn(1000, 1) + 0.1) +x = np.cumsum(rng.standard_normal((1000, 1)) + 0.1) y, _, _ = detrend(x, 3) plt.figure(6) -plt.plot(x, label='original') -plt.plot(y, label='detrended') +plt.plot(x, label="original") +plt.plot(y, label="detrended") plt.legend() ############################################################################### @@ -119,7 +119,7 @@ # glitch, leading to a mediocre fit. When downweightining this artifactual # period, the fit is much improved (green trace). x = np.linspace(0, 100, 1000)[:, None] -x = x + 3 * np.random.randn(*x.shape) +x = x + 3 * rng.standard_normal(x.shape) # introduce some strong artifact on the first 100 samples x[:100, :] = 100 @@ -133,8 +133,8 @@ z, _, _ = detrend(x, 3, w) plt.figure(7) -plt.plot(x, label='original') -plt.plot(y, label='detrended - no weights') -plt.plot(z, label='detrended - weights') +plt.plot(x, label="original") +plt.plot(y, label="detrended - no weights") +plt.plot(z, label="detrended - weights") plt.legend() plt.show() diff --git a/examples/example_dss.py b/examples/example_dss.py index bb5de9e3..eb4a3b87 100644 --- a/examples/example_dss.py +++ b/examples/example_dss.py @@ -11,9 +11,9 @@ import numpy as np from meegkit import dss -from meegkit.utils import unfold, rms, fold, tscov +from meegkit.utils import fold, rms, tscov, unfold -# import config +rng = np.random.default_rng(5) ############################################################################### # Create simulated data @@ -30,14 +30,14 @@ np.zeros((n_samples // 3,)), np.sin(2 * np.pi * np.arange(n_samples // 3) / (n_samples / 3)).T, np.zeros((n_samples // 3,))))[np.newaxis].T -s = source * np.random.randn(1, n_chans) # 300 * 30 +s = source * rng.standard_normal((1, n_chans)) # 300 * 30 s = s[:, :, np.newaxis] s = np.tile(s, (1, 1, 100)) # Noise noise = np.dot( - unfold(np.random.randn(n_samples, noise_dim, n_trials)), - np.random.randn(noise_dim, n_chans)) + unfold(rng.standard_normal((n_samples, noise_dim, n_trials))), + rng.standard_normal((noise_dim, n_chans))) noise = fold(noise, n_samples) # Mix signal and noise @@ -66,8 +66,8 @@ # Plot results # ----------------------------------------------------------------------------- f, (ax1, ax2, ax3) = plt.subplots(3, 1) -ax1.plot(source, label='source') -ax2.plot(np.mean(data, 2), label='data') -ax3.plot(best_comp, label='recovered') +ax1.plot(source, label="source") +ax2.plot(np.mean(data, 2), label="data") +ax3.plot(best_comp, label="recovered") plt.legend() plt.show() diff --git a/examples/example_dss_line.py b/examples/example_dss_line.py index cc5f6ab9..b014b0d7 100644 --- a/examples/example_dss_line.py +++ b/examples/example_dss_line.py @@ -18,9 +18,10 @@ import matplotlib.pyplot as plt import numpy as np +from scipy import signal + from meegkit import dss from meegkit.utils import create_line_data, unfold -from scipy import signal ############################################################################### # Line noise removal @@ -48,11 +49,11 @@ ax[0].semilogy(f, Pxx) f, Pxx = signal.welch(out, sfreq, nperseg=500, axis=0, return_onesided=True) ax[1].semilogy(f, Pxx) -ax[0].set_xlabel('frequency [Hz]') -ax[1].set_xlabel('frequency [Hz]') -ax[0].set_ylabel('PSD [V**2/Hz]') -ax[0].set_title('before') -ax[1].set_title('after') +ax[0].set_xlabel("frequency [Hz]") +ax[1].set_xlabel("frequency [Hz]") +ax[0].set_ylabel("PSD [V**2/Hz]") +ax[0].set_title("before") +ax[1].set_title("after") plt.show() @@ -60,7 +61,7 @@ # Remove line noise with dss_line_iter() # ----------------------------------------------------------------------------- # We first load some noisy data to work with -data = np.load(os.path.join('..', 'tests', 'data', 'dss_line_data.npy')) +data = np.load(os.path.join("..", "tests", "data", "dss_line_data.npy")) fline = 50 sfreq = 200 print(data.shape) # n_samples, n_chans, n_trials @@ -72,7 +73,7 @@ # Now try dss_line_iter(). This applies dss_line() repeatedly until the # artifact is gone out2, iterations = dss.dss_line_iter(data, fline, sfreq, nfft=400) -print(f'Removed {iterations} components') +print(f"Removed {iterations} components") ############################################################################### # Plot results with dss_line() vs. dss_line_iter() @@ -83,10 +84,10 @@ f, Pxx = signal.welch(unfold(out2), sfreq, nperseg=200, axis=0, return_onesided=True) ax[1].semilogy(f, Pxx, lw=.5) -ax[0].set_xlabel('frequency [Hz]') -ax[1].set_xlabel('frequency [Hz]') -ax[0].set_ylabel('PSD [V**2/Hz]') -ax[0].set_title('dss_line') -ax[1].set_title('dss_line_iter') +ax[0].set_xlabel("frequency [Hz]") +ax[1].set_xlabel("frequency [Hz]") +ax[0].set_ylabel("PSD [V**2/Hz]") +ax[0].set_title("dss_line") +ax[1].set_title("dss_line_iter") plt.tight_layout() plt.show() diff --git a/examples/example_mcca.py b/examples/example_mcca.py index ab5c8c76..d2f13400 100644 --- a/examples/example_mcca.py +++ b/examples/example_mcca.py @@ -12,7 +12,7 @@ from meegkit import cca -# import config +rng = np.random.default_rng(5) ############################################################################### # First example @@ -22,12 +22,12 @@ ############################################################################### # Build data -x1 = np.random.randn(10000, 10) -x2 = np.random.randn(10000, 10) -x3 = np.random.randn(10000, 10) +x1 = rng.standard_normal((10000, 10)) +x2 = rng.standard_normal((10000, 10)) +x3 = rng.standard_normal((10000, 10)) x = np.hstack((x1, x2, x3)) C = np.dot(x.T, x) -print('Aggregated data covariance shape: {}'.format(C.shape)) +print(f"Aggregated data covariance shape: {C.shape}") ############################################################################### # Apply CCA @@ -37,14 +37,14 @@ ############################################################################### # Plot results f, axes = plt.subplots(1, 3, figsize=(12, 4)) -axes[0].imshow(A, aspect='auto') -axes[0].set_title('mCCA transform matrix') -axes[1].imshow(A.T.dot(C.dot(A)), aspect='auto') -axes[1].set_title('Covariance of\ntransformed data') -axes[2].imshow(x.T.dot((x.dot(A))), aspect='auto') -axes[2].set_title('Cross-correlation between\nraw & transformed data') -axes[2].set_xlabel('transformed') -axes[2].set_ylabel('raw') +axes[0].imshow(A, aspect="auto") +axes[0].set_title("mCCA transform matrix") +axes[1].imshow(A.T.dot(C.dot(A)), aspect="auto") +axes[1].set_title("Covariance of\ntransformed data") +axes[2].imshow(x.T.dot(x.dot(A)), aspect="auto") +axes[2].set_title("Cross-correlation between\nraw & transformed data") +axes[2].set_xlabel("transformed") +axes[2].set_ylabel("raw") plt.plot(np.mean(z ** 2, axis=0)) plt.show() @@ -55,13 +55,13 @@ ############################################################################### # Build data -x1 = np.random.randn(10000, 5) -x2 = np.random.randn(10000, 5) -x3 = np.random.randn(10000, 5) -x4 = np.random.randn(10000, 5) +x1 = rng.standard_normal((10000, 5)) +x2 = rng.standard_normal((10000, 5)) +x3 = rng.standard_normal((10000, 5)) +x4 = rng.standard_normal((10000, 5)) x = np.hstack((x2, x1, x3, x1, x4, x1)) C = np.dot(x.T, x) -print('Aggregated data covariance shape: {}'.format(C.shape)) +print(f"Aggregated data covariance shape: {C.shape}") ############################################################################### # Apply mCCA @@ -70,14 +70,14 @@ ############################################################################### # Plot results f, axes = plt.subplots(1, 3, figsize=(12, 4)) -axes[0].imshow(A, aspect='auto') -axes[0].set_title('mCCA transform matrix') -axes[1].imshow(A.T.dot(C.dot(A)), aspect='auto') -axes[1].set_title('Covariance of\ntransformed data') -axes[2].imshow(x.T.dot((x.dot(A))), aspect='auto') -axes[2].set_title('Cross-correlation between\nraw & transformed data') -axes[2].set_xlabel('transformed') -axes[2].set_ylabel('raw') +axes[0].imshow(A, aspect="auto") +axes[0].set_title("mCCA transform matrix") +axes[1].imshow(A.T.dot(C.dot(A)), aspect="auto") +axes[1].set_title("Covariance of\ntransformed data") +axes[2].imshow(x.T.dot(x.dot(A)), aspect="auto") +axes[2].set_title("Cross-correlation between\nraw & transformed data") +axes[2].set_xlabel("transformed") +axes[2].set_ylabel("raw") plt.show() ############################################################################### @@ -90,10 +90,10 @@ ############################################################################### # Build data -x1 = np.random.randn(10000, 10) +x1 = rng.standard_normal((10000, 10)) x = np.hstack((x1, x1, x1)) C = np.dot(x.T, x) -print('Aggregated data covariance shape: {}'.format(C.shape)) +print(f"Aggregated data covariance shape: {C.shape}") ############################################################################### # Compute mCCA @@ -102,12 +102,12 @@ ############################################################################### # Plot results f, axes = plt.subplots(1, 3, figsize=(12, 4)) -axes[0].imshow(A, aspect='auto') -axes[0].set_title('mCCA transform matrix') -axes[1].imshow(A.T.dot(C.dot(A)), aspect='auto') -axes[1].set_title('Covariance of\ntransformed data') -axes[2].imshow(x.T.dot((x.dot(A))), aspect='auto') -axes[2].set_title('Cross-correlation between\nraw & transformed data') -axes[2].set_xlabel('transformed') -axes[2].set_ylabel('raw') +axes[0].imshow(A, aspect="auto") +axes[0].set_title("mCCA transform matrix") +axes[1].imshow(A.T.dot(C.dot(A)), aspect="auto") +axes[1].set_title("Covariance of\ntransformed data") +axes[2].imshow(x.T.dot(x.dot(A)), aspect="auto") +axes[2].set_title("Cross-correlation between\nraw & transformed data") +axes[2].set_xlabel("transformed") +axes[2].set_ylabel("raw") plt.show() diff --git a/examples/example_ress.py b/examples/example_ress.py index 903de7e5..b328c4c7 100644 --- a/examples/example_ress.py +++ b/examples/example_ress.py @@ -12,12 +12,13 @@ import matplotlib.pyplot as plt import numpy as np import scipy.signal as ss + from meegkit import ress from meegkit.utils import fold, matmul3d, rms, snr_spectrum, unfold # import config -np.random.seed(1) +rng = np.random.default_rng(9) ############################################################################### # Create synthetic data @@ -34,7 +35,7 @@ # source source = np.sin(2 * np.pi * target * np.arange(n_times - t0) / sfreq)[None].T -s = source * np.random.randn(1, n_chans) +s = source * rng.standard_normal((1, n_chans)) s = s[:, :, np.newaxis] s = np.tile(s, (1, 1, n_trials)) signal = np.zeros((n_times, n_chans, n_trials)) @@ -42,8 +43,8 @@ # noise noise = np.dot( - unfold(np.random.randn(n_times, noise_dim, n_trials)), - np.random.randn(noise_dim, n_chans)) + unfold(rng.standard_normal((n_times, noise_dim, n_trials))), + rng.standard_normal((noise_dim, n_chans))) noise = fold(noise, n_times) # mix signal and noise @@ -53,9 +54,9 @@ # Plot f, ax = plt.subplots(3) -ax[0].plot(signal[:, 0, 0], c='C0', label='source') -ax[1].plot(noise[:, 1, 0], c='C1', label='noise') -ax[2].plot(data[:, 1, 0], c='C2', label='mixture') +ax[0].plot(signal[:, 0, 0], c="C0", label="source") +ax[1].plot(noise[:, 1, 0], c="C1", label="noise") +ax[2].plot(data[:, 1, 0], c="C2", label="mixture") ax[0].legend() ax[1].legend() ax[2].legend() @@ -76,12 +77,12 @@ snr = snr_spectrum(psd, bins, skipbins=2, n_avg=2) f, ax = plt.subplots(1) -ax.plot(bins, snr, 'o', label='SNR') -ax.plot(bins[bins == target], snr[bins == target], 'ro', label='Target SNR') -ax.axhline(1, ls=':', c='grey', zorder=0) -ax.axvline(target, ls=':', c='grey', zorder=0) -ax.set_ylabel('SNR (a.u.)') -ax.set_xlabel('Frequency (Hz)') +ax.plot(bins, snr, "o", label="SNR") +ax.plot(bins[bins == target], snr[bins == target], "ro", label="Target SNR") +ax.axhline(1, ls=":", c="grey", zorder=0) +ax.axvline(target, ls=":", c="grey", zorder=0) +ax.set_ylabel("SNR (a.u.)") +ax.set_xlabel("Frequency (Hz)") ax.set_xlim([0, 40]) ############################################################################### @@ -89,15 +90,15 @@ # average SSVEP. proj = matmul3d(out, maps) -f, ax = plt.subplots(n_chans, 2, sharey='col') +f, ax = plt.subplots(n_chans, 2, sharey="col") for c in range(n_chans): ax[c, 0].plot(data[:, c].mean(-1), lw=.5) ax[c, 1].plot(proj[:, c].mean(-1), lw=.5) - ax[c, 0].set_ylabel(f'ch{c}') + ax[c, 0].set_ylabel(f"ch{c}") if c < n_chans: ax[c, 0].set_xticks([]) ax[c, 1].set_xticks([]) -ax[0, 0].set_title('Trial average (before)') -ax[0, 1].set_title('Trial average (after)') +ax[0, 0].set_title("Trial average (before)") +ax[0, 1].set_title("Trial average (after)") plt.show() diff --git a/examples/example_star.py b/examples/example_star.py index 694a79c8..9255630a 100644 --- a/examples/example_star.py +++ b/examples/example_star.py @@ -12,7 +12,7 @@ from meegkit import star from meegkit.utils import demean, normcol -# import config +rng = np.random.default_rng(9) ############################################################################### # Create simulated data @@ -26,12 +26,12 @@ f = 2 target = np.sin(np.arange(n_samples) / n_samples * 2 * np.pi * f) target = target[:, np.newaxis] -noise = np.random.randn(n_samples, nchans - 3) +noise = rng.standard_normal((n_samples, nchans - 3)) # Create artifact signal SNR = np.sqrt(1) -x0 = (normcol(np.dot(noise, np.random.randn(noise.shape[1], nchans))) + - SNR * target * np.random.randn(1, nchans)) +x0 = normcol(np.dot(noise, rng.standard_normal((noise.shape[1], nchans)))) + \ + SNR * target * rng.standard_normal((1, nchans)) x0 = demean(x0) artifact = np.zeros(x0.shape) for k in np.arange(nchans): @@ -54,10 +54,10 @@ # ----------------------------------------------------------------------------- f, (ax1, ax2, ax3) = plt.subplots(3, 1) ax1.plot(x, lw=.5) -ax1.set_title('Signal + Artifacts (SNR = {})'.format(SNR)) +ax1.set_title(f"Signal + Artifacts (SNR = {SNR})") ax2.plot(y, lw=.5) -ax2.set_title('Denoised') +ax2.set_title("Denoised") ax3.plot(demean(y) - x0, lw=.5) -ax3.set_title('Residual') +ax3.set_title("Residual") f.set_tight_layout(True) plt.show() diff --git a/examples/example_star_dss.py b/examples/example_star_dss.py index c4af48a5..66eb7956 100644 --- a/examples/example_star_dss.py +++ b/examples/example_star_dss.py @@ -16,15 +16,14 @@ """ import matplotlib.pyplot as plt import numpy as np - from scipy.optimize import leastsq -from meegkit import star, dss +from meegkit import dss, star from meegkit.utils import demean, normcol, tscov # import config # noqa -np.random.seed(9) +rng = np.random.default_rng(9) ############################################################################### # Create simulated data @@ -37,12 +36,12 @@ f = 2 target = np.sin(np.arange(n_samples) / n_samples * 2 * np.pi * f) target = target[:, np.newaxis] -noise = np.random.randn(n_samples, n_chans - 3) +noise = rng.standard_normal((n_samples, n_chans - 3)) # Create artifact signal SNR = np.sqrt(1) -x0 = (normcol(np.dot(noise, np.random.randn(noise.shape[1], n_chans))) + - SNR * target * np.random.randn(1, n_chans)) +x0 = normcol(np.dot(noise, rng.standard_normal((noise.shape[1], n_chans)))) + \ + SNR * target * rng.standard_normal((1, n_chans)) x0 = demean(x0) artifact = np.zeros(x0.shape) for k in np.arange(n_chans): @@ -95,18 +94,18 @@ def func(y): # ----------------------------------------------------------------------------- f, (ax0, ax1, ax2, ax3) = plt.subplots(4, 1, figsize=(7, 9)) ax0.plot(target, lw=.5) -ax0.set_title('Target') +ax0.set_title("Target") ax1.plot(x, lw=.5) -ax1.set_title('Signal + Artifacts (SNR = {})'.format(SNR)) +ax1.set_title(f"Signal + Artifacts (SNR = {SNR})") -ax2.plot(z1[:, 0], lw=.5, label='Best DSS component') -ax2.set_title('DSS') -ax2.legend(loc='lower right') +ax2.plot(z1[:, 0], lw=.5, label="Best DSS component") +ax2.set_title("DSS") +ax2.legend(loc="lower right") -ax3.plot(z2[:, 0], lw=.5, label='Best DSS component') -ax3.set_title('STAR + DSS') -ax3.legend(loc='lower right') +ax3.plot(z2[:, 0], lw=.5, label="Best DSS component") +ax3.set_title("STAR + DSS") +ax3.legend(loc="lower right") f.set_tight_layout(True) plt.show() diff --git a/examples/example_trca.py b/examples/example_trca.py index 0d227b07..b36f2bb6 100644 --- a/examples/example_trca.py +++ b/examples/example_trca.py @@ -11,8 +11,8 @@ Uses `meegkit.trca.TRCA()`. -References: - +References +---------- .. [1] M. Nakanishi, Y. Wang, X. Chen, Y.-T. Wang, X. Gao, and T.-P. Jung, "Enhancing detection of SSVEPs for a high-speed brain speller using task-related component analysis", IEEE Trans. Biomed. Eng, 65(1): 104-112, @@ -33,6 +33,7 @@ import matplotlib.pyplot as plt import numpy as np import scipy.io + from meegkit.trca import TRCA from meegkit.utils.trca import itr, normfit, round_half_up @@ -65,7 +66,7 @@ ############################################################################### # Load data # ----------------------------------------------------------------------------- -path = os.path.join('..', 'tests', 'data', 'trcadata.mat') +path = os.path.join("..", "tests", "data", "trcadata.mat") eeg = scipy.io.loadmat(path)["eeg"] n_trials, n_chans, n_samples, n_blocks = eeg.shape @@ -98,27 +99,27 @@ [(54, 90), (48, 100)]] f, ax = plt.subplots(1, figsize=(7, 4)) -for i, band in enumerate(filterbank): +for i, _band in enumerate(filterbank): ax.axvspan(ymin=i / len(filterbank) + .02, ymax=(i + 1) / len(filterbank) - .02, xmin=filterbank[i][1][0], xmax=filterbank[i][1][1], - alpha=0.2, facecolor=f'C{i}') + alpha=0.2, facecolor=f"C{i}") ax.axvspan(ymin=i / len(filterbank) + .02, ymax=(i + 1) / len(filterbank) - .02, xmin=filterbank[i][0][0], xmax=filterbank[i][0][1], - alpha=0.5, label=f'sub-band{i}', facecolor=f'C{i}') + alpha=0.5, label=f"sub-band{i}", facecolor=f"C{i}") for f in list_freqs.flat: colors = np.ones((9, 4)) colors[:, :3] = np.linspace(0, .5, 9)[:, None] ax.scatter(f * np.arange(1, 10), [f] * 9, c=colors, s=8, zorder=100) -ax.set_ylabel('Stimulus frequency (Hz)') -ax.set_xlabel('EEG response frequency (Hz)') +ax.set_ylabel("Stimulus frequency (Hz)") +ax.set_xlabel("EEG response frequency (Hz)") ax.set_xlim([0, 102]) ax.set_xticks(np.arange(0, 100, 10)) -ax.grid(True, ls=':', axis='x') -ax.legend(bbox_to_anchor=(1.05, .5), fontsize='small') +ax.grid(True, ls=":", axis="x") +ax.legend(bbox_to_anchor=(1.05, .5), fontsize="small") plt.tight_layout() plt.show() @@ -126,7 +127,7 @@ # Now perform the TRCA-based SSVEP detection algorithm trca = TRCA(sfreq, filterbank, is_ensemble) -print('Results of the ensemble TRCA-based method:\n') +print("Results of the ensemble TRCA-based method:\n") accs = np.zeros(n_blocks) itrs = np.zeros(n_blocks) for i in range(n_blocks): @@ -159,8 +160,8 @@ mu, _, muci, _ = normfit(itrs, alpha_ci) print(f"Mean ITR = {mu:.1f}\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f})") if is_ensemble: - ensemble = 'ensemble TRCA-based method' + ensemble = "ensemble TRCA-based method" else: - ensemble = 'TRCA-based method' + ensemble = "TRCA-based method" print(f"\nElapsed time: {time.time()-t:.1f} seconds") diff --git a/meegkit/__init__.py b/meegkit/__init__.py index 0aec3dee..a68935b6 100644 --- a/meegkit/__init__.py +++ b/meegkit/__init__.py @@ -1,7 +1,7 @@ """M/EEG denoising utilities in python.""" -__version__ = '0.1.3' +__version__ = "0.1.4" -from . import asr, cca, detrend, dss, lof, sns, star, ress, trca, tspca, utils +from . import asr, cca, detrend, dss, lof, ress, sns, star, trca, tspca, utils -__all__ = ['asr', 'cca', 'detrend', 'dss', 'lof', 'ress', 'sns', 'star', 'trca', - 'tspca', 'utils'] +__all__ = ["asr", "cca", "detrend", "dss", "lof", "ress", "sns", "star", "trca", + "tspca", "utils"] diff --git a/meegkit/asr.py b/meegkit/asr.py index 2674e376..b22295fc 100755 --- a/meegkit/asr.py +++ b/meegkit/asr.py @@ -6,8 +6,7 @@ from statsmodels.robust.scale import mad from .utils import block_covariance, nonlinear_eigenspace -from .utils.asr import (geometric_median, fit_eeg_distribution, yulewalk, - yulewalk_filter) +from .utils.asr import fit_eeg_distribution, geometric_median, yulewalk, yulewalk_filter try: import pyriemann @@ -101,12 +100,12 @@ class ASR(): def __init__(self, sfreq=250, cutoff=5, blocksize=100, win_len=0.5, win_overlap=0.66, max_dropout_fraction=0.1, - min_clean_fraction=0.25, name='asrfilter', method='euclid', - estimator='scm', **kwargs): + min_clean_fraction=0.25, name="asrfilter", method="euclid", + estimator="scm", **kwargs): - if pyriemann is None and method == 'riemann': - logging.warning('Need pyriemann to use riemannian ASR flavor.') - method = 'euclid' + if pyriemann is None and method == "riemann": + logging.warning("Need pyriemann to use riemannian ASR flavor.") + method = "euclid" self.cutoff = cutoff self.blocksize = blocksize @@ -224,10 +223,10 @@ def transform(self, X, y=None, **kwargs): X, sfreq=self.sfreq, ab=self.ab_, zi=self.zi_) if not self._fitted: - logging.warning('ASR is not fitted ! Returning unfiltered data.') + logging.warning("ASR is not fitted ! Returning unfiltered data.") return X - if self.estimator == 'scm': + if self.estimator == "scm": cov = 1 / X.shape[-1] * X_filt @ X_filt.T else: cov = pyriemann.estimation.covariances(X_filt[None, ...], @@ -331,7 +330,7 @@ def clean_windows(X, sfreq, max_bad_chans=0.2, zthresholds=[-3.5, 5], N = int(win_len * sfreq) offsets = np.round(np.arange(0, ns - N, (N * (1 - win_overlap)))) offsets = offsets.astype(int) - logging.debug('[ASR] Determining channel-wise rejection thresholds') + logging.debug("[ASR] Determining channel-wise rejection thresholds") wz = np.zeros((nc, len(offsets))) for ichan in range(nc): @@ -391,18 +390,18 @@ def clean_windows(X, sfreq, max_bad_chans=0.2, zthresholds=[-3.5, 5], for i in range(nc): ax[i].fill_between(times, 0, 1, where=sample_mask.flat, transform=ax[i].get_xaxis_transform(), - facecolor='none', hatch='...', edgecolor='k', - label='selected window') - ax[i].plot(times, X[i], lw=.5, label='EEG') + facecolor="none", hatch="...", edgecolor="k", + label="selected window") + ax[i].plot(times, X[i], lw=.5, label="EEG") ax[i].set_ylim([-50, 50]) # ax[i].set_ylabel(raw.ch_names[i]) ax[i].set_yticks([]) - ax[i].set_xlabel('Time (s)') - ax[i].set_ylabel(f'ch{i}') - ax[0].legend(fontsize='small', bbox_to_anchor=(1.04, 1), + ax[i].set_xlabel("Time (s)") + ax[i].set_ylabel(f"ch{i}") + ax[0].legend(fontsize="small", bbox_to_anchor=(1.04, 1), borderaxespad=0) plt.subplots_adjust(hspace=0, right=0.75) - plt.suptitle('Clean windows') + plt.suptitle("Clean windows") plt.show() return clean, sample_mask @@ -410,7 +409,7 @@ def clean_windows(X, sfreq, max_bad_chans=0.2, zthresholds=[-3.5, 5], def asr_calibrate(X, sfreq, cutoff=5, blocksize=100, win_len=0.5, win_overlap=0.66, max_dropout_fraction=0.1, - min_clean_fraction=0.25, method='euclid', estimator='scm'): + min_clean_fraction=0.25, method="euclid", estimator="scm"): """Calibration function for the Artifact Subspace Reconstruction method. The input to this data is a multi-channel time series of calibration data. @@ -478,7 +477,7 @@ def asr_calibrate(X, sfreq, cutoff=5, blocksize=100, win_len=0.5, Threshold matrix. """ - logging.debug('[ASR] Calibrating...') + logging.debug("[ASR] Calibrating...") # set number of channels and number of samples [nc, ns] = X.shape @@ -491,11 +490,11 @@ def asr_calibrate(X, sfreq, cutoff=5, blocksize=100, win_len=0.5, U = block_covariance(X, window=blocksize, overlap=win_overlap, estimator=estimator) - if method == 'euclid': + if method == "euclid": Uavg = geometric_median(U.reshape((-1, nc * nc))) Uavg = Uavg.reshape((nc, nc)) else: # method == 'riemann' - Uavg = pyriemann.utils.mean.mean_covariance(U, metric='riemann') + Uavg = pyriemann.utils.mean.mean_covariance(U, metric="riemann") # get the mixing matrix M M = linalg.sqrtm(np.real(Uavg)) @@ -520,11 +519,11 @@ def asr_calibrate(X, sfreq, cutoff=5, blocksize=100, win_len=0.5, Y, min_clean_fraction, max_dropout_fraction) T = np.dot(np.diag(mu + cutoff * sig), V.T) - logging.debug('[ASR] Calibration done.') + logging.debug("[ASR] Calibration done.") return M, T -def asr_process(X, X_filt, state, cov=None, detrend=False, method='riemann', +def asr_process(X, X_filt, state, cov=None, detrend=False, method="riemann", sample_weight=None): """Apply Artifact Subspace Reconstruction method. @@ -567,14 +566,14 @@ def asr_process(X, X_filt, state, cov=None, detrend=False, method='riemann', if cov is None: if detrend: - X_filt = signal.detrend(X_filt, axis=1, type='constant') + X_filt = signal.detrend(X_filt, axis=1, type="constant") cov = block_covariance(X_filt, window=nc ** 2) cov = cov.squeeze() if cov.ndim == 3: - if method == 'riemann': + if method == "riemann": cov = pyriemann.utils.mean.mean_covariance( - cov, metric='riemann', sample_weight=sample_weight) + cov, metric="riemann", sample_weight=sample_weight) else: cov = geometric_median(cov.reshape((-1, nc * nc))) cov = cov.reshape((nc, nc)) @@ -582,7 +581,7 @@ def asr_process(X, X_filt, state, cov=None, detrend=False, method='riemann', maxdims = int(np.fix(0.66 * nc)) # constant TODO make param # do a PCA to find potential artifacts - if method == 'riemann': + if method == "riemann": D, Vtmp = nonlinear_eigenspace(cov, nc) # TODO else: D, Vtmp = linalg.eigh(cov) @@ -604,14 +603,14 @@ def asr_process(X, X_filt, state, cov=None, detrend=False, method='riemann', demux = VT * keep[:, None] R = np.dot(np.dot(M, linalg.pinv(demux)), V.T) - if state['R'] is not None: + if state["R"] is not None: # apply the reconstruction to intermediate samples (using raised-cosine # blending) blend = (1 - np.cos(np.pi * np.arange(ns) / ns)) / 2 - clean = blend * R.dot(X) + (1 - blend) * state['R'].dot(X) + clean = blend * R.dot(X) + (1 - blend) * state["R"].dot(X) else: clean = R.dot(X) - state['R'] = R + state["R"] = R return clean, state diff --git a/meegkit/cca.py b/meegkit/cca.py index ebba7d32..efdcb604 100644 --- a/meegkit/cca.py +++ b/meegkit/cca.py @@ -3,7 +3,7 @@ from scipy import linalg from .utils import cov_lags, pca -from .utils.matrix import _check_shifts, normcol, relshift, _times_to_delays +from .utils.matrix import _check_shifts, _times_to_delays, normcol, relshift try: from tqdm import tqdm @@ -11,7 +11,7 @@ def tqdm(*args, **kwargs): # noqa if args: return args[0] - return kwargs.get('iterable', None) + return kwargs.get("iterable", None) def mcca(C, n_channels, n_keep=[]): @@ -45,9 +45,9 @@ def mcca(C, n_channels, n_keep=[]): """ if C.shape[0] != C.shape[1]: - raise ValueError('Covariance must be square !') + raise ValueError("Covariance must be square !") if np.mod(C.shape[0], n_channels) != 0: - raise ValueError('!') + raise ValueError("!") # Whiten covariance by blocks n_blocks = C.shape[0] // n_channels @@ -78,8 +78,7 @@ def mcca(C, n_channels, n_keep=[]): return A, scores, AA -def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False, - plot=False): +def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False, plot=False): """CCA with cross-validation. Parameters @@ -119,8 +118,8 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False, xx = [xx[..., t] for t in np.arange(xx.shape[-1])] yy = [yy[..., t] for t in np.arange(yy.shape[-1])] else: - raise AttributeError('xx and yy both must be lists of same length, ' - 'or arrays os same n_trials.') + raise AttributeError("xx and yy both must be lists of same length, " + "or arrays os same n_trials.") shifts = _times_to_delays(shifts, sfreq) shifts, n_shifts = _check_shifts(shifts) @@ -128,13 +127,13 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False, n_feats = xx[0].shape[1] + yy[0].shape[1] # sum of channels # Calculate covariance matrices - print('Calculate all covariances...') + print("Calculate all covariances...") C = np.zeros((n_feats, n_feats, n_shifts, n_trials)).squeeze() for t in tqdm(np.arange(n_trials)): C[..., t], _, _ = cov_lags(xx[t], yy[t], shifts) # Calculate leave-one-out CCAs - print('Calculate CCAs...') + print("Calculate CCAs...") AA = [] BB = [] for t in tqdm(np.arange(n_trials)): @@ -152,7 +151,7 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False, del C, CC # Calculate leave-one-out correlation coefficients - print('Calculate cross-correlations...') + print("Calculate cross-correlations...") n_comps = AA[0].shape[1] r = np.zeros((n_comps, n_shifts)) RR = np.zeros((n_comps, n_shifts, n_trials)) @@ -191,23 +190,23 @@ def cca_crossvalidate(xx, yy, shifts=None, sfreq=1, surrogate=False, import matplotlib.pyplot as plt f, (ax1) = plt.subplots(1, 1) for k in range(RR.shape[0]): - ax1.plot(shifts, np.mean(RR[k, :, :], 1).T, label='CC{}'.format(k)) - ax1.set_title('correlation for each CC') - ax1.set_xlabel('shift') - ax1.set_ylabel('correlation') + ax1.plot(shifts, np.mean(RR[k, :, :], 1).T, label=f"CC{k}") + ax1.set_title("correlation for each CC") + ax1.set_xlabel("shift") + ax1.set_ylabel("correlation") ax1.legend() # if surrogate: # ax1.plot(SD.T, ':') f2, axes = plt.subplots(min(4, RR.shape[0]), 1) - for k, ax in zip(np.arange(min(4, RR.shape[0])), axes): + for k, ax in zip(np.arange(min(4, RR.shape[0])), axes, strict=True): idx = np.argmax(np.mean(RR[k, :, :], 1)) [x, y] = relshift(xx[0], yy[0], shifts[idx]) - ax.plot(np.dot(x, AA[0][:, k, idx]).T, label='CC{}'.format(k)) - ax.plot(np.dot(y, BB[0][:, k, idx]).T, ':') + ax.plot(np.dot(x, AA[0][:, k, idx]).T, label=f"CC{k}") + ax.plot(np.dot(y, BB[0][:, k, idx]).T, ":") ax.legend() - ax.set_xlabel('sample') + ax.set_xlabel("sample") f2.set_tight_layout(True) plt.show() @@ -268,7 +267,7 @@ def nt_cca(X=None, Y=None, lags=None, C=None, m=None, thresh=1e-12, sfreq=1): """ if (X is None and Y is not None) or (Y is None and X is not None): - raise AttributeError('Either *both* X and Y should be defined, or C!') + raise AttributeError("Either *both* X and Y should be defined, or C!") if X is not None: lags = _times_to_delays(lags, sfreq) @@ -278,15 +277,15 @@ def nt_cca(X=None, Y=None, lags=None, C=None, m=None, thresh=1e-12, sfreq=1): return A, B, R if C is None: - raise RuntimeError('covariance matrix should be defined') + raise RuntimeError("covariance matrix should be defined") if m is None: - raise RuntimeError('m should be defined') + raise RuntimeError("m should be defined") if C.shape[0] != C.shape[1]: - raise RuntimeError('covariance matrix should be square') + raise RuntimeError("covariance matrix should be square") if any((X, Y, lags)): - raise RuntimeError('only covariance should be defined at this point') + raise RuntimeError("only covariance should be defined at this point") if C.ndim > 3: - raise RuntimeError('covariance should be 3D at most') + raise RuntimeError("covariance should be 3D at most") if C.ndim == 3: # covariance is 3D: do a separate CCA for each page n_chans, _, n_lags = C.shape diff --git a/meegkit/detrend.py b/meegkit/detrend.py index 34de9f74..43cde6ae 100644 --- a/meegkit/detrend.py +++ b/meegkit/detrend.py @@ -1,6 +1,5 @@ """Robust detrending.""" import numpy as np - from scipy.signal import lfilter from .utils import demean, mrdivide, pca, unfold @@ -8,7 +7,7 @@ from .utils.sig import stmcb -def detrend(x, order, w=None, basis='polynomials', threshold=3, n_iter=4, +def detrend(x, order, w=None, basis="polynomials", threshold=3, n_iter=4, show=False): """Robustly remove trend. @@ -59,7 +58,7 @@ def detrend(x, order, w=None, basis='polynomials', threshold=3, n_iter=4, """ if threshold == 0: - raise ValueError('thresh=0 is not what you want...') + raise ValueError("thresh=0 is not what you want...") # check/fix sizes dims = x.shape @@ -73,22 +72,22 @@ def detrend(x, order, w=None, basis='polynomials', threshold=3, n_iter=4, r = basis else: lin = np.linspace(-1, 1, n_times) - if basis == 'polynomials' or basis is None: + if basis == "polynomials" or basis is None: r = np.zeros((n_times, order)) for i, o in enumerate(range(1, order + 1)): r[:, i] = lin ** o - elif basis == 'sinusoids': + elif basis == "sinusoids": r = np.zeros((n_times, order * 2)) for i, o in enumerate(range(1, order + 1)): r[:, 2 * i] = np.sin(2 * np.pi * o * lin / 2) r[:, 2 * i + 1] = np.cos(2 * np.pi * o * lin / 2) else: - raise ValueError('!') + raise ValueError("!") # iteratively remove trends # the tricky bit is to ensure that weighted means are removed before # calculating the regression (see regress()). - for i in range(n_iter): + for _ in range(n_iter): # weighted regression on basis _, y = regress(x, r, w) @@ -150,7 +149,7 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False): r = unfold(r) x = unfold(x) if r.shape[0] != x.shape[0]: - raise ValueError('r and x have incompatible shapes!') + raise ValueError("r and x have incompatible shapes!") # save weighted mean mn = x - demean(x, w) @@ -171,11 +170,11 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False): else: # weighted regression if w.shape[0] != n_times: - raise ValueError('!') + raise ValueError("!") if w.shape[1] == 1: # same weight for all channels if sum(w.flatten()) == 0: - print('weights all zero') + print("weights all zero") b = 0 else: yy = demean(x, w) * w @@ -189,12 +188,12 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False): else: # each channel has own weight if w.shape[1] != x.shape[1]: - raise ValueError('!') + raise ValueError("!") z = np.zeros(x.shape) b = np.zeros((n_chans, n_regs)) for i in range(n_chans): if not np.any(w[:, i]): - print(f'weights are all zero for channel {i}') + print(f"weights are all zero for channel {i}") else: wc = w[:, i][:, None] # channel-specific weight xx = demean(x[:, i], wc) * wc @@ -245,7 +244,7 @@ def reduce_ringing(X, samples, order=10, n_samples=100, extra=50, threshold=3, samples = samples[samples < X.shape[0] - n_samples] y = X.copy() - for i, s in enumerate(samples): + for _i, s in enumerate(samples): for c in range(X.shape[1]): # select portion to fit filter response, remove polynomial trend response = X[s - extra:s + n_samples, c] @@ -282,22 +281,22 @@ def _plot_detrend(x, y, w): f = plt.figure() gs = GridSpec(4, 1, figure=f) ax1 = f.add_subplot(gs[:3, 0]) - lines = ax1.plot(x, label='original', color='C0') + lines = ax1.plot(x, label="original", color="C0") plt.setp(lines[1:], label="_") - lines = ax1.plot(y, label='detrended', color='C1') + lines = ax1.plot(y, label="detrended", color="C1") plt.setp(lines[1:], label="_") ax1.set_xlim(0, n_times) - ax1.set_xticklabels('') - ax1.set_title('Robust detrending') - ax1.legend(fontsize='smaller') + ax1.set_xticklabels("") + ax1.set_title("Robust detrending") + ax1.legend(fontsize="smaller") ax2 = f.add_subplot(gs[3, 0]) - ax2.pcolormesh(w.T, cmap='Greys') + ax2.pcolormesh(w.T, cmap="Greys") ax2.set_yticks(np.arange(0, n_chans) + 0.5) - ax2.set_yticklabels(['ch{}'.format(i) for i in np.arange(n_chans)]) + ax2.set_yticklabels([f"ch{i}" for i in np.arange(n_chans)]) ax2.set_xlim(0, n_times) - ax2.set_ylabel('ch. weights') - ax2.set_xlabel('samples') + ax2.set_ylabel("ch. weights") + ax2.set_xlabel("samples") plt.show() @@ -334,7 +333,7 @@ def create_masked_weight(x, events, tmin, tmax, sfreq): """ if x.ndim != 2: - raise ValueError('The shape of x must be (n_times, n_channels)') + raise ValueError("The shape of x must be (n_times, n_channels)") weights = np.ones(x.shape) for e in events: diff --git a/meegkit/dss.py b/meegkit/dss.py index f4841ba3..f2017cb9 100644 --- a/meegkit/dss.py +++ b/meegkit/dss.py @@ -2,14 +2,22 @@ # Authors: Nicolas Barascud # Maciej Szul import numpy as np +from numpy.lib.stride_tricks import sliding_window_view from scipy import linalg from scipy.signal import welch from .tspca import tsr -from .utils import (demean, gaussfilt, matmul3d, mean_over_trials, pca, smooth, - theshapeof, tscov, wpwr) - -from numpy.lib.stride_tricks import sliding_window_view +from .utils import ( + demean, + gaussfilt, + matmul3d, + mean_over_trials, + pca, + smooth, + theshapeof, + tscov, + wpwr, +) def dss1(X, weights=None, keep1=None, keep2=1e-12): @@ -94,15 +102,15 @@ def dss0(c0, c1, keep1=None, keep2=1e-9): """ if c0 is None or c1 is None: - raise AttributeError('dss0 needs at least two arguments') + raise AttributeError("dss0 needs at least two arguments") if c0.shape != c1.shape: - raise AttributeError('c0 and c1 should have same size') + raise AttributeError("c0 and c1 should have same size") if c0.shape[0] != c0.shape[1]: - raise AttributeError('c0 should be square') + raise AttributeError("c0 should be square") if np.any(np.isnan(c0)) or np.any(np.isinf(c0)): - raise ValueError('NaN or INF in c0') + raise ValueError("NaN or INF in c0") if np.any(np.isnan(c1)) or np.any(np.isinf(c1)): - raise ValueError('NaN or INF in c1') + raise ValueError("NaN or INF in c1") # derive PCA and whitening matrix from unbiased covariance eigvec0, eigval0 = pca(c0, max_comps=keep1, thresh=keep2) @@ -190,7 +198,7 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None, """ if X.shape[0] < nfft: - print('Reducing nfft to {}'.format(X.shape[0])) + print(f"Reducing nfft to {X.shape[0]}") nfft = X.shape[0] n_samples, n_chans, _ = theshapeof(X) if blocksize is None: @@ -233,10 +241,10 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None, if show: import matplotlib.pyplot as plt - plt.plot(pwr1 / pwr0, '.-') - plt.xlabel('component') - plt.ylabel('score') - plt.title('DSS to enhance line frequencies') + plt.plot(pwr1 / pwr0, ".-") + plt.xlabel("component") + plt.ylabel("score") + plt.title("DSS to enhance line frequencies") plt.show() # Remove line components from X_noise @@ -249,7 +257,7 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None, # Power of components p = wpwr(X - y)[0] / wpwr(X)[0] - print('Power of components removed by DSS: {:.2f}'.format(p)) + print(f"Power of components removed by DSS: {p:.2f}") # return the reconstructed clean signal, and the artifact return y, X - y @@ -334,7 +342,7 @@ def nan_basic_interp(array): mean_score = np.mean(residuals[freq_sp_ix]) aggr_resid.append(mean_score) - print("Iteration {} score: {}".format(iterations, mean_score)) + print(f"Iteration {iterations} score: {mean_score}") if show: import matplotlib.pyplot as plt @@ -362,7 +370,7 @@ def nan_basic_interp(array): ax.flat[2].scatter(residuals[tf_ix], freq_used[tf_ix], c=color) ax.flat[2].set_title("Residuals") - ax.flat[3].plot(np.arange(iterations + 1), aggr_resid, marker='o') + ax.flat[3].plot(np.arange(iterations + 1), aggr_resid, marker="o") ax.flat[3].set_title("Iterations") f.set_tight_layout(True) @@ -375,7 +383,7 @@ def nan_basic_interp(array): iterations += 1 if iterations == n_iter_max: - raise RuntimeError('Could not converge. Consider increasing the ' - 'maximum number of iterations') + raise RuntimeError("Could not converge. Consider increasing the " + "maximum number of iterations") return data, iterations diff --git a/meegkit/lof.py b/meegkit/lof.py index 9f9db51c..3974bf7c 100644 --- a/meegkit/lof.py +++ b/meegkit/lof.py @@ -3,6 +3,7 @@ # License: BSD-3-Clause import logging + from sklearn.neighbors import LocalOutlierFactor @@ -42,7 +43,7 @@ class LOF(): """ - def __init__(self, n_neighbors=20, metric='euclidean', + def __init__(self, n_neighbors=20, metric="euclidean", threshold=1.5, **kwargs): self.n_neighbors = n_neighbors @@ -63,21 +64,21 @@ def predict(self, X): """ if X.ndim == 3: # in case the input data is epoched - logging.warning('Expected input data with shape ' - '(n_channels, n_samples)') + logging.warning("Expected input data with shape " + "(n_channels, n_samples)") return [] if self.n_neighbors >= X.shape[0]: - logging.warning('Number of neighbours cannot be greater than the ' - 'number of channels') + logging.warning("Number of neighbours cannot be greater than the " + "number of channels") return [] if self.threshold < 1.0: - logging.warning('Invalid threshold. Try a positive integer >= 1.0') + logging.warning("Invalid threshold. Try a positive integer >= 1.0") return [] clf = LocalOutlierFactor(self.n_neighbors) - logging.debug('[LOF] Predicting bad channels') + logging.debug("[LOF] Predicting bad channels") clf.fit_predict(X) lof_scores = clf.negative_outlier_factor_ bad_channel_indices = -lof_scores >= self.threshold diff --git a/meegkit/ress.py b/meegkit/ress.py index dbb33a84..b47fd7b5 100644 --- a/meegkit/ress.py +++ b/meegkit/ress.py @@ -2,7 +2,7 @@ import numpy as np from scipy import linalg -from .utils import demean, gaussfilt, theshapeof, tscov, mrdivide +from .utils import demean, gaussfilt, mrdivide, theshapeof, tscov def RESS(X, sfreq: int, peak_freq: float, neig_freq: float = 1, diff --git a/meegkit/sns.py b/meegkit/sns.py index 9ad0fe38..cd068c87 100644 --- a/meegkit/sns.py +++ b/meegkit/sns.py @@ -5,8 +5,10 @@ from .utils import demean, fold, pca, theshapeof, tscov, unfold from .utils.matrix import _check_weights +DEFAULT_WEIGHTS = np.array([]) -def sns(X, n_neighbors=0, skip=0, weights=np.array([])): + +def sns(X, n_neighbors=0, skip=0, weights=DEFAULT_WEIGHTS): """Sensor Noise Suppression. This algorithm will replace the data from each channel by its regression on @@ -61,7 +63,7 @@ def sns(X, n_neighbors=0, skip=0, weights=np.array([])): return y, r -def sns0(c, n_neighbors=0, skip=0, wc=np.array([])): +def sns0(c, n_neighbors=0, skip=0, wc=DEFAULT_WEIGHTS): """Sensor Noise Suppression from data covariance. Parameters @@ -124,7 +126,7 @@ def sns0(c, n_neighbors=0, skip=0, wc=np.array([])): r[k, idx] = np.dot(eigvec, np.dot(wc[k][idx], eigvec)) if r[k, k] != 0: - raise RuntimeError('SNS operator should be zero along diagonal') + raise RuntimeError("SNS operator should be zero along diagonal") return r.T diff --git a/meegkit/star.py b/meegkit/star.py index f4b178d4..462b0f53 100644 --- a/meegkit/star.py +++ b/meegkit/star.py @@ -2,8 +2,7 @@ import numpy as np from scipy.signal import filtfilt -from .utils import (demean, fold, mrdivide, normcol, pca, theshapeof, tscov, - unfold, wpwr) +from .utils import demean, fold, mrdivide, normcol, pca, theshapeof, tscov, unfold, wpwr def star(X, thresh=1, closest=[], depth=1, pca_thresh=1e-15, n_smooth=10, @@ -48,7 +47,9 @@ def star(X, thresh=1, closest=[], depth=1, pca_thresh=1e-15, n_smooth=10, if thresh is None: thresh = 1 if len(closest) > 0 and closest.shape[0] != X.shape[1]: - raise ValueError('`closest` should have as many rows as n_chans') + raise ValueError("`closest` should have as many rows as n_chans") + + rng = np.random.default_rng() ndims = X.ndim n_samples, n_chans, n_trials = theshapeof(X) @@ -64,9 +65,9 @@ def star(X, thresh=1, closest=[], depth=1, pca_thresh=1e-15, n_smooth=10, idx_nan = np.all(np.isnan(X), axis=0) idx_zero = np.all(X == 0, axis=0) if idx_nan.any(): - X[:, idx_nan] = np.random.randn(X.shape[0], np.sum(idx_nan)) + X[:, idx_nan] = rng.standard_normal((X.shape[0], np.sum(idx_nan))) if idx_zero.any(): - X[:, idx_zero] = np.random.randn(X.shape[0], np.sum(idx_zero)) + X[:, idx_zero] = rng.standard_normal((X.shape[0], np.sum(idx_zero))) # initial covariance estimate X = demean(X) @@ -98,12 +99,12 @@ def star(X, thresh=1, closest=[], depth=1, pca_thresh=1e-15, n_smooth=10, artifact_free = np.mean(w, axis=0) if verbose: - print('proportion artifact free: {:.2f}'.format(artifact_free)) + print(f"proportion artifact free: {artifact_free:.2f}") if iter == n_iter and artifact_free < min_prop: thresh = thresh * 1.1 if verbose: - print('Warning: increasing threshold to {:.2f}'.format(thresh)) + print(f"Warning: increasing threshold to {thresh:.2f}") w = np.ones(w.shape) else: iter = iter - 1 @@ -153,10 +154,10 @@ def star(X, thresh=1, closest=[], depth=1, pca_thresh=1e-15, n_smooth=10, y[bad_samples, ch] = z.squeeze() # fix if verbose: - print('depth: {}'.format(i_depth + 1)) - print('fixed channels: {}'.format(i_fixed)) - print('fixed samples: {}'.format(n_fixed)) - print('ratio: {:.2f}'.format(wpwr(X)[0] / p00)) + print(f"depth: {i_depth + 1}") + print(f"fixed channels: {i_fixed}") + print(f"fixed samples: {n_fixed}") + print(f"ratio: {wpwr(X)[0] / p00:.2f}") y = demean(y) y *= norm @@ -169,9 +170,9 @@ def star(X, thresh=1, closest=[], depth=1, pca_thresh=1e-15, n_smooth=10, y[:, idx_zero] = 0 if verbose: - print('power ratio: {:.2f}'.format(wpwr(y)[0] / p0)) + print(f"power ratio: {wpwr(y)[0] / p0:.2f}") - if verbose == 'debug': + if verbose == "debug": _diagnostics(X * norm + intercept, y, d, thresh) if ndims == 3: # fold back into trials @@ -250,24 +251,24 @@ def _diagnostics(X, y, d, thresh): f, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1) ax1.plot(X, lw=.5) - ax1.set_title('Signal + Artifacts') + ax1.set_title("Signal + Artifacts") ax1.set_xticklabels([]) ax2.plot(demean(y), lw=.5) - ax2.set_title('Denoised') + ax2.set_title("Denoised") ax2.set_xticklabels([]) ax3.plot(X - demean(y), lw=.5) - ax3.set_title('Residual') + ax3.set_title("Residual") ax3.set_xticklabels([]) ax4.plot(d, lw=.5, alpha=.3) d[d < thresh] = None ax4.plot(d, lw=1) - ax4.axhline(thresh, lw=2, color='k', ls=':') + ax4.axhline(thresh, lw=2, color="k", ls=":") ax4.set_ylim([0, thresh + 1]) - ax4.set_title('Eccentricity') - ax4.set_xlabel('samples') + ax4.set_title("Eccentricity") + ax4.set_xlabel("samples") - f.set_tight_layout(True) + f.set_layout_engine("tight") plt.show() diff --git a/meegkit/trca.py b/meegkit/trca.py index 9c410805..892948dc 100644 --- a/meegkit/trca.py +++ b/meegkit/trca.py @@ -3,11 +3,11 @@ # Ludovic Darmet import numpy as np import scipy.linalg as linalg -from pyriemann.utils.mean import mean_covariance from pyriemann.estimation import Covariances +from pyriemann.utils.mean import mean_covariance -from .utils.trca import bandpass, schaefer_strimmer_cov from .utils import theshapeof +from .utils.trca import bandpass, schaefer_strimmer_cov class TRCA: @@ -66,15 +66,15 @@ class TRCA: """ - def __init__(self, sfreq, filterbank, ensemble=False, method='original', - estimator='scm'): + def __init__(self, sfreq, filterbank, ensemble=False, method="original", + estimator="scm"): self.sfreq = sfreq self.ensemble = ensemble self.filterbank = filterbank self.n_bands = len(self.filterbank) self.coef_ = None self.method = method - if estimator == 'schaefer': + if estimator == "schaefer": self.estimator = schaefer_strimmer_cov else: self.estimator = estimator @@ -112,12 +112,12 @@ def fit(self, X, y): trains[class_i, fb_i] = eeg_tmp # Find the spatial filter for the corresponding filtered signal # and label - if self.method == 'original': + if self.method == "original": w_best = trca(eeg_tmp) - elif self.method == 'riemann': + elif self.method == "riemann": w_best = trca_regul(eeg_tmp, self.estimator) else: - raise ValueError('Invalid `method` option.') + raise ValueError("Invalid `method` option.") W[fb_i, class_i, :] = w_best # Store the spatial filter @@ -144,14 +144,14 @@ def predict(self, X): """ if self.coef_ is None: - raise RuntimeError('TRCA is not fitted') + raise RuntimeError("TRCA is not fitted") # Alpha coefficients for the fusion of filterbank analysis fb_coefs = [(x + 1)**(-1.25) + 0.25 for x in range(self.n_bands)] _, _, n_trials = theshapeof(X) r = np.zeros((self.n_bands, len(self.classes))) - pred = np.zeros((n_trials), 'int') # To store predictions + pred = np.zeros((n_trials), "int") # To store predictions for trial in range(n_trials): test_tmp = X[..., trial] # pick a trial to be analysed @@ -318,9 +318,9 @@ def trca_regul(X, method): # If the number of samples is too big, we compute an approximate of # riemannian mean to speed up the computation if n_trials < 30: - S = mean_covariance(S, metric='riemann') + S = mean_covariance(S, metric="riemann") else: - S = mean_covariance(S, metric='logeuclid') + S = mean_covariance(S, metric="logeuclid") # 3. Compute eigenvalues and vectors # ------------------------------------------------------------------------- diff --git a/meegkit/tspca.py b/meegkit/tspca.py index 0efbac21..bb0622c3 100644 --- a/meegkit/tspca.py +++ b/meegkit/tspca.py @@ -1,8 +1,19 @@ """Time-shift PCA.""" import numpy as np -from .utils import (demean, fold, multishift, normcol, pca, regcov, tscov, - tsxcov, unfold, theshapeof, unsqueeze) +from .utils import ( + demean, + fold, + multishift, + normcol, + pca, + regcov, + theshapeof, + tscov, + tsxcov, + unfold, + unsqueeze, +) from .utils.matrix import _check_shifts, _check_weights diff --git a/meegkit/utils/__init__.py b/meegkit/utils/__init__.py index 766cc140..28b1cc6b 100644 --- a/meegkit/utils/__init__.py +++ b/meegkit/utils/__init__.py @@ -1,16 +1,53 @@ """Utility functions.""" -from .auditory import (AuditoryFilterbank, GammatoneFilterbank, erb2hz, - erbspace, hz2erb) +from .auditory import AuditoryFilterbank, GammatoneFilterbank, erb2hz, erbspace, hz2erb from .base import mldivide, mrdivide -from .covariances import (block_covariance, convmtx, cov_lags, - nonlinear_eigenspace, pca, regcov, tscov, tsxcov) -from .denoise import (demean, find_outlier_samples, find_outlier_trials, - mean_over_trials, wpwr) -from .matrix import (fold, matmul3d, multishift, multismooth, normcol, - relshift, shift, shiftnd, sliding_window, theshapeof, - unfold, unsqueeze, widen_mask) -from .sig import (gaussfilt, hilbert_envelope, slope_sum, smooth, - spectral_envelope, teager_kaiser) -from .stats import (bootstrap_ci, bootstrap_snr, cronbach, rms, robust_mean, - rolling_corr, snr_spectrum) +from .covariances import ( + block_covariance, + convmtx, + cov_lags, + nonlinear_eigenspace, + pca, + regcov, + tscov, + tsxcov, +) +from .denoise import ( + demean, + find_outlier_samples, + find_outlier_trials, + mean_over_trials, + wpwr, +) +from .matrix import ( + fold, + matmul3d, + multishift, + multismooth, + normcol, + relshift, + shift, + shiftnd, + sliding_window, + theshapeof, + unfold, + unsqueeze, + widen_mask, +) +from .sig import ( + gaussfilt, + hilbert_envelope, + slope_sum, + smooth, + spectral_envelope, + teager_kaiser, +) +from .stats import ( + bootstrap_ci, + bootstrap_snr, + cronbach, + rms, + robust_mean, + rolling_corr, + snr_spectrum, +) from .testing import create_line_data diff --git a/meegkit/utils/asr.py b/meegkit/utils/asr.py index 3a15120d..dd0fa6b5 100755 --- a/meegkit/utils/asr.py +++ b/meegkit/utils/asr.py @@ -6,11 +6,12 @@ from scipy.spatial.distance import cdist, euclidean from scipy.special import gamma, gammaincinv +SHAPE_RANGE = np.linspace(1.7, 3.5, 13) + def fit_eeg_distribution(X, min_clean_fraction=0.25, max_dropout_fraction=0.1, - fit_quantiles=[0.022, 0.6], - step_sizes=[0.0220, 0.6000], - shape_range=np.linspace(1.7, 3.5, 13)): + fit_quantiles=[0.022, 0.6], step_sizes=[0.0220, 0.6000], + shape_range=SHAPE_RANGE): """Estimate the mean and SD of clean EEG from contaminated data. This function estimates the mean and standard deviation of clean EEG from a @@ -39,14 +40,14 @@ def fit_eeg_distribution(X, min_clean_fraction=0.25, max_dropout_fraction=0.1, ---------- X : array, shape=(n_channels, n_samples) EEG data, possibly containing artifacts. - max_dropout_fraction : float - Maximum fraction that can have dropouts. This is the maximum fraction - of time windows that may have arbitrarily low amplitude (e.g., due to - the sensors being unplugged) (default=0.25). min_clean_fraction : float Minimum fraction that needs to be clean. This is the minimum fraction of time windows that need to contain essentially uncontaminated EEG (default=0.1). + max_dropout_fraction : float + Maximum fraction that can have dropouts. This is the maximum fraction + of time windows that may have arbitrarily low amplitude (e.g., due to + the sensors being unplugged) (default=0.25). fit_quantiles : 2-tuple Quantile range [lower,upper] of the truncated generalized Gaussian distribution that shall be fit to the EEG contents (default=[0.022 @@ -56,7 +57,7 @@ def fit_eeg_distribution(X, min_clean_fraction=0.25, max_dropout_fraction=0.1, lower bound (which essentially steps over any dropout samples), and the second value is the stepping over possible scales (i.e., clean-data quantiles) (default=[0.01, 0.01]). - beta : array + shape_range : array Range that the clean EEG distribution's shape parameter beta may take. Returns diff --git a/meegkit/utils/auditory.py b/meegkit/utils/auditory.py index bc193190..49492770 100644 --- a/meegkit/utils/auditory.py +++ b/meegkit/utils/auditory.py @@ -224,5 +224,5 @@ def __init__(self, sfreq, b=1.019, order=1, q=9.26449, min_bw=24.7): 30, 800, 1000, 1250, 1600, 2000, 2500, 3150, 4000, 5000, 6300, 8000]) - super(AuditoryFilterbank, self).__init__( + super().__init__( sfreq=sfreq, cf=cf, b=b, order=order, q=q, min_bw=min_bw) diff --git a/meegkit/utils/base.py b/meegkit/utils/base.py index 46cdff6d..be4ff314 100644 --- a/meegkit/utils/base.py +++ b/meegkit/utils/base.py @@ -39,11 +39,12 @@ def mldivide(A, B): try: # Note: we must use overwrite_a=False in order to be able to # use the fall-back solution below in case a LinAlgError is raised - return linalg.solve(A, B, assume_a='pos', overwrite_a=False) + return linalg.solve(A, B, assume_a="pos", overwrite_a=False) except linalg.LinAlgError: # Singular matrix in solving dual problem. Using least-squares # solution instead. - return linalg.lstsq(A, B, lapack_driver='gelsy')[0] - except linalg.LinAlgError: - print('Solution not stable. Model not updated!') - return None + try: + return linalg.lstsq(A, B, lapack_driver="gelsy")[0] + except linalg.LinAlgError: + print("Solution not stable. Model not updated!") + return None diff --git a/meegkit/utils/covariances.py b/meegkit/utils/covariances.py index dd496a4a..383a5f28 100644 --- a/meegkit/utils/covariances.py +++ b/meegkit/utils/covariances.py @@ -3,12 +3,19 @@ from scipy import linalg from .base import mldivide -from .matrix import (_check_shifts, _check_weights, multishift, relshift, - theshapeof, unsqueeze) +from .matrix import ( + _check_shifts, + _check_weights, + multishift, + relshift, + theshapeof, + unsqueeze, +) +rng = np.random.default_rng() -def block_covariance(data, window=128, overlap=0.5, padding=True, - estimator='cov'): + +def block_covariance(data, window=128, overlap=0.5, padding=True, estimator="cov"): """Compute blockwise covariance. Parameters @@ -78,11 +85,11 @@ def cov_lags(X, Y, shifts=None): n_samples2, n_chans2, n_trials2 = theshapeof(Y) if n_samples != n_samples2: - raise AttributeError('X and Y must have same n_times') + raise AttributeError("X and Y must have same n_times") if n_trials != n_trials2: - raise AttributeError('X and Y must have same n_trials') + raise AttributeError("X and Y must have same n_trials") if n_samples <= max(shifts): - raise AttributeError('shifts should be no larger than n_samples') + raise AttributeError("shifts should be no larger than n_samples") n_cov = n_chans + n_chans2 # sum of channels of X and Y C = np.zeros((n_cov, n_cov, n_shifts)) @@ -142,7 +149,7 @@ def tsxcov(X, Y, shifts=None, weights=None, assume_centered=True): # Apply weights if any if weights.any(): - X = np.einsum('ijk,ilk->ijk', X, weights) # element-wise mult + X = np.einsum("ijk,ilk->ijk", X, weights) # element-wise mult weights = weights[:n_times2, :, :] # cross covariance @@ -208,7 +215,7 @@ def tscov(X, shifts=None, weights=None, assume_centered=True): X = X - X.mean(0, keepdims=1) if weights.any(): # weights - X = np.einsum('ijk,ilk->ijk', X, weights) # element-wise mult + X = np.einsum("ijk,ilk->ijk", X, weights) # element-wise mult tw = np.sum(weights[:]) else: # no weights N = 0 @@ -287,7 +294,7 @@ def convmtx(V, n): [nr, nc] = V.shape V = V.flatten() - c = np.hstack((V, np.zeros((n - 1)))) + c = np.hstack((V, np.zeros(n - 1))) r = np.zeros(n) m = len(c) x_left = r[n:0:-1] # reverse order from n to 2 in original code @@ -336,7 +343,7 @@ def pca(cov, max_comps=None, thresh=0): """ if thresh is not None and (thresh > 1 or thresh < 0): - raise ValueError('Threshold must be between 0 and 1 (or None).') + raise ValueError("Threshold must be between 0 and 1 (or None).") d, V = linalg.eigh(cov) d = d.real @@ -365,13 +372,13 @@ def pca(cov, max_comps=None, thresh=0): var = 100 * d.sum() / p0 if var < 99: - print('[PCA] Explained variance of selected components : {:.2f}%'. + print("[PCA] Explained variance of selected components : {:.2f}%". format(var)) return V, d -def regcov(Cxy, Cyy, keep=np.array([]), threshold=np.array([])): +def regcov(Cxy, Cyy, keep=None, threshold=0): """Compute regression matrix from cross covariance. Parameters @@ -381,9 +388,10 @@ def regcov(Cxy, Cyy, keep=np.array([]), threshold=np.array([])): Cyy : array Covariance matrix of regressor. keep : array - Number of regressor PCs to keep (default=all). + Number of regressor PCs to keep (default=None, which keeps all). threshold : float - Eigenvalue threshold for discarding regressor PCs (default=0). + Eigenvalue threshold for discarding regressor PCs (default=0, which keeps all + components). Returns ------- @@ -447,7 +455,7 @@ def nonlinear_eigenspace(L, k, alpha=1): from pymanopt.optimizers import TrustRegions n = L.shape[0] - assert L.shape[1] == n, 'L must be square.' + assert L.shape[1] == n, "L must be square." # Grassmann manifold description manifold = Grassmann(n, k) @@ -483,7 +491,7 @@ def ehess(X, U): # Initialization as suggested in above referenced paper. # randomly generate starting point for svd - x = np.random.randn(n, k) + x = rng.standard_normal((n, k)) [U, S, V] = linalg.svd(x, full_matrices=False) x = U.dot(V.T) S0, U0 = linalg.eig( diff --git a/meegkit/utils/denoise.py b/meegkit/utils/denoise.py index 92ed6321..4bb91de2 100644 --- a/meegkit/utils/denoise.py +++ b/meegkit/utils/denoise.py @@ -1,10 +1,9 @@ """Denoising utilities.""" import matplotlib.pyplot as plt import numpy as np - from matplotlib import gridspec -from .matrix import fold, theshapeof, unfold, _check_weights +from .matrix import _check_weights, fold, theshapeof, unfold def demean(X, weights=None, return_mean=False, inplace=False): @@ -40,15 +39,15 @@ def demean(X, weights=None, return_mean=False, inplace=False): weights = unfold(weights) if weights.shape[0] != X.shape[0]: - raise ValueError('X and weights arrays should have same ' + - 'number of samples (rows).') + raise ValueError("X and weights arrays should have same " + + "number of samples (rows).") if weights.shape[1] == 1 or weights.shape[1] == n_chans: mn = (np.sum(X * weights, axis=0) / np.sum(weights, axis=0))[None, :] else: - raise ValueError('Weight array should have either the same ' + - 'number of columns as X array, or 1 column.') + raise ValueError("Weight array should have either the same " + + "number of columns as X array, or 1 column.") else: mn = np.mean(X, axis=0, keepdims=True) @@ -92,7 +91,7 @@ def mean_over_trials(X, weights=None): weights = fold(weights, n_samples) # Take weighted average - with np.errstate(divide='ignore', invalid='ignore'): + with np.errstate(divide="ignore", invalid="ignore"): y = np.sum(X, -1) / np.sum(weights, -1) tw = np.mean(weights, -1) y = np.nan_to_num(y) @@ -179,7 +178,7 @@ def find_outlier_trials(X, thresh=None, show=True): thresh = [thresh] if X.ndim > 3: - raise ValueError('X should be 2D or 3D') + raise ValueError("X should be 2D or 3D") elif X.ndim == 3: n_samples, n_chans, n_trials = theshapeof(X) X = np.reshape(X, (n_samples * n_chans, n_trials)) @@ -196,27 +195,27 @@ def find_outlier_trials(X, thresh=None, show=True): if show: plt.figure(figsize=(7, 4)) gs = gridspec.GridSpec(1, 2) - plt.suptitle('Outlier trial detection') + plt.suptitle("Outlier trial detection") # Before ax1 = plt.subplot(gs[0, 0]) - ax1.plot(d, ls='-') + ax1.plot(d, ls="-") ax1.plot(np.setdiff1d(np.arange(n_trials), idx), - d[np.setdiff1d(np.arange(n_trials), idx)], color='r', ls=' ', - marker='.') - ax1.axhline(y=thresh[0], color='grey', linestyle=':') - ax1.set_xlabel('Trial #') - ax1.set_ylabel('Normalized deviation from mean') - ax1.set_title('Before, ' + str(len(d)), fontsize=10) + d[np.setdiff1d(np.arange(n_trials), idx)], color="r", ls=" ", + marker=".") + ax1.axhline(y=thresh[0], color="grey", linestyle=":") + ax1.set_xlabel("Trial #") + ax1.set_ylabel("Normalized deviation from mean") + ax1.set_title("Before, " + str(len(d)), fontsize=10) ax1.set_xlim(0, len(d) + 1) plt.draw() # After ax2 = plt.subplot(gs[0, 1]) _, dd = find_outlier_trials(X[:, idx], None, False) - ax2.plot(dd, ls='-') - ax2.set_xlabel('Trial #') - ax2.set_title('After, ' + str(len(idx)), fontsize=10) + ax2.plot(dd, ls="-") + ax2.set_xlabel("Trial #") + ax2.set_title("After, " + str(len(idx)), fontsize=10) ax2.yaxis.tick_right() ax2.set_xlim(0, len(idx) + 1) plt.show() diff --git a/meegkit/utils/matrix.py b/meegkit/utils/matrix.py index e186b095..4c4c733e 100644 --- a/meegkit/utils/matrix.py +++ b/meegkit/utils/matrix.py @@ -100,7 +100,7 @@ def widen_mask(mask, widen=4, axis=0): dtype = mask.dtype if axis > dims - 1: - raise AttributeError('Invalid `axis` value.') + raise AttributeError("Invalid `axis` value.") if widen < 0: # This places the desired axis at the front of the shape tuple, then @@ -188,7 +188,7 @@ def relshift(X, ref, shifts, fill_value=0, axis=0): ref = _check_data(ref) if X.shape[0] != ref.shape[0]: - raise AttributeError('X and ref must have same n_times') + raise AttributeError("X and ref must have same n_times") # First we delay X y = multishift(X, shifts=shifts, axis=axis, fill_value=fill_value) @@ -209,7 +209,7 @@ def relshift(X, ref, shifts, fill_value=0, axis=0): def multishift(X, shifts, fill_value=0, axis=0, keep_dims=False, - reshape=False, solution='full'): + reshape=False, solution="full"): """Apply several shifts along specified axis. If `shifts` has multiple values, the output will contain one shift per @@ -268,7 +268,7 @@ def multishift(X, shifts, fill_value=0, axis=0, keep_dims=False, if n_shifts == 1 and not keep_dims: y = np.squeeze(y, axis=-1) - if solution == 'valid': + if solution == "valid": max_neg_shift = np.abs(np.min(np.min(shifts), 0)) max_pos_shift = np.max((np.max(shifts), 0)) y = y[max_pos_shift:-max_neg_shift, ...] @@ -349,7 +349,7 @@ def shift(X, shift, fill_value=0, axis=0): """ if not np.equal(np.mod(shift, 1), 0): - raise AttributeError('shift must be a single int') + raise AttributeError("shift must be a single int") # reallocate empty array and assign slice. y = np.empty_like(X) @@ -382,7 +382,7 @@ def shift(X, shift, fill_value=0, axis=0): y[..., :shift] = X[..., -shift:] else: - raise NotImplementedError('Axis must be 0, 1 or -1.') + raise NotImplementedError("Axis must be 0, 1 or -1.") return y @@ -499,7 +499,7 @@ def fold(X, epoch_size): if X.ndim == 1: X = X[:, np.newaxis] if X.ndim > 2: - raise AttributeError('X must be 2D at most') + raise AttributeError("X must be 2D at most") nt = X.shape[0] // epoch_size nc = X.shape[1] @@ -579,7 +579,7 @@ def normcol(X, weights=None, return_norm=False): n_samples, n_chans, n_trials = theshapeof(X) weights = _check_weights(weights, X) if not weights.any(): - with np.errstate(divide='ignore'): + with np.errstate(divide="ignore"): N = ((np.sum(X ** 2, axis=0) / n_samples) ** -0.5)[np.newaxis] N[np.isinf(N)] = 0 @@ -588,12 +588,12 @@ def normcol(X, weights=None, return_norm=False): else: if weights.shape[0] != X.shape[0]: - raise ValueError('Weight array should have same number of ' + - 'columns as X') + raise ValueError("Weight array should have same number of " + + "columns as X") if weights.shape[1] == 1: weights = np.tile(weights, (1, n_chans)) if weights.shape != X.shape: - raise ValueError('Weight array should have be same shape as X') + raise ValueError("Weight array should have be same shape as X") N = (np.sum(X ** 2 * weights, axis=0) / np.sum(weights, axis=0)) ** -0.5 @@ -623,14 +623,14 @@ def matmul3d(X, mixin): Projection. """ - assert mixin.ndim == 2, 'mixing matrix must be 2D' + assert mixin.ndim == 2, "mixing matrix must be 2D" if X.ndim == 2: return X @ mixin elif X.ndim == 3: - return np.einsum('sct,ck->skt', X, mixin) + return np.einsum("sct,ck->skt", X, mixin) else: - raise RuntimeError('X must be (n_samples, n_chans, n_trials)') + raise RuntimeError("X must be (n_samples, n_chans, n_trials)") def _check_shifts(shifts, allow_floats=False): @@ -639,7 +639,7 @@ def _check_shifts(shifts, allow_floats=False): if allow_floats: types += (float, np.float_) if not isinstance(shifts, (np.ndarray, list, type(None)) + types): - raise AttributeError('shifts should be a list, an array or an int') + raise AttributeError("shifts should be a list, an array or an int") if isinstance(shifts, (list, ) + types): shifts = np.array(shifts).flatten() if shifts is None or len(shifts) == 0: @@ -653,12 +653,12 @@ def _check_shifts(shifts, allow_floats=False): def _check_data(X): """Check data is numpy array and has the proper dimensions.""" if not isinstance(X, (np.ndarray, list)): - raise AttributeError('data should be a list or a numpy array') + raise AttributeError("data should be a list or a numpy array") dtype = np.complex128 if np.any(np.iscomplex(X)) else np.float64 X = np.asanyarray(X, dtype=dtype) if X.ndim > 3: - raise ValueError('Data must be 3D at most') + raise ValueError("Data must be 3D at most") return X @@ -667,7 +667,7 @@ def _check_weights(weights, X): """Check weights dimensions against X.""" if not isinstance(weights, (np.ndarray, list)): if weights is not None: - warnings.warn('weights should be a list or a numpy array.') + warnings.warn("weights should be a list or a numpy array.") weights = np.array([]) weights = np.asanyarray(weights) @@ -675,7 +675,7 @@ def _check_weights(weights, X): dtype = np.complex128 if np.any(np.iscomplex(weights)) else np.float64 weights = np.asanyarray(weights, dtype=dtype) if weights.ndim > 3: - raise ValueError('Weights must be 3D at most') + raise ValueError("Weights must be 3D at most") if weights.shape[0] != X.shape[0]: raise ValueError("Weights should be the same n_times as X.") @@ -695,7 +695,7 @@ def _check_weights(weights, X): raise ValueError("Weights array should have a single column.") if np.any(np.abs(weights) > 1.): - warnings.warn('weights should be between 0 and 1.') + warnings.warn("weights should be between 0 and 1.") weights[np.abs(weights) > 1.] = 1. return weights @@ -706,18 +706,18 @@ def _times_to_delays(lags, sfreq): if lags is None: return np.array([0]) if not isinstance(sfreq, (int, float, np.int_)): - raise ValueError('`sfreq` must be an integer or float') + raise ValueError("`sfreq` must be an integer or float") sfreq = float(sfreq) if not all([isinstance(ii, (int, float, np.int_)) for ii in lags]): - raise ValueError('lags must be an integer or float') + raise ValueError("lags must be an integer or float") if len(lags) == 2 and sfreq != 1: tmin = lags[0] tmax = lags[1] if not tmin <= tmax: - raise ValueError('tmin must be <= tmax') + raise ValueError("tmin must be <= tmax") # Convert seconds to samples delays = np.arange(int(np.round(tmin * sfreq)), diff --git a/meegkit/utils/sig.py b/meegkit/utils/sig.py index ada58d5d..03824517 100644 --- a/meegkit/utils/sig.py +++ b/meegkit/utils/sig.py @@ -8,7 +8,7 @@ def modulation_index(phase, amp, n_bins=18): - u"""Compute the Modulation Index (MI) between two signals. + """Compute the Modulation Index (MI) between two signals. MI is a measure of the amount of phase-amplitude coupling. Phase angles are expected to be in radians [1]_. MI is derived from the Kullbach-Leibner @@ -36,7 +36,7 @@ def modulation_index(phase, amp, n_bins=18): Examples -------- >> phas = np.random.rand(100, 1) * 2 * np.pi - np.pi - >> ampl = np.random.randn(100, 1) * 30 + 100 + >> ampl = rng.standard_normal((100, 1) * 30 + 100 >> MI, KL = modulation_index(phas, ampl) Notes @@ -78,7 +78,7 @@ def modulation_index(phase, amp, n_bins=18): phase = phase.squeeze() amp = amp.squeeze() if phase.shape != amp.shape or phase.ndims > 1 or amp.ndims: - raise AttributeError('Inputs must be 1D vectors of same length.') + raise AttributeError("Inputs must be 1D vectors of same length.") # Convert phase to degrees phasedeg = np.degrees(phase) @@ -111,7 +111,7 @@ def modulation_index(phase, amp, n_bins=18): return MI, KL -def smooth(x, window_len, window='square', axis=0, align='left'): +def smooth(x, window_len, window="square", axis=0, align="left"): """Smooth a signal using a window with requested size along a given axis. This method is based on the convolution of a scaled window with the signal. @@ -160,35 +160,35 @@ def smooth(x, window_len, window='square', axis=0, align='left'): """ if x.shape[axis] < window_len: - raise ValueError('Input vector needs to be bigger than window size.') - if window not in ['square', 'hanning', 'hamming', 'bartlett', 'blackman']: - raise ValueError('Unknown window type.') + raise ValueError("Input vector needs to be bigger than window size.") + if window not in ["square", "hanning", "hamming", "bartlett", "blackman"]: + raise ValueError("Unknown window type.") if window_len == 0: - raise ValueError('Smoothing kernel must be at least 1 sample wide') + raise ValueError("Smoothing kernel must be at least 1 sample wide") if window_len == 1: return x - def _smooth1d(x, n, align='left'): + def _smooth1d(x, n, align="left"): if x.ndim != 1: - raise ValueError('Smooth only accepts 1D arrays') + raise ValueError("Smooth only accepts 1D arrays") frac, n = np.modf(n) n = int(n) - if window == 'square': # moving average - w = np.ones(n, 'd') + if window == "square": # moving average + w = np.ones(n, "d") w = np.r_[w, frac] else: - w = eval('np.' + window + '(n)') + w = eval("np." + window + "(n)") - if align == 'center': + if align == "center": a = x[n - 1:0:-1] b = x[-2:-n - 1:-1] s = np.r_[a, x, b] - out = np.convolve(w / w.sum(), s, mode='same') + out = np.convolve(w / w.sum(), s, mode="same") return out[len(a):-len(b)] - elif align == 'left': + elif align == "left": out = ss.lfilter(w / w.sum(), 1, x) return out @@ -219,7 +219,7 @@ def lowpass_env_filtering(x, cutoff=150., n=1, sfreq=22050): Low-pass filtered signal. """ - b, a = ss.butter(N=n, Wn=cutoff * 2. / sfreq, btype='lowpass') + b, a = ss.butter(N=n, Wn=cutoff * 2. / sfreq, btype="lowpass") return ss.lfilter(b, a, x) @@ -261,7 +261,7 @@ def spectral_envelope(x, sfreq, lowpass=32): """ x = np.squeeze(x) if x.ndim > 1: - raise AttributeError('x must be 1D') + raise AttributeError("x must be 1D") if lowpass is None: lowpass = sfreq / 2 @@ -272,7 +272,7 @@ def spectral_envelope(x, sfreq, lowpass=32): s = np.r_[a, x, b] # Convolve squared signal with a square window and take cubic root - y = np.convolve(s ** 2, np.ones((win,)) / win, mode='same') ** (1 / 3) + y = np.convolve(s ** 2, np.ones((win,)) / win, mode="same") ** (1 / 3) return y[len(a):-len(b)] @@ -314,16 +314,16 @@ def gaussfilt(data, srate, f, fwhm, n_harm=1, shift=0, return_empvals=False, """ # input check assert (data.shape[1] <= data.shape[0] - ), 'n_channels must be less than n_samples' - assert ((f - fwhm) >= 0), 'increase frequency or decrease FWHM' - assert (fwhm >= 0), 'FWHM must be greater than 0' + ), "n_channels must be less than n_samples" + assert ((f - fwhm) >= 0), "increase frequency or decrease FWHM" + assert (fwhm >= 0), "FWHM must be greater than 0" # frequencies hz = np.fft.fftfreq(data.shape[0], 1. / srate) empVals = np.zeros((2,)) # compute empirical frequency and standard deviation - idx_p = np.searchsorted(hz[hz >= 0], f, 'left') + idx_p = np.searchsorted(hz[hz >= 0], f, "left") # create Gaussian fx = np.zeros_like(hz) @@ -363,15 +363,15 @@ def gaussfilt(data, srate, f, fwhm, n_harm=1, shift=0, return_empvals=False, # inspect the Gaussian (turned off by default) import matplotlib.pyplot as plt plt.figure(1) - plt.plot(hz, fx, 'o-') + plt.plot(hz, fx, "o-") plt.xlim([0, None]) - title = 'Requested: {}, {} Hz\nEmpirical: {}, {} Hz'.format( + title = "Requested: {}, {} Hz\nEmpirical: {}, {} Hz".format( f, fwhm, empVals[0], empVals[1] ) plt.title(title) - plt.xlabel('Frequency (Hz)') - plt.ylabel('Amplitude gain') + plt.xlabel("Frequency (Hz)") + plt.ylabel("Amplitude gain") plt.show() if return_empvals: @@ -524,7 +524,7 @@ def stmcb(x, u_in=None, q=None, p=None, niter=5, a_in=None): a = a_in N = len(x) - for i in range(niter): + for _i in range(niter): u = lfilter([1], a, x) v = lfilter([1], a, u_in) C1 = convmtx(u, (p + 1)).T diff --git a/meegkit/utils/stats.py b/meegkit/utils/stats.py index ba9cdfad..830c6571 100644 --- a/meegkit/utils/stats.py +++ b/meegkit/utils/stats.py @@ -1,5 +1,4 @@ """Statistics utilities.""" -from __future__ import division, print_function import numpy as np @@ -10,6 +9,8 @@ except ImportError: mne = None +rng = np.random.default_rng() + def rms(X, axis=0): """Root-mean-square along given axis.""" @@ -62,11 +63,11 @@ def rolling_corr(X, y, window=None, sfreq=1, step=1, axis=0): if y.ndim == 3: y = np.squeeze(y) if X.ndim > 3: - raise AttributeError('Data must be 2D or 3D.') + raise AttributeError("Data must be 2D or 3D.") if y.shape[0] != X.shape[0]: - raise AttributeError('X and y must share the same time axis.') + raise AttributeError("X and y must share the same time axis.") if y.ndim > 2: - raise AttributeError('y must be at most 2D.') + raise AttributeError("y must be at most 2D.") n_times, n_chans, n_epochs = theshapeof(X) timebins = np.arange(n_times - window, 0, -step)[::-1] @@ -116,15 +117,13 @@ def bootstrap_ci(X, n_bootstrap=2000, ci=(5, 95), axis=-1): Confidence intervals. """ - n_samples, n_chans, n_trials = theshapeof(X) idx = np.arange(X.shape[axis], dtype=int) - shape = list(X.shape) shape.pop(axis) - bootstraps = np.nan * np.ones(((n_bootstrap,) + tuple(shape))) + bootstraps = np.nan * np.ones((n_bootstrap,) + tuple(shape)) for i in range(n_bootstrap): - temp_idx = np.random.choice(idx, replace=True, size=len(idx)) + temp_idx = rng.choice(idx, replace=True, size=len(idx)) bootstraps[i] = np.mean(np.take(X, temp_idx, axis=axis), axis=axis) ci_low, ci_up = np.percentile(bootstraps, ci, axis=0) @@ -179,14 +178,14 @@ def bootstrap_snr(epochs, n_bootstrap=2000, baseline=None, window=None): gfp_bs = np.empty((n_bootstrap, n_chans, len(epochs.times))) for i in range(n_bootstrap): - bs_indices = np.random.choice(indices, replace=True, size=len(indices)) + bs_indices = rng.choice(indices, replace=True, size=len(indices)) erp_bs[i] = np.mean(epochs._data[bs_indices, ...], 0) # Baseline correct mean waveform if baseline: erp_bs[i] = mne.baseline.rescale(erp_bs[i], epochs.times, baseline=baseline, - verbose='ERROR') + verbose="ERROR") # Rectify waveform gfp_bs[i] = np.sqrt(erp_bs[i] ** 2) @@ -271,7 +270,7 @@ def cronbach(epochs, K=None, n_bootstrap=2000, tmin=None, tmax=None): alpha = np.empty((n_bootstrap, n_chans)) for b in np.arange(n_bootstrap): # take K trials randomly - idx = np.random.choice(range(n_trials), K) + idx = rng.choice(range(n_trials), K) X = erp[idx, :, tmin:tmax] sigmaY = X.var(axis=2).sum(0) # var over time sigmaX = X.sum(axis=0).var(-1) # var of average @@ -329,8 +328,8 @@ def snr_spectrum(X, freqs, n_avg=1, n_harm=1, skipbins=1): n_freqs = X.shape[0] n_chans = X.shape[1] else: - raise ValueError('Data must have shape (n_freqs, n_chans, [n_trials,])' - f', got {X.shape}') + raise ValueError("Data must have shape (n_freqs, n_chans, [n_trials,])" + f", got {X.shape}") # Number of points to get desired resolution X = np.reshape(X, (n_freqs, n_chans * n_trials)) @@ -385,7 +384,7 @@ def snr_spectrum(X, freqs, n_avg=1, n_harm=1, skipbins=1): B[h] = np.mean(X[bin_noise[h], i_trial].flatten() ** 2) # Ratio - with np.errstate(divide='ignore', invalid='ignore'): + with np.errstate(divide="ignore", invalid="ignore"): SNR[i_bin, i_trial] = np.sqrt(A) / np.sqrt(B) del A diff --git a/meegkit/utils/testing.py b/meegkit/utils/testing.py index 008f616c..e715028e 100644 --- a/meegkit/utils/testing.py +++ b/meegkit/utils/testing.py @@ -1,8 +1,8 @@ """Synthetic test data.""" +import matplotlib.pyplot as plt import numpy as np -from meegkit.utils import fold, rms, unfold -import matplotlib.pyplot as plt +from meegkit.utils import fold, rms, unfold def create_line_data(n_samples=100 * 3, n_chans=30, n_trials=100, noise_dim=20, @@ -63,9 +63,9 @@ def create_line_data(n_samples=100 * 3, n_chans=30, n_trials=100, noise_dim=20, if show: f, ax = plt.subplots(3) - ax[0].plot(source.mean(-1), label='source') - ax[1].plot(noise[:, 1].mean(-1), label='noise (avg over trials)') - ax[2].plot(data[:, 1].mean(-1), label='mixture (avg over trials)') + ax[0].plot(source.mean(-1), label="source") + ax[1].plot(noise[:, 1].mean(-1), label="noise (avg over trials)") + ax[2].plot(data[:, 1].mean(-1), label="mixture (avg over trials)") ax[0].legend() ax[1].legend() ax[2].legend() diff --git a/meegkit/utils/trca.py b/meegkit/utils/trca.py index cb25adde..849a98ab 100644 --- a/meegkit/utils/trca.py +++ b/meegkit/utils/trca.py @@ -1,8 +1,7 @@ """TRCA utils.""" import numpy as np - -from scipy.signal import filtfilt, cheb1ord, cheby1 from scipy import stats +from scipy.signal import cheb1ord, cheby1, filtfilt def round_half_up(num, decimals=0): @@ -91,10 +90,10 @@ def itr(n, p, t): itr = 0 if (p < 0 or 1 < p): - raise ValueError('Accuracy need to be between 0 and 1.') + raise ValueError("Accuracy need to be between 0 and 1.") elif (p < 1 / n): itr = 0 - raise ValueError('ITR might be incorrect because accuracy < chance') + raise ValueError("ITR might be incorrect because accuracy < chance") elif (p == 1): itr = np.log2(n) * 60 / t else: @@ -137,7 +136,7 @@ def bandpass(eeg, sfreq, Wp, Ws): # the arguments 'axis=0, padtype='odd', padlen=3*(max(len(B),len(A))-1)' # correspond to Matlab filtfilt : https://dsp.stackexchange.com/a/47945 - y = filtfilt(B, A, eeg, axis=0, padtype='odd', + y = filtfilt(B, A, eeg, axis=0, padtype="odd", padlen=3 * (max(len(B), len(A)) - 1)) return y diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..0301fb52 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,137 @@ +# As per https://github.com/pypa/setuptools/blob/main/docs/userguide/quickstart.rst +################################## +# Project # +################################## + +[project] +name = "meegkit" +authors = [ + {name = "Nicolas Barascud", email = "nicolas.barascud@gmail.com"}, +] +license = {text = "BSD (3-clause)"} +dynamic = ["version", "dependencies"] +description = "M/EEG denoising in Python" +readme = {file = "README.md", content-type = "text/markdown"} +requires-python = ">=3.8" + +[project.urls] +repository = "https://github.com/nbara/python-meegkit" +documentation = "https://nbara.github.io/python-meegkit/" +tracker = "https://github.com/nbara/python-meegkit/issues/" + +[project.optional-dependencies] +extra = ["pymanopt"] +docs = ["sphinx", "sphinx-gallery", "sphinx-bootstrap_theme", "sphinx-copybutton", + "sphinxemoji", "numpydoc", "pydata-sphinx-theme", "pillow", "jupyter-sphinx", + "ghp-import", "meegkit[extra]"] +tests = ["pytest", "pytest-cov", "codecov", "codespell", "ruff", "meegkit[extra]"] + +################################## +# Package building # +################################## + +[build-system] +requires = ["setuptools>=62.0.0", "wheel", "pybind11~=2.10.3"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +zip-safe = false +include-package-data = true + +[tool.setuptools.dynamic] +version = {attr = "meegkit.__version__"} +dependencies = {file = ["requirements.txt"]} + +[tool.setuptools.packages.find] +exclude = ["tests*", "examples*", "doc*"] + +################################## +# Codespell # +################################## + +[tool.codespell] +skip = """*.html,*.fif,*.eve,*.gz,*.tgz,*.zip,*.mat,*.stc, + *.label,*.w,*.bz2,*.annot,*.sulc,*.log,*.local-copy, + *.orig_avg,*.inflated_avg,*.gii,*.pyc,*.doctree,*.pickle, + *.inv,*.png,*.edf,*.touch,*.thickness,*.nofix,*.volume, + *.defect_borders,*.mgh,lh.*,rh.*,COR-*,*.examples,.xdebug_mris_calc + ,bad.segments,BadChannels,*.hist,empty_file,*.orig,*.js,*.map,*.pdf, + *.ipynb,searchindex.dat,*.c""" +ignore-words-list = "hist,dof,datas,mot" +quiet-level = 3 +interactive = 0 +write-changes = false + +################################## +# Linter configuration # +################################## +[tool.ruff] +select = ["D", "E", "F", "B", "Q", "NPY", "I", "ICN", "UP"] +line-length = 90 +target-version = "py310" +ignore-init-module-imports = true +ignore = ["E731", "B006", "B028", "UP038", "D100", "D105", "D212"] + +[tool.ruff.per-file-ignores] +"__init__.py" = ["F401"] +"test_*.py" = ["D101", "D102", "D103", "D"] +"example_*.py" = ["D205", "D400", "D212"] + +[tool.ruff.pydocstyle] +convention = "numpy" + +################################## +# Pytest, Coverage # +################################## + +[tool.pytest.ini_options] +timeout = 1200 +testpaths = ["tests/"] +filterwarnings = [ + "ignore:Call to deprecated create function FieldDescriptor", + "ignore:Call to deprecated create function Descriptor", + "ignore:Call to deprecated create function EnumDescriptor", + "ignore:Call to deprecated create function EnumValueDescriptor", + "ignore:Call to deprecated create function FileDescriptor", + "ignore:Call to deprecated create function OneofDescriptor" +] +addopts = """ + --color=yes + --durations 10 + """ + +[tool.coverage.run] +branch = true +data_file = "build/reports/coverage/.meegkit" +omit = [ + "*/python?.?/*", + "*/site-packages/pytest/*", + "*/setup.py", + "*/tests/*", + "*/examples/*", +] + +[tool.coverage.report] +skip_covered = true +skip_empty = true +exclude_lines = [ + "pragma: no cover", + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:" +] + +[tool.coverage.html] +directory = "build/reports/coverage/html" + +[tool.coverage.xml] +output = "build/reports/coverage/meegkit.xml" + +################################## +# Documentation # +################################## +[tool.sphinx.build] +source-dir = "doc/" +build-dir = "doc/_build" +all_files = 1 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 748a351d..00000000 --- a/setup.cfg +++ /dev/null @@ -1,24 +0,0 @@ -[flake8] -exclude = __init__.py,*externals*,constants.py,fixes.py -ignore = E241,E305,W504 - -[pydocstyle] -convention = pep257 -match_dir = ^(?!\.|doc|tutorials|tests|examples).*$ -match = (?!test_|fixes).*\.py -add-ignore = D100,D107,D413 -add-select = D214,D215,D404,D405,D406,D407,D408,D409,D410,D411 -ignore-decorators = ^(copy_.*_doc_to_|on_trait_change|cached_property|deprecated|property|.*setter).* - -[build_sphinx] -source-dir = doc/ -build-dir = doc/_build -all_files = 1 - -[upload_sphinx] -upload-dir = doc/_build/html - -[options.extras_require] -extra = pymanopt -docs = sphinx;sphinx-gallery;sphinx-bootstrap_theme;sphinx-copybutton;sphinxemoji;numpydoc;pydata-sphinx-theme;pillow;jupyter-sphinx;ghp-import;meegkit[extra] -tests = pytest;pytest-cov;codecov;codespell;flake8;pydocstyle;meegkit[extra] diff --git a/setup.py b/setup.py index 66737f48..8782e952 100644 --- a/setup.py +++ b/setup.py @@ -1,28 +1,5 @@ #! /usr/bin/env python # -from setuptools import setup, find_packages +from setuptools import setup -with open("README.md", "r", encoding="utf8") as fid: - long_description = fid.read() - -setup( - name='meegkit', - description='M/EEG denoising in Python', - long_description=long_description, - long_description_content_type="text/markdown", - url='http://github.com/nbara/python-meegkit/', - author='Nicolas Barascud', - author_email='nicolas.barascud@gmail.com', - license='BSD (3-clause)', - version='0.1.3', - packages=find_packages(exclude=['doc', 'tests']), - project_urls={ - "Documentation": "https://nbara.github.io/python-meegkit/", - "Source": "https://github.com/nbara/python-meegkit/", - "Tracker": "https://github.com/nbara/python-meegkit/issues/", - }, - platforms="any", - python_requires=">=3.8", - install_requires=["numpy", "scipy", "scikit-learn", "joblib", "pandas", - "matplotlib", "tqdm", "pyriemann", "statsmodels"], - zip_safe=False) +setup() diff --git a/tests/conftest.py b/tests/conftest.py index 0b54abf0..0ab9b351 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ -import pytest - import matplotlib.pyplot as plt +import pytest def pytest_addoption(parser): @@ -22,7 +21,7 @@ def pytest_addoption(parser): def pytest_collection_modifyitems(config, items): """Do not skip slow test if option provided.""" if config.getoption("--noplots"): - plt.switch_backend('agg') + plt.switch_backend("agg") if config.getoption("--runslow"): # --runslow given in cli: do not skip slow tests diff --git a/tests/test_asr.py b/tests/test_asr.py index d698cad6..f213939d 100644 --- a/tests/test_asr.py +++ b/tests/test_asr.py @@ -4,10 +4,11 @@ import matplotlib.pyplot as plt import numpy as np import pytest +from scipy import signal + from meegkit.asr import ASR, asr_calibrate, asr_process, clean_windows from meegkit.utils.asr import yulewalk, yulewalk_filter from meegkit.utils.matrix import sliding_window -from scipy import signal # Data files THIS_FOLDER = os.path.dirname(os.path.abspath(__file__)) @@ -18,9 +19,10 @@ # raw.crop(0, 60) # keep 60s only # raw.pick_types(eeg=True, misc=False) # raw = raw._data +rng = np.random.default_rng(9) -@pytest.mark.parametrize(argnames='sfreq', argvalues=(125, 250, 256, 2048)) +@pytest.mark.parametrize(argnames="sfreq", argvalues=(125, 250, 256, 2048)) def test_yulewalk(sfreq, show=False): """Test that my version of yulewelk works just like MATLAB's.""" # Temp fix, values are computed in matlab using yulewalk.m @@ -53,7 +55,7 @@ def test_yulewalk(sfreq, show=False): -132.920238664871, 158.567177443427, -121.909488069062, 58.9853908881204, -16.4212688404351, 2.01391570212326] else: - raise AttributeError('Currently sfreq must be 250, 256 or 2048...') + raise AttributeError("Currently sfreq must be 250, 256 or 2048...") # Theoretical values w0, h0 = signal.freqz(b, a, sfreq) @@ -68,12 +70,12 @@ def test_yulewalk(sfreq, show=False): if show: fig = plt.figure() ax = fig.add_subplot(111) - ax.plot(w0 / np.pi, np.abs(h0), label='matlab') - ax.plot(w1 / np.pi, np.abs(h1), ':', label='mine') - ax.set_title('Filter frequency response') - ax.set_xlabel('Frequency [radians / second]') - ax.set_ylabel('Amplitude [dB]') - ax.grid(which='both', axis='both') + ax.plot(w0 / np.pi, np.abs(h0), label="matlab") + ax.plot(w1 / np.pi, np.abs(h1), ":", label="mine") + ax.set_title("Filter frequency response") + ax.set_xlabel("Frequency [radians / second]") + ax.set_ylabel("Amplitude [dB]") + ax.grid(which="both", axis="both") ax.legend() # plt.show() @@ -86,48 +88,48 @@ def test_yulewalk(sfreq, show=False): if show: plt.figure() - plt.plot(f, m, label='ideal') - plt.plot(w / np.pi, np.abs(h), '--', label='yw designed') + plt.plot(f, m, label="ideal") + plt.plot(w / np.pi, np.abs(h), "--", label="yw designed") plt.legend() - plt.title('Comparison of Frequency Response Magnitudes') + plt.title("Comparison of Frequency Response Magnitudes") plt.legend() plt.show() -@pytest.mark.parametrize(argnames='n_chans', argvalues=(4, 8, 12)) +@pytest.mark.parametrize(argnames="n_chans", argvalues=(4, 8, 12)) def test_yulewalk_filter(n_chans, show=False): """Test yulewalk filter.""" - raw = np.load(os.path.join(THIS_FOLDER, 'data', 'eeg_raw.npy')) + raw = np.load(os.path.join(THIS_FOLDER, "data", "eeg_raw.npy")) sfreq = 250 n_chan_orig = raw.shape[0] - raw = np.random.randn(n_chans, n_chan_orig) @ raw + raw = rng.standard_normal((n_chans, n_chan_orig)) @ raw raw_filt, iirstate = yulewalk_filter(raw, sfreq) if show: f, ax = plt.subplots(n_chans, sharex=True, figsize=(8, 5)) for i in range(n_chans): - ax[i].plot(raw[i], lw=.5, label='before') - ax[i].plot(raw_filt[i], label='after', lw=.5) + ax[i].plot(raw[i], lw=.5, label="before") + ax[i].plot(raw_filt[i], label="after", lw=.5) ax[i].set_ylim([-50, 50]) if i < n_chans - 1: ax[i].set_yticks([]) - ax[i].set_xlabel('Time (s)') - ax[i].set_ylabel(f'ch{i}') - ax[0].legend(fontsize='small', bbox_to_anchor=(1.04, 1), + ax[i].set_xlabel("Time (s)") + ax[i].set_ylabel(f"ch{i}") + ax[0].legend(fontsize="small", bbox_to_anchor=(1.04, 1), borderaxespad=0) plt.subplots_adjust(hspace=0, right=0.75) - plt.suptitle('Before/after filter') + plt.suptitle("Before/after filter") plt.show() -def test_asr_functions(show=False, method='riemann'): +def test_asr_functions(show=False, method="riemann"): """Test ASR functions (offline use). Note: this will not be optimal since the filter parameters will be estimated only once and not updated online as is intended. """ - raw = np.load(os.path.join(THIS_FOLDER, 'data', 'eeg_raw.npy')) + raw = np.load(os.path.join(THIS_FOLDER, "data", "eeg_raw.npy")) sfreq = 250 raw_filt = raw.copy() raw_filt, iirstate = yulewalk_filter(raw_filt, sfreq) @@ -147,34 +149,33 @@ def test_asr_functions(show=False, method='riemann'): if show: f, ax = plt.subplots(8, sharex=True, figsize=(8, 5)) for i in range(8): - ax[i].fill_between(train_idx, 0, 1, color='grey', alpha=.3, + ax[i].fill_between(train_idx, 0, 1, color="grey", alpha=.3, transform=ax[i].get_xaxis_transform(), - label='calibration window') + label="calibration window") ax[i].fill_between(train_idx, 0, 1, where=sample_mask.flat, transform=ax[i].get_xaxis_transform(), - facecolor='none', hatch='...', edgecolor='k', - label='selected window') - ax[i].plot(raw[i], lw=.5, label='before ASR') - ax[i].plot(clean[i], label='after ASR', lw=.5) + facecolor="none", hatch="...", edgecolor="k", + label="selected window") + ax[i].plot(raw[i], lw=.5, label="before ASR") + ax[i].plot(clean[i], label="after ASR", lw=.5) # ax[i].set_xlim([10, 50]) ax[i].set_ylim([-50, 50]) # ax[i].set_ylabel(raw.ch_names[i]) if i < 7: ax[i].set_yticks([]) - ax[i].set_xlabel('Time (s)') - ax[0].legend(fontsize='small', bbox_to_anchor=(1.04, 1), + ax[i].set_xlabel("Time (s)") + ax[0].legend(fontsize="small", bbox_to_anchor=(1.04, 1), borderaxespad=0) plt.subplots_adjust(hspace=0, right=0.75) - plt.suptitle('Before/after ASR') + plt.suptitle("Before/after ASR") plt.show() -@pytest.mark.parametrize(argnames='method', argvalues=('riemann', 'euclid')) -@pytest.mark.parametrize(argnames='reref', argvalues=(False, True)) +@pytest.mark.parametrize(argnames="method", argvalues=("riemann", "euclid")) +@pytest.mark.parametrize(argnames="reref", argvalues=(False, True)) def test_asr_class(method, reref, show=False): """Test ASR class (simulate online use).""" - np.random.default_rng(9) - raw = np.load(os.path.join(THIS_FOLDER, 'data', 'eeg_raw.npy')) + raw = np.load(os.path.join(THIS_FOLDER, "data", "eeg_raw.npy")) sfreq = 250 # Train on a clean portion of data train_idx = np.arange(5 * sfreq, 45 * sfreq, dtype=int) @@ -187,15 +188,15 @@ def test_asr_class(method, reref, show=False): raw2 = raw.copy() if reref: - if method == 'riemann': - with pytest.raises(ValueError, match='Add regularization'): - blah = ASR(method=method, estimator='scm') + if method == "riemann": + with pytest.raises(ValueError, match="Add regularization"): + blah = ASR(method=method, estimator="scm") blah.fit(raw2[:, train_idx]) - asr = ASR(method=method, estimator='lwf') + asr = ASR(method=method, estimator="lwf") asr.fit(raw2[:, train_idx]) else: - asr = ASR(method=method, estimator='scm') + asr = ASR(method=method, estimator="scm") asr.fit(raw2[:, train_idx]) # Split into small windows @@ -219,27 +220,27 @@ def test_asr_class(method, reref, show=False): if show: f, ax = plt.subplots(8, sharex=True, figsize=(8, 5)) for i in range(8): - ax[i].plot(times, X[i], lw=.5, label='before ASR') - ax[i].plot(times, Y[i], label='after ASR', lw=.5) + ax[i].plot(times, X[i], lw=.5, label="before ASR") + ax[i].plot(times, Y[i], label="after ASR", lw=.5) ax[i].set_ylim([-50, 50]) - ax[i].set_ylabel(f'ch{i}') + ax[i].set_ylabel(f"ch{i}") if i < 7: ax[i].set_yticks([]) - ax[i].set_xlabel('Time (s)') - ax[0].legend(fontsize='small', bbox_to_anchor=(1.04, 1), + ax[i].set_xlabel("Time (s)") + ax[0].legend(fontsize="small", bbox_to_anchor=(1.04, 1), borderaxespad=0) plt.subplots_adjust(hspace=0, right=0.75) - plt.suptitle('Before/after ASR') + plt.suptitle("Before/after ASR") f, ax = plt.subplots(8, sharex=True, figsize=(8, 5)) for i in range(8): - ax[i].plot(times, Y[i], label='incremental', lw=.5) - ax[i].plot(times, Y2[i], label='bulk', lw=.5) - ax[i].plot(times, Y[i] - Y2[i], label='difference', lw=.5) + ax[i].plot(times, Y[i], label="incremental", lw=.5) + ax[i].plot(times, Y2[i], label="bulk", lw=.5) + ax[i].plot(times, Y[i] - Y2[i], label="difference", lw=.5) if i < 7: ax[i].set_yticks([]) - ax[i].set_xlabel('Time (s)') - plt.suptitle('incremental vs. bulk difference ') + ax[i].set_xlabel("Time (s)") + plt.suptitle("incremental vs. bulk difference ") plt.show() # TODO: the transform() process is stochastic, so Y and Y2 are not going to diff --git a/tests/test_cca.py b/tests/test_cca.py index 0c469f40..39c684c0 100644 --- a/tests/test_cca.py +++ b/tests/test_cca.py @@ -3,7 +3,7 @@ from scipy.io import loadmat from sklearn.cross_decomposition import CCA -from meegkit.cca import cca_crossvalidate, nt_cca, mcca +from meegkit.cca import cca_crossvalidate, mcca, nt_cca from meegkit.utils import multishift, tscov @@ -14,11 +14,11 @@ def test_cca(): # y = rng.randn(1000, 9) # x = demean(x).squeeze() # y = demean(y).squeeze() - mat = loadmat('./tests/data/ccadata.mat') - x = mat['x'] - y = mat['y'] - A2 = mat['A2'] - B2 = mat['B2'] + mat = loadmat("./tests/data/ccadata.mat") + x = mat["x"] + y = mat["y"] + A2 = mat["A2"] + B2 = mat["B2"] A1, B1, R = nt_cca(x, y) # if mean(A1(:).*A2(:))<0; A2=-A2; end X1 = np.dot(x, A1) @@ -80,9 +80,9 @@ def test_cca2(): def test_cca_scaling(): """Test CCA with MEG data.""" - data = np.load('./tests/data/ccadata_meg_2trials.npz') - raw = data['arr_0'] - env = data['arr_1'] + data = np.load("./tests/data/ccadata_meg_2trials.npz") + raw = data["arr_0"] + env = data["arr_1"] # Test with scaling (unit: fT) A0, B0, R0 = nt_cca(raw * 1e15, env) @@ -138,9 +138,9 @@ def test_correlated(): def test_cca_lags(): """Test multiple lags.""" - mat = loadmat('./tests/data/ccadata.mat') - x = mat['x'] - y = mat['y'] + mat = loadmat("./tests/data/ccadata.mat") + x = mat["x"] + y = mat["y"] y[:, :3] = x[:, :3] lags = np.arange(-10, 11, 1) A1, B1, R1 = nt_cca(x, y, lags) @@ -165,10 +165,10 @@ def test_cca_crossvalidate(): # xx = [x, x, x] # yy = [x[:, :9], y, y] - mat = loadmat('./tests/data/ccadata2.mat') - xx = mat['x'] - yy = mat['y'] - R1 = mat['R'] # no shifts + mat = loadmat("./tests/data/ccadata2.mat") + xx = mat["x"] + yy = mat["y"] + R1 = mat["R"] # no shifts # Test with no shifts A, B, R = cca_crossvalidate(xx, yy) @@ -201,8 +201,8 @@ def test_cca_crossvalidate_shifts(): # uncorrelated y[:, 6:8, :] = rng.randn(n_times, 2, n_trials) - xx = multishift(x, -np.arange(1, 4), reshape=True, solution='valid') - yy = multishift(y, -np.arange(1, 4), reshape=True, solution='valid') + xx = multishift(x, -np.arange(1, 4), reshape=True, solution="valid") + yy = multishift(y, -np.arange(1, 4), reshape=True, solution="valid") # Test with shifts A, B, R = cca_crossvalidate(xx, yy, shifts=[-3, -2, -1, 0, 1, 2, 3]) @@ -217,10 +217,10 @@ def test_cca_crossvalidate_shifts(): def test_cca_crossvalidate_shifts2(): """Test CCA crossvalidation with shifts.""" - mat = loadmat('./tests/data/ccacrossdata.mat') - xx = mat['xx2'] - yy = mat['yy2'] - R2 = mat['R'][:, ::-1, :] # shifts go in reverse direction in noisetools + mat = loadmat("./tests/data/ccacrossdata.mat") + xx = mat["xx2"] + yy = mat["yy2"] + R2 = mat["R"][:, ::-1, :] # shifts go in reverse direction in noisetools # Test with shifts A, B, R = cca_crossvalidate(xx, yy, shifts=[-3, -2, -1, 0, 1, 2, 3]) @@ -261,18 +261,18 @@ def test_mcca(show=False): if show: import matplotlib.pyplot as plt f, axes = plt.subplots(2, 3, figsize=(10, 6)) - axes[0, 0].imshow(A, aspect='auto') - axes[0, 0].set_title('mCCA transform matrix') - axes[0, 1].imshow(A.T @ C @ A, aspect='auto') - axes[0, 1].set_title('Covariance of\ntransformed data') - axes[0, 2].imshow(x.T @ x @ A, aspect='auto') - axes[0, 2].set_title('Cross-correlation between\nraw & transformed data') - axes[0, 2].set_xlabel('transformed') - axes[0, 2].set_ylabel('raw') + axes[0, 0].imshow(A, aspect="auto") + axes[0, 0].set_title("mCCA transform matrix") + axes[0, 1].imshow(A.T @ C @ A, aspect="auto") + axes[0, 1].set_title("Covariance of\ntransformed data") + axes[0, 2].imshow(x.T @ x @ A, aspect="auto") + axes[0, 2].set_title("Cross-correlation between\nraw & transformed data") + axes[0, 2].set_xlabel("transformed") + axes[0, 2].set_ylabel("raw") ax = plt.subplot2grid((2, 3), (1, 0), colspan=3) - ax.plot(np.mean(z ** 2, axis=0), ':o') - ax.set_ylabel('Power') - ax.set_xlabel('CC') + ax.plot(np.mean(z ** 2, axis=0), ":o") + ax.set_ylabel("Power") + ax.set_xlabel("CC") plt.tight_layout() plt.show() @@ -296,18 +296,18 @@ def test_mcca(show=False): if show: f, axes = plt.subplots(2, 3, figsize=(10, 6)) - axes[0, 0].imshow(A, aspect='auto') - axes[0, 0].set_title('mCCA transform matrix') - axes[0, 1].imshow(A.T.dot(C.dot(A)), aspect='auto') - axes[0, 1].set_title('Covariance of\ntransformed data') - axes[0, 2].imshow(x.T.dot((x.dot(A))), aspect='auto') - axes[0, 2].set_title('Cross-correlation between\nraw & transformed data') - axes[0, 2].set_xlabel('transformed') - axes[0, 2].set_ylabel('raw') + axes[0, 0].imshow(A, aspect="auto") + axes[0, 0].set_title("mCCA transform matrix") + axes[0, 1].imshow(A.T.dot(C.dot(A)), aspect="auto") + axes[0, 1].set_title("Covariance of\ntransformed data") + axes[0, 2].imshow(x.T.dot(x.dot(A)), aspect="auto") + axes[0, 2].set_title("Cross-correlation between\nraw & transformed data") + axes[0, 2].set_xlabel("transformed") + axes[0, 2].set_ylabel("raw") ax = plt.subplot2grid((2, 3), (1, 0), colspan=3) - ax.plot(np.mean(z ** 2, axis=0), ':o') - ax.set_ylabel('Power') - ax.set_xlabel('CC') + ax.plot(np.mean(z ** 2, axis=0), ":o") + ax.set_ylabel("Power") + ax.set_xlabel("CC") plt.tight_layout() plt.show() @@ -330,20 +330,20 @@ def test_mcca(show=False): # Plot results if show: f, axes = plt.subplots(2, 3, figsize=(10, 6)) - axes[0, 0].imshow(A, aspect='auto') - axes[0, 0].set_title('mCCA transform matrix') + axes[0, 0].imshow(A, aspect="auto") + axes[0, 0].set_title("mCCA transform matrix") - axes[0, 1].imshow(A.T @ C @ A, aspect='auto') - axes[0, 1].set_title('Covariance of\ntransformed data') + axes[0, 1].imshow(A.T @ C @ A, aspect="auto") + axes[0, 1].set_title("Covariance of\ntransformed data") - axes[0, 2].imshow(x.T @ x @ A, aspect='auto') - axes[0, 2].set_title('Cross-correlation between\nraw & transformed data') - axes[0, 2].set_xlabel('transformed') - axes[0, 2].set_ylabel('raw') + axes[0, 2].imshow(x.T @ x @ A, aspect="auto") + axes[0, 2].set_title("Cross-correlation between\nraw & transformed data") + axes[0, 2].set_xlabel("transformed") + axes[0, 2].set_ylabel("raw") ax = plt.subplot2grid((2, 3), (1, 0), colspan=3) - ax.plot(np.mean(z ** 2, axis=0), ':o') - ax.set_ylabel('Power') - ax.set_xlabel('CC') + ax.plot(np.mean(z ** 2, axis=0), ":o") + ax.set_ylabel("Power") + ax.set_xlabel("CC") plt.tight_layout() plt.show() @@ -352,7 +352,7 @@ def test_mcca(show=False): assert np.all(diagonal[:10] > 1), diagonal[:10] assert np.all(diagonal[10:] < .01) -if __name__ == '__main__': +if __name__ == "__main__": import pytest pytest.main([__file__]) # test_mcca(False) diff --git a/tests/test_cov.py b/tests/test_cov.py index f45c8c7e..ce69b470 100644 --- a/tests/test_cov.py +++ b/tests/test_cov.py @@ -1,12 +1,13 @@ import numpy as np from numpy.testing import assert_almost_equal -from meegkit.utils import tscov, tsxcov, convmtx +from meegkit.utils import convmtx, tscov, tsxcov +rng = np.random.default_rng(10) def test_tscov(): """Test time-shift covariance.""" - x = 2 * np.eye(3) + 0.1 * np.random.rand(3) + x = 2 * np.eye(3) + 0.1 * rng.random(3) x = x - np.mean(x, 0) # Compare 0-lag case with numpy.cov() @@ -87,7 +88,7 @@ def test_convmtx(): ]) ) -if __name__ == '__main__': +if __name__ == "__main__": # import pytest # pytest.main([__file__]) test_convmtx() diff --git a/tests/test_detrend.py b/tests/test_detrend.py index 5d115044..47eeed13 100644 --- a/tests/test_detrend.py +++ b/tests/test_detrend.py @@ -1,42 +1,42 @@ """Test robust detrending.""" import numpy as np - -from meegkit.detrend import regress, detrend, reduce_ringing, create_masked_weight - from scipy.signal import butter, lfilter +from meegkit.detrend import create_masked_weight, detrend, reduce_ringing, regress + +rng = np.random.default_rng(9) def test_regress(): """Test regression.""" # Simple regression example, no weights # fit random walk - y = np.cumsum(np.random.randn(1000, 1), axis=0) + y = np.cumsum(rng.standard_normal((1000, 1)), axis=0) x = np.arange(1000.)[:, None] x = np.hstack([x, x ** 2, x ** 3]) [b, z] = regress(y, x) # Simple regression example, with weights - y = np.cumsum(np.random.randn(1000, 1), axis=0) - w = np.random.rand(*y.shape) + y = np.cumsum(rng.standard_normal((1000, 1)), axis=0) + w = rng.random(y.shape) [b, z] = regress(y, x, w) # Downweight 1st half of the data - y = np.cumsum(np.random.randn(1000, 1), axis=0) + 1000 + y = np.cumsum(rng.standard_normal((1000, 1)), axis=0) + 1000 w = np.ones(y.shape[0]) w[:500] = 0 [b, z] = regress(y, x, w) # # Multichannel regression - y = np.cumsum(np.random.randn(1000, 2), axis=0) + y = np.cumsum(rng.standard_normal((1000, 2)), axis=0) w = np.ones(y.shape[0]) [b, z] = regress(y, x, w) assert z.shape == (1000, 2) assert b.shape == (2, 1) # Multichannel regression - y = np.cumsum(np.random.randn(1000, 2), axis=0) + y = np.cumsum(rng.standard_normal((1000, 2)), axis=0) w = np.ones(y.shape) - w[:, 1] == .8 + w[:, 1] = .8 [b, z] = regress(y, x, w) assert z.shape == (1000, 2) assert b.shape == (2, 3) @@ -46,21 +46,21 @@ def test_detrend(show=False): """Test detrending.""" # basic x = np.arange(100)[:, None] # trend - source = np.random.randn(*x.shape) + source = rng.standard_normal(x.shape) x = x + source y, _, _ = detrend(x, 1) assert y.shape == x.shape # detrend biased random walk - x = np.cumsum(np.random.randn(1000, 1) + 0.1) + x = np.cumsum(rng.standard_normal((1000, 1)) + 0.1) y, _, _ = detrend(x, 3) assert y.shape == x.shape # test weights trend = np.linspace(0, 100, 1000)[:, None] - data = 3 * np.random.randn(*trend.shape) + data = 3 * rng.standard_normal(trend.shape) data[:100, :] = 100 x = trend + data w = np.ones(x.shape) @@ -77,27 +77,27 @@ def test_detrend(show=False): assert np.all(np.abs(yy[100:] - data[100:]) < 1.) # detrend higher-dimensional data - x = np.cumsum(np.random.randn(1000, 16) + 0.1, axis=0) + x = np.cumsum(rng.standard_normal((1000, 16)) + 0.1, axis=0) y, _, _ = detrend(x, 1, show=False) # detrend higher-dimensional data with order 3 polynomial - x = np.cumsum(np.random.randn(1000, 16) + 0.1, axis=0) - y, _, _ = detrend(x, 3, basis='polynomials', show=True) + x = np.cumsum(rng.standard_normal((1000, 16)) + 0.1, axis=0) + y, _, _ = detrend(x, 3, basis="polynomials", show=True) # detrend with sinusoids - x = np.random.randn(1000, 2) + x = rng.standard_normal((1000, 2)) x += 2 * np.sin(2 * np.pi * np.arange(1000) / 200)[:, None] - y, _, _ = detrend(x, 5, basis='sinusoids', show=True) + y, _, _ = detrend(x, 5, basis="sinusoids", show=True) # trial-masked detrending trend = np.linspace(0, 100, 1000)[:, None] - data = 3 * np.random.randn(*trend.shape) + data = 3 * rng.standard_normal(trend.shape) data[:100, :] = 100 x = trend + data events = np.arange(30, 970, 40) tmin, tmax, sfreq = -0.2, 0.3, 20 w = create_masked_weight(x, events, tmin, tmax, sfreq) - y, _, _ = detrend(x, 1, w, basis='polynomials', show=show) + y, _, _ = detrend(x, 1, w, basis="polynomials", show=show) def test_ringing(): @@ -106,13 +106,13 @@ def test_ringing(): [b, a] = butter(6, 0.2) # Butterworth filter design x = lfilter(b, a, x) * 50 # Filter data using above filter x = np.roll(x, 500) - signal = np.random.randn(1000, 2) + signal = rng.standard_normal((1000, 2)) x = x[:, None] + signal - y = reduce_ringing(x, samples=np.array([500])) + reduce_ringing(x, samples=np.array([500])) # np.testing.assert_array_almost_equal(y, signal, 2) -if __name__ == '__main__': +if __name__ == "__main__": import pytest pytest.main([__file__]) # test_detrend(False) diff --git a/tests/test_dss.py b/tests/test_dss.py index be390355..5b354a66 100644 --- a/tests/test_dss.py +++ b/tests/test_dss.py @@ -5,13 +5,16 @@ import matplotlib.pyplot as plt import numpy as np import pytest -from meegkit import dss -from meegkit.utils import create_line_data, fold, tscov, unfold from numpy.testing import assert_allclose from scipy import signal +from meegkit import dss +from meegkit.utils import create_line_data, fold, tscov, unfold + +rng = np.random.default_rng(10) + -@pytest.mark.parametrize('n_bad_chans', [0, -1]) +@pytest.mark.parametrize("n_bad_chans", [0, -1]) def test_dss0(n_bad_chans): """Test dss0. @@ -64,9 +67,9 @@ def test_dss1(show=True): if show: f, (ax1, ax2, ax3) = plt.subplots(3, 1) - ax1.plot(source, label='source') - ax2.plot(np.mean(data, 2), label='data') - ax3.plot(best_comp, label='recovered') + ax1.plot(source, label="source") + ax2.plot(np.mean(data, 2), label="data") + ax3.plot(best_comp, label="recovered") plt.legend() plt.show() @@ -74,7 +77,7 @@ def test_dss1(show=True): atol=1e-6) # use abs as DSS component might be flipped -@pytest.mark.parametrize('nkeep', [None, 2]) +@pytest.mark.parametrize("nkeep", [None, 2]) def test_dss_line(nkeep): """Test line noise removal.""" sr = 200 @@ -92,11 +95,11 @@ def _plot(x): f, Pxx = signal.welch(s, sr, nperseg=1024, axis=0, return_onesided=True) ax[0].semilogy(f, Pxx) - ax[0].set_xlabel('frequency [Hz]') - ax[1].set_xlabel('frequency [Hz]') - ax[0].set_ylabel('PSD [V**2/Hz]') - ax[0].set_title('before') - ax[1].set_title('after') + ax[0].set_xlabel("frequency [Hz]") + ax[1].set_xlabel("frequency [Hz]") + ax[0].set_ylabel("PSD [V**2/Hz]") + ax[0].set_title("before") + ax[1].set_title("after") plt.show() # 2D case, n_outputs == 1 @@ -112,7 +115,7 @@ def _plot(x): # _plot(out) # Test n_trials > 1 - x = np.random.randn(nsamples, nchans, 4) + x = rng.standard_normal((nsamples, nchans, 4)) artifact = np.sin( np.arange(nsamples) / sr * 2 * np.pi * fline)[:, None, None] artifact[artifact < 0] = 0 @@ -144,7 +147,7 @@ def test_dss_line_iter(): with TemporaryDirectory() as tmpdir: out, _ = dss.dss_line_iter(x, fline + .5, sr, - prefix=os.path.join(tmpdir, 'dss_iter_'), + prefix=os.path.join(tmpdir, "dss_iter_"), show=True) def _plot(before, after): @@ -155,11 +158,11 @@ def _plot(before, after): f, Pxx = signal.welch(after[:, -1], sr, nperseg=1024, axis=0, return_onesided=True) ax[1].semilogy(f, Pxx) - ax[0].set_xlabel('frequency [Hz]') - ax[1].set_xlabel('frequency [Hz]') - ax[0].set_ylabel('PSD [V**2/Hz]') - ax[0].set_title('before') - ax[1].set_title('after') + ax[0].set_xlabel("frequency [Hz]") + ax[1].set_xlabel("frequency [Hz]") + ax[0].set_ylabel("PSD [V**2/Hz]") + ax[0].set_title("before") + ax[1].set_title("after") plt.show() _plot(x, out) @@ -193,7 +196,7 @@ def profile_dss_line(nkeep): ps.print_stats() print(s.getvalue()) -if __name__ == '__main__': +if __name__ == "__main__": pytest.main([__file__]) # create_data(SNR=5, show=True) # test_dss1(True) diff --git a/tests/test_lof.py b/tests/test_lof.py index b37d2382..7635067a 100644 --- a/tests/test_lof.py +++ b/tests/test_lof.py @@ -7,26 +7,26 @@ from meegkit.lof import LOF -np.random.seed(9) +rng = np.random.default_rng(10) # Data files THIS_FOLDER = os.path.dirname(os.path.abspath(__file__)) # data folder of MEEGKIT -@pytest.mark.parametrize(argnames='n_neighbors', argvalues=(8, 20, 40, 2048)) +@pytest.mark.parametrize(argnames="n_neighbors", argvalues=(8, 20, 40, 2048)) def test_lof(n_neighbors, show=False): - mat = sio.loadmat(os.path.join(THIS_FOLDER, 'data', 'lofdata.mat')) - X = mat['X'] + mat = sio.loadmat(os.path.join(THIS_FOLDER, "data", "lofdata.mat")) + X = mat["X"] lof = LOF(n_neighbors) bad_channel_indices = lof.predict(X) print(bad_channel_indices) -@pytest.mark.parametrize(argnames='metric', - argvalues=('euclidean', 'nan_euclidean', - 'cosine', 'cityblock', 'manhattan')) +@pytest.mark.parametrize(argnames="metric", + argvalues=("euclidean", "nan_euclidean", + "cosine", "cityblock", "manhattan")) def test_lof2(metric, show=False): - mat = sio.loadmat(os.path.join(THIS_FOLDER, 'data', 'lofdata.mat')) - X = mat['X'] + mat = sio.loadmat(os.path.join(THIS_FOLDER, "data", "lofdata.mat")) + X = mat["X"] lof = LOF(20, metric) bad_channel_indices = lof.predict(X) print(bad_channel_indices) diff --git a/tests/test_ress.py b/tests/test_ress.py index d5c39d41..8e99cec5 100644 --- a/tests/test_ress.py +++ b/tests/test_ress.py @@ -4,9 +4,12 @@ import pytest import scipy.signal as ss from scipy.linalg import pinv + from meegkit import ress from meegkit.utils import fold, matmul3d, rms, snr_spectrum, unfold +rng = np.random.default_rng(9) + def create_data(n_times, n_chans=10, n_trials=20, freq=12, sfreq=250, noise_dim=8, SNR=.8, t0=100, show=False): @@ -20,7 +23,7 @@ def create_data(n_times, n_chans=10, n_trials=20, freq=12, sfreq=250, """ # source source = np.sin(2 * np.pi * freq * np.arange(n_times - t0) / sfreq)[None].T - s = source * np.random.randn(1, n_chans) + s = source * rng.standard_normal((1, n_chans)) s = s[:, :, np.newaxis] s = np.tile(s, (1, 1, n_trials)) signal = np.zeros((n_times, n_chans, n_trials)) @@ -28,8 +31,8 @@ def create_data(n_times, n_chans=10, n_trials=20, freq=12, sfreq=250, # noise noise = np.dot( - unfold(np.random.randn(n_times, noise_dim, n_trials)), - np.random.randn(noise_dim, n_chans)) + unfold(rng.standard_normal((n_times, noise_dim, n_trials))), + rng.standard_normal((noise_dim, n_chans))) noise = fold(noise, n_times) # mix signal and noise @@ -39,9 +42,9 @@ def create_data(n_times, n_chans=10, n_trials=20, freq=12, sfreq=250, if show: f, ax = plt.subplots(3) - ax[0].plot(signal[:, 0, 0], label='source') - ax[1].plot(noise[:, 1, 0], label='noise') - ax[2].plot(noisy_data[:, 1, 0], label='mixture') + ax[0].plot(signal[:, 0, 0], label="source") + ax[1].plot(noise[:, 1, 0], label="noise") + ax[2].plot(noisy_data[:, 1, 0], label="mixture") ax[0].legend() ax[1].legend() ax[2].legend() @@ -50,11 +53,11 @@ def create_data(n_times, n_chans=10, n_trials=20, freq=12, sfreq=250, return noisy_data, signal -@pytest.mark.parametrize('target', [12, 15, 20]) -@pytest.mark.parametrize('n_trials', [16]) -@pytest.mark.parametrize('peak_width', [.5, 1]) -@pytest.mark.parametrize('neig_width', [1]) -@pytest.mark.parametrize('neig_freq', [1]) +@pytest.mark.parametrize("target", [12, 15, 20]) +@pytest.mark.parametrize("n_trials", [16]) +@pytest.mark.parametrize("peak_width", [.5, 1]) +@pytest.mark.parametrize("neig_width", [1]) +@pytest.mark.parametrize("neig_freq", [1]) def test_ress(target, n_trials, peak_width, neig_width, neig_freq, show=False): """Test RESS.""" sfreq = 250 @@ -73,7 +76,7 @@ def test_ress(target, n_trials, peak_width, neig_width, neig_freq, show=False): nfft = 500 bins, psd = ss.welch(out.squeeze(1), sfreq, window="boxcar", nperseg=nfft / (peak_width * 2), - noverlap=0, axis=0, average='mean') + noverlap=0, axis=0, average="mean") # psd = np.abs(np.fft.fft(out, nfft, axis=0)) # psd = psd[0:psd.shape[0] // 2 + 1] # bins = np.linspace(0, sfreq // 2, psd.shape[0]) @@ -85,17 +88,17 @@ def test_ress(target, n_trials, peak_width, neig_width, neig_freq, show=False): # snr = snr.mean(1) if show: f, ax = plt.subplots(2) - ax[0].plot(bins, snr, ':o') - ax[0].axhline(1, ls=':', c='grey', zorder=0) - ax[0].axvline(target, ls=':', c='grey', zorder=0) - ax[0].set_ylabel('SNR (a.u.)') - ax[0].set_xlabel('Frequency (Hz)') + ax[0].plot(bins, snr, ":o") + ax[0].axhline(1, ls=":", c="grey", zorder=0) + ax[0].axvline(target, ls=":", c="grey", zorder=0) + ax[0].set_ylabel("SNR (a.u.)") + ax[0].set_xlabel("Frequency (Hz)") ax[0].set_xlim([0, 40]) ax[0].set_ylim([0, 10]) ax[1].plot(bins, psd) - ax[1].axvline(target, ls=':', c='grey', zorder=0) - ax[1].set_ylabel('PSD') - ax[1].set_xlabel('Frequency (Hz)') + ax[1].axvline(target, ls=":", c="grey", zorder=0) + ax[1].set_ylabel("PSD") + ax[1].set_xlabel("Frequency (Hz)") ax[1].set_xlim([0, 40]) # plt.show() @@ -113,16 +116,16 @@ def test_ress(target, n_trials, peak_width, neig_width, neig_freq, show=False): assert proj.shape == (n_times, n_chans, n_trials) if show: - f, ax = plt.subplots(data.shape[1], 2, sharey='col') + f, ax = plt.subplots(data.shape[1], 2, sharey="col") for c in range(data.shape[1]): - ax[c, 0].plot(data[:, c].mean(-1), lw=.5, label='data') - ax[c, 1].plot(proj[:, c].mean(-1), lw=.5, label='projection') + ax[c, 0].plot(data[:, c].mean(-1), lw=.5, label="data") + ax[c, 1].plot(proj[:, c].mean(-1), lw=.5, label="projection") if c < data.shape[1]: ax[c, 0].set_xticks([]) ax[c, 1].set_xticks([]) - ax[0, 0].set_title('Before') - ax[0, 1].set_title('After') + ax[0, 0].set_title("Before") + ax[0, 1].set_title("After") plt.legend() # 2 comps @@ -141,19 +144,19 @@ def test_ress(target, n_trials, peak_width, neig_width, neig_freq, show=False): _max = np.amax(combined_data) f, ax = plt.subplots(3) - ax[0].imshow(toress, label='toRESS') - ax[0].set_title('toRESS') - ax[1].imshow(fromress, label='fromRESS', vmin=-_max, vmax=_max) - ax[1].set_title('fromRESS') + ax[0].imshow(toress, label="toRESS") + ax[0].set_title("toRESS") + ax[1].imshow(fromress, label="fromRESS", vmin=-_max, vmax=_max) + ax[1].set_title("fromRESS") ax[2].imshow(pinv(toress), vmin=-_max, vmax=_max) - ax[2].set_title('toRESS$^{-1}$') + ax[2].set_title("toRESS$^{-1}$") plt.tight_layout() plt.show() print(np.sum(np.abs(pinv(toress) - fromress) >= .1)) -if __name__ == '__main__': +if __name__ == "__main__": import pytest pytest.main([__file__]) # test_ress(20, 16, 1, 1, 1, show=False) diff --git a/tests/test_signal.py b/tests/test_signal.py index 78eb7361..9e9688ec 100644 --- a/tests/test_signal.py +++ b/tests/test_signal.py @@ -1,15 +1,17 @@ """Test signal utils.""" import numpy as np -from scipy.signal import lfilter, butter, freqz -from meegkit.utils.sig import teager_kaiser, stmcb +from scipy.signal import butter, freqz, lfilter +from meegkit.utils.sig import stmcb, teager_kaiser + +rng = np.random.default_rng(9) def test_teager_kaiser(show=False): """Test Teager-Kaiser Energy.""" - x = 2 * np.random.rand(1000, 2) - 1 + x = 2 * rng.random((1000, 2)) - 1 x[100, 0] = 5 x[200, 0] = 5 - x += np.cumsum(np.random.randn(1000, 1) + 0.1, axis=0) / 1000 + x += np.cumsum(rng.standard_normal((1000, 1)) + 0.1, axis=0) / 1000 for i in range(1, 5): print(i) y = teager_kaiser(x, M=i, m=1) @@ -20,8 +22,8 @@ def test_teager_kaiser(show=False): plt.figure() # plt.plot((x[1:, 0] - x[1:, 0].mean()) / np.nanstd(x[1:, 0])) # plt.plot((y[..., 0] - y[..., 0].mean()) / np.nanstd(y[..., 0])) - plt.plot(x[1:, 0], label='X') - plt.plot(y[..., 0], label='Y') + plt.plot(x[1:, 0], label="X") + plt.plot(y[..., 0], label="Y") plt.legend() plt.show() @@ -39,10 +41,10 @@ def test_stcmb(show=True): if show: import matplotlib.pyplot as plt f, ax = plt.subplots(2, 1) - ax[0].plot(x, label='step') - ax[0].plot(y, label='filt') - ax[1].plot(w, np.abs(h), label='real') - ax[1].plot(ww, np.abs(hh), label='stcmb') + ax[0].plot(x, label="step") + ax[0].plot(y, label="filt") + ax[1].plot(w, np.abs(h), label="real") + ax[1].plot(ww, np.abs(hh), label="stcmb") ax[0].legend() ax[1].legend() plt.show() @@ -50,4 +52,5 @@ def test_stcmb(show=True): np.testing.assert_allclose(h, hh, rtol=2) # equal to 2% if __name__ == "__main__": + test_teager_kaiser() test_stcmb() diff --git a/tests/test_sns.py b/tests/test_sns.py index 2d175058..1ffd6fda 100644 --- a/tests/test_sns.py +++ b/tests/test_sns.py @@ -7,11 +7,11 @@ def test_sns(): """Test against NoiseTools.""" - mat = loadmat('./tests/data/snsdata.mat') - x = mat['x'] - y_sns = mat['y_sns'] - r_sns0 = mat['y_sns0'] - cx = mat['cx'] + mat = loadmat("./tests/data/snsdata.mat") + x = mat["x"] + y_sns = mat["y_sns"] + r_sns0 = mat["y_sns0"] + cx = mat["cx"] r = sns.sns0(cx, n_neighbors=4) assert_allclose(r, r_sns0) # assert our results match Matlab's @@ -20,6 +20,6 @@ def test_sns(): assert_allclose(y, y_sns) # assert our results match Matlab's -if __name__ == '__main__': +if __name__ == "__main__": import pytest pytest.main([__file__]) diff --git a/tests/test_star.py b/tests/test_star.py index cad4584c..5f85eaa5 100644 --- a/tests/test_star.py +++ b/tests/test_star.py @@ -5,6 +5,7 @@ from meegkit.star import star from meegkit.utils import demean, normcol +rng = np.random.default_rng(9) def test_star1(): """Test STAR 1.""" @@ -17,11 +18,10 @@ def test_star1(): def sim_data(n_samples, n_chans, f, SNR): target = np.sin(np.arange(n_samples) / n_samples * 2 * np.pi * f) target = target[:, np.newaxis] - noise = np.random.randn(n_samples, n_chans - 3) + noise = rng.standard_normal((n_samples, n_chans - 3)) - x0 = (normcol(np.dot( - noise, np.random.randn(noise.shape[1], n_chans))) + - SNR * target * np.random.randn(1, n_chans)) + x0 = normcol(np.dot(noise, rng.standard_normal((noise.shape[1], n_chans)))) \ + + SNR * target * rng.standard_normal((1, n_chans)) x0 = demean(x0) artifact = np.zeros(x0.shape) for k in np.arange(n_chans): @@ -31,7 +31,7 @@ def sim_data(n_samples, n_chans, f, SNR): # Test SNR=1 x, x0 = sim_data(n_samples, n_chans, f, SNR=np.sqrt(1)) - y, w, _ = star(x, 2, verbose='debug') + y, w, _ = star(x, 2, verbose="debug") assert_allclose(demean(y), x0) # check that denoised signal ~ x0 # Test more unfavourable SNR @@ -45,7 +45,7 @@ def sim_data(n_samples, n_chans, f, SNR): assert_allclose(demean(y)[:, 0], x[:, 0]) -if __name__ == '__main__': +if __name__ == "__main__": import pytest pytest.main([__file__]) # test_star1() diff --git a/tests/test_trca.py b/tests/test_trca.py index c45bcd33..de17d61a 100644 --- a/tests/test_trca.py +++ b/tests/test_trca.py @@ -4,13 +4,14 @@ import numpy as np import pytest import scipy.io + from meegkit.trca import TRCA from meegkit.utils.trca import itr, normfit, round_half_up ########################################################################## # Load data # ----------------------------------------------------------------------------- -path = os.path.join('.', 'tests', 'data', 'trcadata.mat') +path = os.path.join(".", "tests", "data", "trcadata.mat") mat = scipy.io.loadmat(path) eeg = mat["eeg"] @@ -41,12 +42,12 @@ [[38, 90], [32, 100]]] -@pytest.mark.parametrize('ensemble', [True, False]) -@pytest.mark.parametrize('method', ['original', 'riemann']) -@pytest.mark.parametrize('regularization', ['schaefer', 'scm']) +@pytest.mark.parametrize("ensemble", [True, False]) +@pytest.mark.parametrize("method", ["original", "riemann"]) +@pytest.mark.parametrize("regularization", ["schaefer", "scm"]) def test_trca(ensemble, method, regularization): """Test TRCA.""" - if method == 'original' and regularization == 'schaefer': + if method == "original" and regularization == "schaefer": pytest.skip("regularization only used for riemann version") len_gaze_s = 0.5 # data length for target identification [s] @@ -101,15 +102,15 @@ def test_trca(ensemble, method, regularization): # Mean accuracy and ITR computation mu, _, muci, _ = normfit(accs, alpha_ci) print(f"Mean accuracy = {mu:.1f}%\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)") # noqa - if method != 'riemann' or (regularization == 'scm' and ensemble): + if method != "riemann" or (regularization == "scm" and ensemble): assert mu > 95 mu, _, muci, _ = normfit(itrs, alpha_ci) print(f"Mean ITR = {mu:.1f}\t({ci:.0f}% CI: {muci[0]:.1f}-{muci[1]:.1f}%)") - if method != 'riemann' or (regularization == 'scm' and ensemble): + if method != "riemann" or (regularization == "scm" and ensemble): assert mu > 300 -if __name__ == '__main__': +if __name__ == "__main__": import pytest pytest.main([__file__]) # test_trcacode() diff --git a/tests/test_tspca.py b/tests/test_tspca.py index 382dad65..d1753182 100644 --- a/tests/test_tspca.py +++ b/tests/test_tspca.py @@ -1,9 +1,10 @@ +import matplotlib.pyplot as plt import numpy as np from meegkit import dss, sns, tspca from meegkit.utils import demean, fold, unfold -import matplotlib.pyplot as plt +rng = np.random.default_rng(9) def test_tspca_sns_dss(): # TODO @@ -16,12 +17,12 @@ def test_tspca_sns_dss(): # TODO Remove non-repeatable components with DSS. """ # Random data (time*chans*trials) - data = np.random.random((800, 102, 200)) - ref = np.random.random((800, 3, 200)) + data = rng.random((800, 102, 200)) + ref = rng.random((800, 3, 200)) # remove means noisy_data = demean(data) - noisy_ref = demean(ref) + demean(ref) # Apply TSPCA # ------------------------------------------------------------------------- @@ -34,18 +35,18 @@ def test_tspca_sns_dss(): # TODO # Apply SNS # ------------------------------------------------------------------------- nneighbors = 10 - print('SNS...') + print("SNS...") y_tspca_sns, r = sns.sns(y_tspca, nneighbors) - print('\b OK!') + print("\b OK!") # apply DSS # ------------------------------------------------------------------------- - print('DSS...') + print("DSS...") # Keep all PC components y_tspca_sns = demean(y_tspca_sns) print(y_tspca_sns.shape) todss, fromdss, _, _ = dss.dss1(y_tspca_sns) - print('\b OK!') + print("\b OK!") # c3 = DSS components y_tspca_sns_dss = fold( @@ -60,7 +61,7 @@ def test_tsr(show=True): sr = 200 nsamples = 10000 nchans = 10 - x = np.random.randn(nsamples, nchans) + x = rng.standard_normal((nsamples, nchans)) # artifact + harmonics artifact = np.sin(np.arange(nsamples) / sr * 2 * np.pi * 10)[:, None] @@ -78,10 +79,10 @@ def test_tsr(show=True): shifts=[0]) if show: - f, ax = plt.subplots(2, 1, num='without shifts') - ax[0].plot(y[:500, 0], 'grey', label='recovered signal') - ax[0].plot(x[:500, 0], ':', label='real signal') - ax[1].plot((y - x)[:500], label='residual') + f, ax = plt.subplots(2, 1, num="without shifts") + ax[0].plot(y[:500, 0], "grey", label="recovered signal") + ax[0].plot(x[:500, 0], ":", label="real signal") + ax[1].plot((y - x)[:500], label="residual") ax[0].legend() ax[1].legend() # plt.show() @@ -97,17 +98,17 @@ def test_tsr(show=True): shifts=[-1, 0, 1]) if show: - f, ax = plt.subplots(3, 1, num='with shifts') - ax[0].plot(signal[:500], label='signal + noise') - ax[1].plot(x[:500, 0], 'grey', label='real signal') - ax[1].plot(y[:500, 0], ':', label='recovered signal') - ax[2].plot((signal - y)[:500, 0], label='before - after') + f, ax = plt.subplots(3, 1, num="with shifts") + ax[0].plot(signal[:500], label="signal + noise") + ax[1].plot(x[:500, 0], "grey", label="real signal") + ax[1].plot(y[:500, 0], ":", label="recovered signal") + ax[2].plot((signal - y)[:500, 0], label="before - after") ax[0].legend() ax[1].legend() ax[2].legend() plt.show() -if __name__ == '__main__': +if __name__ == "__main__": import pytest pytest.main([__file__]) # test_tspca_sns_dss() diff --git a/tests/test_utils.py b/tests/test_utils.py index 01678516..2e2176fd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,16 +1,33 @@ import numpy as np -from meegkit.utils import (bootstrap_ci, demean, find_outlier_samples, - find_outlier_trials, fold, mean_over_trials, - multishift, multismooth, relshift, rms, shift, - shiftnd, unfold, widen_mask, cronbach, robust_mean) from numpy.testing import assert_almost_equal, assert_equal +from meegkit.utils import ( + bootstrap_ci, + cronbach, + demean, + find_outlier_samples, + find_outlier_trials, + fold, + mean_over_trials, + multishift, + multismooth, + relshift, + rms, + robust_mean, + shift, + shiftnd, + unfold, + widen_mask, +) + +rng = np.random.default_rng() + def _sim_data(n_times, n_chans, n_trials, noise_dim, SNR=1, t0=100): """Create synthetic data.""" # source source = np.sin(2 * np.pi * np.linspace(0, .5, n_times - t0))[np.newaxis].T - s = source * np.random.randn(1, n_chans) + s = source * rng.standard_normal((1, n_chans)) s = s[:, :, np.newaxis] s = np.tile(s, (1, 1, n_trials)) signal = np.zeros((n_times, n_chans, n_trials)) @@ -18,8 +35,8 @@ def _sim_data(n_times, n_chans, n_trials, noise_dim, SNR=1, t0=100): # noise noise = np.dot( - unfold(np.random.randn(n_times, noise_dim, n_trials)), - np.random.randn(noise_dim, n_chans)) + unfold(rng.standard_normal((n_times, noise_dim, n_trials))), + rng.standard_normal((noise_dim, n_chans))) noise = fold(noise, n_times) # mix signal and noise @@ -42,7 +59,7 @@ def test_multishift(): x = np.ones((4, 4, 3)) x[..., 1] *= 2 x[..., 2] *= 3 - xx = multishift(x, [-1, -2], reshape=True, solution='valid') + xx = multishift(x, [-1, -2], reshape=True, solution="valid") assert_equal(xx[..., 0], np.array([[1., 1., 1., 1., 1., 1., 1., 1.], [1., 1., 1., 1., 1., 1., 1., 1.]])) assert_equal(xx[..., 1], np.array([[1., 1., 1., 1., 1., 1., 1., 1.], @@ -125,7 +142,7 @@ def test_widen_mask(): def test_multismooth(): """Test smoothing.""" - x = (np.random.randn(1000, 1) / 2 + + x = (rng.standard_normal((1000, 1)) / 2 + np.cos(2 * np.pi * 3 * np.linspace(0, 20, 1000))[:, None]) for i in np.arange(1, 10, 1): @@ -141,7 +158,7 @@ def test_demean(show=False): n_trials = 100 n_chans = 8 n_times = 1000 - x = np.random.randn(n_times, n_chans, n_trials) + x = rng.standard_normal((n_times, n_chans, n_trials)) x, s = _sim_data(n_times, n_chans, n_trials, 8, SNR=10) # 1. demean and check trial average is almost zero @@ -158,14 +175,14 @@ def test_demean(show=False): if show: import matplotlib.pyplot as plt f, ax = plt.subplots(3, 1) - ax[0].plot(times, x[:, 0].mean(-1), label='noisy_data') - ax[0].plot(times, s[:, 0].mean(-1), label='signal') + ax[0].plot(times, x[:, 0].mean(-1), label="noisy_data") + ax[0].plot(times, s[:, 0].mean(-1), label="signal") ax[0].legend() - ax[1].plot(times, x1[:, 0].mean(-1), label='mean over entire epoch') + ax[1].plot(times, x1[:, 0].mean(-1), label="mean over entire epoch") ax[1].legend() - ax[2].plot(x2[:, 0].mean(-1), label='weighted mean') + ax[2].plot(x2[:, 0].mean(-1), label="weighted mean") plt.gca().set_prop_cycle(None) - ax[2].plot(s[:, 0].mean(-1), 'k:') + ax[2].plot(s[:, 0].mean(-1), "k:") ax[2].legend() plt.show() @@ -194,7 +211,7 @@ def test_demean(show=False): def test_computeci(): """Compute CI.""" - x = np.random.randn(1000, 8, 100) + x = rng.standard_normal((1000, 8, 100)) ci_low, ci_high = bootstrap_ci(x) assert ci_low.shape == (1000, 8) @@ -205,7 +222,7 @@ def test_computeci(): # assert ci_low.shape == (1000,) # assert ci_high.shape == (1000,) - x = np.random.randn(1000, 100) + x = rng.standard_normal((1000, 100)) ci_low, ci_high = bootstrap_ci(x) assert ci_low.shape == (1000,) @@ -214,7 +231,7 @@ def test_computeci(): def test_outliers(show=False): """Test outlier detection.""" - x = np.random.randn(250, 8, 50) # 50 trials, 8, channels + x = rng.standard_normal((250, 8, 50)) # 50 trials, 8, channels x[..., :5] *= 10 # 5 first trials are outliers # Pass standard threshold @@ -248,13 +265,13 @@ def test_cronbach(): m = robust_mean(X, axis=0) assert m.shape == (X.shape[1], X.shape[2]) -if __name__ == '__main__': +if __name__ == "__main__": import pytest pytest.main([__file__]) # test_outliers() # import matplotlib.pyplot as plt - # x = np.random.randn(1000,) + # x = rng.standard_normal((1000,) # y = multismooth(x, np.arange(1, 200, 4)) # plt.imshow(y.T, aspect='auto') # plt.show()