From 1d2635f84a55785c3531cfe4027eda3820a7fb31 Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Mon, 13 Jan 2025 20:00:00 +0000 Subject: [PATCH 1/5] [ENH] Add option to store and return TFR taper weights (#12910) Co-authored-by: Daniel McCloy Co-authored-by: Eric Larson --- doc/changes/devel/12910.newfeature.rst | 1 + mne/time_frequency/multitaper.py | 10 + mne/time_frequency/tests/test_tfr.py | 221 +++++++++++++-- mne/time_frequency/tfr.py | 362 +++++++++++++++++-------- mne/utils/docs.py | 12 + mne/utils/numerics.py | 3 + mne/viz/tests/test_topomap.py | 25 +- mne/viz/topomap.py | 14 +- 8 files changed, 507 insertions(+), 141 deletions(-) create mode 100644 doc/changes/devel/12910.newfeature.rst diff --git a/doc/changes/devel/12910.newfeature.rst b/doc/changes/devel/12910.newfeature.rst new file mode 100644 index 00000000000..95605c11017 --- /dev/null +++ b/doc/changes/devel/12910.newfeature.rst @@ -0,0 +1 @@ +Added the option to return taper weights from :func:`mne.time_frequency.tfr_array_multitaper`, and taper weights are now stored in the :class:`mne.time_frequency.BaseTFR` objects, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index 73a3308685d..98705e838c2 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -471,6 +471,7 @@ def tfr_array_multitaper( output="complex", n_jobs=None, *, + return_weights=False, verbose=None, ): """Compute Time-Frequency Representation (TFR) using DPSS tapers. @@ -504,6 +505,11 @@ def tfr_array_multitaper( coherence across trials. %(n_jobs)s The parallelization is implemented across channels. + return_weights : bool, default False + If True, return the taper weights. Only applies if ``output='complex'`` or + ``'phase'``. + + .. versionadded:: 1.10.0 %(verbose)s Returns @@ -520,6 +526,9 @@ def tfr_array_multitaper( If ``output`` is ``'avg_power_itc'``, the real values in ``out`` contain the average power and the imaginary values contain the inter-trial coherence: :math:`out = power_{avg} + i * ITC`. + weights : array of shape (n_tapers, n_freqs) + The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and + ``return_weights=True``. See Also -------- @@ -550,6 +559,7 @@ def tfr_array_multitaper( use_fft=use_fft, decim=decim, output=output, + return_weights=return_weights, n_jobs=n_jobs, verbose=verbose, ) diff --git a/mne/time_frequency/tests/test_tfr.py b/mne/time_frequency/tests/test_tfr.py index e68ea9e6e18..6fa3a833be2 100644 --- a/mne/time_frequency/tests/test_tfr.py +++ b/mne/time_frequency/tests/test_tfr.py @@ -432,17 +432,21 @@ def test_tfr_morlet(): def test_dpsswavelet(): """Test DPSS tapers.""" freqs = np.arange(5, 25, 3) - Ws = _make_dpss( - 1000, freqs=freqs, n_cycles=freqs / 2.0, time_bandwidth=4.0, zero_mean=True + Ws, weights = _make_dpss( + 1000, + freqs=freqs, + n_cycles=freqs / 2.0, + time_bandwidth=4.0, + zero_mean=True, + return_weights=True, ) - assert len(Ws) == 3 # 3 tapers expected + assert np.shape(Ws)[:2] == (3, len(freqs)) # 3 tapers expected + assert np.shape(Ws)[:2] == np.shape(weights) # weights of shape (tapers, freqs) # Check that zero mean is true assert np.abs(np.mean(np.real(Ws[0][0]))) < 1e-5 - assert len(Ws[0]) == len(freqs) # As many wavelets as asked for - @pytest.mark.slowtest def test_tfr_multitaper(): @@ -664,6 +668,17 @@ def test_tfr_io(inst, average_tfr, request, tmp_path): with tfr.info._unlock(): tfr.info["meas_date"] = want assert tfr_loaded == tfr + # test with taper dimension and weights + n_tapers = 3 # anything >= 1 should do + weights = np.ones((n_tapers, tfr.shape[2])) # tapers x freqs + state = tfr.__getstate__() + state["data"] = np.repeat(np.expand_dims(tfr.data, 2), n_tapers, axis=2) # add dim + state["weights"] = weights # add weights + state["dims"] = ("epoch", "channel", "taper", "freq", "time") # update dims + tfr = EpochsTFR(inst=state) + tfr.save(fname, overwrite=True) + tfr_loaded = read_tfrs(fname) + assert tfr_loaded == tfr # test overwrite with pytest.raises(OSError, match="Destination file exists."): tfr.save(fname, overwrite=False) @@ -722,17 +737,31 @@ def test_average_tfr_init(full_evoked): AverageTFR(inst=full_evoked, method="stockwell", freqs=freqs_linspace) -def test_epochstfr_init_errors(epochs_tfr): - """Test __init__ for EpochsTFR.""" - state = epochs_tfr.__getstate__() - with pytest.raises(ValueError, match="EpochsTFR data should be 4D, got 3"): - EpochsTFR(inst=state | dict(data=epochs_tfr.data[..., 0])) +@pytest.mark.parametrize("inst", ("raw_tfr", "epochs_tfr", "average_tfr")) +def test_tfr_init_errors(inst, request, average_tfr): + """Test __init__ for {Raw,Epochs,Average}TFR.""" + # Load data + inst = _get_inst(inst, request, average_tfr=average_tfr) + state = inst.__getstate__() + # Prepare for TFRArray object instantiation + inst_name = inst.__class__.__name__ + class_mapping = dict(RawTFR=RawTFR, EpochsTFR=EpochsTFR, AverageTFR=AverageTFR) + ndims_mapping = dict( + RawTFR=("3D or 4D"), EpochsTFR=("4D or 5D"), AverageTFR=("3D or 4D") + ) + TFR = class_mapping[inst_name] + allowed_ndims = ndims_mapping[inst_name] + # Check errors caught + with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"): + TFR(inst=state | dict(data=inst.data[..., 0])) + with pytest.raises(ValueError, match=f".*TFR data should be {allowed_ndims}"): + TFR(inst=state | dict(data=np.expand_dims(inst.data, axis=(0, 1)))) with pytest.raises(ValueError, match="Channel axis of data .* doesn't match info"): - EpochsTFR(inst=state | dict(data=epochs_tfr.data[:, :-1])) + TFR(inst=state | dict(data=inst.data[..., :-1, :, :])) with pytest.raises(ValueError, match="Time axis of data.*doesn't match times attr"): - EpochsTFR(inst=state | dict(times=epochs_tfr.times[:-1])) + TFR(inst=state | dict(times=inst.times[:-1])) with pytest.raises(ValueError, match="Frequency axis of.*doesn't match freqs attr"): - EpochsTFR(inst=state | dict(freqs=epochs_tfr.freqs[:-1])) + TFR(inst=state | dict(freqs=inst.freqs[:-1])) @pytest.mark.parametrize( @@ -830,6 +859,25 @@ def test_plot(): plt.close("all") +@pytest.mark.parametrize("output", ("complex", "phase")) +def test_plot_multitaper_complex_phase(output): + """Test TFR plotting of data with a taper dimension.""" + # Create example data with a taper dimension + n_chans, n_tapers, n_freqs, n_times = (3, 4, 2, 3) + data = np.random.rand(n_chans, n_tapers, n_freqs, n_times) + if output == "complex": + data = data + np.random.rand(*data.shape) * 1j # add imaginary data + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.random.rand(n_tapers, n_freqs) + info = mne.create_info(n_chans, 1000.0, "eeg") + tfr = AverageTFRArray( + info=info, data=data, times=times, freqs=freqs, weights=weights + ) + # Check that plotting works + tfr.plot() + + @pytest.mark.parametrize( "timefreqs,title,combine", ( @@ -1154,6 +1202,15 @@ def test_averaging_epochsTFR(): ): power.average(method=np.mean) + # Check it doesn't run for taper spectra + tapered = epochs.compute_tfr( + method="multitaper", freqs=freqs, n_cycles=n_cycles, output="complex" + ) + with pytest.raises( + NotImplementedError, match=r"Averaging multitaper tapers .* is not supported." + ): + tapered.average() + def test_averaging_freqsandtimes_epochsTFR(): """Test that EpochsTFR averaging freqs methods work.""" @@ -1258,12 +1315,15 @@ def test_to_data_frame(): ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] n_picks = len(ch_names) ch_types = ["eeg"] * n_picks + n_tapers = 2 n_freqs = 5 n_times = 6 - data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) + data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times) + times = np.arange(n_times) srate = 1000.0 - freqs = np.arange(5) + freqs = np.arange(n_freqs) + tapers = np.arange(n_tapers) + weights = np.ones((n_tapers, n_freqs)) events = np.zeros((n_epos, 3), dtype=int) events[:, 0] = np.arange(n_epos) events[:, 2] = np.arange(5, 5 + n_epos) @@ -1276,6 +1336,7 @@ def test_to_data_frame(): freqs=freqs, events=events, event_id=event_id, + weights=weights, ) # test index checking with pytest.raises(ValueError, match="options. Valid index options are"): @@ -1287,10 +1348,21 @@ def test_to_data_frame(): # test wide format df_wide = tfr.to_data_frame() assert all(np.isin(tfr.ch_names, df_wide.columns)) - assert all(np.isin(["time", "condition", "freq", "epoch"], df_wide.columns)) + assert all( + np.isin(["time", "condition", "freq", "epoch", "taper"], df_wide.columns) + ) # test long format df_long = tfr.to_data_frame(long_format=True) - expected = ("condition", "epoch", "freq", "time", "channel", "ch_type", "value") + expected = ( + "condition", + "epoch", + "freq", + "time", + "channel", + "ch_type", + "value", + "taper", + ) assert set(expected) == set(df_long.columns) assert set(tfr.ch_names) == set(df_long["channel"]) assert len(df_long) == tfr.data.size @@ -1298,21 +1370,29 @@ def test_to_data_frame(): df_long = tfr.to_data_frame(long_format=True, index=["freq"]) del df_wide, df_long # test whether data is in correct shape - df = tfr.to_data_frame(index=["condition", "epoch", "freq", "time"]) + df = tfr.to_data_frame(index=["condition", "epoch", "taper", "freq", "time"]) data = tfr.data assert_array_equal(df.values[:, 0], data[:, 0, :, :].reshape(1, -1).squeeze()) # compare arbitrary observation: assert ( - df.loc[("he", slice(None), freqs[1], times[2]), ch_names[3]].iat[0] - == data[1, 3, 1, 2] + df.loc[("he", slice(None), tapers[1], freqs[1], times[2]), ch_names[3]].iat[0] + == data[1, 3, 1, 1, 2] ) # Check also for AverageTFR: + # (remove taper dimension before averaging) + state = tfr.__getstate__() + state["data"] = state["data"][:, :, 0] + state["dims"] = ("epoch", "channel", "freq", "time") + state["weights"] = None + tfr = EpochsTFR(inst=state) tfr = tfr.average() with pytest.raises(ValueError, match="options. Valid index options are"): tfr.to_data_frame(index=["epoch", "condition"]) with pytest.raises(ValueError, match='"epoch" is not a valid option'): tfr.to_data_frame(index="epoch") + with pytest.raises(ValueError, match='"taper" is not a valid option'): + tfr.to_data_frame(index="taper") with pytest.raises(TypeError, match="index must be `None` or a string "): tfr.to_data_frame(index=np.arange(400)) # test wide format @@ -1348,11 +1428,13 @@ def test_to_data_frame_index(index): ch_names = ["EEG 001", "EEG 002", "EEG 003", "EEG 004"] n_picks = len(ch_names) ch_types = ["eeg"] * n_picks + n_tapers = 2 n_freqs = 5 n_times = 6 - data = np.random.rand(n_epos, n_picks, n_freqs, n_times) - times = np.arange(6) - freqs = np.arange(5) + data = np.random.rand(n_epos, n_picks, n_tapers, n_freqs, n_times) + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.ones((n_tapers, n_freqs)) events = np.zeros((n_epos, 3), dtype=int) events[:, 0] = np.arange(n_epos) events[:, 2] = np.arange(5, 8) @@ -1365,6 +1447,7 @@ def test_to_data_frame_index(index): freqs=freqs, events=events, event_id=event_id, + weights=weights, ) df = tfr.to_data_frame(picks=[0, 2, 3], index=index) # test index order/hierarchy preservation @@ -1372,7 +1455,7 @@ def test_to_data_frame_index(index): index = [index] assert list(df.index.names) == index # test that non-indexed data were present as columns - non_index = list(set(["condition", "time", "freq", "epoch"]) - set(index)) + non_index = list(set(["condition", "time", "freq", "taper", "epoch"]) - set(index)) if len(non_index): assert all(np.isin(non_index, df.columns)) @@ -1538,7 +1621,8 @@ def test_epochs_compute_tfr_stockwell(epochs, freqs, return_itc): def test_epochs_compute_tfr_multitaper_complex_phase(epochs, output): """Test Epochs.compute_tfr(output="complex"/"phase").""" tfr = epochs.compute_tfr("multitaper", freqs_linspace, output=output) - assert len(tfr.shape) == 5 + assert len(tfr.shape) == 5 # epoch x channel x taper x freq x time + assert tfr.weights.shape == tfr.shape[2:4] # check weights and coeffs shapes match @pytest.mark.parametrize("copy", (False, True)) @@ -1550,6 +1634,42 @@ def test_epochstfr_iter_evoked(epochs_tfr, copy): assert avgs[0].comment == str(epochs_tfr.events[0, -1]) +@pytest.mark.parametrize("obj_type", ("raw", "epochs", "evoked")) +def test_tfrarray_tapered_spectra(obj_type): + """Test {Raw,Epochs,Average}TFRArray instantiation with tapered spectra.""" + # Create example data with a taper dimension + n_epochs, n_chans, n_tapers, n_freqs, n_times = (5, 3, 4, 2, 6) + data_shape = (n_chans, n_tapers, n_freqs, n_times) + if obj_type == "epochs": + data_shape = (n_epochs,) + data_shape + data = np.random.rand(*data_shape) + times = np.arange(n_times) + freqs = np.arange(n_freqs) + weights = np.random.rand(n_tapers, n_freqs) + info = mne.create_info(n_chans, 1000.0, "eeg") + # Prepare for TFRArray object instantiation + defaults = dict(info=info, data=data, times=times, freqs=freqs) + class_mapping = dict(raw=RawTFRArray, epochs=EpochsTFRArray, evoked=AverageTFRArray) + TFRArray = class_mapping[obj_type] + # Check TFRArray instantiation runs with good data + TFRArray(**defaults, weights=weights) + # Check taper dimension but no weights caught + with pytest.raises( + ValueError, match="Taper dimension in data, but no weights found." + ): + TFRArray(**defaults) + # Check mismatching n_taper in weights caught + with pytest.raises( + ValueError, match=r"Taper axis .* doesn't match weights attribute" + ): + TFRArray(**defaults, weights=weights[:-1]) + # Check mismatching n_freq in weights caught + with pytest.raises( + ValueError, match=r"Frequency axis .* doesn't match weights attribute" + ): + TFRArray(**defaults, weights=weights[:, :-1]) + + def test_tfr_proj(epochs): """Test `compute_tfr(proj=True)`.""" epochs.compute_tfr(method="morlet", freqs=freqs_linspace, proj=True) @@ -1731,3 +1851,52 @@ def test_tfr_plot_topomap(inst, ch_type, full_average_tfr, request): assert re.match( rf"Average over \d{{1,3}} {ch_type} channels\.", popup_fig.axes[0].get_title() ) + + +@pytest.mark.parametrize("output", ("complex", "phase")) +def test_tfr_topo_plotting_multitaper_complex_phase(output, evoked): + """Test plot_joint/topo/topomap() for data with a taper dimension.""" + # Compute TFR with taper dimension + tfr = evoked.compute_tfr( + method="multitaper", freqs=freqs_linspace, n_cycles=4, output=output + ) + # Check that plotting works + tfr.plot_joint(topomap_args=dict(res=8, contours=0, sensors=False)) # for speed + tfr.plot_topo() + tfr.plot_topomap() + + +def test_combine_tfr_error_catch(average_tfr): + """Test combine_tfr() catches errors.""" + # check unrecognised weights string caught + with pytest.raises(ValueError, match='Weights must be .* "nave" or "equal"'): + combine_tfr([average_tfr, average_tfr], weights="foo") + # check bad weights size caught + with pytest.raises(ValueError, match="Weights must be the same size as all_tfr"): + combine_tfr([average_tfr, average_tfr], weights=[1, 1, 1]) + # check different channel names caught + state = average_tfr.__getstate__() + new_info = average_tfr.info.copy() + average_tfr_bad = AverageTFR( + inst=state | dict(info=new_info.rename_channels({new_info.ch_names[0]: "foo"})) + ) + with pytest.raises(AssertionError, match=".* do not contain the same channels"): + combine_tfr([average_tfr, average_tfr_bad]) + # check different times caught + average_tfr_bad = AverageTFR(inst=state | dict(times=average_tfr.times + 1)) + with pytest.raises( + AssertionError, match=".* do not contain the same time instants" + ): + combine_tfr([average_tfr, average_tfr_bad]) + # check taper dim caught + n_tapers = 3 # anything >= 1 should do + weights = np.ones((n_tapers, average_tfr.shape[1])) # tapers x freqs + state["data"] = np.repeat(np.expand_dims(average_tfr.data, 1), n_tapers, axis=1) + state["weights"] = weights + state["dims"] = ("channel", "taper", "freq", "time") + average_tfr_taper = AverageTFR(inst=state) + with pytest.raises( + NotImplementedError, + match="Aggregating multitaper tapers across TFR datasets is not supported.", + ): + combine_tfr([average_tfr_taper, average_tfr_taper]) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 12d45d5d572..918fea1a33f 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -264,8 +264,11 @@ def _make_dpss( ------- Ws : list of array The wavelets time series. + Cs : list of array + The concentration weights. Only returned if return_weights=True. """ Ws = list() + Cs = list() freqs = np.array(freqs) if np.any(freqs <= 0): @@ -281,6 +284,7 @@ def _make_dpss( for m in range(n_taps): Wm = list() + Cm = list() for k, f in enumerate(freqs): if len(n_cycles) != 1: this_n_cycles = n_cycles[k] @@ -302,12 +306,15 @@ def _make_dpss( real_offset = Wk.mean() Wk -= real_offset Wk /= np.sqrt(0.5) * np.linalg.norm(Wk.ravel()) + Ck = np.sqrt(conc[m]) Wm.append(Wk) + Cm.append(Ck) Ws.append(Wm) + Cs.append(Cm) if return_weights: - return Ws, conc + return Ws, Cs return Ws @@ -428,6 +435,7 @@ def _compute_tfr( use_fft=True, decim=1, output="complex", + return_weights=False, n_jobs=None, *, verbose=None, @@ -478,7 +486,9 @@ def _compute_tfr( * 'itc' : inter-trial coherence. * 'avg_power_itc' : average of single trial power and inter-trial coherence across trials. - + return_weights : bool, default False + Whether to return the taper weights. Only applies if method='multitaper' and + output='complex' or 'phase'. %(n_jobs)s The number of epochs to process at the same time. The parallelization is implemented across channels. @@ -495,6 +505,9 @@ def _compute_tfr( n_tapers, n_freqs, n_times)``. If output is ``'avg_power_itc'``, the real values in the ``output`` contain average power' and the imaginary values contain the ITC: ``out = avg_power + i * itc``. + weights : array of shape (n_tapers, n_freqs) + The taper weights. Only returned if method='multitaper', output='complex' or + 'phase', and return_weights=True. """ # Check data epoch_data = np.asarray(epoch_data) @@ -516,6 +529,9 @@ def _compute_tfr( decim, output, ) + return_weights = ( + return_weights and method == "multitaper" and output in ["complex", "phase"] + ) decim = _ensure_slice(decim) if (freqs > sfreq / 2.0).any(): @@ -531,13 +547,18 @@ def _compute_tfr( Ws = [W] # to have same dimensionality as the 'multitaper' case elif method == "multitaper": - Ws = _make_dpss( + out = _make_dpss( sfreq, freqs, n_cycles=n_cycles, time_bandwidth=time_bandwidth, zero_mean=zero_mean, + return_weights=return_weights, ) + if return_weights: + Ws, weights = out + else: + Ws = out # Check wavelets if len(Ws[0][0]) > epoch_data.shape[2]: @@ -561,6 +582,8 @@ def _compute_tfr( out = np.empty((n_chans, n_freqs, n_times), dtype) elif output in ["complex", "phase"] and method == "multitaper": out = np.empty((n_chans, n_tapers, n_epochs, n_freqs, n_times), dtype) + if return_weights: + weights = np.array(weights) else: out = np.empty((n_chans, n_epochs, n_freqs, n_times), dtype) @@ -585,6 +608,9 @@ def _compute_tfr( out = out.transpose(2, 0, 1, 3, 4) else: out = out.transpose(1, 0, 2, 3) + + if return_weights: + return out, weights return out @@ -1200,6 +1226,9 @@ def __init__( method_kw.setdefault("output", "power") self._freqs = np.asarray(freqs, dtype=np.float64) del freqs + # always store weights for per-taper outputs + if method == "multitaper" and method_kw.get("output") in ["complex", "phase"]: + method_kw["return_weights"] = True # check validity of kwargs manually to save compute time if any are invalid tfr_funcs = dict( morlet=tfr_array_morlet, @@ -1221,6 +1250,7 @@ def __init__( self._method = method self._inst_type = type(inst) self._baseline = None + self._weights = None self.preload = True # needed for __getitem__, never False for TFRs # self._dims may also get updated by child classes self._dims = ["channel", "freq", "time"] @@ -1379,6 +1409,7 @@ def __getstate__(self): info=self.info, baseline=self._baseline, decim=self._decim, + weights=self._weights, ) def __setstate__(self, state): @@ -1389,7 +1420,6 @@ def __setstate__(self, state): defaults = dict( method="unknown", - dims=("epoch", "channel", "freq", "time")[-state["data"].ndim :], baseline=None, decim=1, data_type="TFR", @@ -1407,12 +1437,13 @@ def __setstate__(self, state): self._decim = defaults["decim"] self.preload = True self._set_times(self._raw_times) + self._weights = state.get("weights") # objs saved before #12910 won't have # Handle instance type. Prior to gh-11282, Raw was not a possibility so if # `inst_type_str` is missing it must be Epochs or Evoked unknown_class = Epochs if "epoch" in self._dims else Evoked inst_types = dict(Raw=Raw, Epochs=Epochs, Evoked=Evoked, Unknown=unknown_class) self._inst_type = inst_types[defaults["inst_type_str"]] - # sanity check data/freqs/times/info agreement + # sanity check data/freqs/times/info/weights agreement self._check_state() def __repr__(self): @@ -1465,18 +1496,29 @@ def _check_compatibility(self, other): raise RuntimeError(msg.format(problem, extra)) def _check_state(self): - """Check data/freqs/times/info agreement during __setstate__.""" + """Check data/freqs/times/info/weights agreement during __setstate__.""" msg = "{} axis of data ({}) doesn't match {} attribute ({})" n_chan_info = len(self.info["chs"]) n_chan = self._data.shape[self._dims.index("channel")] n_freq = self._data.shape[self._dims.index("freq")] n_time = self._data.shape[self._dims.index("time")] + n_taper = ( + self._data.shape[self._dims.index("taper")] + if "taper" in self._dims + else None + ) + if n_taper is not None and self._weights is None: + raise ValueError("Taper dimension in data, but no weights found.") if n_chan_info != n_chan: msg = msg.format("Channel", n_chan, "info", n_chan_info) elif n_freq != len(self.freqs): msg = msg.format("Frequency", n_freq, "freqs", self.freqs.size) elif n_time != len(self.times): msg = msg.format("Time", n_time, "times", self.times.size) + elif n_taper is not None and n_taper != self._weights.shape[0]: + msg = msg.format("Taper", n_taper, "weights", self._weights.shape[0]) + elif n_taper is not None and n_freq != self._weights.shape[1]: + msg = msg.format("Frequency", n_freq, "weights", self._weights.shape[1]) else: return raise ValueError(msg) @@ -1513,6 +1555,10 @@ def _compute_tfr(self, data, n_jobs, verbose): if self.method == "stockwell": self._data, self._itc, freqs = result assert np.array_equal(self._freqs, freqs) + elif self.method == "multitaper" and self._tfr_func.keywords.get( + "output", "" + ) in ["complex", "phase"]: + self._data, self._weights = result elif self._tfr_func.keywords.get("output", "").endswith("_itc"): self._data, self._itc = result.real, result.imag else: @@ -1613,6 +1659,7 @@ def _onselect( fmax=fmax, baseline=baseline, mode=mode, + taper_weights=self.weights, verbose=verbose, ) # average over times and freqs @@ -1691,6 +1738,11 @@ def times(self): """The time points present in the data (in seconds).""" return self._times_readonly + @property + def weights(self): + """The weights used for each taper in the time-frequency estimates.""" + return self._weights + @fill_doc def crop(self, tmin=None, tmax=None, fmin=None, fmax=None, include_tmax=True): """Crop data to a given time interval in place. @@ -1785,6 +1837,7 @@ def get_data( tmax=None, return_times=False, return_freqs=False, + return_tapers=False, ): """Get time-frequency data in NumPy array format. @@ -1800,6 +1853,10 @@ def get_data( return_freqs : bool Whether to return the frequency bin values for the requested frequency range. Default is ``False``. + return_tapers : bool + Whether to return the taper numbers. Default is ``False``. + + .. versionadded:: 1.10.0 Returns ------- @@ -1811,6 +1868,9 @@ def get_data( freqs : array The frequency values for the requested data range. Only returned if ``return_freqs`` is ``True``. + tapers : array | None + The taper numbers. Only returned if ``return_tapers`` is ``True``. Will be + ``None`` if a taper dimension is not present in the data. Notes ----- @@ -1848,7 +1908,13 @@ def get_data( if return_freqs: freqs = self._freqs[fmin_idx:fmax_idx] out.append(freqs) - if not return_times and not return_freqs: + if return_tapers: + if "taper" in self._dims: + tapers = np.arange(self.shape[self._dims.index("taper")]) + else: + tapers = None + out.append(tapers) + if not return_times and not return_freqs and not return_tapers: return out[0] return tuple(out) @@ -1960,6 +2026,7 @@ def plot( baseline=baseline, mode=mode, dB=dB, + taper_weights=self.weights, verbose=verbose, ) # shape @@ -1970,6 +2037,9 @@ def plot( want_shape[ch_axis] = len(idx_picks) if combine is None else 1 want_shape[freq_axis] = len(freqs) # in case there was fmin/fmax cropping want_shape[time_axis] = len(times) # in case there was tmin/tmax cropping + want_shape = [ + n for dim, n in zip(self._dims, want_shape) if dim != "taper" + ] # tapers must be aggregated over by now want_shape = tuple(want_shape) # combine combine_was_none = combine is None @@ -2313,6 +2383,7 @@ def plot_joint( fmax=_fmax, baseline=baseline, mode=mode, + taper_weights=self.weights, verbose=verbose, ) _data = _data.mean(axis=(-1, -2)) # avg over times and freqs @@ -2461,23 +2532,23 @@ def plot_topo( info, data = _prepare_picks(info, data, picks, axis=0) del picks - # TODO this is the only remaining call to _preproc_tfr; should be refactored - # (to use _prep_data_for_plot?) - data, times, freqs, vmin, vmax = _preproc_tfr( + # baseline, crop, convert complex to power, aggregate tapers, and dB scaling + data, times, freqs = _prep_data_for_plot( data, times, freqs, - tmin, - tmax, - fmin, - fmax, - mode, - baseline, - vmin, - vmax, - dB, - info["sfreq"], + tmin=tmin, + tmax=tmax, + fmin=fmin, + fmax=fmax, + baseline=baseline, + mode=mode, + dB=dB, + taper_weights=self.weights, + verbose=verbose, ) + # get vlims + vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) if layout is None: from mne import find_layout @@ -2624,21 +2695,21 @@ def to_data_frame( ): """Export data in tabular structure as a pandas DataFrame. - Channels are converted to columns in the DataFrame. By default, - additional columns ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` (epoch event description) are added, unless ``index`` - is not ``None`` (in which case the columns specified in ``index`` will - be used to form the DataFrame's index instead). ``'epoch'``, and - ``'condition'`` are not supported for ``AverageTFR``. + Channels are converted to columns in the DataFrame. By default, additional + columns ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, and ``'condition'`` + (epoch event description) are added, unless ``index`` is not ``None`` (in which + case the columns specified in ``index`` will be used to form the DataFrame's + index instead). ``'epoch'``, and ``'condition'`` are not supported for + ``AverageTFR``. ``'taper'`` is only supported when a taper dimensions is + present, such as for complex or phase multitaper data. Parameters ---------- %(picks_all)s %(index_df_epo)s - Valid string values are ``'time'``, ``'freq'``, ``'epoch'``, and - ``'condition'`` for ``EpochsTFR`` and ``'time'`` and ``'freq'`` - for ``AverageTFR``. - Defaults to ``None``. + Valid string values are ``'time'``, ``'freq'``, ``'taper'``, ``'epoch'``, + and ``'condition'`` for ``EpochsTFR`` and ``'time'``, ``'freq'``, and + ``'taper'`` for ``AverageTFR``. Defaults to ``None``. %(long_format_df_epo)s %(time_format_df)s @@ -2651,42 +2722,58 @@ def to_data_frame( """ # check pandas once here, instead of in each private utils function pd = _check_pandas_installed() # noqa + # triage for Epoch-derived or unaggregated spectra + from_epo = isinstance(self, EpochsTFR) + unagg_mt = "taper" in self._dims # arg checking valid_index_args = ["time", "freq"] - if isinstance(self, EpochsTFR): + if from_epo: valid_index_args.extend(["epoch", "condition"]) + if unagg_mt: + valid_index_args.append("taper") valid_time_formats = ["ms", "timedelta"] index = _check_pandas_index_arguments(index, valid_index_args) time_format = _check_time_format(time_format, valid_time_formats) # get data picks = _picks_to_idx(self.info, picks, "all", exclude=()) - data, times, freqs = self.get_data(picks, return_times=True, return_freqs=True) - axis = self._dims.index("channel") - if not isinstance(self, EpochsTFR): + data, times, freqs, tapers = self.get_data( + picks, return_times=True, return_freqs=True, return_tapers=True + ) + ch_axis = self._dims.index("channel") + if not from_epo: data = data[np.newaxis] # add singleton "epochs" axis - axis += 1 - n_epochs, n_picks, n_freqs, n_times = data.shape - # reshape to (epochs*freqs*times) x signals - data = np.moveaxis(data, axis, -1) - data = data.reshape(n_epochs * n_freqs * n_times, n_picks) + ch_axis += 1 + if not unagg_mt: + data = np.expand_dims(data, -3) # add singleton "tapers" axis + n_epochs, n_picks, n_tapers, n_freqs, n_times = data.shape + # reshape to (epochs*tapers*freqs*times) x signals + data = np.moveaxis(data, ch_axis, -1) + data = data.reshape(n_epochs * n_tapers * n_freqs * n_times, n_picks) # prepare extra columns / multiindex mindex = list() + default_index = list() times = _convert_times(times, time_format, self.info["meas_date"]) - times = np.tile(times, n_epochs * n_freqs) - freqs = np.tile(np.repeat(freqs, n_times), n_epochs) + times = np.tile(times, n_epochs * n_freqs * n_tapers) + freqs = np.tile(np.repeat(freqs, n_times), n_epochs * n_tapers) mindex.append(("time", times)) mindex.append(("freq", freqs)) - if isinstance(self, EpochsTFR): - mindex.append(("epoch", np.repeat(self.selection, n_times * n_freqs))) + if from_epo: + mindex.append( + ("epoch", np.repeat(self.selection, n_times * n_freqs * n_tapers)) + ) rev_event_id = {v: k for k, v in self.event_id.items()} conditions = [rev_event_id[k] for k in self.events[:, 2]] - mindex.append(("condition", np.repeat(conditions, n_times * n_freqs))) + mindex.append( + ("condition", np.repeat(conditions, n_times * n_freqs * n_tapers)) + ) + default_index.extend(["condition", "epoch"]) + if unagg_mt: + tapers = np.repeat(np.tile(tapers, n_epochs), n_freqs * n_times) + mindex.append(("taper", tapers)) + default_index.append("taper") + default_index.extend(["freq", "time"]) assert all(len(mdx) == len(mindex[0]) for mdx in mindex[1:]) # build DataFrame - if isinstance(self, EpochsTFR): - default_index = ["condition", "epoch", "freq", "time"] - else: - default_index = ["freq", "time"] df = _build_data_frame( self, data, picks, long_format, mindex, index, default_index=default_index ) @@ -2733,6 +2820,7 @@ class AverageTFR(BaseTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2849,6 +2937,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack AverageTFR from serialized format.""" + if state["data"].ndim not in [3, 4]: + raise ValueError( + f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("channel",) + if state["data"].ndim == 4: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") super().__setstate__(state) self._comment = state.get("comment", "") self._nave = state.get("nave", 1) @@ -2892,6 +2989,7 @@ class AverageTFRArray(AverageTFR): The number of averaged TFRs. %(comment_averagetfr_attr)s %(method_tfr_array)s + %(weights_tfr_array)s Attributes ---------- @@ -2904,6 +3002,7 @@ class AverageTFRArray(AverageTFR): %(nave_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -2914,12 +3013,22 @@ class AverageTFRArray(AverageTFR): """ def __init__( - self, info, data, times, freqs, *, nave=None, comment=None, method=None + self, + info, + data, + times, + freqs, + *, + nave=None, + comment=None, + method=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) - for name, optional in dict(nave=nave, comment=comment, method=method).items(): - if optional is not None: - state[name] = optional + optional = dict(nave=nave, comment=comment, method=method, weights=weights) + for name, value in optional.items(): + if value is not None: + state[name] = value self.__setstate__(state) @@ -2962,6 +3071,7 @@ class EpochsTFR(BaseTFR, GetEpochsMixin): %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3041,8 +3151,15 @@ def __getstate__(self): def __setstate__(self, state): """Unpack EpochsTFR from serialized format.""" - if state["data"].ndim != 4: - raise ValueError(f"EpochsTFR data should be 4D, got {state['data'].ndim}.") + if state["data"].ndim not in [4, 5]: + raise ValueError( + f"EpochsTFR data should be 4D or 5D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("epoch", "channel") + if state["data"].ndim == 5: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") super().__setstate__(state) self._metadata = state.get("metadata", None) n_epochs = self.shape[0] @@ -3152,7 +3269,16 @@ def average(self, method="mean", *, dim="epochs", copy=False): See discussion here: https://github.com/scipy/scipy/pull/12676#issuecomment-783370228 + + Averaging is not supported for data containing a taper dimension. """ + if "taper" in self._dims: + raise NotImplementedError( + "Averaging multitaper tapers across epochs, frequencies, or times is " + "not supported. If averaging across epochs, consider averaging the " + "epochs before computing the complex/phase spectrum." + ) + _check_option("dim", dim, ("epochs", "freqs", "times")) axis = self._dims.index(dim[:-1]) # self._dims entries aren't plural @@ -3524,6 +3650,7 @@ class EpochsTFRArray(EpochsTFR): %(selection)s %(drop_log)s %(metadata_epochstfr)s + %(weights_tfr_array)s Attributes ---------- @@ -3540,6 +3667,7 @@ class EpochsTFRArray(EpochsTFR): %(selection_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3562,6 +3690,7 @@ def __init__( selection=None, drop_log=None, metadata=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) optional = dict( @@ -3572,6 +3701,7 @@ def __init__( selection=selection, drop_log=drop_log, metadata=metadata, + weights=weights, ) for name, value in optional.items(): if value is not None: @@ -3614,6 +3744,7 @@ class RawTFR(BaseTFR): method : str The method used to compute the spectra (``'morlet'``, ``'multitaper'`` or ``'stockwell'``). + %(weights_tfr_attr)s See Also -------- @@ -3663,6 +3794,19 @@ def __init__( **method_kw, ) + def __setstate__(self, state): + """Unpack RawTFR from serialized format.""" + if state["data"].ndim not in [3, 4]: + raise ValueError( + f"RawTFR data should be 3D or 4D, got {state['data'].ndim}." + ) + # Set dims now since optional tapers makes it difficult to disentangle later + state["dims"] = ("channel",) + if state["data"].ndim == 4: + state["dims"] += ("taper",) + state["dims"] += ("freq", "time") + super().__setstate__(state) + def __getitem__(self, item): """Get RawTFR data. @@ -3728,6 +3872,7 @@ class RawTFRArray(RawTFR): %(times)s %(freqs_tfr_array)s %(method_tfr_array)s + %(weights_tfr_array)s Attributes ---------- @@ -3738,6 +3883,7 @@ class RawTFRArray(RawTFR): %(method_tfr_attr)s %(sfreq_tfr_attr)s %(shape_tfr_attr)s + %(weights_tfr_attr)s See Also -------- @@ -3755,10 +3901,13 @@ def __init__( freqs, *, method=None, + weights=None, ): state = dict(info=info, data=data, times=times, freqs=freqs) - if method is not None: - state["method"] = method + optional = dict(method=method, weights=weights) + for name, value in optional.items(): + if value is not None: + state[name] = value self.__setstate__(state) @@ -3786,8 +3935,16 @@ def combine_tfr(all_tfr, weights="nave"): Notes ----- + Aggregating multitaper TFR datasets with a taper dimension such as for complex or + phase data is not supported. + .. versionadded:: 0.11.0 """ + if any("taper" in tfr._dims for tfr in all_tfr): + raise NotImplementedError( + "Aggregating multitaper tapers across TFR datasets is not supported." + ) + tfr = all_tfr[0].copy() if isinstance(weights, str): if weights not in ("nave", "equal"): @@ -3861,62 +4018,6 @@ def _centered(arr, newsize): return arr[tuple(myslice)] -def _preproc_tfr( - data, - times, - freqs, - tmin, - tmax, - fmin, - fmax, - mode, - baseline, - vmin, - vmax, - dB, - sfreq, - copy=None, -): - """Aux Function to prepare tfr computation.""" - if copy is None: - copy = baseline is not None - data = rescale(data, times, baseline, mode, copy=copy) - - if np.iscomplexobj(data): - # complex amplitude → real power (for plotting); if data are - # real-valued they should already be power - data = (data * data.conj()).real - - # crop time - itmin, itmax = None, None - idx = np.where(_time_mask(times, tmin, tmax, sfreq=sfreq))[0] - if tmin is not None: - itmin = idx[0] - if tmax is not None: - itmax = idx[-1] + 1 - - times = times[itmin:itmax] - - # crop freqs - ifmin, ifmax = None, None - idx = np.where(_time_mask(freqs, fmin, fmax, sfreq=sfreq))[0] - if fmin is not None: - ifmin = idx[0] - if fmax is not None: - ifmax = idx[-1] + 1 - - freqs = freqs[ifmin:ifmax] - - # crop data - data = data[:, ifmin:ifmax, itmin:itmax] - - if dB: - data = 10 * np.log10(data) - - vmin, vmax = _setup_vmin_vmax(data, vmin, vmax) - return data, times, freqs, vmin, vmax - - def _ensure_slice(decim): """Aux function checking the decim parameter.""" _validate_type(decim, ("int-like", slice), "decim") @@ -4151,6 +4252,7 @@ def _prep_data_for_plot( baseline=None, mode=None, dB=False, + taper_weights=None, verbose=None, ): # baseline @@ -4164,9 +4266,43 @@ def _prep_data_for_plot( freqs = freqs[freq_mask] # crop data data = data[..., freq_mask, :][..., time_mask] - # complex amplitude → real power; real-valued data is already power (or ITC) + # handle unaggregated multitaper (complex or phase multitaper data) + if taper_weights is not None: # assumes a taper dimension + logger.info("Aggregating multitaper estimates before plotting...") + if np.iscomplexobj(data): # complex coefficients → power + data = _tfr_from_mt(data, taper_weights) + else: # tapered phase data → weighted phase data + # channels, tapers, freqs, time + assert data.ndim == 4 + # weights as a function of (tapers, freqs) + assert taper_weights.ndim == 2 + data = (data * taper_weights[np.newaxis, :, :, np.newaxis]).mean(axis=1) + # handle remaining complex amplitude → real power if np.iscomplexobj(data): data = (data * data.conj()).real if dB: data = 10 * np.log10(data) return data, times, freqs + + +def _tfr_from_mt(x_mt, weights): + """Aggregate complex multitaper coefficients over tapers and convert to power. + + Parameters + ---------- + x_mt : array, shape (n_channels, n_tapers, n_freqs, n_times) + The complex-valued multitaper coefficients. + weights : array, shape (n_tapers, n_freqs) + The weights to use to combine the tapered estimates. + + Returns + ------- + tfr : array, shape (n_channels, n_freqs, n_times) + The time-frequency power estimates. + """ + weights = weights[np.newaxis, :, :, np.newaxis] # add singleton channel & time dims + tfr = weights * x_mt + tfr *= tfr.conj() + tfr = tfr.real.sum(axis=1) + tfr *= 2 / (weights * weights.conj()).real.sum(axis=1) + return tfr diff --git a/mne/utils/docs.py b/mne/utils/docs.py index aea0a17fd32..683704c4bc6 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -5014,6 +5014,18 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): solution. """ +docdict["weights_tfr_array"] = """ +weights : array, shape (n_tapers, n_freqs) | None + The weights for each taper. Must be provided if ``data`` has a taper dimension, such + as for complex or phase multitaper data. + + .. versionadded:: 1.10.0 +""" +docdict["weights_tfr_attr"] = """ +weights : array, shape (n_tapers, n_freqs) | None + The weights used for each taper in the time-frequency estimates. +""" + docdict["window_psd"] = """\ window : str | float | tuple Windowing function to use. See :func:`scipy.signal.get_window`. diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index c287fb42305..4bf8d094f81 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -550,6 +550,9 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): Notes ----- + Aggregating multitaper TFR datasets with a taper dimension such as for complex or + phase data is not supported. + .. versionadded:: 0.11.0 """ # check if all elements in the given list are evoked data diff --git a/mne/viz/tests/test_topomap.py b/mne/viz/tests/test_topomap.py index afa9341c00e..b87d0d39f89 100644 --- a/mne/viz/tests/test_topomap.py +++ b/mne/viz/tests/test_topomap.py @@ -44,7 +44,7 @@ compute_bridged_electrodes, compute_current_source_density, ) -from mne.time_frequency.tfr import AverageTFRArray +from mne.time_frequency.tfr import AverageTFR, AverageTFRArray from mne.viz import plot_evoked_topomap, plot_projs_topomap, topomap from mne.viz.tests.test_raw import _proj_status from mne.viz.topomap import ( @@ -610,6 +610,29 @@ def test_plot_tfr_topomap(): ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 ) + # test data with taper dimension (real) + data = np.expand_dims(data, axis=1) + weights = np.random.rand(1, n_freqs) + tfr = AverageTFRArray( + info=info, + data=data, + times=times, + freqs=np.arange(n_freqs), + nave=nave, + weights=weights, + ) + tfr.plot_topomap( + ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 + ) + # test data with taper dimension (complex) + state = tfr.__getstate__() + tfr = AverageTFR(inst=state | dict(data=data * (1 + 1j))) + tfr.plot_topomap( + ch_type="mag", tmin=0.05, tmax=0.150, fmin=0, fmax=10, res=res, contours=0 + ) + # remove taper dim before proceeding + data = data[:, 0] + # test real numbers tfr = AverageTFRArray( info=info, data=data, times=times, freqs=np.arange(n_freqs), nave=nave diff --git a/mne/viz/topomap.py b/mne/viz/topomap.py index dd63a626683..d83698acbb1 100644 --- a/mne/viz/topomap.py +++ b/mne/viz/topomap.py @@ -1882,7 +1882,7 @@ def plot_tfr_topomap( tfr, ch_type, sphere=sphere ) outlines = _make_head_outlines(sphere, pos, outlines, clip_origin) - data = tfr.data[picks, :, :] + data = tfr.data[picks] # merging grads before rescaling makes ERDs visible if merge_channels: @@ -1890,6 +1890,18 @@ def plot_tfr_topomap( data = rescale(data, tfr.times, baseline, mode, copy=True) + # handle unaggregated multitaper (complex or phase multitaper data) + if tfr.weights is not None: # assumes a taper dimension + logger.info("Aggregating multitaper estimates before plotting...") + weights = tfr.weights[np.newaxis, :, :, np.newaxis] # add channel & time dims + data = weights * data + if np.iscomplexobj(data): # complex coefficients → power + data *= data.conj() + data = data.real.sum(axis=1) + data *= 2 / (weights * weights.conj()).real.sum(axis=1) + else: # tapered phase data → weighted phase data + data = data.mean(axis=1) + # handle remaining complex amplitude → real power if np.iscomplexobj(data): data = np.sqrt((data * data.conj()).real) From 2abb7b220ed2580e141158499919300cfa1f6a3b Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Mon, 13 Jan 2025 17:37:42 -0500 Subject: [PATCH 2/5] BUG: Fix bug with helium anon (#13056) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- doc/changes/devel/13056.bugfix.rst | 1 + mne/_fiff/meas_info.py | 15 +++++++--- mne/_fiff/tests/test_meas_info.py | 48 +++++++++++++++++++----------- mne/_fiff/write.py | 9 +++--- mne/utils/_testing.py | 4 +-- 5 files changed, 49 insertions(+), 28 deletions(-) create mode 100644 doc/changes/devel/13056.bugfix.rst diff --git a/doc/changes/devel/13056.bugfix.rst b/doc/changes/devel/13056.bugfix.rst new file mode 100644 index 00000000000..2a7919de289 --- /dev/null +++ b/doc/changes/devel/13056.bugfix.rst @@ -0,0 +1 @@ +Fix bug with saving of anonymized data when helium info is present in measurement info, by `Eric Larson`_. diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index 629d9a4b0ce..ecc93591a05 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -2493,6 +2493,8 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): hi["meas_date"] = _ensure_meas_date_none_or_dt( tuple(int(t) for t in tag.data), ) + if "meas_date" not in hi: + hi["meas_date"] = None info["helium_info"] = hi del hi @@ -2879,7 +2881,8 @@ def write_meas_info(fid, info, data_type=None, reset_range=True): write_float(fid, FIFF.FIFF_HELIUM_LEVEL, hi["helium_level"]) if hi.get("orig_file_guid") is not None: write_string(fid, FIFF.FIFF_ORIG_FILE_GUID, hi["orig_file_guid"]) - write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"])) + if hi["meas_date"] is not None: + write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"])) end_block(fid, FIFF.FIFFB_HELIUM) del hi @@ -2916,8 +2919,10 @@ def write_meas_info(fid, info, data_type=None, reset_range=True): _write_proc_history(fid, info) -@fill_doc -def write_info(fname, info, data_type=None, reset_range=True): +@verbose +def write_info( + fname, info, *, data_type=None, reset_range=True, overwrite=False, verbose=None +): """Write measurement info in fif file. Parameters @@ -2931,8 +2936,10 @@ def write_info(fname, info, data_type=None, reset_range=True): raw data. reset_range : bool If True, info['chs'][k]['range'] will be set to unity. + %(overwrite)s + %(verbose)s """ - with start_and_end_file(fname) as fid: + with start_and_end_file(fname, overwrite=overwrite) as fid: start_block(fid, FIFF.FIFFB_MEAS) write_meas_info(fid, info, data_type, reset_range) end_block(fid, FIFF.FIFFB_MEAS) diff --git a/mne/_fiff/tests/test_meas_info.py b/mne/_fiff/tests/test_meas_info.py index 3e3c150573f..d088da2a4a2 100644 --- a/mne/_fiff/tests/test_meas_info.py +++ b/mne/_fiff/tests/test_meas_info.py @@ -306,7 +306,9 @@ def test_read_write_info(tmp_path): gantry_angle = info["gantry_angle"] meas_id = info["meas_id"] - write_info(temp_file, info) + with pytest.raises(FileExistsError, match="Destination file exists"): + write_info(temp_file, info) + write_info(temp_file, info, overwrite=True) info = read_info(temp_file) assert info["proc_history"][0]["creator"] == creator assert info["hpi_meas"][0]["creator"] == creator @@ -348,7 +350,7 @@ def test_read_write_info(tmp_path): info["meas_date"] = datetime(1800, 1, 1, 0, 0, 0, tzinfo=timezone.utc) fname = tmp_path / "test.fif" with pytest.raises(RuntimeError, match="must be between "): - write_info(fname, info) + write_info(fname, info, overwrite=True) @testing.requires_testing_data @@ -377,7 +379,7 @@ def test_io_coord_frame(tmp_path): for ch_type in ("eeg", "seeg", "ecog", "dbs", "hbo", "hbr"): info = create_info(ch_names=["Test Ch"], sfreq=1000.0, ch_types=[ch_type]) info["chs"][0]["loc"][:3] = [0.05, 0.01, -0.03] - write_info(fname, info) + write_info(fname, info, overwrite=True) info2 = read_info(fname) assert info2["chs"][0]["coord_frame"] == FIFF.FIFFV_COORD_HEAD @@ -585,7 +587,7 @@ def test_check_consistency(): info2["subject_info"] = {"height": "bad"} -def _test_anonymize_info(base_info): +def _test_anonymize_info(base_info, tmp_path): """Test that sensitive information can be anonymized.""" pytest.raises(TypeError, anonymize_info, "foo") assert isinstance(base_info, Info) @@ -692,14 +694,25 @@ def _adjust_back(e_i, dt): # exp 4 tests is a supplied daysback delta_t_3 = timedelta(days=223 + 364 * 500) + def _check_equiv(got, want, err_msg): + __tracebackhide__ = True + fname_temp = tmp_path / "test.fif" + assert_object_equal(got, want, err_msg=err_msg) + write_info(fname_temp, got, reset_range=False, overwrite=True) + got = read_info(fname_temp) + # this gets changed on write but that's expected + with got._unlock(): + got["file_id"] = want["file_id"] + assert_object_equal(got, want, err_msg=f"{err_msg} (on I/O round trip)") + new_info = anonymize_info(base_info.copy()) - assert_object_equal(new_info, exp_info, err_msg="anon mismatch") + _check_equiv(new_info, exp_info, err_msg="anon mismatch") new_info = anonymize_info(base_info.copy(), keep_his=True) - assert_object_equal(new_info, exp_info_2, err_msg="anon keep_his mismatch") + _check_equiv(new_info, exp_info_2, err_msg="anon keep_his mismatch") new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) - assert_object_equal(new_info, exp_info_3, err_msg="anon daysback mismatch") + _check_equiv(new_info, exp_info_3, err_msg="anon daysback mismatch") with pytest.raises(RuntimeError, match="anonymize_info generated"): anonymize_info(base_info.copy(), daysback=delta_t_3.days) @@ -726,7 +739,7 @@ def _adjust_back(e_i, dt): new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) else: new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) - assert_object_equal( + _check_equiv( new_info, exp_info_3, err_msg="meas_date=None daysback mismatch", @@ -734,7 +747,7 @@ def _adjust_back(e_i, dt): with _record_warnings(): # meas_date is None new_info = anonymize_info(base_info.copy()) - assert_object_equal(new_info, exp_info_3, err_msg="meas_date=None mismatch") + _check_equiv(new_info, exp_info_3, err_msg="meas_date=None mismatch") @pytest.mark.parametrize( @@ -777,8 +790,8 @@ def _complete_info(info): height=2.0, ) info["helium_info"] = dict( - he_level_raw=12.34, - helium_level=45.67, + he_level_raw=np.float32(12.34), + helium_level=np.float32(45.67), meas_date=datetime(2024, 11, 14, 14, 8, 2, tzinfo=timezone.utc), orig_file_guid="e", ) @@ -796,14 +809,13 @@ def _complete_info(info): machid=np.ones(2, int), secs=d[0], usecs=d[1], - date=d, ), experimenter="j", max_info=dict( - max_st=[], - sss_ctc=[], - sss_cal=[], - sss_info=dict(head_pos=None, in_order=8), + max_st=dict(), + sss_ctc=dict(), + sss_cal=dict(), + sss_info=dict(in_order=8), ), date=d, ), @@ -830,8 +842,8 @@ def test_anonymize(tmp_path): # test mne.anonymize_info() events = read_events(event_name) epochs = Epochs(raw, events[:1], 2, 0.0, 0.1, baseline=None) - _test_anonymize_info(raw.info) - _test_anonymize_info(epochs.info) + _test_anonymize_info(raw.info, tmp_path) + _test_anonymize_info(epochs.info, tmp_path) # test instance methods & I/O roundtrip for inst, keep_his in zip((raw, epochs), (True, False)): diff --git a/mne/_fiff/write.py b/mne/_fiff/write.py index 1fc32f0163e..8486ca13121 100644 --- a/mne/_fiff/write.py +++ b/mne/_fiff/write.py @@ -13,7 +13,7 @@ import numpy as np from scipy.sparse import csc_array, csr_array -from ..utils import _file_like, _validate_type, logger +from ..utils import _check_fname, _file_like, _validate_type, logger from ..utils.numerics import _date_to_julian from .constants import FIFF @@ -277,7 +277,7 @@ def end_block(fid, kind): write_int(fid, FIFF.FIFF_BLOCK_END, kind) -def start_file(fname, id_=None): +def start_file(fname, id_=None, *, overwrite=True): """Open a fif file for writing and writes the compulsory header tags. Parameters @@ -294,6 +294,7 @@ def start_file(fname, id_=None): fid = fname fid.seek(0) else: + fname = _check_fname(fname, overwrite=overwrite) fname = str(fname) if op.splitext(fname)[1].lower() == ".gz": logger.debug("Writing using gzip") @@ -311,9 +312,9 @@ def start_file(fname, id_=None): @contextmanager -def start_and_end_file(fname, id_=None): +def start_and_end_file(fname, id_=None, *, overwrite=True): """Start and (if successfully written) close the file.""" - with start_file(fname, id_=id_) as fid: + with start_file(fname, id_=id_, overwrite=overwrite) as fid: yield fid end_file(fid) # we only hit this line if the yield does not err diff --git a/mne/utils/_testing.py b/mne/utils/_testing.py index 323b530a641..63e0d1036b9 100644 --- a/mne/utils/_testing.py +++ b/mne/utils/_testing.py @@ -179,9 +179,9 @@ def assert_and_remove_boundary_annot(annotations, n=1): annotations.delete(idx) -def assert_object_equal(a, b, *, err_msg="Object mismatch"): +def assert_object_equal(a, b, *, err_msg="Object mismatch", allclose=False): """Assert two objects are equal.""" - d = object_diff(a, b) + d = object_diff(a, b, allclose=allclose) assert d == "", f"{err_msg}\n{d}" From f82d3993617d2a34744eb955385448c67672d6ec Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Mon, 13 Jan 2025 22:43:44 +0000 Subject: [PATCH 3/5] Add `combine_spectrum()` function and allow `grand_average()` to support `Spectrum` data (#13058) Co-authored-by: Daniel McCloy --- doc/api/time_frequency.rst | 1 + doc/changes/devel/13058.newfeature.rst | 1 + mne/time_frequency/__init__.pyi | 2 + mne/time_frequency/spectrum.py | 68 +++++++++++++++++++++++ mne/time_frequency/tests/test_spectrum.py | 55 +++++++++++++++++- mne/time_frequency/tfr.py | 12 ++-- mne/utils/numerics.py | 57 +++++++++++-------- 7 files changed, 165 insertions(+), 31 deletions(-) create mode 100644 doc/changes/devel/13058.newfeature.rst diff --git a/doc/api/time_frequency.rst b/doc/api/time_frequency.rst index 8923920bdba..b66b1b6ca64 100644 --- a/doc/api/time_frequency.rst +++ b/doc/api/time_frequency.rst @@ -31,6 +31,7 @@ Functions that operate on mne-python objects: .. autosummary:: :toctree: ../generated/ + combine_spectrum csd_tfr csd_fourier csd_multitaper diff --git a/doc/changes/devel/13058.newfeature.rst b/doc/changes/devel/13058.newfeature.rst new file mode 100644 index 00000000000..bbd01fa4552 --- /dev/null +++ b/doc/changes/devel/13058.newfeature.rst @@ -0,0 +1 @@ +Add the function :func:`mne.time_frequency.combine_spectrum` for combining data across :class:`mne.time_frequency.Spectrum` objects, and allow :func:`mne.grand_average` to operate on :class:`mne.time_frequency.Spectrum` objects, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/time_frequency/__init__.pyi b/mne/time_frequency/__init__.pyi index 0faeb7263d8..a612c2a850a 100644 --- a/mne/time_frequency/__init__.pyi +++ b/mne/time_frequency/__init__.pyi @@ -11,6 +11,7 @@ __all__ = [ "RawTFRArray", "Spectrum", "SpectrumArray", + "combine_spectrum", "csd_array_fourier", "csd_array_morlet", "csd_array_multitaper", @@ -61,6 +62,7 @@ from .spectrum import ( EpochsSpectrumArray, Spectrum, SpectrumArray, + combine_spectrum, read_spectrum, ) from .tfr import ( diff --git a/mne/time_frequency/spectrum.py b/mne/time_frequency/spectrum.py index a70697fd57c..b1de7f11c0f 100644 --- a/mne/time_frequency/spectrum.py +++ b/mne/time_frequency/spectrum.py @@ -1643,6 +1643,74 @@ def __init__( ) +def combine_spectrum(all_spectrum, weights="nave"): + """Merge spectral data by weighted addition. + + Create a new :class:`mne.time_frequency.Spectrum` instance, using a combination of + the supplied instances as its data. By default, the mean (weighted by trials) is + used. Subtraction can be performed by passing negative weights (e.g., ``[1, -1]``). + Data must have the same channels and the same frequencies. + + Parameters + ---------- + all_spectrum : list of Spectrum + The Spectrum objects. + weights : list of float | str + The weights to apply to the data of each :class:`~mne.time_frequency.Spectrum` + instance, or a string describing the weighting strategy to apply: 'nave' + computes sum-to-one weights proportional to each object’s nave attribute; + 'equal' weights each :class:`~mne.time_frequency.Spectrum` by + ``1 / len(all_spectrum)``. + + Returns + ------- + spectrum : Spectrum + The new spectral data. + + Notes + ----- + .. versionadded:: 1.10.0 + """ + spectrum = all_spectrum[0].copy() + if isinstance(weights, str): + if weights not in ("nave", "equal"): + raise ValueError('Weights must be a list of float, or "nave" or "equal"') + if weights == "nave": + for s_ in all_spectrum: + if s_.nave is None: + raise ValueError(f"The 'nave' attribute is not specified for {s_}") + weights = np.array([e.nave for e in all_spectrum], float) + weights /= weights.sum() + else: # == 'equal' + weights = [1.0 / len(all_spectrum)] * len(all_spectrum) + weights = np.array(weights, float) + if weights.ndim != 1 or weights.size != len(all_spectrum): + raise ValueError("Weights must be the same size as all_spectrum") + + ch_names = spectrum.ch_names + for s_ in all_spectrum[1:]: + assert ( + s_.ch_names == ch_names + ), f"{spectrum} and {s_} do not contain the same channels" + assert ( + np.max(np.abs(s_.freqs - spectrum.freqs)) < 1e-7 + ), f"{spectrum} and {s_} do not contain the same frequencies" + + # use union of bad channels + bads = list( + set(spectrum.info["bads"]).union(*(s_.info["bads"] for s_ in all_spectrum[1:])) + ) + spectrum.info["bads"] = bads + + # combine spectral data + spectrum._data = sum(w * s_.data for w, s_ in zip(weights, all_spectrum)) + if spectrum.nave is not None: + spectrum._nave = max( + int(1.0 / sum(w**2 / s_.nave for w, s_ in zip(weights, all_spectrum))), 1 + ) + return spectrum + + def read_spectrum(fname): """Load a :class:`mne.time_frequency.Spectrum` object from disk. diff --git a/mne/time_frequency/tests/test_spectrum.py b/mne/time_frequency/tests/test_spectrum.py index 162d89b1c25..927c22360c5 100644 --- a/mne/time_frequency/tests/test_spectrum.py +++ b/mne/time_frequency/tests/test_spectrum.py @@ -14,7 +14,11 @@ from mne.io import RawArray from mne.time_frequency import read_spectrum from mne.time_frequency.multitaper import _psd_from_mt -from mne.time_frequency.spectrum import EpochsSpectrumArray, SpectrumArray +from mne.time_frequency.spectrum import ( + EpochsSpectrumArray, + SpectrumArray, + combine_spectrum, +) from mne.utils import _record_warnings @@ -190,6 +194,55 @@ def test_spectrum_copy(raw_spectrum): assert raw_spectrum.freqs is not None +@pytest.mark.parametrize("weights", ["nave", "equal", [1, -1]]) +def test_combine_spectrum(raw_spectrum, weights): + """Test `combine_spectrum()` works.""" + spectrum1 = raw_spectrum.copy() + spectrum2 = raw_spectrum.copy() + if weights == "nave": + spectrum1._nave = 1 + spectrum2._nave = 2 + spectrum2._data *= 2 + new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights) + assert_allclose(new_spectrum.data, spectrum1.data * (5 / 3)) + elif weights == "equal": + spectrum2._data *= 2 + new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights) + assert_allclose(new_spectrum.data, spectrum1.data * 1.5) + else: + new_spectrum = combine_spectrum([spectrum1, spectrum2], weights=weights) + assert_allclose(new_spectrum.data, 0) + + +def test_combine_spectrum_error_catch(raw_spectrum): + """Test `combine_spectrum()` catches errors.""" + # Test bad weights + with pytest.raises( + ValueError, match='Weights must be a list of float, or "nave" or "equal"' + ): + combine_spectrum([raw_spectrum, raw_spectrum], weights="foo") + with pytest.raises( + ValueError, match="Weights must be the same size as all_spectrum" + ): + combine_spectrum([raw_spectrum, raw_spectrum], weights=[1, 1, 1]) + + # Test bad nave + with pytest.raises(ValueError, match="The 'nave' attribute is not specified"): + combine_spectrum([raw_spectrum, raw_spectrum], weights="nave") + + # Test inconsistent channels + raw_spectrum2 = raw_spectrum.copy() + raw_spectrum2.drop_channels(raw_spectrum2.ch_names[0]) + with pytest.raises(AssertionError, match=".* do not contain the same channels"): + combine_spectrum([raw_spectrum, raw_spectrum2], weights="equal") + + # Test inconsistent frequencies + raw_spectrum2 = raw_spectrum.copy() + raw_spectrum2._freqs = raw_spectrum2._freqs + 1 + with pytest.raises(AssertionError, match=".* do not contain the same frequencies"): + combine_spectrum([raw_spectrum, raw_spectrum2], weights="equal") + + def test_spectrum_reject_by_annot(raw): """Test rejecting by annotation. diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index 918fea1a33f..b1736f151d2 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -3960,12 +3960,12 @@ def combine_tfr(all_tfr, weights="nave"): ch_names = tfr.ch_names for t_ in all_tfr[1:]: - assert t_.ch_names == ch_names, ValueError( - f"{tfr} and {t_} do not contain the same channels" - ) - assert np.max(np.abs(t_.times - tfr.times)) < 1e-7, ValueError( - f"{tfr} and {t_} do not contain the same time instants" - ) + assert ( + t_.ch_names == ch_names + ), f"{tfr} and {t_} do not contain the same channels" + assert ( + np.max(np.abs(t_.times - tfr.times)) < 1e-7 + ), f"{tfr} and {t_} do not contain the same time instants" # use union of bad channels bads = list(set(tfr.info["bads"]).union(*(t_.info["bads"] for t_ in all_tfr[1:]))) diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index 4bf8d094f81..eed23998774 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -515,37 +515,42 @@ def _freq_mask(freqs, sfreq, fmin=None, fmax=None, raise_error=True): def grand_average(all_inst, interpolate_bads=True, drop_bads=True): - """Make grand average of a list of Evoked or AverageTFR data. + """Make grand average of a list of Evoked, AverageTFR, or Spectrum data. - For :class:`mne.Evoked` data, the function interpolates bad channels based - on the ``interpolate_bads`` parameter. If ``interpolate_bads`` is True, - the grand average file will contain good channels and the bad channels - interpolated from the good MEG/EEG channels. - For :class:`mne.time_frequency.AverageTFR` data, the function takes the - subset of channels not marked as bad in any of the instances. + For :class:`mne.Evoked` data, the function interpolates bad channels based on the + ``interpolate_bads`` parameter. If ``interpolate_bads`` is True, the grand average + file will contain good channels and the bad channels interpolated from the good + MEG/EEG channels. + For :class:`mne.time_frequency.AverageTFR` and :class:`mne.time_frequency.Spectrum` + data, the function takes the subset of channels not marked as bad in any of the + instances. - The ``grand_average.nave`` attribute will be equal to the number - of evoked datasets used to calculate the grand average. + The ``grand_average.nave`` attribute will be equal to the number of datasets used to + calculate the grand average. - .. note:: A grand average evoked should not be used for source - localization. + .. note:: A grand average evoked should not be used for source localization. Parameters ---------- - all_inst : list of Evoked or AverageTFR - The evoked datasets. + all_inst : list of Evoked, AverageTFR or Spectrum + The datasets. + + .. versionchanged:: 1.10.0 + Added support for :class:`~mne.time_frequency.Spectrum` objects. + interpolate_bads : bool If True, bad MEG and EEG channels are interpolated. Ignored for - AverageTFR. + :class:`~mne.time_frequency.AverageTFR` and + :class:`~mne.time_frequency.Spectrum` data. drop_bads : bool - If True, drop all bad channels marked as bad in any data set. - If neither interpolate_bads nor drop_bads is True, in the output file, - every channel marked as bad in at least one of the input files will be - marked as bad, but no interpolation or dropping will be performed. + If True, drop all bad channels marked as bad in any data set. If neither + ``interpolate_bads`` nor ``drop_bads`` is `True`, in the output file, every + channel marked as bad in at least one of the input files will be marked as bad, + but no interpolation or dropping will be performed. Returns ------- - grand_average : Evoked | AverageTFR + grand_average : Evoked | AverageTFR | Spectrum The grand average data. Same type as input. Notes @@ -558,15 +563,17 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): # check if all elements in the given list are evoked data from ..channels.channels import equalize_channels from ..evoked import Evoked - from ..time_frequency import AverageTFR + from ..time_frequency import AverageTFR, Spectrum if not all_inst: - raise ValueError("Please pass a list of Evoked or AverageTFR objects.") + raise ValueError( + "Please pass a list of Evoked, AverageTFR, or Spectrum objects." + ) elif len(all_inst) == 1: warn("Only a single dataset was passed to mne.grand_average().") inst_type = type(all_inst[0]) - _validate_type(all_inst[0], (Evoked, AverageTFR), "All elements") + _validate_type(all_inst[0], (Evoked, AverageTFR, Spectrum), "All elements") for inst in all_inst: _validate_type(inst, inst_type, "All elements", "of the same type") @@ -581,6 +588,8 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): for inst in all_inst ] from ..evoked import combine_evoked as combine + elif isinstance(all_inst[0], Spectrum): + from ..time_frequency.spectrum import combine_spectrum as combine else: # isinstance(all_inst[0], AverageTFR): from ..time_frequency.tfr import combine_tfr as combine @@ -591,9 +600,9 @@ def grand_average(all_inst, interpolate_bads=True, drop_bads=True): inst.drop_channels(bads) equalize_channels(all_inst, copy=False) - # make grand_average object using combine_[evoked/tfr] + # make grand_average object using combine_[evoked/tfr/spectrum] grand_average = combine(all_inst, weights="equal") - # change the grand_average.nave to the number of Evokeds + # change the grand_average.nave to the number of datasets grand_average.nave = len(all_inst) # change comment field grand_average.comment = f"Grand average (n = {grand_average.nave})" From 2ae61edccb2af5b5f9f3a89a3131499b5c229c27 Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Tue, 14 Jan 2025 00:08:54 +0000 Subject: [PATCH 4/5] Add `combine_tfr` to API (#13054) --- doc/api/time_frequency.rst | 1 + doc/changes/devel/13054.newfeature.rst | 1 + mne/time_frequency/__init__.pyi | 2 ++ mne/time_frequency/tfr.py | 8 ++++---- 4 files changed, 8 insertions(+), 4 deletions(-) create mode 100644 doc/changes/devel/13054.newfeature.rst diff --git a/doc/api/time_frequency.rst b/doc/api/time_frequency.rst index b66b1b6ca64..a9ab2c34268 100644 --- a/doc/api/time_frequency.rst +++ b/doc/api/time_frequency.rst @@ -32,6 +32,7 @@ Functions that operate on mne-python objects: :toctree: ../generated/ combine_spectrum + combine_tfr csd_tfr csd_fourier csd_multitaper diff --git a/doc/changes/devel/13054.newfeature.rst b/doc/changes/devel/13054.newfeature.rst new file mode 100644 index 00000000000..3c89290e7fe --- /dev/null +++ b/doc/changes/devel/13054.newfeature.rst @@ -0,0 +1 @@ +Added :func:`mne.time_frequency.combine_tfr` to allow combining TFRs across tapers, by `Thomas Binns`_. \ No newline at end of file diff --git a/mne/time_frequency/__init__.pyi b/mne/time_frequency/__init__.pyi index a612c2a850a..6b53c39a98b 100644 --- a/mne/time_frequency/__init__.pyi +++ b/mne/time_frequency/__init__.pyi @@ -12,6 +12,7 @@ __all__ = [ "Spectrum", "SpectrumArray", "combine_spectrum", + "combine_tfr", "csd_array_fourier", "csd_array_morlet", "csd_array_multitaper", @@ -73,6 +74,7 @@ from .tfr import ( EpochsTFRArray, RawTFR, RawTFRArray, + combine_tfr, fwhm, morlet, read_tfrs, diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index b1736f151d2..71dabce6d31 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -3914,10 +3914,10 @@ def __init__( def combine_tfr(all_tfr, weights="nave"): """Merge AverageTFR data by weighted addition. - Create a new AverageTFR instance, using a combination of the supplied - instances as its data. By default, the mean (weighted by trials) is used. - Subtraction can be performed by passing negative weights (e.g., [1, -1]). - Data must have the same channels and the same time instants. + Create a new :class:`mne.time_frequency.AverageTFR` instance, using a combination of + the supplied instances as its data. By default, the mean (weighted by trials) is + used. Subtraction can be performed by passing negative weights (e.g., [1, -1]). Data + must have the same channels and the same time instants. Parameters ---------- From 5fec4e024a963c3f628693ab172d5b77cbafe6db Mon Sep 17 00:00:00 2001 From: Simon Kern <14980558+skjerns@users.noreply.github.com> Date: Tue, 14 Jan 2025 12:46:03 +0100 Subject: [PATCH 5/5] [DOC] extend documentation for add_channels (#13051) --- mne/channels/channels.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/mne/channels/channels.py b/mne/channels/channels.py index 8fbff33c13e..ed6dd8508cc 100644 --- a/mne/channels/channels.py +++ b/mne/channels/channels.py @@ -661,17 +661,21 @@ def _pick_projs(self): return self def add_channels(self, add_list, force_update_info=False): - """Append new channels to the instance. + """Append new channels from other MNE objects to the instance. Parameters ---------- add_list : list - A list of objects to append to self. Must contain all the same - type as the current object. + A list of MNE objects to append to the current instance. + The channels contained in the other instances are appended to the + channels of the current instance. Therefore, all other instances + must be of the same type as the current object. + See notes on how to add data coming from an array. force_update_info : bool If True, force the info for objects to be appended to match the - values in ``self``. This should generally only be used when adding - stim channels for which important metadata won't be overwritten. + values of the current instance. This should generally only be + used when adding stim channels for which important metadata won't + be overwritten. .. versionadded:: 0.12 @@ -688,6 +692,12 @@ def add_channels(self, add_list, force_update_info=False): ----- If ``self`` is a Raw instance that has been preloaded into a :obj:`numpy.memmap` instance, the memmap will be resized. + + This function expects an MNE object to be appended (e.g. :class:`~mne.io.Raw`, + :class:`~mne.Epochs`, :class:`~mne.Evoked`). If you simply want to add a + channel based on values of an np.ndarray, you need to create a + :class:`~mne.io.RawArray`. + See `_ """ # avoid circular imports from ..epochs import BaseEpochs