From b0b47cce87d01b95714caedbb1375c2c9d214b53 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Fri, 9 Feb 2024 13:47:16 -0600 Subject: [PATCH] #768 --- CHANGELOG.md | 5 +- src/spyglass/common/common_behav.py | 23 ++++++--- src/spyglass/utils/dj_merge_tables.py | 72 +++++++++++++++------------ 3 files changed, 59 insertions(+), 41 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d8c473f5..ad5a17a69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,10 @@ - Add overview of Spyglass to docs. #779 - Update linting for Black 24. #808 - Steamline dependency management. #822 -- Add catch errorst during `populate_all_common`, log in `common_usage`. #XXX +- Add catch errors during `populate_all_common`, log in `common_usage`. #XXX +- Merge UUIDs #XXX + - Revise Merge table uuid generation to include source. + - Remove mutual exclusivity logic due to new UUID generation. ### Pipelines diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index df195b31e..f2a2226d8 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -42,19 +42,24 @@ class SpatialSeries(SpyglassMixin, dj.Part): name=null: varchar(32) # name of spatial series """ - def populate(self, key=None): + def populate(self, keys=None): """Insert position source data from NWB file. WARNING: populate method on Manual table is not protected by transaction protections like other DataJoint tables. """ - nwb_file_name = key.get("nwb_file_name") - if not nwb_file_name: - raise ValueError( - "PositionSource.populate is an alias for a non-computed table " - + "and must be passed a key with nwb_file_name" - ) - self.insert_from_nwbfile(nwb_file_name) + if not isinstance(keys, list): + keys = [keys] + if isinstance(keys[0], dj.Table): + keys = [k for tbl in keys for k in tbl.fetch("KEY", as_dict=True)] + for key in keys: + nwb_file_name = key.get("nwb_file_name") + if not nwb_file_name: + raise ValueError( + "PositionSource.populate is an alias for a non-computed table " + + "and must be passed a key with nwb_file_name" + ) + self.insert_from_nwbfile(nwb_file_name) @classmethod def insert_from_nwbfile(cls, nwb_file_name): @@ -496,6 +501,7 @@ def _no_transaction_make(self, key): # Skip populating if no pos interval list names if len(pos_intervals) == 0: + # TODO: Now that populate_all accept errors, raise here? logger.error(f"NO POS INTERVALS FOR {key}; {no_pop_msg}") return @@ -533,6 +539,7 @@ def _no_transaction_make(self, key): # Check that each pos interval was matched to only one epoch if len(matching_pos_intervals) != 1: + # TODO: Now that populate_all accept errors, raise here? logger.error( f"Found {len(matching_pos_intervals)} pos intervals for {key}; " + f"{no_pop_msg}\n{matching_pos_intervals}" diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index f15183dfe..f712e9179 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -274,10 +274,8 @@ def _proj_part(part): return query @classmethod - def _merge_insert( - cls, rows: list, part_name: str = None, mutual_exclusvity=True, **kwargs - ) -> None: - """Insert rows into merge, ensuring db integrity and mutual exclusivity + def _merge_insert(cls, rows: list, part_name: str = None, **kwargs) -> None: + """Insert rows into merge, ensuring data exists in part parent(s). Parameters --------- @@ -291,18 +289,17 @@ def _merge_insert( TypeError If rows is not a list of dicts ValueError - If entry already exists, mutual exclusivity errors If data doesn't exist in part parents, integrity error """ cls._ensure_dependencies_loaded() + type_err_msg = "Input `rows` must be a list of dictionaries" try: for r in iter(rows): - assert isinstance( - r, dict - ), 'Input "rows" must be a list of dictionaries' + if not isinstance(r, dict): + raise TypeError(type_err_msg) except TypeError: - raise TypeError('Input "rows" must be a list of dictionaries') + raise TypeError(type_err_msg) parts = cls._merge_restrict_parts(as_objects=True) if part_name: @@ -315,30 +312,24 @@ def _merge_insert( master_entries = [] parts_entries = {p: [] for p in parts} for row in rows: - keys = [] # empty to-be-inserted key + keys = [] # empty to-be-inserted keys for part in parts: # check each part - part_parent = part.parents(as_objects=True)[-1] part_name = cls._part_name(part) + part_parent = part.parents(as_objects=True)[-1] if part_parent & row: # if row is in part parent - if keys and mutual_exclusvity: # if key from other part - raise ValueError( - "Mutual Exclusivity Error! Entry exists in more " - + f"than one table - Entry: {row}" - ) - keys = (part_parent & row).fetch("KEY") # get pk if len(keys) > 1: raise ValueError( "Ambiguous entry. Data has mult rows in " + f"{part_name}:\n\tData:{row}\n\t{keys}" ) - master_pk = { # make uuid - cls()._reserved_pk: dj.hash.key_hash(keys[0]), - } - parts_entries[part].append({**master_pk, **keys[0]}) - master_entries.append( - {**master_pk, cls()._reserved_sk: part_name} - ) + key = keys[0] + master_sk = {cls()._reserved_sk: part_name} + uuid = dj.hash.key_hash(key | master_sk) + master_pk = {cls()._reserved_pk: uuid} + + master_entries.append({**master_pk, **master_sk}) + parts_entries[part].append({**master_pk, **key}) if not keys: raise ValueError( @@ -369,27 +360,22 @@ def _ensure_dependencies_loaded(cls) -> None: if not dj.conn.connection.dependencies._loaded: dj.conn.connection.dependencies.load() - def insert(self, rows: list, mutual_exclusvity=True, **kwargs): - """Merges table specific insert - - Ensuring db integrity and mutual exclusivity + def insert(self, rows: list, **kwargs): + """Merges table specific insert, ensuring data exists in part parents. Parameters --------- rows: List[dict] An iterable where an element is a dictionary. - mutual_exclusvity: bool - Check for mutual exclusivity before insert. Default True. Raises ------ TypeError If rows is not a list of dicts ValueError - If entry already exists, mutual exclusivity errors If data doesn't exist in part parents, integrity error """ - self._merge_insert(rows, mutual_exclusvity=mutual_exclusvity, **kwargs) + self._merge_insert(rows, **kwargs) @classmethod def merge_view(cls, restriction: str = True): @@ -586,6 +572,8 @@ def merge_get_part( + "Try adding a restriction before invoking `get_part`.\n\t" + "Or permitting multiple sources with `multi_source=True`." ) + if len(sources) == 0: + return None parts = [ ( @@ -777,6 +765,26 @@ def merge_populate(source: str, key=None): + "part_parent `make` and then inserting all entries into Merge" ) + def delete(self, force_permission=False, *args, **kwargs): + """Alias for cautious_delete, overwrites datajoint.table.Table.delete""" + raise NotImplementedError( + "Please use delete_downstream_merge or cautious_delete " + + "to clear merge entries." + ) + # for part in self.merge_get_part( + # restriction=self.restriction, + # multi_source=True, + # return_empties=False, + # ): + # part.delete(force_permission=force_permission, *args, **kwargs) + + def super_delete(self, *args, **kwargs): + """Alias for datajoint.table.Table.delete. + + Added to support MRO of SpyglassMixin""" + logger.warning("!! Using super_delete. Bypassing cautious_delete !!") + super().delete(*args, **kwargs) + _Merge = Merge