Skip to content

Commit

Permalink
Include externals
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Oct 31, 2024
1 parent 3e28542 commit 9d064b4
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 21 deletions.
86 changes: 73 additions & 13 deletions src/spyglass/common/common_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,22 @@ def stop_export(self, **kwargs) -> None:
# before actually exporting anything, which is more associated with
# Selection

def _list_raw_files(self, key: dict) -> list[str]:
"""Return a list of unique nwb file names for a given restriction/key."""
file_table = self * self.File & key
return list(
{
*AnalysisNwbfile.join(file_table, log_export=False).fetch(
"nwb_file_name"
)
}
)

def _list_analysis_files(self, key: dict) -> list[str]:
"""Return a list of unique analysis file names for a given restriction/key."""
file_table = self * self.File & key
return list(file_table.fetch("analysis_file_name"))

def list_file_paths(self, key: dict, as_dict=True) -> list[str]:
"""Return a list of unique file paths for a given restriction/key.
Expand All @@ -159,18 +175,60 @@ def list_file_paths(self, key: dict, as_dict=True) -> list[str]:
If False, returns a list of strings without key.
"""
file_table = self * self.File & key
analysis_fp = [
AnalysisNwbfile().get_abs_path(fname)
for fname in file_table.fetch("analysis_file_name")
]
nwbfile_fp = [
Nwbfile().get_abs_path(fname)
for fname in AnalysisNwbfile.join(
file_table, log_export=False
).fetch("nwb_file_name")
]
unique_ft = list({*analysis_fp, *nwbfile_fp})
return [{"file_path": p} for p in unique_ft] if as_dict else unique_ft
unique_fp = {
*[
AnalysisNwbfile().get_abs_path(p)
for p in self._list_analysis_files(key)
],
*[Nwbfile().get_abs_path(p) for p in self._list_raw_files(key)],
}

return [{"file_path": p} for p in unique_fp] if as_dict else unique_fp

@property
def _externals(self) -> dj.external.ExternalMapping:
"""Return the external mapping for the common_n schema."""
return dj.external.ExternalMapping(schema=AnalysisNwbfile)

def _add_externals_to_restr_graph(
self, restr_graph: RestrGraph, key: dict
) -> RestrGraph:
"""Add external tables to a RestrGraph for a given restriction/key.
Tables added as nodes with restrictions based on file paths. Names
added to visited set to appear in restr_ft obj bassed to SQLDumpHelper.
Parameters
----------
restr_graph : RestrGraph
A RestrGraph object to add external tables to.
key : dict
Any valid restriction key for ExportSelection.Table
Returns
-------
restr_graph : RestrGraph
The updated RestrGraph
"""
raw_tbl = self._externals["raw"]
raw_name = raw_tbl.full_table_name
raw_restr = (
"filepath in ('" + "','".join(self._list_raw_files(key)) + "')"
)
restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr)

analysis_tbl = self._externals["analysis"]
analysis_name = analysis_tbl.full_table_name
analysis_restr = ( # filepaths have analysis subdir. regexp substrings
"filepath REGEXP '" + "|".join(self._list_analysis_files(key)) + "'"
) # regexp is slow, but we're only doing this once, and future-proof
restr_graph.graph.add_node(
analysis_name, ft=analysis_tbl, restr=analysis_restr
)

restr_graph.visited.update({raw_name, analysis_name})

return restr_graph

def get_restr_graph(
self, key: dict, verbose=False, cascade=True
Expand All @@ -193,9 +251,11 @@ def get_restr_graph(
"table_name", "restriction", as_dict=True
)
)
return RestrGraph(

restr_graph = RestrGraph(
seed_table=self, leaves=leaves, verbose=verbose, cascade=cascade
)
return self._add_externals_to_restr_graph(restr_graph, key)

def preview_tables(self, **kwargs) -> list[dj.FreeTable]:
"""Return a list of restricted FreeTables for a given restriction/key.
Expand Down
15 changes: 8 additions & 7 deletions src/spyglass/utils/dj_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,13 +808,14 @@ def file_paths(self) -> List[str]:
directly by the user.
"""
self.cascade()
return [
self.analysis_file_tbl.get_abs_path(file)
for file in set(
[f for files in self.file_dict.values() for f in files]
)
if file is not None
]

files = {
file
for table in self.visited
for file in self._get_node(table).get("files", [])
}

return [self.analysis_file_tbl.get_abs_path(file) for file in files]


class TableChain(RestrGraph):
Expand Down
4 changes: 3 additions & 1 deletion src/spyglass/utils/mixins/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,9 @@ def restrict(self, restriction):
return super().restrict(restriction)
log_export = "fetch_nwb" not in self._called_funcs()
return self._run_with_log(
super().restrict, restriction=dj.AndList([restriction, self.restriction]), log_export=log_export
super().restrict,
restriction=dj.AndList([restriction, self.restriction]),
log_export=log_export,
)

def join(self, other, log_export=True, *args, **kwargs):
Expand Down

0 comments on commit 9d064b4

Please sign in to comment.