From f400f6d74d5018340bb71767012ff1b687167c76 Mon Sep 17 00:00:00 2001 From: Eric Larson Date: Fri, 10 Jan 2025 12:50:13 -0500 Subject: [PATCH] BUG: Fix bug with helium anon --- doc/changes/devel/bugfix.rst | 1 + mne/_fiff/meas_info.py | 15 +++++++--- mne/_fiff/tests/test_meas_info.py | 48 +++++++++++++++++++------------ mne/_fiff/write.py | 9 +++--- mne/utils/_testing.py | 4 +-- 5 files changed, 49 insertions(+), 28 deletions(-) create mode 100644 doc/changes/devel/bugfix.rst diff --git a/doc/changes/devel/bugfix.rst b/doc/changes/devel/bugfix.rst new file mode 100644 index 00000000000..2a7919de289 --- /dev/null +++ b/doc/changes/devel/bugfix.rst @@ -0,0 +1 @@ +Fix bug with saving of anonymized data when helium info is present in measurement info, by `Eric Larson`_. diff --git a/mne/_fiff/meas_info.py b/mne/_fiff/meas_info.py index 629d9a4b0ce..ecc93591a05 100644 --- a/mne/_fiff/meas_info.py +++ b/mne/_fiff/meas_info.py @@ -2493,6 +2493,8 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None): hi["meas_date"] = _ensure_meas_date_none_or_dt( tuple(int(t) for t in tag.data), ) + if "meas_date" not in hi: + hi["meas_date"] = None info["helium_info"] = hi del hi @@ -2879,7 +2881,8 @@ def write_meas_info(fid, info, data_type=None, reset_range=True): write_float(fid, FIFF.FIFF_HELIUM_LEVEL, hi["helium_level"]) if hi.get("orig_file_guid") is not None: write_string(fid, FIFF.FIFF_ORIG_FILE_GUID, hi["orig_file_guid"]) - write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"])) + if hi["meas_date"] is not None: + write_int(fid, FIFF.FIFF_MEAS_DATE, _dt_to_stamp(hi["meas_date"])) end_block(fid, FIFF.FIFFB_HELIUM) del hi @@ -2916,8 +2919,10 @@ def write_meas_info(fid, info, data_type=None, reset_range=True): _write_proc_history(fid, info) -@fill_doc -def write_info(fname, info, data_type=None, reset_range=True): +@verbose +def write_info( + fname, info, *, data_type=None, reset_range=True, overwrite=False, verbose=None +): """Write measurement info in fif file. Parameters @@ -2931,8 +2936,10 @@ def write_info(fname, info, data_type=None, reset_range=True): raw data. reset_range : bool If True, info['chs'][k]['range'] will be set to unity. + %(overwrite)s + %(verbose)s """ - with start_and_end_file(fname) as fid: + with start_and_end_file(fname, overwrite=overwrite) as fid: start_block(fid, FIFF.FIFFB_MEAS) write_meas_info(fid, info, data_type, reset_range) end_block(fid, FIFF.FIFFB_MEAS) diff --git a/mne/_fiff/tests/test_meas_info.py b/mne/_fiff/tests/test_meas_info.py index 3e3c150573f..d088da2a4a2 100644 --- a/mne/_fiff/tests/test_meas_info.py +++ b/mne/_fiff/tests/test_meas_info.py @@ -306,7 +306,9 @@ def test_read_write_info(tmp_path): gantry_angle = info["gantry_angle"] meas_id = info["meas_id"] - write_info(temp_file, info) + with pytest.raises(FileExistsError, match="Destination file exists"): + write_info(temp_file, info) + write_info(temp_file, info, overwrite=True) info = read_info(temp_file) assert info["proc_history"][0]["creator"] == creator assert info["hpi_meas"][0]["creator"] == creator @@ -348,7 +350,7 @@ def test_read_write_info(tmp_path): info["meas_date"] = datetime(1800, 1, 1, 0, 0, 0, tzinfo=timezone.utc) fname = tmp_path / "test.fif" with pytest.raises(RuntimeError, match="must be between "): - write_info(fname, info) + write_info(fname, info, overwrite=True) @testing.requires_testing_data @@ -377,7 +379,7 @@ def test_io_coord_frame(tmp_path): for ch_type in ("eeg", "seeg", "ecog", "dbs", "hbo", "hbr"): info = create_info(ch_names=["Test Ch"], sfreq=1000.0, ch_types=[ch_type]) info["chs"][0]["loc"][:3] = [0.05, 0.01, -0.03] - write_info(fname, info) + write_info(fname, info, overwrite=True) info2 = read_info(fname) assert info2["chs"][0]["coord_frame"] == FIFF.FIFFV_COORD_HEAD @@ -585,7 +587,7 @@ def test_check_consistency(): info2["subject_info"] = {"height": "bad"} -def _test_anonymize_info(base_info): +def _test_anonymize_info(base_info, tmp_path): """Test that sensitive information can be anonymized.""" pytest.raises(TypeError, anonymize_info, "foo") assert isinstance(base_info, Info) @@ -692,14 +694,25 @@ def _adjust_back(e_i, dt): # exp 4 tests is a supplied daysback delta_t_3 = timedelta(days=223 + 364 * 500) + def _check_equiv(got, want, err_msg): + __tracebackhide__ = True + fname_temp = tmp_path / "test.fif" + assert_object_equal(got, want, err_msg=err_msg) + write_info(fname_temp, got, reset_range=False, overwrite=True) + got = read_info(fname_temp) + # this gets changed on write but that's expected + with got._unlock(): + got["file_id"] = want["file_id"] + assert_object_equal(got, want, err_msg=f"{err_msg} (on I/O round trip)") + new_info = anonymize_info(base_info.copy()) - assert_object_equal(new_info, exp_info, err_msg="anon mismatch") + _check_equiv(new_info, exp_info, err_msg="anon mismatch") new_info = anonymize_info(base_info.copy(), keep_his=True) - assert_object_equal(new_info, exp_info_2, err_msg="anon keep_his mismatch") + _check_equiv(new_info, exp_info_2, err_msg="anon keep_his mismatch") new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) - assert_object_equal(new_info, exp_info_3, err_msg="anon daysback mismatch") + _check_equiv(new_info, exp_info_3, err_msg="anon daysback mismatch") with pytest.raises(RuntimeError, match="anonymize_info generated"): anonymize_info(base_info.copy(), daysback=delta_t_3.days) @@ -726,7 +739,7 @@ def _adjust_back(e_i, dt): new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) else: new_info = anonymize_info(base_info.copy(), daysback=delta_t_2.days) - assert_object_equal( + _check_equiv( new_info, exp_info_3, err_msg="meas_date=None daysback mismatch", @@ -734,7 +747,7 @@ def _adjust_back(e_i, dt): with _record_warnings(): # meas_date is None new_info = anonymize_info(base_info.copy()) - assert_object_equal(new_info, exp_info_3, err_msg="meas_date=None mismatch") + _check_equiv(new_info, exp_info_3, err_msg="meas_date=None mismatch") @pytest.mark.parametrize( @@ -777,8 +790,8 @@ def _complete_info(info): height=2.0, ) info["helium_info"] = dict( - he_level_raw=12.34, - helium_level=45.67, + he_level_raw=np.float32(12.34), + helium_level=np.float32(45.67), meas_date=datetime(2024, 11, 14, 14, 8, 2, tzinfo=timezone.utc), orig_file_guid="e", ) @@ -796,14 +809,13 @@ def _complete_info(info): machid=np.ones(2, int), secs=d[0], usecs=d[1], - date=d, ), experimenter="j", max_info=dict( - max_st=[], - sss_ctc=[], - sss_cal=[], - sss_info=dict(head_pos=None, in_order=8), + max_st=dict(), + sss_ctc=dict(), + sss_cal=dict(), + sss_info=dict(in_order=8), ), date=d, ), @@ -830,8 +842,8 @@ def test_anonymize(tmp_path): # test mne.anonymize_info() events = read_events(event_name) epochs = Epochs(raw, events[:1], 2, 0.0, 0.1, baseline=None) - _test_anonymize_info(raw.info) - _test_anonymize_info(epochs.info) + _test_anonymize_info(raw.info, tmp_path) + _test_anonymize_info(epochs.info, tmp_path) # test instance methods & I/O roundtrip for inst, keep_his in zip((raw, epochs), (True, False)): diff --git a/mne/_fiff/write.py b/mne/_fiff/write.py index 1fc32f0163e..8486ca13121 100644 --- a/mne/_fiff/write.py +++ b/mne/_fiff/write.py @@ -13,7 +13,7 @@ import numpy as np from scipy.sparse import csc_array, csr_array -from ..utils import _file_like, _validate_type, logger +from ..utils import _check_fname, _file_like, _validate_type, logger from ..utils.numerics import _date_to_julian from .constants import FIFF @@ -277,7 +277,7 @@ def end_block(fid, kind): write_int(fid, FIFF.FIFF_BLOCK_END, kind) -def start_file(fname, id_=None): +def start_file(fname, id_=None, *, overwrite=True): """Open a fif file for writing and writes the compulsory header tags. Parameters @@ -294,6 +294,7 @@ def start_file(fname, id_=None): fid = fname fid.seek(0) else: + fname = _check_fname(fname, overwrite=overwrite) fname = str(fname) if op.splitext(fname)[1].lower() == ".gz": logger.debug("Writing using gzip") @@ -311,9 +312,9 @@ def start_file(fname, id_=None): @contextmanager -def start_and_end_file(fname, id_=None): +def start_and_end_file(fname, id_=None, *, overwrite=True): """Start and (if successfully written) close the file.""" - with start_file(fname, id_=id_) as fid: + with start_file(fname, id_=id_, overwrite=overwrite) as fid: yield fid end_file(fid) # we only hit this line if the yield does not err diff --git a/mne/utils/_testing.py b/mne/utils/_testing.py index 323b530a641..63e0d1036b9 100644 --- a/mne/utils/_testing.py +++ b/mne/utils/_testing.py @@ -179,9 +179,9 @@ def assert_and_remove_boundary_annot(annotations, n=1): annotations.delete(idx) -def assert_object_equal(a, b, *, err_msg="Object mismatch"): +def assert_object_equal(a, b, *, err_msg="Object mismatch", allclose=False): """Assert two objects are equal.""" - d = object_diff(a, b) + d = object_diff(a, b, allclose=allclose) assert d == "", f"{err_msg}\n{d}"