Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into ebm
Browse files Browse the repository at this point in the history
* upstream/main:
  [DOC] extend documentation for add_channels (mne-tools#13051)
  Add `combine_tfr` to API (mne-tools#13054)
  Add `combine_spectrum()` function and allow `grand_average()` to support `Spectrum` data (mne-tools#13058)
  BUG: Fix bug with helium anon (mne-tools#13056)
  [ENH] Add option to store and return TFR taper weights (mne-tools#12910)
  • Loading branch information
larsoner committed Jan 14, 2025
2 parents ebda34e + 5fec4e0 commit 72facf0
Show file tree
Hide file tree
Showing 20 changed files with 744 additions and 209 deletions.
2 changes: 2 additions & 0 deletions doc/api/time_frequency.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ Functions that operate on mne-python objects:
.. autosummary::
:toctree: ../generated/

combine_spectrum
combine_tfr
csd_tfr
csd_fourier
csd_multitaper
Expand Down
1 change: 1 addition & 0 deletions doc/changes/devel/12910.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added the option to return taper weights from :func:`mne.time_frequency.tfr_array_multitaper`, and taper weights are now stored in the :class:`mne.time_frequency.BaseTFR` objects, by `Thomas Binns`_.
1 change: 1 addition & 0 deletions doc/changes/devel/13054.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added :func:`mne.time_frequency.combine_tfr` to allow combining TFRs across tapers, by `Thomas Binns`_.
1 change: 1 addition & 0 deletions doc/changes/devel/13056.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug with saving of anonymized data when helium info is present in measurement info, by `Eric Larson`_.
1 change: 1 addition & 0 deletions doc/changes/devel/13058.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add the function :func:`mne.time_frequency.combine_spectrum` for combining data across :class:`mne.time_frequency.Spectrum` objects, and allow :func:`mne.grand_average` to operate on :class:`mne.time_frequency.Spectrum` objects, by `Thomas Binns`_.
15 changes: 11 additions & 4 deletions mne/_fiff/meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -2493,6 +2493,8 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None):
hi["meas_date"] = _ensure_meas_date_none_or_dt(
tuple(int(t) for t in tag.data),
)
if "meas_date" not in hi:
hi["meas_date"] = None
info["helium_info"] = hi
del hi

Expand Down Expand Up @@ -2879,7 +2881,8 @@ def write_meas_info(fid, info, data_type=None, reset_range=True):
write_float(fid, FIFF.FIFF_HELIUM_LEVEL, hi["helium_level"])
if hi.get("orig_file_guid") is not None:
write_string(fid, FIFF.FIFF_ORIG_FILE_GUID, hi["orig_file_guid"])
write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"]))
if hi["meas_date"] is not None:
write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"]))
end_block(fid, FIFF.FIFFB_HELIUM)
del hi

Expand Down Expand Up @@ -2916,8 +2919,10 @@ def write_meas_info(fid, info, data_type=None, reset_range=True):
_write_proc_history(fid, info)


@fill_doc
def write_info(fname, info, data_type=None, reset_range=True):
@verbose
def write_info(
fname, info, *, data_type=None, reset_range=True, overwrite=False, verbose=None
):
"""Write measurement info in fif file.
Parameters
Expand All @@ -2931,8 +2936,10 @@ def write_info(fname, info, data_type=None, reset_range=True):
raw data.
reset_range : bool
If True, info['chs'][k]['range'] will be set to unity.
%(overwrite)s
%(verbose)s
"""
with start_and_end_file(fname) as fid:
with start_and_end_file(fname, overwrite=overwrite) as fid:
start_block(fid, FIFF.FIFFB_MEAS)
write_meas_info(fid, info, data_type, reset_range)
end_block(fid, FIFF.FIFFB_MEAS)
Expand Down
48 changes: 30 additions & 18 deletions mne/_fiff/tests/test_meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ def test_read_write_info(tmp_path):
gantry_angle = info["gantry_angle"]

meas_id = info["meas_id"]
write_info(temp_file, info)
with pytest.raises(FileExistsError, match="Destination file exists"):
write_info(temp_file, info)
write_info(temp_file, info, overwrite=True)
info = read_info(temp_file)
assert info["proc_history"][0]["creator"] == creator
assert info["hpi_meas"][0]["creator"] == creator
Expand Down Expand Up @@ -348,7 +350,7 @@ def test_read_write_info(tmp_path):
info["meas_date"] = datetime(1800, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
fname = tmp_path / "test.fif"
with pytest.raises(RuntimeError, match="must be between "):
write_info(fname, info)
write_info(fname, info, overwrite=True)


@testing.requires_testing_data
Expand Down Expand Up @@ -377,7 +379,7 @@ def test_io_coord_frame(tmp_path):
for ch_type in ("eeg", "seeg", "ecog", "dbs", "hbo", "hbr"):
info = create_info(ch_names=["Test Ch"], sfreq=1000.0, ch_types=[ch_type])
info["chs"][0]["loc"][:3] = [0.05, 0.01, -0.03]
write_info(fname, info)
write_info(fname, info, overwrite=True)
info2 = read_info(fname)
assert info2["chs"][0]["coord_frame"] == FIFF.FIFFV_COORD_HEAD

Expand Down Expand Up @@ -585,7 +587,7 @@ def test_check_consistency():
info2["subject_info"] = {"height": "bad"}


def _test_anonymize_info(base_info):
def _test_anonymize_info(base_info, tmp_path):
"""Test that sensitive information can be anonymized."""
pytest.raises(TypeError, anonymize_info, "foo")
assert isinstance(base_info, Info)
Expand Down Expand Up @@ -692,14 +694,25 @@ def _adjust_back(e_i, dt):
# exp 4 tests is a supplied daysback
delta_t_3 = timedelta(days=223 + 364 * 500)

def _check_equiv(got, want, err_msg):
__tracebackhide__ = True
fname_temp = tmp_path / "test.fif"
assert_object_equal(got, want, err_msg=err_msg)
write_info(fname_temp, got, reset_range=False, overwrite=True)
got = read_info(fname_temp)
# this gets changed on write but that's expected
with got._unlock():
got["file_id"] = want["file_id"]
assert_object_equal(got, want, err_msg=f"{err_msg} (on I/O round trip)")

new_info = anonymize_info(base_info.copy())
assert_object_equal(new_info, exp_info, err_msg="anon mismatch")
_check_equiv(new_info, exp_info, err_msg="anon mismatch")

new_info = anonymize_info(base_info.copy(), keep_his=True)
assert_object_equal(new_info, exp_info_2, err_msg="anon keep_his mismatch")
_check_equiv(new_info, exp_info_2, err_msg="anon keep_his mismatch")

new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days)
assert_object_equal(new_info, exp_info_3, err_msg="anon daysback mismatch")
_check_equiv(new_info, exp_info_3, err_msg="anon daysback mismatch")

with pytest.raises(RuntimeError, match="anonymize_info generated"):
anonymize_info(base_info.copy(), daysback=delta_t_3.days)
Expand All @@ -726,15 +739,15 @@ def _adjust_back(e_i, dt):
new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days)
else:
new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days)
assert_object_equal(
_check_equiv(
new_info,
exp_info_3,
err_msg="meas_date=None daysback mismatch",
)

with _record_warnings(): # meas_date is None
new_info = anonymize_info(base_info.copy())
assert_object_equal(new_info, exp_info_3, err_msg="meas_date=None mismatch")
_check_equiv(new_info, exp_info_3, err_msg="meas_date=None mismatch")


@pytest.mark.parametrize(
Expand Down Expand Up @@ -777,8 +790,8 @@ def _complete_info(info):
height=2.0,
)
info["helium_info"] = dict(
he_level_raw=12.34,
helium_level=45.67,
he_level_raw=np.float32(12.34),
helium_level=np.float32(45.67),
meas_date=datetime(2024, 11, 14, 14, 8, 2, tzinfo=timezone.utc),
orig_file_guid="e",
)
Expand All @@ -796,14 +809,13 @@ def _complete_info(info):
machid=np.ones(2, int),
secs=d[0],
usecs=d[1],
date=d,
),
experimenter="j",
max_info=dict(
max_st=[],
sss_ctc=[],
sss_cal=[],
sss_info=dict(head_pos=None, in_order=8),
max_st=dict(),
sss_ctc=dict(),
sss_cal=dict(),
sss_info=dict(in_order=8),
),
date=d,
),
Expand All @@ -830,8 +842,8 @@ def test_anonymize(tmp_path):
# test mne.anonymize_info()
events = read_events(event_name)
epochs = Epochs(raw, events[:1], 2, 0.0, 0.1, baseline=None)
_test_anonymize_info(raw.info)
_test_anonymize_info(epochs.info)
_test_anonymize_info(raw.info, tmp_path)
_test_anonymize_info(epochs.info, tmp_path)

# test instance methods & I/O roundtrip
for inst, keep_his in zip((raw, epochs), (True, False)):
Expand Down
9 changes: 5 additions & 4 deletions mne/_fiff/write.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
from scipy.sparse import csc_array, csr_array

from ..utils import _file_like, _validate_type, logger
from ..utils import _check_fname, _file_like, _validate_type, logger
from ..utils.numerics import _date_to_julian
from .constants import FIFF

Expand Down Expand Up @@ -277,7 +277,7 @@ def end_block(fid, kind):
write_int(fid, FIFF.FIFF_BLOCK_END, kind)


def start_file(fname, id_=None):
def start_file(fname, id_=None, *, overwrite=True):
"""Open a fif file for writing and writes the compulsory header tags.
Parameters
Expand All @@ -294,6 +294,7 @@ def start_file(fname, id_=None):
fid = fname
fid.seek(0)
else:
fname = _check_fname(fname, overwrite=overwrite)
fname = str(fname)
if op.splitext(fname)[1].lower() == ".gz":
logger.debug("Writing using gzip")
Expand All @@ -311,9 +312,9 @@ def start_file(fname, id_=None):


@contextmanager
def start_and_end_file(fname, id_=None):
def start_and_end_file(fname, id_=None, *, overwrite=True):
"""Start and (if successfully written) close the file."""
with start_file(fname, id_=id_) as fid:
with start_file(fname, id_=id_, overwrite=overwrite) as fid:
yield fid
end_file(fid) # we only hit this line if the yield does not err

Expand Down
20 changes: 15 additions & 5 deletions mne/channels/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,17 +661,21 @@ def _pick_projs(self):
return self

def add_channels(self, add_list, force_update_info=False):
"""Append new channels to the instance.
"""Append new channels from other MNE objects to the instance.
Parameters
----------
add_list : list
A list of objects to append to self. Must contain all the same
type as the current object.
A list of MNE objects to append to the current instance.
The channels contained in the other instances are appended to the
channels of the current instance. Therefore, all other instances
must be of the same type as the current object.
See notes on how to add data coming from an array.
force_update_info : bool
If True, force the info for objects to be appended to match the
values in ``self``. This should generally only be used when adding
stim channels for which important metadata won't be overwritten.
values of the current instance. This should generally only be
used when adding stim channels for which important metadata won't
be overwritten.
.. versionadded:: 0.12
Expand All @@ -688,6 +692,12 @@ def add_channels(self, add_list, force_update_info=False):
-----
If ``self`` is a Raw instance that has been preloaded into a
:obj:`numpy.memmap` instance, the memmap will be resized.
This function expects an MNE object to be appended (e.g. :class:`~mne.io.Raw`,
:class:`~mne.Epochs`, :class:`~mne.Evoked`). If you simply want to add a
channel based on values of an np.ndarray, you need to create a
:class:`~mne.io.RawArray`.
See <https://mne.tools/mne-project-template/auto_examples/plot_mne_objects_from_arrays.html>`_
"""
# avoid circular imports
from ..epochs import BaseEpochs
Expand Down
4 changes: 4 additions & 0 deletions mne/time_frequency/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ __all__ = [
"RawTFRArray",
"Spectrum",
"SpectrumArray",
"combine_spectrum",
"combine_tfr",
"csd_array_fourier",
"csd_array_morlet",
"csd_array_multitaper",
Expand Down Expand Up @@ -61,6 +63,7 @@ from .spectrum import (
EpochsSpectrumArray,
Spectrum,
SpectrumArray,
combine_spectrum,
read_spectrum,
)
from .tfr import (
Expand All @@ -71,6 +74,7 @@ from .tfr import (
EpochsTFRArray,
RawTFR,
RawTFRArray,
combine_tfr,
fwhm,
morlet,
read_tfrs,
Expand Down
10 changes: 10 additions & 0 deletions mne/time_frequency/multitaper.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,7 @@ def tfr_array_multitaper(
output="complex",
n_jobs=None,
*,
return_weights=False,
verbose=None,
):
"""Compute Time-Frequency Representation (TFR) using DPSS tapers.
Expand Down Expand Up @@ -504,6 +505,11 @@ def tfr_array_multitaper(
coherence across trials.
%(n_jobs)s
The parallelization is implemented across channels.
return_weights : bool, default False
If True, return the taper weights. Only applies if ``output='complex'`` or
``'phase'``.
.. versionadded:: 1.10.0
%(verbose)s
Returns
Expand All @@ -520,6 +526,9 @@ def tfr_array_multitaper(
If ``output`` is ``'avg_power_itc'``, the real values in ``out``
contain the average power and the imaginary values contain the
inter-trial coherence: :math:`out = power_{avg} + i * ITC`.
weights : array of shape (n_tapers, n_freqs)
The taper weights. Only returned if ``output='complex'`` or ``'phase'`` and
``return_weights=True``.
See Also
--------
Expand Down Expand Up @@ -550,6 +559,7 @@ def tfr_array_multitaper(
use_fft=use_fft,
decim=decim,
output=output,
return_weights=return_weights,
n_jobs=n_jobs,
verbose=verbose,
)
Loading

0 comments on commit 72facf0

Please sign in to comment.