From 919257388f05364ef6334331a02bdc7cf6af96fb Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 6 Sep 2024 10:45:48 -0500 Subject: [PATCH 01/21] WIP: remove AnalysisNwbfileLog --- CHANGELOG.md | 7 ++ src/spyglass/common/common_ephys.py | 8 +- src/spyglass/common/common_nwbfile.py | 108 +----------------- src/spyglass/common/common_position.py | 6 +- src/spyglass/decoding/v0/clusterless.py | 3 +- src/spyglass/decoding/v1/waveform_features.py | 3 - src/spyglass/lfp/analysis/v1/lfp_band.py | 5 +- src/spyglass/lfp/v1/lfp.py | 3 +- src/spyglass/linearization/v0/main.py | 4 +- src/spyglass/linearization/v1/main.py | 4 +- .../position/v1/position_dlc_orient.py | 6 +- .../v1/position_dlc_pose_estimation.py | 1 - .../position/v1/position_dlc_position.py | 1 - .../position/v1/position_dlc_selection.py | 3 - .../position/v1/position_trodes_position.py | 5 +- .../spikesorting/v0/spikesorting_curation.py | 10 +- src/spyglass/spikesorting/v1/curation.py | 4 - .../spikesorting/v1/metric_curation.py | 4 - src/spyglass/spikesorting/v1/recording.py | 23 ++-- src/spyglass/spikesorting/v1/sorting.py | 2 - 20 files changed, 30 insertions(+), 180 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bc6c50c7..2cdf9be4d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,10 +6,17 @@ +```python +import datajoint as dj + +dj.FreeTable(dj.conn(), "common_nwbfile.analysis_nwbfile_log").drop() +``` + ### Infrastructure - Disable populate transaction protection for long-populating tables #1066 - Add docstrings to all public methods #1076 +- Remove `AnalysisNwbfileLog` #10XX ### Pipelines diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index 37c4361c5..4cc5918fe 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -463,7 +463,7 @@ def make(self, key): """ # get the NWB object with the data; FIX: change to fetch with # additional infrastructure - lfp_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) # logged + lfp_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) rawdata = Raw().nwb_object(key) sampling_rate, interval_list_name = (Raw() & key).fetch1( @@ -553,7 +553,6 @@ def make(self, key): }, replace=True, ) - AnalysisNwbfile().log(key, table=self.full_table_name) self.insert1(key) def nwb_object(self, key): @@ -748,9 +747,7 @@ def make(self, key): 6. Adds resulting interval list to IntervalList table. """ # create the analysis nwb file to store the results. - lfp_band_file_name = AnalysisNwbfile().create( # logged - key["nwb_file_name"] - ) + lfp_band_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) # get the NWB object with the lfp data; # FIX: change to fetch with additional infrastructure @@ -946,7 +943,6 @@ def make(self, key): "previously saved lfp band times do not match current times" ) - AnalysisNwbfile().log(lfp_band_file_name, table=self.full_table_name) self.insert1(key) def fetch1_dataframe(self, *attrs, **kwargs) -> pd.DataFrame: diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 2166b2c04..ad0a2e0f4 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -3,7 +3,6 @@ import stat import string from pathlib import Path -from time import time from uuid import uuid4 import datajoint as dj @@ -172,8 +171,6 @@ class AnalysisNwbfile(SpyglassMixin, dj.Manual): # See #630, #664. Excessive key length. - _creation_times = {} - def create(self, nwb_file_name: str) -> str: """Open the NWB file, create copy, write to disk and return new name. @@ -190,9 +187,6 @@ def create(self, nwb_file_name: str) -> str: analysis_file_name : str The name of the new NWB file. """ - # To allow some times to occur before create - # creation_time = self._creation_times.pop("pre_create_time", time()) - nwb_file_abspath = Nwbfile.get_abs_path(nwb_file_name) alter_source_script = False with pynwb.NWBHDF5IO( @@ -214,16 +208,19 @@ def create(self, nwb_file_name: str) -> str: alter_source_script = True analysis_file_name = self.__get_new_file_name(nwb_file_name) + # write the new file logger.info(f"Writing new NWB file {analysis_file_name}") analysis_file_abs_path = AnalysisNwbfile.get_abs_path( analysis_file_name ) + # export the new NWB file with pynwb.NWBHDF5IO( path=analysis_file_abs_path, mode="w", manager=io.manager ) as export_io: export_io.export(io, nwbf) + if alter_source_script: self._alter_spyglass_version(analysis_file_abs_path) @@ -235,8 +232,6 @@ def create(self, nwb_file_name: str) -> str: permissions = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH os.chmod(analysis_file_abs_path, permissions) - # self._creation_times[analysis_file_name] = creation_time - return analysis_file_name @staticmethod @@ -699,100 +694,3 @@ def nightly_cleanup(): # a separate external files clean up required - this is to be done # during times when no other transactions are in progress. AnalysisNwbfile.cleanup(True) - - def log(self, *args, **kwargs): - """Null log method. Revert to _disabled_log to turn back on.""" - logger.debug("Logging disabled.") - - def _disabled_log(self, analysis_file_name, table=None): - """Passthrough to the AnalysisNwbfileLog table. Avoid new imports.""" - if isinstance(analysis_file_name, dict): - analysis_file_name = analysis_file_name["analysis_file_name"] - time_delta = time() - self._creation_times[analysis_file_name] - file_size = Path(self.get_abs_path(analysis_file_name)).stat().st_size - - AnalysisNwbfileLog().log( - analysis_file_name=analysis_file_name, - time_delta=time_delta, - file_size=file_size, - table=table, - ) - - def increment_access(self, *args, **kwargs): - """Null method. Revert to _disabled_increment_access to turn back on.""" - logger.debug("Incrementing access disabled.") - - def _disabled_increment_access(self, keys, table=None): - """Passthrough to the AnalysisNwbfileLog table. Avoid new imports.""" - if not isinstance(keys, list): - key = [keys] - - for key in keys: - AnalysisNwbfileLog().increment_access(key, table=table) - - -@schema -class AnalysisNwbfileLog(dj.Manual): - definition = """ - id: int auto_increment - --- - -> AnalysisNwbfile - dj_user : varchar(64) # user who created the file - timestamp = CURRENT_TIMESTAMP : timestamp # when the file was created - table = null : varchar(64) # creating table - time_delta = null : float # how long it took to create - file_size = null : float # size of the file in bytes - accessed = 0 : int # n times accessed - unique index (analysis_file_name) - """ - - def log( - self, - analysis_file_name=None, - time_delta=None, - file_size=None, - table=None, - ): - """Log the creation of an analysis NWB file. - - Parameters - ---------- - analysis_file_name : str - The name of the analysis NWB file. - """ - - self.insert1( - { - "dj_user": dj.config["database.user"], - "analysis_file_name": analysis_file_name, - "time_delta": time_delta, - "file_size": file_size, - "table": table[:64], - } - ) - - def increment_access(self, key, table=None): - """Increment the accessed field for the given analysis file name. - - Parameters - ---------- - key : Union[str, dict] - The name of the analysis NWB file, or a key to the table. - table : str, optional - The table that created the file. - """ - if isinstance(key, str): - key = {"analysis_file_name": key} - - if not (query := self & key): - self.log(**key, table=table) - entries = query.fetch(as_dict=True) - - inserts = [] - for entry in entries: - entry["accessed"] += 1 - if table and not entry.get("table"): - entry["table"] = table - inserts.append(entry) - - self.insert(inserts, replace=True) diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index 9026eae7f..02b30cc88 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -88,9 +88,7 @@ def make(self, key): """Insert smoothed head position, orientation and velocity.""" logger.info(f"Computing position for: {key}") - analysis_file_name = AnalysisNwbfile().create( # logged - key["nwb_file_name"] - ) + analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) raw_position = RawPosition.PosObject & key spatial_series = raw_position.fetch_nwb()[0]["raw_position"] @@ -117,8 +115,6 @@ def make(self, key): AnalysisNwbfile().add(key["nwb_file_name"], analysis_file_name) - AnalysisNwbfile().log(key, table=self.full_table_name) - self.insert1(key) @staticmethod diff --git a/src/spyglass/decoding/v0/clusterless.py b/src/spyglass/decoding/v0/clusterless.py index cc73fb5eb..1f0184776 100644 --- a/src/spyglass/decoding/v0/clusterless.py +++ b/src/spyglass/decoding/v0/clusterless.py @@ -159,7 +159,7 @@ def make(self, key): 4. Saves the marks as a TimeSeries object in a new AnalysisNwbfile. """ # create a new AnalysisNwbfile and a timeseries for the marks and save - key["analysis_file_name"] = AnalysisNwbfile().create( # logged + key["analysis_file_name"] = AnalysisNwbfile().create( key["nwb_file_name"] ) # get the list of mark parameters @@ -246,7 +246,6 @@ def make(self, key): key["analysis_file_name"], nwb_object ) AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) - AnalysisNwbfile().log(key, table=self.full_table_name) self.insert1(key) def fetch1_dataframe(self) -> pd.DataFrame: diff --git a/src/spyglass/decoding/v1/waveform_features.py b/src/spyglass/decoding/v1/waveform_features.py index 818d9bf43..077600b9e 100644 --- a/src/spyglass/decoding/v1/waveform_features.py +++ b/src/spyglass/decoding/v1/waveform_features.py @@ -1,6 +1,5 @@ import os from itertools import chain -from time import time import datajoint as dj import numpy as np @@ -108,7 +107,6 @@ class UnitWaveformFeatures(SpyglassMixin, dj.Computed): def make(self, key): """Populate UnitWaveformFeatures table.""" - AnalysisNwbfile()._creation_times["pre_create_time"] = time() # get the list of feature parameters params = (WaveformFeaturesParams & key).fetch1("params") @@ -175,7 +173,6 @@ def make(self, key): nwb_file_name, key["analysis_file_name"], ) - AnalysisNwbfile().log(key, table=self.full_table_name) self.insert1(key) diff --git a/src/spyglass/lfp/analysis/v1/lfp_band.py b/src/spyglass/lfp/analysis/v1/lfp_band.py index 074da4b38..57df622dc 100644 --- a/src/spyglass/lfp/analysis/v1/lfp_band.py +++ b/src/spyglass/lfp/analysis/v1/lfp_band.py @@ -175,9 +175,7 @@ class LFPBandV1(SpyglassMixin, dj.Computed): def make(self, key): """Populate LFPBandV1""" # create the analysis nwb file to store the results. - lfp_band_file_name = AnalysisNwbfile().create( # logged - key["nwb_file_name"] - ) + lfp_band_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) # get the NWB object with the lfp data; # FIX: change to fetch with additional infrastructure lfp_key = {"merge_id": key["lfp_merge_id"]} @@ -368,7 +366,6 @@ def make(self, key): "previously saved lfp band times do not match current times" ) - AnalysisNwbfile().log(key, table=self.full_table_name) self.insert1(key) def fetch1_dataframe(self, *attrs, **kwargs): diff --git a/src/spyglass/lfp/v1/lfp.py b/src/spyglass/lfp/v1/lfp.py index 9c529e24a..9fc704bca 100644 --- a/src/spyglass/lfp/v1/lfp.py +++ b/src/spyglass/lfp/v1/lfp.py @@ -66,7 +66,7 @@ def make(self, key): the AnalysisNwbfile table. The valid times for the filtered data are stored in the IntervalList table. """ - lfp_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) # logged + lfp_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) # get the NWB object with the data nwbf_key = {"nwb_file_name": key["nwb_file_name"]} rawdata = (Raw & nwbf_key).fetch_nwb()[0]["raw"] @@ -201,7 +201,6 @@ def make(self, key): orig_key["analysis_file_name"] = lfp_file_name orig_key["lfp_object_id"] = lfp_object_id LFPOutput.insert1(orig_key) - AnalysisNwbfile().log(key, table=self.full_table_name) def fetch1_dataframe(self, *attrs, **kwargs) -> pd.DataFrame: """Fetch a single dataframe.""" diff --git a/src/spyglass/linearization/v0/main.py b/src/spyglass/linearization/v0/main.py index c802b583c..f406f7322 100644 --- a/src/spyglass/linearization/v0/main.py +++ b/src/spyglass/linearization/v0/main.py @@ -126,7 +126,7 @@ def make(self, key): """Compute linearized position for a given key.""" logger.info(f"Computing linear position for: {key}") - key["analysis_file_name"] = AnalysisNwbfile().create( # logged + key["analysis_file_name"] = AnalysisNwbfile().create( key["nwb_file_name"] ) @@ -189,8 +189,6 @@ def make(self, key): self.insert1(key) - AnalysisNwbfile().log(key, table=self.full_table_name) - def fetch1_dataframe(self) -> DataFrame: """Fetch a single dataframe""" return self.fetch_nwb()[0]["linearized_position"].set_index("time") diff --git a/src/spyglass/linearization/v1/main.py b/src/spyglass/linearization/v1/main.py index 76ec85aaf..52cc0338f 100644 --- a/src/spyglass/linearization/v1/main.py +++ b/src/spyglass/linearization/v1/main.py @@ -134,7 +134,7 @@ def make(self, key): position_nwb = PositionOutput().fetch_nwb( {"merge_id": key["pos_merge_id"]} )[0] - key["analysis_file_name"] = AnalysisNwbfile().create( # logged + key["analysis_file_name"] = AnalysisNwbfile().create( position_nwb["nwb_file_name"] ) position = np.asarray( @@ -195,8 +195,6 @@ def make(self, key): [orig_key], part_name=part_name, skip_duplicates=True ) - AnalysisNwbfile().log(key, table=self.full_table_name) - def fetch1_dataframe(self) -> DataFrame: """Fetch a single dataframe.""" return self.fetch_nwb()[0]["linearized_position"].set_index("time") diff --git a/src/spyglass/position/v1/position_dlc_orient.py b/src/spyglass/position/v1/position_dlc_orient.py index 118ddffc8..9134440ff 100644 --- a/src/spyglass/position/v1/position_dlc_orient.py +++ b/src/spyglass/position/v1/position_dlc_orient.py @@ -1,5 +1,3 @@ -from time import time - import datajoint as dj import numpy as np import pandas as pd @@ -123,7 +121,6 @@ def make(self, key): 4. Insert the key into the DLCOrientation table. """ # Get labels to smooth from Parameters table - AnalysisNwbfile()._creation_times["pre_create_time"] = time() pos_df = self._get_pos_df(key) params = (DLCOrientationParams() & key).fetch1("params") @@ -162,7 +159,7 @@ def make(self, key): final_df = pd.DataFrame( orientation, columns=["orientation"], index=pos_df.index ) - key["analysis_file_name"] = AnalysisNwbfile().create( # logged + key["analysis_file_name"] = AnalysisNwbfile().create( key["nwb_file_name"] ) # if spatial series exists, get metadata from there @@ -192,7 +189,6 @@ def make(self, key): ) self.insert1(key) - AnalysisNwbfile().log(key, table=self.full_table_name) def fetch1_dataframe(self) -> pd.DataFrame: """Fetch a single dataframe""" diff --git a/src/spyglass/position/v1/position_dlc_pose_estimation.py b/src/spyglass/position/v1/position_dlc_pose_estimation.py index 2ff376837..fc0a2ee8a 100644 --- a/src/spyglass/position/v1/position_dlc_pose_estimation.py +++ b/src/spyglass/position/v1/position_dlc_pose_estimation.py @@ -347,7 +347,6 @@ def _logged_make(self, key): analysis_file_name=key["analysis_file_name"], ) self.BodyPart.insert1(key) - AnalysisNwbfile().log(key, table=self.full_table_name) def fetch_dataframe(self, *attrs, **kwargs) -> pd.DataFrame: """Fetch a concatenated dataframe of all bodyparts.""" diff --git a/src/spyglass/position/v1/position_dlc_position.py b/src/spyglass/position/v1/position_dlc_position.py index 58dce0a89..a4e1cdba9 100644 --- a/src/spyglass/position/v1/position_dlc_position.py +++ b/src/spyglass/position/v1/position_dlc_position.py @@ -287,7 +287,6 @@ def _logged_make(self, key): analysis_file_name=key["analysis_file_name"], ) self.insert1(key) - AnalysisNwbfile().log(key, table=self.full_table_name) def fetch1_dataframe(self) -> pd.DataFrame: """Fetch a single dataframe.""" diff --git a/src/spyglass/position/v1/position_dlc_selection.py b/src/spyglass/position/v1/position_dlc_selection.py index e0bd0359e..4298b1b5e 100644 --- a/src/spyglass/position/v1/position_dlc_selection.py +++ b/src/spyglass/position/v1/position_dlc_selection.py @@ -1,6 +1,5 @@ import copy from pathlib import Path -from time import time import datajoint as dj import numpy as np @@ -67,7 +66,6 @@ def make(self, key): """ orig_key = copy.deepcopy(key) # Add to Analysis NWB file - AnalysisNwbfile()._creation_times["pre_create_time"] = time() key["pose_eval_result"] = self.evaluate_pose_estimation(key) pos_nwb = (DLCCentroid & key).fetch_nwb()[0] @@ -155,7 +153,6 @@ def make(self, key): part_name=to_camel_case(self.table_name.split("__")[-1]), skip_duplicates=True, ) - AnalysisNwbfile().log(key, table=self.full_table_name) def fetch1_dataframe(self) -> pd.DataFrame: """Return the position data as a DataFrame.""" diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index 72adada46..18f5f3bcc 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -174,9 +174,7 @@ def make(self, key): logger.info(f"Computing position for: {key}") orig_key = copy.deepcopy(key) - analysis_file_name = AnalysisNwbfile().create( # logged - key["nwb_file_name"] - ) + analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) raw_position = RawPosition.PosObject & key spatial_series = raw_position.fetch_nwb()[0]["raw_position"] @@ -218,7 +216,6 @@ def make(self, key): PositionOutput._merge_insert( [orig_key], part_name=part_name, skip_duplicates=True ) - AnalysisNwbfile().log(key, table=self.full_table_name) @staticmethod def generate_pos_components(*args, **kwargs): diff --git a/src/spyglass/spikesorting/v0/spikesorting_curation.py b/src/spyglass/spikesorting/v0/spikesorting_curation.py index 52a2dff73..bbe167790 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_curation.py +++ b/src/spyglass/spikesorting/v0/spikesorting_curation.py @@ -341,7 +341,7 @@ def make(self, key): 3. Generates an analysis NWB file with the waveforms 4. Inserts the key into Waveforms table """ - key["analysis_file_name"] = AnalysisNwbfile().create( # logged + key["analysis_file_name"] = AnalysisNwbfile().create( key["nwb_file_name"] ) recording = Curation.get_recording(key) @@ -375,7 +375,6 @@ def make(self, key): key["waveforms_object_id"] = object_id AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) - AnalysisNwbfile().log(key, table=self.full_table_name) self.insert1(key) def load_waveforms(self, key: dict): @@ -541,9 +540,7 @@ def make(self, key): 3. Generates an analysis NWB file with the metrics. 4. Inserts the key into QualityMetrics table """ - analysis_file_name = AnalysisNwbfile().create( # logged - key["nwb_file_name"] - ) + analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) waveform_extractor = Waveforms().load_waveforms(key) key["analysis_file_name"] = ( analysis_file_name # add to key here to prevent fetch errors @@ -567,7 +564,6 @@ def make(self, key): key["analysis_file_name"], metrics=qm ) AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) - AnalysisNwbfile().log(key, table=self.full_table_name) self.insert1(key) @@ -980,7 +976,6 @@ def make(self, key): 2. Saves the sorting in an analysis NWB file 3. Inserts key into CuratedSpikeSorting table and units into part table. """ - AnalysisNwbfile()._creation_times["pre_create_time"] = time.time() unit_labels_to_remove = ["reject"] # check that the Curation has metrics metrics = (Curation & key).fetch1("quality_metrics") @@ -1051,7 +1046,6 @@ def make(self, key): labels=labels, ) - AnalysisNwbfile().log(key, table=self.full_table_name) self.insert1(key) # now add the units diff --git a/src/spyglass/spikesorting/v1/curation.py b/src/spyglass/spikesorting/v1/curation.py index 00b1ef81e..7996bcd61 100644 --- a/src/spyglass/spikesorting/v1/curation.py +++ b/src/spyglass/spikesorting/v1/curation.py @@ -1,4 +1,3 @@ -from time import time from typing import Dict, List, Union import datajoint as dj @@ -80,8 +79,6 @@ def insert_curation( ------- curation_key : dict """ - AnalysisNwbfile()._creation_times["pre_create_time"] = time() - sort_query = cls & {"sorting_id": sorting_id} parent_curation_id = max(parent_curation_id, -1) if parent_curation_id == -1: @@ -124,7 +121,6 @@ def insert_curation( "description": description, } cls.insert1(key, skip_duplicates=True) - AnalysisNwbfile().log(analysis_file_name, table=cls.full_table_name) return key diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py index b9d1fb66f..f1af29259 100644 --- a/src/spyglass/spikesorting/v1/metric_curation.py +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -1,6 +1,5 @@ import os import uuid -from time import time from typing import Any, Dict, List, Union import datajoint as dj @@ -226,8 +225,6 @@ def make(self, key): 7. Saves the waveforms, metrics, labels, and merge groups to an analysis NWB file and inserts into MetricCuration table. """ - - AnalysisNwbfile()._creation_times["pre_create_time"] = time() # FETCH nwb_file_name = ( SpikeSortingSelection * MetricCurationSelection & key @@ -301,7 +298,6 @@ def make(self, key): nwb_file_name, key["analysis_file_name"], ) - AnalysisNwbfile().log(key, table=self.full_table_name) self.insert1(key) @classmethod diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index f4e150837..66c07d86c 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -1,5 +1,4 @@ import uuid -from time import time from typing import Iterable, List, Optional, Tuple, Union import datajoint as dj @@ -182,7 +181,10 @@ def make(self, key): - NWB file to AnalysisNwbfile - Recording ids to SpikeSortingRecording """ - AnalysisNwbfile()._creation_times["pre_create_time"] = time() + nwb_file_name = (SpikeSortingRecordingSelection & key).fetch1( + "nwb_file_name" + ) + # DO: # - get valid times for sort interval # - proprocess recording @@ -190,9 +192,7 @@ def make(self, key): sort_interval_valid_times = self._get_sort_interval_valid_times(key) recording, timestamps = self._get_preprocessed_recording(key) recording_nwb_file_name, recording_object_id = _write_recording_to_nwb( - recording, - timestamps, - (SpikeSortingRecordingSelection & key).fetch1("nwb_file_name"), + recording, timestamps, nwb_file_name ) key["analysis_file_name"] = recording_nwb_file_name key["object_id"] = recording_object_id @@ -203,21 +203,13 @@ def make(self, key): # - entry into SpikeSortingRecording IntervalList.insert1( { - "nwb_file_name": (SpikeSortingRecordingSelection & key).fetch1( - "nwb_file_name" - ), + "nwb_file_name": nwb_file_name, "interval_list_name": key["recording_id"], "valid_times": sort_interval_valid_times, "pipeline": "spikesorting_recording_v1", } ) - AnalysisNwbfile().add( - (SpikeSortingRecordingSelection & key).fetch1("nwb_file_name"), - key["analysis_file_name"], - ) - AnalysisNwbfile().log( - recording_nwb_file_name, table=self.full_table_name - ) + AnalysisNwbfile().add(nwb_file_name, key["analysis_file_name"]) self.insert1(key) @classmethod @@ -538,6 +530,7 @@ def _write_recording_to_nwb( analysis_nwb_file = AnalysisNwbfile().create(nwb_file_name) analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(analysis_nwb_file) + with pynwb.NWBHDF5IO( path=analysis_nwb_file_abs_path, mode="a", diff --git a/src/spyglass/spikesorting/v1/sorting.py b/src/spyglass/spikesorting/v1/sorting.py index 06c7c6ede..1ad352943 100644 --- a/src/spyglass/spikesorting/v1/sorting.py +++ b/src/spyglass/spikesorting/v1/sorting.py @@ -155,7 +155,6 @@ def make(self, key: dict): # - information about the recording # - artifact free intervals # - spike sorter and sorter params - AnalysisNwbfile()._creation_times["pre_create_time"] = time.time() recording_key = ( SpikeSortingRecording * SpikeSortingSelection & key @@ -301,7 +300,6 @@ def make(self, key: dict): (SpikeSortingSelection & key).fetch1("nwb_file_name"), key["analysis_file_name"], ) - AnalysisNwbfile().log(key, table=self.full_table_name) self.insert1(key, skip_duplicates=True) @classmethod From 27f0004b5db1321a1aac645a0466423c298c3edf Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 6 Sep 2024 15:04:38 -0500 Subject: [PATCH 02/21] WIP: recompute --- src/spyglass/common/common_dandi.py | 14 +++- src/spyglass/common/common_nwbfile.py | 10 ++- src/spyglass/spikesorting/v1/recording.py | 68 +++++++++++++---- src/spyglass/utils/dj_helper_fn.py | 7 +- src/spyglass/utils/nwb_helper_fn.py | 91 +++++++++++++---------- 5 files changed, 128 insertions(+), 62 deletions(-) diff --git a/src/spyglass/common/common_dandi.py b/src/spyglass/common/common_dandi.py index 6dfbb56e3..c802b39e6 100644 --- a/src/spyglass/common/common_dandi.py +++ b/src/spyglass/common/common_dandi.py @@ -50,8 +50,20 @@ class DandiPath(SpyglassMixin, dj.Manual): dandi_instance = "dandi": varchar(32) """ - def fetch_file_from_dandi(self, key: dict): + def key_from_path(self, file_path) -> dict: + return {"filename": os.path.basename(file_path)} + + def has_file_path(self, file_path: str) -> bool: + return bool(self & self.key_from_path(file_path)) + + def fetch_file_from_dandi( + self, key: dict = None, nwb_file_path: str = None + ): """Fetch the file from Dandi and return the NWB file object.""" + if key is None and nwb_file_path is None: + raise ValueError("Must provide either key or nwb_file_path") + key = key or self.key_from_path(nwb_file_path) + dandiset_id, dandi_path, dandi_instance = (self & key).fetch1( "dandiset_id", "dandi_path", "dandi_instance" ) diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index ad0a2e0f4..d2fb34a6d 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -171,7 +171,9 @@ class AnalysisNwbfile(SpyglassMixin, dj.Manual): # See #630, #664. Excessive key length. - def create(self, nwb_file_name: str) -> str: + def create( + self, nwb_file_name: str, recompute_file_name: str = None + ) -> str: """Open the NWB file, create copy, write to disk and return new name. Note that this does NOT add the file to the schema; that needs to be @@ -181,6 +183,8 @@ def create(self, nwb_file_name: str) -> str: ---------- nwb_file_name : str The name of an NWB file to be copied. + recompute_file_name : str, optional + The name of the file to be regenerated. Defaults to None. Returns ------- @@ -207,7 +211,9 @@ def create(self, nwb_file_name: str) -> str: else: alter_source_script = True - analysis_file_name = self.__get_new_file_name(nwb_file_name) + analysis_file_name = ( + recompute_file_name or self.__get_new_file_name(nwb_file_name) + ) # write the new file logger.info(f"Writing new NWB file {analysis_file_name}") diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 66c07d86c..5c0a46102 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -1,4 +1,5 @@ import uuid +from pathlib import Path from typing import Iterable, List, Optional, Tuple, Union import datajoint as dj @@ -185,17 +186,7 @@ def make(self, key): "nwb_file_name" ) - # DO: - # - get valid times for sort interval - # - proprocess recording - # - write recording to NWB file - sort_interval_valid_times = self._get_sort_interval_valid_times(key) - recording, timestamps = self._get_preprocessed_recording(key) - recording_nwb_file_name, recording_object_id = _write_recording_to_nwb( - recording, timestamps, nwb_file_name - ) - key["analysis_file_name"] = recording_nwb_file_name - key["object_id"] = recording_object_id + key.update(self._make_file(key)) # INSERT: # - valid times into IntervalList @@ -205,13 +196,50 @@ def make(self, key): { "nwb_file_name": nwb_file_name, "interval_list_name": key["recording_id"], - "valid_times": sort_interval_valid_times, + "valid_times": self._get_sort_interval_valid_times(key), "pipeline": "spikesorting_recording_v1", } ) AnalysisNwbfile().add(nwb_file_name, key["analysis_file_name"]) self.insert1(key) + @classmethod + def _make_file(cls, key: dict = None, recompute_file_name: str = None): + """Preprocess recording and write to NWB file. + + - Get valid times for sort interval from IntervalList + - Preprocess recording + - Write processed recording to NWB file + + Parameters + ---------- + key : dict + primary key of SpikeSortingRecordingSelection table + recompute_file_name : str, Optional + If specified, recompute this file. Use as resulting file name. + If none, generate a new file name. + """ + if not key and not recompute_file_name: + raise ValueError( + "Either key or recompute_file_name must be specified." + ) + + key = key or (cls & {"analysis_file_name": recompute_file_name}).fetch1( + "KEY" + ) + + parent = SpikeSortingRecordingSelection & key + recording_nwb_file_name, recording_object_id = _write_recording_to_nwb( + **cls()._get_preprocessed_recording(key), + nwb_file_name=parent.fetch1("nwb_file_name"), + recompute_file_name=recompute_file_name, + ) + + return dict( + analysis_file_name=recording_nwb_file_name, + object_id=recording_object_id, + ) + @classmethod def get_recording(cls, key: dict) -> si.BaseRecording: """Get recording related to this curation as spikeinterface BaseRecording @@ -221,11 +249,14 @@ def get_recording(cls, key: dict) -> si.BaseRecording: key : dict primary key of SpikeSorting table """ - analysis_file_name = (cls & key).fetch1("analysis_file_name") analysis_file_abs_path = AnalysisNwbfile.get_abs_path( analysis_file_name ) + + if not Path(analysis_file_abs_path).exists(): + cls._make_file(key, recompute_file_name=analysis_file_name) + recording = se.read_nwb_recording( analysis_file_abs_path, load_time_vector=True ) @@ -303,6 +334,8 @@ def _get_preprocessed_recording(self, key: dict): # - the reference channel # - probe type # - filter parameters + + # TODO: Reduce number of fetches nwb_file_name = (SpikeSortingRecordingSelection & key).fetch1( "nwb_file_name" ) @@ -365,7 +398,7 @@ def _get_preprocessed_recording(self, key: dict): ) all_timestamps = recording.get_times() - # TODO: make sure the following works for recordings that don't have explicit timestamps + # TODO: verify for recordings that don't have explicit timestamps valid_sort_times = self._get_sort_interval_valid_times(key) valid_sort_times_indices = _consolidate_intervals( valid_sort_times, all_timestamps @@ -451,7 +484,7 @@ def _get_preprocessed_recording(self, key: dict): tetrode.set_device_channel_indices(np.arange(4)) recording = recording.set_probe(tetrode, in_place=True) - return recording, np.asarray(timestamps) + return dict(recording=recording, timestamps=np.asarray(timestamps)) def _consolidate_intervals(intervals, timestamps): @@ -512,6 +545,7 @@ def _write_recording_to_nwb( recording: si.BaseRecording, timestamps: Iterable, nwb_file_name: str, + recompute_file_name: Optional[str] = None, ): """Write a recording in NWB format @@ -528,7 +562,9 @@ def _write_recording_to_nwb( name of analysis NWB file containing the preprocessed recording """ - analysis_nwb_file = AnalysisNwbfile().create(nwb_file_name) + analysis_nwb_file = AnalysisNwbfile().create( + nwb_file_name=nwb_file_name, recompute_file_name=recompute_file_name + ) analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(analysis_nwb_file) with pynwb.NWBHDF5IO( diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 889e64294..c24d9301f 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -287,8 +287,11 @@ def fetch_nwb(query_expression, nwb_master, *attrs, **kwargs): for file_name in nwb_files: file_path = file_path_fn(file_name) - if not os.path.exists(file_path): # retrieve the file from kachery. - # This also opens the file and stores the file object + if not os.path.exists(file_path): + if hasattr(query_expression, "_make_file"): + # Attempt to recompute the file + query_expression._make_file(recompute_file_name=file_name) + # get from kachery/dandi, store in cache get_nwb_file(file_path) query_table = query_expression * tbl.proj(nwb2load_filepath=attr_name) diff --git a/src/spyglass/utils/nwb_helper_fn.py b/src/spyglass/utils/nwb_helper_fn.py index af25ec987..175e8b8fb 100644 --- a/src/spyglass/utils/nwb_helper_fn.py +++ b/src/spyglass/utils/nwb_helper_fn.py @@ -21,6 +21,23 @@ invalid_electrode_index = 99999999 +def _open_nwb_file(nwb_file_path, source="local"): + """Open an NWB file, add to cache, return contents. Does not close file.""" + if source == "local": + io = pynwb.NWBHDF5IO(path=nwb_file_path, mode="r", load_namespaces=True) + nwbfile = io.read() + elif source == "dandi": + from ..common.common_dandi import DandiPath + + io, nwbfile = DandiPath().fetch_file_from_dandi( + nwb_file_path=nwb_file_path + ) + else: + raise ValueError(f"Invalid open_nwb source: {source}") + __open_nwb_files[nwb_file_path] = (io, nwbfile) + return nwbfile + + def get_nwb_file(nwb_file_path): """Return an NWBFile object with the given file path in read mode. @@ -40,54 +57,46 @@ def get_nwb_file(nwb_file_path): NWB file object for the given path opened in read mode. """ if not nwb_file_path.startswith("/"): - from ..common import Nwbfile + from spyglass.common import Nwbfile nwb_file_path = Nwbfile.get_abs_path(nwb_file_path) _, nwbfile = __open_nwb_files.get(nwb_file_path, (None, None)) - if nwbfile is None: - # check to see if the file exists - if not os.path.exists(nwb_file_path): - logger.info( - "NWB file not found locally; checking kachery for " - + f"{nwb_file_path}" - ) - # first try the analysis files - from ..sharing.sharing_kachery import AnalysisNwbfileKachery + if nwbfile is not None: + return nwbfile - # the download functions assume just the filename, so we need to - # get that from the path - if not AnalysisNwbfileKachery.download_file( - os.path.basename(nwb_file_path), permit_fail=True - ): - logger.info( - "NWB file not found in kachery; checking Dandi for " - + f"{nwb_file_path}" - ) - # Dandi fallback SB 2024-04-03 - from ..common.common_dandi import DandiPath - - dandi_key = {"filename": os.path.basename(nwb_file_path)} - if not DandiPath & dandi_key: - # If not in Dandi, then we can't find the file - raise FileNotFoundError( - f"NWB file not found in kachery or Dandi: {os.path.basename(nwb_file_path)}." - ) - io, nwbfile = DandiPath().fetch_file_from_dandi( - dandi_key - ) # TODO: consider case where file in multiple dandisets - __open_nwb_files[nwb_file_path] = (io, nwbfile) - return nwbfile - - # now open the file - io = pynwb.NWBHDF5IO( - path=nwb_file_path, mode="r", load_namespaces=True - ) # keep file open - nwbfile = io.read() - __open_nwb_files[nwb_file_path] = (io, nwbfile) + if os.path.exists(nwb_file_path): + return _open_nwb_file(nwb_file_path) - return nwbfile + logger.info( + f"NWB file not found locally; checking kachery for {nwb_file_path}" + ) + + from ..sharing.sharing_kachery import AnalysisNwbfileKachery + + kachery_success = AnalysisNwbfileKachery.download_file( + os.path.basename(nwb_file_path), permit_fail=True + ) + if kachery_success: + return _open_nwb_file(nwb_file_path) + + logger.info( + "NWB file not found in kachery; checking Dandi for " + + f"{nwb_file_path}" + ) + + # Dandi fallback SB 2024-04-03 + from ..common.common_dandi import DandiPath + + if DandiPath().has_file_path(file_path=nwb_file_path): + return _open_nwb_file(nwb_file_path, source="dandi") + + # If not in Dandi, then we can't find the file + raise FileNotFoundError( + "NWB file not found in kachery or Dandi: " + + f"{os.path.basename(nwb_file_path)}." + ) def file_from_dandi(filepath): From 743502db07a27896bf017fb8c9c822f1b4da994b Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 11 Sep 2024 16:25:29 -0500 Subject: [PATCH 03/21] WIP: recompute 2 --- pyproject.toml | 4 ++-- src/spyglass/common/common_nwbfile.py | 20 +++++++++++++++++--- src/spyglass/spikesorting/v1/recording.py | 2 ++ 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8db231c8d..093c123fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,9 +127,9 @@ minversion = "7.0" addopts = [ # "-sv", # no capture, verbose output "--sw", # stepwise: resume with next test after failure - # "--pdb", # drop into debugger on failure + "--pdb", # drop into debugger on failure "-p no:warnings", - # "--no-teardown", # don't teardown the database after tests + "--no-teardown", # don't teardown the database after tests # "--quiet-spy", # don't show logging from spyglass # "--no-dlc", # don't run DLC tests "--show-capture=no", diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index d2fb34a6d..82c9e57b6 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -94,7 +94,9 @@ def get_file_key(cls, nwb_file_name: str) -> dict: return {"nwb_file_name": cls._get_file_name(nwb_file_name)} @classmethod - def get_abs_path(cls, nwb_file_name: str, new_file: bool = False) -> str: + def get_abs_path( + cls, nwb_file_name: str, new_file: bool = False, **kwargs + ) -> str: """Return absolute path for a stored raw NWB file given file name. The SPYGLASS_BASE_DIR must be set, either as an environment or part of @@ -216,7 +218,8 @@ def create( ) # write the new file - logger.info(f"Writing new NWB file {analysis_file_name}") + if not recompute_file_name: + logger.info(f"Writing new NWB file {analysis_file_name}") analysis_file_abs_path = AnalysisNwbfile.get_abs_path( analysis_file_name ) @@ -330,7 +333,9 @@ def add(self, nwb_file_name: str, analysis_file_name: str) -> None: self.insert1(key) @classmethod - def get_abs_path(cls, analysis_nwb_file_name: str) -> str: + def get_abs_path( + cls, analysis_nwb_file_name: str, from_schema: bool = False + ) -> str: """Return the absolute path for an analysis NWB file given the name. The spyglass config from settings.py must be set. @@ -339,12 +344,21 @@ def get_abs_path(cls, analysis_nwb_file_name: str) -> str: ---------- analysis_nwb_file_name : str The name of the NWB file in AnalysisNwbfile. + from schema : bool, optional + If true, get the file path from the schema externals table, skipping + checksum and file existence checks. Defaults to False. Returns ------- analysis_nwb_file_abspath : str The absolute path for the given file name. """ + if from_schema: + substring = analysis_nwb_file_name.split("_")[0] + return f"{analysis_dir}/" + ( + schema.external["analysis"] & f'filepath LIKE "%{substring}%"' + ).fetch1("filepath") + # If an entry exists in the database get the stored datajoint filepath file_key = {"analysis_file_name": analysis_nwb_file_name} if cls & file_key: diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 5c0a46102..13d532cc9 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -223,6 +223,8 @@ def _make_file(cls, key: dict = None, recompute_file_name: str = None): raise ValueError( "Either key or recompute_file_name must be specified." ) + if recompute_file_name and not key: + logger.info(f"Recomputing {recompute_file_name}.") key = key or (cls & {"analysis_file_name": recompute_file_name}).fetch1( "KEY" From 9d23949d9f8b4cd213a6ae8f34e58c4e64a22b58 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 12 Sep 2024 11:10:36 -0500 Subject: [PATCH 04/21] WIP: recompute 3 --- src/spyglass/common/common_nwbfile.py | 4 ++-- src/spyglass/spikesorting/v1/recording.py | 29 ++++++++++++++++------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 82c9e57b6..b48142e65 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -354,9 +354,9 @@ def get_abs_path( The absolute path for the given file name. """ if from_schema: - substring = analysis_nwb_file_name.split("_")[0] return f"{analysis_dir}/" + ( - schema.external["analysis"] & f'filepath LIKE "%{substring}%"' + schema.external["analysis"] + & f'filepath LIKE "%{analysis_nwb_file_name}"' ).fetch1("filepath") # If an entry exists in the database get the stored datajoint filepath diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 13d532cc9..6c9faa696 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -225,16 +225,17 @@ def _make_file(cls, key: dict = None, recompute_file_name: str = None): ) if recompute_file_name and not key: logger.info(f"Recomputing {recompute_file_name}.") - - key = key or (cls & {"analysis_file_name": recompute_file_name}).fetch1( - "KEY" - ) + query = cls & {"analysis_file_name": recompute_file_name} + key, recompute_object_id = query.fetch1("KEY", "object_id") + else: + recompute_object_id = None parent = SpikeSortingRecordingSelection & key recording_nwb_file_name, recording_object_id = _write_recording_to_nwb( **cls()._get_preprocessed_recording(key), nwb_file_name=parent.fetch1("nwb_file_name"), recompute_file_name=recompute_file_name, + recompute_object_id=recompute_object_id, ) return dict( @@ -548,6 +549,7 @@ def _write_recording_to_nwb( timestamps: Iterable, nwb_file_name: str, recompute_file_name: Optional[str] = None, + recompute_object_id: Optional[str] = None, ): """Write a recording in NWB format @@ -567,7 +569,9 @@ def _write_recording_to_nwb( analysis_nwb_file = AnalysisNwbfile().create( nwb_file_name=nwb_file_name, recompute_file_name=recompute_file_name ) - analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(analysis_nwb_file) + analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_nwb_file, from_schema=bool(recompute_file_name) + ) with pynwb.NWBHDF5IO( path=analysis_nwb_file_abs_path, @@ -595,10 +599,19 @@ def _write_recording_to_nwb( conversion=np.unique(recording.get_channel_gains())[0] * 1e-6, ) nwbfile.add_acquisition(processed_electrical_series) - recording_object_id = nwbfile.acquisition[ - "ProcessedElectricalSeries" - ].object_id + + if recompute_object_id: + nwbfile.acquisition["ProcessedElectricalSeries"].object_id = ( + recompute_object_id # AttributeError: can't set attribute + ) + recompute_object_id = recompute_object_id + else: + recording_object_id = nwbfile.acquisition[ + "ProcessedElectricalSeries" + ].object_id + io.write(nwbfile) + return analysis_nwb_file, recording_object_id From 39f07bf329cddbe3b59d73b6e352af626a679d4b Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 12 Sep 2024 16:46:24 -0500 Subject: [PATCH 05/21] WIP: recompute 4 --- src/spyglass/spikesorting/v1/recording.py | 28 ++++++++++++++--------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 6c9faa696..a8a8b58e8 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -198,7 +198,8 @@ def make(self, key): "interval_list_name": key["recording_id"], "valid_times": self._get_sort_interval_valid_times(key), "pipeline": "spikesorting_recording_v1", - } + }, + skip_duplicates=True, # for recompute ) AnalysisNwbfile().add(nwb_file_name, key["analysis_file_name"]) self.insert1(key) @@ -595,23 +596,28 @@ def _write_recording_to_nwb( electrodes=table_region, timestamps=timestamps_iterator, filtering="Bandpass filtered for spike band", - description=f"Referenced and filtered recording from {nwb_file_name} for spike sorting", + description="Referenced and filtered recording from " + + f"{nwb_file_name} for spike sorting", conversion=np.unique(recording.get_channel_gains())[0] * 1e-6, ) + nwbfile.add_acquisition(processed_electrical_series) - if recompute_object_id: - nwbfile.acquisition["ProcessedElectricalSeries"].object_id = ( - recompute_object_id # AttributeError: can't set attribute - ) - recompute_object_id = recompute_object_id - else: - recording_object_id = nwbfile.acquisition[ - "ProcessedElectricalSeries" - ].object_id + recording_object_id = nwbfile.acquisition[ + "ProcessedElectricalSeries" + ].object_id io.write(nwbfile) + if recompute_object_id: + import h5py + + with h5py.File(analysis_nwb_file_abs_path, "a") as f: + f["acquisition/ProcessedElectricalSeries"].attrs[ + "object_id" + ] = recompute_object_id + recording_object_id = recompute_object_id + return analysis_nwb_file, recording_object_id From 1b38818f9d7a39c0c0b72653550e4aefffeed2b9 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 18 Sep 2024 14:14:15 -0500 Subject: [PATCH 06/21] WIP: recompute 5, electrodes object --- CHANGELOG.md | 3 +++ src/spyglass/spikesorting/v1/recording.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2cdf9be4d..1ab9a0fea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,11 @@ ```python import datajoint as dj +from spyglass.spikesorting.v1.recording import SpikeSortingRecording dj.FreeTable(dj.conn(), "common_nwbfile.analysis_nwbfile_log").drop() +SpikeSortingRecording().alter() +SpikeSortingRecording().update_ids() ``` ### Infrastructure diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index a8a8b58e8..15150c52c 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -169,8 +169,11 @@ class SpikeSortingRecording(SpyglassMixin, dj.Computed): --- -> AnalysisNwbfile object_id: varchar(40) # Object ID for the processed recording in NWB file + electrodes_id='': varchar(40) # Object ID for the processed electrodes """ + # Note: electrodes id is used for recomputing + def make(self, key): """Populate SpikeSortingRecording. @@ -490,6 +493,20 @@ def _get_preprocessed_recording(self, key: dict): return dict(recording=recording, timestamps=np.asarray(timestamps)) + def update_ids(self): + """Update object_id and electrodes_id in SpikeSortingRecording table.""" + for key in self.fetch(as_dict=True): + self.update_id(key) + raise NotImplementedError + + # import h5py + # + # with h5py.File(analysis_nwb_file_abs_path, "a") as f: + # f["acquisition/ProcessedElectricalSeries/electrodes"].attrs[ + # "object_id" + # ] = recompute_object_id + # recording_object_id = recompute_object_id + def _consolidate_intervals(intervals, timestamps): """Convert a list of intervals (start_time, stop_time) From 282d553d7bf8ff7cd80e9268bd68db1663396337 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Thu, 19 Sep 2024 17:39:14 -0500 Subject: [PATCH 07/21] WIP: recompute 6, add file hash --- CHANGELOG.md | 4 +- src/spyglass/common/common_nwbfile.py | 18 ++++ src/spyglass/spikesorting/v1/recording.py | 120 ++++++++++++++++------ src/spyglass/utils/nwb_hash.py | 78 ++++++++++++++ 4 files changed, 188 insertions(+), 32 deletions(-) create mode 100644 src/spyglass/utils/nwb_hash.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ab9a0fea..57fdcd943 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,9 +6,11 @@ + + ```python import datajoint as dj -from spyglass.spikesorting.v1.recording import SpikeSortingRecording +from spyglass.spikesorting.v1.recording import * # noqa dj.FreeTable(dj.conn(), "common_nwbfile.analysis_nwbfile_log").drop() SpikeSortingRecording().alter() diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index b48142e65..04ca01e9f 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -426,6 +426,24 @@ def add_nwb_object( io.write(nwbf) return nwb_object.object_id + def _update_external(self, analysis_file_name: str): + """Update the external contents checksum for an analysis file. + + USE WITH CAUTION. This should only be run after the file has been + verified to be correct by another method such as hashing. + + Parameters + ---------- + analysis_file_name : str + The name of the analysis NWB file. + """ + external_tbl = schema.external["analysis"] + file_path = analysis_dir + "/" + analysis_file_name + key = (external_tbl & f"filepath = '{file_path}'").fetch1(as_dict=True) + key["contents_hash"] = dj.hash.uuid_from_file(file_path) + + self.update1(key) + def add_units( self, analysis_file_name: str, diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 15150c52c..972595dbe 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -8,6 +8,7 @@ import pynwb import spikeinterface as si import spikeinterface.extractors as se +from h5py import File as H5File from hdmf.data_utils import GenericDataChunkIterator from spyglass.common import Session # noqa: F401 @@ -19,12 +20,14 @@ ) from spyglass.common.common_lab import LabTeam from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile +from spyglass.common.common_nwbfile import schema as nwb_schema from spyglass.settings import test_mode from spyglass.spikesorting.utils import ( _get_recording_timestamps, get_group_by_shank, ) from spyglass.utils import SpyglassMixin, logger +from spyglass.utils.nwb_hash import NwbfileHasher schema = dj.schema("spikesorting_v1_recording") @@ -170,9 +173,14 @@ class SpikeSortingRecording(SpyglassMixin, dj.Computed): -> AnalysisNwbfile object_id: varchar(40) # Object ID for the processed recording in NWB file electrodes_id='': varchar(40) # Object ID for the processed electrodes + file_hash='': varchar(32) # Hash of the NWB file """ - # Note: electrodes id is used for recomputing + # QUESTION: Should this file_hash be a (hidden) attr of AnalysisNwbfile? + # Hidden would require the pre-release datajoint version. + # Adding it there would centralize recompute abilities, but maybe + # that's excessive if we only plan recompute for a handful of + # tables. def make(self, key): """Populate SpikeSortingRecording. @@ -223,6 +231,7 @@ def _make_file(cls, key: dict = None, recompute_file_name: str = None): If specified, recompute this file. Use as resulting file name. If none, generate a new file name. """ + file_hash = None if not key and not recompute_file_name: raise ValueError( "Either key or recompute_file_name must be specified." @@ -230,21 +239,40 @@ def _make_file(cls, key: dict = None, recompute_file_name: str = None): if recompute_file_name and not key: logger.info(f"Recomputing {recompute_file_name}.") query = cls & {"analysis_file_name": recompute_file_name} - key, recompute_object_id = query.fetch1("KEY", "object_id") + key, recompute_object_id, recompute_electrodes_id, file_hash = ( + query.fetch1("KEY", "object_id", "electrodes_id") + ) else: - recompute_object_id = None + recompute_object_id, recompute_electrodes_id = None, None parent = SpikeSortingRecordingSelection & key - recording_nwb_file_name, recording_object_id = _write_recording_to_nwb( - **cls()._get_preprocessed_recording(key), - nwb_file_name=parent.fetch1("nwb_file_name"), - recompute_file_name=recompute_file_name, - recompute_object_id=recompute_object_id, + recording_nwb_file_name, recording_object_id, electrodes_id = ( + _write_recording_to_nwb( + **cls()._get_preprocessed_recording(key), + nwb_file_name=parent.fetch1("nwb_file_name"), + recompute_file_name=recompute_file_name, + recompute_object_id=recompute_object_id, + recompute_electrodes_id=recompute_electrodes_id, + ) ) + # check hash + if file_hash is not None: + file_path = AnalysisNwbfile.get_abs_path(recompute_file_name) + new_hash = NwbfileHasher(file_path).hash + if not file_hash == new_hash: + Path(file_path).unlink() # remove mismatched file + ( + AnalysisNwbfile + & {"analysis_file_name": recompute_file_name} + ).super_delete(safemode=False) + raise ValueError(f"Failed to recompute {recompute_file_name}.") + AnalysisNwbfile._update_external(recompute_file_name) + return dict( analysis_file_name=recording_nwb_file_name, object_id=recording_object_id, + electrodes_id=electrodes_id, ) @classmethod @@ -494,18 +522,22 @@ def _get_preprocessed_recording(self, key: dict): return dict(recording=recording, timestamps=np.asarray(timestamps)) def update_ids(self): - """Update object_id and electrodes_id in SpikeSortingRecording table.""" - for key in self.fetch(as_dict=True): - self.update_id(key) - raise NotImplementedError + """Update electrodes_id, and file_hash in SpikeSortingRecording table. + + Only used for transitioning to recompute NWB files, see #1093. + """ + elect_attr = "acquisition/ProcessedElectricalSeries/electrodes" + for key in (self & "electrodes_id=''").fetch(as_dict=True): + analysis_file_path = AnalysisNwbfile.get_abs_path( + key["analysis_file_name"] + ) + with H5File(analysis_file_path, "r") as f: + elect_id = f[elect_attr].attrs["object_id"] + key["electrodes_id"] = elect_id - # import h5py - # - # with h5py.File(analysis_nwb_file_abs_path, "a") as f: - # f["acquisition/ProcessedElectricalSeries/electrodes"].attrs[ - # "object_id" - # ] = recompute_object_id - # recording_object_id = recompute_object_id + key["file_hash"] = NwbfileHasher(analysis_file_path).hash + + self.update1(key) def _consolidate_intervals(intervals, timestamps): @@ -568,6 +600,7 @@ def _write_recording_to_nwb( nwb_file_name: str, recompute_file_name: Optional[str] = None, recompute_object_id: Optional[str] = None, + recompute_electrodes_id: Optional[str] = None, ): """Write a recording in NWB format @@ -577,18 +610,43 @@ def _write_recording_to_nwb( timestamps : iterable nwb_file_name : str name of NWB file the recording originates + recompute_file_name : str, optional + name of the NWB file to recompute + recompute_object_id : str, optional + object ID for recomputed processed electrical series object, + acquisition/ProcessedElectricalSeries. + recompute_electrodes_id : str, optional + object ID for recomputed electrodes sub-object, + acquisition/ProcessedElectricalSeries/electrodes. Returns ------- analysis_nwb_file : str name of analysis NWB file containing the preprocessed recording """ + recompute_args = ( + recompute_file_name, + recompute_object_id, + recompute_electrodes_id, + ) + recompute = any(recompute_args) + if recompute and not all(recompute_args): + raise ValueError( + "If recomputing, must specify all of recompute_file_name, " + "recompute_object_id, and recompute_electrodes_id." + ) + + series_name = "ProcessedElectricalSeries" + series_attr = "acquisition/" + series_name + elect_attr = series_attr + "/electrodes" analysis_nwb_file = AnalysisNwbfile().create( - nwb_file_name=nwb_file_name, recompute_file_name=recompute_file_name + nwb_file_name=nwb_file_name, + recompute_file_name=recompute_file_name, ) analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path( - analysis_nwb_file, from_schema=bool(recompute_file_name) + analysis_nwb_file, + from_schema=recompute, ) with pynwb.NWBHDF5IO( @@ -620,22 +678,22 @@ def _write_recording_to_nwb( nwbfile.add_acquisition(processed_electrical_series) - recording_object_id = nwbfile.acquisition[ - "ProcessedElectricalSeries" - ].object_id + recording_object_id = nwbfile.acquisition[series_name].object_id + electrodes_id = nwbfile.acquisition[series_name].electrodes.object_id + # how get elect id? + __import__("pdb").set_trace() + # how get elect id? io.write(nwbfile) if recompute_object_id: - import h5py - - with h5py.File(analysis_nwb_file_abs_path, "a") as f: - f["acquisition/ProcessedElectricalSeries"].attrs[ - "object_id" - ] = recompute_object_id + with H5File(analysis_nwb_file_abs_path, "a") as f: + f[series_attr].attrs["object_id"] = recompute_object_id + f[elect_attr].attrs["object_id"] = recompute_electrodes_id recording_object_id = recompute_object_id + electrodes_id = recompute_electrodes_id - return analysis_nwb_file, recording_object_id + return analysis_nwb_file, recording_object_id, electrodes_id # For writing recording to NWB file diff --git a/src/spyglass/utils/nwb_hash.py b/src/spyglass/utils/nwb_hash.py new file mode 100644 index 000000000..d45d290e4 --- /dev/null +++ b/src/spyglass/utils/nwb_hash.py @@ -0,0 +1,78 @@ +from functools import cached_property +from hashlib import md5 +from pathlib import Path +from typing import Union + +import h5py +import numpy as np + + +class NwbfileHasher: + def __init__(self, path: Union[str, Path], data_limit: int = 4095): + """Hashes the contents of an NWB file, limiting to partial data. + + In testing, chunking the data for large datasets caused false positives + in the hash comparison, and some datasets may be too large to store in + memory. This method limits the data to the first N elements to avoid + this issue, and may not be suitable for all datasets. + + Parameters + ---------- + path : Union[str, Path] + Path to the NWB file. + data_limit : int, optional + Limit of data to hash for large datasets, by default 4095. + """ + self.file = h5py.File(path, "r") + self.data_limit = data_limit + + def collect_names(self, file): + """Collects all object names in the file.""" + + def collect_items(name, obj): + items_to_process.append((name, obj)) + + items_to_process = [] + file.visititems(collect_items) + items_to_process.sort(key=lambda x: x[0]) + return items_to_process + + def serialize_attr_value(self, value): + """Serializes an attribute value into bytes for hashing. + + Setting all numpy array types to string avoids false positives. + + Parameters + ---------- + value : Any + Attribute value. + + Returns + ------- + bytes + Serialized bytes of the attribute value. + """ + if isinstance(value, np.ndarray): + return value.astype(str).tobytes() + elif isinstance(value, (str, int, float)): + return str(value).encode() + return repr(value).encode() # For other data types, use repr + + @cached_property + def hash(self) -> str: + """Hashes the NWB file contents, limiting to partal data where large.""" + hashed = md5("".encode()) + for name, obj in self.collect_names(self.file): + if isinstance(obj, h5py.Dataset): # hash the dataset name and shape + hashed.update(str(obj.shape).encode()) + hashed.update(str(obj.dtype).encode()) + partial_data = ( # full if scalar dataset, else use data_limit + obj[()] if obj.shape == () else obj[: self.data_limit] + ) + hashed.update(self.serialize_attr_value(partial_data)) + for attr_key in sorted(obj.attrs): + attr_value = obj.attrs[attr_key] + hashed.update(attr_key.encode()) + hashed.update(self.serialize_attr_value(attr_value)) + self.file.close() + return hashed.hexdigest() From 94168de1ffb00f52619d558efb25e4ec7983f007 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 20 Sep 2024 12:04:02 -0500 Subject: [PATCH 08/21] WIP: recompute 7 --- src/spyglass/common/common_nwbfile.py | 2 +- src/spyglass/spikesorting/v1/recording.py | 13 +++++++------ tests/spikesorting/test_recording.py | 23 +++++++++++++++++++++++ 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 04ca01e9f..554f80008 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -353,7 +353,7 @@ def get_abs_path( analysis_nwb_file_abspath : str The absolute path for the given file name. """ - if from_schema: + if from_schema: # Skips checksum and file existence checks return f"{analysis_dir}/" + ( schema.external["analysis"] & f'filepath LIKE "%{analysis_nwb_file_name}"' diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 972595dbe..001577ba2 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -240,7 +240,7 @@ def _make_file(cls, key: dict = None, recompute_file_name: str = None): logger.info(f"Recomputing {recompute_file_name}.") query = cls & {"analysis_file_name": recompute_file_name} key, recompute_object_id, recompute_electrodes_id, file_hash = ( - query.fetch1("KEY", "object_id", "electrodes_id") + query.fetch1("KEY", "object_id", "electrodes_id", "file_hash") ) else: recompute_object_id, recompute_electrodes_id = None, None @@ -258,10 +258,13 @@ def _make_file(cls, key: dict = None, recompute_file_name: str = None): # check hash if file_hash is not None: - file_path = AnalysisNwbfile.get_abs_path(recompute_file_name) + file_path = AnalysisNwbfile.get_abs_path( + recompute_file_name, from_schema=True + ) new_hash = NwbfileHasher(file_path).hash if not file_hash == new_hash: Path(file_path).unlink() # remove mismatched file + # force delete, including all downstream ( AnalysisNwbfile & {"analysis_file_name": recompute_file_name} @@ -527,7 +530,8 @@ def update_ids(self): Only used for transitioning to recompute NWB files, see #1093. """ elect_attr = "acquisition/ProcessedElectricalSeries/electrodes" - for key in (self & "electrodes_id=''").fetch(as_dict=True): + needs_update = self & ["electrodes_id=''", "file_hash=''"] + for key in needs_update.fetch(as_dict=True): analysis_file_path = AnalysisNwbfile.get_abs_path( key["analysis_file_name"] ) @@ -680,9 +684,6 @@ def _write_recording_to_nwb( recording_object_id = nwbfile.acquisition[series_name].object_id electrodes_id = nwbfile.acquisition[series_name].electrodes.object_id - # how get elect id? - __import__("pdb").set_trace() - # how get elect id? io.write(nwbfile) diff --git a/tests/spikesorting/test_recording.py b/tests/spikesorting/test_recording.py index 780cbc46c..dc4f98c64 100644 --- a/tests/spikesorting/test_recording.py +++ b/tests/spikesorting/test_recording.py @@ -1,3 +1,6 @@ +from pathlib import Path + + def test_sort_group(spike_v1, pop_rec): max_id = max(spike_v1.SortGroup.fetch("sort_group_id")) assert ( @@ -8,3 +11,23 @@ def test_sort_group(spike_v1, pop_rec): def test_spike_sorting(spike_v1, pop_rec): n_records = len(spike_v1.SpikeSortingRecording()) assert n_records == 1, "SpikeSortingRecording failed to insert a record" + + +def test_recompute(spike_v1, pop_rec, common): + key = spike_v1.SpikeSortingRecording().fetch( + "analysis_file_name", as_dict=True + )[0] + restr_tbl = spike_v1.SpikeSortingRecording() & key + pre = restr_tbl.fetch_nwb()[0]["object_id"] + + file_path = common.AnalysisNwbfile.get_abs_path(key["analysis_file_name"]) + Path(file_path).unlink() # delete the file to force recompute + + post = restr_tbl.fetch_nwb()[0]["object_id"] # trigger recompute + + assert ( + pre.object_id == post.object_id + and pre.electrodes.object_id == post.electrodes.object_id + ), "Recompute failed to preserve object_ids" + + __import__("pdb").set_trace() From a5947865ff272a93e1996b1f0b0cdcc85f8d13e4 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 20 Sep 2024 16:30:17 -0500 Subject: [PATCH 09/21] =?UTF-8?q?=20=E2=9C=85=20:=20recompute?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/spyglass/common/common_nwbfile.py | 64 ++++++++++++++++++++--- src/spyglass/spikesorting/v1/recording.py | 33 ++++++------ src/spyglass/utils/database_settings.py | 2 +- src/spyglass/utils/dj_helper_fn.py | 2 +- tests/spikesorting/test_recording.py | 2 - 5 files changed, 76 insertions(+), 27 deletions(-) diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 554f80008..10e2ac9e6 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -17,6 +17,7 @@ from spyglass.settings import analysis_dir, raw_dir from spyglass.utils import SpyglassMixin, logger from spyglass.utils.dj_helper_fn import get_child_tables +from spyglass.utils.nwb_hash import NwbfileHasher from spyglass.utils.nwb_helper_fn import get_electrode_indices, get_nwb_file schema = dj.schema("common_nwbfile") @@ -426,23 +427,72 @@ def add_nwb_object( io.write(nwbf) return nwb_object.object_id - def _update_external(self, analysis_file_name: str): + def get_hash( + self, analysis_file_name: str, from_schema: bool = False + ) -> str: + """Return the hash of the file contents. + + Parameters + ---------- + analysis_file_name : str + The name of the analysis NWB file. + from_schema : bool, Optional + If true, get the file path from the schema externals table, skipping + checksum and file existence checks. Defaults to False. + + + Returns + ------- + file_hash : str + The hash of the file contents. + """ + return NwbfileHasher( + self.get_abs_path(analysis_file_name, from_schema=from_schema) + ).hash + + def _update_external(self, analysis_file_name: str, file_hash: str): """Update the external contents checksum for an analysis file. - USE WITH CAUTION. This should only be run after the file has been - verified to be correct by another method such as hashing. + USE WITH CAUTION. If the hash does not match the file contents, the file + and downstream entries are deleted. Parameters ---------- analysis_file_name : str The name of the analysis NWB file. + file_hash : str + The hash of the file contents as calculated by NwbfileHasher. + If the hash does not match the file contents, the file and + downstream entries are deleted. """ + file_path = self.get_abs_path(analysis_file_name, from_schema=True) + new_hash = self.get_hash(analysis_file_name, from_schema=True) + + if file_hash != new_hash: + Path(file_path).unlink() # remove mismatched file + # force delete, including all downstream, forcing permissions + del_kwargs = dict(force_permission=True, safemode=False) + if self._has_updated_dj_version: + del_kwargs["force_masters"] = True + query = self & {"analysis_file_name": analysis_file_name} + query.delete(**del_kwargs) + raise ValueError( + f"Failed to recompute {analysis_file_name}.", + "Could not exactly replicate file content.", + "Please re-populate from parent table.", + ) + external_tbl = schema.external["analysis"] - file_path = analysis_dir + "/" + analysis_file_name - key = (external_tbl & f"filepath = '{file_path}'").fetch1(as_dict=True) - key["contents_hash"] = dj.hash.uuid_from_file(file_path) + file_path = ( + self.__get_analysis_file_dir(analysis_file_name) + + f"/{analysis_file_name}" + ) + key = (external_tbl & f"filepath = '{file_path}'").fetch1() + abs_path = Path(analysis_dir) / file_path + key["contents_hash"] = dj.hash.uuid_from_file(abs_path) + key["size"] = abs_path.stat().st_size - self.update1(key) + external_tbl.update1(key) def add_units( self, diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 001577ba2..94bf9cf1a 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -219,6 +219,8 @@ def make(self, key): def _make_file(cls, key: dict = None, recompute_file_name: str = None): """Preprocess recording and write to NWB file. + All `_make_file` methods should exit early if the file already exists. + - Get valid times for sort interval from IntervalList - Preprocess recording - Write processed recording to NWB file @@ -236,7 +238,12 @@ def _make_file(cls, key: dict = None, recompute_file_name: str = None): raise ValueError( "Either key or recompute_file_name must be specified." ) - if recompute_file_name and not key: + elif recompute := bool(recompute_file_name and not key): + file_path = AnalysisNwbfile.get_abs_path( + recompute_file_name, from_schema=True + ) + if Path(file_path).exists(): + return logger.info(f"Recomputing {recompute_file_name}.") query = cls & {"analysis_file_name": recompute_file_name} key, recompute_object_id, recompute_electrodes_id, file_hash = ( @@ -246,7 +253,7 @@ def _make_file(cls, key: dict = None, recompute_file_name: str = None): recompute_object_id, recompute_electrodes_id = None, None parent = SpikeSortingRecordingSelection & key - recording_nwb_file_name, recording_object_id, electrodes_id = ( + (recording_nwb_file_name, recording_object_id, electrodes_id) = ( _write_recording_to_nwb( **cls()._get_preprocessed_recording(key), nwb_file_name=parent.fetch1("nwb_file_name"), @@ -256,26 +263,18 @@ def _make_file(cls, key: dict = None, recompute_file_name: str = None): ) ) - # check hash - if file_hash is not None: - file_path = AnalysisNwbfile.get_abs_path( - recompute_file_name, from_schema=True + if recompute: + AnalysisNwbfile()._update_external(recompute_file_name, file_hash) + else: + file_hash = AnalysisNwbfile().get_hash( + recording_nwb_file_name, from_schema=False ) - new_hash = NwbfileHasher(file_path).hash - if not file_hash == new_hash: - Path(file_path).unlink() # remove mismatched file - # force delete, including all downstream - ( - AnalysisNwbfile - & {"analysis_file_name": recompute_file_name} - ).super_delete(safemode=False) - raise ValueError(f"Failed to recompute {recompute_file_name}.") - AnalysisNwbfile._update_external(recompute_file_name) return dict( analysis_file_name=recording_nwb_file_name, object_id=recording_object_id, electrodes_id=electrodes_id, + file_hash=file_hash, ) @classmethod @@ -688,11 +687,13 @@ def _write_recording_to_nwb( io.write(nwbfile) if recompute_object_id: + logger.info(f"Recomputed {recompute_file_name}, fixing object IDs.") with H5File(analysis_nwb_file_abs_path, "a") as f: f[series_attr].attrs["object_id"] = recompute_object_id f[elect_attr].attrs["object_id"] = recompute_electrodes_id recording_object_id = recompute_object_id electrodes_id = recompute_electrodes_id + analysis_nwb_file = recompute_file_name return analysis_nwb_file, recording_object_id, electrodes_id diff --git a/src/spyglass/utils/database_settings.py b/src/spyglass/utils/database_settings.py index e7f36479e..56e2aa28f 100755 --- a/src/spyglass/utils/database_settings.py +++ b/src/spyglass/utils/database_settings.py @@ -15,7 +15,7 @@ "spikesorting", "decoding", "position", - "position_linearization", + "linearization", "ripple", "lfp", "waveform", diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index c24d9301f..835e27edd 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -286,7 +286,7 @@ def fetch_nwb(query_expression, nwb_master, *attrs, **kwargs): ) for file_name in nwb_files: - file_path = file_path_fn(file_name) + file_path = file_path_fn(file_name, from_schema=True) if not os.path.exists(file_path): if hasattr(query_expression, "_make_file"): # Attempt to recompute the file diff --git a/tests/spikesorting/test_recording.py b/tests/spikesorting/test_recording.py index dc4f98c64..89e905ee1 100644 --- a/tests/spikesorting/test_recording.py +++ b/tests/spikesorting/test_recording.py @@ -29,5 +29,3 @@ def test_recompute(spike_v1, pop_rec, common): pre.object_id == post.object_id and pre.electrodes.object_id == post.electrodes.object_id ), "Recompute failed to preserve object_ids" - - __import__("pdb").set_trace() From 6d0df075d23bec5726fd755db000683d1fbb46f6 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Mon, 21 Oct 2024 15:55:47 -0500 Subject: [PATCH 10/21] Handle groups and links --- src/spyglass/utils/nwb_hash.py | 107 ++++++++++++++++++++++++++------- 1 file changed, 84 insertions(+), 23 deletions(-) diff --git a/src/spyglass/utils/nwb_hash.py b/src/spyglass/utils/nwb_hash.py index d45d290e4..ce3ffb619 100644 --- a/src/spyglass/utils/nwb_hash.py +++ b/src/spyglass/utils/nwb_hash.py @@ -1,14 +1,20 @@ -from functools import cached_property +import atexit from hashlib import md5 from pathlib import Path -from typing import Union +from typing import Any, Union import h5py import numpy as np +from tqdm import tqdm class NwbfileHasher: - def __init__(self, path: Union[str, Path], data_limit: int = 4095): + def __init__( + self, + path: Union[str, Path], + batch_size: int = 4095, + verbose: bool = True, + ): """Hashes the contents of an NWB file, limiting to partial data. In testing, chunking the data for large datasets caused false positives @@ -20,24 +26,42 @@ def __init__(self, path: Union[str, Path], data_limit: int = 4095): ---------- path : Union[str, Path] Path to the NWB file. - data_limit : int, optional + batch_size : int, optional Limit of data to hash for large datasets, by default 4095. + verbose : bool, optional + Display progress bar, by default True. """ self.file = h5py.File(path, "r") - self.data_limit = data_limit + atexit.register(self.cleanup) + + self.batch_size = batch_size + self.verbose = verbose + self.hashed = md5("".encode()) + self.hash = self.compute_hash() + + self.cleanup() + atexit.unregister(self.cleanup) + + def cleanup(self): + self.file.close() def collect_names(self, file): """Collects all object names in the file.""" def collect_items(name, obj): - items_to_process.append((name, obj)) + if isinstance(file.get(name, getclass=True), h5py.SoftLink): + print("SoftLink:", name) + items_to_process.append((name, file.get(name, getclass=True))) + __import__("pdb").set_trace() + else: + items_to_process.append((name, obj)) items_to_process = [] file.visititems(collect_items) items_to_process.sort(key=lambda x: x[0]) return items_to_process - def serialize_attr_value(self, value): + def serialize_attr_value(self, value: Any): """Serializes an attribute value into bytes for hashing. Setting all numpy array types to string avoids false positives. @@ -53,26 +77,63 @@ def serialize_attr_value(self, value): Serialized bytes of the attribute value. """ if isinstance(value, np.ndarray): - return value.astype(str).tobytes() + return value.astype(str).tobytes() # Try with and without `str` elif isinstance(value, (str, int, float)): return str(value).encode() return repr(value).encode() # For other data types, use repr - @cached_property - def hash(self) -> str: + def hash_dataset(self, dataset: h5py.Dataset): + _ = self.hash_shape_dtype(dataset) + + if dataset.shape == (): + self.hashed.update(self.serialize_attr_value(dataset[()])) + return + + size = dataset.shape[0] + start = 0 + + while start < size: + end = min(start + self.batch_size, size) + self.hashed.update(self.serialize_attr_value(dataset[start:end])) + start = end + + def hash_shape_dtype(self, obj: [h5py.Dataset, np.ndarray]) -> str: + if not hasattr(obj, "shape") or not hasattr(obj, "dtype"): + return + self.hashed.update(str(obj.shape).encode() + str(obj.dtype).encode()) + + def compute_hash(self) -> str: """Hashes the NWB file contents, limiting to partal data where large.""" - hashed = md5("".encode()) - for name, obj in self.collect_names(self.file): - if isinstance(obj, h5py.Dataset): # hash the dataset name and shape - hashed.update(str(obj.shape).encode()) - hashed.update(str(obj.dtype).encode()) - partial_data = ( # full if scalar dataset, else use data_limit - obj[()] if obj.shape == () else obj[: self.data_limit] - ) - hashed.update(self.serialize_attr_value(partial_data)) + # Dev note: fallbacks if slow: 1) read_direct_chunk, 2) read from offset + + for name, obj in tqdm( + self.collect_names(self.file), + desc=self.file.filename.split("/")[-1].split(".")[0], + disable=not self.verbose, + ): + if "basic" in name: + __import__("pdb").set_trace() + self.hashed.update(name.encode()) for attr_key in sorted(obj.attrs): attr_value = obj.attrs[attr_key] - hashed.update(attr_key.encode()) - hashed.update(self.serialize_attr_value(attr_value)) - self.file.close() - return hashed.hexdigest() + _ = self.hash_shape_dtype(attr_value) + self.hashed.update(attr_key.encode()) + self.hashed.update(self.serialize_attr_value(attr_value)) + + if isinstance(obj, h5py.Dataset): + _ = self.hash_dataset(obj) + elif isinstance(obj, h5py.SoftLink): + # TODO: Check that this works + self.hashed.update(obj.path.encode()) + print("SoftLink:", obj.path) + elif isinstance(obj, h5py.Group): + for k, v in obj.items(): + self.hashed.update(k.encode()) + self.hashed.update(self.serialize_attr_value(v)) + else: + raise TypeError( + f"Unknown object type: {type(obj)}\n" + + "Please report this an issue on GitHub." + ) + + return self.hashed.hexdigest() From 15879977dc07c68f484602412ad0b077824a7a0a Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 22 Oct 2024 10:03:58 -0500 Subject: [PATCH 11/21] Remove debug --- src/spyglass/utils/nwb_hash.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/spyglass/utils/nwb_hash.py b/src/spyglass/utils/nwb_hash.py index ce3ffb619..c5c5081d4 100644 --- a/src/spyglass/utils/nwb_hash.py +++ b/src/spyglass/utils/nwb_hash.py @@ -49,12 +49,7 @@ def collect_names(self, file): """Collects all object names in the file.""" def collect_items(name, obj): - if isinstance(file.get(name, getclass=True), h5py.SoftLink): - print("SoftLink:", name) - items_to_process.append((name, file.get(name, getclass=True))) - __import__("pdb").set_trace() - else: - items_to_process.append((name, obj)) + items_to_process.append((name, obj)) items_to_process = [] file.visititems(collect_items) @@ -111,8 +106,6 @@ def compute_hash(self) -> str: desc=self.file.filename.split("/")[-1].split(".")[0], disable=not self.verbose, ): - if "basic" in name: - __import__("pdb").set_trace() self.hashed.update(name.encode()) for attr_key in sorted(obj.attrs): attr_value = obj.attrs[attr_key] @@ -125,7 +118,6 @@ def compute_hash(self) -> str: elif isinstance(obj, h5py.SoftLink): # TODO: Check that this works self.hashed.update(obj.path.encode()) - print("SoftLink:", obj.path) elif isinstance(obj, h5py.Group): for k, v in obj.items(): self.hashed.update(k.encode()) From 1ed831e2f00b401f1937c9b4fbd7843a2ee12616 Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Tue, 12 Nov 2024 12:46:19 -0800 Subject: [PATCH 12/21] Add directory hasher --- CHANGELOG.md | 1 + src/spyglass/utils/nwb_hash.py | 35 +++++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 25ead1249..1ad170f18 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ SpikeSortingRecording().update_ids() - Update DataJoint install and password instructions #1131 - Fix dandi upload process for nwb's with video or linked objects #1095, #1151 - Minor docs fixes #1145 +- Add Nwb hashing tool #1093 - Test fixes - Remove stored hashes from pytests #1152 - Remove mambaforge from tests #1153 diff --git a/src/spyglass/utils/nwb_hash.py b/src/spyglass/utils/nwb_hash.py index ce3ffb619..a8eadc414 100644 --- a/src/spyglass/utils/nwb_hash.py +++ b/src/spyglass/utils/nwb_hash.py @@ -7,12 +7,45 @@ import numpy as np from tqdm import tqdm +DEFAULT_BATCH_SIZE = 4095 + + +def hash_directory(directory_path: str, batch_size: int = DEFAULT_BATCH_SIZE): + """Generate a hash of the contents of a directory, recursively. + + Searches though all files in the directory and subdirectories, hashing + the contents of files. nwb files are hashed with the NwbfileHasher class. + + Parameters + ---------- + directory_path : str + Path to the directory to hash. + batch_size : int, optional + Limit of data to hash for large files, by default 4095. + """ + hash_obj = md5() + + for file_path in sorted(Path(directory_path).rglob("*")): + if not file_path.is_file(): # Only hash files, not directories + continue + if file_path.suffix == ".nwb": + hasher = NwbfileHasher(file_path, batch_size=batch_size) + hash_obj.update(hasher.hash.encode()) + continue + with file_path.open("rb") as f: + while chunk := f.read(batch_size): + hash_obj.update(chunk) + # update with the rel path to for same file in diff dirs + hash_obj.update(str(file_path.relative_to(directory_path)).encode()) + + return hash_obj.hexdigest() # Return the hex digest of the hash + class NwbfileHasher: def __init__( self, path: Union[str, Path], - batch_size: int = 4095, + batch_size: int = DEFAULT_BATCH_SIZE, verbose: bool = True, ): """Hashes the contents of an NWB file, limiting to partial data. From d0011bf80f8f6d900ea78a720cd32175104da84d Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Wed, 13 Nov 2024 11:54:50 -0800 Subject: [PATCH 13/21] Update directory hasher --- src/spyglass/utils/nwb_hash.py | 118 +++++++++++++++++++++++---------- 1 file changed, 84 insertions(+), 34 deletions(-) diff --git a/src/spyglass/utils/nwb_hash.py b/src/spyglass/utils/nwb_hash.py index f23b018b0..1c9c644e3 100644 --- a/src/spyglass/utils/nwb_hash.py +++ b/src/spyglass/utils/nwb_hash.py @@ -1,4 +1,5 @@ import atexit +import json from hashlib import md5 from pathlib import Path from typing import Any, Union @@ -8,37 +9,91 @@ from tqdm import tqdm DEFAULT_BATCH_SIZE = 4095 +IGNORED_KEYS = ["version"] -def hash_directory(directory_path: str, batch_size: int = DEFAULT_BATCH_SIZE): - """Generate a hash of the contents of a directory, recursively. - - Searches though all files in the directory and subdirectories, hashing - the contents of files. nwb files are hashed with the NwbfileHasher class. - - Parameters - ---------- - directory_path : str - Path to the directory to hash. - batch_size : int, optional - Limit of data to hash for large files, by default 4095. - """ - hash_obj = md5() - - for file_path in sorted(Path(directory_path).rglob("*")): - if not file_path.is_file(): # Only hash files, not directories - continue - if file_path.suffix == ".nwb": - hasher = NwbfileHasher(file_path, batch_size=batch_size) - hash_obj.update(hasher.hash.encode()) - continue +class DirectoryHasher: + def __init__( + self, + directory_path: Union[str, Path], + batch_size: int = DEFAULT_BATCH_SIZE, + verbose: bool = False, + ): + """Generate a hash of the contents of a directory, recursively. + + Searches though all files in the directory and subdirectories, hashing + the contents of files. nwb files are hashed with the NwbfileHasher + class. JSON files are hashed by encoding the contents, ignoring + specific keys, like 'version'. All other files are hashed by reading + the file in chunks. + + If the contents of a json file is otherwise the same, but the 'version' + value is different, we assume that the dependency change had no effect + on the data and ignore the difference. + + Parameters + ---------- + directory_path : str + Path to the directory to hash. + batch_size : int, optional + Limit of data to hash for large files, by default 4095. + """ + + self.dir_path = Path(directory_path) + self.batch_size = batch_size + self.verbose = verbose + self.hashed = md5("".encode()) + self.hash = self.compute_hash() + + def compute_hash(self) -> str: + """Hashes the contents of the directory, recursively.""" + all_files = [f for f in sorted(self.dir_path.rglob("*")) if f.is_file()] + + for file_path in tqdm(all_files, disable=not self.verbose): + if file_path.suffix == ".nwb": + hasher = NwbfileHasher(file_path, batch_size=batch_size) + self.hashed.update(hasher.hash.encode()) + elif file_path.suffix == ".json": + self.hashed.update(self.json_encode(file_path)) + else: + self.chunk_encode(file_path) + + # update with the rel path to for same file in diff dirs + rel_path = str(file_path.relative_to(self.dir_path)) + self.hashed.update(rel_path.encode()) + + if self.verbose: + print(f"{file_path.name}: {self.hased.hexdigest()}") + + return self.hashed.hexdigest() # Return the hex digest of the hash + + def chunk_encode(self, file_path: Path) -> str: + """Encode the contents of a file in chunks for hashing.""" with file_path.open("rb") as f: - while chunk := f.read(batch_size): - hash_obj.update(chunk) - # update with the rel path to for same file in diff dirs - hash_obj.update(str(file_path.relative_to(directory_path)).encode()) + while chunk := f.read(self.batch_size): + self.hashed.update(chunk) + + def json_encode(self, file_path: Path) -> str: + """Encode the contents of a json file for hashing. - return hash_obj.hexdigest() # Return the hex digest of the hash + Ignores the 'version' key(s) in the json file. + """ + with file_path.open("r") as f: + file_data = json.load(f, object_hook=self.pop_version) + return json.dumps(file_data, sort_keys=True).encode() + + def pop_version(self, data: Union[dict, list]) -> Union[dict, list]: + """Recursively remove banned keys from any nested dicts/lists.""" + if isinstance(data, dict): + return { + k: self.pop_version(v) + for k, v in data.items() + if k not in IGNORED_KEYS + } + elif isinstance(data, list): + return [self.pop_version(item) for item in data] + else: + return data class NwbfileHasher: @@ -50,11 +105,6 @@ def __init__( ): """Hashes the contents of an NWB file, limiting to partial data. - In testing, chunking the data for large datasets caused false positives - in the hash comparison, and some datasets may be too large to store in - memory. This method limits the data to the first N elements to avoid - this issue, and may not be suitable for all datasets. - Parameters ---------- path : Union[str, Path] @@ -105,7 +155,7 @@ def serialize_attr_value(self, value: Any): Serialized bytes of the attribute value. """ if isinstance(value, np.ndarray): - return value.astype(str).tobytes() # Try with and without `str` + return value.astype(str).tobytes() # must be 'astype(str)' elif isinstance(value, (str, int, float)): return str(value).encode() return repr(value).encode() # For other data types, use repr @@ -140,6 +190,7 @@ def compute_hash(self) -> str: disable=not self.verbose, ): self.hashed.update(name.encode()) + for attr_key in sorted(obj.attrs): attr_value = obj.attrs[attr_key] _ = self.hash_shape_dtype(attr_value) @@ -149,7 +200,6 @@ def compute_hash(self) -> str: if isinstance(obj, h5py.Dataset): _ = self.hash_dataset(obj) elif isinstance(obj, h5py.SoftLink): - # TODO: Check that this works self.hashed.update(obj.path.encode()) elif isinstance(obj, h5py.Group): for k, v in obj.items(): From ad7c74a80ef8778c9b7dc7f6d31431f8c71d42ad Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Wed, 8 Jan 2025 11:39:08 -0800 Subject: [PATCH 14/21] WIP: update hasher --- src/spyglass/common/common_nwbfile.py | 6 +- src/spyglass/common/common_usage.py | 37 +++++---- .../spikesorting/v0/spikesorting_recording.py | 3 + src/spyglass/spikesorting/v1/recording.py | 44 ++++++---- src/spyglass/utils/nwb_hash.py | 83 +++++++++++++++++-- 5 files changed, 135 insertions(+), 38 deletions(-) diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 39e3ac2fe..e250ecec9 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -221,11 +221,15 @@ def create( # write the new file if not recompute_file_name: logger.info(f"Writing new NWB file {analysis_file_name}") + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( - analysis_file_name + analysis_file_name, from_schema=bool(recompute_file_name) ) # export the new NWB file + parent_path = Path(analysis_file_abs_path).parent + if not parent_path.exists(): + parent_path.mkdir(parents=True) with pynwb.NWBHDF5IO( path=analysis_file_abs_path, mode="w", manager=io.manager ) as export_io: diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index ccbf7c909..5d5161f28 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -210,23 +210,26 @@ def _add_externals_to_restr_graph( restr_graph : RestrGraph The updated RestrGraph """ - raw_tbl = self._externals["raw"] - raw_name = raw_tbl.full_table_name - raw_restr = ( - "filepath in ('" + "','".join(self._list_raw_files(key)) + "')" - ) - restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr) - - analysis_tbl = self._externals["analysis"] - analysis_name = analysis_tbl.full_table_name - analysis_restr = ( # filepaths have analysis subdir. regexp substrings - "filepath REGEXP '" + "|".join(self._list_analysis_files(key)) + "'" - ) # regexp is slow, but we're only doing this once, and future-proof - restr_graph.graph.add_node( - analysis_name, ft=analysis_tbl, restr=analysis_restr - ) - - restr_graph.visited.update({raw_name, analysis_name}) + # only add items if found respective file types + if raw_files := self._list_raw_files(key): + raw_tbl = self._externals["raw"] + raw_name = raw_tbl.full_table_name + raw_restr = "filepath in ('" + "','".join(raw_files) + "')" + restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr) + restr_graph.visited.add(raw_name) + + if analysis_files := self._list_analysis_files(key): + analysis_tbl = self._externals["analysis"] + analysis_name = analysis_tbl.full_table_name + # to avoid issues with analysis subdir, we use REGEXP + # this is slow, but we're only doing this once, and future-proof + analysis_restr = ( + "filepath REGEXP '" + "|".join(analysis_files) + "'" + ) + restr_graph.graph.add_node( + analysis_name, ft=analysis_tbl, restr=analysis_restr + ) + restr_graph.visited.add(analysis_name) return restr_graph diff --git a/src/spyglass/spikesorting/v0/spikesorting_recording.py b/src/spyglass/spikesorting/v0/spikesorting_recording.py index a6a356a33..f3ff15f97 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_recording.py +++ b/src/spyglass/spikesorting/v0/spikesorting_recording.py @@ -315,6 +315,8 @@ def make(self, key): recording = self._get_filtered_recording(key) recording_name = self._get_recording_name(key) + # recording_dir = Path("/home/cbroz/wrk/temp_ssr0/") + # Path to files that will hold the recording extractors recording_path = str(recording_dir / Path(recording_name)) if os.path.exists(recording_path): @@ -351,6 +353,7 @@ def _get_recording_name(key): key["sort_interval_name"], str(key["sort_group_id"]), key["preproc_params_name"], + # key["team_name"], # TODO: add team name, reflect PK structure ] ) diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 546884b96..f7582b6ce 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -176,12 +176,6 @@ class SpikeSortingRecording(SpyglassMixin, dj.Computed): file_hash='': varchar(32) # Hash of the NWB file """ - # QUESTION: Should this file_hash be a (hidden) attr of AnalysisNwbfile? - # Hidden would require the pre-release datajoint version. - # Adding it there would centralize recompute abilities, but maybe - # that's excessive if we only plan recompute for a handful of - # tables. - def make(self, key): """Populate SpikeSortingRecording. @@ -216,7 +210,12 @@ def make(self, key): self.insert1(key) @classmethod - def _make_file(cls, key: dict = None, recompute_file_name: str = None): + def _make_file( + cls, + key: dict = None, + recompute_file_name: str = None, + force: bool = False, + ): """Preprocess recording and write to NWB file. All `_make_file` methods should exit early if the file already exists. @@ -249,6 +248,20 @@ def _make_file(cls, key: dict = None, recompute_file_name: str = None): key, recompute_object_id, recompute_electrodes_id, file_hash = ( query.fetch1("KEY", "object_id", "electrodes_id", "file_hash") ) + elif force: + # print("force") + elect_attr = "acquisition/ProcessedElectricalSeries/electrodes" + analysis_path = ( + Path("/stelmo/nwb/analysis") + / key["analysis_file_name"].split("_")[0] + / key["analysis_file_name"] + ) + if not analysis_path.exists(): + raise FileNotFoundError(f"File {analysis_path} not found.") + with H5File(analysis_path, "r") as f: + elect_id = f[elect_attr].attrs["object_id"] + obj_id = (cls & key).fetch1("object_id") + recompute_object_id, recompute_electrodes_id = obj_id, elect_id else: recompute_object_id, recompute_electrodes_id = None, None @@ -263,12 +276,13 @@ def _make_file(cls, key: dict = None, recompute_file_name: str = None): ) ) - if recompute: - AnalysisNwbfile()._update_external(recompute_file_name, file_hash) - else: - file_hash = AnalysisNwbfile().get_hash( - recording_nwb_file_name, from_schema=False - ) + # if recompute: + # pass + # # AnalysisNwbfile()._update_external(recompute_file_name, file_hash) + # else: + file_hash = AnalysisNwbfile().get_hash( + recording_nwb_file_name, from_schema=True # REVERT TO FALSE? + ) return dict( analysis_file_name=recording_nwb_file_name, @@ -523,6 +537,8 @@ def _get_preprocessed_recording(self, key: dict): return dict(recording=recording, timestamps=np.asarray(timestamps)) + return obj_id, elect_id + def update_ids(self): """Update electrodes_id, and file_hash in SpikeSortingRecording table. @@ -687,7 +703,7 @@ def _write_recording_to_nwb( io.write(nwbfile) if recompute_object_id: - logger.info(f"Recomputed {recompute_file_name}, fixing object IDs.") + # logger.info(f"Recomputed {recompute_file_name}, fixing object IDs.") with H5File(analysis_nwb_file_abs_path, "a") as f: f[series_attr].attrs["object_id"] = recompute_object_id f[elect_attr].attrs["object_id"] = recompute_electrodes_id diff --git a/src/spyglass/utils/nwb_hash.py b/src/spyglass/utils/nwb_hash.py index 1c9c644e3..fd11f3c87 100644 --- a/src/spyglass/utils/nwb_hash.py +++ b/src/spyglass/utils/nwb_hash.py @@ -1,15 +1,24 @@ import atexit import json +import re from hashlib import md5 from pathlib import Path -from typing import Any, Union +from typing import Any, Dict, Union import h5py import numpy as np +import pynwb +from hdmf.build import TypeMap +from hdmf.spec import NamespaceCatalog +from pynwb.spec import NWBDatasetSpec, NWBGroupSpec, NWBNamespace from tqdm import tqdm DEFAULT_BATCH_SIZE = 4095 -IGNORED_KEYS = ["version"] +IGNORED_KEYS = [ + "version", + "object_id", # TODO: remove +] +PRECISION_LOOKUP = dict(ProcessedElectricalSeries=8) class DirectoryHasher: @@ -101,9 +110,22 @@ def __init__( self, path: Union[str, Path], batch_size: int = DEFAULT_BATCH_SIZE, + precision_lookup: Dict[str, int] = PRECISION_LOOKUP, + source_script_version: bool = False, verbose: bool = True, ): - """Hashes the contents of an NWB file, limiting to partial data. + """Hashes the contents of an NWB file. + + Iterates through all objects in the NWB file, hashing the names, attrs, + and data of each object. Ignores NWB specifications, and only considers + NWB version. + + Uses a batch size to limit the amount of data hashed at once for large + datasets. Rounds data to n decimal places for specific dataset names, + as provided in the data_rounding dict. + + Version numbers stored in '/general/source_script' are ignored by + default per the source_script_version flag. Parameters ---------- @@ -111,12 +133,23 @@ def __init__( Path to the NWB file. batch_size : int, optional Limit of data to hash for large datasets, by default 4095. + data_rounding : Dict[str, int], optional + Round data to n decimal places for specific datasets (i.e., + {dataset_name: n}). Default is to round ProcessedElectricalSeries + to 10 significant digits via np.round(chunk, n). + source_script_version : bool, optional + Include version numbers from the source_script in the hash, by + default False. If false, uses regex pattern to censor version + numbers from this field. verbose : bool, optional Display progress bar, by default True. """ + self.path = Path(path) self.file = h5py.File(path, "r") atexit.register(self.cleanup) + self.source_ver = source_script_version + self.precision = precision_lookup self.batch_size = batch_size self.verbose = verbose self.hashed = md5("".encode()) @@ -128,15 +161,28 @@ def __init__( def cleanup(self): self.file.close() + def remove_version(self, key: str) -> bool: + version_pattern = ( + r"\d+\.\d+\.\d+" # Major.Minor.Patch + + r"(?:-alpha|-beta|a\d+)?" # Optional alpha or beta, -alpha + + r"(?:\.dev\d+)?" # Optional dev build, .dev01 + + r"(?:\+[a-z0-9]{9})?" # Optional commit hash, +abcdefghi + + r"(?:\.d\d{8})?" # Optional date, dYYYYMMDD + ) + return re.sub(version_pattern, "VERSION", key) + def collect_names(self, file): """Collects all object names in the file.""" def collect_items(name, obj): + if "specifications" in name: + return # Ignore specifications, because we hash namespaces items_to_process.append((name, obj)) items_to_process = [] file.visititems(collect_items) items_to_process.sort(key=lambda x: x[0]) + return items_to_process def serialize_attr_value(self, value: Any): @@ -158,21 +204,30 @@ def serialize_attr_value(self, value: Any): return value.astype(str).tobytes() # must be 'astype(str)' elif isinstance(value, (str, int, float)): return str(value).encode() - return repr(value).encode() # For other data types, use repr + return repr(value).encode() # For other, use repr def hash_dataset(self, dataset: h5py.Dataset): _ = self.hash_shape_dtype(dataset) if dataset.shape == (): - self.hashed.update(self.serialize_attr_value(dataset[()])) + raw_scalar = str(dataset[()]) + if "source_script" in dataset.name and not self.source_ver: + raw_scalar = self.remove_version(raw_scalar) + self.hashed.update(self.serialize_attr_value(raw_scalar)) return + dataset_name = dataset.parent.name.split("/")[-1] + precision = self.precision.get(dataset_name, None) + size = dataset.shape[0] start = 0 while start < size: end = min(start + self.batch_size, size) - self.hashed.update(self.serialize_attr_value(dataset[start:end])) + data = dataset[start:end] + if precision: + data = np.round(data, precision) + self.hashed.update(self.serialize_attr_value(data)) start = end def hash_shape_dtype(self, obj: [h5py.Dataset, np.ndarray]) -> str: @@ -180,10 +235,24 @@ def hash_shape_dtype(self, obj: [h5py.Dataset, np.ndarray]) -> str: return self.hashed.update(str(obj.shape).encode() + str(obj.dtype).encode()) + @property + def all_namespaces(self) -> bytes: + """Encoded string of all NWB namespace specs.""" + catalog = NamespaceCatalog(NWBGroupSpec, NWBDatasetSpec, NWBNamespace) + pynwb.NWBHDF5IO.load_namespaces(catalog, self.path) + name_cat = TypeMap(catalog).namespace_catalog + ret = "" + for ns_name in name_cat.namespaces: + ret += ns_name + ret += name_cat.get_namespace(ns_name)["version"] + return ret.encode() + def compute_hash(self) -> str: """Hashes the NWB file contents, limiting to partal data where large.""" # Dev note: fallbacks if slow: 1) read_direct_chunk, 2) read from offset + self.hashed.update(self.all_namespaces) + for name, obj in tqdm( self.collect_names(self.file), desc=self.file.filename.split("/")[-1].split(".")[0], @@ -192,6 +261,8 @@ def compute_hash(self) -> str: self.hashed.update(name.encode()) for attr_key in sorted(obj.attrs): + if attr_key in IGNORED_KEYS: + continue attr_value = obj.attrs[attr_key] _ = self.hash_shape_dtype(attr_value) self.hashed.update(attr_key.encode()) From 54a3ca13f61534525b10e80a3cc0bd6e00cee155 Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Thu, 9 Jan 2025 08:02:31 -0800 Subject: [PATCH 15/21] WIP: error specificity --- src/spyglass/common/common_nwbfile.py | 13 +++++++++---- src/spyglass/common/common_usage.py | 7 +++++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index e250ecec9..fad76c77b 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -358,11 +358,16 @@ def get_abs_path( analysis_nwb_file_abspath : str The absolute path for the given file name. """ - if from_schema: # Skips checksum and file existence checks - return f"{analysis_dir}/" + ( + if from_schema: # Skips checksum check + query = ( schema.external["analysis"] - & f'filepath LIKE "%{analysis_nwb_file_name}"' - ).fetch1("filepath") + & f"filepath LIKE '%{analysis_nwb_file_name}'" + ) + if len(query) != 1: + raise FileNotFoundError( + f"Found {len(query)} files for: {analysis_nwb_file_name}" + ) + return f"{analysis_dir}/" + query.fetch1("filepath") # If an entry exists in the database get the stored datajoint filepath file_key = {"analysis_file_name": analysis_nwb_file_name} diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 421f16946..bd76f7df3 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -15,8 +15,11 @@ from spyglass.settings import test_mode from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger from spyglass.utils.dj_graph import RestrGraph -from spyglass.utils.dj_helper_fn import (make_file_obj_id_unique, unique_dicts, - update_analysis_for_dandi_standard) +from spyglass.utils.dj_helper_fn import ( + make_file_obj_id_unique, + unique_dicts, + update_analysis_for_dandi_standard, +) from spyglass.utils.nwb_helper_fn import get_linked_nwbs from spyglass.utils.sql_helper_fn import SQLDumpHelper From 1e416985c80476d534629de5608dddc18542e668 Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Tue, 4 Feb 2025 09:21:04 -0800 Subject: [PATCH 16/21] Add tables for recompute processing --- src/spyglass/common/common_nwbfile.py | 15 +- src/spyglass/spikesorting/v1/recording.py | 104 +++++---- src/spyglass/spikesorting/v1/usage.py | 257 ++++++++++++++++++++++ src/spyglass/utils/nwb_hash.py | 127 +++++++---- 4 files changed, 427 insertions(+), 76 deletions(-) create mode 100644 src/spyglass/spikesorting/v1/usage.py diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index fad76c77b..f987952c7 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -3,6 +3,7 @@ import stat import string from pathlib import Path +from typing import Union from uuid import uuid4 import datajoint as dj @@ -175,7 +176,10 @@ class AnalysisNwbfile(SpyglassMixin, dj.Manual): # See #630, #664. Excessive key length. def create( - self, nwb_file_name: str, recompute_file_name: str = None + self, + nwb_file_name: str, + recompute_file_name: str = None, + alternate_dir: Union[str, Path] = None, ) -> str: """Open the NWB file, create copy, write to disk and return new name. @@ -188,6 +192,8 @@ def create( The name of an NWB file to be copied. recompute_file_name : str, optional The name of the file to be regenerated. Defaults to None. + alternate_dir : Union[str, Path], Optional + An alternate directory to store the file. Defaults to analysis_dir. Returns ------- @@ -226,6 +232,12 @@ def create( analysis_file_name, from_schema=bool(recompute_file_name) ) + if alternate_dir: # override the default analysis_dir for recompute + relative = Path(analysis_file_abs_path).relative_to( + analysis_dir + ) + analysis_file_abs_path = Path(alternate_dir) / relative + # export the new NWB file parent_path = Path(analysis_file_abs_path).parent if not parent_path.exists(): @@ -296,6 +308,7 @@ def copy(cls, nwb_file_name: str): The name of the new NWB file. """ nwb_file_abspath = AnalysisNwbfile.get_abs_path(nwb_file_name) + with pynwb.NWBHDF5IO( path=nwb_file_abspath, mode="r", load_namespaces=True ) as io: diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index f7582b6ce..8d8a42e2b 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -3,6 +3,7 @@ from typing import Iterable, List, Optional, Tuple, Union import datajoint as dj +import hdmf import numpy as np import probeinterface as pi import pynwb @@ -21,12 +22,12 @@ from spyglass.common.common_lab import LabTeam from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile from spyglass.common.common_nwbfile import schema as nwb_schema -from spyglass.settings import test_mode +from spyglass.settings import temp_dir, test_mode from spyglass.spikesorting.utils import ( _get_recording_timestamps, get_group_by_shank, ) -from spyglass.utils import SpyglassMixin, logger +from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger from spyglass.utils.nwb_hash import NwbfileHasher schema = dj.schema("spikesorting_v1_recording") @@ -174,6 +175,7 @@ class SpikeSortingRecording(SpyglassMixin, dj.Computed): object_id: varchar(40) # Object ID for the processed recording in NWB file electrodes_id='': varchar(40) # Object ID for the processed electrodes file_hash='': varchar(32) # Hash of the NWB file + dependencies=null: blob # dict of dependencies (pynwb, hdmf, spikeinterface) """ def make(self, key): @@ -207,6 +209,7 @@ def make(self, key): skip_duplicates=True, # for recompute ) AnalysisNwbfile().add(nwb_file_name, key["analysis_file_name"]) + self.insert1(key) @classmethod @@ -214,7 +217,7 @@ def _make_file( cls, key: dict = None, recompute_file_name: str = None, - force: bool = False, + save_to: Union[str, Path] = None, ): """Preprocess recording and write to NWB file. @@ -230,35 +233,43 @@ def _make_file( primary key of SpikeSortingRecordingSelection table recompute_file_name : str, Optional If specified, recompute this file. Use as resulting file name. - If none, generate a new file name. + If none, generate a new file name. Used for recomputation after + typical deletion. + save_to : Union[str,Path], Optional + Default None, save to analysis directory. If provided, save to + specified path. Used for recomputation prior to deletion. """ - file_hash = None if not key and not recompute_file_name: raise ValueError( "Either key or recompute_file_name must be specified." ) - elif recompute := bool(recompute_file_name and not key): - file_path = AnalysisNwbfile.get_abs_path( - recompute_file_name, from_schema=True - ) - if Path(file_path).exists(): + + file_hash = None + recompute = recompute_file_name and not key and not save_to + file_path = AnalysisNwbfile.get_abs_path( + recompute_file_name, from_schema=True + ) + if recompute: + if Path(file_path).exists(): # No need to recompute return logger.info(f"Recomputing {recompute_file_name}.") query = cls & {"analysis_file_name": recompute_file_name} + # Use deleted file's ids and hash for recompute key, recompute_object_id, recompute_electrodes_id, file_hash = ( query.fetch1("KEY", "object_id", "electrodes_id", "file_hash") ) - elif force: - # print("force") + elif save_to: # recompute prior to deletion elect_attr = "acquisition/ProcessedElectricalSeries/electrodes" - analysis_path = ( - Path("/stelmo/nwb/analysis") - / key["analysis_file_name"].split("_")[0] - / key["analysis_file_name"] - ) - if not analysis_path.exists(): + if not Path(file_path).exists(): raise FileNotFoundError(f"File {analysis_path} not found.") - with H5File(analysis_path, "r") as f: + + # ensure partial objects are present in existing file + with H5File(file_path, "r") as f: + elect_parts = elect_attr.split("/") + for i in range(len(elect_parts)): + mid_path = "/".join(elect_parts[: i + 1]) + if mid_path not in f.keys(): + raise KeyError(f"{mid_path} MISSING {analysis_path}") elect_id = f[elect_attr].attrs["object_id"] obj_id = (cls & key).fetch1("object_id") recompute_object_id, recompute_electrodes_id = obj_id, elect_id @@ -273,13 +284,13 @@ def _make_file( recompute_file_name=recompute_file_name, recompute_object_id=recompute_object_id, recompute_electrodes_id=recompute_electrodes_id, + save_to=save_to, ) ) + # TODO: uncomment after review. Commented to avoid impacting database # if recompute: - # pass # # AnalysisNwbfile()._update_external(recompute_file_name, file_hash) - # else: file_hash = AnalysisNwbfile().get_hash( recording_nwb_file_name, from_schema=True # REVERT TO FALSE? ) @@ -289,6 +300,11 @@ def _make_file( object_id=recording_object_id, electrodes_id=electrodes_id, file_hash=file_hash, + dependencies=dict( + pynwb=pynwb.__version__, + hdmf=hdmf.__version__, + spikeinterface=si.__version__, + ), ) @classmethod @@ -411,27 +427,21 @@ def _get_preprocessed_recording(self, key: dict): recording_channel_ids = np.setdiff1d(channel_ids, ref_channel_id) all_channel_ids = np.unique(np.append(channel_ids, ref_channel_id)) + # Electrode's fk to probe is nullable, so we need to check if present + query = Electrode * Probe & {"nwb_file_name": nwb_file_name} + if len(query) == 0: + raise ValueError(f"No probe info found for {nwb_file_name}.") + probe_type_by_channel = [] electrode_group_by_channel = [] for channel_id in channel_ids: - probe_type_by_channel.append( - ( - Electrode * Probe - & { - "nwb_file_name": nwb_file_name, - "electrode_id": channel_id, - } - ).fetch1("probe_type") - ) + # TODO: limit to one unique fetch. use set + c_query = query & {"electrode_id": channel_id} + probe_type_by_channel.append(c_query.fetch1("probe_type")) electrode_group_by_channel.append( - ( - Electrode - & { - "nwb_file_name": nwb_file_name, - "electrode_id": channel_id, - } - ).fetch1("electrode_group_name") + c_query.fetch1("electrode_group_name") ) + probe_type = np.unique(probe_type_by_channel) filter_params = ( SpikeSortingPreprocessingParameters * SpikeSortingRecordingSelection @@ -558,6 +568,16 @@ def update_ids(self): self.update1(key) + def recompute(self, key: dict): + """Recompute the processed recording. + + Parameters + ---------- + key : dict + primary key of SpikeSortingRecording table + """ + raise NotImplementedError("Recompute not implemented.") + def _consolidate_intervals(intervals, timestamps): """Convert a list of intervals (start_time, stop_time) @@ -620,6 +640,7 @@ def _write_recording_to_nwb( recompute_file_name: Optional[str] = None, recompute_object_id: Optional[str] = None, recompute_electrodes_id: Optional[str] = None, + save_to: Union[str, Path] = None, ): """Write a recording in NWB format @@ -637,6 +658,9 @@ def _write_recording_to_nwb( recompute_electrodes_id : str, optional object ID for recomputed electrodes sub-object, acquisition/ProcessedElectricalSeries/electrodes. + save_to : Union[str, Path], optional + Default None, save to analysis directory. If provided, save to specified + path. For use in recompute prior to deletion. Returns ------- @@ -662,12 +686,18 @@ def _write_recording_to_nwb( analysis_nwb_file = AnalysisNwbfile().create( nwb_file_name=nwb_file_name, recompute_file_name=recompute_file_name, + alternate_dir=save_to, ) analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path( analysis_nwb_file, from_schema=recompute, ) + abs_obj = Path(analysis_nwb_file_abs_path) + if save_to and abs_obj.is_relative_to(analysis_dir): + relative = abs_obj.relative_to(analysis_dir) + analysis_nwb_file_abs_path = Path(save_to) / relative + with pynwb.NWBHDF5IO( path=analysis_nwb_file_abs_path, mode="a", diff --git a/src/spyglass/spikesorting/v1/usage.py b/src/spyglass/spikesorting/v1/usage.py new file mode 100644 index 000000000..bbd836011 --- /dev/null +++ b/src/spyglass/spikesorting/v1/usage.py @@ -0,0 +1,257 @@ +"""This schema is used to transition manage files for recompute. + +Tables +------ +RecordingVersions: What versions are present in an existing analysis file? + Allows restrict of recompute attempts to pynwb environments that are + compatible with a pre-existing file. For pip dependencies, see + SpikeSortingRecording.dependencies field +RecordingRecompute: Attempt recompute of an analysis file. +""" + +from pathlib import Path +from typing import Union + +import datajoint as dj +import pynwb +from h5py import File as h5py_File +from hdmf import __version__ as hdmf_version +from hdmf.build import TypeMap +from hdmf.spec import NamespaceCatalog +from pynwb.spec import NWBDatasetSpec, NWBGroupSpec, NWBNamespace +from spikeinterface import __version__ as si_version + +from spyglass.common import AnalysisNwbfile +from spyglass.settings import analysis_dir, temp_dir +from spyglass.spikesorting.v1.recording import SpikeSortingRecording +from spyglass.utils import SpyglassMixin, logger +from spyglass.utils.nwb_hash import NwbfileHasher, get_file_namespaces + +schema = dj.schema("cbroz_temp") # TODO: spikesorting_v1_usage or _recompute + + +@schema +class RecordingVersions(SpyglassMixin, dj.Computed): + definition = """ + -> SpikeSortingRecording + --- + core = '' : varchar(32) + hdmf_common = '' : varchar(32) + hdmf_experimental = '' : varchar(32) + ndx_franklab_novela = '' : varchar(32) + spyglass = '' : varchar(64) + """ + + @property + def valid_ver_restr(self): + """Return a restriction of self for the current environment.""" + return self.namespace_dict(pynwb.get_manager().type_map) + + @property + def this_env(self): + """Return restricted version of self for the current environment.""" + return self & self.valid_ver_restr + + def namespace_dict(self, type_map: TypeMap): + """Remap namespace names to hyphenated field names for DJ compatibility.""" + hyph_fields = [f.replace("_", "-") for f in self.heading.names] + name_cat = type_map.namespace_catalog + return { + field.replace("-", "_"): name_cat.get_namespace(field).get( + "version", None + ) + for field in name_cat.namespaces + if field in hyph_fields + } + + def make(self, key): + parent = (SpikeSortingRecording & key).fetch1() + path = AnalysisNwbfile().get_abs_path(parent["analysis_file_name"]) + + insert = key.copy() + insert.update(get_file_namespaces(path)) + + with h5py_File(path, "r") as f: + script = f.get("general/source_script") + if script is not None: + script = str(script[()]).split("=")[1].strip().replace("'", "") + insert["spyglass"] = script + + self.insert1(insert) + + +@schema +class RecordingRecomputeSelection(SpyglassMixin, dj.Manual): + definition = """ + -> RecordingVersions + attempt_id: varchar(32) # name for environment used to attempt recompute + --- + dependencies: blob # dict of pip dependencies + """ + + @property + def pip_deps(self): + return dict( + pynwb=pynwb.__version__, + hdmf=hdmf_version, + spikeinterface=si_version, + # TODO: add othes? + ) + + def attempt_all(self, attempt_id): + inserts = [ + {**key, "attempt_id": attempt_id, "dependencies": self.pip_deps} + for key in RecordingVersions().this_env.fetch("KEY", as_dict=True) + ] + self.insert(inserts, skip_duplicates=True) + + +@schema +class RecordingRecompute(dj.Computed): + definition = """ + -> RecordingRecomputeSelection + --- + matched: bool + """ + + class Name(dj.Part): + definition = """ + -> master + name : varchar(255) + --- + missing_from: enum('old', 'new') + """ + + class Hash(dj.Part): + definition = """ + -> master + name : varchar(255) + --- + old=null: longblob + new=null: longblob + """ + + key_source = RecordingVersions().this_env * RecordingRecomputeSelection + old_dir = Path(analysis_dir) + # see /stelmo/cbroz/temp_rcp/ for existing recomputed + new_dir = Path(temp_dir) / "spikesort_v1_recompute" + ignore_files = set() + + def _get_subdir(self, key): + file = key["analysis_file_name"] if isinstance(key, dict) else key + parts = file.split("_") + subdir = "_".join(parts[:-1]) + return subdir + "/" + file + + def _get_paths(self, key): + old = self.old_dir / self._get_subdir(key) + new = self.new_dir / self._get_subdir(key) + return old, new + + def _hash_existing(self): + for nwb in self.new_dir.rglob("*.nwb"): + if nwb.with_suffix(".hash").exists(): + continue + try: + logger.info(f"Hashing {nwb}") + _ = self._hash_one(nwb) + except (OSError, ValueError, RuntimeError) as e: + logger.warning(f"Error: {e.__class__.__name__}: {nwb.name}") + continue + + def _hash_one(self, path): + hasher = NwbfileHasher(path, verbose=False) + with open(path.with_suffix(".hash"), "w") as f: + f.write(hasher.hash) + return hasher.hash + + def make(self, key): + if self & key: + return + + parent = ( + SpikeSortingRecording + * RecordingVersions + * RecordingRecomputeSelection + & key + ).fetch1() + + # Ensure dependencies unchanged since selection insert + parent_deps = parent["dependencies"] + for dep, ver in RecordingRecomputeSelection().pip_deps.items(): + if parent_deps.get(dep, None) != ver: + raise ValueError(f"Please run this key with {parent_deps}") + + old, new = self._get_paths(parent) + + if not new.exists(): # if the file has yet to be recomputed + try: + new_vals = SpikeSortingRecording()._make_file( + parent, + recompute_file_name=parent["analysis_file_name"], + save_to=new, + ) + except RuntimeError as e: + print(f"{e}: {new.name}") + return + except ValueError as e: + e_info = e.args[0] + if "probe info" in e_info: # make failed bc missing probe info + self.insert1(dict(key, matched=False, diffs=e_info)) + else: + print(f"ValueError: {e}: {new.name}") + return + except KeyError as err: + e_info = err.args[0] + if "MISSING" in e_info: # make failed bc missing parent file + e = e_info.split("MISSING")[0].strip() + self.insert1(dict(key, matched=False, diffs=e)) + self.Name().insert1( + dict( + key, name=f"Parent missing {e}", missing_from="old" + ) + ) + else: + logger.warning(f"KeyError: {err}: {new.name}") + return + with open(new.with_suffix(".hash"), "w") as f: + f.write(new_vals["file_hash"]) + + # TODO: how to check the env used to make the file when reading? + elif new.with_suffix(".hash").exists(): + print(f"\nReading hash {new}") + with open(new.with_suffix(".hash"), "r") as f: + new_vals = dict(file_hash=f.read()) + else: + new_vals = dict(file_hash=self._hash_one(new)) + + old_hasher = parent["file_hash"] or NwbfileHasher( + old, verbose=False, keep_obj_hash=True + ) + + if new_vals["file_hash"] == old_hasher.hash: + self.insert1(dict(key, match=True)) + new.unlink(missing_ok=True) + return + + names = [] + hashes = [] + + logger.info(f"Compparing mismatched {new}") + new_hasher = NwbfileHasher(new, verbose=False, keep_obj_hash=True) + all_objs = set(old_hasher.objs.keys()) | set(new_hasher.objs.keys()) + + for obj in all_objs: + old_obj, old_hash = old_hasher.objs.get(obj, (None, None)) + new_obj, new_hash = new_hasher.objs.get(obj, (None, None)) + + if old_hash is None: + names.append(dict(key, name=obj, missing_from="old")) + elif new_hash is None: + names.append(dict(key, name=obj, missing_from="new")) + elif old_hash != new_hash: + hashes.append(dict(key, name=obj, old=old_obj, new=new_obj)) + + self.insert1(dict(key, matched=False)) + self.Name().insert(names) + self.Hash().insert(hashes) diff --git a/src/spyglass/utils/nwb_hash.py b/src/spyglass/utils/nwb_hash.py index fd11f3c87..376c9ed3f 100644 --- a/src/spyglass/utils/nwb_hash.py +++ b/src/spyglass/utils/nwb_hash.py @@ -1,6 +1,7 @@ import atexit import json import re +from functools import cached_property from hashlib import md5 from pathlib import Path from typing import Any, Dict, Union @@ -14,11 +15,38 @@ from tqdm import tqdm DEFAULT_BATCH_SIZE = 4095 -IGNORED_KEYS = [ - "version", - "object_id", # TODO: remove -] -PRECISION_LOOKUP = dict(ProcessedElectricalSeries=8) +IGNORED_KEYS = ["version"] +PRECISION_LOOKUP = dict(ProcessedElectricalSeries=4) + + +def get_file_namespaces( + file_path: Union[str, Path], replace_hypens: bool = True +) -> dict: + """Get all namespace versions from an NWB file. + + WARNING: This function falsely reports core <= 2.6.0 as 2.6.0-alpha + + Parameters + ---------- + file_path : Union[str, Path] + Path to the NWB file. + replace_hypens : bool, optional + Replace hyphens with underscores for DJ compatibility, by default True. + """ + catalog = NamespaceCatalog(NWBGroupSpec, NWBDatasetSpec, NWBNamespace) + pynwb.NWBHDF5IO.load_namespaces(catalog, file_path) + name_cat = TypeMap(catalog).namespace_catalog + + ret = { + ns_name: name_cat.get_namespace(ns_name).get("version", None) + for ns_name in name_cat.namespaces + } + + return ( + {k.replace("-", "_"): v for k, v in ret.items()} + if replace_hypens + else ret + ) class DirectoryHasher: @@ -111,7 +139,7 @@ def __init__( path: Union[str, Path], batch_size: int = DEFAULT_BATCH_SIZE, precision_lookup: Dict[str, int] = PRECISION_LOOKUP, - source_script_version: bool = False, + keep_obj_hash: bool = False, verbose: bool = True, ): """Hashes the contents of an NWB file. @@ -124,8 +152,10 @@ def __init__( datasets. Rounds data to n decimal places for specific dataset names, as provided in the data_rounding dict. - Version numbers stored in '/general/source_script' are ignored by - default per the source_script_version flag. + Version numbers stored in '/general/source_script' are ignored. + + Keeps each object hash as a dictionary, if keep_obj_hash is True. This + is useful for debugging, but not recommended for large files. Parameters ---------- @@ -137,10 +167,8 @@ def __init__( Round data to n decimal places for specific datasets (i.e., {dataset_name: n}). Default is to round ProcessedElectricalSeries to 10 significant digits via np.round(chunk, n). - source_script_version : bool, optional - Include version numbers from the source_script in the hash, by - default False. If false, uses regex pattern to censor version - numbers from this field. + keep_obj_hash : bool, optional + Keep the hash of each object in the NWB file, by default False. verbose : bool, optional Display progress bar, by default True. """ @@ -148,10 +176,14 @@ def __init__( self.file = h5py.File(path, "r") atexit.register(self.cleanup) - self.source_ver = source_script_version + if isinstance(precision_lookup, int): + precision_lookup = dict(ProcessedElectricalSeries=precision_lookup) + self.precision = precision_lookup self.batch_size = batch_size self.verbose = verbose + self.keep_obj_hash = keep_obj_hash + self.objs = {} self.hashed = md5("".encode()) self.hash = self.compute_hash() @@ -207,13 +239,13 @@ def serialize_attr_value(self, value: Any): return repr(value).encode() # For other, use repr def hash_dataset(self, dataset: h5py.Dataset): - _ = self.hash_shape_dtype(dataset) + this_hash = md5(self.hash_shape_dtype(dataset)) if dataset.shape == (): raw_scalar = str(dataset[()]) - if "source_script" in dataset.name and not self.source_ver: + if "source_script" in dataset.name: raw_scalar = self.remove_version(raw_scalar) - self.hashed.update(self.serialize_attr_value(raw_scalar)) + this_hash.update(self.serialize_attr_value(raw_scalar)) return dataset_name = dataset.parent.name.split("/")[-1] @@ -227,59 +259,78 @@ def hash_dataset(self, dataset: h5py.Dataset): data = dataset[start:end] if precision: data = np.round(data, precision) - self.hashed.update(self.serialize_attr_value(data)) + this_hash.update(self.serialize_attr_value(data)) start = end + return this_hash.hexdigest() + def hash_shape_dtype(self, obj: [h5py.Dataset, np.ndarray]) -> str: if not hasattr(obj, "shape") or not hasattr(obj, "dtype"): - return - self.hashed.update(str(obj.shape).encode() + str(obj.dtype).encode()) + return "".encode() + return str(obj.shape).encode() + str(obj.dtype).encode() - @property - def all_namespaces(self) -> bytes: + @cached_property + def namespaces(self) -> dict: """Encoded string of all NWB namespace specs.""" - catalog = NamespaceCatalog(NWBGroupSpec, NWBDatasetSpec, NWBNamespace) - pynwb.NWBHDF5IO.load_namespaces(catalog, self.path) - name_cat = TypeMap(catalog).namespace_catalog - ret = "" - for ns_name in name_cat.namespaces: - ret += ns_name - ret += name_cat.get_namespace(ns_name)["version"] - return ret.encode() + return get_file_namespaces(self.path) + + @cached_property + def namespaces_str(self) -> str: + """String representation of all NWB namespace specs.""" + return json.dumps(self.namespaces, sort_keys=True).encode() + + def add_to_cache(self, name: str, obj: Any, digest: str = None): + """Add object to the cache. + + Centralizes conditional logic for adding objects to the cache. + """ + if self.keep_obj_hash: + self.objs[name] = (obj, digest) def compute_hash(self) -> str: - """Hashes the NWB file contents, limiting to partal data where large.""" + """Hashes the NWB file contents.""" # Dev note: fallbacks if slow: 1) read_direct_chunk, 2) read from offset - self.hashed.update(self.all_namespaces) + self.hashed.update(self.namespaces_str) + + self.add_to_cache("namespaces", self.namespaces, None) for name, obj in tqdm( self.collect_names(self.file), desc=self.file.filename.split("/")[-1].split(".")[0], disable=not self.verbose, ): - self.hashed.update(name.encode()) + this_hash = md5(name.encode()) for attr_key in sorted(obj.attrs): if attr_key in IGNORED_KEYS: continue attr_value = obj.attrs[attr_key] - _ = self.hash_shape_dtype(attr_value) - self.hashed.update(attr_key.encode()) - self.hashed.update(self.serialize_attr_value(attr_value)) + this_hash.update(self.hash_shape_dtype(attr_value)) + this_hash.update(attr_key.encode()) + this_hash.update(self.serialize_attr_value(attr_value)) if isinstance(obj, h5py.Dataset): _ = self.hash_dataset(obj) elif isinstance(obj, h5py.SoftLink): - self.hashed.update(obj.path.encode()) + this_hash.update(obj.path.encode()) elif isinstance(obj, h5py.Group): for k, v in obj.items(): - self.hashed.update(k.encode()) - self.hashed.update(self.serialize_attr_value(v)) + this_hash.update(k.encode()) + obj_value = self.serialize_attr_value(v) + this_hash.update(obj_value) + self.add_to_cache( + f"{name}/k", v, md5(obj_value).hexdigest() + ) else: raise TypeError( f"Unknown object type: {type(obj)}\n" + "Please report this an issue on GitHub." ) + this_digest = this_hash.hexdigest() + self.hashed.update(this_digest.encode()) + + self.add_to_cache(name, obj, this_digest) + return self.hashed.hexdigest() From ae52aedb038fb802e88ad5a8c8db63ea346eecc0 Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Fri, 21 Feb 2025 08:38:59 -0800 Subject: [PATCH 17/21] WIP: incorporate feedback --- src/spyglass/common/common_nwbfile.py | 6 +++--- src/spyglass/spikesorting/v0/spikesorting_recording.py | 2 -- src/spyglass/spikesorting/v1/{usage.py => recompute.py} | 4 ++-- src/spyglass/utils/nwb_hash.py | 5 +++-- 4 files changed, 8 insertions(+), 9 deletions(-) rename src/spyglass/spikesorting/v1/{usage.py => recompute.py} (98%) diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index f987952c7..d983df09f 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -380,7 +380,7 @@ def get_abs_path( raise FileNotFoundError( f"Found {len(query)} files for: {analysis_nwb_file_name}" ) - return f"{analysis_dir}/" + query.fetch1("filepath") + return Path(analysis_dir) / query.fetch1("filepath") # If an entry exists in the database get the stored datajoint filepath file_key = {"analysis_file_name": analysis_nwb_file_name} @@ -508,8 +508,8 @@ def _update_external(self, analysis_file_name: str, file_hash: str): external_tbl = schema.external["analysis"] file_path = ( - self.__get_analysis_file_dir(analysis_file_name) - + f"/{analysis_file_name}" + Path(self.__get_analysis_file_dir(analysis_file_name)) + / analysis_file_name ) key = (external_tbl & f"filepath = '{file_path}'").fetch1() abs_path = Path(analysis_dir) / file_path diff --git a/src/spyglass/spikesorting/v0/spikesorting_recording.py b/src/spyglass/spikesorting/v0/spikesorting_recording.py index 99a220b5d..a4382bf32 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_recording.py +++ b/src/spyglass/spikesorting/v0/spikesorting_recording.py @@ -317,8 +317,6 @@ def make(self, key): recording = self._get_filtered_recording(key) recording_name = self._get_recording_name(key) - # recording_dir = Path("/home/cbroz/wrk/temp_ssr0/") - # Path to files that will hold the recording extractors recording_path = str(recording_dir / Path(recording_name)) if os.path.exists(recording_path): diff --git a/src/spyglass/spikesorting/v1/usage.py b/src/spyglass/spikesorting/v1/recompute.py similarity index 98% rename from src/spyglass/spikesorting/v1/usage.py rename to src/spyglass/spikesorting/v1/recompute.py index bbd836011..3877b9330 100644 --- a/src/spyglass/spikesorting/v1/usage.py +++ b/src/spyglass/spikesorting/v1/recompute.py @@ -1,4 +1,4 @@ -"""This schema is used to transition manage files for recompute. +"""This schema is used to track recompute capabilities for existing files. Tables ------ @@ -27,7 +27,7 @@ from spyglass.utils import SpyglassMixin, logger from spyglass.utils.nwb_hash import NwbfileHasher, get_file_namespaces -schema = dj.schema("cbroz_temp") # TODO: spikesorting_v1_usage or _recompute +schema = dj.schema("cbroz_temp") # TODO: spikesorting_v1_recompute @schema diff --git a/src/spyglass/utils/nwb_hash.py b/src/spyglass/utils/nwb_hash.py index 376c9ed3f..b7777c732 100644 --- a/src/spyglass/utils/nwb_hash.py +++ b/src/spyglass/utils/nwb_hash.py @@ -193,7 +193,8 @@ def __init__( def cleanup(self): self.file.close() - def remove_version(self, key: str) -> bool: + def remove_version(self, input_string: str) -> str: + """Removes version numbers from the input.""" version_pattern = ( r"\d+\.\d+\.\d+" # Major.Minor.Patch + r"(?:-alpha|-beta|a\d+)?" # Optional alpha or beta, -alpha @@ -201,7 +202,7 @@ def remove_version(self, key: str) -> bool: + r"(?:\+[a-z0-9]{9})?" # Optional commit hash, +abcdefghi + r"(?:\.d\d{8})?" # Optional date, dYYYYMMDD ) - return re.sub(version_pattern, "VERSION", key) + return re.sub(version_pattern, "VERSION", input_string) def collect_names(self, file): """Collects all object names in the file.""" From 2e8907017a2854de149731a9d5291f831d783f77 Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Mon, 3 Mar 2025 09:32:48 -0800 Subject: [PATCH 18/21] WIP: enforce environment restriction --- src/spyglass/spikesorting/v1/recompute.py | 383 ++++++++++++++++------ src/spyglass/spikesorting/v1/recording.py | 11 +- src/spyglass/utils/h5_helper_fn.py | 83 +++++ src/spyglass/utils/nwb_hash.py | 6 +- 4 files changed, 386 insertions(+), 97 deletions(-) create mode 100644 src/spyglass/utils/h5_helper_fn.py diff --git a/src/spyglass/spikesorting/v1/recompute.py b/src/spyglass/spikesorting/v1/recompute.py index 3877b9330..ec967ca94 100644 --- a/src/spyglass/spikesorting/v1/recompute.py +++ b/src/spyglass/spikesorting/v1/recompute.py @@ -9,15 +9,22 @@ RecordingRecompute: Attempt recompute of an analysis file. """ +import atexit +from functools import cached_property +from json import loads as json_loads +from os import environ as os_environ from pathlib import Path -from typing import Union +from typing import Tuple, Union import datajoint as dj +import h5py import pynwb +from datajoint.hash import key_hash from h5py import File as h5py_File from hdmf import __version__ as hdmf_version from hdmf.build import TypeMap from hdmf.spec import NamespaceCatalog +from numpy import __version__ as np_version from pynwb.spec import NWBDatasetSpec, NWBGroupSpec, NWBNamespace from spikeinterface import __version__ as si_version @@ -25,6 +32,7 @@ from spyglass.settings import analysis_dir, temp_dir from spyglass.spikesorting.v1.recording import SpikeSortingRecording from spyglass.utils import SpyglassMixin, logger +from spyglass.utils.h5_helper_fn import H5pyComparator from spyglass.utils.nwb_hash import NwbfileHasher, get_file_namespaces schema = dj.schema("cbroz_temp") # TODO: spikesorting_v1_recompute @@ -42,16 +50,25 @@ class RecordingVersions(SpyglassMixin, dj.Computed): spyglass = '' : varchar(64) """ - @property + @cached_property def valid_ver_restr(self): """Return a restriction of self for the current environment.""" return self.namespace_dict(pynwb.get_manager().type_map) - @property + @cached_property def this_env(self): """Return restricted version of self for the current environment.""" return self & self.valid_ver_restr + def key_env(self, key): + """Return the pynwb environment for a given key.""" + if not self & key: + self.make(key) + query = self & key + if len(query) != 1: + raise ValueError(f"Key matches {len(query)} entries: {query}") + return (self & key).fetch(*self.valid_ver_restr.keys(), as_dict=True)[0] + def namespace_dict(self, type_map: TypeMap): """Remap namespace names to hyphenated field names for DJ compatibility.""" hyph_fields = [f.replace("_", "-") for f in self.heading.names] @@ -65,15 +82,15 @@ def namespace_dict(self, type_map: TypeMap): } def make(self, key): + """Inventory the namespaces present in an analysis file.""" parent = (SpikeSortingRecording & key).fetch1() path = AnalysisNwbfile().get_abs_path(parent["analysis_file_name"]) - insert = key.copy() - insert.update(get_file_namespaces(path)) + insert = {**key, **get_file_namespaces(path)} with h5py_File(path, "r") as f: script = f.get("general/source_script") - if script is not None: + if script is not None: # after `=`, remove quotes script = str(script[()]).split("=")[1].strip().replace("'", "") insert["spyglass"] = script @@ -85,26 +102,109 @@ class RecordingRecomputeSelection(SpyglassMixin, dj.Manual): definition = """ -> RecordingVersions attempt_id: varchar(32) # name for environment used to attempt recompute + rounding=8: int # rounding for float ElectricalSeries --- dependencies: blob # dict of pip dependencies """ - @property + # --- Insert helpers --- + + @cached_property + def default_attempt_id(self): + conda = os_environ.get("CONDA_DEFAULT_ENV", "base") + si_readable = si_version.replace(".", "-") + return f"{conda}_si{si_readable}" + + @cached_property def pip_deps(self): return dict( pynwb=pynwb.__version__, hdmf=hdmf_version, spikeinterface=si_version, - # TODO: add othes? + numpy=np_version, ) - def attempt_all(self, attempt_id): + def insert(self, rows, **kwargs): + """Custom insert to ensure dependencies are added to each row.""" + if not isinstance(rows, list): + rows = [rows] + if not isinstance(rows[0], dict): + raise ValueError("Rows must be a list of dicts") + inserts = [] + for row in rows: + if not self._has_matching_pynwb(row): + continue + if not row.get("attempt_id", None): + row["attempt_id"] = self.default_attempt_id + row["dependencies"] = self.pip_deps + inserts.append(row) + super().insert(inserts, **kwargs) + + def attempt_all(self, attempt_id=None): + if not attempt_id: + attempt_id = self.default_attempt_id inserts = [ {**key, "attempt_id": attempt_id, "dependencies": self.pip_deps} for key in RecordingVersions().this_env.fetch("KEY", as_dict=True) ] self.insert(inserts, skip_duplicates=True) + # --- Gatekeep recompute attempts --- + + @cached_property # Ok to cache b/c static to python runtime + def this_env(self): + """Restricted table matching pynwb env and pip env.""" + restr = [] + for key in RecordingVersions().this_env * self: + if key["dependencies"] != self.pip_deps: + continue + pk = {k: v for k, v in key.items() if k in self.primary_key} + restr.append(pk) + return self & restr + + def _sort_dict(self, d): + return dict(sorted(d.items())) + + def _has_matching_pynwb(self, key, show_err=True) -> bool: + """Check current env for matching pynwb versions.""" + this_rec = {"recording_id": key["recording_id"]} + ret = RecordingVersions().this_env & key + if not ret and show_err: + need = self._sort_dict(RecordingVersions().key_env(key)) + have = self._sort_dict(RecordingVersions().valid_ver_restr) + logger.warning( + f"PyNWB version mismatch. Skipping key: {this_rec}" + + f"\n\tHave: {have}" + + f"\n\tNeed: {need}" + ) + return bool(ret) + + def _has_matching_pip(self, key, show_err=True) -> bool: + """Check current env for matching pip versions.""" + this_rec = {"recording_id": key["recording_id"]} + query = self.this_env & key + + if not len(query) == 1: + raise ValueError(f"Query returned {len(query)} entries: {query}") + + need = query.fetch1("dependencies") + ret = need == self.pip_deps + + if not ret and show_err: + logger.error( + f"Pip version mismatch. Skipping key: {this_rec}" + + f"\n\tHave: {self.pip_deps}" + + f"\n\tNeed: {need}" + ) + + return ret + + def _has_matching_env(self, key, show_err=True) -> bool: + """Check current env for matching pynwb and pip versions.""" + return self._has_matching_pynwb( + key, show_err=show_err + ) and self._has_matching_pip(key, show_err=show_err) + @schema class RecordingRecompute(dj.Computed): @@ -118,7 +218,6 @@ class Name(dj.Part): definition = """ -> master name : varchar(255) - --- missing_from: enum('old', 'new') """ @@ -131,115 +230,213 @@ class Hash(dj.Part): new=null: longblob """ - key_source = RecordingVersions().this_env * RecordingRecomputeSelection - old_dir = Path(analysis_dir) - # see /stelmo/cbroz/temp_rcp/ for existing recomputed - new_dir = Path(temp_dir) / "spikesort_v1_recompute" - ignore_files = set() + def get_objs(self, key, obj_name=None): + old, new = (self & key).fetch1("old", "new") + if old is not None and new is not None: + return old, new + old, new = RecordingRecompute()._open_files(key) + this_obj = obj_name or key["name"] + return old.get(this_obj, None), new.get(this_obj, None) + + def compare(self, key, obj_name=None): + old, new = self.get_objs(key, obj_name=obj_name) + return H5pyComparator(old=old, new=new) + + key_source = RecordingRecomputeSelection().this_env + _key_cache = {} + _hasher_cache = {} + _files_cache = {} + _cleanup_registered = False + + @property + def with_names(self) -> dj.expression.QueryExpression: + """Return tables joined with analysis file names.""" + return self * SpikeSortingRecording.proj("analysis_file_name") + + # --- Cache management --- + + def _cleanup(self) -> None: + """Close all open files.""" + for file in self._file_cache.values(): + file.close() + self._file_cache = {} + for hasher in self._hasher_cache.values(): + hasher.cleanup() + if self._cleanup_registered: + atexit.unregister(self._cleanup) + self._cleanup_registered = False + + def _open_files(self, key) -> Tuple[h5py_File, h5py_File]: + """Open old and new files for comparison.""" + if not self._cleanup_registered: + atexit.register(self._cleanup) + self._cleanup_registered = True + + old, new = self._get_paths(key, as_str=True) + if old not in self._file_cache: + self._file_cache[old] = h5py_File(old, "r") + if new not in self._file_cache: + self._file_cache[new] = h5py_File(new, "r") + + return self._file_cache[old], self._file_cache[new] + + def _hash_one(self, path, precision): + """Return the hasher for a given path. Store in cache.""" + cache_val = f"{path}_{precision}" + if cache_val in self._hasher_cache: + return self._hasher_cache[cache_val] + hasher = NwbfileHasher( + path, + verbose=False, + keep_obj_hash=True, + keep_file_open=True, + precision_lookup=precision, + ) + self._hasher_cache[cache_val] = hasher + return hasher + + # --- Path management --- - def _get_subdir(self, key): + def _get_subdir(self, key) -> Path: + """Return the analysis file's subdirectory.""" file = key["analysis_file_name"] if isinstance(key, dict) else key parts = file.split("_") subdir = "_".join(parts[:-1]) return subdir + "/" + file - def _get_paths(self, key): - old = self.old_dir / self._get_subdir(key) - new = self.new_dir / self._get_subdir(key) - return old, new + def _get_paths(self, key, as_str=False) -> Tuple[Path, Path]: + """Return the old and new file paths.""" + key = self.get_parent_key(key) - def _hash_existing(self): - for nwb in self.new_dir.rglob("*.nwb"): - if nwb.with_suffix(".hash").exists(): - continue - try: - logger.info(f"Hashing {nwb}") - _ = self._hash_one(nwb) - except (OSError, ValueError, RuntimeError) as e: - logger.warning(f"Error: {e.__class__.__name__}: {nwb.name}") - continue + old = Path(analysis_dir) / self._get_subdir(key) + new = ( + Path(temp_dir) + / "spikesort_v1_recompute" + / key.get("attempt_id", "") + / self._get_subdir(key) + ) - def _hash_one(self, path): - hasher = NwbfileHasher(path, verbose=False) - with open(path.with_suffix(".hash"), "w") as f: - f.write(hasher.hash) - return hasher.hash + return (str(old), str(new)) if as_str else (old, new) - def make(self, key): - if self & key: - return + # --- Database checks --- + def get_parent_key(self, key) -> dict: + """Return the parent key for a given recompute key.""" + key = { + k: v + for k, v in key.items() + if k in RecordingRecomputeSelection.primary_key + } + hashed = key_hash(key) + if hashed in self._key_cache: + return self._key_cache[hashed] parent = ( SpikeSortingRecording * RecordingVersions * RecordingRecomputeSelection & key ).fetch1() + self._key_cache[hashed] = parent + return parent + + def _other_roundings(self, key, less_than=False): + """Return other planned precision recompute attempts. + + Parameters + ---------- + key : dict + Key for the current recompute attempt. + less_than : bool + Default False. + If True, return attempts with lower precision than key. + If False, return attempts with higher precision. + """ + operator = "<" if less_than else "!=" + return ( + RecordingRecomputeSelection() + & {k: v for k, v in key.items() if k != "rounding"} + & f'rounding {operator} "{key["rounding"]}"' + ).proj() - self + + def _has_other_roundings(self, key, less_than=False): + """Check if other planned precision recompute attempts exist.""" + return bool(self._other_roundings(key, less_than=less_than)) + + # --- Recompute --- + + def _recompute(self, key) -> Union[None, dict]: + """Attempt to recompute the analysis file. Catch common errors.""" + + _, new = self._get_paths(key) + parent = self.get_parent_key(key) + + try: + new_vals = SpikeSortingRecording()._make_file( + parent, + recompute_file_name=parent["analysis_file_name"], + save_to=new.parent.parent, + rounding=key.get("rounding", 8), + ) + except RuntimeError as e: + logger.warning(f"{e}: {new.name}") + except ValueError as e: + e_info = e.args[0] + if "probe info" in e_info: # make failed bc missing probe info + self.insert1(dict(key, matched=False, diffs=e_info)) + else: + logger.warning(f"ValueError: {e}: {new.name}") + except KeyError as err: + e_info = err.args[0] + if "MISSING" in e_info: # make failed bc missing parent file + e = e_info.split("MISSING")[0].strip() + self.insert1(dict(key, matched=False, diffs=e)) + self.Name().insert1( + dict(key, name=f"Parent missing {e}", missing_from="old") + ) + else: + logger.warning(f"KeyError: {err}: {new.name}") + else: + return new_vals + + def make(self, key): + parent = self.get_parent_key(key) + rounding = key.get("rounding") # Ensure dependencies unchanged since selection insert - parent_deps = parent["dependencies"] - for dep, ver in RecordingRecomputeSelection().pip_deps.items(): - if parent_deps.get(dep, None) != ver: - raise ValueError(f"Please run this key with {parent_deps}") + if not RecordingRecomputeSelection()._has_matching_env(key): + return old, new = self._get_paths(parent) if not new.exists(): # if the file has yet to be recomputed - try: - new_vals = SpikeSortingRecording()._make_file( - parent, - recompute_file_name=parent["analysis_file_name"], - save_to=new, - ) - except RuntimeError as e: - print(f"{e}: {new.name}") - return - except ValueError as e: - e_info = e.args[0] - if "probe info" in e_info: # make failed bc missing probe info - self.insert1(dict(key, matched=False, diffs=e_info)) - else: - print(f"ValueError: {e}: {new.name}") + new_hash = self._recompute(key)["file_hash"] + if new_vals is None: # Error occurred return - except KeyError as err: - e_info = err.args[0] - if "MISSING" in e_info: # make failed bc missing parent file - e = e_info.split("MISSING")[0].strip() - self.insert1(dict(key, matched=False, diffs=e)) - self.Name().insert1( - dict( - key, name=f"Parent missing {e}", missing_from="old" - ) - ) - else: - logger.warning(f"KeyError: {err}: {new.name}") - return - with open(new.with_suffix(".hash"), "w") as f: - f.write(new_vals["file_hash"]) - - # TODO: how to check the env used to make the file when reading? - elif new.with_suffix(".hash").exists(): - print(f"\nReading hash {new}") - with open(new.with_suffix(".hash"), "r") as f: - new_vals = dict(file_hash=f.read()) else: - new_vals = dict(file_hash=self._hash_one(new)) - - old_hasher = parent["file_hash"] or NwbfileHasher( - old, verbose=False, keep_obj_hash=True - ) - - if new_vals["file_hash"] == old_hasher.hash: - self.insert1(dict(key, match=True)) - new.unlink(missing_ok=True) + new_hasher = self._hash_one(new, rounding) + new_hash = new_hasher.hash + + old_hasher = self._hash_one(old, rounding) + + if new_hash == old_hasher.hash: + self.insert1(dict(key, matched=True)) + if not self._has_other_roundings(key, less_than=False): + # if no other recompute attempts + new.unlink(missing_ok=True) + elif query := self._other_roundings(key, less_than=True): + logger.info( + f"Matched at {rounding} precision: {new.name}\n" + + "Deleting lesser pecision attempts" + ) + query.delete() return names = [] hashes = [] - logger.info(f"Compparing mismatched {new}") - new_hasher = NwbfileHasher(new, verbose=False, keep_obj_hash=True) - all_objs = set(old_hasher.objs.keys()) | set(new_hasher.objs.keys()) + logger.info(f"Comparing mismatched {new.name}") + new_hasher = self._hash_one(new, rounding) + all_objs = set({**old_hasher.objs, **new_hasher.objs}) for obj in all_objs: old_obj, old_hash = old_hasher.objs.get(obj, (None, None)) @@ -247,10 +444,10 @@ def make(self, key): if old_hash is None: names.append(dict(key, name=obj, missing_from="old")) - elif new_hash is None: + if new_hash is None: names.append(dict(key, name=obj, missing_from="new")) - elif old_hash != new_hash: - hashes.append(dict(key, name=obj, old=old_obj, new=new_obj)) + if old_hash != new_hash: + hashes.append(dict(key, name=obj)) self.insert1(dict(key, matched=False)) self.Name().insert(names) diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 8d8a42e2b..7ed98ceba 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -22,7 +22,7 @@ from spyglass.common.common_lab import LabTeam from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile from spyglass.common.common_nwbfile import schema as nwb_schema -from spyglass.settings import temp_dir, test_mode +from spyglass.settings import analysis_dir, temp_dir, test_mode from spyglass.spikesorting.utils import ( _get_recording_timestamps, get_group_by_shank, @@ -218,6 +218,7 @@ def _make_file( key: dict = None, recompute_file_name: str = None, save_to: Union[str, Path] = None, + rounding: int = 4, ): """Preprocess recording and write to NWB file. @@ -243,6 +244,8 @@ def _make_file( raise ValueError( "Either key or recompute_file_name must be specified." ) + if isinstance(key, dict): + key = {k: v for k, v in key.items() if k in cls.primary_key} file_hash = None recompute = recompute_file_name and not key and not save_to @@ -291,8 +294,11 @@ def _make_file( # TODO: uncomment after review. Commented to avoid impacting database # if recompute: # # AnalysisNwbfile()._update_external(recompute_file_name, file_hash) + precision_lookup = dict(ProcessedElectricalSeries=rounding) file_hash = AnalysisNwbfile().get_hash( - recording_nwb_file_name, from_schema=True # REVERT TO FALSE? + recording_nwb_file_name, + from_schema=True, # REVERT TO FALSE? + precision_lookup=precision_lookup, ) return dict( @@ -667,6 +673,7 @@ def _write_recording_to_nwb( analysis_nwb_file : str name of analysis NWB file containing the preprocessed recording """ + recompute_args = ( recompute_file_name, recompute_object_id, diff --git a/src/spyglass/utils/h5_helper_fn.py b/src/spyglass/utils/h5_helper_fn.py new file mode 100644 index 000000000..e7e2653e1 --- /dev/null +++ b/src/spyglass/utils/h5_helper_fn.py @@ -0,0 +1,83 @@ +"""Helper methods for comparing pynwb objects.""" +from json import loads as json_loads +import h5py + + +class H5pyComparator: + def __init__(self, old, new, line_limit=80): + self.old = self.obj_to_dict(old) + self.new = self.obj_to_dict(new) + self.line_limit = line_limit + self.compare_dicts(self.old, self.new) + + def unpack_scalar(self, obj): + """Unpack a scalar from an h5py dataset.""" + if isinstance(obj, (int, float, str)): + return dict(scalar=obj) + str_obj = str(obj[()]) + if "{" not in str_obj: + return dict(scalar=str_obj) + return json_loads(str_obj) + + def assemble_dict(self, obj): + """Assemble a dictionary from an h5py group.""" + ret = dict() + for k, v in obj.items(): + if isinstance(v, h5py.Dataset): + ret[k] = self.unpack_scalar(v) + elif isinstance(v, h5py.Group): + ret[k] = self.assemble_dict(v) + else: + ret[k] = v + return ret + + def obj_to_dict(self, obj): + """Convert an h5py object to a dictionary.""" + if isinstance(obj, dict): + return {k: self.obj_to_dict(v) for k, v in obj.items()} + if isinstance(obj, (float, str, int, h5py.Dataset)): + return self.unpack_scalar(obj) + if isinstance(obj, h5py.Group): + return self.assemble_dict(obj) + return json_loads(obj) + + def sort_list_of_dicts(self, obj): + """Sort a list of dictionaries.""" + return sorted( + obj, + key=lambda x: sorted(x.keys() if isinstance(x, dict) else str(x)), + ) + + def compare_dict_values(self, key, oval, nval, level, iteration): + """Compare values of a specific key in two dictionaries.""" + if oval != nval: + print(f"{level} {iteration}: dict val differ for {key}") + if isinstance(oval, dict): + self.compare_dicts(oval, nval, f"{level} {key}", iteration + 1) + elif isinstance(oval, list): + self.compare_lists(oval, nval, f"{level} {key}", iteration) + + def compare_lists(self, old_list, new_list, level, iteration): + """Compare two lists of dictionaries.""" + old_sorted = self.sort_list_of_dicts(old_list) + new_sorted = self.sort_list_of_dicts(new_list) + for o, n in zip(old_sorted, new_sorted): + iteration += 1 + if isinstance(o, dict): + self.compare_dicts(o, n, level, iteration) + elif o != n: + print(f"{level} {iteration}: list val differ") + print(f"\t{str(o)[:self.line_limit]}") + print(f"\t{str(n)[:self.line_limit]}") + + def compare_dicts(self, old, new, level="", iteration=0): + """Compare two dictionaries.""" + all_keys = set(old.keys()) | set(new.keys()) + for key in all_keys: + if key not in old: + print(f"{level} {iteration}: old missing key: {key}") + continue + if key not in new: + print(f"{level} {iteration}: new missing key: {key}") + continue + self.compare_dict_values(key, old[key], new[key], level, iteration) diff --git a/src/spyglass/utils/nwb_hash.py b/src/spyglass/utils/nwb_hash.py index b7777c732..872cabc83 100644 --- a/src/spyglass/utils/nwb_hash.py +++ b/src/spyglass/utils/nwb_hash.py @@ -140,6 +140,7 @@ def __init__( batch_size: int = DEFAULT_BATCH_SIZE, precision_lookup: Dict[str, int] = PRECISION_LOOKUP, keep_obj_hash: bool = False, + keep_file_open: bool = False, verbose: bool = True, ): """Hashes the contents of an NWB file. @@ -187,8 +188,9 @@ def __init__( self.hashed = md5("".encode()) self.hash = self.compute_hash() - self.cleanup() - atexit.unregister(self.cleanup) + if not keep_file_open: + self.cleanup() + atexit.unregister(self.cleanup) def cleanup(self): self.file.close() From bfe49d1981871d47edaf556c435810ce798ab773 Mon Sep 17 00:00:00 2001 From: cbroz1 Date: Mon, 3 Mar 2025 09:43:04 -0800 Subject: [PATCH 19/21] WIP: typo --- src/spyglass/spikesorting/v1/recompute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spyglass/spikesorting/v1/recompute.py b/src/spyglass/spikesorting/v1/recompute.py index ec967ca94..28e161ad8 100644 --- a/src/spyglass/spikesorting/v1/recompute.py +++ b/src/spyglass/spikesorting/v1/recompute.py @@ -410,7 +410,7 @@ def make(self, key): if not new.exists(): # if the file has yet to be recomputed new_hash = self._recompute(key)["file_hash"] - if new_vals is None: # Error occurred + if new_hash is None: # Error occurred return else: new_hasher = self._hash_one(new, rounding) From 72f8a25595084338b3996ee8c35f4f14c1031e31 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 4 Mar 2025 12:53:07 -0600 Subject: [PATCH 20/21] WIP: add tests --- src/spyglass/common/common_nwbfile.py | 26 ++-- src/spyglass/spikesorting/v1/recompute.py | 139 ++++++++++++++-------- src/spyglass/spikesorting/v1/recording.py | 91 +++++++++----- src/spyglass/utils/h5_helper_fn.py | 1 + src/spyglass/utils/nwb_hash.py | 4 +- src/spyglass/utils/nwb_helper_fn.py | 2 +- tests/spikesorting/test_recording.py | 12 ++ 7 files changed, 181 insertions(+), 94 deletions(-) diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 4eb37b397..64a676831 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -381,11 +381,11 @@ def get_abs_path( schema.external["analysis"] & f"filepath LIKE '%{analysis_nwb_file_name}'" ) - if len(query) != 1: - raise FileNotFoundError( - f"Found {len(query)} files for: {analysis_nwb_file_name}" - ) - return Path(analysis_dir) / query.fetch1("filepath") + if len(query) == 1: # Else try the standard way + return Path(analysis_dir) / query.fetch1("filepath") + logger.warning( + f"Found {len(query)} files for: {analysis_nwb_file_name}" + ) # If an entry exists in the database get the stored datajoint filepath file_key = {"analysis_file_name": analysis_nwb_file_name} @@ -461,7 +461,8 @@ def get_hash( analysis_file_name: str, from_schema: bool = False, precision_lookup: dict = None, - ) -> str: + return_hasher: bool = False, + ) -> Union[str, NwbfileHasher]: """Return the hash of the file contents. Parameters @@ -471,17 +472,20 @@ def get_hash( from_schema : bool, Optional If true, get the file path from the schema externals table, skipping checksum and file existence checks. Defaults to False. - + return_hasher: bool, Optional + If true, return the hasher object instead of the hash. Defaults to + False. Returns ------- - file_hash : str - The hash of the file contents. + file_hash : [str, NwbfileHasher] + The hash of the file contents or the hasher object itself. """ - return NwbfileHasher( + hasher = NwbfileHasher( self.get_abs_path(analysis_file_name, from_schema=from_schema), precision_lookup=precision_lookup, - ).hash + ) + return hasher if return_hasher else hasher.hash def _update_external(self, analysis_file_name: str, file_hash: str): """Update the external contents checksum for an analysis file. diff --git a/src/spyglass/spikesorting/v1/recompute.py b/src/spyglass/spikesorting/v1/recompute.py index 28e161ad8..3108a16e1 100644 --- a/src/spyglass/spikesorting/v1/recompute.py +++ b/src/spyglass/spikesorting/v1/recompute.py @@ -11,21 +11,17 @@ import atexit from functools import cached_property -from json import loads as json_loads from os import environ as os_environ from pathlib import Path from typing import Tuple, Union import datajoint as dj -import h5py import pynwb from datajoint.hash import key_hash from h5py import File as h5py_File from hdmf import __version__ as hdmf_version from hdmf.build import TypeMap -from hdmf.spec import NamespaceCatalog from numpy import __version__ as np_version -from pynwb.spec import NWBDatasetSpec, NWBGroupSpec, NWBNamespace from spikeinterface import __version__ as si_version from spyglass.common import AnalysisNwbfile @@ -83,7 +79,14 @@ def namespace_dict(self, type_map: TypeMap): def make(self, key): """Inventory the namespaces present in an analysis file.""" - parent = (SpikeSortingRecording & key).fetch1() + query = SpikeSortingRecording() & key + if not len(query) == 1: + raise ValueError( + f"SpikeSortingRecording & {key} has {len(query)} " + + f"matching entries: {query}" + ) + + parent = query.fetch1() path = AnalysisNwbfile().get_abs_path(parent["analysis_file_name"]) insert = {**key, **get_file_namespaces(path)} @@ -94,7 +97,7 @@ def make(self, key): script = str(script[()]).split("=")[1].strip().replace("'", "") insert["spyglass"] = script - self.insert1(insert) + self.insert1(insert, allow_direct_insert=True) @schema @@ -104,10 +107,15 @@ class RecordingRecomputeSelection(SpyglassMixin, dj.Manual): attempt_id: varchar(32) # name for environment used to attempt recompute rounding=8: int # rounding for float ElectricalSeries --- - dependencies: blob # dict of pip dependencies + logged_at_creation=0: bool # whether the attempt was logged at creation + pip_deps: blob # dict of pip dependencies + nwb_deps: blob # dict of pynwb dependencies """ # --- Insert helpers --- + @cached_property + def default_rounding(self): + return int(self.heading.attributes["rounding"].default) @cached_property def default_attempt_id(self): @@ -124,20 +132,34 @@ def pip_deps(self): numpy=np_version, ) - def insert(self, rows, **kwargs): + def key_pk(self, key): + """Return the current recording_id.""" + return {"recording_id": key["recording_id"]} + + def insert(self, rows, at_creation=False, **kwargs): """Custom insert to ensure dependencies are added to each row.""" + # rows = rows.copy() if not isinstance(rows, list): rows = [rows] if not isinstance(rows[0], dict): raise ValueError("Rows must be a list of dicts") + inserts = [] for row in rows: - if not self._has_matching_pynwb(row): + key_pk = self.key_pk(row) + if not RecordingVersions & key_pk: # ensure in parent table + RecordingVersions().make(key_pk) + if not self._has_matching_pynwb(key_pk): continue - if not row.get("attempt_id", None): - row["attempt_id"] = self.default_attempt_id - row["dependencies"] = self.pip_deps - inserts.append(row) + inserts.append( + dict( + **key_pk, + rounding=row.get("rounding", self.default_rounding), + attempt_id=row.get("attempt_id", self.default_attempt_id), + dependencies=self.pip_deps, + logged_at_creation=at_creation, + ) + ) super().insert(inserts, **kwargs) def attempt_all(self, attempt_id=None): @@ -151,11 +173,19 @@ def attempt_all(self, attempt_id=None): # --- Gatekeep recompute attempts --- - @cached_property # Ok to cache b/c static to python runtime + @cached_property def this_env(self): - """Restricted table matching pynwb env and pip env.""" + """Restricted table matching pynwb env and pip env. + + Serves as key_source for RecordingRecompute. Ensures that recompute + attempts are only made when the pynwb and pip environments match the + records. Also skips files whose environment was logged on creation. + """ + restr = [] - for key in RecordingVersions().this_env * self: + for key in RecordingVersions().this_env * ( + self & "logged_at_creation=0" + ): if key["dependencies"] != self.pip_deps: continue pk = {k: v for k, v in key.items() if k in self.primary_key} @@ -167,13 +197,13 @@ def _sort_dict(self, d): def _has_matching_pynwb(self, key, show_err=True) -> bool: """Check current env for matching pynwb versions.""" - this_rec = {"recording_id": key["recording_id"]} + key_pk = self.key_pk(key) ret = RecordingVersions().this_env & key if not ret and show_err: need = self._sort_dict(RecordingVersions().key_env(key)) have = self._sort_dict(RecordingVersions().valid_ver_restr) logger.warning( - f"PyNWB version mismatch. Skipping key: {this_rec}" + f"PyNWB version mismatch. Skipping key: {key_pk}" + f"\n\tHave: {have}" + f"\n\tNeed: {need}" ) @@ -242,7 +272,9 @@ def compare(self, key, obj_name=None): old, new = self.get_objs(key, obj_name=obj_name) return H5pyComparator(old=old, new=new) - key_source = RecordingRecomputeSelection().this_env + # TODO: debug key source issues + key_source = RecordingRecomputeSelection().this_env.proj() + # key_source = RecordingRecomputeSelection() & "logged_at_creation=0" _key_cache = {} _hasher_cache = {} _files_cache = {} @@ -333,34 +365,41 @@ def get_parent_key(self, key) -> dict: parent = ( SpikeSortingRecording * RecordingVersions - * RecordingRecomputeSelection + * RecordingRecomputeSelection.proj() & key ).fetch1() self._key_cache[hashed] = parent return parent - def _other_roundings(self, key, less_than=False): + def _other_roundings( + self, key, operator="<" + ) -> dj.expression.QueryExpression: """Return other planned precision recompute attempts. Parameters ---------- key : dict Key for the current recompute attempt. - less_than : bool - Default False. - If True, return attempts with lower precision than key. - If False, return attempts with higher precision. + operator : str, optional + Comparator for rounding field. + Default 'less than', return attempts with lower precision than key. + Also accepts '!=' or '>'. """ - operator = "<" if less_than else "!=" return ( RecordingRecomputeSelection() & {k: v for k, v in key.items() if k != "rounding"} & f'rounding {operator} "{key["rounding"]}"' ).proj() - self - def _has_other_roundings(self, key, less_than=False): - """Check if other planned precision recompute attempts exist.""" - return bool(self._other_roundings(key, less_than=less_than)) + def _is_lower_rounding(self, key) -> bool: + """Check for lesser precision recompute attempts after match.""" + this_key = {k: v for k, v in key.items() if k != "rounding"} + has_match = bool(self & this_key & "matched=1") + return ( + False + if not has_match # Only if match, report True of lower precision + else bool(self._other_roundings(key) & key) + ) # --- Recompute --- @@ -398,47 +437,43 @@ def _recompute(self, key) -> Union[None, dict]: else: return new_vals - def make(self, key): + def make(self, key, force_check=False): parent = self.get_parent_key(key) rounding = key.get("rounding") # Ensure dependencies unchanged since selection insert if not RecordingRecomputeSelection()._has_matching_env(key): return + # Ensure not duplicate work for lesser precision + if self._is_lower_rounding(key) and not force_check: + logger.warning( + f"Match at higher precision. Assuming match for {key}\n\t" + + "Run with force_check=True to recompute." + ) old, new = self._get_paths(parent) - if not new.exists(): # if the file has yet to be recomputed - new_hash = self._recompute(key)["file_hash"] - if new_hash is None: # Error occurred - return - else: - new_hasher = self._hash_one(new, rounding) - new_hash = new_hasher.hash + new_hasher = ( + self._hash_one(new, rounding) + if new.exists() + else self._recompute(key)["file_hash"] + ) + + if new_hasher is None: # Error occurred during recompute + return old_hasher = self._hash_one(old, rounding) - if new_hash == old_hasher.hash: + if new_hasher.hash == old_hasher.hash: self.insert1(dict(key, matched=True)) - if not self._has_other_roundings(key, less_than=False): + if not self._has_other_roundings(key, operator="!="): # if no other recompute attempts new.unlink(missing_ok=True) - elif query := self._other_roundings(key, less_than=True): - logger.info( - f"Matched at {rounding} precision: {new.name}\n" - + "Deleting lesser pecision attempts" - ) - query.delete() - return - - names = [] - hashes = [] logger.info(f"Comparing mismatched {new.name}") - new_hasher = self._hash_one(new, rounding) - all_objs = set({**old_hasher.objs, **new_hasher.objs}) - for obj in all_objs: + names, hashes = [], [] + for obj in set({**old_hasher.objs, **new_hasher.objs}): old_obj, old_hash = old_hasher.objs.get(obj, (None, None)) new_obj, new_hash = new_hasher.objs.get(obj, (None, None)) diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 7ed98ceba..b276fb54d 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -21,13 +21,12 @@ ) from spyglass.common.common_lab import LabTeam from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile -from spyglass.common.common_nwbfile import schema as nwb_schema -from spyglass.settings import analysis_dir, temp_dir, test_mode +from spyglass.settings import analysis_dir, test_mode from spyglass.spikesorting.utils import ( _get_recording_timestamps, get_group_by_shank, ) -from spyglass.utils import SpyglassMixin, SpyglassMixinPart, logger +from spyglass.utils import SpyglassMixin, logger from spyglass.utils.nwb_hash import NwbfileHasher schema = dj.schema("spikesorting_v1_recording") @@ -173,8 +172,8 @@ class SpikeSortingRecording(SpyglassMixin, dj.Computed): --- -> AnalysisNwbfile object_id: varchar(40) # Object ID for the processed recording in NWB file - electrodes_id='': varchar(40) # Object ID for the processed electrodes - file_hash='': varchar(32) # Hash of the NWB file + electrodes_id=null: varchar(40) # Object ID for the processed electrodes + file_hash=null: varchar(32) # Hash of the NWB file dependencies=null: blob # dict of dependencies (pynwb, hdmf, spikeinterface) """ @@ -211,6 +210,7 @@ def make(self, key): AnalysisNwbfile().add(nwb_file_name, key["analysis_file_name"]) self.insert1(key) + self._record_environment(key) @classmethod def _make_file( @@ -249,10 +249,12 @@ def _make_file( file_hash = None recompute = recompute_file_name and not key and not save_to - file_path = AnalysisNwbfile.get_abs_path( - recompute_file_name, from_schema=True - ) - if recompute: + + if recompute or save_to: # if we expect file to exist + file_path = AnalysisNwbfile.get_abs_path( + recompute_file_name, from_schema=True + ) + if recompute: # If recompute, check if file exists if Path(file_path).exists(): # No need to recompute return logger.info(f"Recomputing {recompute_file_name}.") @@ -261,19 +263,8 @@ def _make_file( key, recompute_object_id, recompute_electrodes_id, file_hash = ( query.fetch1("KEY", "object_id", "electrodes_id", "file_hash") ) - elif save_to: # recompute prior to deletion - elect_attr = "acquisition/ProcessedElectricalSeries/electrodes" - if not Path(file_path).exists(): - raise FileNotFoundError(f"File {analysis_path} not found.") - - # ensure partial objects are present in existing file - with H5File(file_path, "r") as f: - elect_parts = elect_attr.split("/") - for i in range(len(elect_parts)): - mid_path = "/".join(elect_parts[: i + 1]) - if mid_path not in f.keys(): - raise KeyError(f"{mid_path} MISSING {analysis_path}") - elect_id = f[elect_attr].attrs["object_id"] + elif save_to: # recompute prior to deletion, save copy to temp_dir + elect_id = cls._validate_file(file_path) obj_id = (cls & key).fetch1("object_id") recompute_object_id, recompute_electrodes_id = obj_id, elect_id else: @@ -291,16 +282,17 @@ def _make_file( ) ) - # TODO: uncomment after review. Commented to avoid impacting database - # if recompute: - # # AnalysisNwbfile()._update_external(recompute_file_name, file_hash) - precision_lookup = dict(ProcessedElectricalSeries=rounding) file_hash = AnalysisNwbfile().get_hash( recording_nwb_file_name, from_schema=True, # REVERT TO FALSE? - precision_lookup=precision_lookup, + precision_lookup=rounding, + return_hasher=bool(save_to), ) + # NOTE: Conditional to avoid impacting database. NO MERGE! + if recompute and test_mode: + AnalysisNwbfile()._update_external(recompute_file_name, file_hash) + return dict( analysis_file_name=recording_nwb_file_name, object_id=recording_object_id, @@ -313,6 +305,49 @@ def _make_file( ), ) + @classmethod + def _validate_file(self, file_path: str) -> str: + """Validate the NWB file exists and contains required upstream data. + + Parameters + ---------- + file_path : str + path to the NWB file to validate + + Returns + ------- + elect_id : str + ProcessedElectricalSeries/electrodes object ID + + Raises + ------ + FileNotFoundError + If the file does not exist + KeyError + If the file does not contain electrodes or upstream objects + """ + elect_attr = "acquisition/ProcessedElectricalSeries/electrodes" + if not Path(file_path).exists(): + raise FileNotFoundError(f"File {file_path} not found.") + + # ensure partial objects are present in existing file + with H5File(file_path, "r") as f: + elect_parts = elect_attr.split("/") + for i in range(len(elect_parts)): + mid_path = "/".join(elect_parts[: i + 1]) + if mid_path in f.keys(): + continue + raise KeyError(f"{mid_path} MISSING {file_path}") + elect_id = f[elect_attr].attrs["object_id"] + + return elect_id + + def _record_environment(self, key): + """Record environment details for this recording.""" + from spyglass.spikesorting.v1 import recompute as rcp + + rcp.RecordingRecomputeSelection().insert(key) + @classmethod def get_recording(cls, key: dict) -> si.BaseRecording: """Get recording related to this curation as spikeinterface BaseRecording @@ -553,8 +588,6 @@ def _get_preprocessed_recording(self, key: dict): return dict(recording=recording, timestamps=np.asarray(timestamps)) - return obj_id, elect_id - def update_ids(self): """Update electrodes_id, and file_hash in SpikeSortingRecording table. diff --git a/src/spyglass/utils/h5_helper_fn.py b/src/spyglass/utils/h5_helper_fn.py index e7e2653e1..9a9acbabb 100644 --- a/src/spyglass/utils/h5_helper_fn.py +++ b/src/spyglass/utils/h5_helper_fn.py @@ -1,4 +1,5 @@ """Helper methods for comparing pynwb objects.""" + from json import loads as json_loads import h5py diff --git a/src/spyglass/utils/nwb_hash.py b/src/spyglass/utils/nwb_hash.py index 872cabc83..7602a1f92 100644 --- a/src/spyglass/utils/nwb_hash.py +++ b/src/spyglass/utils/nwb_hash.py @@ -88,7 +88,7 @@ def compute_hash(self) -> str: for file_path in tqdm(all_files, disable=not self.verbose): if file_path.suffix == ".nwb": - hasher = NwbfileHasher(file_path, batch_size=batch_size) + hasher = NwbfileHasher(file_path, batch_size=self.batch_size) self.hashed.update(hasher.hash.encode()) elif file_path.suffix == ".json": self.hashed.update(self.json_encode(file_path)) @@ -177,6 +177,8 @@ def __init__( self.file = h5py.File(path, "r") atexit.register(self.cleanup) + if precision_lookup is None: + precision_lookup = PRECISION_LOOKUP if isinstance(precision_lookup, int): precision_lookup = dict(ProcessedElectricalSeries=precision_lookup) diff --git a/src/spyglass/utils/nwb_helper_fn.py b/src/spyglass/utils/nwb_helper_fn.py index 0c5b3406f..5682b477b 100644 --- a/src/spyglass/utils/nwb_helper_fn.py +++ b/src/spyglass/utils/nwb_helper_fn.py @@ -57,7 +57,7 @@ def get_nwb_file(nwb_file_path): nwbfile : pynwb.NWBFile NWB file object for the given path opened in read mode. """ - if not nwb_file_path.startswith("/"): + if not str(nwb_file_path).startswith("/"): from spyglass.common import Nwbfile nwb_file_path = Nwbfile.get_abs_path(nwb_file_path) diff --git a/tests/spikesorting/test_recording.py b/tests/spikesorting/test_recording.py index 89e905ee1..120997a22 100644 --- a/tests/spikesorting/test_recording.py +++ b/tests/spikesorting/test_recording.py @@ -29,3 +29,15 @@ def test_recompute(spike_v1, pop_rec, common): pre.object_id == post.object_id and pre.electrodes.object_id == post.electrodes.object_id ), "Recompute failed to preserve object_ids" + + +def test_recompute_env(spike_v1, pop_rec): + """Test recompute to temp_dir""" + from spyglass.spikesorting.v1 import recompute + + key = spike_v1.SpikeSortingRecording().fetch("KEY", as_dict=True)[0] + + recompute.RecordingRecompute().populate(key) + + ret = (recompute.RecordingRecompute() & key).fetch1("matched") + assert ret, "Recompute failed" From 9c27d87c967954dc7f6e48e6c33f179d237c96b8 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 5 Mar 2025 13:08:00 -0600 Subject: [PATCH 21/21] WIP: start add V0 hasher --- CHANGELOG.md | 13 +- docs/mkdocs.yml | 1 + docs/src/Features/Recompute.md | 55 +++++++ src/spyglass/common/common_nwbfile.py | 8 +- src/spyglass/spikesorting/v0/__init__.py | 39 +++++ .../spikesorting/v0/spikesorting_recompute.py | 148 ++++++++++++++++++ .../spikesorting/v0/spikesorting_recording.py | 83 +++++++--- src/spyglass/spikesorting/v1/__init__.py | 30 ++++ src/spyglass/spikesorting/v1/recompute.py | 43 +++-- src/spyglass/spikesorting/v1/recording.py | 32 ++-- src/spyglass/utils/nwb_hash.py | 38 +++-- 11 files changed, 415 insertions(+), 75 deletions(-) create mode 100644 docs/src/Features/Recompute.md create mode 100644 src/spyglass/spikesorting/v0/spikesorting_recompute.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c0bfdb587..14219a397 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # Change Log -## [0.5.5] (Unreleased) +## \[0.5.5\] (Unreleased) ### Release Notes @@ -10,14 +10,17 @@ ```python import datajoint as dj -from spyglass.spikesorting.v1.recording import * # noqa +from spyglass.spikesorting.v1 import recording as v1rec # noqa +from spyglass.spikesorting.v0 import spikesorting_recording as v0rec # noqa from spyglass.linearization.v1.main import * # noqa dj.FreeTable(dj.conn(), "common_nwbfile.analysis_nwbfile_log").drop() dj.FreeTable(dj.conn(), "common_session.session_group").drop() TrackGraph.alter() # Add edge map parameter -SpikeSortingRecording().alter() -SpikeSortingRecording().update_ids() +v0rec.SpikeSortingRecording().alter() +v0rec.SpikeSortingRecording().update_ids() +v1rec.SpikeSortingRecording().alter() +v1rec.SpikeSortingRecording().update_ids() ``` ### Infrastructure @@ -28,7 +31,7 @@ SpikeSortingRecording().update_ids() - Improve cron job documentation and script #1226, #1241 - Update export process to include `~external` tables #1239 - Only add merge parts to `source_class_dict` if present in codebase #1237 - - Add recompute ability for `SpikeSortingRecording` #1093 +- Add recompute ability for `SpikeSortingRecording` #1093 ### Pipelines diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 3cda1b344..127377b68 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -75,6 +75,7 @@ nav: - Merge Tables: Features/Merge.md - Export: Features/Export.md - Centralized Code: Features/Mixin.md + - Recompute: Features/Recompute.md - For Developers: - Overview: ForDevelopers/index.md - How to Contribute: ForDevelopers/Contribute.md diff --git a/docs/src/Features/Recompute.md b/docs/src/Features/Recompute.md new file mode 100644 index 000000000..3aa855859 --- /dev/null +++ b/docs/src/Features/Recompute.md @@ -0,0 +1,55 @@ +# Recompute + +## Why + +Some analysis files that are generated by Spyglass are very unlikely to be +reaccessed. Those generated by `SpikeSortingRecording` tables were identified as +taking up tens of terabytes of space, while very seldom accessed after their +first generation. By finding a way to recompute these files on demand, we can +save significant server space at the cost of an unlikely 10m of recompute time +per file. + +Spyglass 0.5.5 introduces the opportunity to delete and recompute both newly +generated files after this release, and old files that were generated before +this release. + +## How + +`SpikeSortingRecording` has a new `_make_file` method that will be called in the +event a file is accessed but not found. This method will generate the file and +compare it's hash to the hash of the file that was expected. If the hashes +match, the file will be saved and returned. If the hashes do not match, the file +will be deleted and an error raised. For steps to avoid such errors, see the +steps below. + +### New files + +Newly generated files will automatically record information about their +dependencies and the code that generated them in `RecomputeSelection` tables. To +see the dependencies of a file, you can access `RecordingRecomputeSelection` + +```python +from spyglass.spikesorting.v1 import recompute as v1_recompute + +v1_recompute.RecordingRecomputeSelection() +``` + +### Old files + +To ensure the replicability of old files prior to deletion, we'll need to... + +1. Update the tables for new fields. +2. Attempt file recompute, and record dependency info for successful attempts. + + + +```python +from spyglass.spikesorting.v0 import spikesorting_recording as v0_recording +from spyglass.spikesorting.v1 import recording as v1_recording + +# Alter tables to include new fields, updating values +v0_recording.SpikeSortingRecording().alter() +v0_recording.SpikeSortingRecording().update_ids() +v1_recording.SpikeSortingRecording().alter() +v1_recording.SpikeSortingRecording().update_ids() +``` diff --git a/src/spyglass/common/common_nwbfile.py b/src/spyglass/common/common_nwbfile.py index 64a676831..ab4dff9b1 100644 --- a/src/spyglass/common/common_nwbfile.py +++ b/src/spyglass/common/common_nwbfile.py @@ -478,7 +478,7 @@ def get_hash( Returns ------- - file_hash : [str, NwbfileHasher] + hash : [str, NwbfileHasher] The hash of the file contents or the hasher object itself. """ hasher = NwbfileHasher( @@ -487,7 +487,7 @@ def get_hash( ) return hasher if return_hasher else hasher.hash - def _update_external(self, analysis_file_name: str, file_hash: str): + def _update_external(self, analysis_file_name: str, hash: str): """Update the external contents checksum for an analysis file. USE WITH CAUTION. If the hash does not match the file contents, the file @@ -497,7 +497,7 @@ def _update_external(self, analysis_file_name: str, file_hash: str): ---------- analysis_file_name : str The name of the analysis NWB file. - file_hash : str + hash : str The hash of the file contents as calculated by NwbfileHasher. If the hash does not match the file contents, the file and downstream entries are deleted. @@ -505,7 +505,7 @@ def _update_external(self, analysis_file_name: str, file_hash: str): file_path = self.get_abs_path(analysis_file_name, from_schema=True) new_hash = self.get_hash(analysis_file_name, from_schema=True) - if file_hash != new_hash: + if hash != new_hash: Path(file_path).unlink() # remove mismatched file # force delete, including all downstream, forcing permissions del_kwargs = dict(force_permission=True, safemode=False) diff --git a/src/spyglass/spikesorting/v0/__init__.py b/src/spyglass/spikesorting/v0/__init__.py index 8b6035023..dbf857d3a 100644 --- a/src/spyglass/spikesorting/v0/__init__.py +++ b/src/spyglass/spikesorting/v0/__init__.py @@ -30,6 +30,10 @@ SpikeSortingPipelineParameters, spikesorting_pipeline_populator, ) +from spyglass.spikesorting.v0.spikesorting_recompute import ( # noqa: F401 + RecordingRecompute, + RecordingRecomputeSelection, +) from spyglass.spikesorting.v0.spikesorting_recording import ( # noqa: F401 SortGroup, SortInterval, @@ -42,3 +46,38 @@ SpikeSorting, SpikeSortingSelection, ) + +__all__ = [ + "ArtifactDetection", + "ArtifactDetectionParameters", + "ArtifactDetectionSelection", + "ArtifactRemovedIntervalList", + "AutomaticCuration", + "AutomaticCurationParameters", + "AutomaticCurationSelection", + "CuratedSpikeSorting", + "CuratedSpikeSortingSelection", + "Curation", + "CurationFigurl", + "CurationFigurlSelection", + "MetricParameters", + "MetricSelection", + "QualityMetrics", + "RecordingRecompute", + "RecordingRecomputeSelection", + "SortGroup", + "SortInterval", + "SortingviewWorkspace", + "SortingviewWorkspaceSelection", + "SpikeSorterParameters", + "SpikeSorting", + "SpikeSortingPipelineParameters", + "SpikeSortingPreprocessingParameters", + "SpikeSortingRecording", + "SpikeSortingRecordingSelection", + "SpikeSortingSelection", + "WaveformParameters", + "WaveformSelection", + "Waveforms", + "spikesorting_pipeline_populator", +] diff --git a/src/spyglass/spikesorting/v0/spikesorting_recompute.py b/src/spyglass/spikesorting/v0/spikesorting_recompute.py new file mode 100644 index 000000000..bda5eb546 --- /dev/null +++ b/src/spyglass/spikesorting/v0/spikesorting_recompute.py @@ -0,0 +1,148 @@ +"""This schema is used to track recompute capabilities for existing files.""" + +from functools import cached_property +from os import environ as os_environ + +import datajoint as dj +from numpy import __version__ as np_version +from probeinterface import __version__ as pi_version +from spikeinterface import __version__ as si_version + +from spyglass.spikesorting.v0.spikesorting_recording import ( + SpikeSortingRecording, +) # noqa F401 +from spyglass.utils import logger +from spyglass.utils.nwb_hash import DirectoryHasher + +schema = dj.schema("cbroz_temp_v0") + + +@schema +class RecordingRecomputeSelection(dj.Manual): + definition = """ + -> SpikeSortingRecording + --- + logged_at_creation=0: bool + pip_deps: blob # dict of pip dependencies + """ + + @cached_property + def default_attempt_id(self): + user = dj.config["database.user"] + conda = os_environ.get("CONDA_DEFAULT_ENV", "base") + return f"{user}_{conda}" + + @cached_property + def pip_deps(self): + return dict( + spikeinterface=si_version, + probeinterface=pi_version, + numpy=np_version, + ) + + def key_pk(self, key): + return {k: key[k] for k in self.primary_key} + + def insert(self, rows, at_creation=False, **kwargs): + """Custom insert to ensure dependencies are added to each row.""" + if not isinstance(rows, list): + rows = [rows] + if not isinstance(rows[0], dict): + raise ValueError("Rows must be a list of dicts") + + inserts = [] + for row in rows: + key_pk = self.key_pk(row) + inserts.append( + dict( + **key_pk, + attempt_id=row.get("attempt_id", self.default_attempt_id), + dependencies=self.pip_deps, + logged_at_creation=at_creation, + ) + ) + super().insert(inserts, **kwargs) + + # --- Gatekeep recompute attempts --- + + @cached_property + def this_env(self): + """Restricted table matching pynwb env and pip env. + + Serves as key_source for RecordingRecompute. Ensures that recompute + attempts are only made when the pynwb and pip environments match the + records. Also skips files whose environment was logged on creation. + """ + + restr = [] + for key in self & "logged_at_creation=0": + if key["dependencies"] != self.pip_deps: + continue + restr.append(self.key_pk(key)) + return self & restr + + def _has_matching_pip(self, key, show_err=True) -> bool: + """Check current env for matching pip versions.""" + query = self.this_env & key + + if not len(query) == 1: + raise ValueError(f"Query returned {len(query)} entries: {query}") + + need = query.fetch1("dependencies") + ret = need == self.pip_deps + + if not ret and show_err: + logger.error( + f"Pip version mismatch. Skipping key: {self.key_pk(key)}" + + f"\n\tHave: {self.pip_deps}" + + f"\n\tNeed: {need}" + ) + + return ret + + +@schema +class RecordingRecompute(dj.Computed): + definition = """ + -> RecordingRecomputeSelection + --- + matched:bool + """ + + _hasher_cache = dict() + + class Name(dj.Part): + definition = """ # File names missing from old or new versions + -> master + name: varchar(255) + missing_from: enum('old', 'new') + """ + + class Hash(dj.Part): + definition = """ # File hashes that differ between old and new versions + -> master + name : varchar(255) + """ + + def _parent_key(self, key): + ret = SpikeSortingRecording * RecordingRecomputeSelection & key + if len(ret) != 1: + raise ValueError(f"Query returned {len(ret)} entries: {ret}") + return ret.fetch(as_dict=True)[0] + + def _hash_one(self, key): + key_hash = dj.hash.key_hash(key) + if key_hash in self._hasher_cache: + return self._hasher_cache[key_hash] + hasher = DirectoryHasher( + path=self._parent_key(key)["recording_path"], + keep_file_hash=True, + ) + self._hasher_cache[key_hash] = hasher + return hasher + + def make(self, key): + pass + + def delete_file(self, key): + pass # TODO: Add means of deleting repliacted files diff --git a/src/spyglass/spikesorting/v0/spikesorting_recording.py b/src/spyglass/spikesorting/v0/spikesorting_recording.py index a4382bf32..9711c5b1a 100644 --- a/src/spyglass/spikesorting/v0/spikesorting_recording.py +++ b/src/spyglass/spikesorting/v0/spikesorting_recording.py @@ -1,4 +1,3 @@ -import os import shutil from functools import reduce from pathlib import Path @@ -8,6 +7,7 @@ import probeinterface as pi import spikeinterface as si import spikeinterface.extractors as se +from tqdm import tqdm from spyglass.common.common_device import Probe, ProbeType # noqa: F401 from spyglass.common.common_ephys import Electrode, ElectrodeGroup @@ -27,6 +27,7 @@ ) from spyglass.utils import SpyglassMixin from spyglass.utils.dj_helper_fn import dj_replace +from spyglass.utils.nwb_hash import DirectoryHasher schema = dj.schema("spikesorting_recording") @@ -291,16 +292,16 @@ class SpikeSortingRecordingSelection(SpyglassMixin, dj.Manual): @schema class SpikeSortingRecording(SpyglassMixin, dj.Computed): - use_transaction, _allow_insert = False, True - definition = """ -> SpikeSortingRecordingSelection --- recording_path: varchar(1000) -> IntervalList.proj(sort_interval_list_name='interval_list_name') + hash=null: char(32) # hash of the directory """ _parallel_make = True + use_transaction, _allow_insert = False, True def make(self, key): """Populates the SpikeSortingRecording table with the recording data. @@ -313,37 +314,71 @@ def make(self, key): 2. Saves the recording data to the recording directory 3. Inserts the path to the recording data into SpikeSortingRecording """ - sort_interval_valid_times = self._get_sort_interval_valid_times(key) - recording = self._get_filtered_recording(key) - recording_name = self._get_recording_name(key) - - # Path to files that will hold the recording extractors - recording_path = str(recording_dir / Path(recording_name)) - if os.path.exists(recording_path): - shutil.rmtree(recording_path) - - recording.save( - folder=recording_path, chunk_duration="10000ms", n_jobs=8 - ) + rec_info = self._make_file(key) IntervalList.insert1( { "nwb_file_name": key["nwb_file_name"], - "interval_list_name": recording_name, - "valid_times": sort_interval_valid_times, + "interval_list_name": rec_info["name"], + "valid_times": self._get_sort_interval_valid_times(key), "pipeline": "spikesorting_recording_v0", }, replace=True, ) - self.insert1( - { - **key, - # store the list of valid times for the sort - "sort_interval_list_name": recording_name, - "recording_path": recording_path, - } + self_insert = dict( + key, + sort_interval_list_name=rec_info["name"], + recording_path=rec_info["path"], ) + self.insert1(self_insert) + + from spyglass.spikesorting.v0.spikesorting_recompute import ( + RecordingRecomputeSelection, + ) + + RecordingRecomputeSelection.insert(self_insert, logged_at_creation=True) + + def _make_file(self, key): + """Run only operations required to save the recording data to disk.""" + has_entry = bool(self & key) # table entry exists, so recompute files + ret = { + "name": self._get_recording_name(key), + "path": str( + Path(recording_dir) / Path(self._get_recording_name(key)) + ), + } + + # Path to files that will hold the recording extractors + recording_name = self._get_recording_name(key) + recording_path = str(recording_dir / Path(recording_name)) + if Path(ret["path"]).exists(): + if has_entry: # if table entry for existing file, use it + return {**ret, "hash": self._dir_hash(ret["path"])} + else: # if no table entry, assume existing is outdated and delete + shutil.rmtree(recording_path) + + recording = self._get_filtered_recording(key) + recording.save(folder=ret["path"], chunk_duration="10000ms", n_jobs=8) + + return {**ret, "hash": self._dir_hash(recording_path)} + + def _dir_hash(self, path): + """Return the hash of the directory.""" + path = Path(path) + if not path.exists(): + raise FileNotFoundError(f"Directory does not exist: {path}") + if not path.is_dir(): + raise NotADirectoryError(f"Path is not a directory: {path}") + return DirectoryHasher(directory=path).hash + + def update_ids(self): + """Update file hashes for all entries in the table. + + Only used for transitioning to recompute NWB files, see #1093.""" + for key in tqdm(self & 'hash=""', desc="Updating hashes"): + key["hash"] = self._dir_hash(key["recording_path"]) + self.update1(key) @staticmethod def _get_recording_name(key): diff --git a/src/spyglass/spikesorting/v1/__init__.py b/src/spyglass/spikesorting/v1/__init__.py index a2f3ca592..9079ac813 100644 --- a/src/spyglass/spikesorting/v1/__init__.py +++ b/src/spyglass/spikesorting/v1/__init__.py @@ -16,6 +16,11 @@ MetricParameters, WaveformParameters, ) +from spyglass.spikesorting.v1.recompute import ( + RecordingRecompute, + RecordingRecomputeSelection, + RecordingVersions, +) from spyglass.spikesorting.v1.recording import ( SortGroup, SpikeSortingPreprocessingParameters, @@ -27,3 +32,28 @@ SpikeSorting, SpikeSortingSelection, ) + +__all__ = [ + "ArtifactDetection", + "ArtifactDetectionParameters", + "ArtifactDetectionSelection", + "CurationV1", + "FigURLCuration", + "FigURLCurationSelection", + "ImportedSpikeSorting", + "MetricCuration", + "MetricCurationParameters", + "MetricCurationSelection", + "MetricParameters", + "RecordingRecompute", + "RecordingRecomputeSelection", + "RecordingVersions", + "SortGroup", + "SpikeSorterParameters", + "SpikeSorting", + "SpikeSortingPreprocessingParameters", + "SpikeSortingRecording", + "SpikeSortingRecordingSelection", + "SpikeSortingSelection", + "WaveformParameters", +] diff --git a/src/spyglass/spikesorting/v1/recompute.py b/src/spyglass/spikesorting/v1/recompute.py index 3108a16e1..30f94121d 100644 --- a/src/spyglass/spikesorting/v1/recompute.py +++ b/src/spyglass/spikesorting/v1/recompute.py @@ -4,9 +4,15 @@ ------ RecordingVersions: What versions are present in an existing analysis file? Allows restrict of recompute attempts to pynwb environments that are - compatible with a pre-existing file. For pip dependencies, see - SpikeSortingRecording.dependencies field -RecordingRecompute: Attempt recompute of an analysis file. + compatible with a pre-existing file. +RecordingRecomputeSelection: Plan a recompute attempt. Capture a list of + pip dependencies under an attempt label, 'attempt_id', and set the desired + level of precision for the recompute (i.e., rounding for ElectricalSeries + data). +RecordingRecompute: Attempt to recompute an analysis file, saving a new file + to a temporary directory. If the new file matches the old, the new file is + deleted. If the new file does not match, the differences are logged in + the Hash table. """ import atexit @@ -119,9 +125,9 @@ def default_rounding(self): @cached_property def default_attempt_id(self): + user = dj.config["database.user"] conda = os_environ.get("CONDA_DEFAULT_ENV", "base") - si_readable = si_version.replace(".", "-") - return f"{conda}_si{si_readable}" + return f"{user}_{conda}" @cached_property def pip_deps(self): @@ -138,7 +144,6 @@ def key_pk(self, key): def insert(self, rows, at_creation=False, **kwargs): """Custom insert to ensure dependencies are added to each row.""" - # rows = rows.copy() if not isinstance(rows, list): rows = [rows] if not isinstance(rows[0], dict): @@ -163,10 +168,12 @@ def insert(self, rows, at_creation=False, **kwargs): super().insert(inserts, **kwargs) def attempt_all(self, attempt_id=None): - if not attempt_id: - attempt_id = self.default_attempt_id inserts = [ - {**key, "attempt_id": attempt_id, "dependencies": self.pip_deps} + { + **key, + "attempt_id": attempt_id or self.default_attempt_id, + "dependencies": self.pip_deps, + } for key in RecordingVersions().this_env.fetch("KEY", as_dict=True) ] self.insert(inserts, skip_duplicates=True) @@ -245,14 +252,14 @@ class RecordingRecompute(dj.Computed): """ class Name(dj.Part): - definition = """ + definition = """ # Object names missing from old or new versions -> master name : varchar(255) missing_from: enum('old', 'new') """ class Hash(dj.Part): - definition = """ + definition = """ # Object hashes that differ between old and new -> master name : varchar(255) --- @@ -272,12 +279,10 @@ def compare(self, key, obj_name=None): old, new = self.get_objs(key, obj_name=obj_name) return H5pyComparator(old=old, new=new) - # TODO: debug key source issues key_source = RecordingRecomputeSelection().this_env.proj() - # key_source = RecordingRecomputeSelection() & "logged_at_creation=0" - _key_cache = {} - _hasher_cache = {} - _files_cache = {} + _key_cache = dict() + _hasher_cache = dict() + _files_cache = dict() _cleanup_registered = False @property @@ -456,7 +461,7 @@ def make(self, key, force_check=False): new_hasher = ( self._hash_one(new, rounding) if new.exists() - else self._recompute(key)["file_hash"] + else self._recompute(key)["hash"] ) if new_hasher is None: # Error occurred during recompute @@ -487,3 +492,7 @@ def make(self, key, force_check=False): self.insert1(dict(key, matched=False)) self.Name().insert(names) self.Hash().insert(hashes) + + def delete_files(self, key): + """If successfully recomputed, delete files for a given restriction.""" + pass # TODO: add means of deleting replicated files diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index b276fb54d..4d7ebe0ab 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -11,6 +11,7 @@ import spikeinterface.extractors as se from h5py import File as H5File from hdmf.data_utils import GenericDataChunkIterator +from tqdm import tqdm from spyglass.common import Session # noqa: F401 from spyglass.common.common_device import Probe @@ -173,8 +174,7 @@ class SpikeSortingRecording(SpyglassMixin, dj.Computed): -> AnalysisNwbfile object_id: varchar(40) # Object ID for the processed recording in NWB file electrodes_id=null: varchar(40) # Object ID for the processed electrodes - file_hash=null: varchar(32) # Hash of the NWB file - dependencies=null: blob # dict of dependencies (pynwb, hdmf, spikeinterface) + hash=null: varchar(32) # Hash of the NWB file """ def make(self, key): @@ -247,7 +247,7 @@ def _make_file( if isinstance(key, dict): key = {k: v for k, v in key.items() if k in cls.primary_key} - file_hash = None + hash = None recompute = recompute_file_name and not key and not save_to if recompute or save_to: # if we expect file to exist @@ -260,8 +260,8 @@ def _make_file( logger.info(f"Recomputing {recompute_file_name}.") query = cls & {"analysis_file_name": recompute_file_name} # Use deleted file's ids and hash for recompute - key, recompute_object_id, recompute_electrodes_id, file_hash = ( - query.fetch1("KEY", "object_id", "electrodes_id", "file_hash") + key, recompute_object_id, recompute_electrodes_id, hash = ( + query.fetch1("KEY", "object_id", "electrodes_id", "hash") ) elif save_to: # recompute prior to deletion, save copy to temp_dir elect_id = cls._validate_file(file_path) @@ -282,7 +282,7 @@ def _make_file( ) ) - file_hash = AnalysisNwbfile().get_hash( + hash = AnalysisNwbfile().get_hash( recording_nwb_file_name, from_schema=True, # REVERT TO FALSE? precision_lookup=rounding, @@ -291,13 +291,13 @@ def _make_file( # NOTE: Conditional to avoid impacting database. NO MERGE! if recompute and test_mode: - AnalysisNwbfile()._update_external(recompute_file_name, file_hash) + AnalysisNwbfile()._update_external(recompute_file_name, hash) return dict( analysis_file_name=recording_nwb_file_name, object_id=recording_object_id, electrodes_id=electrodes_id, - file_hash=file_hash, + hash=hash, dependencies=dict( pynwb=pynwb.__version__, hdmf=hdmf.__version__, @@ -589,23 +589,27 @@ def _get_preprocessed_recording(self, key: dict): return dict(recording=recording, timestamps=np.asarray(timestamps)) def update_ids(self): - """Update electrodes_id, and file_hash in SpikeSortingRecording table. + """Update electrodes_id, and hash in SpikeSortingRecording table. Only used for transitioning to recompute NWB files, see #1093. """ elect_attr = "acquisition/ProcessedElectricalSeries/electrodes" - needs_update = self & ["electrodes_id=''", "file_hash=''"] - for key in needs_update.fetch(as_dict=True): + needs_update = self & ["electrodes_id=''", "hash=''"] + + for key in tqdm(needs_update): analysis_file_path = AnalysisNwbfile.get_abs_path( key["analysis_file_name"] ) with H5File(analysis_file_path, "r") as f: elect_id = f[elect_attr].attrs["object_id"] - key["electrodes_id"] = elect_id - key["file_hash"] = NwbfileHasher(analysis_file_path).hash + updated = dict( + key, + electrodes_id=elect_id, + hash=NwbfileHasher(analysis_file_path).hash, + ) - self.update1(key) + self.update1(updated) def recompute(self, key: dict): """Recompute the processed recording. diff --git a/src/spyglass/utils/nwb_hash.py b/src/spyglass/utils/nwb_hash.py index 7602a1f92..cf80f1b0d 100644 --- a/src/spyglass/utils/nwb_hash.py +++ b/src/spyglass/utils/nwb_hash.py @@ -54,6 +54,7 @@ def __init__( self, directory_path: Union[str, Path], batch_size: int = DEFAULT_BATCH_SIZE, + keep_file_hash: bool = False, verbose: bool = False, ): """Generate a hash of the contents of a directory, recursively. @@ -74,11 +75,21 @@ def __init__( Path to the directory to hash. batch_size : int, optional Limit of data to hash for large files, by default 4095. + keep_file_hash : bool, optional + Default false. If true, keep cache the hash of each file. + verbose : bool, optional + Display progress bar, by default False. """ - self.dir_path = Path(directory_path) - self.batch_size = batch_size - self.verbose = verbose + if not self.dir_path.exists(): + raise FileNotFoundError(f"Dir does not exist: {self.dir_path}") + if not self.dir_path.is_dir(): + raise NotADirectoryError(f"Path is not a dir: {self.dir_path}") + + self.batch_size = int(batch_size) + self.keep_file_hash = bool(keep_file_hash) + self.hash_cache = {} + self.verbose = bool(verbose) self.hashed = md5("".encode()) self.hash = self.compute_hash() @@ -88,32 +99,37 @@ def compute_hash(self) -> str: for file_path in tqdm(all_files, disable=not self.verbose): if file_path.suffix == ".nwb": - hasher = NwbfileHasher(file_path, batch_size=self.batch_size) - self.hashed.update(hasher.hash.encode()) + this_hash = NwbfileHasher( + file_path, batch_size=self.batch_size + ).hash.encode() elif file_path.suffix == ".json": - self.hashed.update(self.json_encode(file_path)) + this_hash = self.json_encode(file_path) else: - self.chunk_encode(file_path) + this_hash = self.chunk_encode(file_path) + + self.hashed.update(this_hash) # update with the rel path to for same file in diff dirs rel_path = str(file_path.relative_to(self.dir_path)) self.hashed.update(rel_path.encode()) - if self.verbose: - print(f"{file_path.name}: {self.hased.hexdigest()}") + if self.keep_file_hash: + self.hash_cache[rel_path] = this_hash return self.hashed.hexdigest() # Return the hex digest of the hash def chunk_encode(self, file_path: Path) -> str: """Encode the contents of a file in chunks for hashing.""" + this_hash = md5("".encode()) with file_path.open("rb") as f: while chunk := f.read(self.batch_size): - self.hashed.update(chunk) + this_hash.update(chunk) + return this_hash.hexdigest() def json_encode(self, file_path: Path) -> str: """Encode the contents of a json file for hashing. - Ignores the 'version' key(s) in the json file. + Ignores the predetermined keys in the IGNORED_KEYS list. """ with file_path.open("r") as f: file_data = json.load(f, object_hook=self.pop_version)