From 91ce55f4b0a8275b32ea52400bae459ece46bed6 Mon Sep 17 00:00:00 2001 From: Kyu Hyun Lee Date: Fri, 8 Dec 2023 16:17:25 -0800 Subject: [PATCH] Spike sorting v1 pipeline (#651) * Save LFP as pynwb.ecephys.LFP * Fix formatting * Fix formatting * Add new tables * Change name of Figurl table * Update pipeline * Edit merge * Remove methods * Minor fix * Save preproc rec as NWB * Add artfiact changes * Minor change * Fix lint * Update artifact and sorting * Minor update * Update sorting * Start curation * Update curation * Update curation * Write sorting with curation * Finish curation * Remove unused imports * Add new schema * Update data type * Update figurl curation * Modify metric curation * Reorg metric curation * Streamline metric * Add to MetricCuration * Add user exposed methods * Minor update * Minor edit * Update metric methods * Add docstring * Update metric curation * Add comments * Fix merge * Fix spelling error * Change Curation to CurationV1 * Fix comments * Update src/spyglass/spikesorting/v1/artifact.py Co-authored-by: Chris Brozdowski * Remove neuroconv from dep and add copied class * Update init * Remove unused import * Remove unused import * Update src/spyglass/spikesorting/v1/curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/artifact.py Co-authored-by: Chris Brozdowski * Make ->Session primary key in SortGroup * Make insert1 insert for multi row insertion * Remove neuroconv dep * Update settings.json * Change MetricCuration dep * Add `insert_selection` * Update based on comments * Move util * Add insert_selection to SpikeSortingSelection * Minor changes to sorting.py * Fix id gen * Change fetch * Minor change * Update recording * More changes * Add dep * Add curation to AnalysisNWB * Minor change * Update artifact detection * Minor update * Fix while testing * Formatting * Update figurl * Add ss v1 tutorial * Update merge * Update populate * Update notebook and merge insert * Update import issue * Fix typo * Fix timestamp extend * Handle ref channel not in sort group * Apply black * Add insert metric curation * Fix spelling error * Fix concat ref channel * Fix insert metric curation * Add SpikeSortingOutput.CuratedSpikeSorting * Sortings -> sorting * Update src/spyglass/spikesorting/v1/recording.py Co-authored-by: Chris Brozdowski * Revert back to sort_group_id int * Update src/spyglass/spikesorting/v1/metric_curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/metric_curation.py Co-authored-by: Chris Brozdowski * Changes based on comments * Update src/spyglass/spikesorting/merge.py Co-authored-by: Eric Denovellis * Update src/spyglass/spikesorting/v1/curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/recording.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/metric_curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/recording.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/metric_curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/metric_curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/metric_curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/curation.py Co-authored-by: Chris Brozdowski * add util function to get spiking merge_ids from restriction * Update src/spyglass/spikesorting/v1/curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/metric_curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/metric_curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/figurl_curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/figurl_curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/figurl_curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/metric_curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/artifact.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/artifact.py Co-authored-by: Chris Brozdowski * Remove custom merge insert * Update src/spyglass/spikesorting/v1/metric_utils.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/recording.py Co-authored-by: Chris Brozdowski * Minor fix * Update src/spyglass/spikesorting/v1/figurl_curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/metric_curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/artifact.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/recording.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/figurl_curation.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/metric_utils.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/metric_utils.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/figurl_curation.py Co-authored-by: Chris Brozdowski * Update name * Update src/spyglass/spikesorting/v1/recording.py Co-authored-by: Chris Brozdowski * Update src/spyglass/spikesorting/v1/recording.py Co-authored-by: Chris Brozdowski * Lint * Fix changes --------- Co-authored-by: Eric Denovellis Co-authored-by: Chris Brozdowski Co-authored-by: Eric Denovellis Co-authored-by: Sam Bray --- .vscode/settings.json | 1 - notebooks/10_Spike_SortingV1.ipynb | 829 +++++++++++++++++ src/spyglass/common/common_interval.py | 41 + src/spyglass/common/populate_all_common.py | 4 + src/spyglass/spikesorting/imported.py | 28 + src/spyglass/spikesorting/merge.py | 43 + src/spyglass/spikesorting/v1/__init__.py | 22 + src/spyglass/spikesorting/v1/artifact.py | 457 ++++++++++ src/spyglass/spikesorting/v1/curation.py | 481 ++++++++++ .../spikesorting/v1/figurl_curation.py | 292 ++++++ .../spikesorting/v1/metric_curation.py | 590 ++++++++++++ src/spyglass/spikesorting/v1/metric_utils.py | 69 ++ src/spyglass/spikesorting/v1/recording.py | 860 ++++++++++++++++++ src/spyglass/spikesorting/v1/sorting.py | 348 +++++++ src/spyglass/spikesorting/v1/utils.py | 107 +++ 15 files changed, 4171 insertions(+), 1 deletion(-) create mode 100644 notebooks/10_Spike_SortingV1.ipynb create mode 100644 src/spyglass/spikesorting/imported.py create mode 100644 src/spyglass/spikesorting/merge.py create mode 100644 src/spyglass/spikesorting/v1/__init__.py create mode 100644 src/spyglass/spikesorting/v1/artifact.py create mode 100644 src/spyglass/spikesorting/v1/curation.py create mode 100644 src/spyglass/spikesorting/v1/figurl_curation.py create mode 100644 src/spyglass/spikesorting/v1/metric_curation.py create mode 100644 src/spyglass/spikesorting/v1/metric_utils.py create mode 100644 src/spyglass/spikesorting/v1/recording.py create mode 100644 src/spyglass/spikesorting/v1/sorting.py create mode 100644 src/spyglass/spikesorting/v1/utils.py diff --git a/.vscode/settings.json b/.vscode/settings.json index bc2c1fc8c..f94239ef5 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -2,7 +2,6 @@ "editor.formatOnSave": true, "files.trimTrailingWhitespace": true, "files.trimFinalNewlines": true, - "editor.multiCursorModifier": "ctrlCmd", "autoDocstring.docstringFormat": "numpy", "remote.SSH.remoteServerListenOnSocket": true, "git.confirmSync": false, diff --git a/notebooks/10_Spike_SortingV1.ipynb b/notebooks/10_Spike_SortingV1.ipynb new file mode 100644 index 000000000..adccebc26 --- /dev/null +++ b/notebooks/10_Spike_SortingV1.ipynb @@ -0,0 +1,829 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5fce5c22-caab-473b-a103-5009a2798d12", + "metadata": {}, + "source": [ + "Connect to db (remove in later version that works with production database)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5778bf96-740c-4e4b-a695-ed4385fc9b58", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import datajoint as dj\n", + "import os\n", + "from pathlib import Path\n", + "\n", + "# set dirs\n", + "base_dir = Path('/hdd/dj') # change this to your desired directory\n", + "if (base_dir).exists() is False:\n", + " os.mkdir(base_dir)\n", + "raw_dir = base_dir / 'raw'\n", + "if (raw_dir).exists() is False:\n", + " os.mkdir(raw_dir)\n", + "analysis_dir = base_dir / 'analysis'\n", + "if (analysis_dir).exists() is False:\n", + " os.mkdir(analysis_dir)\n", + "tmp_dir = base_dir / 'tmp'\n", + "if (tmp_dir).exists() is False:\n", + " os.mkdir(tmp_dir)\n", + "\n", + "# set dj config\n", + "dj.config['database.host'] = 'localhost'\n", + "dj.config['database.user'] = 'root'\n", + "dj.config['database.password'] = 'tutorial'\n", + "dj.config['database.port'] = 3306\n", + "dj.config['stores'] = {\n", + " 'raw': {\n", + " 'protocol': 'file',\n", + " 'location': str(raw_dir),\n", + " 'stage': str(raw_dir)\n", + " },\n", + " 'analysis': {\n", + " 'protocol': 'file',\n", + " 'location': str(analysis_dir),\n", + " 'stage': str(analysis_dir)\n", + " }\n", + "}\n", + "dj.config[\"enable_python_native_blobs\"] = True\n", + "\n", + "\n", + "# set env vars\n", + "os.environ['SPYGLASS_BASE_DIR'] = str(base_dir)\n", + "os.environ['SPYGLASS_TEMP_DIR'] = str(tmp_dir)\n", + "os.environ['KACHERY_CLOUD_DIR'] = '/hdd/dj/.kachery-cloud'\n", + "os.environ['KACHERY_ZONE'] = \"franklab.default\"\n", + "os.environ['DJ_SUPPORT_FILEPATH_MANAGEMENT'] = 'TRUE'\n", + "\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "id": "84e7c0b5-f660-4304-9b87-08f5bbf4dbac", + "metadata": {}, + "source": [ + "import" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "16345184-c012-486c-b0b6-c914168f2449", + "metadata": {}, + "outputs": [], + "source": [ + "import spyglass as sg\n", + "import spyglass.common as sgc\n", + "import spyglass.spikesorting.v1 as sgs\n", + "import spyglass.data_import as sgi" + ] + }, + { + "cell_type": "markdown", + "id": "48d2c06a-feb6-438c-94b3-4028127e2101", + "metadata": {}, + "source": [ + "insert LabMember and Session" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3a0ecdf-8dad-41d5-9ee2-fa60f80c746d", + "metadata": {}, + "outputs": [], + "source": [ + "nwb_file_name = \"wilbur20210326.nwb\"\n", + "nwb_file_name2 = \"wilbur20210326_.nwb\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69e258e0-2c55-434a-afb5-00cf635bac84", + "metadata": {}, + "outputs": [], + "source": [ + "sgc.LabMember.insert_from_nwbfile(nwb_file_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfa1b73e-da6e-470f-aff6-0d45c3ddff5c", + "metadata": {}, + "outputs": [], + "source": [ + "sgi.insert_sessions(nwb_file_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e11b5f5d-e9e0-4949-9fc1-4a34cc975fb1", + "metadata": {}, + "outputs": [], + "source": [ + "sgc.Session()" + ] + }, + { + "cell_type": "markdown", + "id": "5f3dfe2d-4645-44f9-b169-479292215afe", + "metadata": {}, + "source": [ + "insert SortGroup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a269f6af-eb16-4551-b511-a264368c9490", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.SortGroup.set_group_by_shank(nwb_file_name=nwb_file_name2)" + ] + }, + { + "cell_type": "markdown", + "id": "1c55792e-f9ba-4e0d-a4d2-8c60bf0e8f34", + "metadata": {}, + "source": [ + "insert SpikeSortingRecordingSelection. use `insert_selection` method. this automatically generates a unique recording id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b307631-3cc5-4859-9e95-aeedf6a3de56", + "metadata": {}, + "outputs": [], + "source": [ + "key = {\"nwb_file_name\" : nwb_file_name2,\n", + " \"sort_group_id\" : 0,\n", + " \"interval_list_name\" : \"03_r1\",\n", + " \"preproc_param_name\" : \"default\",\n", + " \"team_name\" : \"Alison Comrie\"}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6d0dcb0-acfe-4adb-8da6-a5570b97f48a", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.SpikeSortingRecordingSelection.insert_selection(key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e61248f-1f55-4782-9018-ff1891acfc16", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.SpikeSortingRecordingSelection()" + ] + }, + { + "cell_type": "markdown", + "id": "2fc85684-71a2-4e37-9ad7-a3a923749c8f", + "metadata": {}, + "source": [ + "preprocess recording (filtering and referencing)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1bd48e28-d40e-4cf3-a89e-58d4c3cb08e8", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.SpikeSortingRecording.populate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a9bf343-5b5e-457c-8bf4-f12b194a5489", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.SpikeSortingRecording()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c6c7ea3-9538-4fa9-890b-ee16cc18af31", + "metadata": {}, + "outputs": [], + "source": [ + "key = (sgs.SpikeSortingRecordingSelection & {\"nwb_file_name\":nwb_file_name2}).fetch1()" + ] + }, + { + "cell_type": "markdown", + "id": "1955ed06-d754-470a-b5b3-94df6c3e03eb", + "metadata": {}, + "source": [ + "insert ArtifactDetectionSelection" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74415172-f2da-4fd3-ab43-01857d682b0d", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.ArtifactDetectionSelection.insert_selection({'recording_id':key['recording_id'],\n", + " 'artifact_param_name':'default'})" + ] + }, + { + "cell_type": "markdown", + "id": "a1fe3cf8-07c2-4743-90f6-8a8025bec696", + "metadata": {}, + "source": [ + "detect artifact; note the output is stored in IntervalList" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd89c36c-c05b-4b4a-85d9-7679fed173d1", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.ArtifactDetection.populate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca9a9f64-0afc-4c83-b22c-0ed120cb87f6", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.ArtifactDetection()" + ] + }, + { + "cell_type": "markdown", + "id": "65ae0f70-2d8d-40d4-86c9-2ab206b28ca9", + "metadata": {}, + "source": [ + "insert SpikeSortingSelection. again use `insert_selection` method" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "34246883-9dc4-43c5-a438-009215a3a35e", + "metadata": {}, + "outputs": [], + "source": [ + "key = {\n", + " \"recording_id\":key['recording_id'],\n", + " \"sorter\":\"mountainsort4\",\n", + " \"sorter_param_name\": \"franklab_tetrode_hippocampus_30KHz\",\n", + " \"nwb_file_name\":nwb_file_name2,\n", + " \"interval_list_name\":str((sgs.ArtifactDetectionSelection & {'recording_id':key['recording_id']}).fetch1(\"artifact_id\"))\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f431dc4-57d6-4a6d-82e0-9b313ac0ce3f", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.SpikeSortingSelection()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68856fb6-b5c2-4ee4-b300-43a117e453a1", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.SpikeSortingSelection.insert_selection(key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2caed63b-6094-4a59-b8d9-6a0f186b2d3f", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.SpikeSortingSelection()" + ] + }, + { + "cell_type": "markdown", + "id": "bb343fb7-04d6-48fc-bf67-9919769a7a52", + "metadata": {}, + "source": [ + "run spike sorting" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4cdc09d9-af18-420c-9707-7a439d5686a8", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.SpikeSorting.populate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54ccf059-b1ae-42e8-aede-4af30a61fd2b", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.SpikeSorting()" + ] + }, + { + "cell_type": "markdown", + "id": "ea8fcaa0-9dd7-4870-9f5b-be039e3579cc", + "metadata": {}, + "source": [ + "we have two main ways of curating spike sorting: by computing quality metrics and applying threshold; and manually applying curation labels. to do so, we first insert CurationV1. use `insert_curation` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0589a3f2-4977-407f-b49d-4ae3f882ae21", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.CurationV1.insert_curation(sorting_id=(sgs.SpikeSortingSelection & {'recording_id':key['recording_id']}).fetch1(\"sorting_id\"),\n", + " description=\"testing sort\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5bec5b97-4e9f-4ee9-a6b5-4f05f4726744", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.CurationV1()" + ] + }, + { + "cell_type": "markdown", + "id": "97317b6f-a40a-4f84-8042-4361064f010a", + "metadata": {}, + "source": [ + "we will first do an automatic curation based on quality metrics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7207abda-ea84-43af-97d4-e5be3464d28d", + "metadata": {}, + "outputs": [], + "source": [ + "key = {\n", + " \"sorting_id\":(sgs.SpikeSortingSelection & {'recording_id':key['recording_id']}).fetch1(\"sorting_id\"),\n", + " \"curation_id\":0,\n", + " \"waveform_param_name\":\"default_not_whitened\",\n", + " \"metric_param_name\":\"franklab_default\",\n", + " \"metric_curation_param_name\":\"default\"\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "14c2eacc-cc45-4e61-9919-04785a721079", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.MetricCurationSelection.insert_selection(key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d73244b3-f754-4701-be52-ea261eb4185c", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.MetricCurationSelection()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d22f5725-4fd1-42ea-a1d4-590bd1353d46", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.MetricCuration.populate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eda6577c-3ed2-480a-b6ed-107d7c479084", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.MetricCuration()" + ] + }, + { + "cell_type": "markdown", + "id": "54f354bf-0bfa-4148-9c5d-c5593f3f3915", + "metadata": {}, + "source": [ + "to do another round of curation, fetch the relevant info and insert back into CurationV1 using `insert_curation`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "544ba8c0-560e-471b-9eaf-5924f6051faa", + "metadata": {}, + "outputs": [], + "source": [ + "key = {\"metric_curation_id\":(sgs.MetricCurationSelection & {'sorting_id':key['sorting_id']}).fetch1(\"metric_curation_id\")}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45f3bc4a-2842-4802-ad0f-4f333dda171e", + "metadata": {}, + "outputs": [], + "source": [ + "labels = sgs.MetricCuration.get_labels(key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df84263f-db5a-44b7-8309-4d63d10fd883", + "metadata": {}, + "outputs": [], + "source": [ + "merge_groups = sgs.MetricCuration.get_merge_groups(key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "895c85a5-5b4f-44de-a003-c942ba231c22", + "metadata": {}, + "outputs": [], + "source": [ + "metrics = sgs.MetricCuration.get_metrics(key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "863f6e17-40a6-4b8d-82b5-d14a059c5c77", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.CurationV1.insert_curation(sorting_id=(sgs.MetricCurationSelection & {'metric_curation_id':key['metric_curation_id']}).fetch1(\"sorting_id\"),\n", + " parent_curation_id=0,\n", + " labels=labels,\n", + " merge_groups= merge_groups,\n", + " metrics=metrics,\n", + " description=\"after metric curation\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7c6bfd9-5985-41e1-bf37-8c8874b59191", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.CurationV1()" + ] + }, + { + "cell_type": "markdown", + "id": "cf8708a4-0a55-4309-b3c4-dbf47d61ad31", + "metadata": {}, + "source": [ + "next we will do manual curation. this is done with figurl. to incorporate info from other stages of processing (e.g. metrics) we have to store that with kachery cloud and get curation uri referring to it. it can be done with `generate_curation_uri`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "924cdfce-e287-41d7-abf9-872797637777", + "metadata": {}, + "outputs": [], + "source": [ + "curation_uri = sgs.FigURLCurationSelection.generate_curation_uri({\"sorting_id\":(sgs.MetricCurationSelection & {'metric_curation_id':key['metric_curation_id']}).fetch1(\"sorting_id\"),\n", + " \"curation_id\":1})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1987fab4-9d4f-47dc-9546-90086fba7919", + "metadata": {}, + "outputs": [], + "source": [ + "key = {\"sorting_id\":(sgs.MetricCurationSelection & {'metric_curation_id':key['metric_curation_id']}).fetch1(\"sorting_id\"),\n", + " \"curation_id\":1,\n", + " \"curation_uri\": curation_uri,\n", + " \"metrics_figurl\":list(metrics.keys())}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18c05728-9a87-4624-bd3b-82038ef68bd8", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.FigURLCurationSelection()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ecb9106b-8f92-4725-a68c-d5233453b3a4", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.FigURLCurationSelection.insert_selection(key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd9667da-794f-4196-9e3d-527d8932d1e9", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.FigURLCurationSelection()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c3b029c6-8dc8-4af3-ad42-8a9443e70023", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.FigURLCuration.populate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bf6eb76-4883-4436-a320-7ade5c3af910", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.FigURLCuration()" + ] + }, + { + "cell_type": "markdown", + "id": "9ca0d48c-900b-4985-a27a-be1ff82616a4", + "metadata": {}, + "source": [ + "or you can manually specify it if you already have a `curation.json`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2e9b018-9a8b-4344-9b8e-9e2141324bfa", + "metadata": {}, + "outputs": [], + "source": [ + "gh_curation_uri = \"gh://LorenFrankLab/sorting-curations/main/khl02007/test/curation.json\"\n", + "\n", + "key = {\"sorting_id\" : key[\"sorting_id\"],\n", + " \"curation_id\" : 1,\n", + " \"curation_uri\" : gh_curation_uri,\n", + " \"metrics_figurl\" : []}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "449cdcdc-dcff-4aa6-a541-d674ccfbb0b5", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.FigURLCurationSelection.insert_selection(key)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ad86c81e-1424-4fa2-a022-7cc0a3425fc0", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.FigURLCuration.populate()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37847473-1c46-4991-93a0-e315568e675a", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.FigURLCuration()" + ] + }, + { + "cell_type": "markdown", + "id": "6d68f93e-1586-4d3b-b680-0fe2115c0ab4", + "metadata": {}, + "source": [ + "once you apply manual curation (curation labels and merge groups) you can store them as nwb by inserting another row in CurationV1. And then you can do more rounds of curation if you want." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "15694ca0-3ec1-49a8-9ac4-66cf6d6f49ee", + "metadata": {}, + "outputs": [], + "source": [ + "labels = sgs.FigURLCuration.get_labels(gh_curation_uri)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f467a487-100e-4217-914a-60e852805faf", + "metadata": {}, + "outputs": [], + "source": [ + "merge_groups = sgs.FigURLCuration.get_merge_groups(gh_curation_uri)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5873ad89-64cb-427a-a183-f15c2c42907a", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.CurationV1.insert_curation(sorting_id=key[\"sorting_id\"],\n", + " parent_curation_id=1,\n", + " labels=labels,\n", + " merge_groups= merge_groups,\n", + " metrics=metrics,\n", + " description=\"after figurl curation\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d40eb3a-34c5-4771-8fc0-730fafb5cb8a", + "metadata": {}, + "outputs": [], + "source": [ + "sgs.CurationV1()" + ] + }, + { + "cell_type": "markdown", + "id": "9ff6aff5-7020-40d6-832f-006d66d54a7e", + "metadata": {}, + "source": [ + "We now insert the curated spike sorting to a `Merge` table for feeding into downstream processing pipelines." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "511ecb19-7d8d-4db6-be71-c0ed66e2b0f2", + "metadata": {}, + "outputs": [], + "source": [ + "from spyglass.spikesorting.merge import SpikeSortingOutput" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5047f866-7435-4dea-9ed8-a9b2d8365682", + "metadata": {}, + "outputs": [], + "source": [ + "SpikeSortingOutput()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2702410-01e1-4af0-a987-891c42c6c099", + "metadata": {}, + "outputs": [], + "source": [ + "key" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b20c2c9e-0c97-4669-b45d-4b1c50fd2fcc", + "metadata": {}, + "outputs": [], + "source": [ + "SpikeSortingOutput.insert([key], part_name='CurationV1')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "184c3401-8df3-46f0-9dd0-c9fa98395c34", + "metadata": {}, + "outputs": [], + "source": [ + "SpikeSortingOutput.merge_view()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2b083a5-b700-438a-8a06-2e2eb041072d", + "metadata": {}, + "outputs": [], + "source": [ + "SpikeSortingOutput.CurationV1()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10b8afa1-d4a6-4ac1-959b-f4e84e582f2e", + "metadata": {}, + "outputs": [], + "source": [ + "SpikeSortingOutput.CuratedSpikeSorting()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9eebf75-6fef-43c4-80b8-12e59e5d743c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index 99c9a5bdf..bd518db91 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -488,3 +488,44 @@ def interval_set_difference_inds(intervals1, intervals2): i += 1 result += intervals1[i:] return result + + +def interval_list_complement(intervals1, intervals2, min_length=0.0): + """ + Finds intervals in intervals1 that are not in intervals2 + + Parameters + ---------- + min_length : float, optional + Minimum interval length in seconds. Defaults to 0.0. + """ + + result = [] + + for start1, end1 in intervals1: + subtracted = [(start1, end1)] + + for start2, end2 in intervals2: + new_subtracted = [] + + for s, e in subtracted: + if start2 <= s and e <= end2: + continue + + if e <= start2 or end2 <= s: + new_subtracted.append((s, e)) + continue + + if start2 > s: + new_subtracted.append((s, start2)) + + if end2 < e: + new_subtracted.append((end2, e)) + + subtracted = new_subtracted + + result.extend(subtracted) + + return intervals_by_length( + np.asarray(result), min_length=min_length, max_length=1e100 + ) diff --git a/src/spyglass/common/populate_all_common.py b/src/spyglass/common/populate_all_common.py index da0bedac0..e672610d1 100644 --- a/src/spyglass/common/populate_all_common.py +++ b/src/spyglass/common/populate_all_common.py @@ -9,6 +9,7 @@ from .common_nwbfile import Nwbfile from .common_session import Session from .common_task import TaskEpoch +from spyglass.spikesorting.imported import ImportedSpikeSorting def populate_all_common(nwb_file_name): @@ -51,3 +52,6 @@ def populate_all_common(nwb_file_name): print("RawPosition...") PositionSource.insert_from_nwbfile(nwb_file_name) RawPosition.populate(fp) + + print("Populate ImportedSpikeSorting...") + ImportedSpikeSorting.populate(fp) diff --git a/src/spyglass/spikesorting/imported.py b/src/spyglass/spikesorting/imported.py new file mode 100644 index 000000000..028b8ed90 --- /dev/null +++ b/src/spyglass/spikesorting/imported.py @@ -0,0 +1,28 @@ +import datajoint as dj +from spyglass.common.common_session import Session +from spyglass.common.common_nwbfile import Nwbfile +import pynwb + +schema = dj.schema("spikesorting_imported") + + +@schema +class ImportedSpikeSorting(dj.Imported): + definition = """ + -> Session + --- + object_id: varchar(32) + """ + + def make(self, key): + nwb_file_abs_path = Nwbfile.get_abs_path(key["nwb_file_name"]) + + with pynwb.NWBHDF5IO( + nwb_file_abs_path, "r", load_namespaces=True + ) as io: + nwbfile = io.read() + if nwbfile.units: + key["object_id"] = nwbfile.units.object_id + self.insert1(key, skip_duplicates=True) + else: + print("No units found in NWB file") diff --git a/src/spyglass/spikesorting/merge.py b/src/spyglass/spikesorting/merge.py new file mode 100644 index 000000000..587775481 --- /dev/null +++ b/src/spyglass/spikesorting/merge.py @@ -0,0 +1,43 @@ +import uuid + +import datajoint as dj + +from spyglass.spikesorting.imported import ImportedSpikeSorting # noqa: F401 +from spyglass.spikesorting.spikesorting_curation import ( # noqa: F401 + CuratedSpikeSorting, +) +from spyglass.spikesorting.v1.curation import CurationV1 # noqa: F401 +from spyglass.utils.dj_merge_tables import _Merge + +schema = dj.schema("spikesorting_merge") + + +@schema +class SpikeSortingOutput(_Merge): + definition = """ + # Output of spike sorting pipelines. + merge_id: uuid + --- + source: varchar(32) + """ + + class CurationV1(dj.Part): # noqa: F811 + definition = """ + -> master + --- + -> CurationV1 + """ + + class ImportedSpikeSorting(dj.Part): # noqa: F811 + definition = """ + -> master + --- + -> ImportedSpikeSorting + """ + + class CuratedSpikeSorting(dj.Part): # noqa: F811 + definition = """ + -> master + --- + -> CuratedSpikeSorting + """ diff --git a/src/spyglass/spikesorting/v1/__init__.py b/src/spyglass/spikesorting/v1/__init__.py new file mode 100644 index 000000000..992818bdd --- /dev/null +++ b/src/spyglass/spikesorting/v1/__init__.py @@ -0,0 +1,22 @@ +from .artifact import ( + ArtifactDetection, + ArtifactDetectionParameters, + ArtifactDetectionSelection, +) +from .curation import CurationV1 +from .figurl_curation import FigURLCuration, FigURLCurationSelection +from .metric_curation import ( + MetricCuration, + MetricCurationParameters, + MetricCurationSelection, + MetricParameters, + WaveformParameters, +) +from .recording import ( + SortGroup, + SpikeSortingPreprocessingParameters, + SpikeSortingRecording, + SpikeSortingRecordingSelection, +) +from .sorting import SpikeSorterParameters, SpikeSorting, SpikeSortingSelection +from .utils import get_spiking_sorting_v1_merge_ids diff --git a/src/spyglass/spikesorting/v1/artifact.py b/src/spyglass/spikesorting/v1/artifact.py new file mode 100644 index 000000000..d43f0b8b9 --- /dev/null +++ b/src/spyglass/spikesorting/v1/artifact.py @@ -0,0 +1,457 @@ +import uuid +import warnings +from functools import reduce +from typing import List, Union + +import datajoint as dj +import numpy as np +import scipy.stats as stats +import spikeinterface as si +import spikeinterface.extractors as se +from spikeinterface.core.job_tools import ChunkRecordingExecutor, ensure_n_jobs + +from spyglass.common.common_interval import ( + IntervalList, + _union_concat, + interval_from_inds, + interval_list_complement, +) +from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.spikesorting.v1.recording import ( + SpikeSortingRecording, + SpikeSortingRecordingSelection, +) + +schema = dj.schema("spikesorting_v1_artifact") + + +@schema +class ArtifactDetectionParameters(dj.Lookup): + definition = """ + # Parameters for detecting artifacts (non-neural high amplitude events). + artifact_param_name : varchar(200) + --- + artifact_params : blob + """ + + contents = [ + [ + "default", + { + "zscore_thresh": None, + "amplitude_thresh_uV": 3000, + "proportion_above_thresh": 1.0, + "removal_window_ms": 1.0, + "chunk_duration": "10s", + "n_jobs": 4, + "progress_bar": "True", + }, + ], + [ + "none", + { + "zscore_thresh": None, + "amplitude_thresh_uV": None, + "chunk_duration": "10s", + "n_jobs": 4, + "progress_bar": "True", + }, + ], + ] + + @classmethod + def insert_default(cls): + cls.insert(cls.contents, skip_duplicates=True) + + +@schema +class ArtifactDetectionSelection(dj.Manual): + definition = """ + # Processed recording and artifact detection parameters. Use `insert_selection` method to insert new rows. + artifact_id: uuid + --- + -> SpikeSortingRecording + -> ArtifactDetectionParameters + """ + + @classmethod + def insert_selection(cls, key: dict): + """Insert a row into ArtifactDetectionSelection with an + automatically generated unique artifact ID as the sole primary key. + + Parameters + ---------- + key : dict + primary key of SpikeSortingRecording and ArtifactDetectionParameters + + Returns + ------- + artifact_id : str + the unique artifact ID serving as primary key for ArtifactDetectionSelection + """ + query = cls & key + if query: + print("Similar row(s) already inserted.") + return query.fetch(as_dict=True) + key["artifact_id"] = uuid.uuid4() + cls.insert1(key, skip_duplicates=True) + return key + + +@schema +class ArtifactDetection(dj.Computed): + definition = """ + # Detected artifacts (e.g. large transients from movement). + # Intervals are stored in IntervalList with `artifact_id` as `interval_list_name`. + -> ArtifactDetectionSelection + """ + + def make(self, key): + # FETCH: + # - artifact parameters + # - recording analysis nwb file + artifact_params, recording_analysis_nwb_file = ( + ArtifactDetectionParameters + * SpikeSortingRecording + * ArtifactDetectionSelection + & key + ).fetch1("artifact_params", "analysis_file_name") + sort_interval_valid_times = ( + IntervalList + & { + "nwb_file_name": ( + SpikeSortingRecordingSelection * ArtifactDetectionSelection + & key + ).fetch1("nwb_file_name"), + "interval_list_name": ( + SpikeSortingRecordingSelection * ArtifactDetectionSelection + & key + ).fetch1("interval_list_name"), + } + ).fetch1("valid_times") + # DO: + # - load recording + recording_analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path( + recording_analysis_nwb_file + ) + recording = se.read_nwb_recording( + recording_analysis_nwb_file_abs_path, load_time_vector=True + ) + + # - detect artifacts + artifact_removed_valid_times, _ = _get_artifact_times( + recording, + sort_interval_valid_times, + **artifact_params, + ) + + # INSERT + # - into IntervalList + IntervalList.insert1( + dict( + nwb_file_name=( + SpikeSortingRecordingSelection * ArtifactDetectionSelection + & key + ).fetch1("nwb_file_name"), + interval_list_name=str(key["artifact_id"]), + valid_times=artifact_removed_valid_times, + ), + skip_duplicates=True, + ) + # - into ArtifactRemovedInterval + self.insert1(key) + + +def _get_artifact_times( + recording: si.BaseRecording, + sort_interval_valid_times: List[List], + zscore_thresh: Union[float, None] = None, + amplitude_thresh_uV: Union[float, None] = None, + proportion_above_thresh: float = 1.0, + removal_window_ms: float = 1.0, + verbose: bool = False, + **job_kwargs, +): + """Detects times during which artifacts do and do not occur. + + Artifacts are defined as periods where the absolute value of the recording + signal exceeds one or both specified amplitude or z-score thresholds on the + proportion of channels specified, with the period extended by the + removal_window_ms/2 on each side. Z-score and amplitude threshold values of + None are ignored. + + Parameters + ---------- + recording : si.BaseRecording + sort_interval_valid_times : List[List] + The sort interval for the recording, unit: seconds + zscore_thresh : float, optional + Stdev threshold for exclusion, should be >=0, defaults to None + amplitude_thresh_uV : float, optional + Amplitude threshold for exclusion, should be >=0, defaults to None + proportion_above_thresh : float, optional, should be>0 and <=1 + Proportion of electrodes that need to have threshold crossings, defaults to 1 + removal_window_ms : float, optional + Width of the window in milliseconds to mask out per artifact + (window/2 removed on each side of threshold crossing), defaults to 1 ms + + Returns + ------- + artifact_removed_valid_times : np.ndarray + Intervals of valid times where artifacts were not detected, unit: seconds + artifact_intervals : np.ndarray + Intervals in which artifacts are detected (including removal windows), unit: seconds + """ + + valid_timestamps = recording.get_times() + + # if both thresholds are None, we skip artifract detection + if amplitude_thresh_uV is zscore_thresh is None: + print( + "Amplitude and zscore thresholds are both None, " + + "skipping artifact detection" + ) + return np.asarray( + [valid_timestamps[0], valid_timestamps[-1]] + ), np.asarray([]) + + # verify threshold parameters + ( + amplitude_thresh_uV, + zscore_thresh, + proportion_above_thresh, + ) = _check_artifact_thresholds( + amplitude_thresh_uV, zscore_thresh, proportion_above_thresh + ) + + # detect frames that are above threshold in parallel + n_jobs = ensure_n_jobs(recording, n_jobs=job_kwargs.get("n_jobs", 1)) + print(f"Using {n_jobs} jobs...") + func = _compute_artifact_chunk + init_func = _init_artifact_worker + if n_jobs == 1: + init_args = ( + recording, + zscore_thresh, + amplitude_thresh_uV, + proportion_above_thresh, + ) + else: + init_args = ( + recording.to_dict(), + zscore_thresh, + amplitude_thresh_uV, + proportion_above_thresh, + ) + + executor = ChunkRecordingExecutor( + recording, + func, + init_func, + init_args, + verbose=verbose, + handle_returns=True, + job_name="detect_artifact_frames", + **job_kwargs, + ) + + artifact_frames = executor.run() + artifact_frames = np.concatenate(artifact_frames) + + # turn ms to remove total into s to remove from either side of each detected artifact + half_removal_window_s = removal_window_ms / 2 / 1000 + + if len(artifact_frames) == 0: + recording_interval = np.asarray( + [[valid_timestamps[0], valid_timestamps[-1]]] + ) + artifact_times_empty = np.asarray([]) + print("No artifacts detected.") + return recording_interval, artifact_times_empty + + # convert indices to intervals + artifact_intervals = interval_from_inds(artifact_frames) + + # convert to seconds and pad with window + artifact_intervals_s = np.zeros( + (len(artifact_intervals), 2), dtype=np.float64 + ) + for interval_idx, interval in enumerate(artifact_intervals): + interv_ind = [ + np.searchsorted( + valid_timestamps, + valid_timestamps[interval[0]] - half_removal_window_s, + ), + np.searchsorted( + valid_timestamps, + valid_timestamps[interval[1]] + half_removal_window_s, + ), + ] + artifact_intervals_s[interval_idx] = [ + valid_timestamps[interv_ind[0]], + valid_timestamps[interv_ind[1]], + ] + + # make the artifact intervals disjoint + artifact_intervals_s = reduce(_union_concat, artifact_intervals_s) + + # find non-artifact intervals in timestamps + artifact_removed_valid_times = interval_list_complement( + sort_interval_valid_times, artifact_intervals_s, min_length=1 + ) + artifact_removed_valid_times = reduce( + _union_concat, artifact_removed_valid_times + ) + + return artifact_removed_valid_times, artifact_intervals_s + + +def _init_artifact_worker( + recording, + zscore_thresh=None, + amplitude_thresh_uV=None, + proportion_above_thresh=1.0, +): + # create a local dict per worker + worker_ctx = {} + if isinstance(recording, dict): + worker_ctx["recording"] = si.load_extractor(recording) + else: + worker_ctx["recording"] = recording + worker_ctx["zscore_thresh"] = zscore_thresh + worker_ctx["amplitude_thresh_uV"] = amplitude_thresh_uV + worker_ctx["proportion_above_thresh"] = proportion_above_thresh + return worker_ctx + + +def _compute_artifact_chunk(segment_index, start_frame, end_frame, worker_ctx): + recording = worker_ctx["recording"] + zscore_thresh = worker_ctx["zscore_thresh"] + amplitude_thresh_uV = worker_ctx["amplitude_thresh_uV"] + proportion_above_thresh = worker_ctx["proportion_above_thresh"] + # compute the number of electrodes that have to be above threshold + nelect_above = np.ceil( + proportion_above_thresh * len(recording.get_channel_ids()) + ) + + traces = recording.get_traces( + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + ) + + # find the artifact occurrences using one or both thresholds, across channels + if (amplitude_thresh_uV is not None) and (zscore_thresh is None): + above_a = np.abs(traces) > amplitude_thresh_uV + above_thresh = ( + np.ravel(np.argwhere(np.sum(above_a, axis=1) >= nelect_above)) + + start_frame + ) + elif (amplitude_thresh_uV is None) and (zscore_thresh is not None): + dataz = np.abs(stats.zscore(traces, axis=1)) + above_z = dataz > zscore_thresh + above_thresh = ( + np.ravel(np.argwhere(np.sum(above_z, axis=1) >= nelect_above)) + + start_frame + ) + else: + above_a = np.abs(traces) > amplitude_thresh_uV + dataz = np.abs(stats.zscore(traces, axis=1)) + above_z = dataz > zscore_thresh + above_thresh = ( + np.ravel( + np.argwhere( + np.sum(np.logical_or(above_z, above_a), axis=1) + >= nelect_above + ) + ) + + start_frame + ) + + return above_thresh + + +def _check_artifact_thresholds( + amplitude_thresh_uV, zscore_thresh, proportion_above_thresh +): + """Alerts user to likely unintended parameters. Not an exhaustive verification. + + Parameters + ---------- + zscore_thresh: float + amplitude_thresh_uV: float + proportion_above_thresh: float + + Return + ------ + zscore_thresh: float + amplitude_thresh_uV: float + proportion_above_thresh: float + + Raise + ------ + ValueError: if signal thresholds are negative + """ + # amplitude or zscore thresholds should be negative, as they are applied to an absolute signal + signal_thresholds = [ + t for t in [amplitude_thresh_uV, zscore_thresh] if t is not None + ] + for t in signal_thresholds: + if t < 0: + raise ValueError( + "Amplitude and Z-Score thresholds must be >= 0, or None" + ) + + # proportion_above_threshold should be in [0:1] inclusive + if proportion_above_thresh < 0: + warnings.warn( + "Warning: proportion_above_thresh must be a proportion >0 and <=1." + f" Using proportion_above_thresh = 0.01 instead of {str(proportion_above_thresh)}" + ) + proportion_above_thresh = 0.01 + elif proportion_above_thresh > 1: + warnings.warn( + "Warning: proportion_above_thresh must be a proportion >0 and <=1. " + f"Using proportion_above_thresh = 1 instead of {str(proportion_above_thresh)}" + ) + proportion_above_thresh = 1 + return amplitude_thresh_uV, zscore_thresh, proportion_above_thresh + + +def merge_intervals(intervals): + """Takes a list of intervals each of which is [start_time, stop_time] + and takes union over intervals that are intersecting + + Parameters + ---------- + intervals : _type_ + _description_ + + Returns + ------- + _type_ + _description_ + """ + if len(intervals) == 0: + return [] + + # Sort the intervals based on their start times + intervals.sort(key=lambda x: x[0]) + + merged = [intervals[0]] + + for i in range(1, len(intervals)): + current_start, current_stop = intervals[i] + last_merged_start, last_merged_stop = merged[-1] + + if current_start <= last_merged_stop: + # Overlapping intervals, merge them + merged[-1] = [ + last_merged_start, + max(last_merged_stop, current_stop), + ] + else: + # Non-overlapping intervals, add the current one to the list + merged.append([current_start, current_stop]) + + return np.asarray(merged) diff --git a/src/spyglass/spikesorting/v1/curation.py b/src/spyglass/spikesorting/v1/curation.py new file mode 100644 index 000000000..6a334d844 --- /dev/null +++ b/src/spyglass/spikesorting/v1/curation.py @@ -0,0 +1,481 @@ +from typing import Dict, List, Union + +import datajoint as dj +import numpy as np +import pynwb +import spikeinterface as si +import spikeinterface.curation as sc +import spikeinterface.extractors as se + +from spyglass.common.common_ephys import Raw +from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.spikesorting.v1.recording import SpikeSortingRecording +from spyglass.spikesorting.v1.sorting import SpikeSorting, SpikeSortingSelection + +schema = dj.schema("spikesorting_v1_curation") + +valid_labels = ["reject", "noise", "artifact", "mua", "accept"] + + +@schema +class CurationV1(dj.Manual): + definition = """ + # Curation of a SpikeSorting. Use `insert_curation` to insert rows. + -> SpikeSorting + curation_id=0: int + --- + parent_curation_id=-1: int + -> AnalysisNwbfile + object_id: varchar(72) + merges_applied: bool + description: varchar(100) + """ + + @classmethod + def insert_curation( + cls, + sorting_id: str, + parent_curation_id: int = -1, + labels: Union[None, Dict[str, List[str]]] = None, + merge_groups: Union[None, List[List[str]]] = None, + apply_merge: bool = False, + metrics: Union[None, Dict[str, Dict[str, float]]] = None, + description: str = "", + ): + """Insert a row into CurationV1. + + Parameters + ---------- + sorting_id : str + The key for the original SpikeSorting + parent_curation_id : int, optional + The curation id of the parent curation + labels : dict or None, optional + curation labels (e.g. good, noise, mua) + merge_groups : dict or None, optional + groups of unit IDs to be merged + metrics : dict or None, optional + Computed quality metrics, one for each neuron + description : str, optional + description of this curation or where it originates; e.g. FigURL + + Note + ---- + Example curation.json (output of figurl): + { + "labelsByUnit": + {"1":["noise","reject"],"10":["noise","reject"]}, + "mergeGroups": + [[11,12],[46,48],[51,54],[50,53]] + } + + Returns + ------- + curation_key : dict + """ + sort_query = cls & {"sorting_id": sorting_id} + parent_curation_id = max(parent_curation_id, -1) + if parent_curation_id == -1: + parent_curation_id = -1 + # check to see if this sorting with a parent of -1 + # has already been inserted and if so, warn the user + query = sort_query & {"parent_curation_id": -1} + if query: + Warning("Sorting has already been inserted.") + return query.fetch("KEY") + + # generate curation ID + existing_curation_ids = sort_query.fetch("curation_id") + curation_id = max(existing_curation_ids, default=-1) + 1 + + # write the curation labels, merge groups, + # and metrics as columns in the units table of NWB + analysis_file_name, object_id = _write_sorting_to_nwb_with_curation( + sorting_id=sorting_id, + labels=labels, + merge_groups=merge_groups, + metrics=metrics, + apply_merge=apply_merge, + ) + + # INSERT + AnalysisNwbfile().add( + (SpikeSortingSelection & {"sorting_id": sorting_id}).fetch1( + "nwb_file_name" + ), + analysis_file_name, + ) + + key = { + "sorting_id": sorting_id, + "curation_id": curation_id, + "parent_curation_id": parent_curation_id, + "analysis_file_name": analysis_file_name, + "object_id": object_id, + "merges_applied": apply_merge, + "description": description, + } + cls.insert1( + key, + skip_duplicates=True, + ) + + return key + + @classmethod + def insert_metric_curation(cls, key: Dict, apply_merge=False): + """Insert a row into CurationV1. + + Parameters + ---------- + key : Dict + primary key of MetricCuration + + Returns + ------- + curation_key : Dict + """ + from spyglass.spikesorting.v1.metric_curation import ( + MetricCuration, + MetricCurationSelection, + ) + + sorting_id, parent_curation_id = (MetricCurationSelection & key).fetch1( + "sorting_id", "curation_id" + ) + + curation_key = cls.insert_curation( + sorting_id=sorting_id, + parent_curation_id=parent_curation_id, + labels=MetricCuration.get_labels(key) or None, + merge_groups=MetricCuration.get_merge_groups(key) or None, + apply_merge=apply_merge, + description=( + "metric curation of sorting id " + + f"{sorting_id}, curation id {parent_curation_id}" + ), + ) + + return curation_key + + @classmethod + def get_recording(cls, key: dict) -> si.BaseRecording: + """Get recording related to this curation as spikeinterface BaseRecording + + Parameters + ---------- + key : dict + primary key of CurationV1 table + """ + + analysis_file_name = ( + SpikeSortingRecording * SpikeSortingSelection & key + ).fetch1("analysis_file_name") + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_file_name + ) + recording = se.read_nwb_recording( + analysis_file_abs_path, load_time_vector=True + ) + recording.annotate(is_filtered=True) + + return recording + + @classmethod + def get_sorting(cls, key: dict) -> si.BaseSorting: + """Get sorting in the analysis NWB file as spikeinterface BaseSorting + + Parameters + ---------- + key : dict + primary key of CurationV1 table + + Returns + ------- + sorting : si.BaseSorting + + """ + recording = cls.get_recording(key) + sampling_frequency = recording.get_sampling_frequency() + analysis_file_name = (CurationV1 & key).fetch1("analysis_file_name") + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_file_name + ) + with pynwb.NWBHDF5IO( + analysis_file_abs_path, "r", load_namespaces=True + ) as io: + nwbf = io.read() + units = nwbf.units.to_dataframe() + units_dict_list = [ + { + unit_id: np.searchsorted(recording.get_times(), spike_times) + for unit_id, spike_times in zip( + units.index, units["spike_times"] + ) + } + ] + + sorting = si.NumpySorting.from_unit_dict( + units_dict_list, sampling_frequency=sampling_frequency + ) + + return sorting + + @classmethod + def get_merged_sorting(cls, key: dict) -> si.BaseSorting: + """Get sorting with merges applied. + + Parameters + ---------- + key : dict + CurationV1 key + + Returns + ------- + sorting : si.BaseSorting + + """ + recording = cls.get_recording(key) + + curation_key = (cls & key).fetch1() + + sorting_analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + curation_key["analysis_file_name"] + ) + si_sorting = se.read_nwb_sorting( + sorting_analysis_file_abs_path, + sampling_frequency=recording.get_sampling_frequency(), + ) + + with pynwb.NWBHDF5IO( + sorting_analysis_file_abs_path, "r", load_namespaces=True + ) as io: + nwbfile = io.read() + nwb_sorting = nwbfile.objects[curation_key["object_id"]] + merge_groups = nwb_sorting["merge_groups"][:] + + if merge_groups: + units_to_merge = _merge_dict_to_list(merge_groups) + return sc.MergeUnitsSorting( + parent_sorting=si_sorting, units_to_merge=units_to_merge + ) + else: + return si_sorting + + +def _write_sorting_to_nwb_with_curation( + sorting_id: str, + labels: Union[None, Dict[str, List[str]]] = None, + merge_groups: Union[None, List[List[str]]] = None, + metrics: Union[None, Dict[str, Dict[str, float]]] = None, + apply_merge: bool = False, +): + """Save sorting to NWB with curation information. + Curation information is saved as columns in the units table of the NWB file. + + Parameters + ---------- + sorting_id : str + key for the sorting + labels : dict or None, optional + curation labels (e.g. good, noise, mua) + merge_groups : list or None, optional + groups of unit IDs to be merged + metrics : dict or None, optional + Computed quality metrics, one for each cell + apply_merge : bool, optional + whether to apply the merge groups to the sorting before saving, by default False + + Returns + ------- + analysis_nwb_file : str + name of analysis NWB file containing the sorting and curation information + object_id : str + object_id of the units table in the analysis NWB file + """ + # FETCH: + # - primary key for the associated sorting and recording + nwb_file_name = (SpikeSortingSelection & {"sorting_id": sorting_id}).fetch1( + "nwb_file_name" + ) + + # get sorting + sorting_analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + (SpikeSorting & {"sorting_id": sorting_id}).fetch1("analysis_file_name") + ) + with pynwb.NWBHDF5IO( + sorting_analysis_file_abs_path, "r", load_namespaces=True + ) as io: + nwbf = io.read() + units = nwbf.units.to_dataframe() + units_dict = { + unit_id: spike_times + for unit_id, spike_times in zip(units.index, units["spike_times"]) + } + + if apply_merge: + for merge_group in merge_groups: + new_unit_id = np.max(list(units_dict.keys())) + 1 + units_dict[new_unit_id] = np.concatenate( + [units_dict[merge_unit_id] for merge_unit_id in merge_group] + ) + for merge_unit_id in merge_group: + units_dict.pop(merge_unit_id, None) + merge_groups = None + + unit_ids = list(units_dict.keys()) + + # create new analysis nwb file + analysis_nwb_file = AnalysisNwbfile().create(nwb_file_name) + analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(analysis_nwb_file) + with pynwb.NWBHDF5IO( + path=analysis_nwb_file_abs_path, + mode="a", + load_namespaces=True, + ) as io: + nwbf = io.read() + # write sorting to the nwb file + for unit_id in unit_ids: + # spike_times = sorting.get_unit_spike_train(unit_id) + nwbf.add_unit( + spike_times=units_dict[unit_id], + id=unit_id, + ) + # add labels, merge groups, metrics + if labels is not None: + label_values = [] + for unit_id in unit_ids: + if unit_id not in labels: + label_values.append([]) + else: + label_values.append(labels[unit_id]) + nwbf.add_unit_column( + name="curation_label", + description="curation label", + data=label_values, + index=True, + ) + if merge_groups is not None: + merge_groups_dict = _list_to_merge_dict(merge_groups, unit_ids) + merge_groups_list = [ + [""] if value == [] else value + for value in merge_groups_dict.values() + ] + nwbf.add_unit_column( + name="merge_groups", + description="merge groups", + data=merge_groups_list, + index=True, + ) + if metrics is not None: + for metric, metric_dict in metrics.items(): + metric_values = [] + for unit_id in unit_ids: + if unit_id not in metric_dict: + metric_values.append([]) + else: + metric_values.append(metric_dict[unit_id]) + nwbf.add_unit_column( + name=metric, + description=metric, + data=metric_values, + ) + + units_object_id = nwbf.units.object_id + io.write(nwbf) + return analysis_nwb_file, units_object_id + + +def _union_intersecting_lists(lists): + result = [] + + while lists: + first, *rest = lists + first = set(first) + + merged = True + while merged: + merged = False + for idx, other in enumerate(rest): + if first.intersection(other): + first.update(other) + del rest[idx] + merged = True + break + + result.append(list(first)) + lists = rest + + return result + + +def _list_to_merge_dict( + merge_group_list: List[List], all_unit_ids: List +) -> dict: + """Converts a list of merge groups to a dict. + + Parameters + ---------- + merge_group_list : list of list + list of merge groups (list of unit IDs to be merged) + all_unit_ids : list + list of unit IDs for all units in the sorting + + Returns + ------- + merge_dict : dict + dict of merge groups; + keys are unit IDs and values are the units to be merged + + Example + ------- + Input: [[1,2,3],[4,5]], [1,2,3,4,5,6] + Output: {1: [2, 3], 2:[1,3], 3:[1,2] 4: [5], 5: [4], 6: []} + """ + merge_group_list = _union_intersecting_lists(merge_group_list) + merge_dict = {unit_id: [] for unit_id in all_unit_ids} + + for merge_group in merge_group_list: + for unit_id in all_unit_ids: + if unit_id in merge_group: + merge_dict[unit_id].extend( + [ + str(merge_unit_id) + for merge_unit_id in merge_group + if merge_unit_id != unit_id + ] + ) + + return merge_dict + + +def _reverse_associations(assoc_dict): + return [ + [key] + values if values else [key] + for key, values in assoc_dict.items() + ] + + +def _merge_dict_to_list(merge_groups: dict) -> List: + """Converts dict of merge groups to list of merge groups. + Undoes `_list_to_merge_dict`. + + Parameters + ---------- + merge_dict : dict + dict of merge groups; + keys are unit IDs and values are the units to be merged + + Returns + ------- + merge_group_list : list of list + list of merge groups (list of unit IDs to be merged) + + Example + ------- + {1: [2, 3], 4: [5]} -> [[1, 2, 3], [4, 5]] + """ + units_to_merge = _union_intersecting_lists( + _reverse_associations(merge_groups) + ) + return [lst for lst in units_to_merge if len(lst) >= 2] diff --git a/src/spyglass/spikesorting/v1/figurl_curation.py b/src/spyglass/spikesorting/v1/figurl_curation.py new file mode 100644 index 000000000..868098325 --- /dev/null +++ b/src/spyglass/spikesorting/v1/figurl_curation.py @@ -0,0 +1,292 @@ +import uuid +from typing import Any, Dict, List, Union + +import datajoint as dj +import kachery_cloud as kcl +import numpy as np +import pynwb +import sortingview.views as vv +import spikeinterface as si +from sortingview.SpikeSortingView import SpikeSortingView + +from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.spikesorting.v1.curation import CurationV1, _merge_dict_to_list +from spyglass.spikesorting.v1.sorting import SpikeSortingSelection + +schema = dj.schema("spikesorting_v1_figurl_curation") + + +@schema +class FigURLCurationSelection(dj.Manual): + definition = """ + # Use `insert_selection` method to insert a row. Use `generate_curation_uri` method to generate a curation uri. + figurl_curation_id: uuid + --- + -> CurationV1 + curation_uri: varchar(1000) # GitHub-based URI to a file to which the manual curation will be saved + metrics_figurl: blob # metrics to display in the figURL + """ + + @classmethod + def insert_selection(cls, key: dict): + """Insert a row into FigURLCurationSelection. + + Parameters + ---------- + key : dict + primary key of `CurationV1`, `curation_uri`, and `metrics_figurl`. + - If `curation_uri` is not provided, it will be generated from `generate_curation_uri` method. + - If `metrics_figurl` is not provided, it will be set to []. + + Returns + ------- + key : dict + primary key of `FigURLCurationSelection` table. + """ + if "curation_uri" not in key: + key["curation_uri"] = cls.generate_curation_uri(key) + if "metrics_figurl" not in key: + key["metrics_figurl"] = [] + if "figurl_curation_id" in key: + query = cls & {"figurl_curation_id": key["figurl_curation_id"]} + if query: + print("Similar row(s) already inserted.") + return query.fetch(as_dict=True) + key["figurl_curation_id"] = uuid.uuid4() + cls.insert1(key, skip_duplicates=True) + return key + + @staticmethod + def generate_curation_uri(key: Dict) -> str: + """Generates a kachery-cloud URI from a row in CurationV1 table + + Parameters + ---------- + key : dict + primary key from CurationV1 + """ + curation_key = (CurationV1 & key).fetch1() + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + curation_key["analysis_file_name"] + ) + with pynwb.NWBHDF5IO( + analysis_file_abs_path, "r", load_namespaces=True + ) as io: + nwbfile = io.read() + nwb_sorting = nwbfile.objects[ + curation_key["object_id"] + ].to_dataframe() + unit_ids = list(nwb_sorting.index) + labels = list(nwb_sorting["curation_label"]) + merge_groups = list(nwb_sorting["merge_groups"]) + + unit_ids = [str(unit_id) for unit_id in unit_ids] + + labels_dict = ( + {unit_id: list(label) for unit_id, label in zip(unit_ids, labels)} + if labels + else {} + ) + + merge_groups_list = ( + [ + [str(unit_id) for unit_id in merge_group] + for merge_group in _merge_dict_to_list( + dict(zip(unit_ids, merge_groups)) + ) + ] + if merge_groups + else [] + ) + + return kcl.store_json( + { + "labelsByUnit": labels_dict, + "mergeGroups": merge_groups_list, + } + ) + + +@schema +class FigURLCuration(dj.Computed): + definition = """ + # URL to the FigURL for manual curation of a spike sorting. + -> FigURLCurationSelection + --- + url: varchar(1000) + """ + + def make(self, key: dict): + # FETCH + sorting_analysis_file_name = ( + FigURLCurationSelection * CurationV1 & key + ).fetch1("analysis_file_name") + object_id = (FigURLCurationSelection * CurationV1 & key).fetch1( + "object_id" + ) + recording_label = (SpikeSortingSelection & key).fetch1("recording_id") + metrics_figurl = (FigURLCurationSelection & key).fetch1( + "metrics_figurl" + ) + + # DO + sorting_analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + sorting_analysis_file_name + ) + recording = CurationV1.get_recording( + (FigURLCurationSelection & key).fetch1() + ) + sorting = CurationV1.get_sorting( + (FigURLCurationSelection & key).fetch1() + ) + sorting_label = (FigURLCurationSelection & key).fetch1("sorting_id") + curation_uri = (FigURLCurationSelection & key).fetch1("curation_uri") + + metric_dict = {} + with pynwb.NWBHDF5IO( + sorting_analysis_file_abs_path, "r", load_namespaces=True + ) as io: + nwbf = io.read() + nwb_sorting = nwbf.objects[object_id].to_dataframe() + unit_ids = nwb_sorting.index + for metric in metrics_figurl: + metric_dict[metric] = dict(zip(unit_ids, nwb_sorting[metric])) + + unit_metrics = _reformat_metrics(metric_dict) + + # TODO: figure out a way to specify the similarity metrics + + # Generate the figURL + key["url"] = _generate_figurl( + R=recording, + S=sorting, + initial_curation_uri=curation_uri, + recording_label=recording_label, + sorting_label=sorting_label, + unit_metrics=unit_metrics, + ) + + # INSERT + self.insert1(key, skip_duplicates=True) + + @classmethod + def get_labels(cls, curation_json): + labels_by_unit = kcl.load_json(curation_json).get("labelsByUnit") + return ( + { + int(unit_id): curation_label_list + for unit_id, curation_label_list in labels_by_unit.items() + } + if labels_by_unit + else {} + ) + + @classmethod + def get_merge_groups(cls, curation_json): + return kcl.load_json(curation_json).get("mergeGroups", {}) + + +def _generate_figurl( + R: si.BaseRecording, + S: si.BaseSorting, + initial_curation_uri: str, + recording_label: str, + sorting_label: str, + unit_metrics: Union[List[Any], None] = None, + segment_duration_sec=1200, + snippet_ms_before=1, + snippet_ms_after=1, + max_num_snippets_per_segment=1000, + channel_neighborhood_size=5, + raster_plot_subsample_max_firing_rate=50, + spike_amplitudes_subsample_max_firing_rate=50, +) -> str: + print("Preparing spikesortingview data") + recording = R + sorting = S + + sampling_frequency = recording.get_sampling_frequency() + + this_view = SpikeSortingView.create( + recording=recording, + sorting=sorting, + segment_duration_sec=segment_duration_sec, + snippet_len=( + int(snippet_ms_before * sampling_frequency / 1000), + int(snippet_ms_after * sampling_frequency / 1000), + ), + max_num_snippets_per_segment=max_num_snippets_per_segment, + channel_neighborhood_size=channel_neighborhood_size, + ) + + # Assemble the views in a layout. Can be replaced with other layouts. + raster_max_fire = raster_plot_subsample_max_firing_rate + spike_amp_max_fire = spike_amplitudes_subsample_max_firing_rate + + sort_items = [ + vv.MountainLayoutItem( + label="Summary", view=this_view.sorting_summary_view() + ), + vv.MountainLayoutItem( + label="Units table", + view=this_view.units_table_view( + unit_ids=this_view.unit_ids, unit_metrics=unit_metrics + ), + ), + vv.MountainLayoutItem( + label="Raster plot", + view=this_view.raster_plot_view( + unit_ids=this_view.unit_ids, + _subsample_max_firing_rate=raster_max_fire, + ), + ), + vv.MountainLayoutItem( + label="Spike amplitudes", + view=this_view.spike_amplitudes_view( + unit_ids=this_view.unit_ids, + _subsample_max_firing_rate=spike_amp_max_fire, + ), + ), + vv.MountainLayoutItem( + label="Autocorrelograms", + view=this_view.autocorrelograms_view(unit_ids=this_view.unit_ids), + ), + vv.MountainLayoutItem( + label="Cross correlograms", + view=this_view.cross_correlograms_view(unit_ids=this_view.unit_ids), + ), + vv.MountainLayoutItem( + label="Avg waveforms", + view=this_view.average_waveforms_view(unit_ids=this_view.unit_ids), + ), + vv.MountainLayoutItem( + label="Electrode geometry", + view=this_view.electrode_geometry_view(), + ), + vv.MountainLayoutItem( + label="Curation", view=vv.SortingCuration2(), is_control=True + ), + ] + + return vv.MountainLayout(items=sort_items).url( + label=f"{recording_label} {sorting_label}", + state={ + "initialSortingCuration": initial_curation_uri, + "sortingCuration": initial_curation_uri, + }, + ) + + +def _reformat_metrics(metrics: Dict[str, Dict[str, float]]) -> List[Dict]: + return [ + { + "name": metric_name, + "label": metric_name, + "tooltip": metric_name, + "data": { + str(unit_id): metric_value + for unit_id, metric_value in metric.items() + }, + } + for metric_name, metric in metrics.items() + ] diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py new file mode 100644 index 000000000..86fa08ab2 --- /dev/null +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -0,0 +1,590 @@ +import os +import uuid +from typing import Any, Dict, List, Union + +import datajoint as dj +import numpy as np +import pynwb +import spikeinterface as si +import spikeinterface.preprocessing as sp +import spikeinterface.qualitymetrics as sq + +from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.settings import temp_dir +from spyglass.spikesorting.v1.curation import ( + CurationV1, + _list_to_merge_dict, + _merge_dict_to_list, +) +from spyglass.spikesorting.v1.metric_utils import ( + compute_isi_violation_fractions, + get_num_spikes, + get_peak_channel, + get_peak_offset, +) +from spyglass.spikesorting.v1.sorting import SpikeSortingSelection + +schema = dj.schema("spikesorting_v1_metric_curation") + + +_metric_name_to_func = { + "snr": sq.compute_snrs, + "isi_violation": compute_isi_violation_fractions, + "nn_isolation": sq.nearest_neighbors_isolation, + "nn_noise_overlap": sq.nearest_neighbors_noise_overlap, + "peak_offset": get_peak_offset, + "peak_channel": get_peak_channel, + "num_spikes": get_num_spikes, +} + +_comparison_to_function = { + "<": np.less, + "<=": np.less_equal, + ">": np.greater, + ">=": np.greater_equal, + "==": np.equal, +} + + +@schema +class WaveformParameters(dj.Lookup): + definition = """ + # Parameters for extracting waveforms from the recording based on the sorting. + waveform_param_name: varchar(80) # name of waveform extraction parameters + --- + waveform_params: blob # a dict of waveform extraction parameters + """ + + contents = [ + [ + "default_not_whitened", + { + "ms_before": 0.5, + "ms_after": 0.5, + "max_spikes_per_unit": 5000, + "n_jobs": 5, + "total_memory": "5G", + "whiten": False, + }, + ], + [ + "default_whitened", + { + "ms_before": 0.5, + "ms_after": 0.5, + "max_spikes_per_unit": 5000, + "n_jobs": 5, + "total_memory": "5G", + "whiten": True, + }, + ], + ] + + @classmethod + def insert_default(cls): + cls.insert(cls.contents, skip_duplicates=True) + + +@schema +class MetricParameters(dj.Lookup): + definition = """ + # Parameters for computing quality metrics of sorted units. + metric_param_name: varchar(200) + --- + metric_params: blob + """ + metric_default_param_name = "franklab_default" + metric_default_param = { + "snr": { + "peak_sign": "neg", + "random_chunk_kwargs_dict": { + "num_chunks_per_segment": 20, + "chunk_size": 10000, + "seed": 0, + }, + }, + "isi_violation": {"isi_threshold_ms": 1.5, "min_isi_ms": 0.0}, + "nn_isolation": { + "max_spikes": 1000, + "min_spikes": 10, + "n_neighbors": 5, + "n_components": 7, + "radius_um": 100, + "seed": 0, + }, + "nn_noise_overlap": { + "max_spikes": 1000, + "min_spikes": 10, + "n_neighbors": 5, + "n_components": 7, + "radius_um": 100, + "seed": 0, + }, + "peak_channel": {"peak_sign": "neg"}, + "num_spikes": {}, + } + contents = [[metric_default_param_name, metric_default_param]] + + @classmethod + def insert_default(cls): + cls.insert(cls.contents, skip_duplicates=True) + + @classmethod + def show_available_metrics(self): + for metric in _metric_name_to_func: + metric_doc = _metric_name_to_func[metric].__doc__.split("\n")[0] + print(f"{metric} : {metric_doc}\n") + + +@schema +class MetricCurationParameters(dj.Lookup): + definition = """ + # Parameters for curating a spike sorting based on the metrics. + metric_curation_param_name: varchar(200) + --- + label_params: blob # dict of param to label units + merge_params: blob # dict of param to merge units + """ + + contents = [ + ["default", {"nn_noise_overlap": [">", 0.1, ["noise", "reject"]]}, {}], + ["none", {}, {}], + ] + + @classmethod + def insert_default(cls): + cls.insert(cls.contents, skip_duplicates=True) + + +@schema +class MetricCurationSelection(dj.Manual): + definition = """ + # Spike sorting and parameters for metric curation. Use `insert_selection` to insert a row into this table. + metric_curation_id: uuid + --- + -> CurationV1 + -> WaveformParameters + -> MetricParameters + -> MetricCurationParameters + """ + + @classmethod + def insert_selection(cls, key: dict): + """Insert a row into MetricCurationSelection with an + automatically generated unique metric curation ID as the sole primary key. + + Parameters + ---------- + key : dict + primary key of CurationV1, WaveformParameters, MetricParameters MetricCurationParameters + + Returns + ------- + key : dict + key for the inserted row + """ + if cls & key: + print( + "This row has already been inserted into MetricCurationSelection." + ) + return (cls & key).fetch1() + key["metric_curation_id"] = uuid.uuid4() + cls.insert1(key, skip_duplicates=True) + return key + + +@schema +class MetricCuration(dj.Computed): + definition = """ + # Results of applying curation based on quality metrics. To do additional curation, insert another row in `CurationV1` + -> MetricCurationSelection + --- + -> AnalysisNwbfile + object_id: varchar(40) # Object ID for the metrics in NWB file + """ + + def make(self, key): + # FETCH + nwb_file_name = ( + SpikeSortingSelection * MetricCurationSelection & key + ).fetch1("nwb_file_name") + + waveform_params = ( + WaveformParameters * MetricCurationSelection & key + ).fetch1("waveform_params") + metric_params = ( + MetricParameters * MetricCurationSelection & key + ).fetch1("metric_params") + label_params, merge_params = ( + MetricCurationParameters * MetricCurationSelection & key + ).fetch1("label_params", "merge_params") + sorting_id, curation_id = (MetricCurationSelection & key).fetch1( + "sorting_id", "curation_id" + ) + # DO + # load recording and sorting + recording = CurationV1.get_recording( + {"sorting_id": sorting_id, "curation_id": curation_id} + ) + sorting = CurationV1.get_sorting( + {"sorting_id": sorting_id, "curation_id": curation_id} + ) + # extract waveforms + if "whiten" in waveform_params: + if waveform_params.pop("whiten"): + recording = sp.whiten(recording, dtype=np.float64) + + waveforms_dir = temp_dir + "/" + str(key["metric_curation_id"]) + try: + os.mkdir(waveforms_dir) + except FileExistsError: + pass + print("Extracting waveforms...") + waveforms = si.extract_waveforms( + recording=recording, + sorting=sorting, + folder=waveforms_dir, + overwrite=True, + **waveform_params, + ) + # compute metrics + print("Computing metrics...") + metrics = {} + for metric_name, metric_param_dict in metric_params.items(): + metrics[metric_name] = self._compute_metric( + waveforms, metric_name, **metric_param_dict + ) + if metrics["nn_isolation"]: + metrics["nn_isolation"] = { + unit_id: value[0] + for unit_id, value in metrics["nn_isolation"].items() + } + + print("Applying curation...") + labels = self._compute_labels(metrics, label_params) + merge_groups = self._compute_merge_groups(metrics, merge_params) + + print("Saving to NWB...") + ( + key["analysis_file_name"], + key["object_id"], + ) = _write_metric_curation_to_nwb( + nwb_file_name, waveforms, metrics, labels, merge_groups + ) + + # INSERT + AnalysisNwbfile().add( + nwb_file_name, + key["analysis_file_name"], + ) + self.insert1(key) + + @classmethod + def get_waveforms(cls): + return NotImplementedError + + @classmethod + def get_metrics(cls, key: dict): + """Returns metrics identified by metric curation + + Parameters + ---------- + key : dict + primary key to MetricCuration + """ + analysis_file_name, object_id, metric_param_name, metric_params = ( + cls * MetricCurationSelection * MetricParameters & key + ).fetch1( + "analysis_file_name", + "object_id", + "metric_param_name", + "metric_params", + ) + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_file_name + ) + with pynwb.NWBHDF5IO( + path=analysis_file_abs_path, + mode="r", + load_namespaces=True, + ) as io: + nwbf = io.read() + units = nwbf.objects[object_id].to_dataframe() + return { + name: dict(zip(units.index, units[name])) for name in metric_params + } + + @classmethod + def get_labels(cls, key: dict): + """Returns curation labels identified by metric curation + + Parameters + ---------- + key : dict + primary key to MetricCuration + """ + analysis_file_name, object_id = (cls & key).fetch1( + "analysis_file_name", "object_id" + ) + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_file_name + ) + with pynwb.NWBHDF5IO( + path=analysis_file_abs_path, + mode="r", + load_namespaces=True, + ) as io: + nwbf = io.read() + units = nwbf.objects[object_id].to_dataframe() + return dict(zip(units.index, units["curation_label"])) + + @classmethod + def get_merge_groups(cls, key: dict): + """Returns merge groups identified by metric curation + + Parameters + ---------- + key : dict + primary key to MetricCuration + """ + analysis_file_name, object_id = (cls & key).fetch1( + "analysis_file_name", "object_id" + ) + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_file_name + ) + with pynwb.NWBHDF5IO( + path=analysis_file_abs_path, + mode="r", + load_namespaces=True, + ) as io: + nwbf = io.read() + units = nwbf.objects[object_id].to_dataframe() + merge_group_dict = dict(zip(units.index, units["merge_groups"])) + + return _merge_dict_to_list(merge_group_dict) + + @staticmethod + def _compute_metric(waveform_extractor, metric_name, **metric_params): + metric_func = _metric_name_to_func[metric_name] + + peak_sign_metrics = ["snr", "peak_offset", "peak_channel"] + if metric_name in peak_sign_metrics: + if "peak_sign" not in metric_params: + raise Exception( + f"{peak_sign_metrics} metrics require peak_sign", + "to be defined in the metric parameters", + ) + return metric_func( + waveform_extractor, + peak_sign=metric_params.pop("peak_sign"), + **metric_params, + ) + + return { + unit_id: metric_func(waveform_extractor, this_unit_id=unit_id) + for unit_id in waveform_extractor.sorting.get_unit_ids() + } + + @staticmethod + def _compute_labels( + metrics: Dict[str, Dict[str, Union[float, List[float]]]], + label_params: Dict[str, List[Any]], + ) -> Dict[str, List[str]]: + """Computes the labels based on the metric and label parameters. + + Parameters + ---------- + quality_metrics : dict + Example: {"snr" : {"1" : 2, "2" : 0.1, "3" : 2.3}} + This indicates that the values of the "snr" quality metric + for the units "1", "2", "3" are 2, 0.1, and 2.3, respectively. + + label_params : dict + Example: { + "snr" : [(">", 1, ["good", "mua"]), + ("<", 1, ["noise"])] + } + This indicates that units with values of the "snr" quality metric + greater than 1 should be given the labels "good" and "mua" and values + less than 1 should be given the label "noise". + + Returns + ------- + labels : dict + Example: {"1" : ["good", "mua"], "2" : ["noise"], "3" : ["good", "mua"]} + + """ + if not label_params: + return {} + + unit_ids = [ + unit_id for unit_id in metrics[list(metrics.keys())[0]].keys() + ] + labels = {unit_id: [] for unit_id in unit_ids} + + for metric in label_params: + if metric not in metrics: + Warning(f"{metric} not found in quality metrics; skipping") + continue + + condition = label_params[metric] + if not len(condition) == 3: + raise ValueError(f"Condition {condition} must be of length 3") + + compare = _comparison_to_function[condition[0]] + for unit_id in unit_ids: + if compare( + metrics[metric][unit_id], + condition[1], + ): + labels[unit_id].extend(label_params[metric][2]) + return labels + + @staticmethod + def _compute_merge_groups( + metrics: Dict[str, Dict[str, Union[float, List[float]]]], + merge_params: Dict[str, List[Any]], + ) -> Dict[str, List[str]]: + """Identifies units to be merged based on the metrics and merge parameters. + + Parameters + --------- + quality_metrics : dict + Example: {"cosine_similarity" : { + "1" : {"1" : 1.00, "2" : 0.10, "3": 0.95}, + "2" : {"1" : 0.10, "2" : 1.00, "3": 0.70}, + "3" : {"1" : 0.95, "2" : 0.70, "3": 1.00} + }} + This shows the pairwise values of the "cosine_similarity" quality metric + for the units "1", "2", "3" as a nested dict. + + merge_params : dict + Example: {"cosine_similarity" : [">", 0.9]} + This indicates that units with values of the "cosine_similarity" quality metric + greater than 0.9 should be placed in the same merge group. + + + Returns + ------- + merge_groups : dict + Example: {"1" : ["3"], "2" : [], "3" : ["1"]} + + """ + + if not merge_params: + return [] + + unit_ids = list(metrics[list(metrics.keys())[0]].keys()) + merge_groups = {unit_id: [] for unit_id in unit_ids} + for metric in merge_params: + if metric not in metrics: + Warning(f"{metric} not found in quality metrics; skipping") + continue + compare = _comparison_to_function[merge_params[metric][0]] + for unit_id in unit_ids: + other_unit_ids = [ + other_unit_id + for other_unit_id in unit_ids + if other_unit_id != unit_id + ] + for other_unit_id in other_unit_ids: + if compare( + metrics[metric][unit_id][other_unit_id], + merge_params[metric][1], + ): + merge_groups[unit_id].extend(other_unit_id) + return merge_groups + + +def _write_metric_curation_to_nwb( + nwb_file_name: str, + waveforms: si.WaveformExtractor, + metrics: Union[None, Dict[str, Dict[str, float]]] = None, + labels: Union[None, Dict[str, List[str]]] = None, + merge_groups: Union[None, List[List[str]]] = None, +): + """Save waveforms, metrics, labels, and merge groups to NWB in the units table. + + Parameters + ---------- + sorting_id : str + key for the sorting + labels : dict or None, optional + curation labels (e.g. good, noise, mua) + merge_groups : list or None, optional + groups of unit IDs to be merged + metrics : dict or None, optional + Computed quality metrics, one for each cell + apply_merge : bool, optional + whether to apply the merge groups to the sorting before saving, by default False + + Returns + ------- + analysis_nwb_file : str + name of analysis NWB file containing the sorting and curation information + object_id : str + object_id of the units table in the analysis NWB file + """ + + unit_ids = [int(i) for i in waveforms.sorting.get_unit_ids()] + + # create new analysis nwb file + analysis_nwb_file = AnalysisNwbfile().create(nwb_file_name) + analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(analysis_nwb_file) + with pynwb.NWBHDF5IO( + path=analysis_nwb_file_abs_path, + mode="a", + load_namespaces=True, + ) as io: + nwbf = io.read() + # Write waveforms to the nwb file + for unit_id in unit_ids: + nwbf.add_unit( + spike_times=waveforms.sorting.get_unit_spike_train(unit_id), + id=unit_id, + electrodes=waveforms.recording.get_channel_ids(), + waveforms=waveforms.get_waveforms(unit_id), + waveform_mean=waveforms.get_template(unit_id), + ) + + # add labels, merge groups, metrics + if labels is not None: + label_values = [] + for unit_id in unit_ids: + if unit_id not in labels: + label_values.append([]) + else: + label_values.append(labels[unit_id]) + nwbf.add_unit_column( + name="curation_label", + description="curation label", + data=label_values, + index=True, + ) + if merge_groups is not None: + merge_groups_dict = _list_to_merge_dict(merge_groups, unit_ids) + merge_groups_list = [ + [""] for i in merge_groups_dict.values() if i == [] + ] + nwbf.add_unit_column( + name="merge_groups", + description="merge groups", + data=merge_groups_list, + index=True, + ) + if metrics is not None: + for metric, metric_dict in metrics.items(): + metric_values = [ + metric_dict[unit_id] if unit_id in metric_dict else [] + for unit_id in unit_ids + ] + nwbf.add_unit_column( + name=metric, + description=metric, + data=metric_values, + ) + + units_object_id = nwbf.units.object_id + io.write(nwbf) + return analysis_nwb_file, units_object_id diff --git a/src/spyglass/spikesorting/v1/metric_utils.py b/src/spyglass/spikesorting/v1/metric_utils.py new file mode 100644 index 000000000..ab440c3d9 --- /dev/null +++ b/src/spyglass/spikesorting/v1/metric_utils.py @@ -0,0 +1,69 @@ +import spikeinterface as si +import spikeinterface.qualitymetrics as sq +import numpy as np + + +def compute_isi_violation_fractions( + waveform_extractor: si.WaveformExtractor, + this_unit_id: str, + isi_threshold_ms: float = 2.0, + min_isi_ms: float = 0.0, +): + """Computes the fraction of interspike interval violations. + + Parameters + ---------- + waveform_extractor: si.WaveformExtractor + The extractor object for the recording. + + """ + + # Extract the total number of spikes that violated the isi_threshold for each unit + _, isi_violation_counts = sq.compute_isi_violations( + waveform_extractor, + isi_threshold_ms=isi_threshold_ms, + min_isi_ms=min_isi_ms, + ) + num_spikes = sq.compute_num_spikes(waveform_extractor) + return isi_violation_counts[this_unit_id] / (num_spikes[this_unit_id] - 1) + + +def get_peak_offset( + waveform_extractor: si.WaveformExtractor, peak_sign: str, **metric_params +): + """Computes the shift of the waveform peak from center of window. + + Parameters + ---------- + waveform_extractor: si.WaveformExtractor + The extractor object for the recording. + peak_sign: str + The sign of the peak to compute. ('neg', 'pos', 'both') + """ + if "peak_sign" in metric_params: + del metric_params["peak_sign"] + peak_offset_inds = si.get_template_extremum_channel_peak_shift( + waveform_extractor=waveform_extractor, + peak_sign=peak_sign, + **metric_params, + ) + return {key: int(abs(val)) for key, val in peak_offset_inds.items()} + + +def get_peak_channel( + waveform_extractor: si.WaveformExtractor, peak_sign: str, **metric_params +): + """Computes the electrode_id of the channel with the extremum peak for each unit.""" + if "peak_sign" in metric_params: + del metric_params["peak_sign"] + peak_channel_dict = si.get_template_extremum_channel( + waveform_extractor=waveform_extractor, + peak_sign=peak_sign, + **metric_params, + ) + return {key: int(val) for key, val in peak_channel_dict.items()} + + +def get_num_spikes(waveform_extractor: si.WaveformExtractor, this_unit_id: str): + """Computes the number of spikes for each unit.""" + return sq.compute_num_spikes(waveform_extractor)[this_unit_id] diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py new file mode 100644 index 000000000..7ec18ebff --- /dev/null +++ b/src/spyglass/spikesorting/v1/recording.py @@ -0,0 +1,860 @@ +import uuid +from typing import Iterable, List, Optional, Tuple, Union + +import datajoint as dj +import numpy as np +import probeinterface as pi +import pynwb +import spikeinterface as si +import spikeinterface.extractors as se +from hdmf.data_utils import GenericDataChunkIterator + +from spyglass.common import Session # noqa: F401 +from spyglass.common.common_device import Probe +from spyglass.common.common_ephys import Electrode, Raw # noqa: F401 +from spyglass.common.common_interval import ( + IntervalList, + interval_list_intersect, +) +from spyglass.common.common_lab import LabTeam +from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile + +schema = dj.schema("spikesorting_v1_recording") + + +@schema +class SortGroup(dj.Manual): + definition = """ + # Set of electrodes to spike sort together + -> Session + sort_group_id: int + --- + sort_reference_electrode_id = -1: int # the electrode to use for referencing + # -1: no reference, -2: common median + """ + + class SortGroupElectrode(dj.Part): + definition = """ + -> SortGroup + -> Electrode + """ + + @classmethod + def set_group_by_shank( + cls, + nwb_file_name: str, + references: dict = None, + omit_ref_electrode_group=False, + omit_unitrode=True, + ): + """Divides electrodes into groups based on their shank position. + + * Electrodes from probes with 1 shank (e.g. tetrodes) are placed in a + single group + * Electrodes from probes with multiple shanks (e.g. polymer probes) are + placed in one group per shank + * Bad channels are omitted + + Parameters + ---------- + nwb_file_name : str + the name of the NWB file whose electrodes should be put into + sorting groups + references : dict, optional + If passed, used to set references. Otherwise, references set using + original reference electrodes from config. Keys: electrode groups. + Values: reference electrode. + omit_ref_electrode_group : bool + Optional. If True, no sort group is defined for electrode group of + reference. + omit_unitrode : bool + Optional. If True, no sort groups are defined for unitrodes. + """ + # delete any current groups + # (SortGroup & {"nwb_file_name": nwb_file_name}).delete() + # get the electrodes from this NWB file + electrodes = ( + Electrode() + & {"nwb_file_name": nwb_file_name} + & {"bad_channel": "False"} + ).fetch() + e_groups = list(np.unique(electrodes["electrode_group_name"])) + e_groups.sort(key=int) # sort electrode groups numerically + sort_group = 0 + sg_key = dict() + sge_key = dict() + sg_key["nwb_file_name"] = sge_key["nwb_file_name"] = nwb_file_name + for e_group in e_groups: + # for each electrode group, get a list of the unique shank numbers + shank_list = np.unique( + electrodes["probe_shank"][ + electrodes["electrode_group_name"] == e_group + ] + ) + sge_key["electrode_group_name"] = e_group + # get the indices of all electrodes in this group / shank and set their sorting group + for shank in shank_list: + sg_key["sort_group_id"] = sge_key["sort_group_id"] = sort_group + # specify reference electrode. Use 'references' if passed, otherwise use reference from config + if not references: + shank_elect_ref = electrodes[ + "original_reference_electrode" + ][ + np.logical_and( + electrodes["electrode_group_name"] == e_group, + electrodes["probe_shank"] == shank, + ) + ] + if np.max(shank_elect_ref) == np.min(shank_elect_ref): + sg_key["sort_reference_electrode_id"] = shank_elect_ref[ + 0 + ] + else: + ValueError( + f"Error in electrode group {e_group}: reference " + + "electrodes are not all the same" + ) + else: + if e_group not in references.keys(): + raise Exception( + f"electrode group {e_group} not a key in " + + "references, so cannot set reference" + ) + else: + sg_key["sort_reference_electrode_id"] = references[ + e_group + ] + # Insert sort group and sort group electrodes + reference_electrode_group = electrodes[ + electrodes["electrode_id"] + == sg_key["sort_reference_electrode_id"] + ][ + "electrode_group_name" + ] # reference for this electrode group + if ( + len(reference_electrode_group) == 1 + ): # unpack single reference + reference_electrode_group = reference_electrode_group[0] + elif (int(sg_key["sort_reference_electrode_id"]) > 0) and ( + len(reference_electrode_group) != 1 + ): + raise Exception( + "Should have found exactly one electrode group for " + + "reference electrode, but found " + + f"{len(reference_electrode_group)}." + ) + if omit_ref_electrode_group and ( + str(e_group) == str(reference_electrode_group) + ): + print( + f"Omitting electrode group {e_group} from sort groups " + + "because contains reference." + ) + continue + shank_elect = electrodes["electrode_id"][ + np.logical_and( + electrodes["electrode_group_name"] == e_group, + electrodes["probe_shank"] == shank, + ) + ] + if ( + omit_unitrode and len(shank_elect) == 1 + ): # omit unitrodes if indicated + print( + f"Omitting electrode group {e_group}, shank {shank} from sort groups because unitrode." + ) + continue + cls.insert1(sg_key, skip_duplicates=True) + for elect in shank_elect: + sge_key["electrode_id"] = elect + cls.SortGroupElectrode().insert1( + sge_key, skip_duplicates=True + ) + sort_group += 1 + + +@schema +class SpikeSortingPreprocessingParameters(dj.Lookup): + definition = """ + # Parameters for denoising a recording prior to spike sorting. + preproc_param_name: varchar(200) + --- + preproc_params: blob + """ + + contents = [ + [ + "default", + { + "frequency_min": 300, # high pass filter value + "frequency_max": 6000, # low pass filter value + "margin_ms": 5, # margin in ms on border to avoid border effect + "seed": 0, # random seed for whitening + "min_segment_length": 1, # minimum segment length in seconds + }, + ] + ] + + @classmethod + def insert_default(cls): + cls.insert(cls.contents, skip_duplicates=True) + + +@schema +class SpikeSortingRecordingSelection(dj.Manual): + definition = """ + # Raw voltage traces and parameters. Use `insert_selection` method to insert rows. + recording_id: uuid + --- + -> Raw + -> SortGroup + -> IntervalList + -> SpikeSortingPreprocessingParameters + -> LabTeam + """ + + @classmethod + def insert_selection(cls, key: dict): + """Insert a row into SpikeSortingRecordingSelection with an + automatically generated unique recording ID as the sole primary key. + + Parameters + ---------- + key : dict + primary key of Raw, SortGroup, IntervalList, + SpikeSortingPreprocessingParameters, LabTeam tables + + Returns + ------- + primary key of SpikeSortingRecordingSelection table + """ + query = cls & key + if query: + print("Similar row(s) already inserted.") + return query.fetch(as_dict=True) + key["recording_id"] = uuid.uuid4() + cls.insert1(key, skip_duplicates=True) + return key + + +@schema +class SpikeSortingRecording(dj.Computed): + definition = """ + # Processed recording. + -> SpikeSortingRecordingSelection + --- + -> AnalysisNwbfile + object_id: varchar(40) # Object ID for the processed recording in NWB file + """ + + def make(self, key): + # DO: + # - get valid times for sort interval + # - proprocess recording + # - write recording to NWB file + sort_interval_valid_times = self._get_sort_interval_valid_times(key) + recording, timestamps = self._get_preprocessed_recording(key) + recording_nwb_file_name, recording_object_id = _write_recording_to_nwb( + recording, + timestamps, + (SpikeSortingRecordingSelection & key).fetch1("nwb_file_name"), + ) + key["analysis_file_name"] = recording_nwb_file_name + key["object_id"] = recording_object_id + + # INSERT: + # - valid times into IntervalList + # - analysis NWB file holding processed recording into AnalysisNwbfile + # - entry into SpikeSortingRecording + IntervalList.insert1( + { + "nwb_file_name": (SpikeSortingRecordingSelection & key).fetch1( + "nwb_file_name" + ), + "interval_list_name": key["recording_id"], + "valid_times": sort_interval_valid_times, + } + ) + AnalysisNwbfile().add( + (SpikeSortingRecordingSelection & key).fetch1("nwb_file_name"), + key["analysis_file_name"], + ) + self.insert1(key) + + @classmethod + def get_recording(cls, key: dict) -> si.BaseRecording: + """Get recording related to this curation as spikeinterface BaseRecording + + Parameters + ---------- + key : dict + primary key of SpikeSorting table + """ + + analysis_file_name = (cls & key).fetch1("analysis_file_name") + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_file_name + ) + recording = se.read_nwb_recording( + analysis_file_abs_path, load_time_vector=True + ) + + return recording + + @staticmethod + def _get_recording_timestamps(recording): + if recording.get_num_segments() > 1: + frames_per_segment = [0] + for i in range(recording.get_num_segments()): + frames_per_segment.append( + recording.get_num_frames(segment_index=i) + ) + + cumsum_frames = np.cumsum(frames_per_segment) + total_frames = np.sum(frames_per_segment) + + timestamps = np.zeros((total_frames,)) + for i in range(recording.get_num_segments()): + timestamps[ + cumsum_frames[i] : cumsum_frames[i + 1] + ] = recording.get_times(segment_index=i) + else: + timestamps = recording.get_times() + return timestamps + + def _get_sort_interval_valid_times(self, key: dict): + """Identifies the intersection between sort interval specified by the user + and the valid times (times for which neural data exist, excluding e.g. dropped packets). + + Parameters + ---------- + key: dict + primary key of SpikeSortingRecordingSelection table + + Returns + ------- + sort_interval_valid_times: ndarray of tuples + (start, end) times for valid intervals in the sort interval + + """ + # FETCH: - sort interval - valid times - preprocessing parameters + nwb_file_name, sort_interval_name, params = ( + SpikeSortingPreprocessingParameters * SpikeSortingRecordingSelection + & key + ).fetch1("nwb_file_name", "interval_list_name", "preproc_params") + + sort_interval = ( + IntervalList + & { + "nwb_file_name": nwb_file_name, + "interval_list_name": sort_interval_name, + } + ).fetch1("valid_times") + + valid_interval_times = ( + IntervalList + & { + "nwb_file_name": nwb_file_name, + "interval_list_name": "raw data valid times", + } + ).fetch1("valid_times") + + # DO: - take intersection between sort interval and valid times + return interval_list_intersect( + sort_interval, + valid_interval_times, + min_length=params["min_segment_length"], + ) + + def _get_preprocessed_recording(self, key: dict): + """Filters and references a recording. + + - Loads the NWB file created during insertion as a spikeinterface Recording + - Slices recording in time (interval) and space (channels); + recording chunks from disjoint intervals are concatenated + - Applies referencing and bandpass filtering + + Parameters + ---------- + key: dict + primary key of SpikeSortingRecordingSelection table + + Returns + ------- + recording: si.Recording + """ + # FETCH: + # - full path to NWB file + # - channels to be included in the sort + # - the reference channel + # - probe type + # - filter parameters + nwb_file_name = (SpikeSortingRecordingSelection & key).fetch1( + "nwb_file_name" + ) + sort_group_id = (SpikeSortingRecordingSelection & key).fetch1( + "sort_group_id" + ) + nwb_file_abs_path = Nwbfile().get_abs_path(nwb_file_name) + channel_ids = ( + SortGroup.SortGroupElectrode + & { + "nwb_file_name": nwb_file_name, + "sort_group_id": sort_group_id, + } + ).fetch("electrode_id") + ref_channel_id = ( + SortGroup + & { + "nwb_file_name": nwb_file_name, + "sort_group_id": sort_group_id, + } + ).fetch1("sort_reference_electrode_id") + recording_channel_ids = np.setdiff1d(channel_ids, ref_channel_id) + all_channel_ids = np.unique(channel_ids + ref_channel_id) + + probe_type_by_channel = [] + electrode_group_by_channel = [] + for channel_id in channel_ids: + probe_type_by_channel.append( + ( + Electrode * Probe + & { + "nwb_file_name": nwb_file_name, + "electrode_id": channel_id, + } + ).fetch1("probe_type") + ) + electrode_group_by_channel.append( + ( + Electrode + & { + "nwb_file_name": nwb_file_name, + "electrode_id": channel_id, + } + ).fetch1("electrode_group_name") + ) + probe_type = np.unique(probe_type_by_channel) + filter_params = ( + SpikeSortingPreprocessingParameters * SpikeSortingRecordingSelection + & key + ).fetch1("preproc_params") + + # DO: + # - load NWB file as a spikeinterface Recording + # - slice the recording object in time and channels + # - apply referencing depending on the option chosen by the user + # - apply bandpass filter + # - set probe to recording + recording = se.read_nwb_recording( + nwb_file_abs_path, load_time_vector=True + ) + all_timestamps = recording.get_times() + + # TODO: make sure the following works for recordings that don't have explicit timestamps + valid_sort_times = self._get_sort_interval_valid_times(key) + valid_sort_times_indices = _consolidate_intervals( + valid_sort_times, all_timestamps + ) + + # slice in time; concatenate disjoint sort intervals + if len(valid_sort_times_indices) > 1: + recordings_list = [] + timestamps = [] + for interval_indices in valid_sort_times_indices: + recording_single = recording.frame_slice( + start_frame=interval_indices[0], + end_frame=interval_indices[1], + ) + recordings_list.append(recording_single) + timestamps.extend( + all_timestamps[interval_indices[0] : interval_indices[1]] + ) + recording = si.concatenate_recordings(recordings_list) + else: + recording = recording.frame_slice( + start_frame=valid_sort_times_indices[0][0], + end_frame=valid_sort_times_indices[0][1], + ) + timestamps = all_timestamps[ + valid_sort_times_indices[0][0] : valid_sort_times_indices[0][1] + ] + + # slice in channels; include ref channel in first slice, then exclude it in second slice + if ref_channel_id >= 0: + recording = recording.channel_slice(channel_ids=all_channel_ids) + recording = si.preprocessing.common_reference( + recording, + reference="single", + ref_channel_ids=ref_channel_id, + dtype=np.float64, + ) + recording = recording.channel_slice( + channel_ids=recording_channel_ids + ) + elif ref_channel_id == -2: + recording = recording.channel_slice( + channel_ids=recording_channel_ids + ) + recording = si.preprocessing.common_reference( + recording, + reference="global", + operator="median", + dtype=np.float64, + ) + elif ref_channel_id == -1: + recording = recording.channel_slice( + channel_ids=recording_channel_ids + ) + else: + raise ValueError( + "Invalid reference channel ID. Use -1 to skip referencing. Use " + + "-2 to reference via global median. Use positive integer to " + + "reference to a specific channel." + ) + + recording = si.preprocessing.bandpass_filter( + recording, + freq_min=filter_params["frequency_min"], + freq_max=filter_params["frequency_max"], + dtype=np.float64, + ) + + # if the sort group is a tetrode, change the channel location + # (necessary because the channel location for tetrodes are not set properly) + if ( + len(probe_type) == 1 + and probe_type[0] == "tetrode_12.5" + and len(recording_channel_ids) == 4 + and len(np.unique(electrode_group_by_channel)) == 1 + ): + tetrode = pi.Probe(ndim=2) + position = [[0, 0], [0, 12.5], [12.5, 0], [12.5, 12.5]] + tetrode.set_contacts( + position, shapes="circle", shape_params={"radius": 6.25} + ) + tetrode.set_contact_ids(channel_ids) + tetrode.set_device_channel_indices(np.arange(4)) + recording = recording.set_probe(tetrode, in_place=True) + + return recording, np.asarray(timestamps) + + +def _consolidate_intervals(intervals, timestamps): + """Convert a list of intervals (start_time, stop_time) + to a list of intervals (start_index, stop_index) by comparing to a list of timestamps; + then consolidates overlapping or adjacent intervals + + Parameters + ---------- + intervals : iterable of tuples + timestamps : numpy.ndarray + + """ + # Convert intervals to a numpy array if it's not + intervals = np.array(intervals) + if intervals.shape[1] != 2: + raise ValueError( + "Input array must have shape (N, 2) where N is the number of intervals." + ) + # Check if intervals are sorted. If not, sort them. + if not np.all(intervals[:-1] <= intervals[1:]): + intervals = np.sort(intervals, axis=0) + + # Initialize an empty list to store the consolidated intervals + consolidated = [] + + # Convert start and stop times to indices + start_indices = np.searchsorted(timestamps, intervals[:, 0], side="left") + stop_indices = ( + np.searchsorted(timestamps, intervals[:, 1], side="right") - 1 + ) + + # Start with the first interval + start, stop = start_indices[0], stop_indices[0] + + # Loop through the rest of the intervals to join them if needed + for next_start, next_stop in zip(start_indices, stop_indices): + # If the stop time of the current interval is equal to or greater than the next start time minus 1 + if stop >= next_start - 1: + stop = max( + stop, next_stop + ) # Extend the current interval to include the next one + else: + # Add the current interval to the consolidated list + consolidated.append((start, stop)) + start, stop = next_start, next_stop # Start a new interval + + # Add the last interval to the consolidated list + consolidated.append((start, stop)) + + # Convert the consolidated list to a NumPy array and return + return np.array(consolidated) + + +def _write_recording_to_nwb( + recording: si.BaseRecording, + timestamps: Iterable, + nwb_file_name: str, +): + """Write a recording in NWB format + + Parameters + ---------- + recording : si.Recording + timestamps : iterable + nwb_file_name : str + name of NWB file the recording originates + + Returns + ------- + analysis_nwb_file : str + name of analysis NWB file containing the preprocessed recording + """ + + analysis_nwb_file = AnalysisNwbfile().create(nwb_file_name) + analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(analysis_nwb_file) + with pynwb.NWBHDF5IO( + path=analysis_nwb_file_abs_path, + mode="a", + load_namespaces=True, + ) as io: + nwbfile = io.read() + table_region = nwbfile.create_electrode_table_region( + region=[i for i in recording.get_channel_ids()], + description="Sort group", + ) + data_iterator = SpikeInterfaceRecordingDataChunkIterator( + recording=recording, return_scaled=False, buffer_gb=5 + ) + timestamps_iterator = TimestampsDataChunkIterator( + recording=TimestampsExtractor(timestamps), buffer_gb=5 + ) + processed_electrical_series = pynwb.ecephys.ElectricalSeries( + name="ProcessedElectricalSeries", + data=data_iterator, + electrodes=table_region, + timestamps=timestamps_iterator, + filtering="Bandpass filtered for spike band", + description=f"Referenced and filtered recording from {nwb_file_name} for spike sorting", + conversion=np.unique(recording.get_channel_gains())[0] * 1e-6, + ) + nwbfile.add_acquisition(processed_electrical_series) + recording_object_id = nwbfile.acquisition[ + "ProcessedElectricalSeries" + ].object_id + io.write(nwbfile) + return analysis_nwb_file, recording_object_id + + +# For writing recording to NWB file + + +class SpikeInterfaceRecordingDataChunkIterator(GenericDataChunkIterator): + """DataChunkIterator specifically for use on RecordingExtractor objects.""" + + def __init__( + self, + recording: si.BaseRecording, + segment_index: int = 0, + return_scaled: bool = False, + buffer_gb: Optional[float] = None, + buffer_shape: Optional[tuple] = None, + chunk_mb: Optional[float] = None, + chunk_shape: Optional[tuple] = None, + display_progress: bool = False, + progress_bar_options: Optional[dict] = None, + ): + """ + Initialize an Iterable object which returns DataChunks with data and their selections on each iteration. + + Parameters + ---------- + recording : si.BaseRecording + The SpikeInterfaceRecording object which handles the data access. + segment_index : int, optional + The recording segment to iterate on. + Defaults to 0. + return_scaled : bool, optional + Whether to return the trace data in scaled units (uV, if True) or in the raw data type (if False). + Defaults to False. + buffer_gb : float, optional + The upper bound on size in gigabytes (GB) of each selection from the iteration. + The buffer_shape will be set implicitly by this argument. + Cannot be set if `buffer_shape` is also specified. + The default is 1GB. + buffer_shape : tuple, optional + Manual specification of buffer shape to return on each iteration. + Must be a multiple of chunk_shape along each axis. + Cannot be set if `buffer_gb` is also specified. + The default is None. + chunk_mb : float, optional + The upper bound on size in megabytes (MB) of the internal chunk for the HDF5 dataset. + The chunk_shape will be set implicitly by this argument. + Cannot be set if `chunk_shape` is also specified. + The default is 1MB, as recommended by the HDF5 group. For more details, see + https://support.hdfgroup.org/HDF5/doc/TechNotes/TechNote-HDF5-ImprovingIOPerformanceCompressedDatasets.pdf + chunk_shape : tuple, optional + Manual specification of the internal chunk shape for the HDF5 dataset. + Cannot be set if `chunk_mb` is also specified. + The default is None. + display_progress : bool, optional + Display a progress bar with iteration rate and estimated completion time. + progress_bar_options : dict, optional + Dictionary of keyword arguments to be passed directly to tqdm. + See https://github.com/tqdm/tqdm#parameters for options. + """ + self.recording = recording + self.segment_index = segment_index + self.return_scaled = return_scaled + self.channel_ids = recording.get_channel_ids() + super().__init__( + buffer_gb=buffer_gb, + buffer_shape=buffer_shape, + chunk_mb=chunk_mb, + chunk_shape=chunk_shape, + display_progress=display_progress, + progress_bar_options=progress_bar_options, + ) + + def _get_data(self, selection: Tuple[slice]) -> Iterable: + return self.recording.get_traces( + segment_index=self.segment_index, + channel_ids=self.channel_ids[selection[1]], + start_frame=selection[0].start, + end_frame=selection[0].stop, + return_scaled=self.return_scaled, + ) + + def _get_dtype(self): + return self.recording.get_dtype() + + def _get_maxshape(self): + return ( + self.recording.get_num_samples(segment_index=self.segment_index), + self.recording.get_num_channels(), + ) + + +class TimestampsExtractor(si.BaseRecording): + def __init__( + self, + timestamps, + sampling_frequency=30e3, + ): + si.BaseRecording.__init__( + self, sampling_frequency, channel_ids=[0], dtype=np.float64 + ) + rec_segment = TimestampsSegment( + timestamps=timestamps, + sampling_frequency=sampling_frequency, + t_start=None, + dtype=np.float64, + ) + self.add_recording_segment(rec_segment) + + +class TimestampsSegment(si.BaseRecordingSegment): + def __init__(self, timestamps, sampling_frequency, t_start, dtype): + si.BaseRecordingSegment.__init__( + self, sampling_frequency=sampling_frequency, t_start=t_start + ) + self._timeseries = timestamps + + def get_num_samples(self) -> int: + return self._timeseries.shape[0] + + def get_traces( + self, + start_frame: Union[int, None] = None, + end_frame: Union[int, None] = None, + channel_indices: Union[List, None] = None, + ) -> np.ndarray: + return np.squeeze(self._timeseries[start_frame:end_frame]) + + +class TimestampsDataChunkIterator(GenericDataChunkIterator): + """DataChunkIterator specifically for use on RecordingExtractor objects.""" + + def __init__( + self, + recording: si.BaseRecording, + segment_index: int = 0, + return_scaled: bool = False, + buffer_gb: Optional[float] = None, + buffer_shape: Optional[tuple] = None, + chunk_mb: Optional[float] = None, + chunk_shape: Optional[tuple] = None, + display_progress: bool = False, + progress_bar_options: Optional[dict] = None, + ): + """ + Initialize an Iterable object which returns DataChunks with data and their selections on each iteration. + + Parameters + ---------- + recording : SpikeInterfaceRecording + The SpikeInterfaceRecording object (RecordingExtractor or BaseRecording) which handles the data access. + segment_index : int, optional + The recording segment to iterate on. + Defaults to 0. + return_scaled : bool, optional + Whether to return the trace data in scaled units (uV, if True) or in the raw data type (if False). + Defaults to False. + buffer_gb : float, optional + The upper bound on size in gigabytes (GB) of each selection from the iteration. + The buffer_shape will be set implicitly by this argument. + Cannot be set if `buffer_shape` is also specified. + The default is 1GB. + buffer_shape : tuple, optional + Manual specification of buffer shape to return on each iteration. + Must be a multiple of chunk_shape along each axis. + Cannot be set if `buffer_gb` is also specified. + The default is None. + chunk_mb : float, optional + The upper bound on size in megabytes (MB) of the internal chunk for the HDF5 dataset. + The chunk_shape will be set implicitly by this argument. + Cannot be set if `chunk_shape` is also specified. + The default is 1MB, as recommended by the HDF5 group. For more details, see + https://support.hdfgroup.org/HDF5/doc/TechNotes/TechNote-HDF5-ImprovingIOPerformanceCompressedDatasets.pdf + chunk_shape : tuple, optional + Manual specification of the internal chunk shape for the HDF5 dataset. + Cannot be set if `chunk_mb` is also specified. + The default is None. + display_progress : bool, optional + Display a progress bar with iteration rate and estimated completion time. + progress_bar_options : dict, optional + Dictionary of keyword arguments to be passed directly to tqdm. + See https://github.com/tqdm/tqdm#parameters for options. + """ + self.recording = recording + self.segment_index = segment_index + self.return_scaled = return_scaled + self.channel_ids = recording.get_channel_ids() + super().__init__( + buffer_gb=buffer_gb, + buffer_shape=buffer_shape, + chunk_mb=chunk_mb, + chunk_shape=chunk_shape, + display_progress=display_progress, + progress_bar_options=progress_bar_options, + ) + + # change channel id to always be first channel + def _get_data(self, selection: Tuple[slice]) -> Iterable: + return self.recording.get_traces( + segment_index=self.segment_index, + channel_ids=[0], + start_frame=selection[0].start, + end_frame=selection[0].stop, + return_scaled=self.return_scaled, + ) + + def _get_dtype(self): + return self.recording.get_dtype() + + # remove the last dim for the timestamps since it is always just a 1D vector + def _get_maxshape(self): + return ( + self.recording.get_num_samples(segment_index=self.segment_index), + ) diff --git a/src/spyglass/spikesorting/v1/sorting.py b/src/spyglass/spikesorting/v1/sorting.py new file mode 100644 index 000000000..729485d9a --- /dev/null +++ b/src/spyglass/spikesorting/v1/sorting.py @@ -0,0 +1,348 @@ +import os +import tempfile +import time +import uuid +from typing import Iterable + +import datajoint as dj +import numpy as np +import pynwb +import spikeinterface as si +import spikeinterface.curation as sic +import spikeinterface.extractors as se +import spikeinterface.preprocessing as sip +import spikeinterface.sorters as sis +from spikeinterface.sortingcomponents.peak_detection import detect_peaks + +from spyglass.common.common_interval import IntervalList +from spyglass.common.common_lab import LabMember, LabTeam +from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.spikesorting.v1.recording import ( + SpikeSortingRecording, + SpikeSortingRecordingSelection, +) + +from .recording import _consolidate_intervals + +schema = dj.schema("spikesorting_v1_sorting") + + +@schema +class SpikeSorterParameters(dj.Lookup): + definition = """ + # Spike sorting algorithm and associated parameters. + sorter: varchar(200) + sorter_param_name: varchar(200) + --- + sorter_params: blob + """ + contents = [ + [ + "mountainsort4", + "franklab_tetrode_hippocampus_30KHz", + { + "detect_sign": -1, + "adjacency_radius": 100, + "freq_min": 600, + "freq_max": 6000, + "filter": False, + "whiten": True, + "num_workers": 1, + "clip_size": 40, + "detect_threshold": 3, + "detect_interval": 10, + }, + ], + [ + "mountainsort4", + "franklab_probe_ctx_30KHz", + { + "detect_sign": -1, + "adjacency_radius": 100, + "freq_min": 300, + "freq_max": 6000, + "filter": False, + "whiten": True, + "num_workers": 1, + "clip_size": 40, + "detect_threshold": 3, + "detect_interval": 10, + }, + ], + [ + "clusterless_thresholder", + "default_clusterless", + { + "detect_threshold": 100.0, # uV + # Locally exclusive means one unit per spike detected + "method": "locally_exclusive", + "peak_sign": "neg", + "exclude_sweep_ms": 0.1, + "local_radius_um": 100, + # noise levels needs to be 1.0 so the units are in uV and not MAD + "noise_levels": np.asarray([1.0]), + "random_chunk_kwargs": {}, + # output needs to be set to sorting for the rest of the pipeline + "outputs": "sorting", + }, + ], + ] + contents.extend( + [ + [sorter, "default", sis.get_default_sorter_params(sorter)] + for sorter in sis.available_sorters() + ] + ) + + @classmethod + def insert_default(cls): + cls.insert(cls.contents, skip_duplicates=True) + + +@schema +class SpikeSortingSelection(dj.Manual): + definition = """ + # Processed recording and spike sorting parameters. Use `insert_selection` method to insert rows. + sorting_id: uuid + --- + -> SpikeSortingRecording + -> SpikeSorterParameters + -> IntervalList + """ + + @classmethod + def insert_selection(cls, key: dict): + """Insert a row into SpikeSortingSelection with an + automatically generated unique sorting ID as the sole primary key. + + Parameters + ---------- + key : dict + primary key of SpikeSortingRecording, SpikeSorterParameters, IntervalList tables + + Returns + ------- + sorting_id : uuid + the unique sorting ID serving as primary key for SpikeSorting + """ + query = cls & key + if query: + print("Similar row(s) already inserted.") + return query.fetch(as_dict=True) + key["sorting_id"] = uuid.uuid4() + cls.insert1(key, skip_duplicates=True) + return key + + +@schema +class SpikeSorting(dj.Computed): + definition = """ + -> SpikeSortingSelection + --- + -> AnalysisNwbfile + object_id: varchar(40) # Object ID for the sorting in NWB file + time_of_sort: int # in Unix time, to the nearest second + """ + + def make(self, key: dict): + """Runs spike sorting on the data and parameters specified by the + SpikeSortingSelection table and inserts a new entry to SpikeSorting table. + """ + # FETCH: + # - information about the recording + # - artifact free intervals + # - spike sorter and sorter params + recording_key = ( + SpikeSortingRecording * SpikeSortingSelection & key + ).fetch1() + artifact_removed_intervals = ( + IntervalList + & { + "nwb_file_name": (SpikeSortingSelection & key).fetch1( + "nwb_file_name" + ), + "interval_list_name": (SpikeSortingSelection & key).fetch1( + "interval_list_name" + ), + } + ).fetch1("valid_times") + sorter, sorter_params = ( + SpikeSorterParameters * SpikeSortingSelection & key + ).fetch1("sorter", "sorter_params") + + # DO: + # - load recording + # - concatenate artifact removed intervals + # - run spike sorting + # - save output to NWB file + recording_analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path( + recording_key["analysis_file_name"] + ) + recording = se.read_nwb_recording( + recording_analysis_nwb_file_abs_path, load_time_vector=True + ) + + timestamps = recording.get_times() + + artifact_removed_intervals_ind = _consolidate_intervals( + artifact_removed_intervals, timestamps + ) + + # if the artifact removed intervals do not span the entire time range + if ( + (len(artifact_removed_intervals_ind) > 1) + or (artifact_removed_intervals_ind[0][0] > 0) + or (artifact_removed_intervals_ind[-1][1] < len(timestamps)) + ): + # set the artifact intervals to zero + list_triggers = [] + if artifact_removed_intervals_ind[0][0] > 0: + list_triggers.append( + np.array([0, artifact_removed_intervals_ind[0][0]]) + ) + for interval_ind in range(len(artifact_removed_intervals_ind) - 1): + list_triggers.append( + np.arange( + (artifact_removed_intervals_ind[interval_ind][1] + 1), + artifact_removed_intervals_ind[interval_ind + 1][0], + ) + ) + if artifact_removed_intervals_ind[-1][1] < len(timestamps): + list_triggers.append( + np.array( + [ + artifact_removed_intervals_ind[-1][1], + len(timestamps) - 1, + ] + ) + ) + + list_triggers = [list(np.concatenate(list_triggers))] + recording = sip.remove_artifacts( + recording=recording, + list_triggers=list_triggers, + ms_before=None, + ms_after=None, + mode="zeros", + ) + + if sorter == "clusterless_thresholder": + # need to remove tempdir and whiten from sorter_params + sorter_params.pop("tempdir", None) + sorter_params.pop("whiten", None) + sorter_params.pop("outputs", None) + + # Detect peaks for clusterless decoding + detected_spikes = detect_peaks(recording, **sorter_params) + sorting = si.NumpySorting.from_times_labels( + times_list=detected_spikes["sample_ind"], + labels_list=np.zeros(len(detected_spikes), dtype=np.int), + sampling_frequency=recording.get_sampling_frequency(), + ) + else: + # Specify tempdir (expected by some sorters like mountainsort4) + sorter_temp_dir = tempfile.TemporaryDirectory( + dir=os.getenv("SPYGLASS_TEMP_DIR") + ) + sorter_params["tempdir"] = sorter_temp_dir.name + # if whitening is specified in sorter params, apply whitening separately + # prior to sorting and turn off "sorter whitening" + if sorter_params["whiten"]: + recording = sip.whiten(recording, dtype=np.float64) + sorter_params["whiten"] = False + sorting = sis.run_sorter( + sorter, + recording, + output_folder=sorter_temp_dir.name, + remove_existing_folder=True, + **sorter_params, + ) + key["time_of_sort"] = int(time.time()) + sorting = sic.remove_excess_spikes(sorting, recording) + key["analysis_file_name"], key["object_id"] = _write_sorting_to_nwb( + sorting, + timestamps, + artifact_removed_intervals, + (SpikeSortingSelection & key).fetch1("nwb_file_name"), + ) + + # INSERT + # - new entry to AnalysisNwbfile + # - new entry to SpikeSorting + AnalysisNwbfile().add( + (SpikeSortingSelection & key).fetch1("nwb_file_name"), + key["analysis_file_name"], + ) + self.insert1(key, skip_duplicates=True) + + @classmethod + def get_sorting(cls, key: dict) -> si.BaseSorting: + """Get sorting in the analysis NWB file as spikeinterface BaseSorting + + Parameters + ---------- + key : dict + primary key of SpikeSorting + + Returns + ------- + sorting : si.BaseSorting + + """ + + analysis_file_name = (cls & key).fetch1("analysis_file_name") + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_file_name + ) + sorting = se.read_nwb_sorting(analysis_file_abs_path) + + return sorting + + +def _write_sorting_to_nwb( + sorting: si.BaseSorting, + timestamps: np.ndarray, + sort_interval: Iterable, + nwb_file_name: str, +): + """Write a sorting in NWB format. + + Parameters + ---------- + sorting : si.BaseSorting + spike times are in samples + timestamps: np.ndarray + the absolute time of each sample, in seconds + sort_interval : Iterable + nwb_file_name : str + Name of NWB file the recording originates from + + Returns + ------- + analysis_nwb_file : str + Name of analysis NWB file containing the sorting + """ + + analysis_nwb_file = AnalysisNwbfile().create(nwb_file_name) + analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(analysis_nwb_file) + with pynwb.NWBHDF5IO( + path=analysis_nwb_file_abs_path, + mode="a", + load_namespaces=True, + ) as io: + nwbf = io.read() + nwbf.add_unit_column( + name="curation_label", + description="curation label applied to a unit", + ) + for unit_id in sorting.get_unit_ids(): + spike_times = sorting.get_unit_spike_train(unit_id) + nwbf.add_unit( + spike_times=timestamps[spike_times], + id=unit_id, + obs_intervals=sort_interval, + curation_label="uncurated", + ) + units_object_id = nwbf.units.object_id + io.write(nwbf) + return analysis_nwb_file, units_object_id diff --git a/src/spyglass/spikesorting/v1/utils.py b/src/spyglass/spikesorting/v1/utils.py new file mode 100644 index 000000000..eaff64739 --- /dev/null +++ b/src/spyglass/spikesorting/v1/utils.py @@ -0,0 +1,107 @@ +import uuid + +import numpy as np + +from spyglass.spikesorting.merge import SpikeSortingOutput +from spyglass.spikesorting.v1.artifact import ArtifactDetectionSelection +from spyglass.spikesorting.v1.curation import CurationV1 +from spyglass.spikesorting.v1.recording import SpikeSortingRecordingSelection +from spyglass.spikesorting.v1.sorting import SpikeSortingSelection + + +def generate_nwb_uuid(nwb_file_name: str, initial: str, len_uuid: int = 6): + """Generates a unique identifier related to an NWB file. + + Parameters + ---------- + nwb_file_name : str + _description_ + initial : str + R if recording; A if artifact; S if sorting etc + len_uuid : int + how many digits of uuid4 to keep + """ + uuid4 = str(uuid.uuid4()) + nwb_uuid = nwb_file_name + "_" + initial + "_" + uuid4[:len_uuid] + return nwb_uuid + + +def get_spiking_sorting_v1_merge_ids(restriction: dict): + """ + Parses the SpikingSorting V1 pipeline to get a list of merge ids for a given restriction. + + Parameters + ---------- + restriction : dict + A dictionary containing some or all of the following key-value pairs: + nwb_file_name : str + name of the nwb file + interval_list_name : str + name of the interval list + sort_group_name : str + name of the sort group + artifact_param_name : str + name of the artifact parameter + curation_id : int, optional + id of the curation (if not specified, uses the latest curation) + Returns + ------- + merge_id_list : list + list of merge ids for the given restriction + """ + # list of recording ids + recording_id_list = (SpikeSortingRecordingSelection() & restriction).fetch( + "recording_id" + ) + # list of artifact ids for each recording + artifact_id_list = [ + ( + ArtifactDetectionSelection() & restriction & {"recording_id": id} + ).fetch1("artifact_id") + for id in recording_id_list + ] + # list of sorting ids for each recording + sorting_restriction = restriction.copy() + del sorting_restriction["interval_list_name"] + sorting_id_list = [] + for r_id, a_id in zip(recording_id_list, artifact_id_list): + # if sorted with artifact detection + if ( + SpikeSortingSelection() + & sorting_restriction + & {"recording_id": r_id, "interval_list_name": a_id} + ): + sorting_id_list.append( + ( + SpikeSortingSelection() + & sorting_restriction + & {"recording_id": r_id, "interval_list_name": a_id} + ).fetch1("sorting_id") + ) + # if sorted without artifact detection + else: + sorting_id_list.append( + ( + SpikeSortingSelection() + & sorting_restriction + & {"recording_id": r_id, "interval_list_name": r_id} + ).fetch1("sorting_id") + ) + # if curation_id is specified, use that id for each sorting_id + if "curation_id" in restriction: + curation_id = [restriction["curation_id"] for _ in sorting_id_list] + # if curation_id is not specified, use the latest curation_id for each sorting_id + else: + curation_id = [ + np.max((CurationV1 & {"sorting_id": id}).fetch("curation_id")) + for id in sorting_id_list + ] + # list of merge ids for the desired curation(s) + merge_id_list = [ + ( + SpikeSortingOutput.CurationV1() + & {"sorting_id": id, "curation_id": c_id} + ).fetch1("merge_id") + for id, c_id in zip(sorting_id_list, curation_id) + ] + return merge_id_list