From 0a49b12c4943f46f5d0788a45df23e3d4fbeb32a Mon Sep 17 00:00:00 2001 From: Chris Brozdowski Date: Thu, 19 Oct 2023 16:27:12 -0500 Subject: [PATCH] Generate utils from `config/` scripts (#662) * Contrib doc edits * #531 * #532: PR Template * Fix numbering * Typo * #658 config. Start to phase out config dir * Update changelog * Add docstrings. Separate sql content from funcs --- CHANGELOG.md | 6 + config/add_dj_collaborator.py | 33 +--- config/add_dj_guest.py | 38 +--- config/add_dj_module.py | 42 +---- config/add_dj_user.py | 46 +---- config/dj_config.py | 114 ++---------- dj_local_conf_example.json | 7 +- docs/src/installation.md | 29 ++- notebooks/00_Setup.ipynb | 27 ++- notebooks/py_scripts/00_Setup.py | 27 ++- src/spyglass/settings.py | 229 +++++++++++++++++++++--- src/spyglass/utils/database_settings.py | 165 +++++++++++++++++ 12 files changed, 489 insertions(+), 274 deletions(-) create mode 100644 src/spyglass/utils/database_settings.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 514c172ab..db2d7182d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## [Unreleased] + +- Migrate `config` helper scripts to Spyglass codebase. #662 +- Revise contribution guidelines. #655 +- Minor bug fixes. #656, #657, #659, #651 + ## [0.4.2] (October 10, 2023) ### Infrastructure / Support diff --git a/config/add_dj_collaborator.py b/config/add_dj_collaborator.py index 4efdd4653..cf6c0c48e 100644 --- a/config/add_dj_collaborator.py +++ b/config/add_dj_collaborator.py @@ -1,31 +1,12 @@ #!/usr/bin/env python -import os import sys -import tempfile - - -def add_collab_user(user_name): - # create a temporary file for the command - file = tempfile.NamedTemporaryFile(mode="w") - - # Create the user (if not already created) and set the password - file.write( - f"CREATE USER IF NOT EXISTS '{user_name}'@'%' IDENTIFIED BY 'temppass';\n" - ) - - # Grant privileges to databases matching the user_name pattern - file.write( - f"GRANT ALL PRIVILEGES ON `{user_name}\_%`.* TO '{user_name}'@'%';\n" - ) - - # Grant SELECT privileges on all databases - file.write(f"GRANT SELECT ON `%`.* TO '{user_name}'@'%';\n") - - file.flush() - - # run those commands in sql - os.system(f"mysql -p -h lmf-db.cin.ucsf.edu < {file.name}") +from warnings import warn +from spyglass.utils.database_settings import DatabaseSettings if __name__ == "__main__": - add_collab_user(sys.argv[1]) + warn( + "This script is deprecated. " + + "Use spyglass.utils.database_settings.DatabaseSettings instead." + ) + DatabaseSettings(user_name=sys.argv[1]).add_collab_user() diff --git a/config/add_dj_guest.py b/config/add_dj_guest.py index 6487af831..1a087a3a8 100644 --- a/config/add_dj_guest.py +++ b/config/add_dj_guest.py @@ -1,36 +1,12 @@ #!/usr/bin/env python -import os import sys -import tempfile - -shared_modules = [ - "common\_%", - "spikesorting\_%", - "decoding\_%", - "position\_%", - "position_linearization\_%", - "ripple\_%", - "lfp\_%", -] - - -def add_user(user_name): - # create a temporary file for the command - file = tempfile.NamedTemporaryFile(mode="w") - - # Create the user (if not already created) and set password - file.write( - f"CREATE USER IF NOT EXISTS '{user_name}'@'%' IDENTIFIED BY 'Data_$haring';\n" - ) - - # Grant privileges - file.write(f"GRANT SELECT ON `%`.* TO '{user_name}'@'%';\n") - - file.flush() - - # run those commands in sql - os.system(f"mysql -p -h lmf-db.cin.ucsf.edu < {file.name}") +from warnings import warn +from spyglass.utils.database_settings import DatabaseSettings if __name__ == "__main__": - add_user(sys.argv[1]) + warn( + "This script is deprecated. " + + "Use spyglass.utils.database_settings.DatabaseSettings instead." + ) + DatabaseSettings(user_name=sys.argv[1]).add_dj_guest() diff --git a/config/add_dj_module.py b/config/add_dj_module.py index 78213f8ee..218737909 100644 --- a/config/add_dj_module.py +++ b/config/add_dj_module.py @@ -1,40 +1,12 @@ #!/usr/bin/env python -import grp -import os import sys -import tempfile - -TARGET_GROUP = "kachery-users" - - -def add_module(module_name): - print(f"Granting everyone permissions to module {module_name}") - - # create a tempoary file for the command - file = tempfile.NamedTemporaryFile(mode="w") - - # find the kachery-users group - groups = grp.getgrall() - group_found = False # initialize the flag as False - for group in groups: - if group.gr_name == TARGET_GROUP: - group_found = True # set the flag to True when the group is found - break - - # Check if the group was found - if not group_found: - sys.exit(f"Error: The target group {TARGET_GROUP} was not found.") - - # get a list of usernames - for user in group.gr_mem: - file.write( - f"GRANT ALL PRIVILEGES ON `{module_name}\_%`.* TO `{user}`@'%';\n" - ) - file.flush() - - # run those commands in sql - os.system(f"mysql -p -h lmf-db.cin.ucsf.edu < {file.name}") +from warnings import warn +from spyglass.utils.database_settings import DatabaseSettings if __name__ == "__main__": - add_module(sys.argv[1]) + warn( + "This script is deprecated. " + + "Use spyglass.utils.database_settings.DatabaseSettings instead." + ) + DatabaseSettings().add_module(sys.argv[1]) diff --git a/config/add_dj_user.py b/config/add_dj_user.py index 77ca7c457..ba7dabb3b 100755 --- a/config/add_dj_user.py +++ b/config/add_dj_user.py @@ -1,44 +1,12 @@ #!/usr/bin/env python -import os import sys -import tempfile - -shared_modules = [ - "common\_%", - "spikesorting\_%", - "decoding\_%", - "position\_%", - "position_linearization\_%", - "ripple\_%", - "lfp\_%", -] - - -def add_user(user_name): - if os.path.isdir(f"/home/{user_name}"): - print("Creating database user ", user_name) - else: - sys.exit(f"Error: user_name {user_name} does not exist in /home.") - - # create a tempoary file for the command - file = tempfile.NamedTemporaryFile(mode="w") - create_user_query = f"CREATE USER IF NOT EXISTS '{user_name}'@'%' IDENTIFIED BY 'temppass';\n" - grant_privileges_query = ( - f"GRANT ALL PRIVILEGES ON `{user_name}\_%`.* TO '{user_name}'@'%';" - ) - - file.write(create_user_query + "\n") - file.write(grant_privileges_query + "\n") - for module in shared_modules: - file.write( - f"GRANT ALL PRIVILEGES ON `{module}`.* TO '{user_name}'@'%';\n" - ) - file.write(f"GRANT SELECT ON `%`.* TO '{user_name}'@'%';\n") - file.flush() - - # run those commands in sql - os.system(f"mysql -p -h lmf-db.cin.ucsf.edu < {file.name}") +from warnings import warn +from spyglass.utils.database_settings import DatabaseSettings if __name__ == "__main__": - add_user(sys.argv[1]) + warn( + "This script is deprecated. " + + "Use spyglass.utils.database_settings.DatabaseSettings instead." + ) + DatabaseSettings(user_name=sys.argv[1]).add_dj_user() diff --git a/config/dj_config.py b/config/dj_config.py index 06a9e725c..55fd8a9ad 100644 --- a/config/dj_config.py +++ b/config/dj_config.py @@ -2,110 +2,30 @@ import os import sys -import tempfile -import datajoint as dj -import yaml -import json -import warnings -from pymysql.err import OperationalError - - -def generate_config(filename: str = None, **kwargs): - """Generate a datajoint configuration file. - - Parameters - ---------- - filename : str - The name of the file to generate. Must be either yaml or json - **kwargs: list of parameters names and values that can include - base_dir : SPYGLASS_BASE_DIR - database_user : user name of system running mysql - database_host : mysql host name (default lmf-db.cin.ucsf.edu) - database_port : port number for mysql server (default 3306) - database_use_tls : Default True. Use TLS encryption. - """ - # TODO: merge with existing spyglass.settings.py - - base_dir = os.environ.get("SPYGLASS_BASE_DIR") or kwargs.get("base_dir") - if not base_dir: - raise ValueError( - "Please set base directory environment variable SPYGLASS_BASE_DIR" - ) - - base_dir = os.path.abspath(base_dir) - if not os.path.exists(base_dir): - warnings.warn(f"Base dir does not exist on this machine: {base_dir}") - - raw_dir = os.path.join(base_dir, "raw") - analysis_dir = os.path.join(base_dir, "analysis") - - config = { - "database.host": kwargs.get("database_host", "lmf-db.cin.ucsf.edu"), - "database.user": kwargs.get("database_user"), - "database.port": kwargs.get("database_port", 3306), - "database.use_tls": kwargs.get("database_use_tls", True), - "filepath_checksum_size_limit": 1 * 1024**3, - "enable_python_native_blobs": True, - "stores": { - "raw": { - "protocol": "file", - "location": raw_dir, - "stage": raw_dir, - }, - "analysis": { - "protocol": "file", - "location": analysis_dir, - "stage": analysis_dir, - }, - }, - "custom": {"spyglass_dirs": {"base": base_dir}}, - } - if not kwargs.get("database_user"): - # Adding then removing if empty retains order to make easier to read - config.pop("database.user") - - if not filename: - filename = "dj_local_config.json" - if os.path(filename).exists(): - warnings.warn(f"File already exists: {filename}") - else: - with open(filename, "w") as outfile: - if filename.endswith("json"): - json.dump(config, outfile, indent=2) - else: - yaml.dump(config, outfile, default_flow_style=False) - - return config - - -def set_configuration(config: dict): - """Sets the dj.config parameters. - - Parameters - ---------- - config : dict - Datajoint config as dictionary - """ - # copy the elements of config to dj.config - for key, value in config.items(): - dj.config[key] = value +def main(*args): + database_user, base_dir, filename = args + (None,) * (3 - len(args)) - dj.set_password() # set the users password - dj.config.save_global() # save these settings + os.environ["SPYGLASS_BASE_DIR"] = base_dir # need to set for import to work + from spyglass.settings import SpyglassConfig # noqa F401 -def main(*args): - user_name, base_dir, outfile = args + (None,) * (3 - len(args)) + config = SpyglassConfig(base_dir=base_dir) + save_method = ( + "local" + if filename == "dj_local_conf.json" + else "global" + if filename is None + else "custom" + ) - config = generate_config( - outfile, database_user=user_name, base_dir=base_dir + config.save_dj_config( + save_method=save_method, + filename=filename, + base_dir=base_dir, + database_user=database_user, ) - try: - set_configuration(config) - except OperationalError as e: - warnings.warn(f"Database connections issues: {e}") if __name__ == "__main__": diff --git a/dj_local_conf_example.json b/dj_local_conf_example.json index cf57ae7be..e962481d1 100644 --- a/dj_local_conf_example.json +++ b/dj_local_conf_example.json @@ -1,7 +1,7 @@ { - "database.host": "localhost", - "database.password": "tutorial", - "database.user": "root", + "database.host": "localhost or lmf-db.cin.ucsf.edu", + "database.password": "Delete this line for shared machines", + "database.user": "Your username", "database.port": 3306, "database.reconnect": true, "connection.init_function": null, @@ -28,7 +28,6 @@ } }, "custom": { - "database.prefix": "username_", "spyglass_dirs": { "base": "/your/base/path" }, diff --git a/docs/src/installation.md b/docs/src/installation.md index 3b6e95ddb..622c97ddb 100644 --- a/docs/src/installation.md +++ b/docs/src/installation.md @@ -47,6 +47,8 @@ additional details, see the ### Config +#### Via File (Recommended) + A `dj_local_conf.json` file in your Spyglass directory (or wherever python is launched) can hold all the specifics needed to connect to a database. This can include different directories for different pipelines. If only the `base` is @@ -69,8 +71,19 @@ specified, the subfolder names below are included as defaults. } ``` -For those who prefer environment variables, the following can pasted into a -file like `~/.bashrc`. +`dj_local_conf_example.json` can be copied and saved as `dj_local_conf.json` +to set the configuration for a given folder. Alternatively, it can be saved as +`.datajoint_config.json` in a user's home directory to be accessed globally. +See +[DataJoint docs](https://datajoint.com/docs/core/datajoint-python/0.14/quick-start/#connection) +for more details. + +#### Via Environment Variables + +Older versions of Spyglass relied exclusively on environment for config. If +`spyglass_dirs` is not found in the config file, Spyglass will look for +environment variables. These can be set either once in a terminal session, or +permanently in a `.bashrc` file. ```bash export SPYGLASS_BASE_DIR="/stelmo/nwb" @@ -78,15 +91,17 @@ export SPYGLASS_RECORDING_DIR="$SPYGLASS_BASE_DIR/recording" export SPYGLASS_SORTING_DIR="$SPYGLASS_BASE_DIR/sorting" export SPYGLASS_VIDEO_DIR="$SPYGLASS_BASE_DIR/video" export SPYGLASS_WAVEFORMS_DIR="$SPYGLASS_BASE_DIR/waveforms" -export SPYGLASS_TEMP_DIR="$SPYGLASS_BASE_DIR/tmp/spyglass" +export SPYGLASS_TEMP_DIR="$SPYGLASS_BASE_DIR/tmp" export DJ_SUPPORT_FILEPATH_MANAGEMENT="TRUE" ``` -And then loaded with `source ~/.bashrc`. +To load variables from a `.bashrc` file, run `source ~/.bashrc` in a terminal. + +#### Temporary directory -Note that a local `SPYGLASS_TEMP_DIR` (e.g., one on your machine) will speed -up spike sorting, but make sure it has enough free space (ideally at least -500GB) +A temporary directory will speed up spike sorting. If unspecified by either +method above, it will be assumed as a `tmp` subfolder relative to the base +path. Be sure it has enough free space (ideally at least 500GB). ## File manager diff --git a/notebooks/00_Setup.ipynb b/notebooks/00_Setup.ipynb index 86beb6b32..ff462b42b 100644 --- a/notebooks/00_Setup.ipynb +++ b/notebooks/00_Setup.ipynb @@ -33,7 +33,8 @@ "source": [ "## Local environment\n", "\n", - "Codespace users can skip this step.\n", + "Codespace users can skip this step. Frank Lab members should first follow\n", + "'rec to nwb overview' steps on Google Drive to set up an ssh connection.\n", "\n", "For local use, download and install ...\n", "\n", @@ -119,20 +120,36 @@ "metadata": {}, "source": [ "Members of the Frank Lab can run the `dj_config.py` helper script to generate\n", - "a default `dj_local_conf.json` like the one below. Outside users should adjust\n", - "values accordingly.\n", + "a config like the one below. Outside users should copy/paste `dj_local_conf_example` and adjust values accordingly.\n", "\n", "```bash\n", "cd spyglass\n", - "python config/dj_config.py \n", + "python config/dj_config.py \n", "```\n", "\n", - "Producing a json config like the following.\n", + "The base path (formerly `SPYGLASS_BASE_DIR`) is the directory where all data\n", + "will be saved. See also\n", + "[docs](https://lorenfranklab.github.io/spyglass/0.4/installation/) for more\n", + "information on subdirectories.\n", + "\n", + "A different `output_filename` will save different files: \n", + "\n", + "- `dj_local_conf.json`: Recommended. Used for tutorials. A file in the current\n", + " directory DataJoint will automatically recognize when a Python session is\n", + " launched from this directory.\n", + "- `.datajoint_config.json` or no input: A file in the user's home directory \n", + " that will be referenced whenever no local version (see above) is present.\n", + "- Anything else: A custom name that will need to be loaded (e.g.,\n", + " `dj.load('x')`) for each python session.\n", + "\n", + "\n", + "The config will be a `json` file like the following.\n", "\n", "```json\n", "{\n", " \"database.host\": \"lmf-db.cin.ucsf.edu\",\n", " \"database.user\": \"\",\n", + " \"database.password\": \"Not recommended for shared machines\",\n", " \"database.port\": 3306,\n", " \"database.use_tls\": true,\n", " \"enable_python_native_blobs\": true,\n", diff --git a/notebooks/py_scripts/00_Setup.py b/notebooks/py_scripts/00_Setup.py index 561da0c73..583777de8 100644 --- a/notebooks/py_scripts/00_Setup.py +++ b/notebooks/py_scripts/00_Setup.py @@ -30,7 +30,8 @@ # ## Local environment # -# Codespace users can skip this step. +# Codespace users can skip this step. Frank Lab members should first follow +# 'rec to nwb overview' steps on Google Drive to set up an ssh connection. # # For local use, download and install ... # @@ -96,20 +97,36 @@ # # Members of the Frank Lab can run the `dj_config.py` helper script to generate -# a default `dj_local_conf.json` like the one below. Outside users should adjust -# values accordingly. +# a config like the one below. Outside users should copy/paste `dj_local_conf_example` and adjust values accordingly. # # ```bash # # cd spyglass -# python config/dj_config.py +# python config/dj_config.py # ``` # -# Producing a json config like the following. +# The base path (formerly `SPYGLASS_BASE_DIR`) is the directory where all data +# will be saved. See also +# [docs](https://lorenfranklab.github.io/spyglass/0.4/installation/) for more +# information on subdirectories. +# +# A different `output_filename` will save different files: +# +# - `dj_local_conf.json`: Recommended. Used for tutorials. A file in the current +# directory DataJoint will automatically recognize when a Python session is +# launched from this directory. +# - `.datajoint_config.json` or no input: A file in the user's home directory +# that will be referenced whenever no local version (see above) is present. +# - Anything else: A custom name that will need to be loaded (e.g., +# `dj.load('x')`) for each python session. +# +# +# The config will be a `json` file like the following. # # ```json # { # "database.host": "lmf-db.cin.ucsf.edu", # "database.user": "", +# "database.password": "Not recommended for shared machines", # "database.port": 3306, # "database.use_tls": true, # "enable_python_native_blobs": true, diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index 57a8ead05..7f1ff9572 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -1,7 +1,11 @@ +import json import os +import warnings from pathlib import Path import datajoint as dj +import yaml +from pymysql.err import OperationalError class SpyglassConfig: @@ -14,7 +18,7 @@ class SpyglassConfig: facilitate testing. """ - def __init__(self, base_dir=None): + def __init__(self, base_dir=None, **kwargs): """ Initializes a new instance of the class. @@ -48,6 +52,15 @@ def __init__(self, base_dir=None): }, } + self.dj_defaults = { + "database.host": kwargs.get("database_host", "lmf-db.cin.ucsf.edu"), + "database.user": kwargs.get("database_user"), + "database.port": kwargs.get("database_port", 3306), + "database.use_tls": kwargs.get("database_use_tls", True), + "filepath_checksum_size_limit": 1 * 1024**3, + "enable_python_native_blobs": True, + } + self.env_defaults = { "FIGURL_CHANNEL": "franklab2", "DJ_SUPPORT_FILEPATH_MANAGEMENT": "TRUE", @@ -92,9 +105,11 @@ def load_config(self, force_reload=False): or os.environ.get("SPYGLASS_BASE_DIR") ) - if not resolved_base: + if not resolved_base or not Path(resolved_base).exists(): raise ValueError( - "SPYGLASS_BASE_DIR not defined in dj.config or os env vars" + f"Could not find SPYGLASS_BASE_DIR: {resolved_base}" + + "\n\tCheck dj.config['custom']['spyglass_dirs']['base']" + + "\n\tand os.environ['SPYGLASS_BASE_DIR']" ) config_dirs = {"SPYGLASS_BASE_DIR": resolved_base} @@ -112,7 +127,7 @@ def load_config(self, force_reload=False): dj_spyglass.get(dir) or dj_kachery.get(dir) or env_loc - or resolved_base + "/" + dir_str + or str(Path(resolved_base) / dir_str) ).replace('"', "") config_dirs.update({dir_env_fmt: dir_location}) @@ -130,7 +145,6 @@ def load_config(self, force_reload=False): {**config_dirs, **kachery_zone_dict, **loaded_env} ) self._mkdirs_from_dict_vals(config_dirs) - self._set_dj_config_stores(config_dirs) self._config = dict( debug_mode=dj_custom.get("debug_mode", False), @@ -139,6 +153,9 @@ def load_config(self, force_reload=False): **kachery_zone_dict, **loaded_env, ) + + self._set_dj_config_stores(config_dirs) + return self._config def _load_env_vars(self): @@ -149,6 +166,7 @@ def _load_env_vars(self): @staticmethod def _set_env_with_dict(env_dict): + # NOTE: Kept for backwards compatibility. Should be removed in future. for var, val in env_dict.items(): os.environ[var] = str(val) @@ -157,8 +175,7 @@ def _mkdirs_from_dict_vals(dir_dict): for dir_str in dir_dict.values(): Path(dir_str).mkdir(exist_ok=True) - @staticmethod - def _set_dj_config_stores(dir_dict, check_match=True, set_stores=True): + def _set_dj_config_stores(self, check_match=True, set_stores=True): """ Checks dj.config['stores'] match resolved dirs. Ensures stores set. @@ -171,9 +188,6 @@ def _set_dj_config_stores(dir_dict, check_match=True, set_stores=True): set_stores: bool Optional. Default True. Set dj.config['stores'] to resolved dirs. """ - raw_dir = Path(dir_dict["SPYGLASS_RAW_DIR"]) - analysis_dir = Path(dir_dict["SPYGLASS_ANALYSIS_DIR"]) - if check_match: dj_stores = dj.config.get("stores", {}) store_raw = dj_stores.get("raw", {}).get("location") @@ -184,40 +198,197 @@ def _set_dj_config_stores(dir_dict, check_match=True, set_stores=True): + "\n\tdj.config['stores']['{0}']['location']:\n\t\t{1}" + "\n\tSPYGLASS_{2}_DIR:\n\t\t{3}." ) - if store_raw and Path(store_raw) != raw_dir: + if store_raw and Path(store_raw) != Path(self.raw_dir): raise ValueError( - err_template.format("raw", store_raw, "RAW", raw_dir) + err_template.format("raw", store_raw, "RAW", self.raw_dir) ) - if store_analysis and Path(store_analysis) != analysis_dir: + if store_analysis and Path(store_analysis) != Path( + self.analysis_dir + ): raise ValueError( err_template.format( - "analysis", store_analysis, "ANALYSIS", analysis_dir + "analysis", + store_analysis, + "ANALYSIS", + self.analysis_dir, ) ) if set_stores: - dj.config["stores"] = { + dj.config.update(self._dj_stores) + + def dir_to_var(self, dir: str, dir_type: str = "spyglass"): + """Converts a dir string to an env variable name.""" + dir_string = self.relative_dirs.get(dir_type, {}).get(dir, "base") + return f"{dir_type.upper()}_{dir_string.upper()}_DIR" + + def _generate_dj_config( + self, + base_dir: str = None, + database_user: str = None, + database_host: str = "lmf-db.cin.ucsf.edu", + database_port: int = 3306, + database_use_tls: bool = True, + **kwargs, + ): + """Generate a datajoint configuration file. + + Parameters + ---------- + base_dir : str, optional + The base directory. If not provided, will use the env variable or + existing config. + database_user : str, optional + The database user. If not provided, resulting config will not + specify. + database_host : str, optional + Default lmf-db.cin.ucsf.edu. MySQL host name. + dapabase_port : int, optional + Default 3306. Port number for MySQL server. + database_use_tls : bool, optional + Default True. Use TLS encryption. + **kwargs: dict, optional + Any other valid datajoint configuration parameters. + Note: python will raise error for params with `.` in name. + """ + + if base_dir: + self.supplied_base_dir = base_dir + self.load_config(force_reload=True) + + if database_user: + kwargs.update({"database.user": database_user}) + + kwargs.update( + { + "database.host": database_host, + "database.port": database_port, + "database.use_tls": database_use_tls, + } + ) + + # `|` merges dictionaries + return self.dj_defaults | self._dj_stores | self._dj_custom | kwargs + + def save_dj_config( + self, + save_method: str = "global", + filename: str = None, + base_dir=None, + database_user=None, + set_password=True, + **kwargs, + ): + """Set the dj.config parameters, set password, and save config to file. + + Parameters + ---------- + save_method : {'local', 'global', 'custom'}, optional + The method to use to save the config. If either 'local' or 'global', + datajoint builtins will be used to save. + filename : str or Path, optional + Default to datajoint global config. If save_method = 'custom', name + of file to generate. Must end in either be either yaml or json. + base_dir : str, optional + The base directory. If not provided, will default to the env var + database_user : str, optional + The database user. If not provided, resulting config will not + specify. + set_password : bool, optional + Default True. Set the database password. + """ + if save_method == "local": + filepath = Path(".") / dj.settings.LOCALCONFIG + elif not filename or save_method == "global": + save_method = "global" + filepath = Path("~").expanduser() / dj.settings.GLOBALCONFIG + + dj.config.update( + self._generate_dj_config( + base_dir=base_dir, database_user=database_user, **kwargs + ) + ) + + if set_password: + try: + dj.set_password() + except OperationalError as e: + warnings.warn(f"Database connection issues. Wrong pass? {e}") + # NOTE: Save anyway? Or raise error? + + user_warn = ( + f"Replace existing file? {filepath.resolve()}\n\t" + + "\n\t".join([f"{k}: {v}" for k, v in config.items()]) + + "\n" + ) + + if filepath.exists() and dj.utils.user_choice(user_warn)[0] != "y": + return dj.config + + if save_method == "global": + dj.config.save_global(verbose=True) + return + + if save_method == "local": + dj.config.save_local(verbose=True) + return + + with open(filename, "w") as outfile: + if filename.endswith("json"): + json.dump(dj.config, outfile, indent=2) + else: + yaml.dump(dj.config, outfile, default_flow_style=False) + + @property + def _dj_stores(self) -> dict: + self.load_config() + return { + "stores": { "raw": { "protocol": "file", - "location": str(raw_dir), - "stage": str(raw_dir), + "location": self.raw_dir, + "stage": self.raw_dir, }, "analysis": { "protocol": "file", - "location": str(analysis_dir), - "stage": str(analysis_dir), + "location": self.analysis_dir, + "stage": self.analysis_dir, }, } + } - def dir_to_var(self, dir: str, dir_type: str = "spyglass"): - """Converts a dir string to an env variable name.""" - dir_string = self.relative_dirs.get(dir_type, {}).get(dir, "base") - return f"{dir_type.upper()}_{dir_string.upper()}_DIR" + @property + def _dj_custom(self) -> dict: + self.load_config() + return { + "custom": { + "debug_mode": str(self.debug_mode).lower(), + "spyglass_dirs": { + "base": self.base_dir, + "raw": self.raw_dir, + "analysis": self.analysis_dir, + "recording": self.recording_dir, + "sorting": self.sorting_dir, + "waveforms": self.waveforms_dir, + "temp": self.temp_dir, + "video": self.video_dir, + }, + "kachery_dirs": { + "cloud": self.config.get( + self.dir_to_var("cloud", "kachery") + ), + "storage": self.config.get( + self.dir_to_var("storage", "kachery") + ), + "temp": self.config.get(self.dir_to_var("tmp", "kachery")), + }, + "kachery_zone": "franklab.default", + } + } @property def config(self) -> dict: - if not self._config: - self.load_config() + self.load_config() return self._config @property @@ -240,6 +411,10 @@ def recording_dir(self) -> str: def sorting_dir(self) -> str: return self.config.get(self.dir_to_var("sorting")) + @property + def waveforms_dir(self) -> str: + return self.config.get(self.dir_to_var("waveforms")) + @property def temp_dir(self) -> str: return self.config.get(self.dir_to_var("temp")) @@ -248,6 +423,10 @@ def temp_dir(self) -> str: def waveform_dir(self) -> str: return self.config.get(self.dir_to_var("waveform")) + @property + def video_dir(self) -> str: + return self.config.get(self.dir_to_var("video")) + @property def debug_mode(self) -> bool: return self.config.get("debug_mode", False) diff --git a/src/spyglass/utils/database_settings.py b/src/spyglass/utils/database_settings.py new file mode 100644 index 000000000..3a29b3834 --- /dev/null +++ b/src/spyglass/utils/database_settings.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python +import grp +import os +import sys +import tempfile +from pathlib import Path + +import datajoint as dj + +GRANT_ALL = "GRANT ALL PRIVILEGES ON " +GRANT_SEL = "GRANT SELECT ON " +CREATE_USR = "CREATE USER IF NOT EXISTS " +TEMP_PASS = " IDENTIFIED BY 'temppass';" +ESC = r"\_%" + + +class DatabaseSettings: + def __init__( + self, user_name=None, host_name=None, target_group=None, debug=False + ): + """Class to manage common database settings + + Parameters + ---------- + user_name : str, optional + The name of the user to add to the database. Default from dj.config + host_name : str, optional + The name of the host to add to the database. Default from dj.config + target_group : str, optional + Group to which user belongs. Default is kachery-users + debug : bool, optional + Default False. If True, print sql instead of running + """ + self.shared_modules = [ + f"common{ESC}", + f"spikesorting{ESC}", + f"decoding{ESC}", + f"position{ESC}", + f"position_linearization{ESC}", + f"ripple{ESC}", + f"lfp{ESC}", + ] + self.user = user_name or dj.config["database.user"] + self.host = ( + host_name or dj.config["database.host"] or "lmf-db.cin.ucsf.edu" + ) + self.target_group = target_group or "kachery-users" + self.debug = debug + + @property + def _add_collab_usr_sql(self): + return [ + # Create the user (if not already created) and set the password + f"{CREATE_USR}'{self.user}'@'%'{TEMP_PASS}\n", + # Grant privileges to databases matching the user_name pattern + f"{GRANT_ALL}`{self.user}{ESC}`.* TO '{self.user}'@'%';\n", + # Grant SELECT privileges on all databases + f"{GRANT_SEL}`%`.* TO '{self.user}'@'%';\n", + ] + + def add_collab_user(self): + """Add collaborator user with full permissions to shared modules""" + file = self.write_temp_file(self._add_collab_usr_sql) + self.run_file(file) + + @property + def _add_dj_guest_sql(self): + return [ + # Create the user (if not already created) and set the password + f"{CREATE_USR}'{self.user}'@'%' IDENTIFIED BY 'Data_$haring';\n", + # Grant privileges + f"{GRANT_SEL}`%`.* TO '{self.user}'@'%';\n", + ] + + def add_dj_guest(self): + """Add guest user with select permissions to shared modules""" + file = self.write_temp_file(self._add_dj_guest_sql) + self.run_file(file) + + def _find_group(self): + # find the kachery-users group + groups = grp.getgrall() + group_found = False # initialize the flag as False + for group in groups: + if group.gr_name == self.target_group: + group_found = ( + True # set the flag to True when the group is found + ) + break + + # Check if the group was found + if not group_found: + if self.debug: + print(f"All groups: {[g.gr_name for g in groups]}") + sys.exit( + f"Error: The target group {self.target_group} was not found." + ) + + return group + + def _add_module_sql(self, module_name, group): + return [ + f"{GRANT_ALL}`{module_name}{ESC}`.* TO `{user}`@'%';\n" + # get a list of usernames + for user in group.gr_mem + ] + + def add_module(self, module_name): + """Add module to database. Grant permissions to all users in group""" + print(f"Granting everyone permissions to module {module_name}") + group = self._find_group() + file = self.write_temp_file(self._add_module_sql(module_name, group)) + self.run_file(file) + + @property + def _add_dj_user_sql(self): + return ( + [ + f"{CREATE_USR}'{self.user}'@'%' " + + "IDENTIFIED BY 'temppass';\n", + f"{GRANT_ALL}`{self.user}{ESC}`.* TO '{self.user}'@'%';" + "\n", + ] + + [ + f"{GRANT_ALL}`{module}`.* TO '{self.user}'@'%';\n" + for module in self.shared_modules + ] + + [f"{GRANT_SEL}`%`.* TO '{self.user}'@'%';\n"] + ) + + def add_dj_user(self, check_exists=True): + """Add user to database with permissions to shared modules""" + if check_exists: + user_home = Path.home().parent / self.user + if user_home.exists(): + print("Creating database user ", self.user) + else: + sys.exit( + f"Error: could not find {self.user} in home dir: {user_home}" + ) + + file = self.write_temp_file(self._add_dj_user_sql) + self.run_file(file) + + def write_temp_file(self, content: list) -> tempfile.NamedTemporaryFile: + """Write content to a temporary file and return the file object""" + file = tempfile.NamedTemporaryFile(mode="w") + for line in content: + file.write(line) + file.flush() + + if self.debug: + from pprint import pprint # noqa F401 + + pprint(file.name) + pprint(content) + + return file + + def run_file(self, file): + """Run commands saved to file in sql""" + + if self.debug: + return + + os.system(f"mysql -p -h {self.host} < {file.name}")