diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index 463879a8860..98705e838c2 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -503,7 +503,6 @@ def tfr_array_multitaper( * ``'itc'`` : inter-trial coherence. * ``'avg_power_itc'`` : average of single trial power and inter-trial coherence across trials. - %(n_jobs)s The parallelization is implemented across channels. return_weights : bool, default False diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 3f099989eb3..f34905980cc 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -486,7 +486,6 @@ def _compute_tfr( * 'itc' : inter-trial coherence. * 'avg_power_itc' : average of single trial power and inter-trial coherence across trials. - return_weights : bool, default False Whether to return the taper weights. Only applies if method='multitaper' and output='complex' or 'phase'. @@ -2039,7 +2038,7 @@ def plot( want_shape[freq_axis] = len(freqs) # in case there was fmin/fmax cropping want_shape[time_axis] = len(times) # in case there was tmin/tmax cropping want_shape = [ - n for i, n in enumerate(want_shape) if self._dims[i] != "taper" + n for dim, n in zip(self._dims, want_shape) dim != "taper" ] # tapers must be aggregated over by now want_shape = tuple(want_shape) # combine @@ -4273,6 +4272,10 @@ def _prep_data_for_plot( if np.iscomplexobj(data): # complex coefficients → power data = _tfr_from_mt(data, taper_weights) else: # tapered phase data → weighted phase data + # channels, tapers, freqs, time + assert data.ndim == 4 + # weights as a function of (tapers, freqs) + assert taper_weights.ndim == 2 data = (data * taper_weights[np.newaxis, :, :, np.newaxis]).mean(axis=1) # handle remaining complex amplitude → real power if np.iscomplexobj(data):