diff --git a/CHANGELOG.md b/CHANGELOG.md index b00777da2..a5d467ec1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,7 +39,8 @@ PositionGroup.alter() - Allow `ModuleNotFoundError` or `ImportError` for optional dependencies #1023 - Ensure integrity of group tables #1026 - Convert list of LFP artifact removed interval list to array #1046 -- Merge duplicate functions in decoding and spikesorting #1050, #1053, #1062, #1069 +- Merge duplicate functions in decoding and spikesorting #1050, #1053, #1058, + #1066 - Revise docs organization. - Misc -> Features/ForDevelopers. #1029 - Installation instructions -> Setup notebook. #1029 diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index f437f6276..0ab4ab477 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -159,20 +159,6 @@ def _camel(self, table): # ------------------------------ Graph Nodes ------------------------------ - def _ensure_names( - self, table: Union[str, Table] = None - ) -> Union[str, List[str]]: - """Ensure table is a string.""" - if table is None: - return None - if isinstance(table, str): - return table - if isinstance(table, Iterable) and not isinstance( - table, (Table, TableMeta) - ): - return [ensure_names(t) for t in table] - return getattr(table, "full_table_name", None) - def _get_node(self, table: Union[str, Table]): """Get node from graph.""" table = ensure_names(table) diff --git a/src/spyglass/utils/dj_helper_fn.py b/src/spyglass/utils/dj_helper_fn.py index 90fd47cc0..0bf61b734 100644 --- a/src/spyglass/utils/dj_helper_fn.py +++ b/src/spyglass/utils/dj_helper_fn.py @@ -35,16 +35,35 @@ def ensure_names( - table: Union[str, Table, Iterable] = None + table: Union[str, Table, Iterable] = None, force_list: bool = False ) -> Union[str, List[str], None]: - """Ensure table is a string.""" + """Ensure table is a string. + + Parameters + ---------- + table : Union[str, Table, Iterable], optional + Table to ensure is a string, by default None. If passed as iterable, + will ensure all elements are strings. + force_list : bool, optional + Force the return to be a list, by default False, only used if input is + iterable. + + Returns + ------- + Union[str, List[str], None] + Table as a string or list of strings. + """ + # is iterable (list, set, set) but not a table/string + is_collection = isinstance(table, Iterable) and not isinstance( + table, (Table, TableMeta, str) + ) + if force_list and not is_collection: + return [ensure_names(table)] if table is None: return None if isinstance(table, str): return table - if isinstance(table, Iterable) and not isinstance( - table, (Table, TableMeta) - ): + if is_collection: return [ensure_names(t) for t in table] return getattr(table, "full_table_name", None) diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 533976329..ff3922087 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -914,11 +914,13 @@ def __rshift__(self, restriction) -> QueryExpression: def ban_search_table(self, table): """Ban table from search in restrict_by.""" - self._banned_search_tables.update(ensure_names(table)) + self._banned_search_tables.update(ensure_names(table, force_list=True)) def unban_search_table(self, table): """Unban table from search in restrict_by.""" - self._banned_search_tables.difference_update(ensure_names(table)) + self._banned_search_tables.difference_update( + ensure_names(table, force_list=True) + ) def see_banned_tables(self): """Print banned tables."""