diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py new file mode 100644 index 000000000..67fac5412 --- /dev/null +++ b/src/spyglass/common/common_usage.py @@ -0,0 +1,22 @@ +"""A schema to store the usage of advanced Spyglass features. + +Records show usage of features such as table chains, which will be used to +determine which features are used, how often, and by whom. This will help +plan future development of Spyglass. +""" +import datajoint as dj + +schema = dj.schema("common_usage") + + +@schema +class CautiousDelete(dj.Manual): + definition = """ + id: int auto_increment + --- + dj_user: varchar(64) + duration: float + origin: varchar(64) + restriction: varchar(64) + merge_deletes = null: blob + """ diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 2fd80ff5e..acd83bb9d 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -1,12 +1,13 @@ from collections.abc import Iterable +from time import time from typing import Dict, List, Union import datajoint as dj +import networkx as nx from datajoint.table import logger as dj_logger from datajoint.user_tables import Table, TableMeta from datajoint.utils import get_master, user_choice -from spyglass.settings import test_mode from spyglass.utils.database_settings import SHARED_MODULES from spyglass.utils.dj_helper_fn import fetch_nwb from spyglass.utils.dj_merge_tables import RESERVED_PRIMARY_KEY as MERGE_PK @@ -37,19 +38,27 @@ class SpyglassMixin: raised. `force_permission` can be set to True to bypass permission check. cdel(*args, **kwargs) Alias for cautious_delete. + delte_downstream_merge(restriction=None, dry_run=True, reload_cache=False) + Delete downstream merge table entries associated with restricton. + Requires caching of merge tables and links, which is slow on first call. + `restriction` can be set to a string to restrict the delete. `dry_run` + can be set to False to commit the delete. `reload_cache` can be set to + True to reload the merge cache. + ddm(*args, **kwargs) + Alias for delete_downstream_merge. """ _nwb_table_dict = {} # Dict mapping NWBFile table to path attribute name. # _nwb_table = None # NWBFile table class, defined at the table level _nwb_table_resolved = None # NWBFiletable class, resolved here from above _delete_dependencies = [] # Session, LabMember, LabTeam, delay import - _merge_delete_func = None # delete_downstream_merge, delay import # pks for delete permission check, assumed to be on field _session_pk = None # Session primary key. Mixin is ambivalent to Session pk _member_pk = None # LabMember primary key. Mixin ambivalent table structure - _merge_cache = {} # Cache of merge tables downstream of self - _merge_cache_links = {} # Cache of merge links downstream of self + _merge_table_cache = {} # Cache of merge tables downstream of self + _merge_chains_cache = {} # Cache of table chains to merges _session_connection_cache = None # Cache of path from Session to self + _usage_table_cache = None # Temporary inclusion for usage tracking # ------------------------------- fetch_nwb ------------------------------- @@ -144,167 +153,13 @@ def _delete_deps(self) -> list: return self._delete_dependencies @property - def _merge_del_func(self) -> callable: - """Callable: delete_downstream_merge function. - - Used to delay import of func until needed, avoiding circular imports. - """ - if not self._merge_delete_func: - from spyglass.utils.dj_merge_tables import ( # noqa F401 - delete_downstream_merge, - ) - - self._merge_delete_func = delete_downstream_merge - return self._merge_delete_func - - @staticmethod - def _get_instanced(*tables) -> Union[dj.user_tables.Table, list]: - """Return instance of table(s) if not already instanced.""" - ret = [] - if not isinstance(tables, Iterable): - tables = tuple(tables) - for table in tables: - if not isinstance(table, dj.user_tables.Table): - ret.append(table()) - else: - ret.append(table) - return ret[0] if len(ret) == 1 else ret + def _test_mode(self) -> bool: + """Return True if test mode is enabled.""" + if not self._test_mode_cache: + from spyglass.settings import test_mode - @staticmethod - def _link_repr(parent, child, width=120): - len_each = (width - 4) // 2 - p = parent.full_table_name[:len_each].ljust(len_each) - c = child.full_table_name[:len_each].ljust(len_each) - return f"{p} -> {c}" - - def _get_connection( - self, - child: dj.user_tables.TableMeta, - parent: dj.user_tables.TableMeta = None, - recurse_level: int = 4, - visited: set = None, - ) -> Union[List[dj.FreeTable], List[List[dj.FreeTable]]]: - """ - Return list of tables connecting the parent and child for a valid join. - - Parameters - ---------- - parent : dj.user_tables.TableMeta - DataJoint table upstream in pipeline. - child : dj.user_tables.TableMeta - DataJoint table downstream in pipeline. - recurse_level : int, optional - Maximum number of recursion levels. Default is 4. - visited : set, optional - Set of visited tables (used internally for recursion). - - Returns - ------- - List[dj.FreeTable] or List[List[dj.FreeTable]] - List of paths, with each path as a list FreeTables connecting the - parent and child for a valid join. - """ - parent = parent or self - parent, child = self._get_instanced(parent, child) - visited = visited or set() - child_is_merge = child.full_table_name in self._merge_tables - - if recurse_level < 1 or ( # if too much recursion - not child_is_merge # merge table ok - and ( # already visited, outside spyglass, or no connection - child.full_table_name in visited - or child.full_table_name.strip("`").split("_")[0] - not in SHARED_MODULES - or child.full_table_name not in parent.descendants() - ) - ): - return [] - - if child.full_table_name in parent.children(): - logger.debug(f"1-{recurse_level}:" + self._link_repr(parent, child)) - if isinstance(child, dict) or isinstance(parent, dict): - __import__("pdb").set_trace() - return [parent, child] - - if child_is_merge: - ret = [] - parts = child.parts(as_objects=True) - if not parts: - logger.warning(f"Merge has no parts: {child.full_table_name}") - for part in child.parts(as_objects=True): - links = self._get_connection( - parent=parent, - child=part, - recurse_level=recurse_level, - visited=visited, - ) - visited.add(part.full_table_name) - if links: - logger.debug( - f"2-{recurse_level}:" + self._link_repr(parent, part) - ) - ret.append(links + [child]) - - return ret - - for subchild in parent.children(as_objects=True): - links = self._get_connection( - parent=subchild, - child=child, - recurse_level=recurse_level - 1, - visited=visited, - ) - visited.add(subchild.full_table_name) - if links: - logger.debug( - f"3-{recurse_level}:" + self._link_repr(subchild, child) - ) - if parent.full_table_name in [l.full_table_name for l in links]: - return links - else: - return [parent] + links - - return [] - - def _join_list( - self, - tables: Union[List[dj.FreeTable], List[List[dj.FreeTable]]], - restriction: str = None, - ) -> dj.expression.QueryExpression: - """Return join of all tables in list. Omits empty items.""" - restriction = restriction or self.restriction or True - - if not isinstance(tables[0], (list, tuple)): - tables = [tables] - ret = [] - for table_list in tables: - join = table_list[0] & restriction - for table in table_list[1:]: - join = join * table - if join: - ret.append(join) - return ret[0] if len(ret) == 1 else ret - - def _connection_repr(self, connection) -> str: - if isinstance(connection[0], (Table, TableMeta)): - connection = [connection] - if not isinstance(connection[0], (list, tuple)): - connection = [connection] - ret = [] - for table_list in connection: - connection_str = "" - for table in table_list: - if isinstance(table, str): - connection_str += table + " -> " - else: - connection_str += table.table_name + " -> " - ret.append(f"\n\tPath: {connection_str[:-4]}") - return ret - - def _ensure_dependencies_loaded(self) -> None: - """Ensure connection dependencies loaded.""" - if not self.connection.dependencies._loaded: - self.connection.dependencies.load() + self._test_mode_cache = test_mode + return self._test_mode_cache @property def _merge_tables(self) -> Dict[str, dj.FreeTable]: @@ -313,13 +168,13 @@ def _merge_tables(self) -> Dict[str, dj.FreeTable]: Cache of items in parents of self.descendants(as_objects=True) that have a merge primary key. """ - if self._merge_cache: - return self._merge_cache + if self._merge_table_cache: + return self._merge_table_cache def has_merge_pk(table): return MERGE_PK in table.heading.names - self._ensure_dependencies_loaded() + self.connection.dependencies.load() for desc in self.descendants(as_objects=True): if not has_merge_pk(desc): continue @@ -327,13 +182,16 @@ def has_merge_pk(table): continue master = dj.FreeTable(self.connection, master_name) if has_merge_pk(master): - self._merge_cache[master_name] = master - logger.info(f"Found {len(self._merge_cache)} merge tables") + self._merge_table_cache[master_name] = master + logger.info( + f"Building merge cache for {self.table_name}.\n\t" + + f"Found {len(self._merge_table_cache)} downstream merge tables" + ) - return self._merge_cache + return self._merge_table_cache @property - def _merge_links(self) -> Dict[str, List[dj.FreeTable]]: + def _merge_chains(self) -> Dict[str, List[dj.FreeTable]]: """Dict of merge links downstream of self. For each merge table found in _merge_tables, find the path from self to @@ -341,21 +199,23 @@ def _merge_links(self) -> Dict[str, List[dj.FreeTable]]: to recompute whenever delete_downstream_merge is called with a new restriction. """ - if self._merge_cache_links: - return self._merge_cache_links + if self._merge_chains_cache: + return self._merge_chains_cache + for name, merge_table in self._merge_tables.items(): - connection = self._get_connection(child=merge_table) - if connection: - self._merge_cache_links[name] = connection - return self._merge_cache_links + chains = TableChains(self, merge_table, connection=self.connection) + if len(chains): + self._merge_chains_cache[name] = chains + return self._merge_chains_cache def _commit_merge_deletes(self, merge_join_dict, **kwargs): - ret = [] - for table, selection in merge_join_dict.items(): - keys = selection.fetch(MERGE_PK, as_dict=True) - ret.append(table & keys) # NEEDS TESTING WITH ACTUAL DELETE - # (table & keys).delete(**kwargs) # TODO: Run delete here - return ret + """Commit merge deletes. + + Extraxted for use in cautious_delete and delete_downstream_merge.""" + for table_name, part_restr in merge_join_dict.items(): + table = self._merge_tables[table_name] + keys = [part.fetch(MERGE_PK, as_dict=True) for part in part_restr] + (table & keys).delete(**kwargs) def delete_downstream_merge( self, @@ -389,17 +249,16 @@ def delete_downstream_merge( Passed to datajoint.table.Table.delete. """ if reload_cache: - self._merge_cache = {} - self._merge_cache_links = {} + self._merge_table_cache = {} + self._merge_chains_cache = {} restriction = restriction or self.restriction or True merge_join_dict = {} - for merge_name, merge_link in self._merge_links.items(): - logger.debug(self._connection_repr(merge_link)) - joined = self._join_list(merge_link, restriction=self.restriction) - if joined: - merge_join_dict[self._merge_tables[merge_name]] = joined + for name, chain in self._merge_chains.items(): + join = chain.join(restriction) + if join: + merge_join_dict[name] = join if not merge_join_dict and not disable_warning: logger.warning( @@ -408,12 +267,30 @@ def delete_downstream_merge( ) if dry_run: - return merge_join_dict + return merge_join_dict.values() if return_parts else merge_join_dict + self._commit_merge_deletes(merge_join_dict, **kwargs) - def ddm(self, *args, **kwargs): + def ddm( + self, + restriction: str = None, + dry_run: bool = True, + reload_cache: bool = False, + disable_warning: bool = False, + return_parts: bool = True, + *args, + **kwargs, + ): """Alias for delete_downstream_merge.""" - return self.delete_downstream_merge(*args, **kwargs) + return self.delete_downstream_merge( + restriction=restriction, + dry_run=dry_run, + reload_cache=reload_cache, + disable_warning=disable_warning, + return_parts=return_parts, + *args, + **kwargs, + ) def _get_exp_summary(self): """Get summary of experimenters for session(s), including NULL. @@ -429,16 +306,15 @@ def _get_exp_summary(self): Summary of experimenters for session(s). """ Session = self._delete_deps[-1] + SesExp = Session.Experimenter + empty_pk = {self._member_pk: "NULL"} format = dj.U(self._session_pk, self._member_pk) + sess_link = self._session_connection.join(self.restriction) + + exp_missing = format & (sess_link - SesExp).proj(**empty_pk) + exp_present = format & (sess_link * SesExp - exp_missing).proj() - sess_link = self._join_list(self._session_connection) - exp_missing = format & (sess_link - Session.Experimenter).proj( - **{self._member_pk: "NULL"} - ) - exp_present = ( - format & (sess_link * Session.Experimenter - exp_missing).proj() - ) return exp_missing + exp_present @property @@ -448,9 +324,9 @@ def _session_connection(self) -> dj.expression.QueryExpression: None is not yet cached, False if no connection found. """ if self._session_connection_cache is None: + connection = TableChain(parent=self._delete_deps[-1], child=self) self._session_connection_cache = ( - self._get_connection(parent=self._delete_deps[-1], child=self) - or False + connection if connection.has_link else False ) return self._session_connection_cache @@ -508,6 +384,15 @@ def _check_delete_permission(self) -> None: ) logger.info(f"Queueing delete for session(s):\n{sess_summary}") + @property + def _usage_table(self): + """Temporary inclusion for usage tracking.""" + if not self._usage_table_cache: + from spyglass.common.common_usage import CautiousDelete + + self._usage_table_cache = CautiousDelete + return self._usage_table_cache + # Rename to `delete` when we're ready to use it # TODO: Intercept datajoint delete confirmation prompt for merge deletes def cautious_delete(self, force_permission: bool = False, *args, **kwargs): @@ -526,6 +411,12 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): *args, **kwargs : Any Passed to datajoint.table.Table.delete. """ + start = time() + usage_dict = dict( + dj_user=dj.config["database.user"], + origin=self.full_table_name, + restriction=self.restriction, + ) if not force_permission: self._check_delete_permission() @@ -547,17 +438,156 @@ def cautious_delete(self, force_permission: bool = False, *args, **kwargs): count, name = len(content), table.full_table_name dj_logger.info(f"Merge: Deleting {count} rows from {name}") if ( - not test_mode + not self._test_mode or not safemode or user_choice("Commit deletes?", default="no") == "yes" ): self._commit_merge_deletes(merge_deletes, **kwargs) else: logger.info("Delete aborted.") + self._usage_table.insert1( + dict(duration=time() - start, **usage_dict) + ) return super().delete(*args, **kwargs) # Additional confirm here + self._usage_table.insert1( + dict( + duration=time() - start, + merge_deletes=merge_deletes, + ) + ) + def cdel(self, *args, **kwargs): """Alias for cautious_delete.""" self.cautious_delete(*args, **kwargs) + + +class TableChains: + """Class for representing chains from parent to Merge table via parts.""" + + def __init__(self, parent, child, connection=None): + self.parent = parent + self.child = child + self.connection = connection or parent.connection + parts = child.parts(as_objects=True) + self.part_names = [part.full_table_name for part in parts] + self.chains = [TableChain(parent, part) for part in parts] + self.has_link = any([chain.has_link for chain in self.chains]) + + def __repr__(self): + return "\n".join([str(chain) for chain in self.chains]) + + def __len__(self): + return len([c for c in self.chains if c.has_link]) + + def join(self, restriction=None): + restriction = restriction or self.parent.restriction or True + joins = [] + for chain in self.chains: + if joined := chain.join(restriction): + joins.append(joined) + return joins + + +class TableChain: + """Class for representing a chain of tables. + + Note: Parent -> Merge should use TableChains instead. + """ + + def __init__(self, parent: Table, child: Table, connection=None): + self._connection = connection or parent.connection + if not self._connection.dependencies._loaded: + self._connection.dependencies.load() + + if ( # if child is a merge table + get_master(child.full_table_name) == "" + and MERGE_PK in child.heading.names + ): + logger.error("Child is a merge table. Use TableChains instead.") + + self._link_symbol = " -> " + self.parent = parent + self.child = child + self._repr = None + self._names = None # full table names of tables in chain + self._objects = None # free tables in chain + self._has_link = child.full_table_name in parent.descendants() + + def __str__(self): + """Return string representation of chain: parent -> child.""" + if not self._has_link: + return "No link" + return ( + f"Chain: " + + self.parent.table_name + + self._link_symbol + + self.child.table_name + ) + + def __repr__(self): + """Return full representation of chain: parent -> {links} -> child.""" + if self._repr: + return self._repr + self._repr = ( + "Chain: " + + self._link_symbol.join([t.table_name for t in self.objects]) + if self.names + else "No link" + ) + return self._repr + + def __len__(self): + """Return number of tables in chain.""" + return len(self.names) + + @property + def has_link(self) -> bool: + """Return True if parent is linked to child. + + Cached as hidden attribute _has_link to set False if nx.NetworkXNoPath + is raised by nx.shortest_path. + """ + return self._has_link + + @property + def names(self) -> List[str]: + """Return list of full table names in chain. + + Uses networkx.shortest_path. + """ + if not self._has_link: + return None + if self._names: + return self._names + try: + self._names = nx.shortest_path( + self.parent.connection.dependencies, + self.parent.full_table_name, + self.child.full_table_name, + ) + return self._names + except nx.NetworkXNoPath: + self._has_link = False + return None + + @property + def objects(self) -> List[dj.FreeTable]: + """Return list of FreeTable objects for each table in chain.""" + if not self._objects: + self._objects = ( + [dj.FreeTable(self._connection, name) for name in self.names] + if self.names + else None + ) + return self._objects + + def join(self, restricton: str = None) -> dj.expression.QueryExpression: + """Return join of tables in chain with restriction applied to parent.""" + restriction = restricton or self.parent.restriction or True + join = self.objects[0] & restriction + for table in self.objects[1:]: + join = join * table + return join if join else None