Skip to content

Commit

Permalink
Merge branch '1144' of https://github.com/CBroz1/spyglass into 1144
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Nov 4, 2024
2 parents 6920722 + 1883fee commit 14f49d5
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 29 deletions.
86 changes: 73 additions & 13 deletions src/spyglass/common/common_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,22 @@ def stop_export(self, **kwargs) -> None:
# before actually exporting anything, which is more associated with
# Selection

def _list_raw_files(self, key: dict) -> list[str]:
"""Return a list of unique nwb file names for a given restriction/key."""
file_table = self * self.File & key
return list(
{
*AnalysisNwbfile.join(file_table, log_export=False).fetch(
"nwb_file_name"
)
}
)

def _list_analysis_files(self, key: dict) -> list[str]:
"""Return a list of unique analysis file names for a given restriction/key."""
file_table = self * self.File & key
return list(file_table.fetch("analysis_file_name"))

def list_file_paths(self, key: dict, as_dict=True) -> list[str]:
"""Return a list of unique file paths for a given restriction/key.
Expand All @@ -159,18 +175,60 @@ def list_file_paths(self, key: dict, as_dict=True) -> list[str]:
If False, returns a list of strings without key.
"""
file_table = self * self.File & key
analysis_fp = [
AnalysisNwbfile().get_abs_path(fname)
for fname in file_table.fetch("analysis_file_name")
]
nwbfile_fp = [
Nwbfile().get_abs_path(fname)
for fname in AnalysisNwbfile.join(
file_table, log_export=False
).fetch("nwb_file_name")
]
unique_ft = list({*analysis_fp, *nwbfile_fp})
return [{"file_path": p} for p in unique_ft] if as_dict else unique_ft
unique_fp = {
*[
AnalysisNwbfile().get_abs_path(p)
for p in self._list_analysis_files(key)
],
*[Nwbfile().get_abs_path(p) for p in self._list_raw_files(key)],
}

return [{"file_path": p} for p in unique_fp] if as_dict else unique_fp

@property
def _externals(self) -> dj.external.ExternalMapping:
"""Return the external mapping for the common_n schema."""
return dj.external.ExternalMapping(schema=AnalysisNwbfile)

def _add_externals_to_restr_graph(
self, restr_graph: RestrGraph, key: dict
) -> RestrGraph:
"""Add external tables to a RestrGraph for a given restriction/key.
Tables added as nodes with restrictions based on file paths. Names
added to visited set to appear in restr_ft obj bassed to SQLDumpHelper.
Parameters
----------
restr_graph : RestrGraph
A RestrGraph object to add external tables to.
key : dict
Any valid restriction key for ExportSelection.Table
Returns
-------
restr_graph : RestrGraph
The updated RestrGraph
"""
raw_tbl = self._externals["raw"]
raw_name = raw_tbl.full_table_name
raw_restr = (
"filepath in ('" + "','".join(self._list_raw_files(key)) + "')"
)
restr_graph.graph.add_node(raw_name, ft=raw_tbl, restr=raw_restr)

analysis_tbl = self._externals["analysis"]
analysis_name = analysis_tbl.full_table_name
analysis_restr = ( # filepaths have analysis subdir. regexp substrings
"filepath REGEXP '" + "|".join(self._list_analysis_files(key)) + "'"
) # regexp is slow, but we're only doing this once, and future-proof
restr_graph.graph.add_node(
analysis_name, ft=analysis_tbl, restr=analysis_restr
)

restr_graph.visited.update({raw_name, analysis_name})

return restr_graph

def get_restr_graph(
self, key: dict, verbose=False, cascade=True
Expand All @@ -193,9 +251,11 @@ def get_restr_graph(
"table_name", "restriction", as_dict=True
)
)
return RestrGraph(

restr_graph = RestrGraph(
seed_table=self, leaves=leaves, verbose=verbose, cascade=cascade
)
return self._add_externals_to_restr_graph(restr_graph, key)

def preview_tables(self, **kwargs) -> list[dj.FreeTable]:
"""Return a list of restricted FreeTables for a given restriction/key.
Expand Down
15 changes: 8 additions & 7 deletions src/spyglass/utils/dj_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,13 +808,14 @@ def file_paths(self) -> List[str]:
directly by the user.
"""
self.cascade()
return [
self.analysis_file_tbl.get_abs_path(file)
for file in set(
[f for files in self.file_dict.values() for f in files]
)
if file is not None
]

files = {
file
for table in self.visited
for file in self._get_node(table).get("files", [])
}

return [self.analysis_file_tbl.get_abs_path(file) for file in files]


class TableChain(RestrGraph):
Expand Down
3 changes: 2 additions & 1 deletion src/spyglass/utils/dj_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _nwb_table_tuple(self) -> tuple:
table_dict[resolved],
)

