diff --git a/CHANGELOG.md b/CHANGELOG.md index b5dc7c78f..bbf2ca515 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,7 +20,7 @@ ### Pipelines - Spike sorting: Add SpikeSorting V1 pipeline. #651 -- LFP: Minor fixes to LFPBandV1 populator. #706 +- LFP: Minor fixes to LFPBandV1 populator and `make`. #706, #795 - Linearization: - Minor fixes to LinearizedPositionV1 pipeline #695 - Rename `position_linearization` -> `linearization`. #717 @@ -34,6 +34,7 @@ - Use the new `non_local_detector` package for decoding #731 - Allow multiple spike waveform features for clusterelss decoding #731 - Reorder notebooks #731 + - Add fetch class functionality to `Merge` table. #783, #786 ## [0.4.3] (November 7, 2023) diff --git a/dj_local_conf_example.json b/dj_local_conf_example.json index 437d77577..1360fb7a0 100644 --- a/dj_local_conf_example.json +++ b/dj_local_conf_example.json @@ -14,7 +14,7 @@ "display.show_tuple_count": true, "database.use_tls": null, "enable_python_native_blobs": true, - "filepath_checksum_size_limit": null, + "filepath_checksum_size_limit": 1073741824, "stores": { "raw": { "protocol": "file", @@ -28,14 +28,28 @@ } }, "custom": { + "debug_mode": "false", + "test_mode": "false", "spyglass_dirs": { - "base": "/your/path/like/stelmo/nwb/" + "base": "/your/base/path", + "raw": "/your/base/path/raw", + "analysis": "/your/base/path/analysis", + "recording": "/your/base/path/recording", + "sorting": "/your/base/path/spikesorting", + "waveforms": "/your/base/path/waveforms", + "temp": "/your/base/path/tmp", + "video": "/your/base/path/video" }, "kachery_dirs": { - "cloud": "/your/path/.kachery-cloud" + "cloud": "/your/base/path/kachery_storage", + "storage": "/your/base/path/kachery_storage", + "temp": "/your/base/path/tmp" }, "dlc_dirs": { - "base": "/your/path/like/nimbus/deeplabcut/" + "base": "/your/base/path/deeplabcut", + "project": "/your/base/path/deeplabcut/projects", + "video": "/your/base/path/deeplabcut/video", + "output": "/your/base/path/deeplabcut/output" }, "kachery_zone": "franklab.default" } diff --git a/docs/src/index.md b/docs/src/index.md index 4f0f7be74..1a0233192 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -2,46 +2,49 @@ ![Figure 1](./images/fig1.png) -**Spyglass** is an open-source software framework designed to offer reliable -and reproducible analysis of neuroscience data and sharing of the results -with collaborators and the broader community. +**Spyglass** is an open-source software framework designed to offer reliable and +reproducible analysis of neuroscience data and sharing of the results with +collaborators and the broader community. Features of Spyglass include: -+ **Standardized data storage** - Spyglass uses the open-source - [Neurodata Without Borders: Neurophysiology (NWB:N)](https://www.nwb.org/) - format to ingest and store processed data. NWB:N is a standard set by the BRAIN - Initiative for neurophysiological data ([Rübel et al., 2022](https://doi.org/10.7554/elife.78362)). -+ **Reproducible analysis** - Spyglass uses [DataJoint](https://datajoint.com/) - to ensure that all analysis is reproducible. DataJoint is a data management - system that automatically tracks dependencies between data and analysis code. This - ensures that all analysis is reproducible and that the results are - automatically updated when the data or analysis code changes. -+ **Common analysis tools** - Spyglass provides easy usage of the open-source packages - [SpikeInterface](https://github.com/SpikeInterface/spikeinterface), - [Ghostipy](https://github.com/kemerelab/ghostipy), and [DeepLabCut](https://github.com/DeepLabCut/DeepLabCut) - for common analysis tasks. These packages are well-documented and have active - developer communities. -+ **Interactive data visualization** - Spyglass uses [figurl](https://github.com/flatironinstitute/figurl) - to create interactive data visualizations that can be shared with collaborators - and the broader community. These visualizations are hosted on the web - and can be viewed in any modern web browser. The interactivity allows users to - explore the data and analysis results in detail. -+ **Sharing results** - Spyglass enables sharing of data and analysis results via - [Kachery](https://github.com/flatironinstitute/kachery-cloud), a - decentralized content addressable data sharing platform. Kachery Cloud allows - users to access the database and pull data and analysis results directly - to their local machine. -+ **Pipeline versioning** - Processing and analysis of data in neuroscience is - often dynamic, requiring new features. Spyglass uses *Merge tables* to ensure that - analysis pipelines can be versioned. This allows users to easily use and compare - results from different versions of the analysis pipeline while retaining - the ability to access previously generated results. -+ **Cautious Delete** - Spyglass uses a `cautious delete` feature to ensure - that data is not accidentally deleted by other users. When a user deletes data, - Spyglass will first check to see if the data belongs to another team of users. - This enables teams of users to work collaboratively on the same database without - worrying about accidentally deleting each other's data. +- **Standardized data storage** - Spyglass uses the open-source + [Neurodata Without Borders: Neurophysiology (NWB:N)](https://www.nwb.org/) + format to ingest and store processed data. NWB:N is a standard set by the + BRAIN Initiative for neurophysiological data + ([Rübel et al., 2022](https://doi.org/10.7554/elife.78362)). +- **Reproducible analysis** - Spyglass uses [DataJoint](https://datajoint.com/) + to ensure that all analysis is reproducible. DataJoint is a data management + system that automatically tracks dependencies between data and analysis + code. This ensures that all analysis is reproducible and that the results + are automatically updated when the data or analysis code changes. +- **Common analysis tools** - Spyglass provides easy usage of the open-source + packages [SpikeInterface](https://github.com/SpikeInterface/spikeinterface), + [Ghostipy](https://github.com/kemerelab/ghostipy), and + [DeepLabCut](https://github.com/DeepLabCut/DeepLabCut) for common analysis + tasks. These packages are well-documented and have active developer + communities. +- **Interactive data visualization** - Spyglass uses + [figurl](https://github.com/flatironinstitute/figurl) to create interactive + data visualizations that can be shared with collaborators and the broader + community. These visualizations are hosted on the web and can be viewed in + any modern web browser. The interactivity allows users to explore the data + and analysis results in detail. +- **Sharing results** - Spyglass enables sharing of data and analysis results + via [Kachery](https://github.com/flatironinstitute/kachery-cloud), a + decentralized content addressable data sharing platform. Kachery Cloud + allows users to access the database and pull data and analysis results + directly to their local machine. +- **Pipeline versioning** - Processing and analysis of data in neuroscience is + often dynamic, requiring new features. Spyglass uses *Merge tables* to + ensure that analysis pipelines can be versioned. This allows users to easily + use and compare results from different versions of the analysis pipeline + while retaining the ability to access previously generated results. +- **Cautious Delete** - Spyglass uses a `cautious delete` feature to ensure that + data is not accidentally deleted by other users. When a user deletes data, + Spyglass will first check to see if the data belongs to another team of + users. This enables teams of users to work collaboratively on the same + database without worrying about accidentally deleting each other's data. ## Getting Started diff --git a/docs/src/installation.md b/docs/src/installation.md index 4d111de09..d588d2daf 100644 --- a/docs/src/installation.md +++ b/docs/src/installation.md @@ -25,7 +25,7 @@ pip install spikeinterface[full,widgets] pip install mountainsort4 ``` -WARNING: If you are on an M1 Mac, you need to install `pyfftw` via `conda` +__WARNING:__ If you are on an M1 Mac, you need to install `pyfftw` via `conda` BEFORE installing `ghostipy`: ```bash @@ -49,40 +49,30 @@ additional details, see the #### Via File (Recommended) -A `dj_local_conf.json` file in your Spyglass directory (or wherever python is -launched) can hold all the specifics needed to connect to a database. This can -include different directories for different pipelines. If only the `base` is -specified, the subfolder names below are included as defaults. - -```json -{ - "custom": { - "database.prefix": "username_", - "spyglass_dirs": { - "base": "/your/base/path", - "raw": "/your/base/path/raw", - "analysis": "/your/base/path/analysis", - "recording": "/your/base/path/recording", - "spike_sorting_storage": "/your/base/path/spikesorting", - "waveforms": "/your/base/path/waveforms", - "temp": "/your/base/path/tmp" - } - } -} -``` - -`dj_local_conf_example.json` can be copied and saved as `dj_local_conf.json` to -set the configuration for a given folder. Alternatively, it can be saved as -`.datajoint_config.json` in a user's home directory to be accessed globally. See +A `dj_local_conf.json` file in your current directory when launching python can +hold all the specifics needed to connect to a database. This can include +different directories for different pipelines. If only the Spyglass `base` is +specified, other subfolder names are assumed from defaults. See +`dj_local_conf_example.json` for the full set of options. This example can be +copied and saved as `dj_local_conf.json` to set the configuration for a given +folder. Alternatively, it can be saved as `.datajoint_config.json` in a user's +home directory to be accessed globally. See [DataJoint docs](https://datajoint.com/docs/core/datajoint-python/0.14/quick-start/#connection) for more details. +Note that raw and analysis folder locations should be specified under both +`stores` and `custom` sections of the config file. The `stores` section is used +by DataJoint to store the location of files referenced in database, while the +`custom` section is used by Spyglass. Spyglass will check that these sections +match on startup. + #### Via Environment Variables Older versions of Spyglass relied exclusively on environment for config. If `spyglass_dirs` is not found in the config file, Spyglass will look for environment variables. These can be set either once in a terminal session, or -permanently in a `.bashrc` file. +permanently in a unix settings file (e.g., `.bashrc` or `.bash_profile`) in your +home directory. ```bash export SPYGLASS_BASE_DIR="/stelmo/nwb" @@ -102,14 +92,21 @@ A temporary directory will speed up spike sorting. If unspecified by either method above, it will be assumed as a `tmp` subfolder relative to the base path. Be sure it has enough free space (ideally at least 500GB). +#### Subfolders + +If subfolders do not exist, they will be created automatically. If unspecified +by either method above, they will be assumed as `recording`, `sorting`, `video`, +etc. subfolders relative to the base path. + ## File manager [`kachery-cloud`](https://github.com/flatironinstitute/kachery-cloud) is a file manager for Frank Lab collaborators who do not have access to the lab's production database. -To customize `kachery` file paths, the following can similarly be pasted into -your `.bashrc`. If unspecified, the defaults below are assumed. +To customize `kachery` file paths, see `dj_local_conf_example.json` or set the +following variables in your unix settings file (e.g., `.bashrc`). If +unspecified, the defaults below are assumed. ```bash export KACHERY_CLOUD_DIR="$SPYGLASS_BASE_DIR/.kachery-cloud" @@ -122,3 +119,9 @@ Be sure to load these with `source ~/.bashrc` to persist changes. Finally, open up a python console (e.g., run `ipython` from terminal) and import `spyglass` to check that the installation has worked. + +```python +from spyglass.common import Nwbfile + +Nwbfile() +``` diff --git a/notebooks/00_Setup.ipynb b/notebooks/00_Setup.ipynb index 8f55ea731..9bfeff14b 100644 --- a/notebooks/00_Setup.ipynb +++ b/notebooks/00_Setup.ipynb @@ -8,6 +8,14 @@ "# Setup\n" ] }, + { + "cell_type": "markdown", + "id": "6f423f76", + "metadata": {}, + "source": [ + "## Intro\n" + ] + }, { "cell_type": "markdown", "id": "cbb74150", @@ -31,9 +39,15 @@ "id": "65a5bf87", "metadata": {}, "source": [ - "## Local environment\n", - "\n", - "Codespace users can skip this step. Frank Lab members should first follow\n", + "## Local environment\n" + ] + }, + { + "cell_type": "markdown", + "id": "aa6bddcb", + "metadata": {}, + "source": [ + "JupyterHub users can skip this step. Frank Lab members should first follow\n", "'rec to nwb overview' steps on Google Drive to set up an ssh connection.\n", "\n", "For local use, download and install ...\n", @@ -76,7 +90,9 @@ "\n", "_Note:_ Spyglass is also installable via\n", "[pip]()\n", - "and [pypi](https://pypi.org/project/spyglass-neuro/) with `pip install spyglass-neuro`, but downloading from GitHub will also other files accessible.\n", + "and [pypi](https://pypi.org/project/spyglass-neuro/) with\n", + "`pip install spyglass-neuro`, but downloading from GitHub will also download\n", + "other files.\n", "\n", "Next, within VSCode,\n", "[select the kernel](https://code.visualstudio.com/docs/datascience/jupyter-kernel-management)\n", @@ -92,7 +108,7 @@ "id": "f87a0acc", "metadata": {}, "source": [ - "## Database Connection\n" + "## Database\n" ] }, { @@ -103,11 +119,24 @@ "You have a few options for databases.\n", "\n", "1. Connect to an existing database.\n", - "2. Use GitHub Codespaces (coming soon...)\n", - "3. Run your own database with [Docker](#running-your-own-database)\n", + "2. Run your own database with [Docker](#running-your-own-database)\n", + "3. JupyterHub (coming soon...)\n", + "\n", + "Your choice above should result in a set of credentials, including host name,\n", + "host port, user name, and password. Note these for the next step.\n", "\n", - "Once your database is set up, be sure to configure the connection\n", - "with your `dj_local_conf.json` file.\n" + "
Note for MySQL 8 users, including Frank Lab members\n", + "\n", + "Using a MySQL 8 server, like the server hosted by the Frank Lab, will\n", + "require the pre-release version of DataJoint to change one's password.\n", + "\n", + "```bash\n", + "cd /location/for/datajoint/source/files/\n", + "git clone https://github.com/datajoint/datajoint-python\n", + "pip install ./datajoint-python\n", + "```\n", + "\n", + "
\n" ] }, { @@ -123,85 +152,10 @@ "id": "580d3feb", "metadata": {}, "source": [ - "Members of the Frank Lab will need to use DataJoint 0.14.2 (currently in\n", - "pre-release) in order to change their password on the MySQL 8 server. DataJoint\n", - "0.14.2\n", - "\n", - "```bash\n", - "git clone https://github.com/datajoint/datajoint-python\n", - "pip install ./datajoint-python\n", - "```\n", - "\n", - "Members of the lab can run the `dj_config.py` helper script to generate a config\n", - "like the one below.\n", - "\n", - "```bash\n", - "cd spyglass\n", - "python config/dj_config.py \n", - "```\n", - "\n", - "Outside users should copy/paste `dj_local_conf_example` and adjust values\n", - "accordingly.\n", - "\n", - "The base path (formerly `SPYGLASS_BASE_DIR`) is the directory where all data\n", - "will be saved. See also\n", - "[docs](https://lorenfranklab.github.io/spyglass/0.4/installation/) for more\n", - "information on subdirectories.\n", - "\n", - "A different `output_filename` will save different files:\n", - "\n", - "- `dj_local_conf.json`: Recommended. Used for tutorials. A file in the current\n", - " directory DataJoint will automatically recognize when a Python session is\n", - " launched from this directory.\n", - "- `.datajoint_config.json` or no input: A file in the user's home directory\n", - " that will be referenced whenever no local version (see above) is present.\n", - "- Anything else: A custom name that will need to be loaded (e.g.,\n", - " `dj.load('x')`) for each python session.\n", - "\n", - "The config will be a `json` file like the following.\n", - "\n", - "```json\n", - "{\n", - " \"database.host\": \"lmf-db.cin.ucsf.edu\",\n", - " \"database.user\": \"\",\n", - " \"database.password\": \"Not recommended for shared machines\",\n", - " \"database.port\": 3306,\n", - " \"database.use_tls\": true,\n", - " \"enable_python_native_blobs\": true,\n", - " \"filepath_checksum_size_limit\": 1 * 1024**3,\n", - " \"loglevel\": \"INFO\",\n", - " \"stores\": {\n", - " \"raw\": {\n", - " \"protocol\": \"file\",\n", - " \"location\": \"/stelmo/nwb/raw\",\n", - " \"stage\": \"/stelmo/nwb/raw\"\n", - " },\n", - " \"analysis\": {\n", - " \"protocol\": \"file\",\n", - " \"location\": \"/stelmo/nwb/analysis\",\n", - " \"stage\": \"/stelmo/nwb/analysis\"\n", - " }\n", - " },\n", - " \"custom\": {\n", - " \"spyglass_dirs\": {\n", - " \"base\": \"/stelmo/nwb/\"\n", - " }\n", - " }\n", - "}\n", - "```\n", - "\n", - "Spyglass will use the log level present in your DataJoint config to decide the\n", - "appropriate logging level for this session. To change the messages you see,\n", - "select from one of [these options](https://docs.python.org/3/library/logging.html#levels).\n", - "\n", - "If you see an error saying `Could not find SPYGLASS_BASE_DIR`, try loading your\n", - "config before importing Spyglass.\n", + "Connecting to an existing database will require a user name and password.\n", + "Please contact your database administrator for this information.\n", "\n", - "```python\n", - "import datajoint as dj\n", - "dj.load('/path/to/config')\n", - "import spyglass\n", - "```\n" + "Frank Lab members should contact Chris.\n" ] }, { @@ -209,16 +163,14 @@ "id": "a63761cc-437f-4e4a-a777-664b321b9b94", "metadata": {}, "source": [ - "### Running your own database\n" + "### Running your own database with Docker\n" ] }, { "cell_type": "markdown", - "id": "97bac46e", + "id": "e8f57976", "metadata": {}, "source": [ - "#### Setup Docker\n", - "\n", "- First, [install Docker](https://docs.docker.com/engine/install/).\n", "- Add yourself to the\n", " [`docker` group](https://docs.docker.com/engine/install/linux-postinstall/) so\n", @@ -229,10 +181,6 @@ " docker pull datajoint/mysql:8.0\n", " ```\n", "\n", - "_Note_: For this demo, MySQL version won't matter. Some\n", - "[database management](https://lorenfranklab.github.io/spyglass/latest/misc/database_management/#mysql-version)\n", - "features of Spyglass, however, expect MySQL >= 8.\n", - "\n", "- When run, this is referred to as a 'Docker container'\n", "- Next start the container with a couple additional pieces of info...\n", "\n", @@ -252,26 +200,12 @@ " docker run --name spyglass-db -v dj-vol:/var/lib/mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=tutorial datajoint/mysql\n", " ```\n", "\n", - "#### Configure\n", - "\n", - "The `dj_local_conf_example.json` contains all the defaults for a Docker\n", - "connection. Simply rename to `dj_local_conf.json` and modify the contents\n", - "accordingly. This includes the host, password and user. For Spyglass, you'll\n", - "want to set your base path under `custom`:\n", - "\n", - "```json\n", - "{\n", - " \"database.host\": \"localhost\",\n", - " \"database.password\": \"tutorial\",\n", - " \"database.user\": \"root\",\n", - " \"custom\": {\n", - " \"database.prefix\": \"username_\",\n", - " \"spyglass_dirs\": {\n", - " \"base\": \"/your/base/path\"\n", - " }\n", - " }\n", - "}\n", - "```\n" + "Docker credentials are as follows:\n", + "\n", + "- Host: localhost\n", + "- Password: tutorial\n", + "- User: root\n", + "- Port: 3306\n" ] }, { @@ -279,197 +213,101 @@ "id": "706d0ed5", "metadata": {}, "source": [ - "### Loading the config\n", - "\n", - "We can check that the paths are correctly set up by loading the config from\n", - "the main Spyglass directory.\n" + "### Config and Connecting to the database\n" + ] + }, + { + "cell_type": "markdown", + "id": "22d3b72d", + "metadata": {}, + "source": [ + "Spyglass can load settings from either a DataJoint config file (recommended) or\n", + "environmental variables. The code below will generate a config file, but we\n", + "first need to decide a 'base path'. This is generally the parent directory\n", + "where the data will be stored, with subdirectories for `raw`, `analysis`, and\n", + "other data folders. If they don't exist already, they will be created.\n", + "\n", + "The function below will create a config file (`~/.datajoint.config` if global,\n", + "`./dj_local_conf.json` if local). Local is recommended for the notebooks, as\n", + "each will start by loading this file. Custom json configs can be saved elsewhere, but will need to be loaded in startup with\n", + "`dj.config.load('your-path')`.\n", + "\n", + "To point spyglass to a folder elsewhere (e.g., an external drive for waveform\n", + "data), simply edit the json file. Note that the `raw` and `analysis` paths\n", + "appear under both `stores` and `custom`.\n" ] }, { "cell_type": "code", - "execution_count": 1, - "id": "912ac84b", + "execution_count": null, + "id": "7ebcb0bf", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'debug_mode': True,\n", - " 'prepopulate': True,\n", - " 'SPYGLASS_BASE_DIR': '/stelmo/nwb',\n", - " 'SPYGLASS_RAW_DIR': '/stelmo/nwb/raw',\n", - " 'SPYGLASS_ANALYSIS_DIR': '/stelmo/nwb/analysis',\n", - " 'SPYGLASS_RECORDING_DIR': '/stelmo/nwb/recording',\n", - " 'SPYGLASS_SORTING_DIR': '/stelmo/nwb/spikesorting',\n", - " 'SPYGLASS_WAVEFORMS_DIR': '/stelmo/nwb/waveforms',\n", - " 'SPYGLASS_TEMP_DIR': '/stelmo/nwb/tmp',\n", - " 'SPYGLASS_VIDEO_DIR': '/stelmo/nwb/video',\n", - " 'KACHERY_CLOUD_DIR': '/stelmo/nwb/kachery_storage',\n", - " 'KACHERY_STORAGE_DIR': '/stelmo/nwb/kachery_storage',\n", - " 'KACHERY_TEMP_DIR': '/stelmo/nwb/tmp',\n", - " 'KACHERY_ZONE': 'franklab.default',\n", - " 'FIGURL_CHANNEL': 'franklab2',\n", - " 'DJ_SUPPORT_FILEPATH_MANAGEMENT': 'TRUE',\n", - " 'KACHERY_CLOUD_EPHEMERAL': 'TRUE'}" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import os\n", - "import datajoint as dj\n", + "from spyglass.settings import SpyglassConfig\n", "\n", + "# change to the root directory of the project\n", "if os.path.basename(os.getcwd()) == \"notebooks\":\n", " os.chdir(\"..\")\n", - "dj.config.load(\"dj_local_conf.json\")\n", "\n", - "from spyglass.settings import config\n", - "\n", - "config" + "SpyglassConfig().save_dj_config(\n", + " save_method=\"local\", # global or local\n", + " base_dir=\"/path/like/stelmo/nwb/\",\n", + " database_user=\"your username\",\n", + " database_password=\"your password\", # remove this line for shared machines\n", + " database_host=\"localhost or lmf-db.cin.ucsf.edu\",\n", + " database_port=3306,\n", + " set_password=False,\n", + ")" ] }, { "cell_type": "markdown", - "id": "b5c15d54", + "id": "06eef771", "metadata": {}, "source": [ - "### Connect\n", - "\n", - "Now, you should be able to connect to the database you set up.\n", - "\n", - "Let's demonstrate with an example table:\n" + "If you used either a local or global save method, we can check the connection\n", + "to the database with ...\n" ] }, { "cell_type": "code", - "execution_count": 2, - "id": "afb63913-4e6b-4049-ae1d-55ab1ac8d42c", + "execution_count": null, + "id": "2e34baaf", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2023-09-28 08:07:06,176][INFO]: Connecting root@localhost:3307\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[2023-09-28 08:07:06,254][INFO]: Connected root@localhost:3307\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Populate: Populating table DataAcquisitionDeviceSystem with data {'data_acquisition_device_system': 'SpikeGadgets'} using insert1.\n", - "Populate: Populating table DataAcquisitionDeviceAmplifier with data {'data_acquisition_device_amplifier': 'Intan'} using insert1.\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " \n", - " \n", - " Table for holding the NWB files.\n", - "
\n", - " \n", - " \n", - " \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
\n", - "

nwb_file_name

\n", - " name of the NWB file\n", - "
\n", - "

nwb_file_abs_path

\n", - " \n", - "
CH101_20210711_.nwb=BLOB=
CH73_20211206_.nwb=BLOB=
CH65_20211212_.nwb=BLOB=
J1620210620_.nwb=BLOB=
montague20200802_.nwb=BLOB=
chimi20200304_.nwb=BLOB=
Wallie20220913_.nwb=BLOB=
mango20211203_.nwb=BLOB=
peanut20201108_.nwb=BLOB=
wilbur20210406_.nwb=BLOB=
eliot20221022_.nwb=BLOB=
Dan20211109_.nwb=BLOB=
\n", - "

...

\n", - "

Total: 817

\n", - " " - ], - "text/plain": [ - "*nwb_file_name nwb_file_a\n", - "+------------+ +--------+\n", - "\n", - " (Total: 0)" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ + "import datajoint as dj\n", + "\n", + "dj.conn() # test connection\n", + "dj.config # check config\n", + "\n", "from spyglass.common import Nwbfile\n", "\n", "Nwbfile()" ] }, + { + "cell_type": "markdown", + "id": "a492b8ba", + "metadata": {}, + "source": [ + "If you see an error saying `Could not find SPYGLASS_BASE_DIR`, try loading your\n", + "config before importing Spyglass, try setting this as an environmental variable\n", + "before importing Spyglass.\n", + "\n", + "```python\n", + "os.environ['SPYGLASS_BASE_DIR'] = '/your/base/path'\n", + "\n", + "import spyglass\n", + "from spyglass.settings import SpyglassConfig\n", + "import datajoint as dj\n", + "print(SpyglassConfig().config)\n", + "dj.config.save_local() # or global\n", + "```\n" + ] + }, { "cell_type": "markdown", "id": "13fd64af", diff --git a/notebooks/03_Merge_Tables.ipynb b/notebooks/03_Merge_Tables.ipynb index 0cbd4e1b7..04cc6ba13 100644 --- a/notebooks/03_Merge_Tables.ipynb +++ b/notebooks/03_Merge_Tables.ipynb @@ -31,8 +31,14 @@ "- For additional info on DataJoint syntax, including table definitions and\n", " inserts, see\n", " [these additional tutorials](https://github.com/datajoint/datajoint-tutorials)\n", - "- For information on why we use merge tables, and how to make one, see our \n", - " [documentation](https://lorenfranklab.github.io/spyglass/0.4/misc/merge_tables/)\n" + "- For information on why we use merge tables, and how to make one, see our\n", + " [documentation](https://lorenfranklab.github.io/spyglass/0.4/misc/merge_tables/)\n", + "\n", + "In short, merge tables represent the end processing point of a given way of\n", + "processing the data in our pipelines. Merge Tables allow us to build new\n", + "processing pipeline, or a new version of an existing pipeline, without having to\n", + "drop or migrate the old tables. They allow data to be processed in different\n", + "ways, but with a unified end result that downstream pipelines can all access.\n" ] }, { @@ -46,7 +52,6 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "\n", "Let's start by importing the `spyglass` package, along with a few others.\n" ] }, @@ -102,7 +107,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Check to make sure the data inserted in the previour notebook is still there." + "Check to make sure the data inserted in the previour notebook is still there.\n" ] }, { @@ -238,7 +243,7 @@ "_Note_: Some existing parents of Merge Tables perform the Merge Table insert as\n", "part of the populate methods. This practice will be revised in the future.\n", "\n", - "" + "\n" ] }, { @@ -309,10 +314,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "Merge Tables have multiple custom methods that begin with `merge`.\n", "\n", - "Merge Tables have multiple custom methods that begin with `merge`. \n", - "\n", - "`help` can show us the docstring of each" + "`help` can show us the docstring of each\n" ] }, { @@ -365,7 +369,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Showing data" + "## Showing data\n" ] }, { @@ -598,7 +602,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Selecting data" + "## Selecting data\n" ] }, { @@ -852,7 +856,7 @@ "metadata": {}, "source": [ "`fetch` will collect all relevant entries and return them as a list in\n", - " the format specified by keyword arguments and one's DataJoint config.\n" + "the format specified by keyword arguments and one's DataJoint config.\n" ] }, { @@ -880,8 +884,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "`merge_fetch` requires a restriction as the first argument. For no restriction, \n", - "use `True`." + "`merge_fetch` requires a restriction as the first argument. For no restriction,\n", + "use `True`.\n" ] }, { @@ -936,7 +940,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Deletion from Merge Tables" + "## Deletion from Merge Tables\n" ] }, { @@ -956,7 +960,7 @@ "\n", "The two latter cases can be destructive, so we include an extra layer of\n", "protection with `dry_run`. When true (by default), these functions return\n", - "a list of tables with the entries that would otherwise be deleted." + "a list of tables with the entries that would otherwise be deleted.\n" ] }, { @@ -978,8 +982,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To delete all merge table entries associated with an NWB file, use \n", - "`delete_downstream_merge` with the `Nwbfile` table. \n" + "To delete all merge table entries associated with an NWB file, use\n", + "`delete_downstream_merge` with the `Nwbfile` table.\n" ] }, { @@ -1000,15 +1004,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Up Next" + "## Up Next\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "In the [next notebook](./10_Spike_Sorting.ipynb), we'll start working with \n", - "ephys data with spike sorting." + "In the [next notebook](./10_Spike_Sorting.ipynb), we'll start working with\n", + "ephys data with spike sorting.\n" ] } ], diff --git a/notebooks/py_scripts/00_Setup.py b/notebooks/py_scripts/00_Setup.py index 3be571e7b..2ea726aa8 100644 --- a/notebooks/py_scripts/00_Setup.py +++ b/notebooks/py_scripts/00_Setup.py @@ -15,6 +15,9 @@ # # Setup # +# ## Intro +# + # Welcome to [Spyglass](https://lorenfranklab.github.io/spyglass/0.4/), # a [DataJoint](https://github.com/datajoint/datajoint-python/) # pipeline maintained by the [Frank Lab](https://franklab.ucsf.edu/) at UCSF. @@ -30,7 +33,8 @@ # ## Local environment # -# Codespace users can skip this step. Frank Lab members should first follow + +# JupyterHub users can skip this step. Frank Lab members should first follow # 'rec to nwb overview' steps on Google Drive to set up an ssh connection. # # For local use, download and install ... @@ -73,7 +77,9 @@ # # _Note:_ Spyglass is also installable via # [pip]() -# and [pypi](https://pypi.org/project/spyglass-neuro/) with `pip install spyglass-neuro`, but downloading from GitHub will also other files accessible. +# and [pypi](https://pypi.org/project/spyglass-neuro/) with +# `pip install spyglass-neuro`, but downloading from GitHub will also download +# other files. # # Next, within VSCode, # [select the kernel](https://code.visualstudio.com/docs/datascience/jupyter-kernel-management) @@ -84,108 +90,44 @@ # details on each of these programs and the role they play in using the pipeline. # -# ## Database Connection +# ## Database # # You have a few options for databases. # # 1. Connect to an existing database. -# 2. Use GitHub Codespaces (coming soon...) -# 3. Run your own database with [Docker](#running-your-own-database) +# 2. Run your own database with [Docker](#running-your-own-database) +# 3. JupyterHub (coming soon...) # -# Once your database is set up, be sure to configure the connection -# with your `dj_local_conf.json` file. +# Your choice above should result in a set of credentials, including host name, +# host port, user name, and password. Note these for the next step. # - -# ### Existing Database +#
Note for MySQL 8 users, including Frank Lab members # - -# Members of the Frank Lab will need to use DataJoint 0.14.2 (currently in -# pre-release) in order to change their password on the MySQL 8 server. DataJoint -# 0.14.2 +# Using a MySQL 8 server, like the server hosted by the Frank Lab, will +# require the pre-release version of DataJoint to change one's password. # # ```bash +# # cd /location/for/datajoint/source/files/ # git clone https://github.com/datajoint/datajoint-python # pip install ./datajoint-python # ``` # -# Members of the lab can run the `dj_config.py` helper script to generate a config -# like the one below. -# -# ```bash -# # cd spyglass -# python config/dj_config.py -# ``` -# -# Outside users should copy/paste `dj_local_conf_example` and adjust values -# accordingly. -# -# The base path (formerly `SPYGLASS_BASE_DIR`) is the directory where all data -# will be saved. See also -# [docs](https://lorenfranklab.github.io/spyglass/0.4/installation/) for more -# information on subdirectories. -# -# A different `output_filename` will save different files: -# -# - `dj_local_conf.json`: Recommended. Used for tutorials. A file in the current -# directory DataJoint will automatically recognize when a Python session is -# launched from this directory. -# - `.datajoint_config.json` or no input: A file in the user's home directory -# that will be referenced whenever no local version (see above) is present. -# - Anything else: A custom name that will need to be loaded (e.g., -# `dj.load('x')`) for each python session. -# -# The config will be a `json` file like the following. -# -# ```json -# { -# "database.host": "lmf-db.cin.ucsf.edu", -# "database.user": "", -# "database.password": "Not recommended for shared machines", -# "database.port": 3306, -# "database.use_tls": true, -# "enable_python_native_blobs": true, -# "filepath_checksum_size_limit": 1 * 1024**3, -# "loglevel": "INFO", -# "stores": { -# "raw": { -# "protocol": "file", -# "location": "/stelmo/nwb/raw", -# "stage": "/stelmo/nwb/raw" -# }, -# "analysis": { -# "protocol": "file", -# "location": "/stelmo/nwb/analysis", -# "stage": "/stelmo/nwb/analysis" -# } -# }, -# "custom": { -# "spyglass_dirs": { -# "base": "/stelmo/nwb/" -# } -# } -# } -# ``` +#
# -# Spyglass will use the log level present in your DataJoint config to decide the -# appropriate logging level for this session. To change the messages you see, -# select from one of [these options](https://docs.python.org/3/library/logging.html#levels). + +# ### Existing Database # -# If you see an error saying `Could not find SPYGLASS_BASE_DIR`, try loading your -# config before importing Spyglass. + +# Connecting to an existing database will require a user name and password. +# Please contact your database administrator for this information. # -# ```python -# import datajoint as dj -# dj.load('/path/to/config') -# import spyglass -# ``` +# Frank Lab members should contact Chris. # -# ### Running your own database +# ### Running your own database with Docker # -# #### Setup Docker -# # - First, [install Docker](https://docs.docker.com/engine/install/). # - Add yourself to the # [`docker` group](https://docs.docker.com/engine/install/linux-postinstall/) so @@ -196,10 +138,6 @@ # docker pull datajoint/mysql:8.0 # ``` # -# _Note_: For this demo, MySQL version won't matter. Some -# [database management](https://lorenfranklab.github.io/spyglass/latest/misc/database_management/#mysql-version) -# features of Spyglass, however, expect MySQL >= 8. -# # - When run, this is referred to as a 'Docker container' # - Next start the container with a couple additional pieces of info... # @@ -219,60 +157,82 @@ # docker run --name spyglass-db -v dj-vol:/var/lib/mysql -p 3306:3306 -e MYSQL_ROOT_PASSWORD=tutorial datajoint/mysql # ``` # -# #### Configure -# -# The `dj_local_conf_example.json` contains all the defaults for a Docker -# connection. Simply rename to `dj_local_conf.json` and modify the contents -# accordingly. This includes the host, password and user. For Spyglass, you'll -# want to set your base path under `custom`: -# -# ```json -# { -# "database.host": "localhost", -# "database.password": "tutorial", -# "database.user": "root", -# "custom": { -# "database.prefix": "username_", -# "spyglass_dirs": { -# "base": "/your/base/path" -# } -# } -# } -# ``` +# Docker credentials are as follows: +# +# - Host: localhost +# - Password: tutorial +# - User: root +# - Port: 3306 +# + +# ### Config and Connecting to the database # -# ### Loading the config +# Spyglass can load settings from either a DataJoint config file (recommended) or +# environmental variables. The code below will generate a config file, but we +# first need to decide a 'base path'. This is generally the parent directory +# where the data will be stored, with subdirectories for `raw`, `analysis`, and +# other data folders. If they don't exist already, they will be created. +# +# The function below will create a config file (`~/.datajoint.config` if global, +# `./dj_local_conf.json` if local). Local is recommended for the notebooks, as +# each will start by loading this file. Custom json configs can be saved elsewhere, but will need to be loaded in startup with +# `dj.config.load('your-path')`. # -# We can check that the paths are correctly set up by loading the config from -# the main Spyglass directory. +# To point spyglass to a folder elsewhere (e.g., an external drive for waveform +# data), simply edit the json file. Note that the `raw` and `analysis` paths +# appear under both `stores` and `custom`. # # + import os -import datajoint as dj +from spyglass.settings import SpyglassConfig +# change to the root directory of the project if os.path.basename(os.getcwd()) == "notebooks": os.chdir("..") -dj.config.load("dj_local_conf.json") -from spyglass.settings import config - -config +SpyglassConfig().save_dj_config( + save_method="local", # global or local + base_dir="/path/like/stelmo/nwb/", + database_user="your username", + database_password="your password", # remove this line for shared machines + database_host="localhost or lmf-db.cin.ucsf.edu", + database_port=3306, + set_password=False, +) # - -# ### Connect -# -# Now, you should be able to connect to the database you set up. -# -# Let's demonstrate with an example table: +# If you used either a local or global save method, we can check the connection +# to the database with ... # # + +import datajoint as dj + +dj.conn() # test connection +dj.config # check config + from spyglass.common import Nwbfile Nwbfile() # - +# If you see an error saying `Could not find SPYGLASS_BASE_DIR`, try loading your +# config before importing Spyglass, try setting this as an environmental variable +# before importing Spyglass. +# +# ```python +# os.environ['SPYGLASS_BASE_DIR'] = '/your/base/path' +# +# import spyglass +# from spyglass.settings import SpyglassConfig +# import datajoint as dj +# print(SpyglassConfig().config) +# dj.config.save_local() # or global +# ``` +# + # # Up Next # diff --git a/notebooks/py_scripts/03_Merge_Tables.py b/notebooks/py_scripts/03_Merge_Tables.py index 69cb29600..c4c0abb48 100644 --- a/notebooks/py_scripts/03_Merge_Tables.py +++ b/notebooks/py_scripts/03_Merge_Tables.py @@ -32,11 +32,16 @@ # - For information on why we use merge tables, and how to make one, see our # [documentation](https://lorenfranklab.github.io/spyglass/0.4/misc/merge_tables/) # +# In short, merge tables represent the end processing point of a given way of +# processing the data in our pipelines. Merge Tables allow us to build new +# processing pipeline, or a new version of an existing pipeline, without having to +# drop or migrate the old tables. They allow data to be processed in different +# ways, but with a unified end result that downstream pipelines can all access. +# # ## Imports # -# # Let's start by importing the `spyglass` package, along with a few others. # @@ -70,6 +75,7 @@ # # Check to make sure the data inserted in the previour notebook is still there. +# nwb_file_name = "minirec20230622.nwb" nwb_copy_file_name = get_nwb_copy_filename(nwb_file_name) @@ -82,6 +88,7 @@ # part of the populate methods. This practice will be revised in the future. # # +# sgc.FirFilterParameters().create_standard_filters() lfp.lfp_electrode.LFPElectrodeGroup.create_lfp_electrode_group( @@ -103,10 +110,10 @@ # ## Helper functions # -# # Merge Tables have multiple custom methods that begin with `merge`. # # `help` can show us the docstring of each +# merge_methods = [d for d in dir(Merge) if d.startswith("merge")] print(merge_methods) @@ -114,6 +121,7 @@ help(getattr(Merge, merge_methods[-1])) # ## Showing data +# # `merge_view` shows a union of the master and all part tables. # @@ -143,6 +151,7 @@ result2 == result1 # ## Selecting data +# # There are also functions for retrieving part/parent table(s) and fetching data. # @@ -156,7 +165,7 @@ result5 # `fetch` will collect all relevant entries and return them as a list in -# the format specified by keyword arguments and one's DataJoint config. +# the format specified by keyword arguments and one's DataJoint config. # result6 = result5.fetch("lfp_sampling_rate") # Sample rate for all mini* files @@ -164,6 +173,7 @@ # `merge_fetch` requires a restriction as the first argument. For no restriction, # use `True`. +# result7 = LFPOutput.merge_fetch(True, "filter_name", "nwb_file_name") result7 @@ -172,6 +182,7 @@ result8 # ## Deletion from Merge Tables +# # When deleting from Merge Tables, we can either... # @@ -187,6 +198,7 @@ # The two latter cases can be destructive, so we include an extra layer of # protection with `dry_run`. When true (by default), these functions return # a list of tables with the entries that would otherwise be deleted. +# LFPOutput.merge_delete(nwb_file_dict) # Delete from merge table LFPOutput.merge_delete_parent(restriction=nwb_file_dict, dry_run=True) @@ -208,6 +220,8 @@ ) # ## Up Next +# # In the [next notebook](./10_Spike_Sorting.ipynb), we'll start working with # ephys data with spike sorting. +# diff --git a/pyproject.toml b/pyproject.toml index 521224737..16be96baf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,9 +117,11 @@ ignore-words-list = 'nevers' minversion = "7.0" addopts = [ "-sv", + # "--sw", # stepwise: resume with next test after failure + # "--pdb", # drop into debugger on failure "-p no:warnings", # "--no-teardown", # don't teardown the database after tests - # "--quiet-spy", # don't show logging from spyglass + "--quiet-spy", # don't show logging from spyglass "--show-capture=no", "--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger "--cov=spyglass", diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index 0cff5f725..ed9673ecb 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -170,8 +170,6 @@ class PosObject(SpyglassMixin, dj.Part): _nwb_table = Nwbfile def fetch1_dataframe(self): - INDEX_ADJUST = 1 # adjust 0-index to 1-index (e.g., xloc0 -> xloc1) - id_rp = [(n["id"], n["raw_position"]) for n in self.fetch_nwb()] if len(set(rp.interval for _, rp in id_rp)) > 1: @@ -181,19 +179,29 @@ def fetch1_dataframe(self): pd.DataFrame( data=rp.data, index=pd.Index(rp.timestamps, name="time"), - columns=[ - col # use existing columns if already numbered - if "1" in rp.description or "2" in rp.description - # else number them by id - else col + str(id + INDEX_ADJUST) - for col in rp.description.split(", ") - ], + columns=self._get_column_names(rp, pos_id), ) - for id, rp in id_rp + for pos_id, rp in id_rp ] return reduce(lambda x, y: pd.merge(x, y, on="time"), df_list) + @staticmethod + def _get_column_names(rp, pos_id): + INDEX_ADJUST = 1 # adjust 0-index to 1-index (e.g., xloc0 -> xloc1) + n_pos_dims = rp.data.shape[1] + column_names = [ + col # use existing columns if already numbered + if "1" in rp.description or "2" in rp.description + # else number them by id + else col + str(pos_id + INDEX_ADJUST) + for col in rp.description.split(", ") + ] + if len(column_names) != n_pos_dims: + # if the string split didn't work, use default names + column_names = ["x", "y", "z"][:n_pos_dims] + return column_names + def make(self, key): nwb_file_name = key["nwb_file_name"] interval_list_name = key["interval_list_name"] diff --git a/src/spyglass/common/common_ephys.py b/src/spyglass/common/common_ephys.py index 8765b1965..df7ca206f 100644 --- a/src/spyglass/common/common_ephys.py +++ b/src/spyglass/common/common_ephys.py @@ -17,6 +17,7 @@ from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile from spyglass.common.common_region import BrainRegion # noqa: F401 from spyglass.common.common_session import Session # noqa: F401 +from spyglass.settings import test_mode from spyglass.utils import SpyglassMixin, logger from spyglass.utils.nwb_helper_fn import ( estimate_sampling_rate, @@ -369,7 +370,9 @@ def set_lfp_electrodes(self, nwb_file_name, electrode_list): """ # remove the session and then recreate the session and Electrode list - (LFPSelection() & {"nwb_file_name": nwb_file_name}).delete() + (LFPSelection() & {"nwb_file_name": nwb_file_name}).delete( + safemode=not test_mode + ) # check to see if the user allowed the deletion if ( len((LFPSelection() & {"nwb_file_name": nwb_file_name}).fetch()) diff --git a/src/spyglass/common/common_filter.py b/src/spyglass/common/common_filter.py index 9d2cdf9d6..988266d0d 100644 --- a/src/spyglass/common/common_filter.py +++ b/src/spyglass/common/common_filter.py @@ -212,7 +212,7 @@ def _time_bound_check(self, start, stop, all, nsamples): start = all[0] if stop > all[-1]: - warnings.warn( + logger.warning( timestamp_warn + "stop time larger than last timestamp, " + f"substituting last: {stop} < {all[-1]}" diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index 732c9779e..7461f0c72 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -352,13 +352,17 @@ def calculate_position_info( dt = np.median(np.diff(time)) sampling_rate = 1 / dt - # Define LEDs - if led1_is_front: - front_LED = position[:, [0, 1]].astype(float) - back_LED = position[:, [2, 3]].astype(float) + if position.shape[1] < 4: + front_LED = position.astype(float) + back_LED = position.astype(float) else: - back_LED = position[:, [0, 1]].astype(float) - front_LED = position[:, [2, 3]].astype(float) + # If there are 4 columns, then there are 2 LEDs + if led1_is_front: + front_LED = position[:, [0, 1]].astype(float) + back_LED = position[:, [2, 3]].astype(float) + else: + back_LED = position[:, [0, 1]].astype(float) + front_LED = position[:, [2, 3]].astype(float) # Convert to cm back_LED *= meters_to_pixels * CM_TO_METERS @@ -783,16 +787,21 @@ def _fix_col_names(spatial_df): DEFAULT_COLS = ["xloc", "yloc", "xloc2", "yloc2"] ONE_IDX_COLS = ["xloc1", "yloc1", "xloc2", "yloc2"] ZERO_IDX_COLS = ["xloc0", "yloc0", "xloc1", "yloc1"] + THREE_D_COLS = ["x", "y", "z"] input_cols = list(spatial_df.columns) has_default = all([c in input_cols for c in DEFAULT_COLS]) has_0_idx = all([c in input_cols for c in ZERO_IDX_COLS]) has_1_idx = all([c in input_cols for c in ONE_IDX_COLS]) + has_other_default = all([c in input_cols for c in THREE_D_COLS]) if has_default: # move the 4 position columns to front, continue spatial_df = spatial_df[DEFAULT_COLS] + elif has_other_default: + # move the 4 position columns to front, continue + spatial_df = spatial_df[THREE_D_COLS] elif has_0_idx: # move the 4 position columns to front, rename to default, continue spatial_df = spatial_df[ZERO_IDX_COLS] diff --git a/src/spyglass/common/common_ripple.py b/src/spyglass/common/common_ripple.py index c3f31ea5f..154c937af 100644 --- a/src/spyglass/common/common_ripple.py +++ b/src/spyglass/common/common_ripple.py @@ -48,7 +48,7 @@ class RippleLFPElectrode(SpyglassMixin, dj.Part): def insert1(self, key, **kwargs): filter_name = (LFPBand & key).fetch1("filter_name") if "ripple" not in filter_name.lower(): - raise UserWarning("Please use a ripple filter") + logger.warning("Please use a ripple filter") super().insert1(key, **kwargs) @staticmethod diff --git a/src/spyglass/decoding/decoding_merge.py b/src/spyglass/decoding/decoding_merge.py index 6603c318f..7e6a2d90a 100644 --- a/src/spyglass/decoding/decoding_merge.py +++ b/src/spyglass/decoding/decoding_merge.py @@ -84,18 +84,36 @@ def cleanup(self, dry_run=False): except (PermissionError, FileNotFoundError): logger.warning(f"Unable to remove {path}, skipping") - @classmethod - def _get_source_class(cls, key): - if cls._source_class_dict is None: - cls._source_class_dict = {} - module = inspect.getmodule(cls) - for part_name in cls.parts(): + @property + def source_class_dict(self) -> dict: + """Dictionary of source class names to source classes + + { + 'ClusterlessDecodingV1': spy...ClusterlessDecodingV1, + 'SortedSpikesDecodingV1': spy...SortedSpikesDecodingV1 + } + + Returns + ------- + dict + Dictionary of source class names to source classes + """ + if not self._source_class_dict: + self._ensure_dependencies_loaded() + module = inspect.getmodule(self) + for part_name in self.parts(): part_name = to_camel_case(part_name.split("__")[-1].strip("`")) part = getattr(module, part_name) - cls._source_class_dict[part_name] = part + self._source_class_dict[part_name] = part + return self._source_class_dict - source = (cls & key).fetch1("source") - return cls._source_class_dict[source] + @classmethod + def _get_source_class(cls, key): + # CB: By making this a property, we can generate the source_class_dict + # without a key. Previously failed on empty table + # This demonstrates pipeline-specific implementation. See also + # merge_restrict_class edits that centralize this logic. + return cls.source_class_dict[(cls & key).fetch1("source")] @classmethod def load_results(cls, key): diff --git a/src/spyglass/decoding/v1/clusterless.py b/src/spyglass/decoding/v1/clusterless.py index 838ed12d9..3b179d7ee 100644 --- a/src/spyglass/decoding/v1/clusterless.py +++ b/src/spyglass/decoding/v1/clusterless.py @@ -13,10 +13,12 @@ from pathlib import Path import datajoint as dj +import non_local_detector.analysis as analysis import numpy as np import pandas as pd import xarray as xr from non_local_detector.models.base import ClusterlessDetector +from ripple_detection import get_multiunit_population_firing_rate from track_linearization import get_linearized_position from spyglass.common.common_interval import IntervalList # noqa: F401 @@ -289,6 +291,9 @@ def load_environments(key): model_params["decoding_kwargs"], ) + if decoding_kwargs is None: + decoding_kwargs = {} + ( position_info, position_variable_names, @@ -361,7 +366,10 @@ def load_linear_position_info(key): environment = ClusterlessDecodingV1.load_environments(key)[0] position_df = ClusterlessDecodingV1.load_position_info(key)[0] - position = np.asarray(position_df[["position_x", "position_y"]]) + position_variable_names = (PositionGroup & key).fetch1( + "position_variables" + ) + position = np.asarray(position_df[position_variable_names]) linear_position_df = get_linearized_position( position=position, @@ -415,3 +423,93 @@ def load_spike_data(key, filter_by_interval=True): new_waveform_features.append(elec_waveform_features[is_in_interval]) return new_spike_times, new_waveform_features + + @classmethod + def get_spike_indicator(cls, key, time): + time = np.asarray(time) + min_time, max_time = time[[0, -1]] + spike_times = cls.load_spike_data(key)[0] + spike_indicator = np.zeros((len(time), len(spike_times))) + + for ind, times in enumerate(spike_times): + times = times[np.logical_and(times >= min_time, times <= max_time)] + spike_indicator[:, ind] = np.bincount( + np.digitize(times, time[1:-1]), + minlength=time.shape[0], + ) + + return spike_indicator + + @classmethod + def get_firing_rate(cls, key, time, multiunit=False): + spike_indicator = cls.get_spike_indicator(key, time) + if spike_indicator.ndim == 1: + spike_indicator = spike_indicator[:, np.newaxis] + + sampling_frequency = 1 / np.median(np.diff(time)) + + if multiunit: + spike_indicator = spike_indicator.sum(axis=1, keepdims=True) + return np.stack( + [ + get_multiunit_population_firing_rate( + indicator[:, np.newaxis], sampling_frequency + ) + for indicator in spike_indicator.T + ], + axis=1, + ) + + def get_ahead_behind_distance(self): + # TODO: allow specification of specific time interval + # TODO: allow specification of track graph + # TODO: Handle decode intervals, store in table + + classifier = self.load_model() + results = self.load_results() + posterior = results.acausal_posterior.unstack("state_bins").sum("state") + + if getattr(classifier.environments[0], "track_graph") is not None: + linear_position_info = self.load_linear_position_info( + self.fetch1("KEY") + ) + + orientation_name = ( + "orientation" + if "orientation" in linear_position_info.columns + else "head_orientation" + ) + + traj_data = analysis.get_trajectory_data( + posterior=posterior, + track_graph=classifier.environments[0].track_graph, + decoder=classifier, + actual_projected_position=linear_position_info[ + ["projected_x_position", "projected_y_position"] + ], + track_segment_id=linear_position_info["track_segment_id"], + actual_orientation=linear_position_info[orientation_name], + ) + + return analysis.get_ahead_behind_distance( + classifier.environments[0].track_graph, *traj_data + ) + else: + position_info = self.load_position_info(self.fetch1("KEY")) + map_position = analysis.maximum_a_posteriori_estimate(posterior) + + orientation_name = ( + "orientation" + if "orientation" in position_info.columns + else "head_orientation" + ) + position_variable_names = ( + PositionGroup & self.fetch1("KEY") + ).fetch1("position_variables") + + return analysis.get_ahead_behind_distance2D( + position_info[position_variable_names].to_numpy(), + position_info[orientation_name].to_numpy(), + map_position, + classifier.environments[0].track_graphDD, + ) diff --git a/src/spyglass/decoding/v1/sorted_spikes.py b/src/spyglass/decoding/v1/sorted_spikes.py index 3c910102a..9f968d768 100644 --- a/src/spyglass/decoding/v1/sorted_spikes.py +++ b/src/spyglass/decoding/v1/sorted_spikes.py @@ -13,10 +13,12 @@ from pathlib import Path import datajoint as dj +import non_local_detector.analysis as analysis import numpy as np import pandas as pd import xarray as xr from non_local_detector.models.base import SortedSpikesDetector +from ripple_detection import get_multiunit_population_firing_rate from track_linearization import get_linearized_position from spyglass.common.common_interval import IntervalList # noqa: F401 @@ -281,6 +283,9 @@ def load_environments(key): model_params["decoding_kwargs"], ) + if decoding_kwargs is None: + decoding_kwargs = {} + ( position_info, position_variable_names, @@ -352,7 +357,10 @@ def load_linear_position_info(key): environment = SortedSpikesDecodingV1.load_environments(key)[0] position_df = SortedSpikesDecodingV1.load_position_info(key)[0] - position = np.asarray(position_df[["position_x", "position_y"]]) + position_variable_names = (PositionGroup & key).fetch1( + "position_variables" + ) + position = np.asarray(position_df[position_variable_names]) linear_position_df = get_linearized_position( position=position, @@ -384,7 +392,7 @@ def load_spike_data(key, filter_by_interval=True): spike_times = [] for merge_id in merge_ids: - nwb_file = SpikeSortingOutput.fetch_nwb({"merge_id": merge_id})[0] + nwb_file = SpikeSortingOutput().fetch_nwb({"merge_id": merge_id})[0] if "object_id" in nwb_file: # v1 spikesorting @@ -401,10 +409,126 @@ def load_spike_data(key, filter_by_interval=True): min_time, max_time = SortedSpikesDecodingV1._get_interval_range(key) new_spike_times = [] - for elec_spike_times in zip(spike_times): + for elec_spike_times in spike_times: is_in_interval = np.logical_and( elec_spike_times >= min_time, elec_spike_times <= max_time ) new_spike_times.append(elec_spike_times[is_in_interval]) return new_spike_times + + @classmethod + def get_spike_indicator(cls, key, time): + time = np.asarray(time) + min_time, max_time = time[[0, -1]] + spike_times = cls.load_spike_data(key) + spike_indicator = np.zeros((len(time), len(spike_times))) + + for ind, times in enumerate(spike_times): + times = times[np.logical_and(times >= min_time, times <= max_time)] + spike_indicator[:, ind] = np.bincount( + np.digitize(times, time[1:-1]), + minlength=time.shape[0], + ) + + return spike_indicator + + @classmethod + def get_firing_rate(cls, key, time, multiunit=False): + spike_indicator = cls.get_spike_indicator(key, time) + if spike_indicator.ndim == 1: + spike_indicator = spike_indicator[:, np.newaxis] + + sampling_frequency = 1 / np.median(np.diff(time)) + + if multiunit: + spike_indicator = spike_indicator.sum(axis=1, keepdims=True) + return np.stack( + [ + get_multiunit_population_firing_rate( + indicator[:, np.newaxis], sampling_frequency + ) + for indicator in spike_indicator.T + ], + axis=1, + ) + + def spike_times_sorted_by_place_field_peak(self, time_slice=None): + if time_slice is None: + time_slice = slice(-np.inf, np.inf) + + spike_times = self.load_spike_data(self.proj()) + classifier = self.load_model() + + new_spike_times = {} + + for encoding_model in classifier.encoding_model_: + place_fields = np.asarray( + classifier.encoding_model_[encoding_model]["place_fields"] + ) + neuron_sort_ind = np.argsort( + np.nanargmax(place_fields, axis=1).squeeze() + ) + new_spike_times[encoding_model] = [ + spike_times[neuron_ind][ + np.logical_and( + spike_times[neuron_ind] >= time_slice.start, + spike_times[neuron_ind] <= time_slice.stop, + ) + ] + for neuron_ind in neuron_sort_ind + ] + + def get_ahead_behind_distance(self): + # TODO: allow specification of specific time interval + # TODO: allow specification of track graph + # TODO: Handle decode intervals, store in table + + classifier = self.load_model() + results = self.load_results() + posterior = results.acausal_posterior.unstack("state_bins").sum("state") + + if classifier.environments[0].track_graph is not None: + linear_position_info = self.load_linear_position_info( + self.fetch1("KEY") + ) + + orientation_name = ( + "orientation" + if "orientation" in linear_position_info.columns + else "head_orientation" + ) + + traj_data = analysis.get_trajectory_data( + posterior=posterior, + track_graph=classifier.environments[0].track_graph, + decoder=classifier, + actual_projected_position=linear_position_info[ + ["projected_x_position", "projected_y_position"] + ], + track_segment_id=linear_position_info["track_segment_id"], + actual_orientation=linear_position_info[orientation_name], + ) + + return analysis.get_ahead_behind_distance( + classifier.environments[0].track_graph, *traj_data + ) + else: + position_info = self.load_position_info(self.fetch1("KEY")) + map_position = analysis.maximum_a_posteriori_estimate(posterior) + + orientation_name = ( + "orientation" + if "orientation" in position_info.columns + else "head_orientation" + ) + position_variable_names = ( + PositionGroup & self.fetch1("KEY") + ).fetch1("position_variables") + + return analysis.get_ahead_behind_distance2D( + position_info[position_variable_names].to_numpy(), + position_info[orientation_name].to_numpy(), + map_position, + classifier.environments[0].track_graphDD, + ) diff --git a/src/spyglass/decoding/v1/waveform_features.py b/src/spyglass/decoding/v1/waveform_features.py index 4bed99f35..b332df45e 100644 --- a/src/spyglass/decoding/v1/waveform_features.py +++ b/src/spyglass/decoding/v1/waveform_features.py @@ -146,7 +146,7 @@ def make(self, key): sorter, ) - spike_times = SpikeSortingOutput.fetch_nwb(merge_key)[0][ + spike_times = SpikeSortingOutput().fetch_nwb(merge_key)[0][ analysis_nwb_key ]["spike_times"] diff --git a/src/spyglass/lfp/analysis/v1/lfp_band.py b/src/spyglass/lfp/analysis/v1/lfp_band.py index 059bd276c..3927a0604 100644 --- a/src/spyglass/lfp/analysis/v1/lfp_band.py +++ b/src/spyglass/lfp/analysis/v1/lfp_band.py @@ -177,7 +177,7 @@ class LFPBandV1(SpyglassMixin, dj.Computed): def make(self, key): # get the NWB object with the lfp data; FIX: change to fetch with additional infrastructure lfp_key = {"merge_id": key["lfp_merge_id"]} - lfp_object = LFPOutput.fetch_nwb(lfp_key)[0]["lfp"] + lfp_object = (LFPOutput & lfp_key).fetch_nwb()[0]["lfp"] # get the electrodes to be filtered and their references lfp_band_elect_id, lfp_band_ref_id = ( diff --git a/src/spyglass/ripple/v1/ripple.py b/src/spyglass/ripple/v1/ripple.py index ef7483f01..4d2397bc3 100644 --- a/src/spyglass/ripple/v1/ripple.py +++ b/src/spyglass/ripple/v1/ripple.py @@ -57,7 +57,7 @@ class RippleLFPElectrode(SpyglassMixin, dj.Part): def validate_key(key): filter_name = (LFPBandV1 & key).fetch1("filter_name") if "ripple" not in filter_name.lower(): - raise UserWarning("Please use a ripple filter") + raise ValueError("Please use a ripple filter") @staticmethod def set_lfp_electrodes( diff --git a/src/spyglass/settings.py b/src/spyglass/settings.py index e2e0a2142..68fe1e528 100644 --- a/src/spyglass/settings.py +++ b/src/spyglass/settings.py @@ -7,6 +7,8 @@ import yaml from pymysql.err import OperationalError +from spyglass.utils import logger + class SpyglassConfig: """Gets Spyglass dirs from dj.config or environment variables. @@ -26,6 +28,27 @@ def __init__(self, base_dir: str = None, **kwargs): ---------- base_dir (str) The base directory. + + Attributes + ---------- + supplied_base_dir (str) + The base directory passed to the class. + config_defaults (dict) + Default settings for the config. + relative_dirs (dict) + Relative dirs for each prefix (spyglass, kachery, dlc). Relative + to respective base_dir. Created on init. + dj_defaults (dict) + Default settings for datajoint. + env_defaults (dict) + Default settings for environment variables. + _config (dict) + Cached config settings. + _debug_mode (bool) + True if debug_mode is set. Supports skipping known bugs in test env. + _test_mode (bool) + True if test_mode is set. Required for pytests to run without + prompts. """ self.supplied_base_dir = base_dir self._config = dict() @@ -33,6 +56,7 @@ def __init__(self, base_dir: str = None, **kwargs): self._debug_mode = kwargs.get("debug_mode", False) self._test_mode = kwargs.get("test_mode", False) self._dlc_base = None + self.load_failed = False self.relative_dirs = { # {PREFIX}_{KEY}_DIR, default dir relative to base_dir @@ -56,7 +80,6 @@ def __init__(self, base_dir: str = None, **kwargs): "output": "output", }, } - self.dj_defaults = { "database.host": kwargs.get("database_host", "lmf-db.cin.ucsf.edu"), "database.user": kwargs.get("database_user"), @@ -65,7 +88,6 @@ def __init__(self, base_dir: str = None, **kwargs): "filepath_checksum_size_limit": 1 * 1024**3, "enable_python_native_blobs": True, } - self.env_defaults = { "FIGURL_CHANNEL": "franklab2", "DJ_SUPPORT_FILEPATH_MANAGEMENT": "TRUE", @@ -73,7 +95,9 @@ def __init__(self, base_dir: str = None, **kwargs): "HD5_USE_FILE_LOCKING": "FALSE", } - def load_config(self, force_reload=False): + def load_config( + self, base_dir=None, force_reload=False, on_startup: bool = False + ): """ Loads the configuration settings for the object. @@ -85,6 +109,9 @@ def load_config(self, force_reload=False): Parameters ---------- + base_dir: str + Optional. Default None. The base directory. If not provided, will + use the env variable or existing config. force_reload: bool Optional. Default False. Default skip load if already completed. @@ -98,7 +125,7 @@ def load_config(self, force_reload=False): dict list of relative_dirs and other settings (e.g., prepopulate). """ - if self._config and not force_reload: + if not force_reload and self._config: return self._config dj_custom = dj.config.get("custom", {}) @@ -110,17 +137,23 @@ def load_config(self, force_reload=False): self._test_mode = dj_custom.get("test_mode", False) resolved_base = ( - self.supplied_base_dir + base_dir + or self.supplied_base_dir or dj_spyglass.get("base") or os.environ.get("SPYGLASS_BASE_DIR") ) + if resolved_base and not Path(resolved_base).exists(): + resolved_base = Path(resolved_base).expanduser() if not resolved_base or not Path(resolved_base).exists(): - raise ValueError( - f"Could not find SPYGLASS_BASE_DIR: {resolved_base}" - + "\n\tCheck dj.config['custom']['spyglass_dirs']['base']" - + "\n\tand os.environ['SPYGLASS_BASE_DIR']" - ) + if not on_startup: # Only warn if not on startup + logger.error( + f"Could not find SPYGLASS_BASE_DIR: {resolved_base}" + + "\n\tCheck dj.config['custom']['spyglass_dirs']['base']" + + "\n\tand os.environ['SPYGLASS_BASE_DIR']" + ) + self.load_failed = True + return self._dlc_base = ( dj_dlc.get("base") @@ -130,7 +163,7 @@ def load_config(self, force_reload=False): ) Path(self._dlc_base).mkdir(exist_ok=True) - config_dirs = {"SPYGLASS_BASE_DIR": resolved_base} + config_dirs = {"SPYGLASS_BASE_DIR": str(resolved_base)} for prefix, dirs in self.relative_dirs.items(): this_base = self._dlc_base if prefix == "dlc" else resolved_base for dir, dir_str in dirs.items(): @@ -150,7 +183,7 @@ def load_config(self, force_reload=False): or str(Path(this_base) / dir_str) ).replace('"', "") - config_dirs.update({dir_env_fmt: dir_location}) + config_dirs.update({dir_env_fmt: str(dir_location)}) kachery_zone_dict = { "KACHERY_ZONE": ( @@ -175,7 +208,7 @@ def load_config(self, force_reload=False): **loaded_env, ) - self._set_dj_config_stores(config_dirs) + self._set_dj_config_stores() return self._config @@ -206,38 +239,42 @@ def _set_dj_config_stores(self, check_match=True, set_stores=True): dir_dict: dict Dictionary of resolved dirs. check_match: bool - Optional. Default True. Check that dj.config['stores'] match resolved dirs. + Optional. Default True. Check that dj.config['stores'] match + resolved dirs. set_stores: bool Optional. Default True. Set dj.config['stores'] to resolved dirs. """ + + mismatch_analysis = False + mismatch_raw = False + if check_match: dj_stores = dj.config.get("stores", {}) - store_raw = dj_stores.get("raw", {}).get("location") - store_analysis = dj_stores.get("analysis", {}).get("location") - - err_template = ( - "dj.config['stores'] does not match resolved dir." - + "\n\tdj.config['stores']['{0}']['location']:\n\t\t{1}" - + "\n\tSPYGLASS_{2}_DIR:\n\t\t{3}." - ) - if store_raw and Path(store_raw) != Path(self.raw_dir): - raise ValueError( - err_template.format("raw", store_raw, "RAW", self.raw_dir) - ) - if store_analysis and Path(store_analysis) != Path( + store_r = dj_stores.get("raw", {}).get("location") + store_a = dj_stores.get("analysis", {}).get("location") + mismatch_raw = store_r and Path(store_r) != Path(self.raw_dir) + mismatch_analysis = store_a and Path(store_a) != Path( self.analysis_dir - ): - raise ValueError( - err_template.format( - "analysis", - store_analysis, - "ANALYSIS", - self.analysis_dir, - ) - ) + ) if set_stores: + if mismatch_raw or mismatch_analysis: + logger.warning( + "Setting config DJ stores to resolve mismatch.\n\t" + + f"raw : {self.raw_dir}\n\t" + + f"analysis: {self.analysis_dir}" + ) dj.config.update(self._dj_stores) + return + + if mismatch_raw or mismatch_analysis: + raise ValueError( + "dj.config['stores'] does not match resolved dirs." + + f"\n\tdj.config['stores']: {dj_stores}" + + f"\n\tResolved dirs: {self._dj_stores}" + ) + + return def dir_to_var(self, dir: str, dir_type: str = "spyglass"): """Converts a dir string to an env variable name.""" @@ -247,6 +284,7 @@ def _generate_dj_config( self, base_dir: str = None, database_user: str = None, + database_password: str = None, database_host: str = "lmf-db.cin.ucsf.edu", database_port: int = 3306, database_use_tls: bool = True, @@ -256,12 +294,12 @@ def _generate_dj_config( Parameters ---------- - base_dir : str, optional - The base directory. If not provided, will use the env variable or - existing config. database_user : str, optional The database user. If not provided, resulting config will not specify. + database_password : str, optional + The database password. If not provided, resulting config will not + specify. database_host : str, optional Default lmf-db.cin.ucsf.edu. MySQL host name. dapabase_port : int, optional @@ -273,12 +311,10 @@ def _generate_dj_config( Note: python will raise error for params with `.` in name. """ - if base_dir: - self.supplied_base_dir = base_dir - self.load_config(force_reload=True) - if database_user: kwargs.update({"database.user": database_user}) + if database_password: + kwargs.update({"database.password": database_password}) kwargs.update( { @@ -294,9 +330,8 @@ def _generate_dj_config( def save_dj_config( self, save_method: str = "global", - filename: str = None, + output_filename: str = None, base_dir=None, - database_user=None, set_password=True, **kwargs, ): @@ -307,35 +342,51 @@ def save_dj_config( save_method : {'local', 'global', 'custom'}, optional The method to use to save the config. If either 'local' or 'global', datajoint builtins will be used to save. - filename : str or Path, optional + output_filename : str or Path, optional Default to datajoint global config. If save_method = 'custom', name of file to generate. Must end in either be either yaml or json. base_dir : str, optional The base directory. If not provided, will default to the env var - database_user : str, optional - The database user. If not provided, resulting config will not - specify. set_password : bool, optional Default True. Set the database password. + kwargs: dict, optional + Any other valid datajoint configuration parameters, including + database_user, database_password, database_host, database_port, etc. + Note: python will raise error for params with `.` in name, so use + underscores instead. """ - if save_method == "local": + if base_dir: + self.load_config( + base_dir=base_dir, force_reload=True, on_startup=False + ) + + if output_filename: + save_method = "custom" + path = Path(output_filename).expanduser() # Expand ~ + filepath = path if path.is_absolute() else path.absolute() + filepath.parent.mkdir(exist_ok=True, parents=True) + filepath = ( + filepath.with_suffix(".json") # ensure suffix, default json + if filepath.suffix not in [".json", ".yaml"] + else filepath + ) + elif save_method == "local": filepath = Path(".") / dj.settings.LOCALCONFIG - elif not filename or save_method == "global": - save_method = "global" + elif save_method == "global": filepath = Path("~").expanduser() / dj.settings.GLOBALCONFIG - - dj.config.update( - self._generate_dj_config( - base_dir=base_dir, database_user=database_user, **kwargs + else: + raise ValueError( + "For save_dj_config, either (a) save_method must be 'local' " + + " or 'global' or (b) must provide custom output_filename." ) - ) + + dj.config.update(self._generate_dj_config(**kwargs)) if set_password: try: dj.set_password() except OperationalError as e: - warnings.warn(f"Database connection issues. Wrong pass? {e}") - # NOTE: Save anyway? Or raise error? + warnings.warn(f"Database connection issues. Wrong pass?\n\t{e}") user_warn = ( f"Replace existing file? {filepath.resolve()}\n\t" @@ -343,8 +394,12 @@ def save_dj_config( + "\n" ) - if filepath.exists() and dj.utils.user_choice(user_warn)[0] != "y": - return dj.config + if ( + not self.test_mode + and filepath.exists() + and dj.utils.user_choice(user_warn)[0] != "y" + ): + return if save_method == "global": dj.config.save_global(verbose=True) @@ -354,11 +409,12 @@ def save_dj_config( dj.config.save_local(verbose=True) return - with open(filename, "w") as outfile: - if filename.endswith("json"): - json.dump(dj.config, outfile, indent=2) + with open(filepath, "w") as outfile: + if filepath.suffix == ".yaml": + yaml.dump(dj.config._conf, outfile, default_flow_style=False) else: - yaml.dump(dj.config, outfile, default_flow_style=False) + json.dump(dj.config._conf, outfile, indent=2) + logger.info(f"Saved config to {filepath}") @property def _dj_stores(self) -> dict: @@ -402,7 +458,7 @@ def _dj_custom(self) -> dict: "storage": self.config.get( self.dir_to_var("storage", "kachery") ), - "temp": self.config.get(self.dir_to_var("tmp", "kachery")), + "temp": self.config.get(self.dir_to_var("temp", "kachery")), }, "dlc_dirs": { "base": self._dlc_base, @@ -484,18 +540,31 @@ def dlc_output_dir(self) -> str: sg_config = SpyglassConfig() -config = sg_config.config -base_dir = sg_config.base_dir -raw_dir = sg_config.raw_dir -recording_dir = sg_config.recording_dir -temp_dir = sg_config.temp_dir -analysis_dir = sg_config.analysis_dir -sorting_dir = sg_config.sorting_dir -waveform_dir = sg_config.waveform_dir -video_dir = sg_config.video_dir -debug_mode = sg_config.debug_mode -test_mode = sg_config.test_mode -prepopulate = config.get("prepopulate", False) -dlc_project_dir = sg_config.dlc_project_dir -dlc_video_dir = sg_config.dlc_video_dir -dlc_output_dir = sg_config.dlc_output_dir +sg_config.load_config(on_startup=True) +if sg_config.load_failed: # Failed to load + logger.warning("Failed to load SpyglassConfig. Please set up config file.") + config = {} # Let __intit__ fetch empty config for first time setup + config, prepopulate, test_mode, base_dir, raw_dir, analysis_dir = ( + {}, + False, + False, + None, + None, + None, + ) +else: + config = sg_config.config + base_dir = sg_config.base_dir + raw_dir = sg_config.raw_dir + recording_dir = sg_config.recording_dir + temp_dir = sg_config.temp_dir + analysis_dir = sg_config.analysis_dir + sorting_dir = sg_config.sorting_dir + waveform_dir = sg_config.waveform_dir + video_dir = sg_config.video_dir + debug_mode = sg_config.debug_mode + test_mode = sg_config.test_mode + prepopulate = config.get("prepopulate", False) + dlc_project_dir = sg_config.dlc_project_dir + dlc_video_dir = sg_config.dlc_video_dir + dlc_output_dir = sg_config.dlc_output_dir diff --git a/src/spyglass/spikesorting/imported.py b/src/spyglass/spikesorting/imported.py index a5b48491f..bb2fb2fd0 100644 --- a/src/spyglass/spikesorting/imported.py +++ b/src/spyglass/spikesorting/imported.py @@ -1,5 +1,8 @@ +import copy + import datajoint as dj import pynwb +from datajoint.utils import to_camel_case from spyglass.common.common_nwbfile import Nwbfile from spyglass.common.common_session import Session # noqa: F401 @@ -16,18 +19,28 @@ class ImportedSpikeSorting(SpyglassMixin, dj.Imported): object_id: varchar(40) """ + _nwb_table = Nwbfile + def make(self, key): + orig_key = copy.deepcopy(key) nwb_file_abs_path = Nwbfile.get_abs_path(key["nwb_file_name"]) with pynwb.NWBHDF5IO( nwb_file_abs_path, "r", load_namespaces=True ) as io: nwbfile = io.read() - if nwbfile.units: - key["object_id"] = nwbfile.units.object_id - self.insert1(key, skip_duplicates=True) - else: + if not nwbfile.units: logger.warn("No units found in NWB file") + return + + from spyglass.spikesorting.merge import SpikeSortingOutput # noqa: F401 + + key["object_id"] = nwbfile.units.object_id + part_name = SpikeSortingOutput._part_name(self.table_name) + self.insert1(key, skip_duplicates=True) + SpikeSortingOutput._merge_insert( + [orig_key], part_name=part_name, skip_duplicates=True + ) @classmethod def get_recording(cls, key): diff --git a/src/spyglass/spikesorting/merge.py b/src/spyglass/spikesorting/merge.py index 12baefd34..db8893040 100644 --- a/src/spyglass/spikesorting/merge.py +++ b/src/spyglass/spikesorting/merge.py @@ -81,13 +81,11 @@ def get_spike_times(cls, key): def get_spike_indicator(cls, key, time): time = np.asarray(time) min_time, max_time = time[[0, -1]] - spike_times = cls.get_spike_times(key) + spike_times = cls.load_spike_data(key) spike_indicator = np.zeros((len(time), len(spike_times))) for ind, times in enumerate(spike_times): - times = times[ - np.logical_and(spike_times >= min_time, spike_times <= max_time) - ] + times = times[np.logical_and(times >= min_time, times <= max_time)] spike_indicator[:, ind] = np.bincount( np.digitize(times, time[1:-1]), minlength=time.shape[0], diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index a710972ff..c37122c70 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -1,7 +1,9 @@ import re from contextlib import nullcontext +from inspect import getmodule from itertools import chain as iter_chain from pprint import pprint +from typing import Union import datajoint as dj from datajoint.condition import make_condition @@ -10,7 +12,6 @@ from datajoint.utils import from_camel_case, get_master, to_camel_case from IPython.core.display import HTML -from spyglass.utils.dj_helper_fn import fetch_nwb from spyglass.utils.logging import logger RESERVED_PRIMARY_KEY = "merge_id" @@ -51,15 +52,7 @@ def __init__(self): + f"\n\tExpected: {self.primary_key}" + f"\n\tActual : {part.primary_key}" ) - self._analysis_nwbfile = None - - @property # CB: This is a property to avoid circular import - def analysis_nwbfile(self): - if self._analysis_nwbfile is None: - from spyglass.common import AnalysisNwbfile # noqa F401 - - self._analysis_nwbfile = AnalysisNwbfile - return self._analysis_nwbfile + self._source_class_dict = {} def _remove_comments(self, definition): """Use regular expressions to remove comments and blank lines""" @@ -67,6 +60,38 @@ def _remove_comments(self, definition): r"\n\s*\n", "\n", re.sub(r"#.*\n", "\n", definition) ) + @staticmethod + def _part_name(part=None): + """Return the CamelCase name of a part table""" + if not isinstance(part, str): + part = part.table_name + return to_camel_case(part.split("__")[-1].strip("`")) + + def get_source_from_key(self, key: dict) -> str: + """Return the source of a given key""" + return self._normalize_source(key) + + def parts(self, camel_case=False, *args, **kwargs) -> list: + """Return a list of part tables, add option for CamelCase names. + + See DataJoint `parts` for additional arguments. If camel_case is True, + forces return of strings rather than objects. + """ + self._ensure_dependencies_loaded() + + if camel_case and kwargs.get("as_objects"): + logger.warning( + "Overriding as_objects=True to return CamelCase part names." + ) + kwargs["as_objects"] = False + + parts = super().parts(*args, **kwargs) + + if camel_case: + parts = [self._part_name(part) for part in parts] + + return parts + @classmethod def _merge_restrict_parts( cls, @@ -220,33 +245,32 @@ def _merge_repr(cls, restriction: str = True) -> dj.expression.Union: for p in cls._merge_restrict_parts( restriction=restriction, add_invalid_restrict=False, - return_empties=True, + return_empties=False, # motivated by SpikeSortingOutput.Import ) ] - primary_attrs = list( - dict.fromkeys( # get all columns from parts - iter_chain.from_iterable([p.heading.names for p in parts]) + attr_dict = { # NULL for non-numeric, 0 for numeric + attr.name: "0" if attr.numeric else "NULL" + for attr in iter_chain.from_iterable( + part.heading.attributes.values() for part in parts ) - ) - # primary_attrs.append(cls()._reserved_sk) - query = dj.U(*primary_attrs) * parts[0].proj( # declare query - ..., # include all attributes from part 0 - **{ - a: "NULL" # add null value where part has no column - for a in primary_attrs - if a not in parts[0].heading.names - }, - ) - for part in parts[1:]: # add to declared query for each part - query += dj.U(*primary_attrs) * part.proj( - ..., + } + + def _proj_part(part): + """Project part, adding NULL/0 for missing attributes""" + return dj.U(*attr_dict.keys()) * part.proj( + ..., # include all attributes from part **{ - a: "NULL" - for a in primary_attrs - if a not in part.heading.names + k: v + for k, v in attr_dict.items() + if k not in part.heading.names }, ) + + query = _proj_part(parts[0]) # start with first part + for part in parts[1:]: # add remaining parts + query += _proj_part(part) + return query @classmethod @@ -294,7 +318,7 @@ def _merge_insert( keys = [] # empty to-be-inserted key for part in parts: # check each part part_parent = part.parents(as_objects=True)[-1] - part_name = to_camel_case(part.table_name.split("__")[-1]) + part_name = cls._part_name(part) if part_parent & row: # if row is in part parent if keys and mutual_exclusvity: # if key from other part raise ValueError( @@ -475,16 +499,18 @@ def merge_delete_parent( for part_parent in part_parents: super().delete(part_parent, **kwargs) - @classmethod def fetch_nwb( - cls, + self, restriction: str = True, multi_source=False, disable_warning=False, *attrs, **kwargs, ): - """Return the AnalysisNwbfile file linked in the source. + """Return the (Analysis)Nwbfile file linked in the source. + + Relies on SpyglassMixin._nwb_table_tuple to determine the table to + fetch from and the appropriate path attribute to return. Parameters ---------- @@ -493,32 +519,14 @@ def fetch_nwb( multi_source: bool Return from multiple parents. Default False. """ - if not disable_warning: - _warn_on_restriction(table=cls, restriction=restriction) - - part_parents = cls._merge_restrict_parents( - restriction=restriction, - return_empties=False, - add_invalid_restrict=False, - ) - - if not multi_source and len(part_parents) != 1: - raise ValueError( - f"{len(part_parents)} possible sources found in Merge Table:" - + " and ".join([p.full_table_name for p in part_parents]) - ) + if isinstance(self, dict): + raise ValueError("Try replacing Merge.method with Merge().method") + if restriction is True and self.restriction: + if not disable_warning: + _warn_on_restriction(self, restriction) + restriction = self.restriction - nwbs = [] - for part_parent in part_parents: - nwbs.extend( - fetch_nwb( - part_parent, - (cls().analysis_nwbfile, "analysis_file_abs_path"), - *attrs, - **kwargs, - ) - ) - return nwbs + return self.merge_restrict_class(restriction).fetch_nwb() @classmethod def merge_get_part( @@ -527,6 +535,7 @@ def merge_get_part( join_master: bool = False, restrict_part=True, multi_source=False, + return_empties=False, ) -> dj.Table: """Retrieve part table from a restricted Merge table. @@ -545,6 +554,8 @@ def merge_get_part( native part table. multi_source: bool Return multiple parts. Default False. + return_empties: bool + Default False. Return empty part tables. Returns ------ @@ -563,11 +574,11 @@ def merge_get_part( restricting """ sources = [ - to_camel_case(n.split("__")[-1].strip("`")) # friendly part name - for n in cls._merge_restrict_parts( + cls._part_name(part) # friendly part name + for part in cls._merge_restrict_parts( restriction=restriction, as_objects=False, - return_empties=False, + return_empties=return_empties, add_invalid_restrict=False, ) ] @@ -595,7 +606,9 @@ def merge_get_parent( cls, restriction: str = True, join_master: bool = False, - multi_source=False, + multi_source: bool = False, + return_empties: bool = False, + add_invalid_restrict: bool = True, ) -> dj.FreeTable: """Returns a list of part parents with restrictions applied. @@ -610,6 +623,12 @@ def merge_get_parent( Default True. join_master: bool Default False. Join part with Merge master to show uuid and source + multi_source: bool + Return multiple parents. Default False. + return_empties: bool + Default False. Return empty parent tables. + add_invalid_restrict: bool + Default True. Include parent for which the restriction is invalid. Returns ------ @@ -620,11 +639,12 @@ def merge_get_parent( part_parents = cls._merge_restrict_parents( restriction=restriction, as_objects=True, - return_empties=False, - add_invalid_restrict=False, + return_empties=return_empties, + add_invalid_restrict=add_invalid_restrict, ) if not multi_source and len(part_parents) != 1: + __import__("pdb").set_trace() raise ValueError( f"Found {len(part_parents)} potential parents: {part_parents}" + "\n\tTry adding a string restriction when invoking " @@ -637,6 +657,71 @@ def merge_get_parent( return part_parents if multi_source else part_parents[0] + @property + def source_class_dict(self) -> dict: + if not self._source_class_dict: + module = getmodule(self) + self._source_class_dict = { + part_name: getattr(module, part_name) + for part_name in self.parts(camel_case=True) + } + return self._source_class_dict + + def _normalize_source( + self, source: Union[str, dj.Table, dj.condition.AndList, dict] + ) -> str: + fetched_source = None + if isinstance(source, (Merge, dj.condition.AndList)): + try: + fetched_source = (self & source).fetch(self._reserved_sk) + except DataJointError: + raise ValueError(f"Unable to find source for {source}") + source = fetched_source[0] + if len(fetched_source) > 1: + logger.warn(f"Multiple sources. Selecting first: {source}.") + if isinstance(source, dj.Table): + source = self._part_name(source) + if isinstance(source, dict): + source = self._part_name(self.merge_get_parent(source)) + + return source + + def merge_get_parent_class(self, source: str) -> dj.Table: + """Return the class of the parent table for a given CamelCase source. + + Parameters + ---------- + source: Union[str, dict, dj.Table] + Accepts a CamelCase name of the source, or key as a dict, or a part + table. + + Returns + ------- + dj.Table + Class instance of the parent table, including class methods. + """ + + ret = self.source_class_dict.get(self._normalize_source(source)) + if not ret: + logger.error( + f"No source class found for {source}: \n\t" + + f"{self.parts(camel_case=True)}" + ) + return ret + + def merge_restrict_class(self, key: dict) -> dj.Table: + """Returns native parent class, restricted with key.""" + parent_key = self.merge_get_parent(key).fetch("KEY", as_dict=True) + + if len(parent_key) > 1: + raise ValueError( + f"Ambiguous entry. Data has mult rows in parent:\n\tData:{key}" + + f"\n\t{parent_key}" + ) + + parent_class = self.merge_get_parent_class(key) + return parent_class & parent_key + @classmethod def merge_fetch(self, restriction: str = True, *attrs, **kwargs) -> list: """Perform a fetch across all parts. If >1 result, return as a list. diff --git a/src/spyglass/utils/dj_mixin.py b/src/spyglass/utils/dj_mixin.py index 3ee0f6292..490274fe0 100644 --- a/src/spyglass/utils/dj_mixin.py +++ b/src/spyglass/utils/dj_mixin.py @@ -31,11 +31,14 @@ class SpyglassMixin: Alias for cautious_delete. """ - _nwb_table_dict = {} - _delete_dependencies = [] - _merge_delete_func = None - _session_pk = None - _member_pk = None + _nwb_table_dict = {} # Dict mapping NWBFile table to path attribute name. + # _nwb_table = None # NWBFile table class, defined at the table level + _nwb_table_resolved = None # NWBFiletable class, resolved here from above + _delete_dependencies = [] # Session, LabMember, LabTeam, delay import + _merge_delete_func = None # delete_downstream_merge, delay import + # pks for delete permission check, assumed to be on field + _session_pk = None # Session primary key. Mixin is ambivalent to Session pk + _member_pk = None # LabMember primary key. Mixin ambivalent table structure # ------------------------------- fetch_nwb ------------------------------- @@ -58,6 +61,47 @@ def _table_dict(self): } return self._nwb_table_dict + @property + def _nwb_table_tuple(self): + """NWBFile table class. + + Used to determine fetch_nwb behavior. Also used in Merge.fetch_nwb. + Multiple copies for different purposes. + + - _nwb_table may be user-set. Don't overwrite. + - _nwb_table_resolved is set here from either _nwb_table or definition. + - _nwb_table_tuple is used to cache result of _nwb_table_resolved and + return the appropriate path_attr from _table_dict above. + """ + if not self._nwb_table_resolved: + from spyglass.common.common_nwbfile import ( # noqa F401 + AnalysisNwbfile, + Nwbfile, + ) + + if hasattr(self, "_nwb_table"): + self._nwb_table_resolved = self._nwb_table + + if not hasattr(self, "_nwb_table"): + self._nwb_table_resolved = ( + AnalysisNwbfile + if "-> AnalysisNwbfile" in self.definition + else Nwbfile + if "-> Nwbfile" in self.definition + else None + ) + + if getattr(self, "_nwb_table_resolved", None) is None: + raise NotImplementedError( + f"{self.__class__.__name__} does not have a " + "(Analysis)Nwbfile foreign key or _nwb_table attribute." + ) + + return ( + self._nwb_table_resolved, + self._table_dict[self._nwb_table_resolved], + ) + def fetch_nwb(self, *attrs, **kwargs): """Fetch NWBFile object from relevant table. @@ -68,30 +112,9 @@ def fetch_nwb(self, *attrs, **kwargs): '-> AnalysisNwbfile' in its definition can use a _nwb_table attribute to specify which table to use. """ - _nwb_table_dict = self._table_dict - analysis_table, nwb_table = _nwb_table_dict.keys() - - if not hasattr(self, "_nwb_table"): - self._nwb_table = ( - analysis_table - if "-> AnalysisNwbfile" in self.definition - else nwb_table - if "-> Nwbfile" in self.definition - else None - ) - - if getattr(self, "_nwb_table", None) is None: - raise NotImplementedError( - f"{self.__class__.__name__} does not have a (Analysis)Nwbfile " - "foreign key or _nwb_table attribute." - ) + nwb_table, path_attr = self._nwb_table_tuple - return fetch_nwb( - self, - (self._nwb_table, _nwb_table_dict[self._nwb_table]), - *attrs, - **kwargs, - ) + return fetch_nwb(self, (nwb_table, path_attr), *attrs, **kwargs) # -------------------------------- delete --------------------------------- diff --git a/src/spyglass/utils/nwb_helper_fn.py b/src/spyglass/utils/nwb_helper_fn.py index d09b5b9fd..6b7947b2d 100644 --- a/src/spyglass/utils/nwb_helper_fn.py +++ b/src/spyglass/utils/nwb_helper_fn.py @@ -2,7 +2,6 @@ import os import os.path -import warnings from itertools import groupby from pathlib import Path @@ -127,7 +126,7 @@ def get_data_interface(nwbfile, data_interface_name, data_interface_class=None): Warns ----- - UserWarning + LoggerWarning If multiple NWBDataInterface and DynamicTable objects with the matching name are found. @@ -146,7 +145,7 @@ def get_data_interface(nwbfile, data_interface_name, data_interface_class=None): continue ret.append(match) if len(ret) > 1: - warnings.warn( + logger.warning( f"Multiple data interfaces with name '{data_interface_name}' " f"found in NWBFile with identifier {nwbfile.identifier}. " + "Using the first one found. " diff --git a/tests/common/test_behav.py b/tests/common/test_behav.py index c21ed96f6..73b435c8a 100644 --- a/tests/common/test_behav.py +++ b/tests/common/test_behav.py @@ -1,6 +1,8 @@ import pytest from pandas import DataFrame +from ..conftest import TEARDOWN + def test_invalid_interval(pos_src): """Test invalid interval""" @@ -44,6 +46,7 @@ def test_videofile_getabspath(common, mini_restr): common.VideoFile().getabspath(mini_restr) +@pytest.mark.skipif(not TEARDOWN, reason="No teardown: expect no change.") def test_posinterval_no_transaction(verbose_context, common, mini_restr): """Test no transaction""" before = common.PositionIntervalMap().fetch() diff --git a/tests/common/test_ephys.py b/tests/common/test_ephys.py index 9ad1ea0a4..bcce8ddf2 100644 --- a/tests/common/test_ephys.py +++ b/tests/common/test_ephys.py @@ -1,6 +1,8 @@ import pytest from numpy import array_equal +from ..conftest import TEARDOWN + def test_create_from_config(mini_insert, common_ephys, mini_path): before = common_ephys.Electrode().fetch() @@ -18,11 +20,11 @@ def test_raw_object(mini_insert, common_ephys, mini_dict, mini_content): assert obj_fetch == obj_raw, "Raw.nwb_object did not return expected object" +@pytest.mark.skipif(not TEARDOWN, reason="No teardown: expect no change.") def test_set_lfp_electrodes(mini_insert, common_ephys, mini_copy_name): before = common_ephys.LFPSelection().fetch() common_ephys.LFPSelection().set_lfp_electrodes(mini_copy_name, [0]) after = common_ephys.LFPSelection().fetch() - # Because already inserted, expect no change assert ( len(after) == len(before) + 1 ), "Set LFP electrodes had unexpected effect" diff --git a/tests/common/test_position.py b/tests/common/test_position.py index 47f285977..8a7261c74 100644 --- a/tests/common/test_position.py +++ b/tests/common/test_position.py @@ -1,5 +1,4 @@ import pytest -from datajoint.hash import key_hash @pytest.fixture @@ -96,7 +95,9 @@ def upsample_position_error( skip_duplicates=True, ) interval_pos_key = {**interval_key, **upsample_param_key} - common_position.IntervalPositionInfoSelection.insert1(interval_pos_key) + common_position.IntervalPositionInfoSelection.insert1( + interval_pos_key, skip_duplicates=not teardown + ) yield interval_pos_key if teardown: (param_table & upsample_param_key).delete(safemode=False) diff --git a/tests/common/test_region.py b/tests/common/test_region.py index 95f62fe1b..8241cb304 100644 --- a/tests/common/test_region.py +++ b/tests/common/test_region.py @@ -1,6 +1,8 @@ import pytest from datajoint import U as dj_U +from ..conftest import TEARDOWN + @pytest.fixture def region_dict(): @@ -15,6 +17,7 @@ def brain_region(common, region_dict): (brain_region & "region_id > 1").delete(safemode=False) +@pytest.mark.skipif(not TEARDOWN, reason="No teardown: no test autoincrement") def test_region_add(brain_region, region_dict): next_id = ( dj_U().aggr(brain_region, n="max(region_id)").fetch1("n") or 0 diff --git a/tests/conftest.py b/tests/conftest.py index 3c2bc866b..759ca43fa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -78,7 +78,7 @@ def pytest_configure(config): os.environ["SPYGLASS_BASE_DIR"] = str(BASE_DIR) SERVER = DockerMySQLManager( - restart=False, + restart=TEARDOWN, shutdown=TEARDOWN, null_server=config.option.no_server, verbose=VERBOSE, @@ -162,10 +162,13 @@ def server(request, teardown): @pytest.fixture(scope="session") def dj_conn(request, server, verbose, teardown): """Fixture for datajoint connection.""" - config_file = "dj_local_conf.json_pytest" + config_file = "dj_local_conf.json_test" + if Path(config_file).exists(): + os.remove(config_file) dj.config.update(server.creds) dj.config["loglevel"] = "INFO" if verbose else "ERROR" + dj.config["custom"]["spyglass_dirs"] = {"base": str(BASE_DIR)} dj.config.save(config_file) dj.conn() yield dj.conn() @@ -242,6 +245,7 @@ def mini_closed(mini_path): def mini_insert(mini_path, teardown, server, dj_conn): from spyglass.common import Nwbfile, Session # noqa: E402 from spyglass.data_import import insert_sessions # noqa: E402 + from spyglass.spikesorting.merge import SpikeSortingOutput # noqa: E402 from spyglass.utils.nwb_helper_fn import close_nwb_files # noqa: E402 dj_logger.info("Inserting test data.")