Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
tsbinns and larsoner authored Jan 10, 2025
1 parent 0b64b43 commit d4b2c7b
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
1 change: 0 additions & 1 deletion mne/time_frequency/multitaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,6 @@ def tfr_array_multitaper(
* ``'itc'`` : inter-trial coherence.
* ``'avg_power_itc'`` : average of single trial power and inter-trial
coherence across trials.
%(n_jobs)s
The parallelization is implemented across channels.
return_weights : bool, default False
Expand Down
7 changes: 5 additions & 2 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,6 @@ def _compute_tfr(
* 'itc' : inter-trial coherence.
* 'avg_power_itc' : average of single trial power and inter-trial
coherence across trials.
return_weights : bool, default False
Whether to return the taper weights. Only applies if method='multitaper' and
output='complex' or 'phase'.
Expand Down Expand Up @@ -2039,7 +2038,7 @@ def plot(
want_shape[freq_axis] = len(freqs) # in case there was fmin/fmax cropping
want_shape[time_axis] = len(times) # in case there was tmin/tmax cropping
want_shape = [
n for i, n in enumerate(want_shape) if self._dims[i] != "taper"
n for dim, n in zip(self._dims, want_shape) dim != "taper"
] # tapers must be aggregated over by now
want_shape = tuple(want_shape)
# combine
Expand Down Expand Up @@ -4273,6 +4272,10 @@ def _prep_data_for_plot(
if np.iscomplexobj(data): # complex coefficients → power
data = _tfr_from_mt(data, taper_weights)
else: # tapered phase data → weighted phase data
# channels, tapers, freqs, time
assert data.ndim == 4
# weights as a function of (tapers, freqs)
assert taper_weights.ndim == 2
data = (data * taper_weights[np.newaxis, :, :, np.newaxis]).mean(axis=1)
# handle remaining complex amplitude → real power
if np.iscomplexobj(data):
Expand Down

0 comments on commit d4b2c7b

Please sign in to comment.