def fetch_nwb(self, log_export=True, *attrs, **kwargs):
def fetch_nwb(self, *attrs, **kwargs):
"""Fetch NWBFile object from relevant table.
Implementing class must have a foreign key reference to Nwbfile or
Expand All @@ -184,6 +184,7 @@ def fetch_nwb(self, log_export=True, *attrs, **kwargs):
"""
table, tbl_attr = self._nwb_table_tuple

log_export = kwargs.pop("log_export", True)
if log_export and self.export_id and "analysis" in tbl_attr:
self._log_fetch_nwb(table, tbl_attr)

Expand Down
7 changes: 6 additions & 1 deletion src/spyglass/utils/mixins/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from re import match as re_match

from datajoint.condition import make_condition
from datajoint.table import Table
from packaging.version import parse as version_parse

from spyglass.utils.logging import logger
Expand Down Expand Up @@ -264,6 +265,8 @@ def _run_join(self, **kwargs):

joined = self.proj().join(other.proj(), log_export=False)
for table in table_list: # log separate for unique pks
if isinstance(table, type) and issubclass(table, Table):
table = table() # adapted from dj.declare.compile_foreign_key
for r in joined.fetch(*table.primary_key, as_dict=True):
table._log_fetch(restriction=r)

Expand Down Expand Up @@ -316,7 +319,9 @@ def restrict(self, restriction):
return super().restrict(restriction)
log_export = "fetch_nwb" not in self._called_funcs()
return self._run_with_log(
super().restrict, restriction=restriction, log_export=log_export
super().restrict,
restriction=dj.AndList([restriction, self.restriction]),
log_export=log_export,
)

def join(self, other, log_export=True, *args, **kwargs):
Expand Down
53 changes: 46 additions & 7 deletions src/spyglass/utils/sql_helper_fn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from functools import cached_property
from os import system as os_system
from pathlib import Path
Expand Down Expand Up @@ -75,10 +76,11 @@ def _write_sql_cnf(self):

def _cmd_prefix(self, docker_id=None):
"""Get prefix for mysqldump command. Includes docker exec if needed."""
default = "mysqldump --hex-blob "
if not docker_id:
return "mysqldump "
return default
return (
f"docker exec -i {docker_id} \\\n\tmysqldump "
f"docker exec -i {docker_id} \\\n\t{default}"
+ "-u {user} --password={password} \\\n\t".format(
**self._get_credentials()
)
Expand Down Expand Up @@ -196,6 +198,34 @@ def _export_conda_env(self):
self._logger.info(f"Conda environment exported to {yml_path}")


def remove_redundant(s):
"""Remove redundant parentheses from a string.
'((a=b)OR((c=d)AND((e=f))))' -> '(a=b) OR ((c=d) AND (e=f))'
Full solve would require content parsing, this removes duplicates.
https://codegolf.stackexchange.com/questions/250596/remove-redundant-parentheses
"""

def is_list(x): # Check if element is a list
return isinstance(x, list)

def list_to_str(x): # Convert list to string
return "(%s)" % "".join(map(list_to_str, x)) if is_list(x) else x

def flatten_list(nested):
ret = [flatten_list(e) if is_list(e) else e for e in nested if e]
return ret[0] if ret == [[*ret[0]]] else ret # first if all same

tokens = repr("\"'" + s)[3:] # Quote to safely eval the string
as_list = tokens.translate({40: "',['", 41: "'],'"}) # parens -> square
flattened = flatten_list(eval(as_list)) # Flatten the nested list
as_str = list_to_str(flattened) # back to str

# space out AND and OR for readability
return re.sub(r"\b(and|or)\b", r" \1 ", as_str, flags=re.IGNORECASE)


def bash_escape_sql(s, add_newline=True):
"""Escape restriction string for bash.
Expand All @@ -207,12 +237,22 @@ def bash_escape_sql(s, add_newline=True):
Add newlines for readability around AND & OR. Default True
"""
s = s.strip()
if s.startswith("WHERE"):
s = s[5:].strip()

# Balance parentheses - because make_condition may unbalance outside parens
n_open = s.count("(")
n_close = s.count(")")
add_open = max(0, n_close - n_open)
add_close = max(0, n_open - n_close)
balanced = "(" * add_open + s + ")" * add_close

s = remove_redundant(balanced)

replace_map = {
"WHERE ": "", # Remove preceding WHERE of dj.where_clause
" ": " ", # Squash double spaces
"( (": "((", # Squash double parens
") )": ")",
") )": "))",
'"': "'", # Replace double quotes with single
"`": "", # Remove backticks
}
Expand All @@ -231,7 +271,6 @@ def bash_escape_sql(s, add_newline=True):
replace_map.update({"%%%%": "%%"}) # Remove extra percent signs

for old, new in replace_map.items():
s = s.replace(old, new)
if s.startswith("(((") and s.endswith(")))"):
s = s[2:-2] # Remove extra parens for readability
s = re.sub(re.escape(old), new, s)

return s

0 comments on commit 14f49d5

Please sign in to comment.