Skip to content

Commit

Permalink
Backport PR #13067 on branch maint/1.9 ([BUG] Fix taper weighting in …
Browse files Browse the repository at this point in the history
…computation of TFR multitaper power) (#13072)

Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
tsbinns and larsoner authored Jan 21, 2025
1 parent 672bdf4 commit 96d22f8
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 42 deletions.
1 change: 1 addition & 0 deletions doc/changes/devel/13067.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where taper weights were not correctly applied when computing multitaper power with :meth:`mne.Epochs.compute_tfr` and :func:`mne.time_frequency.tfr_array_multitaper`, by `Thomas Binns`_.
11 changes: 9 additions & 2 deletions mne/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ def export_raw(
%(export_warning)s
.. warning::
When exporting ``Raw`` with annotations, ``raw.info["meas_date"]`` must be the
same as ``raw.annotations.orig_time``. This guarantees that the annotations are
in the same reference frame as the samples. When
:attr:`Raw.first_time <mne.io.Raw.first_time>` is not zero (e.g., after
cropping), the onsets are automatically corrected so that onsets are always
relative to the first sample.
Parameters
----------
%(fname_export_params)s
Expand Down Expand Up @@ -216,7 +224,6 @@ def _infer_check_export_fmt(fmt, fname, supported_formats):

supported_str = ", ".join(supported)
raise ValueError(
f"Format '{fmt}' is not supported. "
f"Supported formats are {supported_str}."
f"Format '{fmt}' is not supported. Supported formats are {supported_str}."
)
return fmt
16 changes: 0 additions & 16 deletions mne/time_frequency/tests/test_tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,22 +255,6 @@ def test_tfr_morlet():
# computed within the method.
assert_allclose(epochs_amplitude_2.data**2, epochs_power_picks.data)

# test that averaging power across tapers when multitaper with
# output='complex' gives the same as output='power'
epoch_data = epochs.get_data()
multitaper_power = tfr_array_multitaper(
epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="power"
)
multitaper_complex = tfr_array_multitaper(
epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="complex"
)

taper_dim = 2
power_from_complex = (multitaper_complex * multitaper_complex.conj()).real.mean(
axis=taper_dim
)
assert_allclose(power_from_complex, multitaper_power)

print(itc) # test repr
print(itc.ch_names) # test property
itc += power # test add
Expand Down
52 changes: 33 additions & 19 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def _make_dpss(
The wavelets time series.
"""
Ws = list()
Cs = list()

freqs = np.array(freqs)
if np.any(freqs <= 0):
Expand All @@ -281,6 +282,7 @@ def _make_dpss(

for m in range(n_taps):
Wm = list()
Cm = list()
for k, f in enumerate(freqs):
if len(n_cycles) != 1:
this_n_cycles = n_cycles[k]
Expand All @@ -302,12 +304,15 @@ def _make_dpss(
real_offset = Wk.mean()
Wk -= real_offset
Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel())
Ck = np.sqrt(conc[m])

Wm.append(Wk)
Cm.append(Ck)

Ws.append(Wm)
Cs.append(Cm)
if return_weights:
return Ws, conc
return Ws, Cs
return Ws


Expand Down Expand Up @@ -529,15 +534,18 @@ def _compute_tfr(
if method == "morlet":
W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean)
Ws = [W] # to have same dimensionality as the 'multitaper' case
weights = None # no tapers for Morlet estimates

elif method == "multitaper":
Ws = _make_dpss(
Ws, weights = _make_dpss(
sfreq,
freqs,
n_cycles=n_cycles,
time_bandwidth=time_bandwidth,
zero_mean=zero_mean,
return_weights=True, # required for converting complex → power
)
weights = np.asarray(weights)

# Check wavelets
if len(Ws[0][0]) > epoch_data.shape[2]:
Expand All @@ -560,7 +568,7 @@ def _compute_tfr(
if ("avg_" in output) or ("itc" in output):
out = np.empty((n_chans, n_freqs, n_times), dtype)
elif output in ["complex", "phase"] and method == "multitaper":
out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype)
out = np.empty((n_chans, n_epochs, n_tapers, n_freqs, n_times), dtype)
else:
out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype)

Expand All @@ -571,7 +579,7 @@ def _compute_tfr(

# Parallelization is applied across channels.
tfrs = parallel(
my_cwt(channel, Ws, output, use_fft, "same", decim, method)
my_cwt(channel, Ws, output, use_fft, "same", decim, weights)
for channel in epoch_data.transpose(1, 0, 2)
)

Expand All @@ -581,10 +589,8 @@ def _compute_tfr(

if ("avg_" not in output) and ("itc" not in output):
# This is to enforce that the first dimension is for epochs
if output in ["complex", "phase"] and method == "multitaper":
out = out.transpose(2, 0, 1, 3, 4)
else:
out = out.transpose(1, 0, 2, 3)
out = np.moveaxis(out, 1, 0)

return out


Expand Down Expand Up @@ -658,7 +664,7 @@ def _check_tfr_param(
return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim


def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, weights=None):
"""Aux. function to _compute_tfr.
Loops time-frequency transform across wavelets and epochs.
Expand All @@ -685,9 +691,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
See numpy.convolve.
decim : slice
The decimation slice: e.g. power[:, decim]
method : str | None
Used only for multitapering to create tapers dimension in the output
if ``output in ['complex', 'phase']``.
weights : array, shape (n_tapers, n_wavelets) | None
Concentration weights for each taper in the wavelets, if present.
"""
# Set output type
dtype = np.float64
Expand All @@ -701,10 +706,12 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
n_freqs = len(Ws[0])
if ("avg_" in output) or ("itc" in output):
tfrs = np.zeros((n_freqs, n_times), dtype=dtype)
elif output in ["complex", "phase"] and method == "multitaper":
tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype)
elif output in ["complex", "phase"] and weights is not None:
tfrs = np.zeros((n_epochs, n_tapers, n_freqs, n_times), dtype=dtype)
else:
tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype)
if weights is not None:
weights = np.expand_dims(weights, axis=-1) # add singleton time dimension

# Loops across tapers.
for taper_idx, W in enumerate(Ws):
Expand All @@ -719,6 +726,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
# Loop across epochs
for epoch_idx, tfr in enumerate(coefs):
# Transform complex values
if output not in ["complex", "phase"] and weights is not None:
tfr = weights[taper_idx] * tfr # weight each taper estimate
if output in ["power", "avg_power"]:
tfr = (tfr * tfr.conj()).real # power
elif output == "phase":
Expand All @@ -734,8 +743,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
# Stack or add
if ("avg_" in output) or ("itc" in output):
tfrs += tfr
elif output in ["complex", "phase"] and method == "multitaper":
tfrs[taper_idx, epoch_idx] += tfr
elif output in ["complex", "phase"] and weights is not None:
tfrs[epoch_idx, taper_idx] += tfr
else:
tfrs[epoch_idx] += tfr

Expand All @@ -749,9 +758,14 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None):
if ("avg_" in output) or ("itc" in output):
tfrs /= n_epochs

# Normalization by number of taper
if n_tapers > 1 and output not in ["complex", "phase"]:
tfrs /= n_tapers
# Normalization by taper weights
if n_tapers > 1 and output not in ["complex", "phase", "itc"]:
if "avg_" not in output: # add singleton epochs dimension to weights
weights = np.expand_dims(weights, axis=0)
tfrs.real *= 2 / (weights * weights.conj()).real.sum(axis=-3)
if output == "avg_power_itc": # weight itc by the number of tapers
tfrs.imag = tfrs.imag / n_tapers

return tfrs


Expand Down
13 changes: 8 additions & 5 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,19 +1494,22 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75):

docdict["export_fmt_support_epochs"] = """\
Supported formats:
- EEGLAB (``.set``, uses :mod:`eeglabio`)
- EEGLAB (``.set``, uses :mod:`eeglabio`)
"""

docdict["export_fmt_support_evoked"] = """\
Supported formats:
- MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`)
- MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`)
"""

docdict["export_fmt_support_raw"] = """\
Supported formats:
- BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv <https://github.com/bids-standard/pybv>`_)
- EEGLAB (``.set``, uses :mod:`eeglabio`)
- EDF (``.edf``, uses `edfio <https://github.com/the-siesta-group/edfio>`_)
- BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv <https://github.com/bids-standard/pybv>`_)
- EEGLAB (``.set``, uses :mod:`eeglabio`)
- EDF (``.edf``, uses `edfio <https://github.com/the-siesta-group/edfio>`_)
""" # noqa: E501

docdict["export_warning"] = """\
Expand Down

0 comments on commit 96d22f8

Please sign in to comment.