diff --git a/.circleci/config.yml b/.circleci/config.yml index 4175da6..118a5ed 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,26 +1,142 @@ -# Use the latest 2.1 version of CircleCI pipeline process engine. -# See: https://circleci.com/docs/configuration-reference version: 2.1 - -# Define a job to be invoked later in a workflow. -# See: https://circleci.com/docs/configuration-reference/#jobs jobs: - say-hello: - # Specify the execution environment. You can specify an image from Docker Hub or use one of our convenience images from CircleCI's Developer Hub. - # See: https://circleci.com/docs/configuration-reference/#executor-job - docker: - - image: cimg/base:stable - # Add steps to the job - # See: https://circleci.com/docs/configuration-reference/#steps - steps: - - checkout - - run: - name: "Say hello" - command: "echo Hello, World!" + build_docs: + parameters: + scheduled: + type: string + default: "false" + docker: + - image: cimg/base:current-22.04 + steps: + - checkout + - run: + name: Check-skip + command: | + set -e + export COMMIT_MESSAGE=$(git log --format=oneline -n 1); + if [[ -v CIRCLE_PULL_REQUEST ]] && ([[ "$COMMIT_MESSAGE" == *"[skip circle]"* ]] || [[ "$COMMIT_MESSAGE" == *"[circle skip]"* ]]); then + echo "Skip detected, exiting job ${CIRCLE_JOB} for PR ${CIRCLE_PULL_REQUEST}." + circleci-agent step halt; + fi + - run: + name: Set BASH_ENV + command: | + set -e + set -o pipefail + git clone --single-branch --branch main git@github.com:/mne-tools/mne-python.git + ./mne-python/tools/setup_xvfb.sh + sudo apt install -qq graphviz optipng python3.10-venv python3-venv libxft2 ffmpeg + python3.10 -m venv ~/python_env + echo "set -e" >> $BASH_ENV + echo "export OPENBLAS_NUM_THREADS=4" >> $BASH_ENV + echo "export XDG_RUNTIME_DIR=/tmp/runtime-circleci" >> $BASH_ENV + echo "export PATH=~/.local/bin/:$PATH" >> $BASH_ENV + echo "export DISPLAY=:99" >> $BASH_ENV + echo "source ~/python_env/bin/activate" >> $BASH_ENV + mkdir -p ~/.local/bin + ln -s ~/python_env/bin/python ~/.local/bin/python + echo "BASH_ENV:" + cat $BASH_ENV + mkdir -p ~/mne_data + - run: + name: Get Python running + command: | + pip install --upgrade PyQt6 sphinx-gallery pydata-sphinx-theme numpydoc scikit-learn git+https://github.com/pyvista/pyvista@main memory_profiler + pip install -ve ./mne-python . + - run: + name: Check Qt + command: | + ./mne-python/tools/check_qt_import.sh PyQt6 + - run: + name: Check installation + command: | + which python + QT_DEBUG_PLUGINS=1 mne sys_info -pd + python -c "import numpy; numpy.show_config()" + LIBGL_DEBUG=verbose python -c "import pyvistaqt; pyvistaqt.BackgroundPlotter(show=True)" + - run: + name: List packages + command: python -m pip list + - restore_cache: + keys: + - data-cache-somato + - run: + name: Get data + command: | + python -c "import mne; mne.datasets.somato.data_path(update_path=True, verbose=True)" + ls -al ~/mne_data; + - run: + name: make html + command: | + make -C doc html + - store_test_results: + path: doc/_build/test-results + - store_artifacts: + path: doc/_build/test-results + destination: test-results + - store_artifacts: + path: doc/_build/html/ + destination: dev + - persist_to_workspace: + root: doc/_build + paths: + - html + - save_cache: + key: data-cache-somato + paths: + - ~/mne_data/MNE-somato-data + + deploy: + machine: + image: ubuntu-2004:202111-01 + steps: + - attach_workspace: + at: /tmp/build + - restore_cache: + keys: + - website-cache + - add_ssh_keys: + fingerprints: + - "18:ca:dc:af:9a:80:ee:23:91:fd:84:ae:93:7a:7b:4f" + - run: + name: Deploy docs + command: | + set -eo pipefail + mkdir -p ~/.ssh + echo -e "Host *\nStrictHostKeyChecking no" > ~/.ssh/config + chmod og= ~/.ssh/config + git config --global user.email "circle@mne.tools" + git config --global user.name "Circle CI" + if [ ! -d ~/mne-tools.github.io ]; then + git clone git@github.com:/mne-tools/mne-tools.github.io.git ~/mne-tools.github.io + fi + cd ~/mne-tools.github.io + git checkout main + git fetch origin + git reset --hard origin/main + git clean -xdf + echo "Deploying dev docs for ${CIRCLE_BRANCH}." + mkdir -p mne-gui-addons + rm -Rf mne-gui-addons/dev + cp -a /tmp/build/html mne-gui-addons/dev + git add -A + git commit -m "CircleCI update of mne-gui-addons docs (${CIRCLE_BUILD_NUM})." + git push origin main + - save_cache: + key: website-cache + paths: + - ~/mne-tools.githbub.io -# Orchestrate jobs using workflows -# See: https://circleci.com/docs/configuration-reference/#workflows workflows: - say-hello-workflow: + default: jobs: - - say-hello + - build_docs: + name: build_docs + - deploy: + name: deploy + requires: + - build_docs + filters: + branches: + only: + - main diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fa588df..0e9d607 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -20,15 +20,20 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ["3.11"] # Oldest and newest supported versions + os: [ubuntu-latest, windows-latest] + python-version: ["3.11"] # Newest supported version (dipy not out for 3.11 yet!) mne-version: [mne-main] qt: [PyQt6] include: + # macOS (can be moved above once Dipy releases 3.11 wheels) + - os: macos-latest + python-version: "3.10" + mne-version: mne-main # TODO: Set back to mne-stable once 1.4 is out (we need its pytest fixtures) + qt: PyQt6 # Old (and PyQt5) - os: ubuntu-latest python-version: "3.8" - mne-version: mne-stable + mne-version: mne-main # TODO: Set back to mne-stable once 1.4 is out (we need its pytest fixtures) qt: PyQt5 # PySide2 - os: ubuntu-latest diff --git a/.gitignore b/.gitignore index 77e0684..b6008e4 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ dmypy.json # Other junit-results.xml +doc/_build/ +doc/auto_examples/ +doc/generated/ diff --git a/doc/Makefile b/doc/Makefile new file mode 100644 index 0000000..20bf1ca --- /dev/null +++ b/doc/Makefile @@ -0,0 +1,43 @@ +SPHINXOPTS = -nWT --keep-going +SPHINXBUILD = sphinx-build +PAPER = +MPROF = SG_STAMP_STARTS=true mprof run -E --python sphinx + +# Internal variables. +PAPEROPT_a4 = -D latex_paper_size=a4 +PAPEROPT_letter = -D latex_paper_size=letter +ALLSPHINXOPTS = -d _build/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . + +.PHONY: help clean html dirhtml pickle json htmlhelp qthelp latex changes linkcheck doctest + +# make with no arguments will build the first target by default, i.e., build standalone HTML files +first_target: html-noplot + +help: + @echo "Please use \`make ' where is one of" + @echo " html to make standalone HTML files (dev version)" + @echo " html-pattern to make standalone HTML files for one example dir (dev version)" + @echo " *-noplot to make standalone HTML files without plotting" + +clean: + -rm -rf _build auto_examples auto_tutorials generated *.stc *.fif *.nii.gz + +html: + $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) _build/html + @echo + @echo "Build finished. The HTML pages are in _build/html." + +html-pattern: + $(SPHINXBUILD) -D sphinx_gallery_conf.filename_pattern=$(PATTERN) -D sphinx_gallery_conf.run_stale_examples=True -b html $(ALLSPHINXOPTS) _build/html + @echo + @echo "Build finished. The HTML pages are in _build/html" + +html-noplot: + $(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) _build/html + @echo + @echo "Build finished. The HTML pages are in _build/html." + +view: + @python -c "import webbrowser; webbrowser.open_new_tab('file://$(PWD)/_build/html/index.html')" + +show: view diff --git a/doc/_templates/autosummary/class.rst b/doc/_templates/autosummary/class.rst new file mode 100644 index 0000000..fe47440 --- /dev/null +++ b/doc/_templates/autosummary/class.rst @@ -0,0 +1,12 @@ +{{ fullname | escape | underline }} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :special-members: __contains__,__getitem__,__iter__,__len__,__add__,__sub__,__mul__,__div__,__neg__ + :members: + +.. _sphx_glr_backreferences_{{ fullname }}: + +.. minigallery:: {{ fullname }} + :add-heading: diff --git a/doc/_templates/autosummary/function.rst b/doc/_templates/autosummary/function.rst new file mode 100644 index 0000000..bd78b8e --- /dev/null +++ b/doc/_templates/autosummary/function.rst @@ -0,0 +1,10 @@ +{{ fullname | escape | underline }} + +.. currentmodule:: {{ module }} + +.. autofunction:: {{ objname }} + +.. _sphx_glr_backreferences_{{ fullname }}: + +.. minigallery:: {{ fullname }} + :add-heading: diff --git a/doc/api.rst b/doc/api.rst new file mode 100644 index 0000000..4ef02be --- /dev/null +++ b/doc/api.rst @@ -0,0 +1,17 @@ +.. _api_reference: + +==================== +Python API Reference +==================== + +:py:mod:`mne_gui_addons`: + +.. automodule:: mne_gui_addons + :no-members: + :no-inherited-members: + +.. autosummary:: + :toctree: generated/ + + locate_ieeg + view_vol_stc diff --git a/doc/conf.py b/doc/conf.py new file mode 100644 index 0000000..d39b612 --- /dev/null +++ b/doc/conf.py @@ -0,0 +1,158 @@ +import faulthandler +import os +import sys + +import pyvista +import mne +import mne_gui_addons + +faulthandler.enable() +os.environ["_MNE_BROWSER_NO_BLOCK"] = "true" +os.environ["MNE_BROWSER_OVERVIEW_MODE"] = "hidden" +os.environ["MNE_BROWSER_THEME"] = "light" +os.environ["MNE_3D_OPTION_THEME"] = "light" + +project = "MNE-GUI-Addons" +release = mne_gui_addons.__version__ +version = ".".join(release.split(".")[:2]) +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "numpydoc", + "sphinx_gallery.gen_gallery", +] +templates_path = ["_templates"] +source_suffix = ".rst" +master_doc = "index" +exclude_trees = ["_build"] +default_role = "py:obj" +modindex_common_prefix = ["mne_gui_addons."] +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable", None), + "scipy": ("https://docs.scipy.org/doc/scipy", None), + "matplotlib": ("https://matplotlib.org/stable", None), + "nibabel": ("https://nipy.org/nibabel", None), + "dipy": ( + "https://dipy.org/documentation/latest/", + "https://dipy.org/documentation/latest/objects.inv/", + ), + "mne": ("https://mne.tools/stable", None), +} +numpydoc_class_members_toctree = False +numpydoc_attributes_as_param_list = True +numpydoc_xref_param_type = True +numpydoc_xref_aliases = { + # MNE + "SourceSpaces": "mne.SourceSpaces", + "Info": "mne.Info", + "Epochs": "mne.Epochs", + "AverageTFR": "mne.time_frequency.AverageTFR", + "EpochsTFR": "mne.time_frequency.EpochsTFR", + "Transform": "mne.transforms.Transform", + # MNE-GUI-Addons + # 'IntracranialElectrodeLocator': 'mne_gui_addons.IntracranialElectrodeLocator', # Many doc errors! +} +numpydoc_xref_ignore = { + # words + "instance", + "instances", + "of", + "default", + "shape", + "or", + "with", + "length", + "pair", + "matplotlib", + "optional", + "kwargs", + "in", + "dtype", + "object", + # not documented + "IntracranialElectrodeLocator", + "VolSourceEstimateViewer", +} +numpydoc_validate = True +numpydoc_validation_checks = { + "all", + # These we do not live by: + "GL01", # Docstring should start in the line immediately after the quotes + "EX01", + "EX02", # examples failed (we test them separately) + "ES01", # no extended summary + "SA01", # no see also + "YD01", # no yields section + "SA04", # no description in See Also + "PR04", # Parameter "shape (n_channels" has no type + "RT02", # The first line of the Returns section should contain only the type, unless multiple values are being returned # noqa +} +numpydoc_validation_exclude = { # set of regex + r"mne\.utils\.deprecated", +} +pyvista.OFF_SCREEN = False +pyvista.BUILDING_GALLERY = True +sphinx_gallery_conf = { + "doc_module": ("mne_gui_addons",), + "reference_url": dict(mne_gui_addons=None), + "examples_dirs": ["../examples"], + "gallery_dirs": ["auto_examples"], + "backreferences_dir": "generated", + "plot_gallery": "True", # Avoid annoying Unicode/bool default warning + "thumbnail_size": (160, 112), + "remove_config_comments": True, + "min_reported_time": 1.0, + "abort_on_example_error": False, + "image_scrapers": ("matplotlib", mne.gui._GUIScraper(), "pyvista"), + "show_memory": not sys.platform.startswith(("win", "darwin")), + "line_numbers": False, # messes with style + "capture_repr": ("_repr_html_",), + "junit": os.path.join("..", "test-results", "sphinx-gallery", "junit.xml"), + "matplotlib_animations": True, + "compress_images": ("images", "thumbnails"), + "filename_pattern": "^((?!sgskip).)*$", +} + +autosummary_generate = True +autodoc_default_options = {"inherited-members": None} +nitpicky = True +nitpick_ignore = [] +nitpick_ignore_regex = [] +html_theme = "pydata_sphinx_theme" +html_theme_options = { + "icon_links": [ + dict( + name="GitHub", + url="https://github.com/mne-tools/mne-gui-addons", + icon="fa-brands fa-square-github", + ), + dict( + name="Mastodon", + url="https://fosstodon.org/@mne", + icon="fa-brands fa-mastodon", + attributes=dict(rel="me"), + ), + dict( + name="Twitter", + url="https://twitter.com/mne_python", + icon="fa-brands fa-square-twitter", + ), + dict( + name="Discourse", + url="https://mne.discourse.group/", + icon="fa-brands fa-discourse", + ), + dict( + name="Discord", + url="https://discord.gg/rKfvxTuATa", + icon="fa-brands fa-discord", + ), + ], + "use_edit_page_button": False, +} +html_show_sourcelink = False +html_copy_source = False +html_show_sphinx = False +htmlhelp_basename = "mne-gui-addons-doc" diff --git a/doc/index.rst b/doc/index.rst new file mode 100644 index 0000000..a23d5db --- /dev/null +++ b/doc/index.rst @@ -0,0 +1,10 @@ +MNE-GUI-Addons +-------------- + +MNE-Python GUI addons. + +.. toctree:: + :maxdepth: 1 + + api.rst + auto_examples/index diff --git a/examples/README.txt b/examples/README.txt new file mode 100644 index 0000000..91187c9 --- /dev/null +++ b/examples/README.txt @@ -0,0 +1,4 @@ +Examples +-------- + +These examples show how to use MNE-GUI-Addons. diff --git a/examples/evoked_ers_source_power.py b/examples/evoked_ers_source_power.py new file mode 100644 index 0000000..1f8c2b9 --- /dev/null +++ b/examples/evoked_ers_source_power.py @@ -0,0 +1,282 @@ +# -*- coding: utf-8 -*- +""" +.. _ex-source-loc-methods: + +===================================================================== +Compute evoked ERS source power using DICS, LCMV beamformer, and dSPM +===================================================================== + +Here we examine 3 ways of localizing event-related synchronization (ERS) of +beta band activity in this dataset: :ref:`somato-dataset` using +:term:`DICS`, :term:`LCMV beamformer`, and :term:`dSPM` applied to active and +baseline covariance matrices. +""" +# Authors: Luke Bloy +# Eric Larson +# Alex Rockhill +# +# License: BSD-3-Clause + +# %% + +import numpy as np + +import mne_gui_addons + +import mne +from mne.cov import compute_covariance +from mne.datasets import somato +from mne.time_frequency import csd_tfr +from mne.beamformer import make_dics, apply_dics_csd, make_lcmv, apply_lcmv_cov +from mne.minimum_norm import make_inverse_operator, apply_inverse_cov + +print(__doc__) + +# %% +# Reading the raw data and creating epochs: + +data_path = somato.data_path() +subject = "01" +subjects_dir = data_path / "derivatives" / "freesurfer" / "subjects" +task = "somato" +raw_fname = ( + data_path + / "sub-{}".format(subject) + / "meg" + / "sub-{}_task-{}_meg.fif".format(subject, task) +) + +# crop to 5 minutes to save memory +raw = mne.io.read_raw_fif(raw_fname).crop(0, 300) + +# We are interested in the beta band (12-30 Hz) +raw.load_data().filter(12, 30) + +# The DICS beamformer currently only supports a single sensor type. +# We'll use the gradiometers in this example. +picks = mne.pick_types(raw.info, meg="grad", exclude="bads") + +# Read epochs +events = mne.find_events(raw) +epochs = mne.Epochs( + raw, events, event_id=1, tmin=-1.5, tmax=2, picks=picks, preload=True, decim=3 +) + +# Read forward operator and point to freesurfer subject directory +fwd_fname = ( + data_path + / "derivatives" + / "sub-{}".format(subject) + / "sub-{}_task-{}-fwd.fif".format(subject, task) +) +fwd = mne.read_forward_solution(fwd_fname) + +# %% +# Compute covariances and cross-spectral density +# ---------------------------------------------- +# ERS activity starts at 0.5 seconds after stimulus onset. Because these +# data have been processed by MaxFilter directly (rather than MNE-Python's +# version), we have to be careful to compute the rank with a more conservative +# threshold in order to get the correct data rank (64). Once this is used in +# combination with an advanced covariance estimator like "shrunk", the rank +# will be correctly preserved. + +rank = mne.compute_rank(epochs, tol=1e-6, tol_kind="relative") +win_active = (0.5, 1.5) +win_baseline = (-1, 0) +cov_baseline = compute_covariance( + epochs, + tmin=win_baseline[0], + tmax=win_baseline[1], + method="shrunk", + rank=rank, + verbose=True, +) +cov_active = compute_covariance( + epochs, + tmin=win_active[0], + tmax=win_active[1], + method="shrunk", + rank=rank, + verbose=True, +) + +# when the covariance objects are added together, they are scaled by the size +# of the window used to create them so that the average is properly weighted +cov_common = cov_baseline + cov_active +cov_baseline.plot(epochs.info) + +freqs = np.logspace(np.log10(12), np.log10(30), 9) + +# time-frequency decomposition +epochs_tfr = mne.time_frequency.tfr_morlet( + epochs, + freqs=freqs, + n_cycles=freqs / 2, + return_itc=False, + average=False, + output="complex", +) +epochs_tfr.decimate(20) # decimate for speed + +# compute cross-spectral density matrices +csd = csd_tfr(epochs_tfr, tmin=-1, tmax=1.5) +csd_baseline = csd_tfr(epochs_tfr, tmin=win_baseline[0], tmax=win_baseline[1]) +csd_ers = csd_tfr(epochs_tfr, tmin=win_active[0], tmax=win_active[1]) + +csd_baseline.plot() + +# %% +# Compute some source estimates +# ----------------------------- +# Here we will use DICS, LCMV beamformer, and dSPM. +# +# See :ref:`ex-inverse-source-power` for more information about DICS. + + +def _gen_dics(csd, ers_csd, csd_baseline, fwd): + filters = make_dics( + epochs.info, + fwd, + csd.mean(), + pick_ori="max-power", + reduce_rank=True, + real_filter=True, + rank=rank, + ) + stc_base, freqs = apply_dics_csd(csd_baseline.mean(), filters) + stc_act, freqs = apply_dics_csd(csd_ers.mean(), filters) + stc_act /= stc_base + return stc_act + + +# generate lcmv source estimate +def _gen_lcmv(active_cov, cov_baseline, common_cov, fwd): + filters = make_lcmv( + epochs.info, fwd, common_cov, reg=0.05, noise_cov=None, pick_ori="max-power" + ) + stc_base = apply_lcmv_cov(cov_baseline, filters) + stc_act = apply_lcmv_cov(cov_active, filters) + stc_act /= stc_base + return stc_act + + +# generate mne/dSPM source estimate +def _gen_mne(cov_active, cov_baseline, cov_common, fwd, info, method="dSPM"): + inverse_operator = make_inverse_operator(info, fwd, cov_common) + stc_act = apply_inverse_cov( + cov_active, info, inverse_operator, method=method, verbose=True + ) + stc_base = apply_inverse_cov( + cov_baseline, info, inverse_operator, method=method, verbose=True + ) + stc_act /= stc_base + return stc_act + + +# Compute source estimates +stc_dics = _gen_dics(csd, csd_ers, csd_baseline, fwd) +stc_lcmv = _gen_lcmv(cov_active, cov_baseline, cov_common, fwd) +stc_dspm = _gen_mne(cov_active, cov_baseline, cov_common, fwd, epochs.info) + +# %% +# Plot source estimates +# --------------------- +# DICS: + +brain_dics = stc_dics.plot( + hemi="rh", + subjects_dir=subjects_dir, + subject=subject, + time_label="DICS source power in the 12-30 Hz frequency band", +) + +# %% +# LCMV: + +brain_lcmv = stc_lcmv.plot( + hemi="rh", + subjects_dir=subjects_dir, + subject=subject, + time_label="LCMV source power in the 12-30 Hz frequency band", +) + +# %% +# dSPM: + +brain_dspm = stc_dspm.plot( + hemi="rh", + subjects_dir=subjects_dir, + subject=subject, + time_label="dSPM source power in the 12-30 Hz frequency band", +) + +# %% +# Use volume source estimate with time-frequency resolution +# --------------------------------------------------------- + +# make a volume source space +surface = subjects_dir / subject / "bem" / "inner_skull.surf" +vol_src = mne.setup_volume_source_space( + subject=subject, + subjects_dir=subjects_dir, + surface=surface, + pos=10, + add_interpolator=False, +) # just for speed! + +conductivity = (0.3,) # one layer for MEG +model = mne.make_bem_model( + subject=subject, + ico=3, # just for speed + conductivity=conductivity, + subjects_dir=subjects_dir, +) +bem = mne.make_bem_solution(model) + +trans = fwd["info"]["mri_head_t"] +vol_fwd = mne.make_forward_solution( + raw.info, + trans=trans, + src=vol_src, + bem=bem, + meg=True, + eeg=True, + mindist=5.0, + n_jobs=1, + verbose=True, +) + +# Compute source estimate using MNE solver +snr = 3.0 +lambda2 = 1.0 / snr**2 +method = "MNE" # use MNE method (could also be dSPM or sLORETA) + +# make a different inverse operator for each frequency so as to properly +# whiten the sensor data +inverse_operator = list() +for freq_idx in range(epochs_tfr.freqs.size): + # for each frequency, compute a separate covariance matrix + cov_baseline = csd_baseline.get_data(index=freq_idx, as_cov=True) + cov_baseline["data"] = cov_baseline["data"].real # only normalize by real + # then use that covariance matrix as normalization for the inverse + # operator + inverse_operator.append( + mne.minimum_norm.make_inverse_operator(epochs.info, vol_fwd, cov_baseline) + ) + +# finally, compute the stcs for each epoch and frequency +stcs = mne.minimum_norm.apply_inverse_tfr_epochs( + epochs_tfr, inverse_operator, lambda2, method=method, pick_ori="vector" +) + +# %% +# Plot volume source estimates +# ---------------------------- + +viewer = mne_gui_addons.view_vol_stc( + stcs, subject=subject, subjects_dir=subjects_dir, src=vol_src, inst=epochs_tfr +) +viewer.go_to_extreme() # show the maximum intensity source vertex +viewer.set_cmap(vmin=0.25, vmid=0.8) +viewer.set_3d_view(azimuth=40, elevation=35, distance=350) diff --git a/mne_gui_addons/__init__.py b/mne_gui_addons/__init__.py index 1da7a9c..42d69ad 100644 --- a/mne_gui_addons/__init__.py +++ b/mne_gui_addons/__init__.py @@ -1,7 +1,322 @@ +"""Convenience functions for opening GUIs.""" + +# Authors: Alex Rockhill +# +# License: BSD-3-Clause + from importlib.metadata import version, PackageNotFoundError +import numpy as np +from mne.utils import verbose as _verbose, _check_option + +from ._utils import _fill_doc + try: __version__ = version("mne_gui_addons") except PackageNotFoundError: # package is not installed __version__ = "0.0.0" # pragma: no cover + + +@_verbose +@_fill_doc +def locate_ieeg( + info, + trans, + base_image, + subject=None, + subjects_dir=None, + groups=None, + show=True, + block=False, + verbose=None, +): + """Locate intracranial electrode contacts. + + Parameters + ---------- + %(info_not_none)s + %(trans_not_none)s + base_image : path-like | nibabel.spatialimages.SpatialImage + The CT or MR image on which the electrode contacts can located. It + must be aligned to the Freesurfer T1 if ``subject`` and + ``subjects_dir`` are provided. Path-like inputs and nibabel image + objects are supported. + %(subject)s + %(subjects_dir)s + groups : dict | None + A dictionary with channels as keys and their group index as values. + If None, the groups will be inferred by the channel names. Channel + names must have a format like ``LAMY 7`` where a string prefix + like ``LAMY`` precedes a numeric index like ``7``. If the channels + are formatted improperly, group plotting will work incorrectly. + Group assignments can be adjusted in the GUI. + show : bool + Show the GUI if True. + block : bool + Whether to halt program execution until the figure is closed. + %(verbose)s + + Returns + ------- + gui : instance of IntracranialElectrodeLocator + The graphical user interface (GUI) window. + """ + from mne.viz.backends._utils import _init_mne_qtapp, _qt_app_exec + from ._ieeg_locate import IntracranialElectrodeLocator + + app = _init_mne_qtapp() + + gui = IntracranialElectrodeLocator( + info, + trans, + base_image, + subject=subject, + subjects_dir=subjects_dir, + groups=groups, + show=show, + verbose=verbose, + ) + if block: + _qt_app_exec(app) + return gui + + +@_verbose +@_fill_doc +def view_vol_stc( + stcs, + freq_first=True, + group=False, + subject=None, + subjects_dir=None, + src=None, + inst=None, + use_int=True, + show_topomap=True, + tmin=None, + tmax=None, + show=True, + block=False, + verbose=None, +): + """View a volume time and/or frequency source time course estimate. + + Parameters + ---------- + stcs : list of list | generator + The source estimates, the options are: 1) List of lists or + generators for epochs and frequencies (i.e. using + :func:`mne.minimum_norm.apply_inverse_tfr_epochs` or + :func:`mne.beamformer.apply_dics_tfr_epochs`-- in this case + use ``freq_first=False``), or 2) List of source estimates across + frequencies (e.g. :func:`mne.beamformer.apply_dics_csd`), + or 3) List of source estimates across epochs + (e.g. :func:`mne.minimum_norm.apply_inverse_epochs` and + :func:`mne.beamformer.apply_dics_epochs`--in this + case use ``freq_first=False``), or 4) Single + source estimates (e.g. :func:`mne.minimum_norm.apply_inverse` + and :func:`mne.beamformer.apply_dics`, note ``freq_first`` + will not be used in this case), or 5) List of list of lists or + generators for subjects and frequencies and epochs (e.g. + :func:`mne.minimum_norm.apply_inverse_tfr_epochs` for each subject in + a list; use ``group=True``), or 6) List or generator for subjects + with ``stcs`` from evoked data (e.g. + :func:`mne.minimum_norm.apply_inverse` or + :func:`mne.beamformer.apply_dics_csd` for each subject in a + list; use ``group=True``). + freq_first : bool + If frequencies are the outer list of ``stcs`` use ``True``. + group : bool | str + If data is from different subjects is, group should be ``True``. + If data is in time-frequency, group should be ``'ITC'`` to show + inter-trial coherence (power is shown by default). + %(subject)s + %(subjects_dir)s + src : instance of SourceSpaces + The volume source space for the ``stc``. + inst : EpochsTFR | AverageTFR | None | list + The time-frequency or data instances to use to plot topography. + If group-level results are given (``group=True``), a list of + instances should be provided. + use_int : bool + If ``True``, cast the data to integers to reduce memory use. + show_topomap : bool + Whether to show the sensor topomap in the GUI. + %(tmin)s + %(tmax)s + show : bool + Show the GUI if True. + block : bool + Whether to halt program execution until the figure is closed. + %(verbose)s + + Returns + ------- + gui : instance of VolSourceEstimateViewer + The graphical user interface (GUI) window. + """ + from mne.viz.backends._utils import _init_mne_qtapp, _qt_app_exec + from ._vol_stc import ( + VolSourceEstimateViewer, + BASE_INT_DTYPE, + COMPLEX_DTYPE, + RANGE_VALUE, + ) + + _check_option("group", group, (True, False, "itc", "power")) + + app = _init_mne_qtapp() + + def itc(data): + data = np.array(data) + return (np.abs((data / np.abs(data)).mean(axis=0)) * (RANGE_VALUE - 1)).astype( + BASE_INT_DTYPE + ) + + # cast to integers to lower memory usage, use custom complex data + # type if necessary + data = list() + for group_stcs in stcs if group else [stcs]: + # can be generator, compute using first stc object, just a general + # rescaling of data, does not need to be precise + scalar = None # rescale per subject for better comparison + outer_data = list() + for inner_stcs in group_stcs if np.iterable(group_stcs) else [group_stcs]: + inner_data = list() + for stc in inner_stcs if np.iterable(inner_stcs) else [inner_stcs]: + stc.crop(tmin=tmin, tmax=tmax) + if use_int: + if np.iscomplexobj(stc.data) and not group: + if scalar is None: + # this is an order of magnitude approximation, + # if another stc is 10x larger than the first one, + # it will have some clipping + scalar = (RANGE_VALUE - 1) / stc.data.real.max() / 10 + stc_data = np.zeros(stc.data.shape, COMPLEX_DTYPE) + stc_data["re"] = np.clip( + stc.data.real * scalar, -RANGE_VALUE, RANGE_VALUE - 1 + ) + stc_data["im"] = np.clip( + stc.data.imag * scalar, -RANGE_VALUE, RANGE_VALUE - 1 + ) + inner_data.append(stc_data) + else: + if group in (True, "power") and np.iscomplexobj(stc.data): + stc_data = (stc.data * stc.data.conj()).real + else: + stc_data = stc.data.copy() + if scalar is None: + scalar = (RANGE_VALUE - 1) / stc_data.max() / 5 + # ignore group == 'itc' if not complex + use_itc = group == "itc" and np.iscomplexobj(stc.data) + inner_data.append( + stc_data + if use_itc + else np.clip( + stc_data * scalar, -RANGE_VALUE, RANGE_VALUE - 1 + ).astype(BASE_INT_DTYPE) + ) + else: + inner_data.append(stc.data) + # compute ITC here, need epochs + if group == "itc" and np.iscomplexobj(stc.data) and freq_first: + outer_data.append(itc(inner_data)) + else: + outer_data.append( + np.mean(inner_data, axis=0).round().astype(BASE_INT_DTYPE) + if group and freq_first + else inner_data + ) + + # compute ITC here, need epochs + if group == "itc" and np.iscomplexobj(stc.data) and not freq_first: + data.append(itc(outer_data)) + else: + data.append( + np.mean(outer_data, axis=0).round().astype(BASE_INT_DTYPE) + if group and not freq_first + else outer_data + ) + + data = np.array(data) + + if not group: + data = data[0] # flatten group dimension + + if data.ndim == 4: # scalar solution, add dimension at the end + data = data[..., None] + + # move frequencies to penultimate + data = data.transpose( + (1, 2, 3, 0, 4) if freq_first and not group else (0, 2, 3, 1, 4) + ) + + # crop inst(s) to tmin and tmax + for this_inst in inst if isinstance(inst, (list, tuple)) else [inst]: + this_inst.crop(tmin=tmin, tmax=tmax) + + gui = VolSourceEstimateViewer( + data, + subject=subject, + subjects_dir=subjects_dir, + src=src, + inst=inst, + show_topomap=show_topomap, + group=group, + show=show, + verbose=verbose, + ) + if block: + _qt_app_exec(app) + return gui + + +class _GUIScraper(object): + """Scrape GUI outputs.""" + + def __repr__(self): + return "" + + def __call__(self, block, block_vars, gallery_conf): + from ._ieeg_locate import IntracranialElectrodeLocator + from ._vol_stc import VolSourceEstimateViewer + from sphinx_gallery.scrapers import figure_rst + from qtpy import QtGui + + for gui in block_vars["example_globals"].values(): + if ( + isinstance(gui, (IntracranialElectrodeLocator, VolSourceEstimateViewer)) + and not getattr(gui, "_scraped", False) + and gallery_conf["builder_name"] == "html" + ): + gui._scraped = True # monkey-patch but it's easy enough + img_fname = next(block_vars["image_path_iterator"]) + # TODO fix in window refactor + window = gui if hasattr(gui, "grab") else gui._renderer._window + # window is QWindow + # https://doc.qt.io/qt-5/qwidget.html#grab + pixmap = window.grab() + if hasattr(gui, "_renderer"): # if no renderer, no need + # Now the tricky part: we need to get the 3D renderer, + # extract the image from it, and put it in the correct + # place in the pixmap. The easiest way to do this is + # actually to save the 3D image first, then load it + # using QPixmap and Qt geometry. + plotter = gui._renderer.plotter + plotter.screenshot(img_fname) + sub_pixmap = QtGui.QPixmap(img_fname) + # https://doc.qt.io/qt-5/qwidget.html#mapTo + # https://doc.qt.io/qt-5/qpainter.html#drawPixmap-1 + QtGui.QPainter(pixmap).drawPixmap( + plotter.mapTo(window, plotter.rect().topLeft()), sub_pixmap + ) + # https://doc.qt.io/qt-5/qpixmap.html#save + pixmap.save(img_fname) + try: # for compatibility with both GUIs, will be refactored + gui._renderer.close() # TODO should be triggered by close + except Exception: + pass + gui.close() + return figure_rst([img_fname], gallery_conf["src_dir"], "GUI") + return "" diff --git a/mne_gui_addons/_core.py b/mne_gui_addons/_core.py new file mode 100644 index 0000000..f960c84 --- /dev/null +++ b/mne_gui_addons/_core.py @@ -0,0 +1,640 @@ +# -*- coding: utf-8 -*- +"""Shared GUI classes and functions.""" + +# Authors: Alex Rockhill +# +# License: BSD (3-clause) + +import os +import os.path as op +import numpy as np +from functools import partial + +from qtpy import QtCore +from qtpy.QtCore import Slot, Qt +from qtpy.QtWidgets import ( + QMainWindow, + QGridLayout, + QVBoxLayout, + QHBoxLayout, + QLabel, + QMessageBox, + QWidget, + QLineEdit, +) + +from matplotlib import patheffects +from matplotlib.backends.backend_qt5agg import FigureCanvas +from matplotlib.figure import Figure +from matplotlib.patches import Rectangle + +from mne.viz.backends.renderer import _get_renderer +from mne.viz.utils import safe_event +from mne.surface import _read_mri_surface, _marching_cubes +from mne.transforms import apply_trans, _frame_to_str +from mne.utils import ( + logger, + _check_fname, + verbose, + warn, + get_subjects_dir, + _import_nibabel, +) +from mne.viz.backends._utils import _qt_safe_window + +_IMG_LABELS = [["I", "P"], ["I", "L"], ["P", "L"]] +_ZOOM_STEP_SIZE = 5 + + +@verbose +def _load_image(img, verbose=None): + """Load data from a 3D image file (e.g. CT, MR).""" + nib = _import_nibabel("use GUI") + if not isinstance(img, nib.spatialimages.SpatialImage): + logger.debug(f"Loading {img}") + _check_fname(img, overwrite="read", must_exist=True) + img = nib.load(img) + # get data + orig_data = np.array(img.dataobj).astype(np.float32) + # reorient data to RAS + ornt = nib.orientations.axcodes2ornt( + nib.orientations.aff2axcodes(img.affine) + ).astype(int) + ras_ornt = nib.orientations.axcodes2ornt("RAS") + ornt_trans = nib.orientations.ornt_transform(ornt, ras_ornt) + img_data = nib.orientations.apply_orientation(orig_data, ornt_trans) + orig_mgh = nib.MGHImage(orig_data, img.affine) + aff_trans = nib.orientations.inv_ornt_aff(ornt_trans, img.shape) + vox_ras_t = np.dot(orig_mgh.header.get_vox2ras_tkr(), aff_trans) + vox_scan_ras_t = np.dot(orig_mgh.header.get_vox2ras(), aff_trans) + return img_data, vox_ras_t, vox_scan_ras_t + + +def _make_mpl_plot( + width=4, + height=4, + dpi=300, + tight=True, + hide_axes=True, + facecolor="black", + invert=True, +): + fig = Figure(figsize=(width, height), dpi=dpi) + canvas = FigureCanvas(fig) + ax = fig.subplots() + if tight: + fig.subplots_adjust(bottom=0, left=0, right=1, top=1, wspace=0, hspace=0) + ax.set_facecolor(facecolor) + # clean up excess plot text, invert + if invert: + ax.invert_yaxis() + if hide_axes: + ax.set_xticks([]) + ax.set_yticks([]) + return canvas, fig + + +class SliceBrowser(QMainWindow): + """Navigate between slices of an MRI, CT, etc. image.""" + + _xy_idx = ( + (1, 2), + (0, 2), + (0, 1), + ) + + @_qt_safe_window(splash="_renderer.figure.splash", window="") + def __init__(self, base_image=None, subject=None, subjects_dir=None, verbose=None): + """GUI for browsing slices of anatomical images.""" + # initialize QMainWindow class + super(SliceBrowser, self).__init__() + self.setAttribute(Qt.WA_DeleteOnClose, True) + + self._verbose = verbose + # if bad/None subject, will raise an informative error when loading MRI + subject = os.environ.get("SUBJECT") if subject is None else subject + subjects_dir = str(get_subjects_dir(subjects_dir, raise_error=False)) + self._subject_dir = ( + op.join(subjects_dir, subject) if subject and subjects_dir else None + ) + self._load_image_data(base_image=base_image) + + # GUI design + + # Main plots: make one plot for each view; sagittal, coronal, axial + self._plt_grid = QGridLayout() + self._figs = list() + for i in range(3): + canvas, fig = _make_mpl_plot() + self._plt_grid.addWidget(canvas, i // 2, i % 2) + self._figs.append(fig) + self._renderer = _get_renderer( + name="Slice Browser", size=(400, 400), bgcolor="w" + ) + self._plt_grid.addWidget(self._renderer.plotter, 1, 1) + + self._set_ras([0.0, 0.0, 0.0], update_plots=False) + + self._plot_images() + + self._configure_ui() + + def _configure_ui(self): + bottom_hbox = self._configure_status_bar() + + # Put everything together + plot_ch_hbox = QHBoxLayout() + plot_ch_hbox.addLayout(self._plt_grid) + + main_vbox = QVBoxLayout() + main_vbox.addLayout(plot_ch_hbox) + main_vbox.addLayout(bottom_hbox) + + central_widget = QWidget() + central_widget.setLayout(main_vbox) + self.setCentralWidget(central_widget) + + def _load_image_data(self, base_image=None): + """Get image data to display and transforms to/from vox/RAS.""" + if self._subject_dir is None: + # if the recon-all is not finished or the CT is not + # downsampled to the MRI, the MRI can not be used + self._mri_data = None + self._head = None + self._lh = self._rh = None + else: + mri_img = ( + "brain" + if op.isfile(op.join(self._subject_dir, "mri", "brain.mgz")) + else "T1" + ) + self._mri_data, vox_ras_t, vox_scan_ras_t = _load_image( + op.join(self._subject_dir, "mri", f"{mri_img}.mgz") + ) + + # ready alternate base image if provided, otherwise use brain/T1 + if base_image is None: + assert self._mri_data is not None + self._base_data = self._mri_data + self._vox_ras_t = vox_ras_t + self._vox_scan_ras_t = vox_scan_ras_t + else: + self._base_data, self._vox_ras_t, self._vox_scan_ras_t = _load_image( + base_image + ) + if self._mri_data is not None: + if self._mri_data.shape != self._base_data.shape or not np.allclose( + self._vox_ras_t, vox_ras_t, rtol=1e-6 + ): + raise ValueError( + "Base image is not aligned to MRI, got " + f"Base shape={self._base_data.shape}, " + f"MRI shape={self._mri_data.shape}, " + f"Base affine={vox_ras_t} and " + f"MRI affine={self._vox_ras_t}, " + "please provide an aligned image or do not use the " + "``subject`` and ``subjects_dir`` arguments" + ) + + self._ras_vox_t = np.linalg.inv(self._vox_ras_t) + self._scan_ras_vox_t = np.linalg.inv(self._vox_scan_ras_t) + self._voxel_sizes = np.array(self._base_data.shape) + self._voxel_ratios = self._voxel_sizes / self._voxel_sizes.min() + + # We need our extents to land the centers of each pixel on the voxel + # number. This code assumes 1mm isotropic... + img_delta = 0.5 + self._img_extents = list( + [ + -img_delta, + self._voxel_sizes[idx[0]] - img_delta, + -img_delta, + self._voxel_sizes[idx[1]] - img_delta, + ] + for idx in self._xy_idx + ) + + if self._subject_dir is not None: + if op.exists(op.join(self._subject_dir, "surf", "lh.seghead")): + self._head = _read_mri_surface( + op.join(self._subject_dir, "surf", "lh.seghead") + ) + assert _frame_to_str[self._head["coord_frame"]] == "mri" + else: + warn( + "`seghead` not found, using marching cubes on base image " + "for head plot, use :ref:`mne.bem.make_scalp_surfaces` " + "to add the scalp surface instead" + ) + self._head = None + + if self._subject_dir is not None: + # allow ?h.pial.T1 if ?h.pial doesn't exist + # end with '' for better file not found error + for img in ("", ".T1", ".T2", ""): + surf_fname = op.join( + self._subject_dir, "surf", "{hemi}" + f".pial{img}" + ) + if op.isfile(surf_fname.format(hemi="lh")): + break + if op.exists(surf_fname.format(hemi="lh")): + self._lh = _read_mri_surface(surf_fname.format(hemi="lh")) + assert _frame_to_str[self._lh["coord_frame"]] == "mri" + self._rh = _read_mri_surface(surf_fname.format(hemi="rh")) + assert _frame_to_str[self._rh["coord_frame"]] == "mri" + else: + warn( + "`pial` surface not found, skipping adding to 3D " + "plot. This indicates the Freesurfer recon-all " + "has not finished or has been modified and " + "these files have been deleted." + ) + self._lh = self._rh = None + + def _plot_images(self): + """Use the MRI or CT to make plots.""" + # Plot sagittal (0), coronal (1) or axial (2) view + self._images = dict( + base=list(), cursor_v=list(), cursor_h=list(), bounds=list() + ) + img_min = np.nanmin(self._base_data) + img_max = np.nanmax(self._base_data) + text_kwargs = dict( + fontsize="medium", + weight="bold", + color="#66CCEE", + family="monospace", + ha="center", + va="center", + path_effects=[ + patheffects.withStroke(linewidth=4, foreground="k", alpha=0.75) + ], + ) + xyz = apply_trans(self._ras_vox_t, self._ras) + for axis in range(3): + plot_x_idx, plot_y_idx = self._xy_idx[axis] + fig = self._figs[axis] + ax = fig.axes[0] + img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T + self._images["base"].append( + ax.imshow( + img_data, + cmap="gray", + aspect="auto", + zorder=1, + vmin=img_min, + vmax=img_max, + ) + ) + img_extent = self._img_extents[axis] # x0, x1, y0, y1 + w, h = np.diff(np.array(img_extent).reshape(2, 2), axis=1)[:, 0] + self._images["bounds"].append( + Rectangle( + img_extent[::2], + w, + h, + edgecolor="w", + facecolor="none", + alpha=0.25, + lw=0.5, + zorder=1.5, + ) + ) + ax.add_patch(self._images["bounds"][-1]) + v_x = (xyz[plot_x_idx],) * 2 + v_y = img_extent[2:4] + self._images["cursor_v"].append( + ax.plot(v_x, v_y, color="lime", linewidth=0.5, alpha=0.5, zorder=8)[0] + ) + h_y = (xyz[plot_y_idx],) * 2 + h_x = img_extent[0:2] + self._images["cursor_h"].append( + ax.plot(h_x, h_y, color="lime", linewidth=0.5, alpha=0.5, zorder=8)[0] + ) + # label axes + self._figs[axis].text(0.5, 0.075, _IMG_LABELS[axis][0], **text_kwargs) + self._figs[axis].text(0.075, 0.5, _IMG_LABELS[axis][1], **text_kwargs) + self._figs[axis].axes[0].axis(img_extent) + self._figs[axis].canvas.mpl_connect("scroll_event", self._on_scroll) + self._figs[axis].canvas.mpl_connect( + "button_release_event", partial(self._on_click, axis=axis) + ) + # add head and brain in mm (convert from m) + if self._head is None: + logger.debug( + "Using marching cubes on the base image for the " + "3D visualization panel" + ) + # in this case, leave in voxel coordinates + rr, tris = _marching_cubes( + np.where(self._base_data < np.quantile(self._base_data, 0.95), 0, 1), + [1], + )[0] + # marching cubes transposes dimensions so flip + rr = apply_trans(self._vox_ras_t, rr[:, ::-1]) + self._renderer.mesh( + *rr.T, + triangles=tris, + color="gray", + opacity=0.2, + reset_camera=False, + render=False, + ) + self._renderer.set_camera(focalpoint=rr.mean(axis=0)) + else: + self._renderer.mesh( + *self._head["rr"].T * 1000, + triangles=self._head["tris"], + color="gray", + opacity=0.2, + reset_camera=False, + render=False, + ) + if self._lh is not None and self._rh is not None: + self._renderer.mesh( + *self._lh["rr"].T * 1000, + triangles=self._lh["tris"], + color="white", + opacity=0.2, + reset_camera=False, + render=False, + ) + self._renderer.mesh( + *self._rh["rr"].T * 1000, + triangles=self._rh["tris"], + color="white", + opacity=0.2, + reset_camera=False, + render=False, + ) + self._renderer.set_camera( + azimuth=90, elevation=90, distance=300, focalpoint=tuple(self._ras) + ) + # update plots + self._draw() + self._renderer._update() + + def _configure_status_bar(self, hbox=None): + """Make a bar at the bottom with information in it.""" + hbox = QHBoxLayout() if hbox is None else hbox + + self._intensity_label = QLabel("") # update later + hbox.addWidget(self._intensity_label) + + VOX_label = QLabel("VOX =") + self._VOX_textbox = QLineEdit("") # update later + self._VOX_textbox.setMaximumHeight(25) + self._VOX_textbox.setMinimumWidth(75) + self._VOX_textbox.focusOutEvent = self._update_VOX + hbox.addWidget(VOX_label) + hbox.addWidget(self._VOX_textbox) + + RAS_label = QLabel("RAS =") + self._RAS_textbox = QLineEdit("") # update later + self._RAS_textbox.setMaximumHeight(25) + self._RAS_textbox.setMinimumWidth(150) + self._RAS_textbox.focusOutEvent = self._update_RAS + hbox.addWidget(RAS_label) + hbox.addWidget(self._RAS_textbox) + self._update_moved() # update text now + return hbox + + def _update_camera(self, render=False): + """Update the camera position.""" + self._renderer.set_camera( + # needs fix, distance moves when focal point updates + distance=self._renderer.plotter.camera.distance * 0.9, + focalpoint=tuple(self._ras), + reset_camera=False, + ) + + def _on_scroll(self, event): + """Process mouse scroll wheel event to zoom.""" + self._zoom(np.sign(event.step), draw=True) + + def _zoom(self, sign=1, draw=False): + """Zoom in on the image.""" + delta = _ZOOM_STEP_SIZE * sign + for axis, fig in enumerate(self._figs): + xcur = self._images["cursor_v"][axis].get_xdata()[0] + ycur = self._images["cursor_h"][axis].get_ydata()[0] + rx, ry = [self._voxel_ratios[idx] for idx in self._xy_idx[axis]] + xmin, xmax = fig.axes[0].get_xlim() + ymin, ymax = fig.axes[0].get_ylim() + xmid = (xmin + xmax) / 2 + ymid = (ymin + ymax) / 2 + if sign == 1: # may need to shift if zooming in + if abs(xmid - xcur) > delta / 2 * rx: + xmid += delta * np.sign(xcur - xmid) * rx + if abs(ymid - ycur) > delta / 2 * ry: + ymid += delta * np.sign(ycur - ymid) * ry + xwidth = (xmax - xmin) / 2 - delta * rx + ywidth = (ymax - ymin) / 2 - delta * ry + if xwidth <= 0 or ywidth <= 0: + return + fig.axes[0].set_xlim(xmid - xwidth, xmid + xwidth) + fig.axes[0].set_ylim(ymid - ywidth, ymid + ywidth) + if draw: + fig.canvas.draw() + + @Slot() + def _update_RAS(self, event): + """Interpret user input to the RAS textbox.""" + ras = self._convert_text(self._RAS_textbox.text(), "ras") + if ras is not None: + self._set_ras(ras) + + @Slot() + def _update_VOX(self, event): + """Interpret user input to the RAS textbox.""" + ras = self._convert_text(self._VOX_textbox.text(), "vox") + if ras is not None: + self._set_ras(ras) + + def _convert_text(self, text, text_kind): + text = text.replace("\n", "") + vals = text.split(",") + if len(vals) != 3: + vals = text.split(" ") # spaces also okay as in freesurfer + vals = [var.lstrip().rstrip() for var in vals] + try: + vals = np.array([float(var) for var in vals]).reshape(3) + except Exception: + self._update_moved() # resets RAS label + return + if text_kind == "vox": + vox = vals + ras = apply_trans(self._vox_ras_t, vox) + else: + assert text_kind == "ras" + ras = vals + vox = apply_trans(self._ras_vox_t, ras) + wrong_size = any( + var < 0 or var > n - 1 for var, n in zip(vox, self._voxel_sizes) + ) + if wrong_size: + self._update_moved() # resets RAS label + return + return ras + + @property + def _ras(self): + return self._ras_safe + + def set_RAS(self, ras): + """Set the crosshairs to a given RAS. + + Parameters + ---------- + ras : array-like + The right-anterior-superior scanner RAS coordinate. + """ + self._set_ras(ras) + + def _set_ras(self, ras, update_plots=True): + ras = np.asarray(ras, dtype=float) + assert ras.shape == (3,) + msg = ", ".join(f"{x:0.2f}" for x in ras) + logger.debug(f"Trying RAS: ({msg}) mm") + # clip to valid + vox = apply_trans(self._ras_vox_t, ras) + vox = np.array( + [np.clip(d, 0, self._voxel_sizes[ii] - 1) for ii, d in enumerate(vox)] + ) + # transform back, make write-only + self._ras_safe = apply_trans(self._vox_ras_t, vox) + self._ras_safe.flags["WRITEABLE"] = False + msg = ", ".join(f"{x:0.2f}" for x in self._ras_safe) + logger.debug(f"Setting RAS: ({msg}) mm") + if update_plots: + self._move_cursors_to_pos() + + def set_vox(self, vox): + """Set the crosshairs to a given voxel coordinate. + + Parameters + ---------- + vox : array-like + The voxel coordinate. + """ + self._set_ras(apply_trans(self._vox_ras_t, vox)) + + @property + def _vox(self): + return apply_trans(self._ras_vox_t, self._ras) + + @property + def _current_slice(self): + return self._vox.round().astype(int) + + def _draw(self, axis=None): + """Update the figures with a draw call.""" + for axis in range(3) if axis is None else [axis]: + self._figs[axis].canvas.draw() + + def _update_base_images(self, axis=None, draw=False): + """Update the base images.""" + for axis in range(3) if axis is None else [axis]: + img_data = np.take(self._base_data, self._current_slice[axis], axis=axis).T + self._images["base"][axis].set_data(img_data) + if draw: + self._draw(axis) + + def _update_images(self, axis=None, draw=True): + """Update CT and channel images when general changes happen.""" + self._update_base_images(axis=axis) + if draw: + self._draw(axis) + + def _move_cursors_to_pos(self): + """Move the cursors to a position.""" + for axis in range(3): + x, y = self._vox[list(self._xy_idx[axis])] + self._images["cursor_v"][axis].set_xdata([x, x]) + self._images["cursor_h"][axis].set_ydata([y, y]) + self._update_images(draw=True) + self._update_moved() + + def _show_help(self): + """Show the help menu.""" + QMessageBox.information( + self, + "Help", + "Help:\n" + "'+'/'-': zoom\nleft/right arrow: left/right\n" + "up/down arrow: superior/inferior\n" + "left angle bracket/right angle bracket: anterior/posterior", + ) + + def keyPressEvent(self, event): + """Execute functions when the user presses a key.""" + if event.key() == "escape": + self.close() + + elif event.key() == QtCore.Qt.Key_Return: + for widget in (self._RAS_textbox, self._VOX_textbox): + if widget.hasFocus(): + widget.clearFocus() + self.setFocus() # removing focus calls focus out event + + elif event.text() == "h": + self._show_help() + + elif event.text() in ("=", "+", "-"): + self._zoom(sign=-2 * (event.text() == "-") + 1, draw=True) + + # Changing slices + elif event.key() in ( + QtCore.Qt.Key_Up, + QtCore.Qt.Key_Down, + QtCore.Qt.Key_Left, + QtCore.Qt.Key_Right, + QtCore.Qt.Key_Comma, + QtCore.Qt.Key_Period, + QtCore.Qt.Key_PageUp, + QtCore.Qt.Key_PageDown, + ): + ras = np.array(self._ras) + if event.key() in (QtCore.Qt.Key_Up, QtCore.Qt.Key_Down): + ras[2] += 2 * (event.key() == QtCore.Qt.Key_Up) - 1 + elif event.key() in (QtCore.Qt.Key_Left, QtCore.Qt.Key_Right): + ras[0] += 2 * (event.key() == QtCore.Qt.Key_Right) - 1 + else: + ras[1] += ( + 2 + * ( + event.key() == QtCore.Qt.Key_PageUp + or event.key() == QtCore.Qt.Key_Period + ) + - 1 + ) + self._set_ras(ras) + + def _on_click(self, event, axis): + """Move to view on MRI and CT on click.""" + if event.inaxes is self._figs[axis].axes[0]: + # Data coordinates are voxel coordinates + pos = (event.xdata, event.ydata) + logger.debug(f'Clicked {"XYZ"[axis]} ({axis}) axis at pos {pos}') + xyz = self._vox + xyz[list(self._xy_idx[axis])] = pos + logger.debug(f"Using voxel {list(xyz)}") + ras = apply_trans(self._vox_ras_t, xyz) + self._set_ras(ras) + + def _update_moved(self): + """Update when cursor position changes.""" + self._RAS_textbox.setText("{:.2f}, {:.2f}, {:.2f}".format(*self._ras)) + self._VOX_textbox.setText("{:3d}, {:3d}, {:3d}".format(*self._current_slice)) + self._intensity_label.setText( + "intensity = {:.2f}".format(self._base_data[tuple(self._current_slice)]) + ) + + @safe_event + def closeEvent(self, event): + """Clean up upon closing the window.""" + try: + self._renderer.plotter.close() + except AttributeError: + pass + self.close() diff --git a/mne_gui_addons/_ieeg_locate.py b/mne_gui_addons/_ieeg_locate.py new file mode 100644 index 0000000..d9ec82a --- /dev/null +++ b/mne_gui_addons/_ieeg_locate.py @@ -0,0 +1,885 @@ +# -*- coding: utf-8 -*- +"""Intracranial elecrode localization GUI for finding contact locations.""" + +# Authors: Alex Rockhill +# +# License: BSD (3-clause) + +import numpy as np +import platform + +from scipy.ndimage import maximum_filter + +from qtpy import QtCore, QtGui +from qtpy.QtCore import Slot, Signal +from qtpy.QtWidgets import ( + QVBoxLayout, + QHBoxLayout, + QLabel, + QMessageBox, + QWidget, + QAbstractItemView, + QListView, + QSlider, + QPushButton, + QComboBox, +) + +from matplotlib.colors import LinearSegmentedColormap + +from ._core import SliceBrowser +from mne.channels import make_dig_montage +from mne.surface import _voxel_neighbors +from mne.transforms import apply_trans, _get_trans, invert_transform +from mne.utils import logger, _validate_type, verbose +from mne import pick_types + +_CH_PLOT_SIZE = 1024 +_RADIUS_SCALAR = 0.4 +_TUBE_SCALAR = 0.1 +_BOLT_SCALAR = 30 # mm +_CH_MENU_WIDTH = 30 if platform.system() == "Windows" else 10 + +# 20 colors generated to be evenly spaced in a cube, worked better than +# matplotlib color cycle +_UNIQUE_COLORS = [ + (0.1, 0.42, 0.43), + (0.9, 0.34, 0.62), + (0.47, 0.51, 0.3), + (0.47, 0.55, 0.99), + (0.79, 0.68, 0.06), + (0.34, 0.74, 0.05), + (0.58, 0.87, 0.13), + (0.86, 0.98, 0.4), + (0.92, 0.91, 0.66), + (0.77, 0.38, 0.34), + (0.9, 0.37, 0.1), + (0.2, 0.62, 0.9), + (0.22, 0.65, 0.64), + (0.14, 0.94, 0.8), + (0.34, 0.31, 0.68), + (0.59, 0.28, 0.74), + (0.46, 0.19, 0.94), + (0.37, 0.93, 0.7), + (0.56, 0.86, 0.55), + (0.67, 0.69, 0.44), +] +_N_COLORS = len(_UNIQUE_COLORS) +_CMAP = LinearSegmentedColormap.from_list("ch_colors", _UNIQUE_COLORS, N=_N_COLORS) + + +class ComboBox(QComboBox): + """Dropdown menu that emits a click when popped up.""" + + clicked = Signal() + + def showPopup(self): + """Override show popup method to emit click.""" + self.clicked.emit() + super(ComboBox, self).showPopup() + + +class IntracranialElectrodeLocator(SliceBrowser): + """Locate electrode contacts using a coregistered MRI and CT.""" + + def __init__( + self, + info, + trans, + base_image, + subject=None, + subjects_dir=None, + groups=None, + show=True, + verbose=None, + ): + """GUI for locating intracranial electrodes. + + .. note:: Images will be displayed using orientation information + obtained from the image header. Images will be resampled to + dimensions [256, 256, 256] for display. + """ + if not info.ch_names: + raise ValueError("No channels found in `info` to locate") + + # store info for modification + self._info = info + self._seeg_idx = pick_types(self._info, meg=False, seeg=True) + self._verbose = verbose + + # channel plotting default parameters + self._ch_alpha = 0.5 + self._radius = int(_CH_PLOT_SIZE // 100) # starting 1/100 of image + + # initialize channel data + self._ch_index = 0 + # load data, apply trans + self._head_mri_t = _get_trans(trans, "head", "mri")[0] + self._mri_head_t = invert_transform(self._head_mri_t) + + # ensure channel positions in head + montage = info.get_montage() + if montage and montage.get_positions()["coord_frame"] != "head": + raise RuntimeError( + "Channel positions in the ``info`` object must " + 'be in the "head" coordinate frame.' + ) + + # load channels, convert from m to mm + self._chs = { + name: apply_trans(self._head_mri_t, ch["loc"][:3]) * 1000 + for name, ch in zip(info.ch_names, info["chs"]) + } + self._ch_names = list(self._chs.keys()) + self._group_channels(groups) + + # Initialize GUI + super(IntracranialElectrodeLocator, self).__init__( + base_image=base_image, subject=subject, subjects_dir=subjects_dir + ) + + # set current position as current contact location if exists + if not np.isnan(self._chs[self._ch_names[self._ch_index]]).any(): + self._set_ras(self._chs[self._ch_names[self._ch_index]], update_plots=False) + + # add plots of contacts on top + self._plot_ch_images() + + # Add lines + self._lines = dict() + self._lines_2D = dict() + for group in set(self._groups.values()): + self._update_lines(group) + + # ready for user + self._move_cursors_to_pos() + self._ch_list.setFocus() # always focus on list + + if show: + self.show() + + def _configure_ui(self): + # data is loaded for an abstract base image, associate with ct + self._ct_data = self._base_data + self._images["ct"] = self._images["base"] + self._ct_maxima = None # don't compute until turned on + + toolbar = self._configure_toolbar() + slider_bar = self._configure_sliders() + status_bar = self._configure_status_bar() + self._ch_list = self._configure_channel_sidebar() # need for updating + + plot_layout = QHBoxLayout() + plot_layout.addLayout(self._plt_grid) + plot_layout.addWidget(self._ch_list) + + main_vbox = QVBoxLayout() + main_vbox.addLayout(toolbar) + main_vbox.addLayout(slider_bar) + main_vbox.addLayout(plot_layout) + main_vbox.addLayout(status_bar) + + central_widget = QWidget() + central_widget.setLayout(main_vbox) + self.setCentralWidget(central_widget) + + def _configure_channel_sidebar(self): + """Configure the sidebar to select channels/contacts.""" + ch_list = QListView() + ch_list.setSelectionMode(QAbstractItemView.SingleSelection) + max_ch_name_len = max([len(name) for name in self._chs]) + ch_list.setMinimumWidth(max_ch_name_len * _CH_MENU_WIDTH) + ch_list.setMaximumWidth(max_ch_name_len * _CH_MENU_WIDTH) + self._ch_list_model = QtGui.QStandardItemModel(ch_list) + for name in self._ch_names: + self._ch_list_model.appendRow(QtGui.QStandardItem(name)) + self._color_list_item(name=name) + ch_list.setModel(self._ch_list_model) + ch_list.clicked.connect(self._go_to_ch) + ch_list.setCurrentIndex(self._ch_list_model.index(self._ch_index, 0)) + ch_list.keyPressEvent = self.keyPressEvent + return ch_list + + def _make_ch_image(self, axis, proj=False): + """Make a plot to display the channel locations.""" + # Make channel data higher resolution so it looks better. + ch_image = np.zeros((_CH_PLOT_SIZE, _CH_PLOT_SIZE)) * np.nan + vxyz = self._voxel_sizes + + def color_ch_radius(ch_image, xf, yf, group, radius): + # Take the fraction across each dimension of the RAS + # coordinates converted to xyz and put a circle in that + # position in this larger resolution image + ex, ey = np.round(np.array([xf, yf]) * _CH_PLOT_SIZE).astype(int) + ii = np.arange(-radius, radius + 1) + ii_sq = ii * ii + idx = np.where(ii_sq + ii_sq[:, np.newaxis] < radius * radius) + # negative y because y axis is inverted + ch_image[-(ey + ii[idx[1]]), ex + ii[idx[0]]] = group + return ch_image + + for name, ras in self._chs.items(): + # move from middle-centered (half coords positive, half negative) + # to bottom-left corner centered (all coords positive). + if np.isnan(ras).any(): + continue + xyz = apply_trans(self._ras_vox_t, ras) + # check if closest to that voxel + dist = np.linalg.norm(xyz - self._current_slice) + if proj or dist < self._radius: + group = self._groups[name] + r = ( + self._radius + if proj + else self._radius - np.round(abs(dist)).astype(int) + ) + xf, yf = (xyz / vxyz)[list(self._xy_idx[axis])] + ch_image = color_ch_radius(ch_image, xf, yf, group, r) + return ch_image + + @verbose + def _save_ch_coords(self, info=None, verbose=None): + """Save the location of the electrode contacts.""" + logger.info("Saving channel positions to `info`") + if info is None: + info = self._info + montage = info.get_montage() + montage_kwargs = ( + montage.get_positions() + if montage + else dict(ch_pos=dict(), coord_frame="head") + ) + for ch in info["chs"]: + # surface RAS-> head and mm->m + montage_kwargs["ch_pos"][ch["ch_name"]] = apply_trans( + self._mri_head_t, self._chs[ch["ch_name"]].copy() / 1000 + ) + info.set_montage(make_dig_montage(**montage_kwargs)) + + def _plot_ch_images(self): + img_delta = 0.5 + ch_deltas = list( + img_delta * (self._voxel_sizes[ii] / _CH_PLOT_SIZE) for ii in range(3) + ) + self._ch_extents = list( + [ + -ch_delta, + self._voxel_sizes[idx[0]] - ch_delta, + -ch_delta, + self._voxel_sizes[idx[1]] - ch_delta, + ] + for idx, ch_delta in zip(self._xy_idx, ch_deltas) + ) + self._images["chs"] = list() + for axis in range(3): + fig = self._figs[axis] + ax = fig.axes[0] + self._images["chs"].append( + ax.imshow( + self._make_ch_image(axis), + aspect="auto", + extent=self._ch_extents[axis], + zorder=3, + cmap=_CMAP, + alpha=self._ch_alpha, + vmin=0, + vmax=_N_COLORS, + ) + ) + self._3d_chs = dict() + for name in self._chs: + self._plot_3d_ch(name) + + def _plot_3d_ch(self, name, render=False): + """Plot a single 3D channel.""" + if name in self._3d_chs: + self._renderer.plotter.remove_actor(self._3d_chs.pop(name), render=False) + if not any(np.isnan(self._chs[name])): + self._3d_chs[name] = self._renderer.sphere( + tuple(self._chs[name]), + scale=1, + color=_CMAP(self._groups[name])[:3], + opacity=self._ch_alpha, + )[0] + # The actor scale is managed differently than the glyph scale + # in order not to recreate objects, we use the actor scale + self._3d_chs[name].SetOrigin(self._chs[name]) + self._3d_chs[name].SetScale(self._radius * _RADIUS_SCALAR) + if render: + self._renderer._update() + + def _configure_toolbar(self): + """Make a bar with buttons for user interactions.""" + hbox = QHBoxLayout() + + help_button = QPushButton("Help") + help_button.released.connect(self._show_help) + hbox.addWidget(help_button) + + hbox.addStretch(8) + + hbox.addWidget(QLabel("Snap to Center")) + self._snap_button = QPushButton("Off") + self._snap_button.setMaximumWidth(25) # not too big + hbox.addWidget(self._snap_button) + self._snap_button.released.connect(self._toggle_snap) + self._toggle_snap() # turn on to start + + hbox.addStretch(1) + + self._toggle_brain_button = QPushButton("Show Brain") + self._toggle_brain_button.released.connect(self._toggle_show_brain) + hbox.addWidget(self._toggle_brain_button) + + hbox.addStretch(1) + + mark_button = QPushButton("Mark") + hbox.addWidget(mark_button) + mark_button.released.connect(self.mark_channel) + + remove_button = QPushButton("Remove") + hbox.addWidget(remove_button) + remove_button.released.connect(self.remove_channel) + + self._group_selector = ComboBox() + group_model = self._group_selector.model() + + for i in range(_N_COLORS): + self._group_selector.addItem(" ") + color = QtGui.QColor() + color.setRgb(*(255 * np.array(_CMAP(i))).round().astype(int)) + brush = QtGui.QBrush(color) + brush.setStyle(QtCore.Qt.SolidPattern) + group_model.setData( + group_model.index(i, 0), brush, QtCore.Qt.BackgroundRole + ) + self._group_selector.clicked.connect(self._select_group) + self._group_selector.currentIndexChanged.connect(self._select_group) + hbox.addWidget(self._group_selector) + + # update background color for current selection + self._update_group() + + return hbox + + def _configure_sliders(self): + """Make a bar with sliders on it.""" + + def make_label(name): + label = QLabel(name) + label.setAlignment(QtCore.Qt.AlignCenter) + return label + + def make_slider(smin, smax, sval, sfun=None): + slider = QSlider(QtCore.Qt.Horizontal) + slider.setMinimum(int(round(smin))) + slider.setMaximum(int(round(smax))) + slider.setValue(int(round(sval))) + slider.setTracking(False) # only update on release + if sfun is not None: + slider.valueChanged.connect(sfun) + slider.keyPressEvent = self.keyPressEvent + return slider + + slider_hbox = QHBoxLayout() + + ch_vbox = QVBoxLayout() + ch_vbox.addWidget(make_label("ch alpha")) + ch_vbox.addWidget(make_label("ch radius")) + slider_hbox.addLayout(ch_vbox) + + ch_slider_vbox = QVBoxLayout() + self._alpha_slider = make_slider( + 0, 100, self._ch_alpha * 100, self._update_ch_alpha + ) + ch_plot_max = _CH_PLOT_SIZE // 50 # max 1 / 50 of plot size + ch_slider_vbox.addWidget(self._alpha_slider) + self._radius_slider = make_slider( + 0, ch_plot_max, self._radius, self._update_radius + ) + ch_slider_vbox.addWidget(self._radius_slider) + slider_hbox.addLayout(ch_slider_vbox) + + ct_vbox = QVBoxLayout() + ct_vbox.addWidget(make_label("CT min")) + ct_vbox.addWidget(make_label("CT max")) + slider_hbox.addLayout(ct_vbox) + + ct_slider_vbox = QVBoxLayout() + ct_min = int(round(np.nanmin(self._ct_data))) + ct_max = int(round(np.nanmax(self._ct_data))) + self._ct_min_slider = make_slider(ct_min, ct_max, ct_min, self._update_ct_scale) + ct_slider_vbox.addWidget(self._ct_min_slider) + self._ct_max_slider = make_slider(ct_min, ct_max, ct_max, self._update_ct_scale) + ct_slider_vbox.addWidget(self._ct_max_slider) + slider_hbox.addLayout(ct_slider_vbox) + return slider_hbox + + def _configure_status_bar(self, hbox=None): + hbox = QHBoxLayout() if hbox is None else hbox + + hbox.addStretch(3) + + self._toggle_show_mip_button = QPushButton("Show Max Intensity Proj") + self._toggle_show_mip_button.released.connect(self._toggle_show_mip) + hbox.addWidget(self._toggle_show_mip_button) + + self._toggle_show_max_button = QPushButton("Show Maxima") + self._toggle_show_max_button.released.connect(self._toggle_show_max) + hbox.addWidget(self._toggle_show_max_button) + + self._intensity_label = QLabel("") # update later + hbox.addWidget(self._intensity_label) + + # add SliceBrowser navigation items + super(IntracranialElectrodeLocator, self)._configure_status_bar(hbox=hbox) + return hbox + + def _move_cursors_to_pos(self): + super(IntracranialElectrodeLocator, self)._move_cursors_to_pos() + + self._ch_list.setFocus() # remove focus from text edit + + def _group_channels(self, groups): + """Automatically find a group based on the name of the channel.""" + if groups is not None: + for name in self._ch_names: + if name not in groups: + raise ValueError(f"{name} not found in ``groups``") + _validate_type(groups[name], (float, int), f"groups[{name}]") + self.groups = groups + else: + i = 0 + self._groups = dict() + base_names = dict() + for name in self._ch_names: + # strip all numbers from the name + base_name = "".join( + [ + letter + for letter in name + if not letter.isdigit() and letter != " " + ] + ) + if base_name in base_names: + # look up group number by base name + self._groups[name] = base_names[base_name] + else: + self._groups[name] = i + base_names[base_name] = i + i += 1 + + def _update_lines(self, group, only_2D=False): + """Draw lines that connect the points in a group.""" + if group in self._lines_2D: # remove existing 2D lines first + for line in self._lines_2D[group]: + line.remove() + self._lines_2D.pop(group) + if only_2D: # if not in projection, don't add 2D lines + if self._toggle_show_mip_button.text() == "Show Max Intensity Proj": + return + elif group in self._lines: # if updating 3D, remove first + self._renderer.plotter.remove_actor(self._lines[group], render=False) + pos = np.array( + [ + self._chs[ch] + for i, ch in enumerate(self._ch_names) + if self._groups[ch] == group + and i in self._seeg_idx + and not np.isnan(self._chs[ch]).any() + ] + ) + if len(pos) < 2: # not enough points for line + return + # first, the insertion will be the point farthest from the origin + # brains are a longer posterior-anterior, scale for this (80%) + insert_idx = np.argmax(np.linalg.norm(pos * np.array([1, 0.8, 1]), axis=1)) + # second, find the farthest point from the insertion + target_idx = np.argmax(np.linalg.norm(pos[insert_idx] - pos, axis=1)) + # third, make a unit vector and to add to the insertion for the bolt + elec_v = pos[insert_idx] - pos[target_idx] + elec_v /= np.linalg.norm(elec_v) + if not only_2D: + self._lines[group] = self._renderer.tube( + [pos[target_idx]], + [pos[insert_idx] + elec_v * _BOLT_SCALAR], + radius=self._radius * _TUBE_SCALAR, + color=_CMAP(group)[:3], + )[0] + if self._toggle_show_mip_button.text() == "Hide Max Intensity Proj": + # add 2D lines on each slice plot if in max intensity projection + target_vox = apply_trans(self._ras_vox_t, pos[target_idx]) + insert_vox = apply_trans( + self._ras_vox_t, pos[insert_idx] + elec_v * _BOLT_SCALAR + ) + lines_2D = list() + for axis in range(3): + x, y = self._xy_idx[axis] + lines_2D.append( + self._figs[axis] + .axes[0] + .plot( + [target_vox[x], insert_vox[x]], + [target_vox[y], insert_vox[y]], + color=_CMAP(group), + linewidth=0.25, + zorder=7, + )[0] + ) + self._lines_2D[group] = lines_2D + + def _select_group(self): + """Change the group label to the selection.""" + group = self._group_selector.currentIndex() + self._groups[self._ch_names[self._ch_index]] = group + # color differently if found already + self._color_list_item(self._ch_names[self._ch_index]) + self._update_group() + + def _update_group(self): + """Set background for closed group menu.""" + group = self._group_selector.currentIndex() + rgb = (255 * np.array(_CMAP(group))).round().astype(int) + self._group_selector.setStyleSheet( + "background-color: rgb({:d},{:d},{:d})".format(*rgb) + ) + self._group_selector.update() + + def _update_ch_selection(self): + """Update which channel is selected.""" + name = self._ch_names[self._ch_index] + self._ch_list.setCurrentIndex(self._ch_list_model.index(self._ch_index, 0)) + self._group_selector.setCurrentIndex(self._groups[name]) + self._update_group() + if not np.isnan(self._chs[name]).any(): + self._set_ras(self._chs[name]) + self._update_camera(render=True) + self._draw() + + def _go_to_ch(self, index): + """Change current channel to the item selected.""" + self._ch_index = index.row() + self._update_ch_selection() + + @Slot() + def _next_ch(self): + """Increment the current channel selection index.""" + self._ch_index = (self._ch_index + 1) % len(self._ch_names) + self._update_ch_selection() + + def _color_list_item(self, name=None): + """Color the item in the view list for easy id of marked channels.""" + name = self._ch_names[self._ch_index] if name is None else name + color = QtGui.QColor("white") + if not np.isnan(self._chs[name]).any(): + group = self._groups[name] + color.setRgb(*[int(c * 255) for c in _CMAP(group)]) + brush = QtGui.QBrush(color) + brush.setStyle(QtCore.Qt.SolidPattern) + self._ch_list_model.setData( + self._ch_list_model.index(self._ch_names.index(name), 0), + brush, + QtCore.Qt.BackgroundRole, + ) + # color text black + color = QtGui.QColor("black") + brush = QtGui.QBrush(color) + brush.setStyle(QtCore.Qt.SolidPattern) + self._ch_list_model.setData( + self._ch_list_model.index(self._ch_names.index(name), 0), + brush, + QtCore.Qt.ForegroundRole, + ) + + @Slot() + def _toggle_snap(self): + """Toggle snapping the contact location to the center of mass.""" + if self._snap_button.text() == "Off": + self._snap_button.setText("On") + self._snap_button.setStyleSheet("background-color: green") + else: # text == 'On', turn off + self._snap_button.setText("Off") + self._snap_button.setStyleSheet("background-color: red") + + @Slot() + def mark_channel(self, ch=None): + """Mark a channel as being located at the crosshair. + + Parameters + ---------- + ch : str + The channel name. If ``None``, the current channel + is marked. + """ + if ch is not None and ch not in self._ch_names: + raise ValueError(f"Channel {ch} not found") + name = self._ch_names[ + self._ch_index if ch is None else self._ch_names.index(ch) + ] + if self._snap_button.text() == "Off": + self._chs[name][:] = self._ras + else: + shape = np.mean(self._voxel_sizes) # Freesurfer shape (256) + voxels_max = int( + 4 / 3 * np.pi * (shape * self._radius / _CH_PLOT_SIZE) ** 3 + ) + neighbors = _voxel_neighbors( + self._vox, + self._ct_data, + thresh=0.5, + voxels_max=voxels_max, + use_relative=True, + ) + self._chs[name][:] = apply_trans( # to surface RAS + self._vox_ras_t, np.array(list(neighbors)).mean(axis=0) + ) + self._color_list_item() + self._update_lines(self._groups[name]) + self._update_ch_images(draw=True) + self._plot_3d_ch(name, render=True) + self._save_ch_coords() + self._next_ch() + self._ch_list.setFocus() + + @Slot() + def remove_channel(self, ch=None): + """Remove the location data for the current channel. + + Parameters + ---------- + ch : str + The channel name. If ``None``, the current channel + is removed. + """ + if ch is not None and ch not in self._ch_names: + raise ValueError(f"Channel {ch} not found") + name = self._ch_names[ + self._ch_index if ch is None else self._ch_names.index(ch) + ] + self._chs[name] *= np.nan + self._color_list_item() + self._save_ch_coords() + self._update_lines(self._groups[name]) + self._update_ch_images(draw=True) + self._plot_3d_ch(name, render=True) + self._next_ch() + self._ch_list.setFocus() + + def _update_ch_images(self, axis=None, draw=False): + """Update the channel image(s).""" + for axis in range(3) if axis is None else [axis]: + self._images["chs"][axis].set_data(self._make_ch_image(axis)) + if self._toggle_show_mip_button.text() == "Hide Max Intensity Proj": + self._images["mip_chs"][axis].set_data( + self._make_ch_image(axis, proj=True) + ) + if draw: + self._draw(axis) + + def _update_ct_images(self, axis=None, draw=False): + """Update the CT image(s).""" + for axis in range(3) if axis is None else [axis]: + ct_data = np.take(self._ct_data, self._current_slice[axis], axis=axis).T + # Threshold the CT so only bright objects (electrodes) are visible + ct_data[ct_data < self._ct_min_slider.value()] = np.nan + ct_data[ct_data > self._ct_max_slider.value()] = np.nan + self._images["ct"][axis].set_data(ct_data) + if "local_max" in self._images: + ct_max_data = np.take( + self._ct_maxima, self._current_slice[axis], axis=axis + ).T + self._images["local_max"][axis].set_data(ct_max_data) + if draw: + self._draw(axis) + + def _update_mri_images(self, axis=None, draw=False): + """Update the CT image(s).""" + if "mri" in self._images: + for axis in range(3) if axis is None else [axis]: + self._images["mri"][axis].set_data( + np.take(self._mri_data, self._current_slice[axis], axis=axis).T + ) + if draw: + self._draw(axis) + + def _update_images(self, axis=None, draw=True): + """Update CT and channel images when general changes happen.""" + self._update_ch_images(axis=axis) + self._update_mri_images(axis=axis) + super()._update_images() + + def _update_ct_scale(self): + """Update CT min slider value.""" + new_min = self._ct_min_slider.value() + new_max = self._ct_max_slider.value() + # handle inversions + self._ct_min_slider.setValue(min([new_min, new_max])) + self._ct_max_slider.setValue(max([new_min, new_max])) + self._update_ct_images(draw=True) + + def _update_radius(self): + """Update channel plot radius.""" + self._radius = np.round(self._radius_slider.value()).astype(int) + if self._toggle_show_max_button.text() == "Hide Maxima": + self._update_ct_maxima() + self._update_ct_images() + else: + self._ct_maxima = None # signals ct max is out-of-date + self._update_ch_images(draw=True) + for name, actor in self._3d_chs.items(): + if not np.isnan(self._chs[name]).any(): + actor.SetOrigin(self._chs[name]) + actor.SetScale(self._radius * _RADIUS_SCALAR) + self._renderer._update() + self._ch_list.setFocus() # remove focus from 3d plotter + + def _update_ch_alpha(self): + """Update channel plot alpha.""" + self._ch_alpha = self._alpha_slider.value() / 100 + for axis in range(3): + self._images["chs"][axis].set_alpha(self._ch_alpha) + self._draw() + for actor in self._3d_chs.values(): + actor.GetProperty().SetOpacity(self._ch_alpha) + self._renderer._update() + self._ch_list.setFocus() # remove focus from 3d plotter + + def _show_help(self): + """Show the help menu.""" + QMessageBox.information( + self, + "Help", + "Help:\n'm': mark channel location\n" + "'r': remove channel location\n" + "'b': toggle viewing of brain in T1\n" + "'+'/'-': zoom\nleft/right arrow: left/right\n" + "up/down arrow: superior/inferior\n" + "left angle bracket/right angle bracket: anterior/posterior", + ) + + def _update_ct_maxima(self): + """Compute the maximum voxels based on the current radius.""" + self._ct_maxima = ( + maximum_filter(self._ct_data, (self._radius,) * 3) == self._ct_data + ) + self._ct_maxima[self._ct_data <= np.median(self._ct_data)] = False + self._ct_maxima = np.where(self._ct_maxima, 1, np.nan) # transparent + + def _toggle_show_mip(self): + """Toggle whether the maximum-intensity projection is shown.""" + if self._toggle_show_mip_button.text() == "Show Max Intensity Proj": + self._toggle_show_mip_button.setText("Hide Max Intensity Proj") + self._images["mip"] = list() + self._images["mip_chs"] = list() + ct_min, ct_max = np.nanmin(self._ct_data), np.nanmax(self._ct_data) + for axis in range(3): + ct_mip_data = np.max(self._ct_data, axis=axis).T + self._images["mip"].append( + self._figs[axis] + .axes[0] + .imshow( + ct_mip_data, + cmap="gray", + aspect="auto", + vmin=ct_min, + vmax=ct_max, + zorder=5, + ) + ) + # add circles for each channel + xs, ys, colors = list(), list(), list() + for name, ras in self._chs.items(): + xyz = self._vox + xs.append(xyz[self._xy_idx[axis][0]]) + ys.append(xyz[self._xy_idx[axis][1]]) + colors.append(_CMAP(self._groups[name])) + self._images["mip_chs"].append( + self._figs[axis] + .axes[0] + .imshow( + self._make_ch_image(axis, proj=True), + aspect="auto", + extent=self._ch_extents[axis], + zorder=6, + cmap=_CMAP, + alpha=1, + vmin=0, + vmax=_N_COLORS, + ) + ) + for group in set(self._groups.values()): + self._update_lines(group, only_2D=True) + else: + for img in self._images["mip"] + self._images["mip_chs"]: + img.remove() + self._images.pop("mip") + self._images.pop("mip_chs") + self._toggle_show_mip_button.setText("Show Max Intensity Proj") + for group in set(self._groups.values()): # remove lines + self._update_lines(group, only_2D=True) + self._draw() + + def _toggle_show_max(self): + """Toggle whether to color local maxima differently.""" + if self._toggle_show_max_button.text() == "Show Maxima": + self._toggle_show_max_button.setText("Hide Maxima") + # happens on initiation or if the radius is changed with it off + if self._ct_maxima is None: # otherwise don't recompute + self._update_ct_maxima() + self._images["local_max"] = list() + for axis in range(3): + ct_max_data = np.take( + self._ct_maxima, self._current_slice[axis], axis=axis + ).T + self._images["local_max"].append( + self._figs[axis] + .axes[0] + .imshow( + ct_max_data, + cmap="autumn", + aspect="auto", + vmin=0, + vmax=1, + zorder=4, + ) + ) + else: + for img in self._images["local_max"]: + img.remove() + self._images.pop("local_max") + self._toggle_show_max_button.setText("Show Maxima") + self._draw() + + def _toggle_show_brain(self): + """Toggle whether the brain/MRI is being shown.""" + if "mri" in self._images: + for img in self._images["mri"]: + img.remove() + self._images.pop("mri") + self._toggle_brain_button.setText("Show Brain") + else: + self._images["mri"] = list() + for axis in range(3): + mri_data = np.take( + self._mri_data, self._current_slice[axis], axis=axis + ).T + self._images["mri"].append( + self._figs[axis] + .axes[0] + .imshow(mri_data, cmap="hot", aspect="auto", alpha=0.25, zorder=2) + ) + self._toggle_brain_button.setText("Hide Brain") + self._draw() + + def keyPressEvent(self, event): + """Execute functions when the user presses a key.""" + super(IntracranialElectrodeLocator, self).keyPressEvent(event) + + if event.text() == "m": + self.mark_channel() + + if event.text() == "r": + self.remove_channel() + + if event.text() == "b": + self._toggle_show_brain() diff --git a/mne_gui_addons/_utils/__init__.py b/mne_gui_addons/_utils/__init__.py new file mode 100644 index 0000000..84161e9 --- /dev/null +++ b/mne_gui_addons/_utils/__init__.py @@ -0,0 +1 @@ +from .docs import _fill_doc # noqa: F401 diff --git a/mne_gui_addons/_utils/docs.py b/mne_gui_addons/_utils/docs.py new file mode 100644 index 0000000..6e31796 --- /dev/null +++ b/mne_gui_addons/_utils/docs.py @@ -0,0 +1,70 @@ +# -*- coding: utf-8 -*- +"""The documentation functions.""" +# Authors: Eric Larson +# +# License: BSD (3-clause) + +from mne.utils.docs import _indentcount_lines, docdict as _mne_docdict + +############################################################################## +# Define our standard documentation entries + +docdict = _mne_docdict.copy() +# Specific to this repo + +docdict[ + "tmax" +] = """ +tmax : scalar + Time point of the last sample in data. +""" + +docdict_indented = {} + + +def _fill_doc(f): + """Fill a docstring with docdict entries. + + Parameters + ---------- + f : callable + The function to fill the docstring of. Will be modified in place. + + Returns + ------- + f : callable + The function, potentially with an updated ``__doc__``. + """ + docstring = f.__doc__ + if not docstring: + return f + lines = docstring.splitlines() + # Find the minimum indent of the main docstring, after first line + if len(lines) < 2: + icount = 0 + else: + icount = _indentcount_lines(lines[1:]) + # Insert this indent to dictionary docstrings + try: + indented = docdict_indented[icount] + except KeyError: + indent = " " * icount + docdict_indented[icount] = indented = {} + for name, dstr in docdict.items(): + lines = dstr.splitlines() + try: + newlines = [lines[0]] + for line in lines[1:]: + newlines.append(indent + line) + indented[name] = "\n".join(newlines) + except IndexError: + indented[name] = dstr + try: + f.__doc__ = docstring % indented + except (TypeError, ValueError, KeyError) as exp: + funcname = f.__name__ + funcname = docstring.split("\n")[0] if funcname is None else funcname + raise RuntimeError( + "%s documenting %s:\n%s" % (exp.__class__.__name__, funcname, str(exp)) + ) + return f diff --git a/mne_gui_addons/_vol_stc.py b/mne_gui_addons/_vol_stc.py new file mode 100644 index 0000000..83b8100 --- /dev/null +++ b/mne_gui_addons/_vol_stc.py @@ -0,0 +1,1385 @@ +# -*- coding: utf-8 -*- +"""Source estimate viewing graphical user interfaces (GUIs).""" + +# Authors: Alex Rockhill +# +# License: BSD (3-clause) + +import os.path as op +import numpy as np + +from qtpy import QtCore +from qtpy.QtWidgets import ( + QVBoxLayout, + QHBoxLayout, + QLabel, + QMessageBox, + QWidget, + QSlider, + QPushButton, + QComboBox, + QLineEdit, + QFrame, +) +from matplotlib.colors import LinearSegmentedColormap + +from ._core import SliceBrowser +from mne import BaseEpochs +from mne.baseline import rescale, _check_baseline +from mne.defaults import DEFAULTS +from mne.evoked import EvokedArray +from mne.time_frequency import EpochsTFR +from mne.io.constants import FIFF +from mne.io.pick import _get_channel_types, _picks_to_idx, _pick_inst +from mne.transforms import apply_trans +from mne.utils import ( + _require_version, + _validate_type, + _check_range, + fill_doc, + _check_option, +) +from mne.viz.backends._utils import _qt_safe_window +from mne.viz.utils import _get_cmap + +BASE_INT_DTYPE = np.int16 +COMPLEX_DTYPE = np.dtype([("re", BASE_INT_DTYPE), ("im", BASE_INT_DTYPE)]) +RANGE_VALUE = 2**15 +# for taking the complex conjugate, we need to be able to +# temporarily store in a value where x**2 * 2 fits +OVERFLOW_DYPE = np.int32 + +VECTOR_SCALAR = 10 +SLIDER_WIDTH = 300 + + +def _check_consistent(items, name): + if not len(items): + return + for item in items[1:]: + if item != items[0]: + raise RuntimeError( + f"Inconsistent attribute {name}, " f"got {items[0]} and {item}" + ) + return items[0] + + +def _get_src_lut(src): + offset = 2 if src.kind == "mixed" else 0 + inuse = [s["inuse"] for s in src[offset:]] + rr = np.concatenate( + [s["rr"][this_inuse.astype(bool)] for s, this_inuse in zip(src[offset:], inuse)] + ) + shape = _check_consistent([this_src["shape"] for this_src in src], "src['shape']") + # order='F' so that F-order flattening is faster + lut = -1 * np.ones(np.prod(shape), dtype=np.int64, order="F") + n_vertices_seen = 0 + for this_inuse in inuse: + this_inuse = this_inuse.astype(bool) + n_vertices = np.sum(this_inuse) + lut[this_inuse] = np.arange(n_vertices_seen, n_vertices_seen + n_vertices) + n_vertices_seen += n_vertices + lut = np.reshape(lut, shape, order="F") + src_affine_ras = _check_consistent( + [this_src["mri_ras_t"]["trans"] for this_src in src], "src['mri_ras_t']" + ) + src_affine_src = _check_consistent( + [this_src["src_mri_t"]["trans"] for this_src in src], "src['src_mri_t']" + ) + affine = np.dot(src_affine_ras, src_affine_src) + affine[:3] *= 1e3 + return lut, affine, src_affine_src * 1000, rr * 1000 + + +def _make_vol(lut, stc_data): + vol = np.zeros(lut.shape, dtype=stc_data.dtype, order="F") * np.nan + vol[lut >= 0] = stc_data[lut[lut >= 0]] + return vol.reshape(lut.shape, order="F") + + +def _coord_to_coord(coord, vox_ras_t, ras_vox_t): + return apply_trans(ras_vox_t, apply_trans(vox_ras_t, coord)) + + +def _threshold_array(array, min_val, max_val): + array = array.astype(float) + array[array < min_val] = np.nan + array[array > max_val] = np.nan + return array + + +def _int_complex_conj(data): + # Since the mixed real * imaginary terms cancel out, the complex + # conjugate is the same as squaring and adding the real and imaginary. + # Case up the integer size temporarily to prevent overflow + conj = (data["re"].astype(OVERFLOW_DYPE)) ** 2 + ( + data["im"].astype(OVERFLOW_DYPE) + ) ** 2 + return (conj // (conj.max() // RANGE_VALUE + 1)).astype(BASE_INT_DTYPE) + + +class VolSourceEstimateViewer(SliceBrowser): + """View a source estimate time-course time-frequency visualization.""" + + @_qt_safe_window(splash="_renderer.figure.splash", window="") + def __init__( + self, + data, + subject=None, + subjects_dir=None, + src=None, + inst=None, + show_topomap=True, + group=False, + show=True, + verbose=None, + ): + """View a volume time and/or frequency source time course estimate. + + Parameters + ---------- + data : array-like + An array of shape (``n_epochs``, ``n_sources``, ``n_ori``, + ``n_freqs``, ``n_times``). ``n_epochs`` may be 1 for data + averaged across epochs and ``n_freqs`` may be 1 for data + that is in time only and is not time-frequency decomposed. For + faster performance, data can be cast to integers or a + custom complex data type that uses integers as done by + :func:`mne.gui.view_vol_stc`. + %(subject)s + %(subjects_dir)s + src : instance of SourceSpaces + The volume source space for the ``stc``. + inst : EpochsTFR | AverageTFR | None | list + The time-frequency or data instances to use to plot topography. + If group-level results are given (``group=True``), a list of + instances should be provided. + show_topomap : bool + Show the sensor topomap if ``True``. + group : bool + If the first dimension of the data is subjects, rather than + epochs use ``True``. + show : bool + Show the GUI if ``True``. + block : bool + Whether to halt program execution until the figure is closed. + %(verbose)s + """ + _require_version("dipy", "VolSourceEstimateViewer", "0.10.1") + if src is None: + raise NotImplementedError( + "`src` is required, surface source " + "estimate viewing is not yet supported" + ) + if inst is None: + raise NotImplementedError( + "`data` as a source estimate object is " + "not yet supported so `inst` is required" + ) + if not isinstance(data, np.ndarray) or data.ndim != 5: + raise ValueError( + "`data` must be an array of dimensions " + "(n_epochs, n_sources, n_ori, n_freqs, n_times)" + ) + if group: + if not isinstance(inst, (list, tuple)) and len(inst) != data.shape[0]: + raise ValueError( + "Group-level results (group=True) number of " + "`inst`s does not match `data`, expected " + f"a list of {data.shape[0]} `inst`s, " + f"got a {type(inst)} of length {len(inst)}" + ) + else: + if isinstance(inst, (BaseEpochs, EpochsTFR)) and data.shape[0] != len(inst): + raise ValueError( + "Number of epochs in `inst` does not match with `data`, " + f"expected {data.shape[0]}, got {len(inst)}" + ) + insts = inst if group else None + inst = inst[0] if group else inst + + n_src_verts = sum([this_src["nuse"] for this_src in src]) + if src is not None and data.shape[1] != n_src_verts: + raise RuntimeError( + "Source vertices in `data` do not match with " + "source space vertices in `src`, " + f"expected {n_src_verts}, got {data.shape[1]}" + ) + if any([this_src["type"] == "surf" for this_src in src]): + raise NotImplementedError( + "Surface and mixed source space " "viewing is not implemented yet." + ) + if not all([s["coord_frame"] == FIFF.FIFFV_COORD_MRI for s in src]): + raise RuntimeError( + "The source space must be in the `mri`" "coordinate frame" + ) + if hasattr(inst, "freqs") and data.shape[3] != inst.freqs.size: + raise ValueError( + "Frequencies in `inst` do not match with `data`, " + f"expected {data.shape[3]}, got {inst.freqs.size}" + ) + if ( + hasattr(inst, "freqs") + and not group + and not (np.iscomplexobj(data) or data.dtype == COMPLEX_DTYPE) + ): + raise ValueError( + "Complex data is required for time-frequency " "source estimates" + ) + if data.shape[4] != inst.times.size: + raise ValueError( + "Times in `inst` do not match with `data`, " + f"expected {data.shape[4]}, got {inst.times.size}" + ) + self._verbose = verbose # used for logging, unused here + self._data = data + self._src = src + self._inst = inst + self._insts = insts + self._group = group + self._show_topomap = show_topomap + self._selector_prefix = "Subject" if group else "Epoch" + ( + self._src_lut, + self._src_vox_scan_ras_t, + self._src_vox_ras_t, + self._src_rr, + ) = _get_src_lut(src) + self._src_scan_ras_vox_t = np.linalg.inv(self._src_vox_scan_ras_t) + self._is_complex = ( + np.iscomplexobj(self._data) or self._data.dtype == COMPLEX_DTYPE + ) + self._baseline = "none" + self._bl_tmin = self._inst.times[0] + self._bl_tmax = self._inst.times[-1] + self._update = True # can be set to False to prevent double updates + # for time and frequency + # check if only positive values will be used + self._pos_support = ( + self._is_complex or self._data.shape[2] > 1 or (self._data >= 0).all() + ) + self._cmap = _get_cmap("hot" if self._pos_support else "mne") + + # set default variables for plotting + self._t_idx = self._inst.times.size // 2 + self._f_idx = ( + self._inst.freqs.size // 2 if hasattr(self._inst, "freqs") else None + ) + self._alpha = 0.75 + self._epoch_idx = ( + "Subject 0" if self._group else "Average" + " Power" * self._is_complex + ) + + # initialize current 3D image for chosen time and frequency + stc_data = self._pick_epoch(self._data) + + # take the vector magnitude, if scalar, does nothing + self._stc_data_vol = np.linalg.norm(stc_data, axis=1) + + stc_max = np.nanmax(self._stc_data_vol) + self._stc_min = min([np.nanmin(self._stc_data_vol), stc_max]) + self._stc_range = max([stc_max, -self._stc_min]) - self._stc_min + + stc_data_vol = self._pick_stc_tfr(self._stc_data_vol) + self._stc_img = _make_vol(self._src_lut, stc_data_vol) + + super(VolSourceEstimateViewer, self).__init__( + subject=subject, subjects_dir=subjects_dir + ) + + if src._subject != op.basename(self._subject_dir): + raise RuntimeError( + f"Source space subject ({src._subject})-freesurfer subject" + f"({op.basename(self._subject_dir)}) mismatch" + ) + + # make source time course plots + self._images["stc"] = list() + src_shape = np.array(self._src_lut.shape) + corners = [ # center pixel on location + _coord_to_coord((0,) * 3, self._src_vox_scan_ras_t, self._scan_ras_vox_t), + _coord_to_coord( + src_shape - 1, self._src_vox_scan_ras_t, self._scan_ras_vox_t + ), + ] + src_coord = self._get_src_coord() + for axis in range(3): + stc_slice = np.take(self._stc_img, src_coord[axis], axis=axis).T + x_idx, y_idx = self._xy_idx[axis] + extent = [ + corners[0][x_idx], + corners[1][x_idx], + corners[1][y_idx], + corners[0][y_idx], + ] + self._images["stc"].append( + self._figs[axis] + .axes[0] + .imshow( + stc_slice, + aspect="auto", + extent=extent, + cmap=self._cmap, + alpha=self._alpha, + zorder=2, + ) + ) + + self._data_max = abs(stc_data).max() + if self._data.shape[2] > 1 and not self._is_complex: + # also compute vectors for chosen time + self._stc_vectors = self._pick_stc_tfr(stc_data).astype(float) + self._stc_vectors /= self._data_max + self._stc_vectors_masked = self._stc_vectors.copy() + + assert self._data.shape[2] == 3 + self._vector_mapper, self._vector_data = self._renderer.quiver3d( + *self._src_rr.T, + *(VECTOR_SCALAR * self._stc_vectors_masked.T), + color=None, + mode="2darrow", + scale_mode="vector", + scale=1, + opacity=1, + ) + self._vector_actor = self._renderer._actor(self._vector_mapper) + self._vector_actor.GetProperty().SetLineWidth(2.0) + self._renderer.plotter.add_actor(self._vector_actor, render=False) + + # initialize 3D volumetric rendering + # TO DO: add surface source space viewing as elif + if any([this_src["type"] == "vol" for this_src in self._src]): + scalars = np.array(np.where(np.isnan(self._stc_img), 0, 1.0)) + spacing = np.diag(self._src_vox_ras_t)[:3] + origin = self._src_vox_ras_t[:3, 3] - spacing / 2.0 + center = 0.5 * self._stc_range - self._stc_min + ( + self._grid, + self._grid_mesh, + self._volume_pos, + self._volume_neg, + ) = self._renderer._volume( + dimensions=src_shape, + origin=origin, + spacing=spacing, + scalars=scalars.flatten(order="F"), + surface_alpha=self._alpha, + resolution=0.4, + blending="mip", + center=center, + ) + self._volume_pos_actor = self._renderer.plotter.add_actor( + self._volume_pos, render=False + )[0] + self._volume_neg_actor = self._renderer.plotter.add_actor( + self._volume_neg, render=False + )[0] + _, grid_prop = self._renderer.plotter.add_actor( + self._grid_mesh, render=False + ) + grid_prop.SetOpacity(0.1) + self._scalar_bar = self._renderer.scalarbar( + source=self._volume_pos_actor, + n_labels=8, + color="black", + bgcolor="white", + label_font_size=10, + ) + self._scalar_bar.SetOrientationToVertical() + self._scalar_bar.SetHeight(0.6) + self._scalar_bar.SetWidth(0.05) + self._scalar_bar.SetPosition(0.02, 0.2) + + self._update_cmap() # must be called for volume to render properly + # keep focus on main window so that keypress events work + self.setFocus() + if show: + self.show() + + def _get_min_max_val(self): + """Get the minimum and maximum non-transparent values.""" + return [ + self._cmap_sliders[i].value() / SLIDER_WIDTH * self._stc_range + + self._stc_min + for i in (0, 2) + ] + + def _get_src_coord(self): + """Get the current slice transformed to source space.""" + return tuple( + np.round( + _coord_to_coord( + self._current_slice, self._vox_scan_ras_t, self._src_scan_ras_vox_t + ) + ).astype(int) + ) + + def _update_stc_pick(self): + """Update the normalized data with the epoch picked.""" + stc_data = self._pick_epoch(self._data) + self._stc_data_vol = self._apply_vector_norm(stc_data) + self._stc_data_vol = self._apply_baseline_correction(self._stc_data_vol) + # deal with baseline infinite numbers + inf_mask = np.isinf(self._stc_data_vol) + if inf_mask.any(): + self._stc_data_vol[inf_mask] = np.nan + stc_max = np.nanmax(self._stc_data_vol) + self._stc_min = min([np.nanmin(self._stc_data_vol), -stc_max]) + self._stc_range = max([stc_max, -self._stc_min]) - self._stc_min + + def _update_vectors(self): + if self._data.shape[2] > 1 and not self._is_complex: + # pick vector as well + self._stc_vectors = self._pick_stc_tfr(self._data) + self._stc_vectors = self._pick_epoch(self._stc_vectors).astype(float) + self._stc_vectors /= self._data_max + self._update_vector_threshold() + self._plot_vectors() + + def _update_vector_threshold(self): + """Update the threshold for the vectors.""" + # apply threshold, use same mask as for stc_img + stc_data = self._pick_stc_tfr(self._stc_data_vol) + min_val, max_val = self._get_min_max_val() + self._stc_vectors_masked = self._stc_vectors.copy() + self._stc_vectors_masked[stc_data < min_val] = np.nan + self._stc_vectors_masked[stc_data > max_val] = np.nan + + def _update_stc_volume(self): + """Select volume based on the current time, frequency and vertex.""" + stc_data = self._pick_stc_tfr(self._stc_data_vol) + self._stc_img = _make_vol(self._src_lut, stc_data) + self._stc_img = _threshold_array(self._stc_img, *self._get_min_max_val()) + + def _update_stc_all(self): + """Update the data in both the slice plots and the data plot.""" + # pick new epochs + baseline correction combination + self._update_stc_pick() + self._update_stc_images() # and then make the new volume + self._update_intensity() + self._update_cmap() # note: this updates stc slice plots + self._plot_data() + if self._show_topomap and self._update: + self._plot_topomap() + + def _pick_stc_image(self): + """Select time-(frequency) image based on vertex.""" + return self._pick_stc_vertex(self._stc_data_vol) + + def _pick_epoch(self, stc_data): + """Select the source time course epoch based on the parameters.""" + if self._epoch_idx == "Average": + if stc_data.dtype == BASE_INT_DTYPE: + stc_data = stc_data.mean(axis=0).astype(BASE_INT_DTYPE) + else: + stc_data = stc_data.mean(axis=0) + elif self._epoch_idx == "Average Power": + if stc_data.dtype == COMPLEX_DTYPE: + stc_data = np.sum( + _int_complex_conj(stc_data) // stc_data.shape[0], + axis=0, + dtype=BASE_INT_DTYPE, + ) + else: + stc_data = (stc_data * stc_data.conj()).real.mean(axis=0) + elif self._epoch_idx == "ITC": + if stc_data.dtype == COMPLEX_DTYPE: + stc_data = stc_data["re"].astype(np.complex64) + 1j * stc_data["im"] + stc_data = np.abs((stc_data / np.abs(stc_data)).mean(axis=0)) + else: + stc_data = np.abs((stc_data / np.abs(stc_data)).mean(axis=0)) + else: + stc_data = stc_data[ + int(self._epoch_idx.replace(f"{self._selector_prefix} ", "")) + ] + if stc_data.dtype == COMPLEX_DTYPE: + stc_data = _int_complex_conj(stc_data) + elif self._is_complex: + stc_data = (stc_data * stc_data.conj()).real + return stc_data + + def _apply_vector_norm(self, stc_data, axis=1): + """Take the vector norm if source data is vector.""" + if self._epoch_idx == "ITC": + stc_data = np.max(stc_data, axis=axis) # take maximum ITC + elif stc_data.shape[axis] > 1: + stc_data = np.linalg.norm(stc_data, axis=axis) # take magnitude + # if self._data.dtype in (COMPLEX_DTYPE, BASE_INT_DTYPE): + # stc_data = stc_data.round().astype(BASE_INT_DTYPE) + else: + stc_data = np.take(stc_data, 0, axis=axis) + return stc_data + + def _apply_baseline_correction(self, stc_data): + """Apply the chosen baseline correction to the data.""" + if self._baseline != "none": # do baseline correction + stc_data = rescale( + stc_data.astype(float), + times=self._inst.times, + baseline=(float(self._bl_tmin), float(self._bl_tmax)), + mode=self._baseline, + copy=True, + ) + return stc_data + + def _pick_stc_vertex(self, stc_data): + """Select the vertex based on the cursor position.""" + src_coord = self._get_src_coord() + if ( + all( + [ + coord >= 0 and coord < dim + for coord, dim in zip(src_coord, self._src_lut.shape) + ] + ) + and self._src_lut[src_coord] >= 0 + ): + stc_data = stc_data[self._src_lut[src_coord]] + else: # out-of-bounds or unused vertex + stc_data = np.zeros(stc_data[:, 0].shape) * np.nan + return stc_data + + def _pick_stc_tfr(self, stc_data): + """Select the frequency and time based on GUI values.""" + stc_data = np.take(stc_data, self._t_idx, axis=-1) + f_idx = 0 if self._f_idx is None else self._f_idx + stc_data = np.take(stc_data, f_idx, axis=-1) + return stc_data + + def _configure_ui(self): + """Configure the main appearance of the user interface.""" + toolbar = self._configure_toolbar() + slider_bar = self._configure_sliders() + status_bar = self._configure_status_bar() + data_plot = self._configure_data_plot() + + plot_vbox = QVBoxLayout() + plot_vbox.addLayout(self._plt_grid) + + if self._show_topomap: + data_hbox = QHBoxLayout() + topo_plot = self._configure_topo_plot() + data_hbox.addWidget(topo_plot) + data_hbox.addWidget(data_plot) + plot_vbox.addLayout(data_hbox) + else: + plot_vbox.addWidget(data_plot) + + main_hbox = QHBoxLayout() + main_hbox.addLayout(slider_bar) + main_hbox.addLayout(plot_vbox) + + main_vbox = QVBoxLayout() + main_vbox.addLayout(toolbar) + main_vbox.addLayout(main_hbox) + main_vbox.addLayout(status_bar) + + central_widget = QWidget() + central_widget.setLayout(main_vbox) + self.setCentralWidget(central_widget) + + def _configure_toolbar(self): + """Make a bar with buttons for user interactions.""" + hbox = QHBoxLayout() + + help_button = QPushButton("Help") + help_button.released.connect(self._show_help) + hbox.addWidget(help_button) + + hbox.addStretch(8) + + if self._data.shape[0] > 1: + self._epoch_selector = QComboBox() + if not self._group: + if self._is_complex: + self._epoch_selector.addItems(["Average Power"]) + self._epoch_selector.addItems(["ITC"]) + else: + self._epoch_selector.addItems(["Average"]) + self._epoch_selector.addItems( + [f"{self._selector_prefix} {i}" for i in range(self._data.shape[0])] + ) + self._epoch_selector.setCurrentText(self._epoch_idx) + self._epoch_selector.currentTextChanged.connect(self._update_epoch) + self._epoch_selector.setSizeAdjustPolicy(QComboBox.AdjustToContents) + self._epoch_selector.keyPressEvent = self.keyPressEvent + hbox.addWidget(self._epoch_selector) + + return hbox + + def _show_help(self): + """Show the help menu.""" + QMessageBox.information( + self, + "Help", + "Help:\n" + "'+'/'-': zoom\nleft/right arrow: left/right\n" + "up/down arrow: superior/inferior\n" + "left angle bracket/right angle bracket: anterior/posterior", + ) + + def _configure_sliders(self): + """Make a bar with sliders on it.""" + + def make_label(name): + label = QLabel(name) + label.setAlignment(QtCore.Qt.AlignCenter) + return label + + # modified from: + # https://stackoverflow.com/questions/52689047/moving-qslider-to-mouse-click-position + class Slider(QSlider): + def mouseReleaseEvent(self, event): + if event.button() == QtCore.Qt.LeftButton: + event.accept() + value = ( + self.maximum() - self.minimum() + ) * event.pos().x() / self.width() + self.minimum() + value = np.clip(value, 0, SLIDER_WIDTH) + self.setValue(int(round(value))) + else: + super(Slider, self).mouseReleaseEvent(event) + + def make_slider(smin, smax, sval, sfun=None): + slider = Slider(QtCore.Qt.Horizontal) + slider.setMinimum(int(round(smin))) + slider.setMaximum(int(round(smax))) + slider.setValue(int(round(sval))) + slider.setTracking(False) # only update on release + if sfun is not None: + slider.valueChanged.connect(sfun) + slider.keyPressEvent = self.keyPressEvent + slider.setMinimumWidth(SLIDER_WIDTH) + return slider + + slider_layout = QVBoxLayout() + slider_layout.setContentsMargins(11, 11, 11, 11) # for aesthetics + + if hasattr(self._inst, "freqs"): + slider_layout.addWidget(make_label("Frequency (Hz)")) + self._freq_slider = make_slider( + 0, self._inst.freqs.size - 1, self._f_idx, self._update_freq + ) + slider_layout.addWidget(self._freq_slider) + freq_hbox = QHBoxLayout() + freq_hbox.addWidget(make_label(str(self._inst.freqs[0].round(2)))) + freq_hbox.addStretch(1) + freq_hbox.addWidget(make_label(str(self._inst.freqs[-1].round(2)))) + slider_layout.addLayout(freq_hbox) + self._freq_label = make_label( + f"Freq = {self._inst.freqs[self._f_idx].round(2)} Hz" + ) + slider_layout.addWidget(self._freq_label) + slider_layout.addStretch(1) + + slider_layout.addWidget(make_label("Time (s)")) + self._time_slider = make_slider( + 0, self._inst.times.size - 1, self._t_idx, self._update_time + ) + slider_layout.addWidget(self._time_slider) + time_hbox = QHBoxLayout() + time_hbox.addWidget(make_label(str(self._inst.times[0].round(2)))) + time_hbox.addStretch(1) + time_hbox.addWidget(make_label(str(self._inst.times[-1].round(2)))) + slider_layout.addLayout(time_hbox) + self._time_label = make_label( + f"Time = {self._inst.times[self._t_idx].round(2)} s" + ) + slider_layout.addWidget(self._time_label) + slider_layout.addStretch(1) + + slider_layout.addWidget(make_label("Alpha")) + self._alpha_slider = make_slider( + 0, SLIDER_WIDTH, int(self._alpha * SLIDER_WIDTH), self._update_alpha + ) + slider_layout.addWidget(self._alpha_slider) + self._alpha_label = make_label(f"Alpha = {self._alpha}") + slider_layout.addWidget(self._alpha_label) + slider_layout.addStretch(1) + + slider_layout.addWidget(make_label("min / mid / max")) + self._cmap_sliders = [ + make_slider(0, SLIDER_WIDTH, 0, self._update_cmap), + make_slider(0, SLIDER_WIDTH, SLIDER_WIDTH // 2, self._update_cmap), + make_slider(0, SLIDER_WIDTH, SLIDER_WIDTH, self._update_cmap), + ] + for slider in self._cmap_sliders: + slider_layout.addWidget(slider) + slider_layout.addStretch(1) + + return slider_layout + + def _configure_status_bar(self, hbox=None): + hbox = QHBoxLayout() if hbox is None else hbox + + hbox.addWidget(QLabel("Baseline")) + self._baseline_selector = QComboBox() + self._baseline_selector.addItems( + ["none", "mean", "ratio", "logratio", "percent", "zscore", "zlogratio"] + ) + self._baseline_selector.setCurrentText("none") + self._baseline_selector.currentTextChanged.connect(self._update_baseline) + self._baseline_selector.setSizeAdjustPolicy(QComboBox.AdjustToContents) + self._baseline_selector.keyPressEvent = self.keyPressEvent + hbox.addWidget(self._baseline_selector) + + hbox.addWidget(QLabel("tmin =")) + self._bl_tmin_textbox = QLineEdit(str(round(self._bl_tmin, 2))) + self._bl_tmin_textbox.setMaximumWidth(60) + self._bl_tmin_textbox.focusOutEvent = self._update_baseline_tmin + hbox.addWidget(self._bl_tmin_textbox) + + hbox.addWidget(QLabel("tmax =")) + self._bl_tmax_textbox = QLineEdit(str(round(self._bl_tmax, 2))) + self._bl_tmax_textbox.setMaximumWidth(60) + self._bl_tmax_textbox.focusOutEvent = self._update_baseline_tmax + hbox.addWidget(self._bl_tmax_textbox) + + # add separator for clarity + sep = QFrame() + sep.setFrameShape(QFrame.VLine) + sep.setFrameShadow(QFrame.Sunken) + hbox.addWidget(sep) + + hbox.addStretch(3 if self._f_idx is None else 2) + + if self._show_topomap: + hbox.addWidget(QLabel("Topo Data=")) + self._data_type_selector = QComboBox() + self._data_type_selector.addItems( + _get_channel_types(self._inst.info, picks="data", unique=True) + ) + self._data_type_selector.currentTextChanged.connect(self._update_data_type) + self._data_type_selector.setSizeAdjustPolicy(QComboBox.AdjustToContents) + self._data_type_selector.keyPressEvent = self.keyPressEvent + hbox.addWidget(self._data_type_selector) + hbox.addStretch(1) + + if self._f_idx is not None: + hbox.addWidget(QLabel("Interpolate")) + self._interp_button = QPushButton("On") + self._interp_button.setMaximumWidth(25) # not too big + self._interp_button.setStyleSheet("background-color: green") + hbox.addWidget(self._interp_button) + self._interp_button.released.connect(self._toggle_interp) + hbox.addStretch(1) + + self._go_to_extreme_button = QPushButton("Go to Max") + self._go_to_extreme_button.released.connect(self.go_to_extreme) + hbox.addWidget(self._go_to_extreme_button) + hbox.addStretch(2) + + self._intensity_label = QLabel("") # update later + hbox.addWidget(self._intensity_label) + + # add SliceBrowser navigation items + hbox = super(VolSourceEstimateViewer, self)._configure_status_bar(hbox=hbox) + return hbox + + def _configure_data_plot(self): + """Configure the plot that shows spectrograms/time-courses.""" + from ._core import _make_mpl_plot + + canvas, self._fig = _make_mpl_plot( + dpi=96, tight=False, hide_axes=False, invert=False, facecolor="white" + ) + self._fig.axes[0].set_position([0.12, 0.25, 0.73, 0.7]) + self._fig.axes[0].set_xlabel("Time (s)") + min_idx = np.argmin(abs(self._inst.times)) + self._fig.axes[0].set_xticks([0, min_idx, self._inst.times.size - 1]) + self._fig.axes[0].set_xticklabels(self._inst.times[[0, min_idx, -1]].round(2)) + stc_data = self._pick_stc_image() + if self._f_idx is None: + self._fig.axes[0].set_facecolor("black") + self._stc_plot = self._fig.axes[0].plot(stc_data[0], color="white")[0] + self._stc_vline = self._fig.axes[0].axvline(x=self._t_idx, color="lime") + self._fig.axes[0].set_ylabel("Activation (AU)") + self._cax = None + else: + self._stc_plot = self._fig.axes[0].imshow( + stc_data, aspect="auto", cmap=self._cmap, interpolation="bicubic" + ) + self._stc_vline = self._fig.axes[0].axvline( + x=self._t_idx, color="lime", linewidth=0.5 + ) + self._stc_hline = self._fig.axes[0].axhline( + y=self._f_idx, color="lime", linewidth=0.5 + ) + self._fig.axes[0].invert_yaxis() + self._fig.axes[0].set_ylabel("Frequency (Hz)") + self._fig.axes[0].set_yticks(range(self._inst.freqs.size)) + self._fig.axes[0].set_yticklabels(self._inst.freqs.round(2)) + self._cax = self._fig.add_axes([0.88, 0.25, 0.02, 0.6]) + self._cbar = self._fig.colorbar(self._stc_plot, cax=self._cax) + self._cax.set_ylabel("Power") + self._fig.canvas.mpl_connect("button_release_event", self._on_data_plot_click) + canvas.setMinimumHeight(int(self.size().height() * 0.4)) + canvas.keyPressEvent = self.keyPressEvent + return canvas + + def _plot_topomap(self): + self._topo_fig.axes[0].clear() + self._topo_cax.clear() + dtype = self._data_type_selector.currentText() + units = DEFAULTS["units"][dtype] + scaling = DEFAULTS["scalings"][dtype] + + inst = ( + self._insts[int(self._epoch_idx.replace(f"{self._selector_prefix} ", ""))] + if self._group + else self._inst + ) + if isinstance(inst, EpochsTFR): + inst_data = inst.data + scaling *= scaling # power is squared + units = f"({units})" + r"$^2$/Hz" + elif isinstance(inst, BaseEpochs): + inst_data = inst.get_data() + else: + inst_data = inst.data[None] # new axis for single epoch + + # convert to power or ITC for group + if self._group == "ITC" and np.iscomplexobj(inst_data): + inst_data = np.abs((inst_data / np.abs(inst_data))) + elif self._group and np.iscomplexobj(inst_data): # power + inst_data = (inst_data * inst_data.conj()).real + + if self._epoch_idx == "ITC": + units = "ITC" + scaling = 1 + + pick_idx = _picks_to_idx(inst.info, dtype) + inst_data = inst_data[:, pick_idx] + + evo_data = self._pick_epoch(inst_data) * scaling + + if self._f_idx is not None: + evo_data = evo_data[:, self._f_idx] + + if self._baseline != "none": + units = units if self._baseline == "mean" else "" + evo_data = rescale( + evo_data.astype(float), + times=self._inst.times, + baseline=(float(self._bl_tmin), float(self._bl_tmax)), + mode=self._baseline, + copy=False, + ) + + info = _pick_inst(inst, dtype, "bads").info + ave = EvokedArray(evo_data, info, tmin=self._inst.times[0]) + + ave_max = evo_data.max() + self._ave_min = min([evo_data.min(), -ave_max]) + self._ave_range = max([ave_max, -self._ave_min]) - self._ave_min + vmin, vmax = [ + val / SLIDER_WIDTH * self._ave_range + self._ave_min + for val in (self._cmap_sliders[i].value() for i in (0, 2)) + ] + cbar_fmt = "%3.1f" if abs(evo_data).max() < 1e3 else "%.1e" + ave.plot_topomap( + times=self._inst.times[self._t_idx], + scalings={dtype: 1}, + units=units, + axes=(self._topo_fig.axes[0], self._topo_cax), + cmap=self._cmap, + colorbar=True, + cbar_fmt=cbar_fmt, + vlim=(vmin, vmax), + show=False, + ) + + self._topo_fig.axes[0].set_title("") + self._topo_fig.subplots_adjust(top=1.1, bottom=0.05, right=0.75) + self._topo_fig.canvas.draw() + + def _configure_topo_plot(self): + """Configure the plot that shows topomap.""" + from ._core import _make_mpl_plot + + canvas, self._topo_fig = _make_mpl_plot( + dpi=96, hide_axes=False, facecolor="white" + ) + self._topo_cax = self._topo_fig.add_axes([0.77, 0.1, 0.02, 0.75]) + self._plot_topomap() + canvas.setMinimumHeight(int(self.size().height() * 0.4)) + canvas.setMaximumWidth(int(self.size().width() * 0.4)) + canvas.keyPressEvent = self.keyPressEvent + return canvas + + def keyPressEvent(self, event): + """Execute functions when the user presses a key.""" + super().keyPressEvent(event) + + # update if textbox done editing + if event.key() == QtCore.Qt.Key_Return: + for widget in (self._bl_tmin_textbox, self._bl_tmax_textbox): + if widget.hasFocus(): + widget.clearFocus() + self.setFocus() # removing focus calls focus out event + + def _on_data_plot_click(self, event): + """Update viewer when the data plot is clicked on.""" + if event.inaxes is self._fig.axes[0]: + if self._f_idx is not None: + self._update = False + self.set_freq(self._inst.freqs[int(round(event.ydata))]) + self._update = True + self.set_time(self._inst.times[int(round(event.xdata))]) + self._update_intensity() + + def set_baseline(self, baseline=None, mode=None): + """Set the baseline. + + Parameters + ---------- + baseline : array-like, shape (2,) | None + The time interval to apply rescaling / baseline correction. + If None do not apply it. If baseline is (a, b) + the interval is between "a (s)" and "b (s)". + If a is None the beginning of the data is used + and if b is None then b is set to the end of the interval. + If baseline is equal to (None, None) all the time + interval is used. + mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' + Perform baseline correction by + + - subtracting the mean of baseline values ('mean') + - dividing by the mean of baseline values ('ratio') + - dividing by the mean of baseline values and taking the log + ('logratio') + - subtracting the mean of baseline values followed by dividing by + the mean of baseline values ('percent') + - subtracting the mean of baseline values and dividing by the + standard deviation of baseline values ('zscore') + - dividing by the mean of baseline values, taking the log, and + dividing by the standard deviation of log baseline values + ('zlogratio') + tmin : float + The minimum baseline time + """ # noqa E501 + _check_option( + "mode", + mode, + ( + "mean", + "ratio", + "logratio", + "percent", + "zscore", + "zlogratio", + "none", + None, + ), + ) + self._update = False + self._baseline_selector.setCurrentText("none" if mode is None else mode) + if baseline is not None: + baseline = _check_baseline( + baseline, times=self._inst.times, sfreq=self._inst.info["sfreq"] + ) + tmin, tmax = baseline + self._bl_tmin_textbox.setText(str(tmin)) + self._bl_tmax_textbox.setText(str(tmax)) + self._update = True + self._update_stc_all() + + def _update_baseline(self, name): + """Update the chosen baseline normalization method.""" + self._baseline = name + pre_update = self._update + self._update = False + self._cmap_sliders[0].setValue(0) + self._cmap_sliders[1].setValue(SLIDER_WIDTH // 2) + self._update = pre_update + self._cmap_sliders[2].setValue(SLIDER_WIDTH) + # all baselines have negative support + self._cmap = _get_cmap("hot" if name == "none" and self._pos_support else "mne") + self._go_to_extreme_button.setText( + "Go to Max" if name == "none" and self._pos_support else "Go to Extreme" + ) + if self._update: # don't update if bl_tmin, bl_tmax are also changing + self._update_stc_all() + + def _update_baseline_tmin(self, event): + """Update tmin for the baseline.""" + try: + tmin = float(self._bl_tmin_textbox.text()) + except ValueError: + self._bl_tmin_textbox.setText(str(round(self._bl_tmin, 2))) + tmin = self._inst.times[ + np.clip( # find nearest time + self._inst.time_as_index(tmin, use_rounding=True)[0], + 0, + self._inst.times.size - 1, + ) + ] + if tmin == self._bl_tmin: + return + self._bl_tmin = tmin + if self._update: + self._update_stc_all() + + def _update_baseline_tmax(self, event): + """Update tmax for the baseline.""" + try: + tmax = float(self._bl_tmax_textbox.text()) + except ValueError: + self._bl_tmax_textbox.setText(str(round(self._bl_tmax, 2))) + return + tmax = self._inst.times[ + np.clip( # find nearest time + self._inst.time_as_index(tmax, use_rounding=True)[0], + 0, + self._inst.times.size - 1, + ) + ] + if tmax == self._bl_tmax: + return + self._bl_tmax = tmax + if self._update: + self._update_stc_all() + + def _update_data_type(self, dtype): + """Update which data type is shown in the topomap.""" + self._plot_topomap() + + def _update_data_plot_ylabel(self): + """Update the ylabel of the data plot.""" + if self._epoch_idx == "ITC": + self._cax.set_ylabel("ITC") + elif self._is_complex: + self._cax.set_ylabel("Power") + else: + self._fig.axes[0].set_ylabel("Activation (AU)") + + def _update_epoch(self, name): + """Change which epoch is viewed.""" + self._epoch_idx = name + # handle plot labels + self._update_data_plot_ylabel() + # reset sliders + if name == "ITC" and self._epoch_idx != "ITC": + self._cmap_sliders[0].setValue(0) + self._cmap_sliders[1].setValue(SLIDER_WIDTH // 2) + self._cmap_sliders[2].setValue(SLIDER_WIDTH) + self._baseline_selector.setCurrentText("none") + + if self._update: + self._update_stc_all() + self._update_vectors() + + def set_freq(self, freq): + """Set the frequency to display (in Hz). + + Parameters + ---------- + freq : float + The frequency to show, in Hz. + """ + if self._f_idx is None: + raise ValueError("Source estimate does not contain frequencies") + self._freq_slider.setValue(np.argmin(abs(self._inst.freqs - freq))) + + def _update_freq(self, event=None): + """Update freq slider values.""" + self._f_idx = self._freq_slider.value() + self._freq_label.setText(f"Freq = {self._inst.freqs[self._f_idx].round(2)} Hz") + if self._update: + self._update_stc_images() # just need volume updated here + self._stc_hline.set_ydata([self._f_idx]) + self._update_intensity() + if self._show_topomap and self._update: + self._plot_topomap() + self._fig.canvas.draw() + + def set_time(self, time): + """Set the time to display (in seconds). + + Parameters + ---------- + time : float + The time to show, in seconds. + """ + self._time_slider.setValue( + np.clip( + self._inst.time_as_index(time, use_rounding=True)[0], + 0, + self._inst.times.size - 1, + ) + ) + + def _update_time(self, event=None): + """Update time slider values.""" + self._t_idx = self._time_slider.value() + self._time_label.setText(f"Time = {self._inst.times[self._t_idx].round(2)} s") + if self._update: + self._update_stc_images() # just need volume updated here + self._stc_vline.set_xdata([self._t_idx]) + self._update_intensity() + if self._show_topomap and self._update: + self._plot_topomap() + self._update_vectors() + self._fig.canvas.draw() + + def set_alpha(self, alpha): + """Set the opacity of the display. + + Parameters + ---------- + alpha : float + The opacity to use. + """ + self._alpha_slider.setValue(np.clip(alpha, 0, 1)) + + def _update_alpha(self, event=None): + """Update stc plot alpha.""" + self._alpha = round(self._alpha_slider.value() / SLIDER_WIDTH, 2) + self._alpha_label.setText(f"Alpha = {self._alpha}") + for axis in range(3): + self._images["stc"][axis].set_alpha(self._alpha) + self._update_cmap() + + def set_cmap(self, vmin=None, vmid=None, vmax=None): + """Update the colormap. + + Parameters + ---------- + vmin : float + The minimum color value relative to the selected data in [0, 1]. + vmin : float + The middle color value relative to the selected data in [0, 1]. + vmin : float + The maximum color value relative to the selected data in [0, 1]. + """ + for val, name in zip((vmin, vmid, vmax), ("vmin", "vmid", "vmax")): + _validate_type(val, (int, float, None)) + + self._update = False + for i, val in enumerate((vmin, vmid, vmax)): + if val is not None: + _check_range(val, 0, 1, name) + self._cmap_sliders[i].setValue(int(round(val * SLIDER_WIDTH))) + self._update = True + self._update_cmap() + + def _update_cmap( + self, event=None, draw=True, update_slice_plots=True, update_3d=True + ): + """Update the colormap.""" + if not self._update: + return + + # no recursive updating + update_tmp = self._update + self._update = False + if self._cmap_sliders[0].value() > self._cmap_sliders[2].value(): + tmp = self._cmap_sliders[0].value() + self._cmap_sliders[0].setValue(self._cmap_sliders[2].value()) + self._cmap_sliders[2].setValue(tmp) + if self._cmap_sliders[1].value() > self._cmap_sliders[2].value(): + self._cmap_sliders[1].setValue(self._cmap_sliders[2].value()) + if self._cmap_sliders[1].value() < self._cmap_sliders[0].value(): + self._cmap_sliders[1].setValue(self._cmap_sliders[0].value()) + self._update = update_tmp + + vmin, vmid, vmax = [ + val / SLIDER_WIDTH * self._stc_range + self._stc_min + for val in (self._cmap_sliders[i].value() for i in range(3)) + ] + mid_pt = (vmid - vmin) / (vmax - vmin) + ctable = self._cmap( + np.concatenate([np.linspace(0, mid_pt, 128), np.linspace(mid_pt, 1, 128)]) + ) + cmap = LinearSegmentedColormap.from_list("stc", ctable.tolist(), N=256) + ctable = np.round(ctable * 255.0).astype(np.uint8) + if self._stc_min < 0: # make center values transparent + zero_pt = np.argmin(abs(np.linspace(vmin, vmax, 256))) + # 31 on either side of the zero point are made transparent + ctable[max([zero_pt - 31, 0]) : min([zero_pt + 32, 255]), 3] = 0 + else: # make low values transparent + ctable[:25, 3] = np.linspace(0, 255, 25) + + for axis in range(3): + self._images["stc"][axis].set_clim(vmin, vmax) + self._images["stc"][axis].set_cmap(cmap) + if draw and self._update: + self._figs[axis].canvas.draw() + + # update nans in slice plot image + if update_slice_plots and self._update: + self._update_stc_volume() + self._plot_stc_images(draw=draw) + + if self._f_idx is None: + self._fig.axes[0].set_ylim([self._stc_min, self._stc_min + self._stc_range]) + else: + self._stc_plot.set_clim(vmin, vmax) + self._stc_plot.set_cmap(cmap) + # update colorbar + self._cax.clear() + self._cbar = self._fig.colorbar(self._stc_plot, cax=self._cax) + self._update_data_plot_ylabel() + + if self._show_topomap: + topo_vmin, topo_vmax = [ + val / SLIDER_WIDTH * self._ave_range + self._ave_min + for val in (self._cmap_sliders[i].value() for i in (0, 2)) + ] + self._topo_fig.axes[0].get_images()[0].set_clim(topo_vmin, topo_vmax) + if draw and self._update: + self._topo_fig.canvas.draw() + + if draw and self._update: + self._fig.canvas.draw() + + if not update_3d: + return + + if self._data.shape[2] > 1 and not self._is_complex: + # update vector mask + self._update_vector_threshold() + self._plot_vectors(draw=False) + self._renderer._set_colormap_range( + actor=self._vector_actor, + ctable=ctable, + scalar_bar=None, + rng=[0, VECTOR_SCALAR], + ) + + # set alpha + ctable[ctable[:, 3] > self._alpha * 255, 3] = self._alpha * 255 + self._renderer._set_volume_range( + self._volume_pos, ctable, self._alpha, self._scalar_bar, [vmin, vmax] + ) + self._renderer._set_volume_range( + self._volume_neg, ctable, self._alpha, self._scalar_bar, [vmin, vmax] + ) + if draw and self._update: + self._renderer._update() + + def go_to_extreme(self): + """Go to the extreme intensity source vertex.""" + stc_idx, f_idx, t_idx = np.unravel_index( + np.nanargmax(abs(self._stc_data_vol)), self._stc_data_vol.shape + ) + if self._f_idx is not None: + self._freq_slider.setValue(f_idx) + self._time_slider.setValue(t_idx) + max_coord = np.array(np.where(self._src_lut == stc_idx)).flatten() + max_coord_mri = _coord_to_coord( + max_coord, self._src_vox_scan_ras_t, self._scan_ras_vox_t + ) + self._set_ras(apply_trans(self._vox_ras_t, max_coord_mri)) + + def _plot_data(self, draw=True): + """Update which coordinate's data is being shown.""" + stc_data = self._pick_stc_image() + if self._f_idx is None: # no freq data + self._stc_plot.set_ydata(stc_data[0]) + else: + self._stc_plot.set_data(stc_data) + if draw and self._update: + self._fig.canvas.draw() + + def _toggle_interp(self): + """Toggle interpolating the spectrogram data plot.""" + if self._interp_button.text() == "Off": + self._interp_button.setText("On") + self._interp_button.setStyleSheet("background-color: green") + else: # text == 'On', turn off + self._interp_button.setText("Off") + self._interp_button.setStyleSheet("background-color: red") + + self._stc_plot.set_interpolation( + "bicubic" if self._interp_button.text() == "On" else None + ) + if self._update: + self._fig.canvas.draw() + # draws data plot, fixes vmin, vmax + self._update_cmap(update_slice_plots=False, update_3d=False) + + def _update_intensity(self): + """Update the intensity label.""" + label_str = "{:.3f}" + if self._stc_range > 1e5: + label_str = "{:.3e}" + elif np.issubdtype(self._stc_img.dtype, np.integer): + label_str = "{:d}" + self._intensity_label.setText( + ("intensity = " + label_str).format( + self._stc_img[tuple(self._get_src_coord())] + ) + ) + + def _update_moved(self): + """Update when cursor position changes.""" + super()._update_moved() + self._update_intensity() + + @fill_doc + def set_3d_view( + self, roll=None, distance=None, azimuth=None, elevation=None, focalpoint=None + ): + """Orient camera to display view. + + Parameters + ---------- + %(roll)s + %(distance)s + %(azimuth)s + %(elevation)s + %(focalpoint)s + """ + self._renderer.set_camera( + roll=roll, + distance=distance, + azimuth=azimuth, + elevation=elevation, + focalpoint=focalpoint, + reset_camera=False, + ) + self._renderer._update() + + def _plot_vectors(self, draw=True): + """Update the vector plots.""" + if self._data.shape[2] > 1 and not self._is_complex: + self._vector_data.point_data["vec"] = ( + VECTOR_SCALAR * self._stc_vectors_masked + ) + if draw and self._update: + self._renderer._update() + + def _update_stc_images(self, draw=True): + """Update the stc image based on the time and frequency range.""" + self._update_stc_volume() + self._plot_stc_images(draw=draw) + self._plot_3d_stc(draw=draw) + + def _plot_3d_stc(self, draw=True): + """Update the 3D rendering.""" + self._plot_vectors(draw=False) + self._grid.cell_data["values"] = np.where( + np.isnan(self._stc_img), 0.0, self._stc_img + ).flatten(order="F") + if draw and self._update: + self._renderer._update() + + def _plot_stc_images(self, axis=None, draw=True): + """Update the stc image(s).""" + src_coord = self._get_src_coord() + for axis in range(3): + # ensure in bounds + if src_coord[axis] >= 0 and src_coord[axis] < self._stc_img.shape[axis]: + stc_slice = np.take(self._stc_img, src_coord[axis], axis=axis).T + else: + stc_slice = np.take(self._stc_img, 0, axis=axis).T * np.nan + self._images["stc"][axis].set_data(stc_slice) + if draw and self._update: + self._draw(axis) + + def _update_images(self, axis=None, draw=True): + """Update images when general changes happen.""" + self._plot_stc_images(axis=axis, draw=draw) + self._plot_data(draw=draw) + super()._update_images() diff --git a/mne_gui_addons/conftest.py b/mne_gui_addons/conftest.py new file mode 100644 index 0000000..25b7ed1 --- /dev/null +++ b/mne_gui_addons/conftest.py @@ -0,0 +1,2 @@ +# get all MNE fixtures and settings +from mne.conftest import * # noqa: F403 diff --git a/mne_gui_addons/tests/test_core.py b/mne_gui_addons/tests/test_core.py new file mode 100644 index 0000000..3e1c7ae --- /dev/null +++ b/mne_gui_addons/tests/test_core.py @@ -0,0 +1,72 @@ +# -*- coding: utf-8 -*- +# Authors: Alex Rockhill +# +# License: BSD-3-clause + +import numpy as np +from numpy.testing import assert_allclose + +import pytest + +from mne.datasets import testing +from mne.utils import catch_logging, use_log_level +from mne.viz.utils import _fake_click + +data_path = testing.data_path(download=False) +subject = "sample" +subjects_dir = data_path / "subjects" + + +@testing.requires_testing_data +def test_slice_browser_io(renderer_interactive_pyvistaqt): + """Test the input/output of the slice browser GUI.""" + nib = pytest.importorskip("nibabel") + from mne.gui._core import SliceBrowser + + with pytest.raises(ValueError, match="Base image is not aligned to MRI"): + SliceBrowser( + nib.MGHImage(np.ones((96, 96, 96), dtype=np.float32), np.eye(4)), + subject=subject, + subjects_dir=subjects_dir, + ) + + +# TODO: For some reason this leaves some stuff un-closed, we should fix it +@pytest.mark.allow_unclosed +@testing.requires_testing_data +def test_slice_browser_display(renderer_interactive_pyvistaqt): + """Test that the slice browser GUI displays properly.""" + pytest.importorskip("nibabel") + from mne.gui._core import SliceBrowser + + # test no seghead, fsaverage doesn't have seghead + with pytest.warns(RuntimeWarning, match="`seghead` not found"): + with catch_logging() as log: + gui = SliceBrowser( + subject="fsaverage", subjects_dir=subjects_dir, verbose=True + ) + log = log.getvalue() + assert "using marching cubes" in log + gui.close() + + # test functions + with pytest.warns(RuntimeWarning, match="`pial` surface not found"): + gui = SliceBrowser(subject=subject, subjects_dir=subjects_dir) + + # test RAS + gui._RAS_textbox.setText("10 10 10") + gui._RAS_textbox.focusOutEvent(event=None) + assert_allclose(gui._ras, [10, 10, 10]) + + # test vox + gui._VOX_textbox.setText("150, 150, 150") + gui._VOX_textbox.focusOutEvent(event=None) + assert_allclose(gui._ras, [23, 22, 23]) + + # test click + with use_log_level("debug"): + _fake_click( + gui._figs[2], gui._figs[2].axes[0], [137, 140], xform="data", kind="release" + ) + assert_allclose(gui._ras, [10, 12, 23]) + gui.close() diff --git a/mne_gui_addons/tests/test_ieeg_locate.py b/mne_gui_addons/tests/test_ieeg_locate.py new file mode 100644 index 0000000..6258b8f --- /dev/null +++ b/mne_gui_addons/tests/test_ieeg_locate.py @@ -0,0 +1,240 @@ +# -*- coding: utf-8 -*- +# Authors: Alex Rockhill +# +# License: BSD-3-clause + +import numpy as np +from numpy.testing import assert_allclose + +import pytest + +import mne +from mne.datasets import testing +from mne.transforms import apply_trans +from mne.utils import requires_version, use_log_level +from mne.viz.utils import _fake_click + +data_path = testing.data_path(download=False) +subject = "sample" +subjects_dir = data_path / "subjects" +sample_dir = data_path / "MEG" / subject +raw_path = sample_dir / "sample_audvis_trunc_raw.fif" +fname_trans = sample_dir / "sample_audvis_trunc-trans.fif" + + +@pytest.fixture +def _fake_CT_coords(skull_size=5, contact_size=2): + """Make somewhat realistic CT data with contacts.""" + nib = pytest.importorskip("nibabel") + brain = nib.load(subjects_dir / subject / "mri" / "brain.mgz") + verts = mne.read_surface(subjects_dir / subject / "bem" / "outer_skull.surf")[0] + verts = apply_trans(np.linalg.inv(brain.header.get_vox2ras_tkr()), verts) + x, y, z = np.array(brain.shape).astype(int) // 2 + coords = [ + (x, y - 14, z), + (x - 10, y - 15, z), + (x - 20, y - 16, z + 1), + (x - 30, y - 16, z + 1), + ] + center = np.array(brain.shape) / 2 + # make image + np.random.seed(99) + ct_data = np.random.random(brain.shape).astype(np.float32) * 100 + # make skull + for vert in verts: + x, y, z = np.round(vert).astype(int) + ct_data[ + slice(x - skull_size, x + skull_size + 1), + slice(y - skull_size, y + skull_size + 1), + slice(z - skull_size, z + skull_size + 1), + ] = 1000 + # add electrode with contacts + for x, y, z in coords: + # make sure not in skull + assert np.linalg.norm(center - np.array((x, y, z))) < 50 + ct_data[ + slice(x - contact_size, x + contact_size + 1), + slice(y - contact_size, y + contact_size + 1), + slice(z - contact_size, z + contact_size + 1), + ] = 1000 - np.linalg.norm( + np.array(np.meshgrid(*[range(-contact_size, contact_size + 1)] * 3)), axis=0 + ) + ct = nib.MGHImage(ct_data, brain.affine) + coords = apply_trans(ct.header.get_vox2ras_tkr(), np.array(coords)) + return ct, coords + + +def test_ieeg_elec_locate_io(renderer_interactive_pyvistaqt): + """Test the input/output of the intracranial location GUI.""" + nib = pytest.importorskip("nibabel") + import mne.gui + + info = mne.create_info([], 1000) + + # fake as T1 so that aligned + aligned_ct = nib.load(subjects_dir / subject / "mri" / "brain.mgz") + + trans = mne.transforms.Transform("head", "mri") + with pytest.raises(ValueError, match="No channels found in `info` to locate"): + mne.gui.locate_ieeg(info, trans, aligned_ct, subject, subjects_dir) + + info = mne.create_info(["test"], 1000, "seeg") + montage = mne.channels.make_dig_montage({"test": [0, 0, 0]}, coord_frame="mri") + with pytest.warns(RuntimeWarning, match="nasion not found"): + info.set_montage(montage) + with pytest.raises(RuntimeError, match='must be in the "head" coordinate frame'): + with pytest.warns(RuntimeWarning, match="`pial` surface not found"): + mne.gui.locate_ieeg(info, trans, aligned_ct, subject, subjects_dir) + + +@pytest.mark.allow_unclosed_pyside2 +@requires_version("sphinx_gallery") +@testing.requires_testing_data +def test_locate_scraper(renderer_interactive_pyvistaqt, _fake_CT_coords, tmp_path): + """Test sphinx-gallery scraping of the GUI.""" + import mne.gui + + raw = mne.io.read_raw_fif(raw_path) + raw.pick_types(eeg=True) + ch_dict = { + "EEG 001": "LAMY 1", + "EEG 002": "LAMY 2", + "EEG 003": "LSTN 1", + "EEG 004": "LSTN 2", + } + raw.pick_channels(list(ch_dict.keys())) + raw.rename_channels(ch_dict) + raw.set_montage(None) + aligned_ct, _ = _fake_CT_coords + trans = mne.read_trans(fname_trans) + with pytest.warns(RuntimeWarning, match="`pial` surface not found"): + gui = mne.gui.locate_ieeg( + raw.info, trans, aligned_ct, subject=subject, subjects_dir=subjects_dir + ) + (tmp_path / "_images").mkdir() + image_path = tmp_path / "_images" / "temp.png" + gallery_conf = dict(builder_name="html", src_dir=tmp_path) + block_vars = dict( + example_globals=dict(gui=gui), image_path_iterator=iter([str(image_path)]) + ) + assert not image_path.is_file() + assert not getattr(gui, "_scraped", False) + mne.gui._GUIScraper()(None, block_vars, gallery_conf) + assert image_path.is_file() + assert gui._scraped + # no need to call .close + + +@pytest.mark.allow_unclosed_pyside2 +@testing.requires_testing_data +def test_ieeg_elec_locate_display(renderer_interactive_pyvistaqt, _fake_CT_coords): + """Test that the intracranial location GUI displays properly.""" + raw = mne.io.read_raw_fif(raw_path, preload=True) + raw.pick_types(eeg=True) + ch_dict = { + "EEG 001": "LAMY 1", + "EEG 002": "LAMY 2", + "EEG 003": "LSTN 1", + "EEG 004": "LSTN 2", + } + raw.pick_channels(list(ch_dict.keys())) + raw.rename_channels(ch_dict) + raw.set_eeg_reference("average") + raw.set_channel_types({name: "seeg" for name in raw.ch_names}) + raw.set_montage(None) + aligned_ct, coords = _fake_CT_coords + trans = mne.read_trans(fname_trans) + + with pytest.warns(RuntimeWarning, match="`pial` surface not found"): + gui = mne.gui.locate_ieeg( + raw.info, + trans, + aligned_ct, + subject=subject, + subjects_dir=subjects_dir, + verbose=True, + ) + + with pytest.raises(ValueError, match="read-only"): + gui._ras[:] = coords[0] # start in the right position + gui.set_RAS(coords[0]) + gui.mark_channel() + + with pytest.raises(ValueError, match="not found"): + gui.mark_channel("foo") + + assert not gui._lines and not gui._lines_2D # no lines for one contact + for ci, coord in enumerate(coords[1:], 1): + coord_vox = apply_trans(gui._ras_vox_t, coord) + with use_log_level("debug"): + _fake_click( + gui._figs[2], + gui._figs[2].axes[0], + coord_vox[:-1], + xform="data", + kind="release", + ) + assert_allclose(coord[:2], gui._ras[:2], atol=0.1, err_msg=f"coords[{ci}][:2]") + assert_allclose(coord[2], gui._ras[2], atol=2, err_msg=f"coords[{ci}][2]") + gui.mark_channel() + + # ensure a 3D line was made for each group + assert len(gui._lines) == 2 + + # test snap to center + gui._ch_index = 0 + gui.set_RAS(coords[0]) # move to first position + gui.mark_channel() + assert_allclose(coords[0], gui._chs["LAMY 1"], atol=0.2) + gui._snap_button.click() + assert gui._snap_button.text() == "Off" + # now make sure no snap happens + gui._ch_index = 0 + gui.set_RAS(coords[1] + 1) + gui.mark_channel() + assert_allclose(coords[1] + 1, gui._chs["LAMY 1"], atol=0.01) + # check that it turns back on + gui._snap_button.click() + assert gui._snap_button.text() == "On" + + # test remove + gui.remove_channel("LAMY 2") + assert np.isnan(gui._chs["LAMY 2"]).all() + + with pytest.raises(ValueError, match="not found"): + gui.remove_channel("foo") + + # check that raw object saved + assert not np.isnan(raw.info["chs"][0]["loc"][:3]).any() # LAMY 1 + assert np.isnan(raw.info["chs"][1]["loc"][:3]).all() # LAMY 2 (removed) + + # move sliders + gui._alpha_slider.setValue(75) + assert gui._ch_alpha == 0.75 + gui._radius_slider.setValue(5) + assert gui._radius == 5 + ct_sum_before = np.nansum(gui._images["ct"][0].get_array().data) + gui._ct_min_slider.setValue(500) + assert np.nansum(gui._images["ct"][0].get_array().data) < ct_sum_before + + # test buttons + gui._toggle_show_brain() + assert "mri" in gui._images + assert "local_max" not in gui._images + gui._toggle_show_max() + assert "local_max" in gui._images + assert "mip" not in gui._images + gui._toggle_show_mip() + assert "mip" in gui._images + assert "mip_chs" in gui._images + assert len(gui._lines_2D) == 1 # LAMY only has one contact + + # check montage + montage = raw.get_montage() + assert montage is not None + assert_allclose( + montage.get_positions()["ch_pos"]["LAMY 1"], + [0.00726235, 0.01713514, 0.04167233], + atol=0.01, + ) + gui.close() diff --git a/mne_gui_addons/tests/test_vol_stc.py b/mne_gui_addons/tests/test_vol_stc.py new file mode 100644 index 0000000..54d1e5c --- /dev/null +++ b/mne_gui_addons/tests/test_vol_stc.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +# Authors: Alex Rockhill +# +# License: BSD-3-clause + +import sys +import numpy as np +from numpy.testing import assert_allclose + +import pytest + +import mne +from mne.datasets import testing +from mne.io.constants import FIFF +from mne.viz.utils import _fake_click + +data_path = testing.data_path(download=False) +subject = "sample" +subjects_dir = data_path / "subjects" +fname_raw = data_path / "MEG" / "sample" / "sample_audvis_trunc_raw.fif" +fname_fwd_vol = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-vol-7-fwd.fif" +fname_fwd = data_path / "MEG" / "sample" / "sample_audvis_trunc-meg-eeg-oct-4-fwd.fif" + +# TO DO: remove when Azure fixed, causes +# 'Windows fatal exception: access violation' +# but fails to replicate locally on a Windows machine +if sys.platform == "win32": + pytest.skip("Azure CI problem on Windows", allow_module_level=True) + + +def _fake_stc(src_type="vol"): + """Fake a 5D source time estimate.""" + rng = np.random.default_rng(11) + n_epochs = 3 + info = mne.io.read_info(fname_raw) + info = mne.pick_info(info, mne.pick_types(info, meg="grad")) + if src_type == "vol": + src = mne.setup_volume_source_space( + subject="sample", + subjects_dir=subjects_dir, + mri="aseg.mgz", + volume_label="Left-Cerebellum-Cortex", + pos=20, + add_interpolator=False, + ) + else: + assert src_type == "surf" + forward = mne.read_forward_solution(fname_fwd) + src = forward["src"] + for this_src in src: + this_src["coord_frame"] = FIFF.FIFFV_COORD_MRI + this_src["subject_his_id"] = "sample" + freqs = np.arange(8, 10) + times = np.arange(0.1, 0.11, 1 / info["sfreq"]) + data = rng.integers( + -1000, 1000, size=(n_epochs, len(info.ch_names), freqs.size, times.size) + ) + 1j * rng.integers( + -1000, 1000, size=(n_epochs, len(info.ch_names), freqs.size, times.size) + ) + epochs_tfr = mne.time_frequency.EpochsTFR(info, data, times=times, freqs=freqs) + nuse = sum([this_src["nuse"] for this_src in src]) + stc_data = rng.integers( + -1000, 1000, size=(n_epochs, nuse, 3, freqs.size, times.size) + ) + 1j * rng.integers(-1000, 1000, size=(n_epochs, nuse, 3, freqs.size, times.size)) + return stc_data, src, epochs_tfr + + +@pytest.mark.allow_unclosed_pyside2 +def test_stc_viewer_io(renderer_interactive_pyvistaqt): + """Test the input/output of the stc viewer GUI.""" + pytest.importorskip("nibabel") + pytest.importorskip("dipy") + from mne_gui_addons._vol_stc import VolSourceEstimateViewer + + stc_data, src, epochs_tfr = _fake_stc() + with pytest.raises( + NotImplementedError, + match="surface source estimate " "viewing is not yet supported", + ): + VolSourceEstimateViewer(stc_data, inst=epochs_tfr) + with pytest.raises(NotImplementedError, match="source estimate object"): + VolSourceEstimateViewer(stc_data, src=src) + with pytest.raises(ValueError, match="`data` must be an array"): + VolSourceEstimateViewer( + "foo", subject="sample", subjects_dir=subjects_dir, src=src, inst=epochs_tfr + ) + with pytest.raises(ValueError, match="Number of epochs in `inst` does not match"): + VolSourceEstimateViewer(stc_data[1:], src=src, inst=epochs_tfr) + with pytest.raises(RuntimeError, match="ource vertices in `data` do not match "): + VolSourceEstimateViewer( + stc_data[:, :1], + subject="sample", + subjects_dir=subjects_dir, + src=src, + inst=epochs_tfr, + ) + src[0]["coord_frame"] = FIFF.FIFFV_COORD_HEAD + with pytest.raises(RuntimeError, match="must be in the `mri`"): + VolSourceEstimateViewer( + stc_data, + subject="sample", + subjects_dir=subjects_dir, + src=src, + inst=epochs_tfr, + ) + src[0]["coord_frame"] = FIFF.FIFFV_COORD_MRI + + src[0]["subject_his_id"] = "foo" + with pytest.raises(RuntimeError, match="Source space subject"): + with pytest.warns(RuntimeWarning, match="`pial` surface not found"): + VolSourceEstimateViewer( + stc_data, + subject="sample", + subjects_dir=subjects_dir, + src=src, + inst=epochs_tfr, + ) + + with pytest.raises(ValueError, match="Frequencies in `inst` do not match"): + VolSourceEstimateViewer(stc_data[:, :, :, 1:], src=src, inst=epochs_tfr) + + with pytest.raises(ValueError, match="Complex data is required"): + VolSourceEstimateViewer(stc_data.real, src=src, inst=epochs_tfr) + + with pytest.raises(ValueError, match="Times in `inst` do not match"): + VolSourceEstimateViewer(stc_data[:, :, :, :, 1:], src=src, inst=epochs_tfr) + + +@pytest.mark.allow_unclosed_pyside2 +@testing.requires_testing_data +def test_stc_viewer_display(renderer_interactive_pyvistaqt): + """Test that the stc viewer GUI displays properly.""" + pytest.importorskip("nibabel") + pytest.importorskip("dipy") + from mne_gui_addons._vol_stc import VolSourceEstimateViewer + + stc_data, src, epochs_tfr = _fake_stc() + with pytest.warns(RuntimeWarning, match="`pial` surface not found"): + viewer = VolSourceEstimateViewer( + stc_data, + subject="sample", + subjects_dir=subjects_dir, + src=src, + inst=epochs_tfr, + ) + # test go to max + viewer._go_to_extreme_button.click() + assert_allclose(viewer._ras, [-20, -60, -20], atol=0.01) + + src_coord = viewer._get_src_coord() + stc_idx = viewer._src_lut[src_coord] + + viewer._epoch_selector.setCurrentText("Epoch 0") + assert viewer._epoch_idx == "Epoch 0" + + viewer._freq_slider.setValue(1) + assert viewer._f_idx == 1 + + viewer._time_slider.setValue(2) + assert viewer._t_idx == 2 + + plot_data = np.linalg.norm((stc_data[0] * stc_data[0].conj()).real, axis=1)[stc_idx] + assert_allclose(plot_data, viewer._stc_plot.get_array()) + + # test clicking on stc plot + _fake_click(viewer._fig, viewer._fig.axes[0], (0, 0), xform="data", kind="release") + assert viewer._t_idx == 0 + assert viewer._f_idx == 0 + + # test baseline + for mode in ("zscore", "ratio"): + viewer.set_baseline((0.1, None), mode) + + # done with time-frequency, close + viewer.close() + + # test time only, not frequencies + epochs = mne.EpochsArray( + epochs_tfr.data[:, :, 0].real, epochs_tfr.info, tmin=epochs_tfr.tmin + ) + stc_time_data = stc_data[:, :, :, 0:1].real + with pytest.warns(RuntimeWarning, match="`pial` surface not found"): + viewer = VolSourceEstimateViewer( + stc_time_data, + subject="sample", + subjects_dir=subjects_dir, + src=src, + inst=epochs, + ) + + # test go to max + viewer._go_to_extreme_button.click() + assert_allclose(viewer._ras, [-20, -60, -20], atol=0.01) + + src_coord = viewer._get_src_coord() + stc_idx = viewer._src_lut[src_coord] + + viewer._epoch_selector.setCurrentText("Epoch 0") + assert viewer._epoch_idx == "Epoch 0" + + with pytest.raises( + ValueError, match="Source estimate does " "not contain frequencies" + ): + viewer.set_freq(10) + + viewer._time_slider.setValue(2) + assert viewer._t_idx == 2 + + assert_allclose( + np.linalg.norm(stc_time_data[0], axis=1)[stc_idx][0], + viewer._stc_plot.get_data()[1], + ) + viewer.close() + + +@testing.requires_testing_data +def test_stc_viewer_surface(renderer_interactive_pyvistaqt): + """Test the stc viewer with a surface source space.""" + pytest.importorskip("nibabel") + pytest.importorskip("dipy") + from mne_gui_addons._vol_stc import VolSourceEstimateViewer + + stc_data, src, epochs_tfr = _fake_stc(src_type="surf") + with pytest.raises(RuntimeError, match="not implemented yet"): + VolSourceEstimateViewer( + stc_data, + subject="sample", + subjects_dir=subjects_dir, + src=src, + inst=epochs_tfr, + ) diff --git a/pyproject.toml b/pyproject.toml index e2b2719..75679fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,9 @@ dependencies = [ "pyvista", "pyvistaqt", "mne", + "nibabel", + "dipy>=1.4", + "traitlets", "setuptools >=65", ] dynamic = ["version"] @@ -27,7 +30,6 @@ dynamic = ["version"] tests = [ "pytest", "pytest-cov", - "pytest-qt", "black", # function signature formatting ]