diff --git a/CHANGELOG.md b/CHANGELOG.md index dc991e655..c6a9fa232 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,27 +4,38 @@ ### Infrastructure -- Additional documentation. #690 -- Clean up following pre-commit checks. #688 -- Add Mixin class to centralize `fetch_nwb` functionality. #692, #734 -- Refactor restriction use in `delete_downstream_merge` #703 -- Add `cautious_delete` to Mixin class - - Initial implementation. #711, #762 - - More robust caching of join to downstream tables. #806 - - Overwrite datajoint `delete` method to use `cautious_delete`. #806 - - Reverse join order for session summary. #821 - - Add temporary logging of use to `common_usage`. #811, #821 -- Add `deprecation_factory` to facilitate table migration. #717 -- Add Spyglass logger. #730 -- IntervalList: Add secondary key `pipeline` #742 -- Increase pytest coverage for `common`, `lfp`, and `utils`. #743 -- Update docs to reflect new notebooks. #776 -- Add overview of Spyglass to docs. #779 -- Update linting for Black 24. #808 -- Steamline dependency management. #822 +- Docs: + - Additional documentation. #690 + - Add overview of Spyglass to docs. #779 + - Update docs to reflect new notebooks. #776 +- Mixin: + - Add Mixin class to centralize `fetch_nwb` functionality. #692, #734 + - Refactor restriction use in `delete_downstream_merge` #703 + - Add `cautious_delete` to Mixin class + - Initial implementation. #711, #762 + - More robust caching of join to downstream tables. #806 + - Overwrite datajoint `delete` method to use `cautious_delete`. #806 + - Reverse join order for session summary. #821 + - Add temporary logging of use to `common_usage`. #811, #821 +- Merge Tables: + - UUIDs: Revise Merge table uuid generation to include source. #824 + - UUIDs: Remove mutual exclusivity logic due to new UUID generation. #824 + - Add method for `merge_populate`. #824 +- Linting: + - Clean up following pre-commit checks. #688 + - Update linting for Black 24. #808 +- Misc: + - Add `deprecation_factory` to facilitate table migration. #717 + - Add Spyglass logger. #730 + - Increase pytest coverage for `common`, `lfp`, and `utils`. #743 + - Steamline dependency management. #822 ### Pipelines +- Common: + - `IntervalList`: Add secondary key `pipeline` #742 + - Add `common_usage` table. #811, #821, #824 + - Add catch errors during `populate_all_common`. #824 - Spike sorting: - Add SpikeSorting V1 pipeline. #651 - Move modules into spikesorting.v0 #807 diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index d7d4759fb..f2a2226d8 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -42,6 +42,25 @@ class SpatialSeries(SpyglassMixin, dj.Part): name=null: varchar(32) # name of spatial series """ + def populate(self, keys=None): + """Insert position source data from NWB file. + + WARNING: populate method on Manual table is not protected by transaction + protections like other DataJoint tables. + """ + if not isinstance(keys, list): + keys = [keys] + if isinstance(keys[0], dj.Table): + keys = [k for tbl in keys for k in tbl.fetch("KEY", as_dict=True)] + for key in keys: + nwb_file_name = key.get("nwb_file_name") + if not nwb_file_name: + raise ValueError( + "PositionSource.populate is an alias for a non-computed table " + + "and must be passed a key with nwb_file_name" + ) + self.insert_from_nwbfile(nwb_file_name) + @classmethod def insert_from_nwbfile(cls, nwb_file_name): """Add intervals to ItervalList and PositionSource. @@ -482,6 +501,7 @@ def _no_transaction_make(self, key): # Skip populating if no pos interval list names if len(pos_intervals) == 0: + # TODO: Now that populate_all accept errors, raise here? logger.error(f"NO POS INTERVALS FOR {key}; {no_pop_msg}") return @@ -519,6 +539,7 @@ def _no_transaction_make(self, key): # Check that each pos interval was matched to only one epoch if len(matching_pos_intervals) != 1: + # TODO: Now that populate_all accept errors, raise here? logger.error( f"Found {len(matching_pos_intervals)} pos intervals for {key}; " + f"{no_pop_msg}\n{matching_pos_intervals}" diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index ccd58091d..fdf7ae99d 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -1,6 +1,7 @@ """A schema to store the usage of advanced Spyglass features. -Records show usage of features such as table chains, which will be used to +Records show usage of features such as cautious delete and fault-permitting +insert, which will be used to determine which features are used, how often, and by whom. This will help plan future development of Spyglass. """ @@ -21,3 +22,18 @@ class CautiousDelete(dj.Manual): restriction: varchar(255) merge_deletes = null: blob """ + + +@schema +class InsertError(dj.Manual): + definition = """ + id: int auto_increment + --- + dj_user: varchar(64) + connection_id: int # MySQL CONNECTION_ID() + nwb_file_name: varchar(64) + table: varchar(64) + error_type: varchar(64) + error_message: varchar(255) + error_raw = null: blob + """ diff --git a/src/spyglass/common/populate_all_common.py b/src/spyglass/common/populate_all_common.py index b2fa7d760..2972ed145 100644 --- a/src/spyglass/common/populate_all_common.py +++ b/src/spyglass/common/populate_all_common.py @@ -1,3 +1,5 @@ +import datajoint as dj + from spyglass.common.common_behav import ( PositionSource, RawPosition, @@ -14,51 +16,58 @@ from spyglass.common.common_nwbfile import Nwbfile from spyglass.common.common_session import Session from spyglass.common.common_task import TaskEpoch +from spyglass.common.common_usage import InsertError from spyglass.utils import logger def populate_all_common(nwb_file_name): - # Insert session one by one - fp = [(Nwbfile & {"nwb_file_name": nwb_file_name}).proj()] - logger.info("Populate Session...") - Session.populate(fp) - - # If we use Kachery for data sharing we can uncomment the following two lines. TBD - # logger.info('Populate NwbfileKachery...') - # NwbfileKachery.populate() - - logger.info("Populate ElectrodeGroup...") - ElectrodeGroup.populate(fp) - - logger.info("Populate Electrode...") - Electrode.populate(fp) - - logger.info("Populate Raw...") - Raw.populate(fp) - - logger.info("Populate SampleCount...") - SampleCount.populate(fp) - - logger.info("Populate DIOEvents...") - DIOEvents.populate(fp) - - # sensor data (from analog ProcessingModule) is temporarily removed from NWBFile - # to reduce file size while it is not being used. add it back in by commenting out - # the removal code in spyglass/data_import/insert_sessions.py when ready - # logger.info('Populate SensorData') - # SensorData.populate(fp) - - logger.info("Populate TaskEpochs") - TaskEpoch.populate(fp) - logger.info("Populate StateScriptFile") - StateScriptFile.populate(fp) - logger.info("Populate VideoFile") - VideoFile.populate(fp) - logger.info("RawPosition...") - PositionSource.insert_from_nwbfile(nwb_file_name) - RawPosition.populate(fp) - - logger.info("Populate ImportedSpikeSorting...") + """Insert all common tables for a given NWB file.""" from spyglass.spikesorting.imported import ImportedSpikeSorting - ImportedSpikeSorting.populate(fp) + key = [(Nwbfile & f"nwb_file_name LIKE '{nwb_file_name}'").proj()] + tables = [ + Session, + # NwbfileKachery, # Not used by default + ElectrodeGroup, + Electrode, + Raw, + SampleCount, + DIOEvents, + # SensorData, # Not used by default. Generates large files + RawPosition, + TaskEpoch, + StateScriptFile, + VideoFile, + PositionSource, + RawPosition, + ImportedSpikeSorting, + ] + error_constants = dict( + dj_user=dj.config["database.user"], + connection_id=dj.conn().connection_id, + nwb_file_name=nwb_file_name, + ) + + for table in tables: + logger.info(f"Populating {table.__name__}...") + try: + table.populate(key) + except Exception as e: + InsertError.insert1( + dict( + **error_constants, + table=table.__name__, + error_type=type(e).__name__, + error_message=str(e), + error_raw=str(e), + ) + ) + query = InsertError & error_constants + if query: + err_tables = query.fetch("table") + logger.error( + f"Errors occurred during population for {nwb_file_name}:\n\t" + + f"Failed tables {err_tables}\n\t" + + "See common_usage.InsertError for more details" + ) + return query.fetch("KEY") diff --git a/src/spyglass/linearization/__init__.py b/src/spyglass/linearization/__init__.py index 681df507c..e6a9504ea 100644 --- a/src/spyglass/linearization/__init__.py +++ b/src/spyglass/linearization/__init__.py @@ -1 +1,3 @@ -from spyglass.linearization.merge import LinearizedPositionOutput +# CB: Circular import if only importing PositionOutput + +# from spyglass.linearization.merge import LinearizedPositionOutput diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index f15183dfe..14c8c20db 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -3,13 +3,14 @@ from inspect import getmodule from itertools import chain as iter_chain from pprint import pprint +from time import time from typing import Union import datajoint as dj from datajoint.condition import make_condition from datajoint.errors import DataJointError from datajoint.preview import repr_html -from datajoint.utils import from_camel_case, get_master, to_camel_case +from datajoint.utils import from_camel_case, to_camel_case from IPython.core.display import HTML from spyglass.utils.logging import logger @@ -248,6 +249,9 @@ def _merge_repr(cls, restriction: str = True) -> dj.expression.Union: return_empties=False, # motivated by SpikeSortingOutput.Import ) ] + if not parts: + logger.warning("No parts found. Try adjusting restriction.") + return attr_dict = { # NULL for non-numeric, 0 for numeric attr.name: "0" if attr.numeric else "NULL" @@ -274,10 +278,8 @@ def _proj_part(part): return query @classmethod - def _merge_insert( - cls, rows: list, part_name: str = None, mutual_exclusvity=True, **kwargs - ) -> None: - """Insert rows into merge, ensuring db integrity and mutual exclusivity + def _merge_insert(cls, rows: list, part_name: str = None, **kwargs) -> None: + """Insert rows into merge, ensuring data exists in part parent(s). Parameters --------- @@ -291,18 +293,17 @@ def _merge_insert( TypeError If rows is not a list of dicts ValueError - If entry already exists, mutual exclusivity errors If data doesn't exist in part parents, integrity error """ cls._ensure_dependencies_loaded() + type_err_msg = "Input `rows` must be a list of dictionaries" try: for r in iter(rows): - assert isinstance( - r, dict - ), 'Input "rows" must be a list of dictionaries' + if not isinstance(r, dict): + raise TypeError(type_err_msg) except TypeError: - raise TypeError('Input "rows" must be a list of dictionaries') + raise TypeError(type_err_msg) parts = cls._merge_restrict_parts(as_objects=True) if part_name: @@ -315,30 +316,24 @@ def _merge_insert( master_entries = [] parts_entries = {p: [] for p in parts} for row in rows: - keys = [] # empty to-be-inserted key + keys = [] # empty to-be-inserted keys for part in parts: # check each part - part_parent = part.parents(as_objects=True)[-1] part_name = cls._part_name(part) + part_parent = part.parents(as_objects=True)[-1] if part_parent & row: # if row is in part parent - if keys and mutual_exclusvity: # if key from other part - raise ValueError( - "Mutual Exclusivity Error! Entry exists in more " - + f"than one table - Entry: {row}" - ) - keys = (part_parent & row).fetch("KEY") # get pk if len(keys) > 1: raise ValueError( "Ambiguous entry. Data has mult rows in " + f"{part_name}:\n\tData:{row}\n\t{keys}" ) - master_pk = { # make uuid - cls()._reserved_pk: dj.hash.key_hash(keys[0]), - } - parts_entries[part].append({**master_pk, **keys[0]}) - master_entries.append( - {**master_pk, cls()._reserved_sk: part_name} - ) + key = keys[0] + master_sk = {cls()._reserved_sk: part_name} + uuid = dj.hash.key_hash(key | master_sk) + master_pk = {cls()._reserved_pk: uuid} + + master_entries.append({**master_pk, **master_sk}) + parts_entries[part].append({**master_pk, **key}) if not keys: raise ValueError( @@ -369,27 +364,22 @@ def _ensure_dependencies_loaded(cls) -> None: if not dj.conn.connection.dependencies._loaded: dj.conn.connection.dependencies.load() - def insert(self, rows: list, mutual_exclusvity=True, **kwargs): - """Merges table specific insert - - Ensuring db integrity and mutual exclusivity + def insert(self, rows: list, **kwargs): + """Merges table specific insert, ensuring data exists in part parents. Parameters --------- rows: List[dict] An iterable where an element is a dictionary. - mutual_exclusvity: bool - Check for mutual exclusivity before insert. Default True. Raises ------ TypeError If rows is not a list of dicts ValueError - If entry already exists, mutual exclusivity errors If data doesn't exist in part parents, integrity error """ - self._merge_insert(rows, mutual_exclusvity=mutual_exclusvity, **kwargs) + self._merge_insert(rows, **kwargs) @classmethod def merge_view(cls, restriction: str = True): @@ -586,6 +576,8 @@ def merge_get_part( + "Try adding a restriction before invoking `get_part`.\n\t" + "Or permitting multiple sources with `multi_source=True`." ) + if len(sources) == 0: + return None parts = [ ( @@ -770,12 +762,33 @@ def merge_fetch(self, restriction: str = True, *attrs, **kwargs) -> list: ) return results[0] if len(results) == 1 else results - @classmethod - def merge_populate(source: str, key=None): - raise NotImplementedError( - "CBroz: In the future, this command will support executing " - + "part_parent `make` and then inserting all entries into Merge" - ) + def merge_populate(self, source: str, keys=None): + """Populate the merge table with entries from the source table.""" + logger.warning("CBroz: Not fully tested. Use with caution.") + parent_class = self.merge_get_parent_class(source) + if not keys: + keys = parent_class.key_source + parent_class.populate(keys) + successes = (parent_class & keys).fetch("KEY", as_dict=True) + self.insert(successes) + + def delete(self, force_permission=False, *args, **kwargs): + """Alias for cautious_delete, overwrites datajoint.table.Table.delete""" + for part in self.merge_get_part( + restriction=self.restriction, + multi_source=True, + return_empties=False, + ): + part.delete(force_permission=force_permission, *args, **kwargs) + + def super_delete(self, *args, **kwargs): + """Alias for datajoint.table.Table.delete. + + Added to support MRO of SpyglassMixin""" + logger.warning("!! Using super_delete. Bypassing cautious_delete !!") + + self._log_use(start=time(), super_delete=True) + super().delete(*args, **kwargs) _Merge = Merge diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 0e18e3a5c..349476600 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -215,8 +215,8 @@ def delete_downstream_merge( if not merge_join_dict and not disable_warning: logger.warning( f"No merge deletes found w/ {self.table_name} & " - + f"{restriction}.\n\tIf this is unexpected, try running with " - + "`reload_cache`." + + f"{restriction}.\n\tIf this is unexpected, try importing " + + " Merge table(s) and running with `reload_cache`." ) if dry_run: @@ -365,7 +365,7 @@ def _usage_table(self): return CautiousDelete() - def _log_use(self, start, merge_deletes=None): + def _log_use(self, start, merge_deletes=None, super_delete=False): """Log use of cautious_delete.""" if isinstance(merge_deletes, QueryExpression): merge_deletes = merge_deletes.fetch(as_dict=True) @@ -374,15 +374,13 @@ def _log_use(self, start, merge_deletes=None): dj_user=dj.config["database.user"], origin=self.full_table_name, ) + restr_str = "Super delete: " if super_delete else "" + restr_str += "".join(self.restriction) if self.restriction else "None" try: self._usage_table.insert1( dict( **safe_insert, - restriction=( - "".join(self.restriction)[255:] # handle list - if self.restriction - else "None" - ), + restriction=restr_str[:255], merge_deletes=merge_deletes, ) ) @@ -455,4 +453,5 @@ def delete(self, force_permission=False, *args, **kwargs): def super_delete(self, *args, **kwargs): """Alias for datajoint.table.Table.delete.""" logger.warning("!! Using super_delete. Bypassing cautious_delete !!") + self._log_use(start=time(), super_delete=True) super().delete(*args, **kwargs)