diff --git a/doc/changes/devel/13067.bugfix.rst b/doc/changes/devel/13067.bugfix.rst new file mode 100644 index 00000000000..237df7623d5 --- /dev/null +++ b/doc/changes/devel/13067.bugfix.rst @@ -0,0 +1 @@ +Fix bug where taper weights were not correctly applied when computing multitaper power with :meth:`mne.Epochs.compute_tfr` and :func:`mne.time_frequency.tfr_array_multitaper`, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/export/_export.py b/mne/export/_export.py index 490bf986895..4b93fda917e 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -25,6 +25,14 @@ def export_raw( %(export_warning)s + .. warning:: + When exporting ``Raw`` with annotations, ``raw.info["meas_date"]`` must be the + same as ``raw.annotations.orig_time``. This guarantees that the annotations are + in the same reference frame as the samples. When + :attr:`Raw.first_time ` is not zero (e.g., after + cropping), the onsets are automatically corrected so that onsets are always + relative to the first sample. + Parameters ---------- %(fname_export_params)s @@ -216,7 +224,6 @@ def _infer_check_export_fmt(fmt, fname, supported_formats): supported_str = ", ".join(supported) raise ValueError( - f"Format '{fmt}' is not supported. " - f"Supported formats are {supported_str}." + f"Format '{fmt}' is not supported. Supported formats are {supported_str}." ) return fmt diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index cd3a97ab90a..26645370aec 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -255,22 +255,6 @@ def test_tfr_morlet(): # computed within the method. assert_allclose(epochs_amplitude_2.data**2, epochs_power_picks.data) - # test that averaging power across tapers when multitaper with - # output='complex' gives the same as output='power' - epoch_data = epochs.get_data() - multitaper_power = tfr_array_multitaper( - epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="power" - ) - multitaper_complex = tfr_array_multitaper( - epoch_data, epochs.info["sfreq"], freqs, n_cycles, output="complex" - ) - - taper_dim = 2 - power_from_complex = (multitaper_complex * multitaper_complex.conj()).real.mean( - axis=taper_dim - ) - assert_allclose(power_from_complex, multitaper_power) - print(itc) # test repr print(itc.ch_names) # test property itc += power # test add diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index eaf173092bb..51839e23226 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -266,6 +266,7 @@ def _make_dpss( The wavelets time series. """ Ws = list() + Cs = list() freqs = np.array(freqs) if np.any(freqs <= 0): @@ -281,6 +282,7 @@ def _make_dpss( for m in range(n_taps): Wm = list() + Cm = list() for k, f in enumerate(freqs): if len(n_cycles) != 1: this_n_cycles = n_cycles[k] @@ -302,12 +304,15 @@ def _make_dpss( real_offset = Wk.mean() Wk -= real_offset Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) + Ck = np.sqrt(conc[m]) Wm.append(Wk) + Cm.append(Ck) Ws.append(Wm) + Cs.append(Cm) if return_weights: - return Ws, conc + return Ws, Cs return Ws @@ -529,15 +534,18 @@ def _compute_tfr( if method == "morlet": W = morlet(sfreq, freqs, n_cycles=n_cycles, zero_mean=zero_mean) Ws = [W] # to have same dimensionality as the 'multitaper' case + weights = None # no tapers for Morlet estimates elif method == "multitaper": - Ws = _make_dpss( + Ws, weights = _make_dpss( sfreq, freqs, n_cycles=n_cycles, time_bandwidth=time_bandwidth, zero_mean=zero_mean, + return_weights=True, # required for converting complex → power ) + weights = np.asarray(weights) # Check wavelets if len(Ws[0][0]) > epoch_data.shape[2]: @@ -560,7 +568,7 @@ def _compute_tfr( if ("avg_" in output) or ("itc" in output): out = np.empty((n_chans, n_freqs, n_times), dtype) elif output in ["complex", "phase"] and method == "multitaper": - out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype) + out = np.empty((n_chans, n_epochs, n_tapers, n_freqs, n_times), dtype) else: out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) @@ -571,7 +579,7 @@ def _compute_tfr( # Parallelization is applied across channels. tfrs = parallel( - my_cwt(channel, Ws, output, use_fft, "same", decim, method) + my_cwt(channel, Ws, output, use_fft, "same", decim, weights) for channel in epoch_data.transpose(1, 0, 2) ) @@ -581,10 +589,8 @@ def _compute_tfr( if ("avg_" not in output) and ("itc" not in output): # This is to enforce that the first dimension is for epochs - if output in ["complex", "phase"] and method == "multitaper": - out = out.transpose(2, 0, 1, 3, 4) - else: - out = out.transpose(1, 0, 2, 3) + out = np.moveaxis(out, 1, 0) + return out @@ -658,7 +664,7 @@ def _check_tfr_param( return freqs, sfreq, zero_mean, n_cycles, time_bandwidth, decim -def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): +def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, weights=None): """Aux. function to _compute_tfr. Loops time-frequency transform across wavelets and epochs. @@ -685,9 +691,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): See numpy.convolve. decim : slice The decimation slice: e.g. power[:, decim] - method : str | None - Used only for multitapering to create tapers dimension in the output - if ``output in ['complex', 'phase']``. + weights : array, shape (n_tapers, n_wavelets) | None + Concentration weights for each taper in the wavelets, if present. """ # Set output type dtype = np.float64 @@ -701,10 +706,12 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): n_freqs = len(Ws[0]) if ("avg_" in output) or ("itc" in output): tfrs = np.zeros((n_freqs, n_times), dtype=dtype) - elif output in ["complex", "phase"] and method == "multitaper": - tfrs = np.zeros((n_tapers, n_epochs, n_freqs, n_times), dtype=dtype) + elif output in ["complex", "phase"] and weights is not None: + tfrs = np.zeros((n_epochs, n_tapers, n_freqs, n_times), dtype=dtype) else: tfrs = np.zeros((n_epochs, n_freqs, n_times), dtype=dtype) + if weights is not None: + weights = np.expand_dims(weights, axis=-1) # add singleton time dimension # Loops across tapers. for taper_idx, W in enumerate(Ws): @@ -719,6 +726,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): # Loop across epochs for epoch_idx, tfr in enumerate(coefs): # Transform complex values + if output not in ["complex", "phase"] and weights is not None: + tfr = weights[taper_idx] * tfr # weight each taper estimate if output in ["power", "avg_power"]: tfr = (tfr * tfr.conj()).real # power elif output == "phase": @@ -734,8 +743,8 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): # Stack or add if ("avg_" in output) or ("itc" in output): tfrs += tfr - elif output in ["complex", "phase"] and method == "multitaper": - tfrs[taper_idx, epoch_idx] += tfr + elif output in ["complex", "phase"] and weights is not None: + tfrs[epoch_idx, taper_idx] += tfr else: tfrs[epoch_idx] += tfr @@ -749,9 +758,14 @@ def _time_frequency_loop(X, Ws, output, use_fft, mode, decim, method=None): if ("avg_" in output) or ("itc" in output): tfrs /= n_epochs - # Normalization by number of taper - if n_tapers > 1 and output not in ["complex", "phase"]: - tfrs /= n_tapers + # Normalization by taper weights + if n_tapers > 1 and output not in ["complex", "phase", "itc"]: + if "avg_" not in output: # add singleton epochs dimension to weights + weights = np.expand_dims(weights, axis=0) + tfrs.real *= 2 / (weights * weights.conj()).real.sum(axis=-3) + if output == "avg_power_itc": # weight itc by the number of tapers + tfrs.imag = tfrs.imag / n_tapers + return tfrs diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 60e02432c7b..dd6925bd1b1 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1494,19 +1494,22 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["export_fmt_support_epochs"] = """\ Supported formats: - - EEGLAB (``.set``, uses :mod:`eeglabio`) + +- EEGLAB (``.set``, uses :mod:`eeglabio`) """ docdict["export_fmt_support_evoked"] = """\ Supported formats: - - MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`) + +- MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`) """ docdict["export_fmt_support_raw"] = """\ Supported formats: - - BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv `_) - - EEGLAB (``.set``, uses :mod:`eeglabio`) - - EDF (``.edf``, uses `edfio `_) + +- BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv `_) +- EEGLAB (``.set``, uses :mod:`eeglabio`) +- EDF (``.edf``, uses `edfio `_) """ # noqa: E501 docdict["export_warning"] = """\