Skip to content

Commit

Permalink
Ban tables in distance restrict bugfix (#1066)
Browse files Browse the repository at this point in the history
* Ban tables in distance restrict bugfix

* Update changelog
  • Loading branch information
CBroz1 authored Aug 20, 2024
1 parent 7b3eae9 commit 012ea30
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 22 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ PositionGroup.alter()
- Allow `ModuleNotFoundError` or `ImportError` for optional dependencies #1023
- Ensure integrity of group tables #1026
- Convert list of LFP artifact removed interval list to array #1046
- Merge duplicate functions in decoding and spikesorting #1050, #1053, #1062, #1069
- Merge duplicate functions in decoding and spikesorting #1050, #1053, #1058,
#1066
- Revise docs organization.
- Misc -> Features/ForDevelopers. #1029
- Installation instructions -> Setup notebook. #1029
Expand Down
14 changes: 0 additions & 14 deletions src/spyglass/utils/dj_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,6 @@ def _camel(self, table):

# ------------------------------ Graph Nodes ------------------------------

def _ensure_names(
self, table: Union[str, Table] = None
) -> Union[str, List[str]]:
"""Ensure table is a string."""
if table is None:
return None
if isinstance(table, str):
return table
if isinstance(table, Iterable) and not isinstance(
table, (Table, TableMeta)
):
return [ensure_names(t) for t in table]
return getattr(table, "full_table_name", None)

def _get_node(self, table: Union[str, Table]):
"""Get node from graph."""
table = ensure_names(table)
Expand Down
29 changes: 24 additions & 5 deletions src/spyglass/utils/dj_helper_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,35 @@


def ensure_names(
table: Union[str, Table, Iterable] = None
table: Union[str, Table, Iterable] = None, force_list: bool = False
) -> Union[str, List[str], None]:
"""Ensure table is a string."""
"""Ensure table is a string.
Parameters
----------
table : Union[str, Table, Iterable], optional
Table to ensure is a string, by default None. If passed as iterable,
will ensure all elements are strings.
force_list : bool, optional
Force the return to be a list, by default False, only used if input is
iterable.
Returns
-------
Union[str, List[str], None]
Table as a string or list of strings.
"""
# is iterable (list, set, set) but not a table/string
is_collection = isinstance(table, Iterable) and not isinstance(
table, (Table, TableMeta, str)
)
if force_list and not is_collection:
return [ensure_names(table)]
if table is None:
return None
if isinstance(table, str):
return table
if isinstance(table, Iterable) and not isinstance(
table, (Table, TableMeta)
):
if is_collection:
return [ensure_names(t) for t in table]
return getattr(table, "full_table_name", None)

Expand Down
6 changes: 4 additions & 2 deletions src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,11 +914,13 @@ def __rshift__(self, restriction) -> QueryExpression:

def ban_search_table(self, table):
"""Ban table from search in restrict_by."""
self._banned_search_tables.update(ensure_names(table))
self._banned_search_tables.update(ensure_names(table, force_list=True))

def unban_search_table(self, table):
"""Unban table from search in restrict_by."""
self._banned_search_tables.difference_update(ensure_names(table))
self._banned_search_tables.difference_update(
ensure_names(table, force_list=True)
)

def see_banned_tables(self):
"""Print banned tables."""
Expand Down

0 comments on commit 012ea30

Please sign in to comment.