From 7fd8fc9ce5639f7ea9824c80b464847ec2e652fd Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Tue, 7 Nov 2023 16:31:30 -0800 Subject: [PATCH] ssv0 copy files to dest --- src/spyglass/spikesorting/v1/__init__.py | 39 + src/spyglass/spikesorting/v1/artifact.py | 415 +++++++ src/spyglass/spikesorting/v1/curation.py | 1097 +++++++++++++++++ .../spikesorting/v1/figurl_curation.py | 226 ++++ src/spyglass/spikesorting/v1/populator.py | 300 +++++ src/spyglass/spikesorting/v1/recording.py | 627 ++++++++++ src/spyglass/spikesorting/v1/sorting.py | 310 +++++ 7 files changed, 3014 insertions(+) create mode 100644 src/spyglass/spikesorting/v1/__init__.py create mode 100644 src/spyglass/spikesorting/v1/artifact.py create mode 100644 src/spyglass/spikesorting/v1/curation.py create mode 100644 src/spyglass/spikesorting/v1/figurl_curation.py create mode 100644 src/spyglass/spikesorting/v1/populator.py create mode 100644 src/spyglass/spikesorting/v1/recording.py create mode 100644 src/spyglass/spikesorting/v1/sorting.py diff --git a/src/spyglass/spikesorting/v1/__init__.py b/src/spyglass/spikesorting/v1/__init__.py new file mode 100644 index 000000000..05c7e9948 --- /dev/null +++ b/src/spyglass/spikesorting/v1/__init__.py @@ -0,0 +1,39 @@ +from .curation_figurl import CurationFigurl, CurationFigurlSelection +from .sortingview import SortingviewWorkspace, SortingviewWorkspaceSelection +from .spikesorting_artifact import ( + ArtifactDetection, + ArtifactDetectionParameters, + ArtifactDetectionSelection, + ArtifactRemovedIntervalList, +) +from .spikesorting_curation import ( + AutomaticCuration, + AutomaticCurationParameters, + AutomaticCurationSelection, + CuratedSpikeSorting, + CuratedSpikeSortingSelection, + Curation, + MetricParameters, + MetricSelection, + QualityMetrics, + UnitInclusionParameters, + WaveformParameters, + Waveforms, + WaveformSelection, +) +from .spikesorting_populator import ( + SpikeSortingPipelineParameters, + spikesorting_pipeline_populator, +) +from .spikesorting_recording import ( + SortGroup, + SortInterval, + SpikeSortingPreprocessingParameters, + SpikeSortingRecording, + SpikeSortingRecordingSelection, +) +from .spikesorting_sorting import ( + SpikeSorterParameters, + SpikeSorting, + SpikeSortingSelection, +) diff --git a/src/spyglass/spikesorting/v1/artifact.py b/src/spyglass/spikesorting/v1/artifact.py new file mode 100644 index 000000000..0532357c1 --- /dev/null +++ b/src/spyglass/spikesorting/v1/artifact.py @@ -0,0 +1,415 @@ +import warnings +from functools import reduce +from typing import Union + +import datajoint as dj +import numpy as np +import scipy.stats as stats +import spikeinterface as si +from spikeinterface.core.job_tools import ChunkRecordingExecutor, ensure_n_jobs + +from ..common.common_interval import ( + IntervalList, + _union_concat, + interval_from_inds, + interval_set_difference_inds, +) +from ..utils.nwb_helper_fn import get_valid_intervals +from .spikesorting_recording import SpikeSortingRecording + +schema = dj.schema("spikesorting_artifact") + + +@schema +class ArtifactDetectionParameters(dj.Manual): + definition = """ + # Parameters for detecting artifact times within a sort group. + artifact_params_name: varchar(200) + --- + artifact_params: blob # dictionary of parameters + """ + + def insert_default(self): + """Insert the default artifact parameters with an appropriate parameter dict.""" + artifact_params = {} + artifact_params["zscore_thresh"] = None # must be None or >= 0 + artifact_params["amplitude_thresh"] = 3000 # must be None or >= 0 + # all electrodes of sort group + artifact_params["proportion_above_thresh"] = 1.0 + artifact_params["removal_window_ms"] = 1.0 # in milliseconds + self.insert1(["default", artifact_params], skip_duplicates=True) + + artifact_params_none = {} + artifact_params_none["zscore_thresh"] = None + artifact_params_none["amplitude_thresh"] = None + self.insert1(["none", artifact_params_none], skip_duplicates=True) + + +@schema +class ArtifactDetectionSelection(dj.Manual): + definition = """ + # Specifies artifact detection parameters to apply to a sort group's recording. + -> SpikeSortingRecording + -> ArtifactDetectionParameters + --- + custom_artifact_detection=0 : tinyint + """ + + +@schema +class ArtifactDetection(dj.Computed): + definition = """ + # Stores artifact times and valid no-artifact times as intervals. + -> ArtifactDetectionSelection + --- + artifact_times: longblob # np array of artifact intervals + artifact_removed_valid_times: longblob # np array of valid no-artifact intervals + artifact_removed_interval_list_name: varchar(200) # name of the array of no-artifact valid time intervals + """ + + def make(self, key): + if not (ArtifactDetectionSelection & key).fetch1( + "custom_artifact_detection" + ): + # get the dict of artifact params associated with this artifact_params_name + artifact_params = (ArtifactDetectionParameters & key).fetch1( + "artifact_params" + ) + + recording_path = (SpikeSortingRecording & key).fetch1( + "recording_path" + ) + recording_name = SpikeSortingRecording._get_recording_name(key) + recording = si.load_extractor(recording_path) + + job_kwargs = { + "chunk_duration": "10s", + "n_jobs": 4, + "progress_bar": "True", + } + + artifact_removed_valid_times, artifact_times = _get_artifact_times( + recording, **artifact_params, **job_kwargs + ) + + # NOTE: decided not to do this but to just create a single long segment; keep for now + # get artifact times by segment + # if AppendSegmentRecording, get artifact times for each segment + # if isinstance(recording, AppendSegmentRecording): + # artifact_removed_valid_times = [] + # artifact_times = [] + # for rec in recording.recording_list: + # rec_valid_times, rec_artifact_times = _get_artifact_times(rec, **artifact_params) + # for valid_times in rec_valid_times: + # artifact_removed_valid_times.append(valid_times) + # for artifact_times in rec_artifact_times: + # artifact_times.append(artifact_times) + # artifact_removed_valid_times = np.asarray(artifact_removed_valid_times) + # artifact_times = np.asarray(artifact_times) + # else: + # artifact_removed_valid_times, artifact_times = _get_artifact_times(recording, **artifact_params) + + key["artifact_times"] = artifact_times + key["artifact_removed_valid_times"] = artifact_removed_valid_times + + # set up a name for no-artifact times using recording id + key["artifact_removed_interval_list_name"] = ( + recording_name + + "_" + + key["artifact_params_name"] + + "_artifact_removed_valid_times" + ) + + ArtifactRemovedIntervalList.insert1(key, replace=True) + + # also insert into IntervalList + tmp_key = {} + tmp_key["nwb_file_name"] = key["nwb_file_name"] + tmp_key["interval_list_name"] = key[ + "artifact_removed_interval_list_name" + ] + tmp_key["valid_times"] = key["artifact_removed_valid_times"] + IntervalList.insert1(tmp_key, replace=True) + + # insert into computed table + self.insert1(key) + + +@schema +class ArtifactRemovedIntervalList(dj.Manual): + definition = """ + # Stores intervals without detected artifacts. + # Note that entries can come from either ArtifactDetection() or alternative artifact removal analyses. + artifact_removed_interval_list_name: varchar(200) + --- + -> ArtifactDetectionSelection + artifact_removed_valid_times: longblob + artifact_times: longblob # np array of artifact intervals + """ + + +def _get_artifact_times( + recording: si.BaseRecording, + zscore_thresh: Union[float, None] = None, + amplitude_thresh: Union[float, None] = None, + proportion_above_thresh: float = 1.0, + removal_window_ms: float = 1.0, + verbose: bool = False, + **job_kwargs, +): + """Detects times during which artifacts do and do not occur. + Artifacts are defined as periods where the absolute value of the recording signal exceeds one + or both specified amplitude or zscore thresholds on the proportion of channels specified, + with the period extended by the removal_window_ms/2 on each side. Z-score and amplitude + threshold values of None are ignored. + + Parameters + ---------- + recording : si.BaseRecording + zscore_thresh : float, optional + Stdev threshold for exclusion, should be >=0, defaults to None + amplitude_thresh : float, optional + Amplitude threshold for exclusion, should be >=0, defaults to None + proportion_above_thresh : float, optional, should be>0 and <=1 + Proportion of electrodes that need to have threshold crossings, defaults to 1 + removal_window_ms : float, optional + Width of the window in milliseconds to mask out per artifact + (window/2 removed on each side of threshold crossing), defaults to 1 ms + + Returns + ------- + artifact_removed_valid_times : np.ndarray + Intervals of valid times where artifacts were not detected, unit: seconds + artifact_intervals : np.ndarray + Intervals in which artifacts are detected (including removal windows), unit: seconds + """ + + if recording.get_num_segments() > 1: + valid_timestamps = np.array([]) + for segment in range(recording.get_num_segments()): + valid_timestamps = np.concatenate( + (valid_timestamps, recording.get_times(segment_index=segment)) + ) + recording = si.concatenate_recordings([recording]) + elif recording.get_num_segments() == 1: + valid_timestamps = recording.get_times(0) + + # if both thresholds are None, we skip artifract detection + if (amplitude_thresh is None) and (zscore_thresh is None): + recording_interval = np.asarray( + [[valid_timestamps[0], valid_timestamps[-1]]] + ) + artifact_times_empty = np.asarray([]) + print( + "Amplitude and zscore thresholds are both None, skipping artifact detection" + ) + return recording_interval, artifact_times_empty + + # verify threshold parameters + ( + amplitude_thresh, + zscore_thresh, + proportion_above_thresh, + ) = _check_artifact_thresholds( + amplitude_thresh, zscore_thresh, proportion_above_thresh + ) + + # detect frames that are above threshold in parallel + n_jobs = ensure_n_jobs(recording, n_jobs=job_kwargs.get("n_jobs", 1)) + print(f"using {n_jobs} jobs...") + + func = _compute_artifact_chunk + init_func = _init_artifact_worker + if n_jobs == 1: + init_args = ( + recording, + zscore_thresh, + amplitude_thresh, + proportion_above_thresh, + ) + else: + init_args = ( + recording.to_dict(), + zscore_thresh, + amplitude_thresh, + proportion_above_thresh, + ) + + executor = ChunkRecordingExecutor( + recording, + func, + init_func, + init_args, + verbose=verbose, + handle_returns=True, + job_name="detect_artifact_frames", + **job_kwargs, + ) + artifact_frames = executor.run() + artifact_frames = np.concatenate(artifact_frames) + + # turn ms to remove total into s to remove from either side of each detected artifact + half_removal_window_s = removal_window_ms / 1000 * 0.5 + + if len(artifact_frames) == 0: + recording_interval = np.asarray( + [[valid_timestamps[0], valid_timestamps[-1]]] + ) + artifact_times_empty = np.asarray([]) + print("No artifacts detected.") + return recording_interval, artifact_times_empty + + # convert indices to intervals + artifact_intervals = interval_from_inds(artifact_frames) + + # convert to seconds and pad with window + artifact_intervals_s = np.zeros( + (len(artifact_intervals), 2), dtype=np.float64 + ) + for interval_idx, interval in enumerate(artifact_intervals): + artifact_intervals_s[interval_idx] = [ + valid_timestamps[interval[0]] - half_removal_window_s, + valid_timestamps[interval[1]] + half_removal_window_s, + ] + # make the artifact intervals disjoint + artifact_intervals_s = reduce(_union_concat, artifact_intervals_s) + + # convert seconds back to indices + artifact_intervals_new = [] + for artifact_interval_s in artifact_intervals_s: + artifact_intervals_new.append( + np.searchsorted(valid_timestamps, artifact_interval_s) + ) + + # compute set difference between intervals (of indices) + try: + # if artifact_intervals_new is a list of lists then len(artifact_intervals_new[0]) is the number of intervals + # otherwise artifact_intervals_new is a list of ints and len(artifact_intervals_new[0]) is not defined + len(artifact_intervals_new[0]) + except TypeError: + # convert to list of lists + artifact_intervals_new = [artifact_intervals_new] + artifact_removed_valid_times_ind = interval_set_difference_inds( + [(0, len(valid_timestamps) - 1)], artifact_intervals_new + ) + + # convert back to seconds + artifact_removed_valid_times = [] + for i in artifact_removed_valid_times_ind: + artifact_removed_valid_times.append( + (valid_timestamps[i[0]], valid_timestamps[i[1]]) + ) + + return artifact_removed_valid_times, artifact_intervals_s + + +def _init_artifact_worker( + recording, + zscore_thresh=None, + amplitude_thresh=None, + proportion_above_thresh=1.0, +): + # create a local dict per worker + worker_ctx = {} + if isinstance(recording, dict): + worker_ctx["recording"] = si.load_extractor(recording) + else: + worker_ctx["recording"] = recording + worker_ctx["zscore_thresh"] = zscore_thresh + worker_ctx["amplitude_thresh"] = amplitude_thresh + worker_ctx["proportion_above_thresh"] = proportion_above_thresh + return worker_ctx + + +def _compute_artifact_chunk(segment_index, start_frame, end_frame, worker_ctx): + recording = worker_ctx["recording"] + zscore_thresh = worker_ctx["zscore_thresh"] + amplitude_thresh = worker_ctx["amplitude_thresh"] + proportion_above_thresh = worker_ctx["proportion_above_thresh"] + # compute the number of electrodes that have to be above threshold + nelect_above = np.ceil( + proportion_above_thresh * len(recording.get_channel_ids()) + ) + + traces = recording.get_traces( + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + ) + + # find the artifact occurrences using one or both thresholds, across channels + if (amplitude_thresh is not None) and (zscore_thresh is None): + above_a = np.abs(traces) > amplitude_thresh + above_thresh = ( + np.ravel(np.argwhere(np.sum(above_a, axis=1) >= nelect_above)) + + start_frame + ) + elif (amplitude_thresh is None) and (zscore_thresh is not None): + dataz = np.abs(stats.zscore(traces, axis=1)) + above_z = dataz > zscore_thresh + above_thresh = ( + np.ravel(np.argwhere(np.sum(above_z, axis=1) >= nelect_above)) + + start_frame + ) + else: + above_a = np.abs(traces) > amplitude_thresh + dataz = np.abs(stats.zscore(traces, axis=1)) + above_z = dataz > zscore_thresh + above_thresh = ( + np.ravel( + np.argwhere( + np.sum(np.logical_or(above_z, above_a), axis=1) + >= nelect_above + ) + ) + + start_frame + ) + + return above_thresh + + +def _check_artifact_thresholds( + amplitude_thresh, zscore_thresh, proportion_above_thresh +): + """Alerts user to likely unintended parameters. Not an exhaustive verification. + + Parameters + ---------- + zscore_thresh: float + amplitude_thresh: float + proportion_above_thresh: float + + Return + ------ + zscore_thresh: float + amplitude_thresh: float + proportion_above_thresh: float + + Raise + ------ + ValueError: if signal thresholds are negative + """ + # amplitude or zscore thresholds should be negative, as they are applied to an absolute signal + signal_thresholds = [ + t for t in [amplitude_thresh, zscore_thresh] if t is not None + ] + for t in signal_thresholds: + if t < 0: + raise ValueError( + "Amplitude and Z-Score thresholds must be >= 0, or None" + ) + + # proportion_above_threshold should be in [0:1] inclusive + if proportion_above_thresh < 0: + warnings.warn( + "Warning: proportion_above_thresh must be a proportion >0 and <=1." + f" Using proportion_above_thresh = 0.01 instead of {str(proportion_above_thresh)}" + ) + proportion_above_thresh = 0.01 + elif proportion_above_thresh > 1: + warnings.warn( + "Warning: proportion_above_thresh must be a proportion >0 and <=1. " + f"Using proportion_above_thresh = 1 instead of {str(proportion_above_thresh)}" + ) + proportion_above_thresh = 1 + return amplitude_thresh, zscore_thresh, proportion_above_thresh diff --git a/src/spyglass/spikesorting/v1/curation.py b/src/spyglass/spikesorting/v1/curation.py new file mode 100644 index 000000000..b0699ba42 --- /dev/null +++ b/src/spyglass/spikesorting/v1/curation.py @@ -0,0 +1,1097 @@ +import json +import os +import shutil +import time +import uuid +import warnings +from pathlib import Path +from typing import List + +import datajoint as dj +import numpy as np +import spikeinterface as si +import spikeinterface.preprocessing as sip +import spikeinterface.qualitymetrics as sq + +from ..common.common_interval import IntervalList +from ..common.common_nwbfile import AnalysisNwbfile +from ..utils.dj_helper_fn import fetch_nwb +from .merged_sorting_extractor import MergedSortingExtractor +from .spikesorting_recording import SortInterval, SpikeSortingRecording +from .spikesorting_sorting import SpikeSorting + +schema = dj.schema("spikesorting_curation") + +valid_labels = ["reject", "noise", "artifact", "mua", "accept"] + + +def apply_merge_groups_to_sorting( + sorting: si.BaseSorting, merge_groups: List[List[int]] +): + # return a new sorting where the units are merged according to merge_groups + # merge_groups is a list of lists of unit_ids. + # for example: merge_groups = [[1, 2], [5, 8, 4]]] + + return MergedSortingExtractor( + parent_sorting=sorting, merge_groups=merge_groups + ) + + +@schema +class Curation(dj.Manual): + definition = """ + # Stores each spike sorting; similar to IntervalList + curation_id: int # a number corresponding to the index of this curation + -> SpikeSorting + --- + parent_curation_id=-1: int + curation_labels: blob # a dictionary of labels for the units + merge_groups: blob # a list of merge groups for the units + quality_metrics: blob # a list of quality metrics for the units (if available) + description='': varchar(1000) #optional description for this curated sort + time_of_creation: int # in Unix time, to the nearest second + """ + + @staticmethod + def insert_curation( + sorting_key: dict, + parent_curation_id: int = -1, + labels=None, + merge_groups=None, + metrics=None, + description="", + ): + """Given a SpikeSorting key and the parent_sorting_id (and optional + arguments) insert an entry into Curation. + + + Parameters + ---------- + sorting_key : dict + The key for the original SpikeSorting + parent_curation_id : int, optional + The id of the parent sorting + labels : dict or None, optional + merge_groups : dict or None, optional + metrics : dict or None, optional + Computed metrics for sorting + description : str, optional + text description of this sort + + Returns + ------- + curation_key : dict + + """ + if parent_curation_id == -1: + # check to see if this sorting with a parent of -1 has already been inserted and if so, warn the user + inserted_curation = (Curation & sorting_key).fetch("KEY") + if len(inserted_curation) > 0: + Warning( + "Sorting has already been inserted, returning key to previously" + "inserted curation" + ) + return inserted_curation[0] + + if labels is None: + labels = {} + if merge_groups is None: + merge_groups = [] + if metrics is None: + metrics = {} + + # generate a unique number for this curation + id = (Curation & sorting_key).fetch("curation_id") + if len(id) > 0: + curation_id = max(id) + 1 + else: + curation_id = 0 + + # convert unit_ids in labels to integers for labels from sortingview. + new_labels = {int(unit_id): labels[unit_id] for unit_id in labels} + + sorting_key["curation_id"] = curation_id + sorting_key["parent_curation_id"] = parent_curation_id + sorting_key["description"] = description + sorting_key["curation_labels"] = new_labels + sorting_key["merge_groups"] = merge_groups + sorting_key["quality_metrics"] = metrics + sorting_key["time_of_creation"] = int(time.time()) + + # mike: added skip duplicates + Curation.insert1(sorting_key, skip_duplicates=True) + + # get the primary key for this curation + c_key = Curation.fetch("KEY")[0] + curation_key = {item: sorting_key[item] for item in c_key} + + return curation_key + + @staticmethod + def get_recording(key: dict): + """Returns the recording extractor for the recording related to this curation + + Parameters + ---------- + key : dict + SpikeSortingRecording key + + Returns + ------- + recording_extractor : spike interface recording extractor + + """ + recording_path = (SpikeSortingRecording & key).fetch1("recording_path") + return si.load_extractor(recording_path) + + @staticmethod + def get_curated_sorting(key: dict): + """Returns the sorting extractor related to this curation, + with merges applied. + + Parameters + ---------- + key : dict + Curation key + + Returns + ------- + sorting_extractor: spike interface sorting extractor + + """ + sorting_path = (SpikeSorting & key).fetch1("sorting_path") + sorting = si.load_extractor(sorting_path) + merge_groups = (Curation & key).fetch1("merge_groups") + # TODO: write code to get merged sorting extractor + if len(merge_groups) != 0: + return MergedSortingExtractor( + parent_sorting=sorting, merge_groups=merge_groups + ) + else: + return sorting + + @staticmethod + def save_sorting_nwb( + key, + sorting, + timestamps, + sort_interval_list_name, + sort_interval, + labels=None, + metrics=None, + unit_ids=None, + ): + """Store a sorting in a new AnalysisNwbfile + + Parameters + ---------- + key : dict + key to SpikeSorting table + sorting : si.Sorting + sorting + timestamps : array_like + Time stamps of the sorted recoridng; + used to convert the spike timings from index to real time + sort_interval_list_name : str + name of sort interval + sort_interval : list + interval for start and end of sort + labels : dict, optional + curation labels, by default None + metrics : dict, optional + quality metrics, by default None + unit_ids : list, optional + IDs of units whose spiketrains to save, by default None + + Returns + ------- + analysis_file_name : str + units_object_id : str + + """ + + sort_interval_valid_times = ( + IntervalList & {"interval_list_name": sort_interval_list_name} + ).fetch1("valid_times") + + units = dict() + units_valid_times = dict() + units_sort_interval = dict() + + if unit_ids is None: + unit_ids = sorting.get_unit_ids() + + for unit_id in unit_ids: + spike_times_in_samples = sorting.get_unit_spike_train( + unit_id=unit_id + ) + units[unit_id] = timestamps[spike_times_in_samples] + units_valid_times[unit_id] = sort_interval_valid_times + units_sort_interval[unit_id] = [sort_interval] + + analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) + object_ids = AnalysisNwbfile().add_units( + analysis_file_name, + units, + units_valid_times, + units_sort_interval, + metrics=metrics, + labels=labels, + ) + AnalysisNwbfile().add(key["nwb_file_name"], analysis_file_name) + + if object_ids == "": + print( + "Sorting contains no units." + "Created an empty analysis nwb file anyway." + ) + units_object_id = "" + else: + units_object_id = object_ids[0] + + return analysis_file_name, units_object_id + + def fetch_nwb(self, *attrs, **kwargs): + return fetch_nwb( + self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs + ) + + +@schema +class WaveformParameters(dj.Manual): + definition = """ + waveform_params_name: varchar(80) # name of waveform extraction parameters + --- + waveform_params: blob # a dict of waveform extraction parameters + """ + + def insert_default(self): + waveform_params_name = "default_not_whitened" + waveform_params = { + "ms_before": 0.5, + "ms_after": 0.5, + "max_spikes_per_unit": 5000, + "n_jobs": 5, + "total_memory": "5G", + "whiten": False, + } + self.insert1( + [waveform_params_name, waveform_params], skip_duplicates=True + ) + waveform_params_name = "default_whitened" + waveform_params = { + "ms_before": 0.5, + "ms_after": 0.5, + "max_spikes_per_unit": 5000, + "n_jobs": 5, + "total_memory": "5G", + "whiten": True, + } + self.insert1( + [waveform_params_name, waveform_params], skip_duplicates=True + ) + + +@schema +class WaveformSelection(dj.Manual): + definition = """ + -> Curation + -> WaveformParameters + --- + """ + + +@schema +class Waveforms(dj.Computed): + definition = """ + -> WaveformSelection + --- + waveform_extractor_path: varchar(400) + -> AnalysisNwbfile + waveforms_object_id: varchar(40) # Object ID for the waveforms in NWB file + """ + + def make(self, key): + recording = Curation.get_recording(key) + if recording.get_num_segments() > 1: + recording = si.concatenate_recordings([recording]) + + sorting = Curation.get_curated_sorting(key) + + print("Extracting waveforms...") + waveform_params = (WaveformParameters & key).fetch1("waveform_params") + if "whiten" in waveform_params: + if waveform_params.pop("whiten"): + recording = sip.whiten(recording, dtype="float32") + + waveform_extractor_name = self._get_waveform_extractor_name(key) + key["waveform_extractor_path"] = str( + Path(os.environ["SPYGLASS_WAVEFORMS_DIR"]) + / Path(waveform_extractor_name) + ) + if os.path.exists(key["waveform_extractor_path"]): + shutil.rmtree(key["waveform_extractor_path"]) + waveforms = si.extract_waveforms( + recording=recording, + sorting=sorting, + folder=key["waveform_extractor_path"], + **waveform_params, + ) + + key["analysis_file_name"] = AnalysisNwbfile().create( + key["nwb_file_name"] + ) + object_id = AnalysisNwbfile().add_units_waveforms( + key["analysis_file_name"], waveform_extractor=waveforms + ) + key["waveforms_object_id"] = object_id + AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) + + self.insert1(key) + + def load_waveforms(self, key: dict): + """Returns a spikeinterface waveform extractor specified by key + + Parameters + ---------- + key : dict + Could be an entry in Waveforms, or some other key that uniquely defines + an entry in Waveforms + + Returns + ------- + we : spikeinterface.WaveformExtractor + """ + we_path = (self & key).fetch1("waveform_extractor_path") + we = si.WaveformExtractor.load_from_folder(we_path) + return we + + def fetch_nwb(self, key): + # TODO: implement fetching waveforms from NWB + return NotImplementedError + + def _get_waveform_extractor_name(self, key): + waveform_params_name = (WaveformParameters & key).fetch1( + "waveform_params_name" + ) + + return ( + f'{key["nwb_file_name"]}_{str(uuid.uuid4())[0:8]}_' + f'{key["curation_id"]}_{waveform_params_name}_waveforms' + ) + + +@schema +class MetricParameters(dj.Manual): + definition = """ + # Parameters for computing quality metrics of sorted units + metric_params_name: varchar(200) + --- + metric_params: blob + """ + metric_default_params = { + "snr": { + "peak_sign": "neg", + "random_chunk_kwargs_dict": { + "num_chunks_per_segment": 20, + "chunk_size": 10000, + "seed": 0, + }, + }, + "isi_violation": {"isi_threshold_ms": 1.5, "min_isi_ms": 0.0}, + "nn_isolation": { + "max_spikes": 1000, + "min_spikes": 10, + "n_neighbors": 5, + "n_components": 7, + "radius_um": 100, + "seed": 0, + }, + "nn_noise_overlap": { + "max_spikes": 1000, + "min_spikes": 10, + "n_neighbors": 5, + "n_components": 7, + "radius_um": 100, + "seed": 0, + }, + "peak_channel": {"peak_sign": "neg"}, + "num_spikes": {}, + } + # Example of peak_offset parameters 'peak_offset': {'peak_sign': 'neg'} + available_metrics = [ + "snr", + "isi_violation", + "nn_isolation", + "nn_noise_overlap", + "peak_offset", + "peak_channel", + "num_spikes", + ] + + def get_metric_default_params(self, metric: str): + "Returns default params for the given metric" + return self.metric_default_params(metric) + + def insert_default(self): + self.insert1( + ["franklab_default3", self.metric_default_params], + skip_duplicates=True, + ) + + def get_available_metrics(self): + for metric in _metric_name_to_func: + if metric in self.available_metrics: + metric_doc = _metric_name_to_func[metric].__doc__.split("\n")[0] + metric_string = ("{metric_name} : {metric_doc}").format( + metric_name=metric, metric_doc=metric_doc + ) + print(metric_string + "\n") + + # TODO + def _validate_metrics_list(self, key): + """Checks whether a row to be inserted contains only the available metrics""" + # get available metrics list + # get metric list from key + # compare + return NotImplementedError + + +@schema +class MetricSelection(dj.Manual): + definition = """ + -> Waveforms + -> MetricParameters + --- + """ + + def insert1(self, key, **kwargs): + waveform_params = (WaveformParameters & key).fetch1("waveform_params") + metric_params = (MetricParameters & key).fetch1("metric_params") + if "peak_offset" in metric_params: + if waveform_params["whiten"]: + warnings.warn( + "Calculating 'peak_offset' metric on " + "whitened waveforms may result in slight " + "discrepancies" + ) + if "peak_channel" in metric_params: + if waveform_params["whiten"]: + Warning( + "Calculating 'peak_channel' metric on " + "whitened waveforms may result in slight " + "discrepancies" + ) + super().insert1(key, **kwargs) + + +@schema +class QualityMetrics(dj.Computed): + definition = """ + -> MetricSelection + --- + quality_metrics_path: varchar(500) + -> AnalysisNwbfile + object_id: varchar(40) # Object ID for the metrics in NWB file + """ + + def make(self, key): + waveform_extractor = Waveforms().load_waveforms(key) + qm = {} + params = (MetricParameters & key).fetch1("metric_params") + for metric_name, metric_params in params.items(): + metric = self._compute_metric( + waveform_extractor, metric_name, **metric_params + ) + qm[metric_name] = metric + qm_name = self._get_quality_metrics_name(key) + key["quality_metrics_path"] = str( + Path(os.environ["SPYGLASS_WAVEFORMS_DIR"]) / Path(qm_name + ".json") + ) + # save metrics dict as json + print(f"Computed all metrics: {qm}") + self._dump_to_json(qm, key["quality_metrics_path"]) + + key["analysis_file_name"] = AnalysisNwbfile().create( + key["nwb_file_name"] + ) + key["object_id"] = AnalysisNwbfile().add_units_metrics( + key["analysis_file_name"], metrics=qm + ) + AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) + + self.insert1(key) + + def _get_quality_metrics_name(self, key): + wf_name = Waveforms()._get_waveform_extractor_name(key) + qm_name = wf_name + "_qm" + return qm_name + + def _compute_metric(self, waveform_extractor, metric_name, **metric_params): + peak_sign_metrics = ["snr", "peak_offset", "peak_channel"] + metric_func = _metric_name_to_func[metric_name] + # TODO clean up code below + if metric_name == "isi_violation": + metric = metric_func(waveform_extractor, **metric_params) + elif metric_name in peak_sign_metrics: + if "peak_sign" in metric_params: + metric = metric_func( + waveform_extractor, + peak_sign=metric_params.pop("peak_sign"), + **metric_params, + ) + else: + raise Exception( + f"{peak_sign_metrics} metrics require peak_sign", + f"to be defined in the metric parameters", + ) + else: + metric = {} + for unit_id in waveform_extractor.sorting.get_unit_ids(): + metric[str(unit_id)] = metric_func( + waveform_extractor, this_unit_id=unit_id, **metric_params + ) + # nn_isolation returns tuple with isolation and unit number. We only want isolation. + if metric_name == "nn_isolation": + metric[str(unit_id)] = metric[str(unit_id)][0] + return metric + + def _dump_to_json(self, qm_dict, save_path): + new_qm = {} + for key, value in qm_dict.items(): + m = {} + for unit_id, metric_val in value.items(): + m[str(unit_id)] = np.float64(metric_val) + new_qm[str(key)] = m + with open(save_path, "w", encoding="utf-8") as f: + json.dump(new_qm, f, ensure_ascii=False, indent=4) + + +def _compute_isi_violation_fractions(waveform_extractor, **metric_params): + """Computes the per unit fraction of interspike interval violations to total spikes.""" + isi_threshold_ms = metric_params["isi_threshold_ms"] + min_isi_ms = metric_params["min_isi_ms"] + + # Extract the total number of spikes that violated the isi_threshold for each unit + isi_violation_counts = sq.compute_isi_violations( + waveform_extractor, + isi_threshold_ms=isi_threshold_ms, + min_isi_ms=min_isi_ms, + ).isi_violations_count + + # Extract the total number of spikes from each unit. The number of ISIs is one less than this + num_spikes = sq.compute_num_spikes(waveform_extractor) + + # Calculate the fraction of ISIs that are violations + isi_viol_frac_metric = { + str(unit_id): isi_violation_counts[unit_id] / (num_spikes[unit_id] - 1) + for unit_id in waveform_extractor.sorting.get_unit_ids() + } + return isi_viol_frac_metric + + +def _get_peak_offset( + waveform_extractor: si.WaveformExtractor, peak_sign: str, **metric_params +): + """Computes the shift of the waveform peak from center of window.""" + if "peak_sign" in metric_params: + del metric_params["peak_sign"] + peak_offset_inds = ( + si.postprocessing.get_template_extremum_channel_peak_shift( + waveform_extractor=waveform_extractor, + peak_sign=peak_sign, + **metric_params, + ) + ) + peak_offset = {key: int(abs(val)) for key, val in peak_offset_inds.items()} + return peak_offset + + +def _get_peak_channel( + waveform_extractor: si.WaveformExtractor, peak_sign: str, **metric_params +): + """Computes the electrode_id of the channel with the extremum peak for each unit.""" + if "peak_sign" in metric_params: + del metric_params["peak_sign"] + peak_channel_dict = si.postprocessing.get_template_extremum_channel( + waveform_extractor=waveform_extractor, + peak_sign=peak_sign, + **metric_params, + ) + peak_channel = {key: int(val) for key, val in peak_channel_dict.items()} + return peak_channel + + +def _get_num_spikes( + waveform_extractor: si.WaveformExtractor, this_unit_id: int +): + """Computes the number of spikes for each unit.""" + all_spikes = sq.compute_num_spikes(waveform_extractor) + cluster_spikes = all_spikes[this_unit_id] + return cluster_spikes + + +_metric_name_to_func = { + "snr": sq.compute_snrs, + "isi_violation": _compute_isi_violation_fractions, + "nn_isolation": sq.nearest_neighbors_isolation, + "nn_noise_overlap": sq.nearest_neighbors_noise_overlap, + "peak_offset": _get_peak_offset, + "peak_channel": _get_peak_channel, + "num_spikes": _get_num_spikes, +} + + +@schema +class AutomaticCurationParameters(dj.Manual): + definition = """ + auto_curation_params_name: varchar(200) # name of this parameter set + --- + merge_params: blob # dictionary of params to merge units + label_params: blob # dictionary params to label units + """ + + def insert1(self, key, **kwargs): + # validate the labels and then insert + # TODO: add validation for merge_params + for metric in key["label_params"]: + if metric not in _metric_name_to_func: + raise Exception(f"{metric} not in list of available metrics") + comparison_list = key["label_params"][metric] + if comparison_list[0] not in _comparison_to_function: + raise Exception( + f'{metric}: "{comparison_list[0]}" ' + f"not in list of available comparisons" + ) + if not isinstance(comparison_list[1], (int, float)): + raise Exception( + f"{metric}: {comparison_list[1]} is of type " + f"{type(comparison_list[1])} and not a number" + ) + for label in comparison_list[2]: + if label not in valid_labels: + raise Exception( + f'{metric}: "{label}" ' + f"not in list of valid labels: {valid_labels}" + ) + super().insert1(key, **kwargs) + + def insert_default(self): + # label_params parsing: Each key is the name of a metric, + # the contents are a three value list with the comparison, a value, + # and a list of labels to apply if the comparison is true + default_params = { + "auto_curation_params_name": "default", + "merge_params": {}, + "label_params": { + "nn_noise_overlap": [">", 0.1, ["noise", "reject"]] + }, + } + self.insert1(default_params, skip_duplicates=True) + + # Second default parameter set for not applying any labels, + # or merges, but adding metrics + no_label_params = { + "auto_curation_params_name": "none", + "merge_params": {}, + "label_params": {}, + } + self.insert1(no_label_params, skip_duplicates=True) + + +@schema +class AutomaticCurationSelection(dj.Manual): + definition = """ + -> QualityMetrics + -> AutomaticCurationParameters + """ + + +_comparison_to_function = { + "<": np.less, + "<=": np.less_equal, + ">": np.greater, + ">=": np.greater_equal, + "==": np.equal, +} + + +@schema +class AutomaticCuration(dj.Computed): + definition = """ + -> AutomaticCurationSelection + --- + auto_curation_key: blob # the key to the curation inserted by make + """ + + def make(self, key): + metrics_path = (QualityMetrics & key).fetch1("quality_metrics_path") + with open(metrics_path) as f: + quality_metrics = json.load(f) + + # get the curation information and the curated sorting + parent_curation = (Curation & key).fetch(as_dict=True)[0] + parent_merge_groups = parent_curation["merge_groups"] + parent_labels = parent_curation["curation_labels"] + parent_curation_id = parent_curation["curation_id"] + parent_sorting = Curation.get_curated_sorting(key) + + merge_params = (AutomaticCurationParameters & key).fetch1( + "merge_params" + ) + merge_groups, units_merged = self.get_merge_groups( + parent_sorting, parent_merge_groups, quality_metrics, merge_params + ) + + label_params = (AutomaticCurationParameters & key).fetch1( + "label_params" + ) + labels = self.get_labels( + parent_sorting, parent_labels, quality_metrics, label_params + ) + + # keep the quality metrics only if no merging occurred. + metrics = quality_metrics if not units_merged else None + + # insert this sorting into the CuratedSpikeSorting Table + # first remove keys that aren't part of the Sorting (the primary key of curation) + c_key = (SpikeSorting & key).fetch("KEY")[0] + curation_key = {item: key[item] for item in key if item in c_key} + key["auto_curation_key"] = Curation.insert_curation( + curation_key, + parent_curation_id=parent_curation_id, + labels=labels, + merge_groups=merge_groups, + metrics=metrics, + description="auto curated", + ) + + self.insert1(key) + + @staticmethod + def get_merge_groups( + sorting, parent_merge_groups, quality_metrics, merge_params + ): + """Identifies units to be merged based on the quality_metrics and + merge parameters and returns an updated list of merges for the curation. + + Parameters + --------- + sorting : spikeinterface.sorting + parent_merge_groups : list + Information about previous merges + quality_metrics : list + merge_params : dict + + Returns + ------- + merge_groups : list of lists + merge_occurred : bool + + """ + + # overview: + # 1. Use quality metrics to determine merge groups for units + # 2. Combine merge groups with current merge groups to produce union of merges + + if not merge_params: + return parent_merge_groups, False + else: + # TODO: use the metrics to identify clusters that should be merged + # new_merges should then reflect those merges and the line below should be deleted. + new_merges = [] + # append these merges to the parent merge_groups + for new_merge in new_merges: + # check to see if the first cluster listed is in a current merge group + for previous_merge in parent_merge_groups: + if new_merge[0] == previous_merge[0]: + # add the additional units in new_merge to the identified merge group. + previous_merge.extend(new_merge[1:]) + previous_merge.sort() + break + else: + # append this merge group to the list if no previous merge + parent_merge_groups.append(new_merge) + return parent_merge_groups.sort(), True + + @staticmethod + def get_labels(sorting, parent_labels, quality_metrics, label_params): + """Returns a dictionary of labels using quality_metrics and label + parameters. + + Parameters + --------- + sorting : spikeinterface.sorting + parent_labels : list + Information about previous merges + quality_metrics : list + label_params : dict + + Returns + ------- + parent_labels : list + + """ + # overview: + # 1. Use quality metrics to determine labels for units + # 2. Append labels to current labels, checking for inconsistencies + if not label_params: + return parent_labels + else: + for metric in label_params: + if metric not in quality_metrics: + Warning(f"{metric} not found in quality metrics; skipping") + else: + compare = _comparison_to_function[label_params[metric][0]] + + for unit_id in quality_metrics[metric].keys(): + # compare the quality metric to the threshold with the specified operator + # note that label_params[metric] is a three element list with a comparison operator as a string, + # the threshold value, and a list of labels to be applied if the comparison is true + if compare( + quality_metrics[metric][unit_id], + label_params[metric][1], + ): + if unit_id not in parent_labels: + parent_labels[unit_id] = label_params[metric][2] + # check if the label is already there, and if not, add it + elif ( + label_params[metric][2] + not in parent_labels[unit_id] + ): + parent_labels[unit_id].extend( + label_params[metric][2] + ) + return parent_labels + + +@schema +class CuratedSpikeSortingSelection(dj.Manual): + definition = """ + -> Curation + """ + + +@schema +class CuratedSpikeSorting(dj.Computed): + definition = """ + -> CuratedSpikeSortingSelection + --- + -> AnalysisNwbfile + units_object_id: varchar(40) + """ + + class Unit(dj.Part): + definition = """ + # Table for holding sorted units + -> CuratedSpikeSorting + unit_id: int # ID for each unit + --- + label='': varchar(200) # optional set of labels for each unit + nn_noise_overlap=-1: float # noise overlap metric for each unit + nn_isolation=-1: float # isolation score metric for each unit + isi_violation=-1: float # ISI violation score for each unit + snr=0: float # SNR for each unit + firing_rate=-1: float # firing rate + num_spikes=-1: int # total number of spikes + peak_channel=null: int # channel of maximum amplitude for each unit + """ + + def make(self, key): + unit_labels_to_remove = ["reject"] + # check that the Curation has metrics + metrics = (Curation & key).fetch1("quality_metrics") + if metrics == {}: + Warning( + f"Metrics for Curation {key} should normally be calculated before insertion here" + ) + + sorting = Curation.get_curated_sorting(key) + unit_ids = sorting.get_unit_ids() + # Get the labels for the units, add only those units that do not have 'reject' or 'noise' labels + unit_labels = (Curation & key).fetch1("curation_labels") + accepted_units = [] + for unit_id in unit_ids: + if unit_id in unit_labels: + if ( + len(set(unit_labels_to_remove) & set(unit_labels[unit_id])) + == 0 + ): + accepted_units.append(unit_id) + else: + accepted_units.append(unit_id) + + # get the labels for the accepted units + labels = {} + for unit_id in accepted_units: + if unit_id in unit_labels: + labels[unit_id] = ",".join(unit_labels[unit_id]) + + # convert unit_ids in metrics to integers, including only accepted units. + # TODO: convert to int this somewhere else + final_metrics = {} + for metric in metrics: + final_metrics[metric] = { + int(unit_id): metrics[metric][unit_id] + for unit_id in metrics[metric] + if int(unit_id) in accepted_units + } + + print(f"Found {len(accepted_units)} accepted units") + + # get the sorting and save it in the NWB file + sorting = Curation.get_curated_sorting(key) + recording = Curation.get_recording(key) + + # get the sort_interval and sorting interval list + sort_interval_name = (SpikeSortingRecording & key).fetch1( + "sort_interval_name" + ) + sort_interval = (SortInterval & key).fetch1("sort_interval") + sort_interval_list_name = (SpikeSorting & key).fetch1( + "artifact_removed_interval_list_name" + ) + + timestamps = SpikeSortingRecording._get_recording_timestamps(recording) + + ( + key["analysis_file_name"], + key["units_object_id"], + ) = Curation().save_sorting_nwb( + key, + sorting, + timestamps, + sort_interval_list_name, + sort_interval, + metrics=final_metrics, + unit_ids=accepted_units, + labels=labels, + ) + self.insert1(key) + + # now add the units + # Remove the non primary key entries. + del key["units_object_id"] + del key["analysis_file_name"] + + metric_fields = self.metrics_fields() + for unit_id in accepted_units: + key["unit_id"] = unit_id + if unit_id in labels: + key["label"] = labels[unit_id] + for field in metric_fields: + if field in final_metrics: + key[field] = final_metrics[field][unit_id] + else: + Warning( + f"No metric named {field} in computed unit quality metrics; skipping" + ) + CuratedSpikeSorting.Unit.insert1(key) + + def metrics_fields(self): + """Returns a list of the metrics that are currently in the Units table.""" + unit_info = self.Unit().fetch(limit=1, format="frame") + unit_fields = [column for column in unit_info.columns] + unit_fields.remove("label") + return unit_fields + + def fetch_nwb(self, *attrs, **kwargs): + return fetch_nwb( + self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs + ) + + +@schema +class UnitInclusionParameters(dj.Manual): + definition = """ + unit_inclusion_param_name: varchar(80) # the name of the list of thresholds for unit inclusion + --- + inclusion_param_dict: blob # the dictionary of inclusion / exclusion parameters + """ + + def insert1(self, key, **kwargs): + # check to see that the dictionary fits the specifications + # The inclusion parameter dict has the following form: + # param_dict['metric_name'] = (operator, value) + # where operator is '<', '>', <=', '>=', or '==' and value is the comparison (float) value to be used () + # param_dict['exclude_labels'] = [list of labels to exclude] + pdict = key["inclusion_param_dict"] + metrics_list = CuratedSpikeSorting().metrics_fields() + + for k in pdict: + if k not in metrics_list and k != "exclude_labels": + raise Exception( + f"key {k} is not a valid element of the inclusion_param_dict" + ) + if k in metrics_list: + if pdict[k][0] not in _comparison_to_function: + raise Exception( + f"operator {pdict[k][0]} for metric {k} is not in the valid operators list: {_comparison_to_function.keys()}" + ) + if k == "exclude_labels": + for label in pdict[k]: + if label not in valid_labels: + raise Exception( + f"exclude label {label} is not in the valid_labels list: {valid_labels}" + ) + super().insert1(key, **kwargs) + + def get_included_units( + self, curated_sorting_key, unit_inclusion_param_name + ): + """given a reference to a set of curated sorting units and the name of a unit inclusion parameter list, returns + + Parameters + ---------- + curated_sorting_key : dict + key to select a set of curated sorting + unit_inclusion_param_name : str + name of a unit inclusion parameter entry + + Returns + ------unit key + dict + key to select all of the included units + """ + curated_sortings = (CuratedSpikeSorting() & curated_sorting_key).fetch() + inc_param_dict = ( + UnitInclusionParameters + & {"unit_inclusion_param_name": unit_inclusion_param_name} + ).fetch1("inclusion_param_dict") + units = (CuratedSpikeSorting().Unit() & curated_sortings).fetch() + units_key = (CuratedSpikeSorting().Unit() & curated_sortings).fetch( + "KEY" + ) + # get a list of the metrics in the units table + metrics_list = CuratedSpikeSorting().metrics_fields() + # get the list of labels to exclude if there is one + if "exclude_labels" in inc_param_dict: + exclude_labels = inc_param_dict["exclude_labels"] + del inc_param_dict["exclude_labels"] + else: + exclude_labels = [] + + # create a list of the units to kepp. + keep = np.asarray([True] * len(units)) + for metric in inc_param_dict: + # for all units, go through each metric, compare it to the value specified, and update the list to be kept + keep = np.logical_and( + keep, + _comparison_to_function[inc_param_dict[metric][0]]( + units[metric], inc_param_dict[metric][1] + ), + ) + + # now exclude by label if it is specified + if len(exclude_labels): + included_units = [] + for unit_ind in np.ravel(np.argwhere(keep)): + labels = units[unit_ind]["label"].split(",") + exclude = False + for label in labels: + if label in exclude_labels: + keep[unit_ind] = False + break + # return units that passed all of the tests + # TODO: Make this more efficient + return {i: units_key[i] for i in np.ravel(np.argwhere(keep))} diff --git a/src/spyglass/spikesorting/v1/figurl_curation.py b/src/spyglass/spikesorting/v1/figurl_curation.py new file mode 100644 index 000000000..2280f0c2b --- /dev/null +++ b/src/spyglass/spikesorting/v1/figurl_curation.py @@ -0,0 +1,226 @@ +import datajoint as dj + +from typing import Any, Union, List, Dict + +from .spikesorting_curation import Curation +from .spikesorting_recording import SpikeSortingRecording +from .spikesorting_sorting import SpikeSorting + +import spikeinterface as si + +from sortingview.SpikeSortingView import SpikeSortingView +import kachery_cloud as kcl +import sortingview.views as vv + +schema = dj.schema("spikesorting_curation_figurl") + +# A curation figURL is a link to a visualization of a curation. +# Optionally you can specify a new_curation_uri which will be +# the location of the new manually-edited curation. The +# new_curation_uri should be a github uri of the form +# gh://user/repo/branch/path/to/curation.json +# and ideally the path should be determined by the primary key +# of the curation. The new_curation_uri can also be blank if no +# further manual curation is planned. + + +@schema +class CurationFigurlSelection(dj.Manual): + definition = """ + -> Curation + --- + new_curation_uri: varchar(2000) + """ + + +@schema +class CurationFigurl(dj.Computed): + definition = """ + -> CurationFigurlSelection + --- + url: varchar(2000) + initial_curation_uri: varchar(2000) + new_curation_uri: varchar(2000) + """ + + def make(self, key: dict): + """Create a Curation Figurl + Parameters + ---------- + key : dict + primary key of an entry from CurationFigurlSelection table + """ + + # get new_curation_uri from selection table + new_curation_uri = (CurationFigurlSelection & key).fetch1( + "new_curation_uri" + ) + + # fetch + recording_path = (SpikeSortingRecording & key).fetch1("recording_path") + sorting_path = (SpikeSorting & key).fetch1("sorting_path") + recording_label = SpikeSortingRecording._get_recording_name(key) + sorting_label = SpikeSorting._get_sorting_name(key) + unit_metrics = _reformat_metrics( + (Curation & key).fetch1("quality_metrics") + ) + initial_labels = (Curation & key).fetch1("curation_labels") + initial_merge_groups = (Curation & key).fetch1("merge_groups") + + # new_curation_uri = key["new_curation_uri"] + + # Create the initial curation and store it in kachery + for k, v in initial_labels.items(): + new_list = [] + for item in v: + if item not in new_list: + new_list.append(item) + initial_labels[k] = new_list + initial_curation = { + "labelsByUnit": initial_labels, + "mergeGroups": initial_merge_groups, + } + initial_curation_uri = kcl.store_json(initial_curation) + + # Get the recording/sorting extractors + R = si.load_extractor(recording_path) + if R.get_num_segments() > 1: + R = si.concatenate_recordings([R]) + S = si.load_extractor(sorting_path) + + # Generate the figURL + url = _generate_the_figurl( + R=R, + S=S, + initial_curation_uri=initial_curation_uri, + new_curation_uri=new_curation_uri, + recording_label=recording_label, + sorting_label=sorting_label, + unit_metrics=unit_metrics, + ) + + # insert + key["url"] = url + key["initial_curation_uri"] = initial_curation_uri + key["new_curation_uri"] = new_curation_uri + self.insert1(key) + + +def _generate_the_figurl( + *, + R: si.BaseRecording, + S: si.BaseSorting, + unit_metrics: Union[List[Any], None] = None, + initial_curation_uri: str, + recording_label: str, + sorting_label: str, + new_curation_uri: str, +): + print("Preparing spikesortingview data") + X = SpikeSortingView.create( + recording=R, + sorting=S, + segment_duration_sec=60 * 20, + snippet_len=(20, 20), + max_num_snippets_per_segment=100, + channel_neighborhood_size=7, + ) + # create a fake unit similarity matrix (for future reference) + # similarity_scores = [] + # for u1 in X.unit_ids: + # for u2 in X.unit_ids: + # similarity_scores.append( + # vv.UnitSimilarityScore( + # unit_id1=u1, + # unit_id2=u2, + # similarity=similarity_matrix[(X.unit_ids==u1),(X.unit_ids==u2)] + # ) + # ) + # Create the similarity matrix view + # unit_similarity_matrix_view = vv.UnitSimilarityMatrix( + # unit_ids=X.unit_ids, + # similarity_scores=similarity_scores + # ) + + # Assemble the views in a layout + # You can replace this with other layouts + raster_plot_subsample_max_firing_rate = 50 + spike_amplitudes_subsample_max_firing_rate = 50 + view = vv.MountainLayout( + items=[ + vv.MountainLayoutItem( + label="Summary", view=X.sorting_summary_view() + ), + vv.MountainLayoutItem( + label="Units table", + view=X.units_table_view( + unit_ids=X.unit_ids, unit_metrics=unit_metrics + ), + ), + vv.MountainLayoutItem( + label="Raster plot", + view=X.raster_plot_view( + unit_ids=X.unit_ids, + _subsample_max_firing_rate=raster_plot_subsample_max_firing_rate, + ), + ), + vv.MountainLayoutItem( + label="Spike amplitudes", + view=X.spike_amplitudes_view( + unit_ids=X.unit_ids, + _subsample_max_firing_rate=spike_amplitudes_subsample_max_firing_rate, + ), + ), + vv.MountainLayoutItem( + label="Autocorrelograms", + view=X.autocorrelograms_view(unit_ids=X.unit_ids), + ), + vv.MountainLayoutItem( + label="Cross correlograms", + view=X.cross_correlograms_view(unit_ids=X.unit_ids), + ), + vv.MountainLayoutItem( + label="Avg waveforms", + view=X.average_waveforms_view(unit_ids=X.unit_ids), + ), + vv.MountainLayoutItem( + label="Electrode geometry", view=X.electrode_geometry_view() + ), + # vv.MountainLayoutItem( + # label='Unit similarity matrix', + # view=unit_similarity_matrix_view + # ), + vv.MountainLayoutItem( + label="Curation", view=vv.SortingCuration2(), is_control=True + ), + ] + ) + url_state = ( + { + "initialSortingCuration": initial_curation_uri, + "sortingCuration": new_curation_uri, + } + if new_curation_uri + else {"sortingCuration": initial_curation_uri} + ) + label = f"{recording_label} {sorting_label}" + url = view.url(label=label, state=url_state) + return url + + +def _reformat_metrics(metrics: Dict[str, Dict[str, float]]) -> List[Dict]: + for metric_name in metrics: + metrics[metric_name] = { + str(unit_id): metric_value + for unit_id, metric_value in metrics[metric_name].items() + } + new_external_metrics = [ + { + "name": metric_name, + "label": metric_name, + "tooltip": metric_name, + "data": metric, + } + for metric_name, metric in metrics.items() + ] + return new_external_metrics diff --git a/src/spyglass/spikesorting/v1/populator.py b/src/spyglass/spikesorting/v1/populator.py new file mode 100644 index 000000000..2c23edcfc --- /dev/null +++ b/src/spyglass/spikesorting/v1/populator.py @@ -0,0 +1,300 @@ +import datajoint as dj + +from ..common import ElectrodeGroup, IntervalList +from .curation_figurl import CurationFigurl, CurationFigurlSelection +from .spikesorting_artifact import ( + ArtifactDetection, + ArtifactDetectionSelection, + ArtifactRemovedIntervalList, +) +from .spikesorting_curation import ( + AutomaticCuration, + AutomaticCurationSelection, + CuratedSpikeSorting, + CuratedSpikeSortingSelection, + Curation, + MetricSelection, + QualityMetrics, + Waveforms, + WaveformSelection, +) +from .spikesorting_recording import ( + SortGroup, + SortInterval, + SpikeSortingRecording, + SpikeSortingRecordingSelection, +) +from .spikesorting_sorting import SpikeSorting, SpikeSortingSelection + +schema = dj.schema("spikesorting_sorting") + + +@schema +class SpikeSortingPipelineParameters(dj.Manual): + definition = """ + pipeline_parameters_name: varchar(200) + --- + preproc_params_name: varchar(200) + artifact_parameters: varchar(200) + sorter: varchar(200) + sorter_params_name: varchar(200) + waveform_params_name: varchar(200) + metric_params_name: varchar(200) + auto_curation_params_name: varchar(200) + """ + + +def spikesorting_pipeline_populator( + nwb_file_name: str, + team_name: str, + fig_url_repo: str = None, + interval_list_name: str = None, + sort_interval_name: str = None, + pipeline_parameters_name: str = None, + probe_restriction: dict = {}, + artifact_parameters: str = "ampl_2000_prop_75", + preproc_params_name: str = "franklab_tetrode_hippocampus", + sorter: str = "mountainsort4", + sorter_params_name: str = "franklab_tetrode_hippocampus_30KHz_tmp", + waveform_params_name: str = "default_whitened", + metric_params_name: str = "peak_offest_num_spikes_2", + auto_curation_params_name: str = "mike_noise_03_offset_2_isi_0025_mua", +): + """Automatically populate the spike sorting pipeline for a given epoch + + Parameters + ---------- + nwb_file_name : str + Session ID + team_name : str + Which team to assign the spike sorting to + fig_url_repo : str, optional + Where to store the curation figurl json files (e.g., + 'gh://LorenFrankLab/sorting-curations/main/user/'). Default None to + skip figurl + interval_list_name : str, + if sort_interval_name not provided, will create a sort interval for the + given interval with the same name + sort_interval_name : str, default None + if provided, will use the given sort interval, requires making this + interval yourself + pipeline_parameters_name : str, optional + If provided, will lookup pipeline parameters from the + SpikeSortingPipelineParameters table, supersedes other values provided, + by default None + probe_restriction : dict, optional + Restricts analysis to sort groups with matching keys. Can use keys from + the SortGroup and ElectrodeGroup Tables (e.g. electrode_group_name, + probe_id, target_hemisphere), by default {} + artifact_parameters : str, optional + parameter set for artifact detection, by default "ampl_2000_prop_75" + preproc_params_name : str, optional + parameter set for spikesorting recording, by default + "franklab_tetrode_hippocampus" + sorter : str, optional + which spikesorting algorithm to use, by default "mountainsort4" + sorter_params_name : str, optional + parameters for the spike sorting algorithm, by default + "franklab_tetrode_hippocampus_30KHz_tmp" + waveform_params_name : str, optional + Parameters for spike waveform extraction. If empty string, will skip + automatic curation steps, by default "default_whitened" + metric_params_name : str, optional + Parameters defining which QualityMetrics to calculate and how. If empty + string, will skip automatic curation steps, by default + "peak_offest_num_spikes_2" + auto_curation_params_name : str, optional + Thresholds applied to Quality metrics for automatic unit curation. If + empty string, will skip automatic curation steps, by default + "mike_noise_03_offset_2_isi_0025_mua" + """ + nwbf_dict = dict(nwb_file_name=nwb_file_name) + # Define pipeline parameters + if pipeline_parameters_name is not None: + print(f"Using pipeline parameters {pipeline_parameters_name}") + ( + artifact_parameters, + preproc_params_name, + sorter, + sorter_params_name, + waveform_params_name, + metric_params_name, + auto_curation_params_name, + ) = ( + SpikeSortingPipelineParameters + & {"pipeline_parameters_name": pipeline_parameters_name} + ).fetch1( + "artifact_parameters", + "preproc_params_name", + "sorter", + "sorter_params_name", + "waveform_params_name", + "metric_params_name", + "auto_curation_params_name", + ) + + # make sort groups only if not currently available + # don't overwrite existing ones! + if not SortGroup() & nwbf_dict: + print("Generating sort groups") + SortGroup().set_group_by_shank(nwb_file_name) + + # Define sort interval + interval_dict = dict(**nwbf_dict, interval_list_name=interval_list_name) + + if sort_interval_name is not None: + print(f"Using sort interval {sort_interval_name}") + if not ( + SortInterval + & nwbf_dict + & {"sort_interval_name": sort_interval_name} + ): + raise KeyError(f"Sort interval {sort_interval_name} not found") + else: + print(f"Generating sort interval from {interval_list_name}") + interval_list = (IntervalList & interval_dict).fetch1("valid_times")[0] + + sort_interval_name = interval_list_name + sort_interval = interval_list + + SortInterval.insert1( + { + **nwbf_dict, + "sort_interval_name": sort_interval_name, + "sort_interval": sort_interval, + }, + skip_duplicates=True, + ) + + sort_dict = dict(**nwbf_dict, sort_interval_name=sort_interval_name) + + # find desired sort group(s) for these settings + sort_group_id_list = ( + (SortGroup.SortGroupElectrode * ElectrodeGroup) + & nwbf_dict + & probe_restriction + ).fetch("sort_group_id") + + # make spike sorting recording + print("Generating spike sorting recording") + for sort_group_id in sort_group_id_list: + ssr_key = dict( + **sort_dict, + sort_group_id=sort_group_id, # See SortGroup + preproc_params_name=preproc_params_name, # See preproc_params + interval_list_name=interval_list_name, + team_name=team_name, + ) + SpikeSortingRecordingSelection.insert1(ssr_key, skip_duplicates=True) + + SpikeSortingRecording.populate(interval_dict) + + # Artifact detection + print("Running artifact detection") + artifact_keys = [ + {**k, "artifact_params_name": artifact_parameters} + for k in (SpikeSortingRecordingSelection() & interval_dict).fetch("KEY") + ] + ArtifactDetectionSelection().insert(artifact_keys, skip_duplicates=True) + ArtifactDetection.populate(interval_dict) + + # Spike sorting + print("Running spike sorting") + for artifact_key in artifact_keys: + ss_key = dict( + **(ArtifactDetection & artifact_key).fetch1("KEY"), + **(ArtifactRemovedIntervalList() & artifact_key).fetch1("KEY"), + sorter=sorter, + sorter_params_name=sorter_params_name, + ) + ss_key.pop("artifact_params_name") + SpikeSortingSelection.insert1(ss_key, skip_duplicates=True) + SpikeSorting.populate(sort_dict) + + # initial curation + print("Beginning curation") + for sorting_key in (SpikeSorting() & sort_dict).fetch("KEY"): + Curation.insert_curation(sorting_key) + + # Calculate quality metrics and perform automatic curation if specified + if ( + len(waveform_params_name) > 0 + and len(metric_params_name) > 0 + and len(auto_curation_params_name) > 0 + ): + # Extract waveforms + print("Extracting waveforms") + curation_keys = [ + {**k, "waveform_params_name": waveform_params_name} + for k in (Curation() & sort_dict).fetch("KEY") + ] + WaveformSelection.insert(curation_keys, skip_duplicates=True) + Waveforms.populate(sort_dict) + + # Quality Metrics + print("Calculating quality metrics") + waveform_keys = [ + {**k, "metric_params_name": metric_params_name} + for k in (Waveforms() & sort_dict).fetch("KEY") + ] + MetricSelection.insert(waveform_keys, skip_duplicates=True) + QualityMetrics().populate(sort_dict) + + # Automatic Curation + print("Creating automatic curation") + metric_keys = [ + {**k, "auto_curation_params_name": auto_curation_params_name} + for k in (QualityMetrics() & sort_dict).fetch("KEY") + ] + AutomaticCurationSelection.insert(metric_keys, skip_duplicates=True) + AutomaticCuration().populate(sort_dict) + + # Curated Spike Sorting + # get curation keys of the automatic curation to populate into curated + # spike sorting selection + print("Creating curated spike sorting") + auto_key_list = (AutomaticCuration() & sort_dict).fetch( + "auto_curation_key" + ) + for auto_key in auto_key_list: + curation_auto_key = (Curation() & auto_key).fetch1("KEY") + CuratedSpikeSortingSelection.insert1( + curation_auto_key, skip_duplicates=True + ) + + else: + # Perform no automatic curation, just populate curated spike sorting + # selection with the initial curation. Used in case of clusterless + # decoding + print("Creating curated spike sorting") + curation_keys = (Curation() & sort_dict).fetch("KEY") + for curation_key in curation_keys: + CuratedSpikeSortingSelection.insert1( + curation_auto_key, skip_duplicates=True + ) + + # Populate curated spike sorting + CuratedSpikeSorting.populate(sort_dict) + + if fig_url_repo: + # Curation Figurl + print("Creating curation figurl") + sort_interval_name = interval_list_name + "_entire" + gh_url = ( + fig_url_repo + + str(nwb_file_name + "_" + sort_interval_name) # session id + + "/{}" # tetrode using auto_id['sort_group_id'] + + "/curation.json" + ) + + for auto_id in (AutomaticCuration() & sort_dict).fetch( + "auto_curation_key" + ): + auto_curation_out_key = dict( + **(Curation() & auto_id).fetch1("KEY"), + new_curation_uri=gh_url.format(str(auto_id["sort_group_id"])), + ) + CurationFigurlSelection.insert1( + auto_curation_out_key, skip_duplicates=True + ) + CurationFigurl.populate(auto_curation_out_key) diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py new file mode 100644 index 000000000..cdad793e9 --- /dev/null +++ b/src/spyglass/spikesorting/v1/recording.py @@ -0,0 +1,627 @@ +import os +import shutil +from functools import reduce +from pathlib import Path + +import datajoint as dj +import numpy as np +import probeinterface as pi +import spikeinterface as si +import spikeinterface.extractors as se + +from ..common.common_device import Probe, ProbeType # noqa: F401 +from ..common.common_ephys import Electrode, ElectrodeGroup +from ..common.common_interval import ( + IntervalList, + interval_list_intersect, + intervals_by_length, + union_adjacent_index, +) +from ..common.common_lab import LabTeam # noqa: F401 +from ..common.common_nwbfile import Nwbfile +from ..common.common_session import Session # noqa: F401 +from ..utils.dj_helper_fn import dj_replace +from ..settings import recording_dir + +schema = dj.schema("spikesorting_recording") + + +@schema +class SortGroup(dj.Manual): + definition = """ + # Set of electrodes that will be sorted together + -> Session + sort_group_id: int # identifier for a group of electrodes + --- + sort_reference_electrode_id = -1: int # the electrode to use for reference. -1: no reference, -2: common median + """ + + class SortGroupElectrode(dj.Part): + definition = """ + -> SortGroup + -> Electrode + """ + + def set_group_by_shank( + self, + nwb_file_name: str, + references: dict = None, + omit_ref_electrode_group=False, + omit_unitrode=True, + ): + """Divides electrodes into groups based on their shank position. + + * Electrodes from probes with 1 shank (e.g. tetrodes) are placed in a + single group + * Electrodes from probes with multiple shanks (e.g. polymer probes) are + placed in one group per shank + * Bad channels are omitted + + Parameters + ---------- + nwb_file_name : str + the name of the NWB file whose electrodes should be put into + sorting groups + references : dict, optional + If passed, used to set references. Otherwise, references set using + original reference electrodes from config. Keys: electrode groups. + Values: reference electrode. + omit_ref_electrode_group : bool + Optional. If True, no sort group is defined for electrode group of + reference. + omit_unitrode : bool + Optional. If True, no sort groups are defined for unitrodes. + """ + # delete any current groups + (SortGroup & {"nwb_file_name": nwb_file_name}).delete() + # get the electrodes from this NWB file + electrodes = ( + Electrode() + & {"nwb_file_name": nwb_file_name} + & {"bad_channel": "False"} + ).fetch() + e_groups = list(np.unique(electrodes["electrode_group_name"])) + e_groups.sort(key=int) # sort electrode groups numerically + sort_group = 0 + sg_key = dict() + sge_key = dict() + sg_key["nwb_file_name"] = sge_key["nwb_file_name"] = nwb_file_name + for e_group in e_groups: + # for each electrode group, get a list of the unique shank numbers + shank_list = np.unique( + electrodes["probe_shank"][ + electrodes["electrode_group_name"] == e_group + ] + ) + sge_key["electrode_group_name"] = e_group + # get the indices of all electrodes in this group / shank and set their sorting group + for shank in shank_list: + sg_key["sort_group_id"] = sge_key["sort_group_id"] = sort_group + # specify reference electrode. Use 'references' if passed, otherwise use reference from config + if not references: + shank_elect_ref = electrodes[ + "original_reference_electrode" + ][ + np.logical_and( + electrodes["electrode_group_name"] == e_group, + electrodes["probe_shank"] == shank, + ) + ] + if np.max(shank_elect_ref) == np.min(shank_elect_ref): + sg_key["sort_reference_electrode_id"] = shank_elect_ref[ + 0 + ] + else: + ValueError( + f"Error in electrode group {e_group}: reference " + + "electrodes are not all the same" + ) + else: + if e_group not in references.keys(): + raise Exception( + f"electrode group {e_group} not a key in " + + "references, so cannot set reference" + ) + else: + sg_key["sort_reference_electrode_id"] = references[ + e_group + ] + # Insert sort group and sort group electrodes + reference_electrode_group = electrodes[ + electrodes["electrode_id"] + == sg_key["sort_reference_electrode_id"] + ][ + "electrode_group_name" + ] # reference for this electrode group + if ( + len(reference_electrode_group) == 1 + ): # unpack single reference + reference_electrode_group = reference_electrode_group[0] + elif (int(sg_key["sort_reference_electrode_id"]) > 0) and ( + len(reference_electrode_group) != 1 + ): + raise Exception( + "Should have found exactly one electrode group for " + + "reference electrode, but found " + + f"{len(reference_electrode_group)}." + ) + if omit_ref_electrode_group and ( + str(e_group) == str(reference_electrode_group) + ): + print( + f"Omitting electrode group {e_group} from sort groups " + + "because contains reference." + ) + continue + shank_elect = electrodes["electrode_id"][ + np.logical_and( + electrodes["electrode_group_name"] == e_group, + electrodes["probe_shank"] == shank, + ) + ] + if ( + omit_unitrode and len(shank_elect) == 1 + ): # omit unitrodes if indicated + print( + f"Omitting electrode group {e_group}, shank {shank} from sort groups because unitrode." + ) + continue + self.insert1(sg_key) + for elect in shank_elect: + sge_key["electrode_id"] = elect + self.SortGroupElectrode().insert1(sge_key) + sort_group += 1 + + def set_group_by_electrode_group(self, nwb_file_name: str): + """Assign groups to all non-bad channel electrodes based on their electrode group + and sets the reference for each group to the reference for the first channel of the group. + + Parameters + ---------- + nwb_file_name: str + the name of the nwb whose electrodes should be put into sorting groups + """ + # delete any current groups + (SortGroup & {"nwb_file_name": nwb_file_name}).delete() + # get the electrodes from this NWB file + electrodes = ( + Electrode() + & {"nwb_file_name": nwb_file_name} + & {"bad_channel": "False"} + ).fetch() + e_groups = np.unique(electrodes["electrode_group_name"]) + sg_key = dict() + sge_key = dict() + sg_key["nwb_file_name"] = sge_key["nwb_file_name"] = nwb_file_name + sort_group = 0 + for e_group in e_groups: + sge_key["electrode_group_name"] = e_group + # sg_key['sort_group_id'] = sge_key['sort_group_id'] = sort_group + # TEST + sg_key["sort_group_id"] = sge_key["sort_group_id"] = int(e_group) + # get the list of references and make sure they are all the same + shank_elect_ref = electrodes["original_reference_electrode"][ + electrodes["electrode_group_name"] == e_group + ] + if np.max(shank_elect_ref) == np.min(shank_elect_ref): + sg_key["sort_reference_electrode_id"] = shank_elect_ref[0] + else: + ValueError( + f"Error in electrode group {e_group}: reference electrodes are not all the same" + ) + self.insert1(sg_key) + + shank_elect = electrodes["electrode_id"][ + electrodes["electrode_group_name"] == e_group + ] + for elect in shank_elect: + sge_key["electrode_id"] = elect + self.SortGroupElectrode().insert1(sge_key) + sort_group += 1 + + def set_reference_from_list(self, nwb_file_name, sort_group_ref_list): + """ + Set the reference electrode from a list containing sort groups and reference electrodes + :param: sort_group_ref_list - 2D array or list where each row is [sort_group_id reference_electrode] + :param: nwb_file_name - The name of the NWB file whose electrodes' references should be updated + :return: Null + """ + key = dict() + key["nwb_file_name"] = nwb_file_name + sort_group_list = (SortGroup() & key).fetch1() + for sort_group in sort_group_list: + key["sort_group_id"] = sort_group + self.insert( + dj_replace( + sort_group_list, + sort_group_ref_list, + "sort_group_id", + "sort_reference_electrode_id", + ), + replace="True", + ) + + def get_geometry(self, sort_group_id, nwb_file_name): + """ + Returns a list with the x,y coordinates of the electrodes in the sort group + for use with the SpikeInterface package. + + Converts z locations to y where appropriate. + + Parameters + ---------- + sort_group_id : int + nwb_file_name : str + + Returns + ------- + geometry : list + List of coordinate pairs, one per electrode + """ + + # create the channel_groups dictiorary + channel_group = dict() + key = dict() + key["nwb_file_name"] = nwb_file_name + electrodes = (Electrode() & key).fetch() + + key["sort_group_id"] = sort_group_id + sort_group_electrodes = (SortGroup.SortGroupElectrode() & key).fetch() + electrode_group_name = sort_group_electrodes["electrode_group_name"][0] + probe_id = ( + ElectrodeGroup + & { + "nwb_file_name": nwb_file_name, + "electrode_group_name": electrode_group_name, + } + ).fetch1("probe_id") + channel_group[sort_group_id] = dict() + channel_group[sort_group_id]["channels"] = sort_group_electrodes[ + "electrode_id" + ].tolist() + + n_chan = len(channel_group[sort_group_id]["channels"]) + + geometry = np.zeros((n_chan, 2), dtype="float") + tmp_geom = np.zeros((n_chan, 3), dtype="float") + for i, electrode_id in enumerate( + channel_group[sort_group_id]["channels"] + ): + # get the relative x and y locations of this channel from the probe table + probe_electrode = int( + electrodes["probe_electrode"][ + electrodes["electrode_id"] == electrode_id + ] + ) + rel_x, rel_y, rel_z = ( + Probe().Electrode() + & {"probe_id": probe_id, "probe_electrode": probe_electrode} + ).fetch("rel_x", "rel_y", "rel_z") + # TODO: Fix this HACK when we can use probeinterface: + rel_x = float(rel_x) + rel_y = float(rel_y) + rel_z = float(rel_z) + tmp_geom[i, :] = [rel_x, rel_y, rel_z] + + # figure out which columns have coordinates + n_found = 0 + for i in range(3): + if np.any(np.nonzero(tmp_geom[:, i])): + if n_found < 2: + geometry[:, n_found] = tmp_geom[:, i] + n_found += 1 + else: + Warning( + "Relative electrode locations have three coordinates; only two are currently supported" + ) + return np.ndarray.tolist(geometry) + + +@schema +class SortInterval(dj.Manual): + definition = """ + -> Session + sort_interval_name: varchar(200) # name for this interval + --- + sort_interval: longblob # 1D numpy array with start and end time for a single interval to be used for spike sorting + """ + + +@schema +class SpikeSortingPreprocessingParameters(dj.Manual): + definition = """ + preproc_params_name: varchar(200) + --- + preproc_params: blob + """ + + def insert_default(self): + # set up the default filter parameters + freq_min = 300 # high pass filter value + freq_max = 6000 # low pass filter value + margin_ms = 5 # margin in ms on border to avoid border effect + seed = 0 # random seed for whitening + + key = dict() + key["preproc_params_name"] = "default" + key["preproc_params"] = { + "frequency_min": freq_min, + "frequency_max": freq_max, + "margin_ms": margin_ms, + "seed": seed, + } + self.insert1(key, skip_duplicates=True) + + +@schema +class SpikeSortingRecordingSelection(dj.Manual): + definition = """ + # Defines recordings to be sorted + -> SortGroup + -> SortInterval + -> SpikeSortingPreprocessingParameters + -> LabTeam + --- + -> IntervalList + """ + + +@schema +class SpikeSortingRecording(dj.Computed): + definition = """ + -> SpikeSortingRecordingSelection + --- + recording_path: varchar(1000) + -> IntervalList.proj(sort_interval_list_name='interval_list_name') + """ + + def make(self, key): + sort_interval_valid_times = self._get_sort_interval_valid_times(key) + recording = self._get_filtered_recording(key) + recording_name = self._get_recording_name(key) + + # Path to files that will hold the recording extractors + recording_path = str(recording_dir / Path(recording_name)) + if os.path.exists(recording_path): + shutil.rmtree(recording_path) + + recording.save( + folder=recording_path, chunk_duration="10000ms", n_jobs=8 + ) + + IntervalList.insert1( + { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": recording_name, + "valid_times": sort_interval_valid_times, + }, + replace=True, + ) + + self.insert1( + { + **key, + # store the list of valid times for the sort + "sort_interval_list_name": recording_name, + "recording_path": recording_path, + } + ) + + @staticmethod + def _get_recording_name(key): + return "_".join( + [ + key["nwb_file_name"], + key["sort_interval_name"], + str(key["sort_group_id"]), + key["preproc_params_name"], + ] + ) + + @staticmethod + def _get_recording_timestamps(recording): + num_segments = recording.get_num_segments() + + if num_segments <= 1: + return recording.get_times() + + frames_per_segment = [0] + [ + recording.get_num_frames(segment_index=i) + for i in range(num_segments) + ] + + cumsum_frames = np.cumsum(frames_per_segment) + total_frames = np.sum(frames_per_segment) + + timestamps = np.zeros((total_frames,)) + for i in range(num_segments): + start_index = cumsum_frames[i] + end_index = cumsum_frames[i + 1] + timestamps[start_index:end_index] = recording.get_times( + segment_index=i + ) + + return timestamps + + def _get_sort_interval_valid_times(self, key): + """Identifies the intersection between sort interval specified by the user + and the valid times (times for which neural data exist) + + Parameters + ---------- + key: dict + specifies a (partially filled) entry of SpikeSorting table + + Returns + ------- + sort_interval_valid_times: ndarray of tuples + (start, end) times for valid stretches of the sorting interval + + """ + sort_interval = ( + SortInterval + & { + "nwb_file_name": key["nwb_file_name"], + "sort_interval_name": key["sort_interval_name"], + } + ).fetch1("sort_interval") + + interval_list_name = (SpikeSortingRecordingSelection & key).fetch1( + "interval_list_name" + ) + + valid_interval_times = ( + IntervalList + & { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": interval_list_name, + } + ).fetch1("valid_times") + + valid_sort_times = interval_list_intersect( + sort_interval, valid_interval_times + ) + # Exclude intervals shorter than specified length + params = (SpikeSortingPreprocessingParameters & key).fetch1( + "preproc_params" + ) + if "min_segment_length" in params: + valid_sort_times = intervals_by_length( + valid_sort_times, min_length=params["min_segment_length"] + ) + return valid_sort_times + + def _get_filtered_recording(self, key: dict): + """Filters and references a recording + * Loads the NWB file created during insertion as a spikeinterface Recording + * Slices recording in time (interval) and space (channels); + recording chunks from disjoint intervals are concatenated + * Applies referencing and bandpass filtering + + Parameters + ---------- + key: dict, + primary key of SpikeSortingRecording table + + Returns + ------- + recording: si.Recording + """ + + nwb_file_abs_path = Nwbfile().get_abs_path(key["nwb_file_name"]) + recording = se.read_nwb_recording( + nwb_file_abs_path, load_time_vector=True + ) + + valid_sort_times = self._get_sort_interval_valid_times(key) + # shape is (N, 2) + valid_sort_times_indices = np.array( + [ + np.searchsorted(recording.get_times(), interval) + for interval in valid_sort_times + ] + ) + # join intervals of indices that are adjacent + valid_sort_times_indices = reduce( + union_adjacent_index, valid_sort_times_indices + ) + if valid_sort_times_indices.ndim == 1: + valid_sort_times_indices = np.expand_dims( + valid_sort_times_indices, 0 + ) + + # create an AppendRecording if there is more than one disjoint sort interval + if len(valid_sort_times_indices) > 1: + recordings_list = [] + for interval_indices in valid_sort_times_indices: + recording_single = recording.frame_slice( + start_frame=interval_indices[0], + end_frame=interval_indices[1], + ) + recordings_list.append(recording_single) + recording = si.append_recordings(recordings_list) + else: + recording = recording.frame_slice( + start_frame=valid_sort_times_indices[0][0], + end_frame=valid_sort_times_indices[0][1], + ) + + channel_ids = ( + SortGroup.SortGroupElectrode + & { + "nwb_file_name": key["nwb_file_name"], + "sort_group_id": key["sort_group_id"], + } + ).fetch("electrode_id") + ref_channel_id = ( + SortGroup + & { + "nwb_file_name": key["nwb_file_name"], + "sort_group_id": key["sort_group_id"], + } + ).fetch1("sort_reference_electrode_id") + channel_ids = np.setdiff1d(channel_ids, ref_channel_id) + + # include ref channel in first slice, then exclude it in second slice + if ref_channel_id >= 0: + channel_ids_ref = np.append(channel_ids, ref_channel_id) + recording = recording.channel_slice(channel_ids=channel_ids_ref) + + recording = si.preprocessing.common_reference( + recording, reference="single", ref_channel_ids=ref_channel_id + ) + recording = recording.channel_slice(channel_ids=channel_ids) + elif ref_channel_id == -2: + recording = recording.channel_slice(channel_ids=channel_ids) + recording = si.preprocessing.common_reference( + recording, reference="global", operator="median" + ) + else: + raise ValueError("Invalid reference channel ID") + filter_params = (SpikeSortingPreprocessingParameters & key).fetch1( + "preproc_params" + ) + recording = si.preprocessing.bandpass_filter( + recording, + freq_min=filter_params["frequency_min"], + freq_max=filter_params["frequency_max"], + ) + + # if the sort group is a tetrode, change the channel location + # note that this is a workaround that would be deprecated when spikeinterface uses 3D probe locations + probe_type = [] + electrode_group = [] + for channel_id in channel_ids: + probe_type.append( + ( + Electrode * Probe + & { + "nwb_file_name": key["nwb_file_name"], + "electrode_id": channel_id, + } + ).fetch1("probe_type") + ) + electrode_group.append( + ( + Electrode + & { + "nwb_file_name": key["nwb_file_name"], + "electrode_id": channel_id, + } + ).fetch1("electrode_group_name") + ) + if ( + all(p == "tetrode_12.5" for p in probe_type) + and len(probe_type) == 4 + and all(eg == electrode_group[0] for eg in electrode_group) + ): + tetrode = pi.Probe(ndim=2) + position = [[0, 0], [0, 12.5], [12.5, 0], [12.5, 12.5]] + tetrode.set_contacts( + position, shapes="circle", shape_params={"radius": 6.25} + ) + tetrode.set_contact_ids(channel_ids) + tetrode.set_device_channel_indices(np.arange(4)) + recording = recording.set_probe(tetrode, in_place=True) + + return recording diff --git a/src/spyglass/spikesorting/v1/sorting.py b/src/spyglass/spikesorting/v1/sorting.py new file mode 100644 index 000000000..8178ef513 --- /dev/null +++ b/src/spyglass/spikesorting/v1/sorting.py @@ -0,0 +1,310 @@ +import os +import shutil +import tempfile +import time +import uuid +from pathlib import Path + +import datajoint as dj +import numpy as np +import spikeinterface as si +import spikeinterface.preprocessing as sip +import spikeinterface.sorters as sis +from spikeinterface.sortingcomponents.peak_detection import detect_peaks + +from ..common.common_lab import LabMember, LabTeam +from ..common.common_nwbfile import AnalysisNwbfile +from ..settings import temp_dir, sorting_dir +from .spikesorting_artifact import ArtifactRemovedIntervalList +from .spikesorting_recording import ( + SpikeSortingRecording, + SpikeSortingRecordingSelection, +) + +schema = dj.schema("spikesorting_sorting") + + +@schema +class SpikeSorterParameters(dj.Manual): + definition = """ + sorter: varchar(200) + sorter_params_name: varchar(200) + --- + sorter_params: blob + """ + + def insert_default(self): + """Default params from spike sorters available via spikeinterface""" + sorters = sis.available_sorters() + for sorter in sorters: + sorter_params = sis.get_default_sorter_params(sorter) + self.insert1( + [sorter, "default", sorter_params], skip_duplicates=True + ) + + # Insert Frank lab defaults + # Hippocampus tetrode default + sorter = "mountainsort4" + sorter_params_name = "franklab_tetrode_hippocampus_30KHz" + sorter_params = { + "detect_sign": -1, + "adjacency_radius": 100, + "freq_min": 600, + "freq_max": 6000, + "filter": False, + "whiten": True, + "num_workers": 1, + "clip_size": 40, + "detect_threshold": 3, + "detect_interval": 10, + } + self.insert1( + [sorter, sorter_params_name, sorter_params], skip_duplicates=True + ) + + # Cortical probe default + sorter = "mountainsort4" + sorter_params_name = "franklab_probe_ctx_30KHz" + sorter_params = { + "detect_sign": -1, + "adjacency_radius": 100, + "freq_min": 300, + "freq_max": 6000, + "filter": False, + "whiten": True, + "num_workers": 1, + "clip_size": 40, + "detect_threshold": 3, + "detect_interval": 10, + } + self.insert1( + [sorter, sorter_params_name, sorter_params], skip_duplicates=True + ) + + # clusterless defaults + sorter = "clusterless_thresholder" + sorter_params_name = "default_clusterless" + sorter_params = dict( + detect_threshold=100.0, # uV + # Locally exclusive means one unit per spike detected + method="locally_exclusive", + peak_sign="neg", + exclude_sweep_ms=0.1, + local_radius_um=100, + # noise levels needs to be 1.0 so the units are in uV and not MAD + noise_levels=np.asarray([1.0]), + random_chunk_kwargs={}, + # output needs to be set to sorting for the rest of the pipeline + outputs="sorting", + ) + self.insert1( + [sorter, sorter_params_name, sorter_params], skip_duplicates=True + ) + + +@schema +class SpikeSortingSelection(dj.Manual): + definition = """ + # Table for holding selection of recording and parameters for each spike sorting run + -> SpikeSortingRecording + -> SpikeSorterParameters + -> ArtifactRemovedIntervalList + --- + import_path = "": varchar(200) # optional path to previous curated sorting output + """ + + +@schema +class SpikeSorting(dj.Computed): + definition = """ + -> SpikeSortingSelection + --- + sorting_path: varchar(1000) + time_of_sort: int # in Unix time, to the nearest second + """ + + def make(self, key: dict): + """Runs spike sorting on the data and parameters specified by the + SpikeSortingSelection table and inserts a new entry to SpikeSorting table. + + Specifically, + 1. Loads saved recording and runs the sort on it with spikeinterface + 2. Saves the sorting with spikeinterface + 3. Creates an analysis NWB file and saves the sorting there + (this is redundant with 2; will change in the future) + + """ + # CBroz: does this not work w/o arg? as .populate() ? + recording_path = (SpikeSortingRecording & key).fetch1("recording_path") + recording = si.load_extractor(recording_path) + + # first, get the timestamps + timestamps = SpikeSortingRecording._get_recording_timestamps(recording) + _ = recording.get_sampling_frequency() + # then concatenate the recordings + # Note: the timestamps are lost upon concatenation, + # i.e. concat_recording.get_times() doesn't return true timestamps anymore. + # but concat_recording.recoring_list[i].get_times() will return correct + # timestamps for ith recording. + if recording.get_num_segments() > 1 and isinstance( + recording, si.AppendSegmentRecording + ): + recording = si.concatenate_recordings(recording.recording_list) + elif recording.get_num_segments() > 1 and isinstance( + recording, si.BinaryRecordingExtractor + ): + recording = si.concatenate_recordings([recording]) + + # load artifact intervals + artifact_times = ( + ArtifactRemovedIntervalList + & { + "artifact_removed_interval_list_name": key[ + "artifact_removed_interval_list_name" + ] + } + ).fetch1("artifact_times") + if len(artifact_times): + if artifact_times.ndim == 1: + artifact_times = np.expand_dims(artifact_times, 0) + + # convert artifact intervals to indices + list_triggers = [] + for interval in artifact_times: + list_triggers.append( + np.arange( + np.searchsorted(timestamps, interval[0]), + np.searchsorted(timestamps, interval[1]), + ) + ) + list_triggers = [list(np.concatenate(list_triggers))] + recording = sip.remove_artifacts( + recording=recording, + list_triggers=list_triggers, + ms_before=None, + ms_after=None, + mode="zeros", + ) + + print(f"Running spike sorting on {key}...") + sorter, sorter_params = (SpikeSorterParameters & key).fetch1( + "sorter", "sorter_params" + ) + + sorter_temp_dir = tempfile.TemporaryDirectory(dir=temp_dir) + # add tempdir option for mountainsort + sorter_params["tempdir"] = sorter_temp_dir.name + + if sorter == "clusterless_thresholder": + # need to remove tempdir and whiten from sorter_params + sorter_params.pop("tempdir", None) + sorter_params.pop("whiten", None) + sorter_params.pop("outputs", None) + + # Detect peaks for clusterless decoding + detected_spikes = detect_peaks(recording, **sorter_params) + sorting = si.NumpySorting.from_times_labels( + times_list=detected_spikes["sample_index"], + labels_list=np.zeros(len(detected_spikes), dtype=np.int), + sampling_frequency=recording.get_sampling_frequency(), + ) + else: + if "whiten" in sorter_params.keys(): + if sorter_params["whiten"]: + sorter_params["whiten"] = False # set whiten to False + # whiten recording separately; make sure dtype is float32 + # to avoid downstream error with svd + recording = sip.whiten(recording, dtype="float32") + sorting = sis.run_sorter( + sorter, + recording, + output_folder=sorter_temp_dir.name, + delete_output_folder=True, + **sorter_params, + ) + key["time_of_sort"] = int(time.time()) + + print("Saving sorting results...") + + sorting_folder = Path(sorting_dir) + + sorting_name = self._get_sorting_name(key) + key["sorting_path"] = str(sorting_folder / Path(sorting_name)) + if os.path.exists(key["sorting_path"]): + shutil.rmtree(key["sorting_path"]) + sorting = sorting.save(folder=key["sorting_path"]) + self.insert1(key) + + def delete(self): + """Extends the delete method of base class to implement permission checking. + Note that this is NOT a security feature, as anyone that has access to source code + can disable it; it just makes it less likely to accidentally delete entries. + """ + current_user_name = dj.config["database.user"] + entries = self.fetch() + permission_bool = np.zeros((len(entries),)) + print( + f"Attempting to delete {len(entries)} entries, checking permission..." + ) + + for entry_idx in range(len(entries)): + # check the team name for the entry, then look up the members in that team, + # then get their datajoint user names + team_name = ( + SpikeSortingRecordingSelection + & (SpikeSortingRecordingSelection & entries[entry_idx]).proj() + ).fetch1()["team_name"] + lab_member_name_list = ( + LabTeam.LabTeamMember & {"team_name": team_name} + ).fetch("lab_member_name") + datajoint_user_names = [] + for lab_member_name in lab_member_name_list: + datajoint_user_names.append( + ( + LabMember.LabMemberInfo + & {"lab_member_name": lab_member_name} + ).fetch1("datajoint_user_name") + ) + permission_bool[entry_idx] = ( + current_user_name in datajoint_user_names + ) + if np.sum(permission_bool) == len(entries): + print("Permission to delete all specified entries granted.") + super().delete() + else: + raise Exception( + "You do not have permission to delete all specified" + "entries. Not deleting anything." + ) + + def fetch_nwb(self, *attrs, **kwargs): + raise NotImplementedError + return None + # return fetch_nwb(self, (AnalysisNwbfile, 'analysis_file_abs_path'), *attrs, **kwargs) + + def nightly_cleanup(self): + """Clean up spike sorting directories that are not in the SpikeSorting table. + This should be run after AnalysisNwbFile().nightly_cleanup() + """ + # get a list of the files in the spike sorting storage directory + dir_names = next(os.walk(sorting_dir))[1] + # now retrieve a list of the currently used analysis nwb files + analysis_file_names = self.fetch("analysis_file_name") + for dir in dir_names: + if dir not in analysis_file_names: + full_path = str(Path(sorting_dir) / dir) + print(f"removing {full_path}") + shutil.rmtree(str(Path(sorting_dir) / dir)) + + @staticmethod + def _get_sorting_name(key): + recording_name = SpikeSortingRecording._get_recording_name(key) + sorting_name = ( + recording_name + "_" + str(uuid.uuid4())[0:8] + "_spikesorting" + ) + return sorting_name + + # TODO: write a function to import sorting done outside of dj + + def _import_sorting(self, key): + raise NotImplementedError