From 20d753a0b83f77a74a4236d093995815d32a8678 Mon Sep 17 00:00:00 2001 From: CBroz1 Date: Wed, 8 Nov 2023 10:53:23 -0800 Subject: [PATCH] Add v1s --- src/spyglass/common/common_interval.py | 315 ++-- src/spyglass/spikesorting/v1/artifact.py | 392 ++--- src/spyglass/spikesorting/v1/curation.py | 1302 ++++------------- .../spikesorting/v1/figurl_curation.py | 223 +-- .../spikesorting/v1/metric_curation.py | 555 +++++++ src/spyglass/spikesorting/v1/metric_utils.py | 81 + src/spyglass/spikesorting/v1/recording.py | 999 ++++++++----- src/spyglass/spikesorting/v1/sorting.py | 515 ++++--- src/spyglass/spikesorting/v1/utils.py | 18 + 9 files changed, 2405 insertions(+), 1995 deletions(-) create mode 100644 src/spyglass/spikesorting/v1/metric_curation.py create mode 100644 src/spyglass/spikesorting/v1/metric_utils.py create mode 100644 src/spyglass/spikesorting/v1/utils.py diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index 99c9a5bdf..e0658182f 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -18,49 +18,44 @@ class IntervalList(dj.Manual): definition = """ # Time intervals used for analysis -> Session - interval_list_name: varchar(170) # descriptive name of this interval list + interval_list_name: varchar(200) # descriptive name of this interval list --- - valid_times: longblob # numpy array with start/end times for each interval + valid_times: longblob # numpy array with start and end times for each interval """ - # See #630, #664. Excessive key length. - @classmethod def insert_from_nwbfile(cls, nwbf, *, nwb_file_name): - """Add each entry in the NWB file epochs table to the IntervalList. + """Add each entry in the NWB file epochs table to the IntervalList table. - The interval list name for each epoch is set to the first tag for the - epoch. If the epoch has no tags, then 'interval_x' will be used as the - interval list name, where x is the index (0-indexed) of the epoch in the - epochs table. The start time and stop time of the epoch are stored in - the valid_times field as a numpy array of [start time, stop time] for - each epoch. + The interval list name for each epoch is set to the first tag for the epoch. + If the epoch has no tags, then 'interval_x' will be used as the interval list name, where x is the index + (0-indexed) of the epoch in the epochs table. + The start time and stop time of the epoch are stored in the valid_times field as a numpy array of + [start time, stop time] for each epoch. Parameters ---------- nwbf : pynwb.NWBFile The source NWB file object. nwb_file_name : str - The file name of the NWB file, used as a primary key to the Session - table. + The file name of the NWB file, used as a primary key to the Session table. """ if nwbf.epochs is None: print("No epochs found in NWB file.") return - epochs = nwbf.epochs.to_dataframe() - - for _, epoch_data in epochs.iterrows(): - epoch_dict = { - "nwb_file_name": nwb_file_name, - "interval_list_name": epoch_data.tags[0] - if epoch_data.tags - else f"interval_{epoch_data[0]}", - "valid_times": np.asarray( - [[epoch_data.start_time, epoch_data.stop_time]] - ), - } - + for epoch_index, epoch_data in epochs.iterrows(): + epoch_dict = dict() + epoch_dict["nwb_file_name"] = nwb_file_name + if epoch_data.tags[0]: + epoch_dict["interval_list_name"] = epoch_data.tags[0] + else: + epoch_dict["interval_list_name"] = "interval_" + str( + epoch_index + ) + epoch_dict["valid_times"] = np.asarray( + [[epoch_data.start_time, epoch_data.stop_time]] + ) cls.insert1(epoch_dict, skip_duplicates=True) def plot_intervals(self, figsize=(20, 5)): @@ -150,7 +145,7 @@ def intervals_by_length(interval_list, min_length=0.0, max_length=1e10): Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval in seconds. + Each element is (start time, stop time), i.e. an interval. Unit is seconds. min_length : float, optional Minimum interval length in seconds. Defaults to 0.0. max_length : float, optional @@ -163,12 +158,12 @@ def intervals_by_length(interval_list, min_length=0.0, max_length=1e10): def interval_list_contains_ind(interval_list, timestamps): - """Find indices of list of timestamps contained in an interval list. + """Find indices of a list of timestamps that are contained in an interval list. Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval in seconds. + Each element is (start time, stop time), i.e. an interval. Unit is seconds. timestamps : array_like """ ind = [] @@ -189,7 +184,7 @@ def interval_list_contains(interval_list, timestamps): Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval in seconds. + Each element is (start time, stop time), i.e. an interval. Unit is seconds. timestamps : array_like """ ind = [] @@ -205,17 +200,28 @@ def interval_list_contains(interval_list, timestamps): def interval_list_excludes_ind(interval_list, timestamps): - """Find indices of timestamps that are not contained in an interval list. + """Find indices of a list of timestamps that are not contained in an interval list. Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval in seconds. + Each element is (start time, stop time), i.e. an interval. Unit is seconds. timestamps : array_like """ contained_inds = interval_list_contains_ind(interval_list, timestamps) return np.setdiff1d(np.arange(len(timestamps)), contained_inds) + # # add the first and last times to the list and creat a list of invalid intervals + # valid_times_list = np.ndarray.ravel(interval_list).tolist() + # valid_times_list.insert(0, timestamps[0] - 0.00001) + # valid_times_list.append(timestamps[-1] + 0.001) + # invalid_times = np.array(valid_times_list).reshape(-1, 2) + # # add the first and last timestamp indices + # ind = [] + # for invalid_time in invalid_times: + # ind += np.ravel(np.argwhere(np.logical_and(timestamps > invalid_time[0], + # timestamps < invalid_time[1]))).tolist() + # return np.asarray(ind) def interval_list_excludes(interval_list, timestamps): @@ -224,24 +230,22 @@ def interval_list_excludes(interval_list, timestamps): Parameters ---------- interval_list : array_like - Each element is (start time, stop time), i.e. an interval in seconds. + Each element is (start time, stop time), i.e. an interval. Unit is seconds. timestamps : array_like """ contained_times = interval_list_contains(interval_list, timestamps) return np.setdiff1d(timestamps, contained_times) - - -def consolidate_intervals(interval_list): - if interval_list.ndim == 1: - interval_list = np.expand_dims(interval_list, 0) - else: - interval_list = interval_list[np.argsort(interval_list[:, 0])] - interval_list = reduce(_union_concat, interval_list) - # the following check is needed in the case where the interval list is a - # single element (behavior of reduce) - if interval_list.ndim == 1: - interval_list = np.expand_dims(interval_list, 0) - return interval_list + # # add the first and last times to the list and creat a list of invalid intervals + # valid_times_list = np.ravel(valid_times).tolist() + # valid_times_list.insert(0, timestamps[0] - 0.00001) + # valid_times_list.append(timestamps[-1] + 0.00001) + # invalid_times = np.array(valid_times_list).reshape(-1, 2) + # # add the first and last timestamp indices + # ind = [] + # for invalid_time in invalid_times: + # ind += np.ravel(np.argwhere(np.logical_and(timestamps > invalid_time[0], + # timestamps < invalid_time[1]))).tolist() + # return timestamps[ind] def interval_list_intersect(interval_list1, interval_list2, min_length=0): @@ -261,51 +265,76 @@ def interval_list_intersect(interval_list1, interval_list2, min_length=0): interval_list: np.array, (N,2) """ - # Consolidate interval lists to disjoint int'ls by sorting & applying union - interval_list1 = consolidate_intervals(interval_list1) - interval_list2 = consolidate_intervals(interval_list2) + # first, consolidate interval lists to disjoint intervals by sorting and applying union + if interval_list1.ndim == 1: + interval_list1 = np.expand_dims(interval_list1, 0) + else: + interval_list1 = interval_list1[np.argsort(interval_list1[:, 0])] + interval_list1 = reduce(_union_concat, interval_list1) + # the following check is needed in the case where the interval list is a single element (behavior of reduce) + if interval_list1.ndim == 1: + interval_list1 = np.expand_dims(interval_list1, 0) + + if interval_list2.ndim == 1: + interval_list2 = np.expand_dims(interval_list2, 0) + else: + interval_list2 = interval_list2[np.argsort(interval_list2[:, 0])] + interval_list2 = reduce(_union_concat, interval_list2) + # the following check is needed in the case where the interval list is a single element (behavior of reduce) + if interval_list2.ndim == 1: + interval_list2 = np.expand_dims(interval_list2, 0) # then do pairwise comparison and collect intersections - intersecting_intervals = [ - _intersection(interval2, interval1) - for interval2 in interval_list2 - for interval1 in interval_list1 - if _intersection(interval2, interval1) is not None - ] + intersecting_intervals = [] + for interval2 in interval_list2: + for interval1 in interval_list1: + if _intersection(interval2, interval1) is not None: + intersecting_intervals.append( + _intersection(interval1, interval2) + ) # if no intersection, then return an empty list if not intersecting_intervals: return [] + else: + intersecting_intervals = np.asarray(intersecting_intervals) + intersecting_intervals = intersecting_intervals[ + np.argsort(intersecting_intervals[:, 0]) + ] - intersecting_intervals = np.asarray(intersecting_intervals) - intersecting_intervals = intersecting_intervals[ - np.argsort(intersecting_intervals[:, 0]) - ] - - return intervals_by_length(intersecting_intervals, min_length=min_length) + return intervals_by_length( + intersecting_intervals, min_length=min_length + ) def _intersection(interval1, interval2): - """Takes the (set-theoretic) intersection of two intervals""" - start = max(interval1[0], interval2[0]) - end = min(interval1[1], interval2[1]) - intersection = np.array([start, end]) if end > start else None - return intersection + "Takes the (set-theoretic) intersection of two intervals" + intersection = np.array( + [max([interval1[0], interval2[0]]), min([interval1[1], interval2[1]])] + ) + if intersection[1] >= intersection[0]: + return intersection + else: + return None def _union(interval1, interval2): - """Takes the (set-theoretic) union of two intervals""" + "Takes the (set-theoretic) union of two intervals" if _intersection(interval1, interval2) is None: return np.array([interval1, interval2]) - return np.array( - [min(interval1[0], interval2[0]), max(interval1[1], interval2[1])] - ) + else: + return np.array( + [ + min([interval1[0], interval2[0]]), + max([interval1[1], interval2[1]]), + ] + ) def _union_concat(interval_list, interval): - """Compare last interval of interval list to given interval. - - If overlapping, take union. If not, concatenate interval to interval list. + """Compares the last interval of the interval list to the given interval and + * takes their union if overlapping + * concatenates the interval to the interval list if not Recursively called with `reduce`. """ @@ -315,23 +344,27 @@ def _union_concat(interval_list, interval): interval = np.expand_dims(interval, 0) x = _union(interval_list[-1], interval[0]) - x = np.expand_dims(x, 0) if x.ndim == 1 else x - + if x.ndim == 1: + x = np.expand_dims(x, 0) return np.concatenate((interval_list[:-1], x), axis=0) def union_adjacent_index(interval1, interval2): - """Union index-adjacent intervals. If not adjacent, just concatenate. - + """unions two intervals that are adjacent in index e.g. [a,b] and [b+1, c] is converted to [a,c] + if not adjacent, just concatenates interval2 at the end of interval1 Parameters ---------- interval1 : np.array + [description] interval2 : np.array + [description] """ - interval1 = np.atleast_2d(interval1) - interval2 = np.atleast_2d(interval2) + if interval1.ndim == 1: + interval1 = np.expand_dims(interval1, 0) + if interval2.ndim == 1: + interval2 = np.expand_dims(interval2, 0) if ( interval1[-1][1] + 1 == interval2[0][0] @@ -353,63 +386,50 @@ def union_adjacent_index(interval1, interval2): # TODO: test interval_list_union code -def _parallel_union(interval_list): - """Create a parallel list where 1 is start and -1 the end""" - interval_list = np.ravel(interval_list) - interval_list_start_end = np.ones(interval_list.shape) - interval_list_start_end[1::2] = -1 - return interval_list, interval_list_start_end - - def interval_list_union( - interval_list1: np.ndarray, - interval_list2: np.ndarray, - min_length: float = 0.0, - max_length: float = 1e10, -) -> np.ndarray: + interval_list1, interval_list2, min_length=0.0, max_length=1e10 +): """Finds the union (all times in one or both) for two interval lists - Parameters - ---------- - interval_list1 : np.ndarray - The first interval list [start, stop] - interval_list2 : np.ndarray - The second interval list [start, stop] - min_length : float, optional - Minimum length of interval for inclusion in output, default 0.0 - max_length : float, optional - Maximum length of interval for inclusion in output, default 1e10 - - Returns - ------- - np.ndarray - Array of intervals [start, stop] + :param interval_list1: The first interval list + :type interval_list1: numpy array of intervals [start, stop] + :param interval_list2: The second interval list + :type interval_list2: numpy array of intervals [start, stop] + :param min_length: optional minimum length of interval for inclusion in output, default 0.0 + :type min_length: float + :param max_length: optional maximum length of interval for inclusion in output, default 1e10 + :type max_length: float + :return: interval_list + :rtype: numpy array of intervals [start, stop] """ - - il1, il1_start_end = _parallel_union(interval_list1) - il2, il2_start_end = _parallel_union(interval_list2) - - # Concatenate the two lists so we can resort the intervals and apply the - # same sorting to the start-end arrays - combined_intervals = np.concatenate((il1, il2)) - ss = np.concatenate((il1_start_end, il2_start_end)) + # return np.array([min(interval_list1[0],interval_list2[0]), + # max(interval_list1[1],interval_list2[1])]) + interval_list1 = np.ravel(interval_list1) + # create a parallel list where 1 indicates the start and -1 the end of an interval + interval_list1_start_end = np.ones(interval_list1.shape) + interval_list1_start_end[1::2] = -1 + + interval_list2 = np.ravel(interval_list2) + # create a parallel list for the second interval where 1 indicates the start and -1 the end of an interval + interval_list2_start_end = np.ones(interval_list2.shape) + interval_list2_start_end[1::2] = -1 + + # concatenate the two lists so we can resort the intervals and apply the same sorting to the start-end arrays + combined_intervals = np.concatenate((interval_list1, interval_list2)) + ss = np.concatenate((interval_list1_start_end, interval_list2_start_end)) sort_ind = np.argsort(combined_intervals) combined_intervals = combined_intervals[sort_ind] - - # a cumulative sum of 1 indicates the beginning of a joint interval; a - # cumulative sum of 0 indicates the end + # a cumulative sum of 1 indicates the beginning of a joint interval; a cumulative sum of 0 indicates the end union_starts = np.ravel(np.array(np.where(np.cumsum(ss[sort_ind]) == 1))) union_stops = np.ravel(np.array(np.where(np.cumsum(ss[sort_ind]) == 0))) - union = [ - [combined_intervals[start], combined_intervals[stop]] - for start, stop in zip(union_starts, union_stops) - ] - + union = [] + for start, stop in zip(union_starts, union_stops): + union.append([combined_intervals[start], combined_intervals[stop]]) return np.asarray(union) def interval_list_censor(interval_list, timestamps): - """Returns new interval list that starts/ends at first/last timestamp + """returns a new interval list that starts and ends at the first and last timestamp Parameters ---------- @@ -422,10 +442,9 @@ def interval_list_censor(interval_list, timestamps): interval_list (numpy array of intervals [start, stop]) """ # check that all timestamps are in the interval list - if len(interval_list_contains_ind(interval_list, timestamps)) != len( + assert len(interval_list_contains_ind(interval_list, timestamps)) == len( timestamps - ): - raise ValueError("Interval_list must contain all timestamps") + ), "interval_list must contain all timestamps" timestamps_interval = np.asarray([[timestamps[0], timestamps[-1]]]) return interval_list_intersect(interval_list, timestamps_interval) @@ -433,7 +452,6 @@ def interval_list_censor(interval_list, timestamps): def interval_from_inds(list_frames): """Converts a list of indices to a list of intervals. - e.g. [2,3,4,6,7,8,9,10] -> [[2,4],[6,10]] Parameters @@ -488,3 +506,44 @@ def interval_set_difference_inds(intervals1, intervals2): i += 1 result += intervals1[i:] return result + + +def interval_list_complement(intervals1, intervals2, min_length=0.0): + """ + Finds intervals in intervals1 that are not in intervals2 + + Parameters + ---------- + min_length : float, optional + Minimum interval length in seconds. Defaults to 0.0. + """ + + result = [] + + for start1, end1 in intervals1: + subtracted = [(start1, end1)] + + for start2, end2 in intervals2: + new_subtracted = [] + + for s, e in subtracted: + if start2 <= s and e <= end2: + continue + + if e <= start2 or end2 <= s: + new_subtracted.append((s, e)) + continue + + if start2 > s: + new_subtracted.append((s, start2)) + + if end2 < e: + new_subtracted.append((end2, e)) + + subtracted = new_subtracted + + result.extend(subtracted) + + return intervals_by_length( + np.asarray(result), min_length=min_length, max_length=1e100 + ) diff --git a/src/spyglass/spikesorting/v1/artifact.py b/src/spyglass/spikesorting/v1/artifact.py index 0532357c1..3625cae7e 100644 --- a/src/spyglass/spikesorting/v1/artifact.py +++ b/src/spyglass/spikesorting/v1/artifact.py @@ -1,174 +1,197 @@ import warnings from functools import reduce -from typing import Union +from typing import Union, List import datajoint as dj import numpy as np import scipy.stats as stats import spikeinterface as si +import spikeinterface.extractors as se from spikeinterface.core.job_tools import ChunkRecordingExecutor, ensure_n_jobs -from ..common.common_interval import ( +from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.common.common_interval import ( IntervalList, _union_concat, interval_from_inds, - interval_set_difference_inds, + interval_list_complement, +) +from spyglass.spikesorting.v1.utils import generate_nwb_uuid +from spyglass.spikesorting.v1.recording import ( + SpikeSortingRecording, + SpikeSortingRecordingSelection, ) -from ..utils.nwb_helper_fn import get_valid_intervals -from .spikesorting_recording import SpikeSortingRecording -schema = dj.schema("spikesorting_artifact") +schema = dj.schema("spikesorting_v1_artifact") @schema -class ArtifactDetectionParameters(dj.Manual): +class ArtifactDetectionParameters(dj.Lookup): definition = """ - # Parameters for detecting artifact times within a sort group. - artifact_params_name: varchar(200) + # Parameter for detecting artifacts (non-neural high amplitude events) + artifact_param_name : varchar(200) --- - artifact_params: blob # dictionary of parameters + artifact_params : blob """ - def insert_default(self): - """Insert the default artifact parameters with an appropriate parameter dict.""" - artifact_params = {} - artifact_params["zscore_thresh"] = None # must be None or >= 0 - artifact_params["amplitude_thresh"] = 3000 # must be None or >= 0 - # all electrodes of sort group - artifact_params["proportion_above_thresh"] = 1.0 - artifact_params["removal_window_ms"] = 1.0 # in milliseconds - self.insert1(["default", artifact_params], skip_duplicates=True) + contents = [ + [ + "default", + { + "zscore_thresh": None, + "amplitude_thresh_uV": 3000, + "proportion_above_thresh": 1.0, + "removal_window_ms": 1.0, + "chunk_duration": "10s", + "n_jobs": 4, + "progress_bar": "True", + }, + ], + [ + "none", + { + "zscore_thresh": None, + "amplitude_thresh_uV": None, + "chunk_duration": "10s", + "n_jobs": 4, + "progress_bar": "True", + }, + ], + ] - artifact_params_none = {} - artifact_params_none["zscore_thresh"] = None - artifact_params_none["amplitude_thresh"] = None - self.insert1(["none", artifact_params_none], skip_duplicates=True) + @classmethod + def insert_default(cls): + cls.insert(cls.contents, skip_duplicates=True) @schema class ArtifactDetectionSelection(dj.Manual): definition = """ - # Specifies artifact detection parameters to apply to a sort group's recording. + # Processed recording and artifact detection parameters. Use `insert_selection` method to insert new rows. + artifact_id: varchar(30) + --- -> SpikeSortingRecording -> ArtifactDetectionParameters - --- - custom_artifact_detection=0 : tinyint """ + @classmethod + def insert_selection(cls, key: dict): + """Insert a row into ArtifactDetectionSelection with an + automatically generated unique artifact ID as the sole primary key. + + Parameters + ---------- + key : dict + primary key of SpikeSortingRecording and ArtifactDetectionParameters + + Returns + ------- + artifact_id : str + the unique artifact ID serving as primary key for ArtifactDetectionSelection + """ + if len((cls & key).fetch()) > 0: + print( + "This row has already been inserted into ArtifactDetectionSelection." + ) + return (cls & key).fetch1() + key["artifact_id"] = generate_nwb_uuid( + (SpikeSortingRecordingSelection & key).fetch1("nwb_file_name"), + "A", + 6, + ) + cls.insert1(key, skip_duplicates=True) + return key + @schema class ArtifactDetection(dj.Computed): definition = """ - # Stores artifact times and valid no-artifact times as intervals. + # Detects artifacts (e.g. large transients from movement) and saves artifact-free intervals in IntervalList. -> ArtifactDetectionSelection - --- - artifact_times: longblob # np array of artifact intervals - artifact_removed_valid_times: longblob # np array of valid no-artifact intervals - artifact_removed_interval_list_name: varchar(200) # name of the array of no-artifact valid time intervals """ def make(self, key): - if not (ArtifactDetectionSelection & key).fetch1( - "custom_artifact_detection" - ): - # get the dict of artifact params associated with this artifact_params_name - artifact_params = (ArtifactDetectionParameters & key).fetch1( - "artifact_params" - ) - - recording_path = (SpikeSortingRecording & key).fetch1( - "recording_path" - ) - recording_name = SpikeSortingRecording._get_recording_name(key) - recording = si.load_extractor(recording_path) - - job_kwargs = { - "chunk_duration": "10s", - "n_jobs": 4, - "progress_bar": "True", + # FETCH: + # - artifact parameters + # - recording analysis nwb file + artifact_params, recording_analysis_nwb_file = ( + ArtifactDetectionParameters + * SpikeSortingRecording + * ArtifactDetectionSelection + & key + ).fetch1("artifact_params", "analysis_file_name") + sort_interval_valid_times = ( + IntervalList + & { + "nwb_file_name": ( + SpikeSortingRecordingSelection * ArtifactDetectionSelection + & key + ).fetch1("nwb_file_name"), + "interval_list_name": ( + SpikeSortingRecordingSelection * ArtifactDetectionSelection + & key + ).fetch1("interval_list_name"), } + ).fetch1("valid_times") + # DO: + # - load recording + recording_analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path( + recording_analysis_nwb_file + ) + recording = se.read_nwb_recording( + recording_analysis_nwb_file_abs_path, load_time_vector=True + ) - artifact_removed_valid_times, artifact_times = _get_artifact_times( - recording, **artifact_params, **job_kwargs - ) - - # NOTE: decided not to do this but to just create a single long segment; keep for now - # get artifact times by segment - # if AppendSegmentRecording, get artifact times for each segment - # if isinstance(recording, AppendSegmentRecording): - # artifact_removed_valid_times = [] - # artifact_times = [] - # for rec in recording.recording_list: - # rec_valid_times, rec_artifact_times = _get_artifact_times(rec, **artifact_params) - # for valid_times in rec_valid_times: - # artifact_removed_valid_times.append(valid_times) - # for artifact_times in rec_artifact_times: - # artifact_times.append(artifact_times) - # artifact_removed_valid_times = np.asarray(artifact_removed_valid_times) - # artifact_times = np.asarray(artifact_times) - # else: - # artifact_removed_valid_times, artifact_times = _get_artifact_times(recording, **artifact_params) - - key["artifact_times"] = artifact_times - key["artifact_removed_valid_times"] = artifact_removed_valid_times - - # set up a name for no-artifact times using recording id - key["artifact_removed_interval_list_name"] = ( - recording_name - + "_" - + key["artifact_params_name"] - + "_artifact_removed_valid_times" - ) - - ArtifactRemovedIntervalList.insert1(key, replace=True) - - # also insert into IntervalList - tmp_key = {} - tmp_key["nwb_file_name"] = key["nwb_file_name"] - tmp_key["interval_list_name"] = key[ - "artifact_removed_interval_list_name" - ] - tmp_key["valid_times"] = key["artifact_removed_valid_times"] - IntervalList.insert1(tmp_key, replace=True) - - # insert into computed table - self.insert1(key) - + # - detect artifacts + artifact_removed_valid_times, _ = _get_artifact_times( + recording, + sort_interval_valid_times, + **artifact_params, + ) -@schema -class ArtifactRemovedIntervalList(dj.Manual): - definition = """ - # Stores intervals without detected artifacts. - # Note that entries can come from either ArtifactDetection() or alternative artifact removal analyses. - artifact_removed_interval_list_name: varchar(200) - --- - -> ArtifactDetectionSelection - artifact_removed_valid_times: longblob - artifact_times: longblob # np array of artifact intervals - """ + # INSERT + # - into IntervalList + IntervalList.insert1( + dict( + nwb_file_name=( + SpikeSortingRecordingSelection * ArtifactDetectionSelection + & key + ).fetch1("nwb_file_name"), + interval_list_name=key["artifact_id"], + valid_times=artifact_removed_valid_times, + ), + skip_duplicates=True, + ) + # - into ArtifactRemovedInterval + self.insert1(key) def _get_artifact_times( recording: si.BaseRecording, + sort_interval_valid_times: List[List], zscore_thresh: Union[float, None] = None, - amplitude_thresh: Union[float, None] = None, + amplitude_thresh_uV: Union[float, None] = None, proportion_above_thresh: float = 1.0, removal_window_ms: float = 1.0, verbose: bool = False, **job_kwargs, ): """Detects times during which artifacts do and do not occur. - Artifacts are defined as periods where the absolute value of the recording signal exceeds one - or both specified amplitude or zscore thresholds on the proportion of channels specified, - with the period extended by the removal_window_ms/2 on each side. Z-score and amplitude - threshold values of None are ignored. + + Artifacts are defined as periods where the absolute value of the recording + signal exceeds one or both specified amplitude or z-score thresholds on the + proportion of channels specified, with the period extended by the + removal_window_ms/2 on each side. Z-score and amplitude threshold values of + None are ignored. Parameters ---------- recording : si.BaseRecording + sort_interval_valid_times : List[List] + The sort interval for the recording, unit: seconds zscore_thresh : float, optional Stdev threshold for exclusion, should be >=0, defaults to None - amplitude_thresh : float, optional + amplitude_thresh_uV : float, optional Amplitude threshold for exclusion, should be >=0, defaults to None proportion_above_thresh : float, optional, should be>0 and <=1 Proportion of electrodes that need to have threshold crossings, defaults to 1 @@ -184,54 +207,44 @@ def _get_artifact_times( Intervals in which artifacts are detected (including removal windows), unit: seconds """ - if recording.get_num_segments() > 1: - valid_timestamps = np.array([]) - for segment in range(recording.get_num_segments()): - valid_timestamps = np.concatenate( - (valid_timestamps, recording.get_times(segment_index=segment)) - ) - recording = si.concatenate_recordings([recording]) - elif recording.get_num_segments() == 1: - valid_timestamps = recording.get_times(0) + valid_timestamps = recording.get_times() # if both thresholds are None, we skip artifract detection - if (amplitude_thresh is None) and (zscore_thresh is None): - recording_interval = np.asarray( - [[valid_timestamps[0], valid_timestamps[-1]]] - ) - artifact_times_empty = np.asarray([]) + if amplitude_thresh_uV is zscore_thresh is None: print( - "Amplitude and zscore thresholds are both None, skipping artifact detection" + "Amplitude and zscore thresholds are both None, " + + "skipping artifact detection" ) - return recording_interval, artifact_times_empty + return np.asarray( + [valid_timestamps[0], valid_timestamps[-1]] + ), np.asarray([]) # verify threshold parameters ( - amplitude_thresh, + amplitude_thresh_uV, zscore_thresh, proportion_above_thresh, ) = _check_artifact_thresholds( - amplitude_thresh, zscore_thresh, proportion_above_thresh + amplitude_thresh_uV, zscore_thresh, proportion_above_thresh ) # detect frames that are above threshold in parallel n_jobs = ensure_n_jobs(recording, n_jobs=job_kwargs.get("n_jobs", 1)) - print(f"using {n_jobs} jobs...") - + print(f"Using {n_jobs} jobs...") func = _compute_artifact_chunk init_func = _init_artifact_worker if n_jobs == 1: init_args = ( recording, zscore_thresh, - amplitude_thresh, + amplitude_thresh_uV, proportion_above_thresh, ) else: init_args = ( recording.to_dict(), zscore_thresh, - amplitude_thresh, + amplitude_thresh_uV, proportion_above_thresh, ) @@ -245,11 +258,12 @@ def _get_artifact_times( job_name="detect_artifact_frames", **job_kwargs, ) + artifact_frames = executor.run() artifact_frames = np.concatenate(artifact_frames) # turn ms to remove total into s to remove from either side of each detected artifact - half_removal_window_s = removal_window_ms / 1000 * 0.5 + half_removal_window_s = removal_window_ms / 2 / 1000 if len(artifact_frames) == 0: recording_interval = np.asarray( @@ -267,38 +281,31 @@ def _get_artifact_times( (len(artifact_intervals), 2), dtype=np.float64 ) for interval_idx, interval in enumerate(artifact_intervals): + interv_ind = [ + np.searchsorted( + valid_timestamps, + valid_timestamps[interval[0]] - half_removal_window_s, + ), + np.searchsorted( + valid_timestamps, + valid_timestamps[interval[1]] + half_removal_window_s, + ), + ] artifact_intervals_s[interval_idx] = [ - valid_timestamps[interval[0]] - half_removal_window_s, - valid_timestamps[interval[1]] + half_removal_window_s, + valid_timestamps[interv_ind[0]], + valid_timestamps[interv_ind[1]], ] + # make the artifact intervals disjoint artifact_intervals_s = reduce(_union_concat, artifact_intervals_s) - # convert seconds back to indices - artifact_intervals_new = [] - for artifact_interval_s in artifact_intervals_s: - artifact_intervals_new.append( - np.searchsorted(valid_timestamps, artifact_interval_s) - ) - - # compute set difference between intervals (of indices) - try: - # if artifact_intervals_new is a list of lists then len(artifact_intervals_new[0]) is the number of intervals - # otherwise artifact_intervals_new is a list of ints and len(artifact_intervals_new[0]) is not defined - len(artifact_intervals_new[0]) - except TypeError: - # convert to list of lists - artifact_intervals_new = [artifact_intervals_new] - artifact_removed_valid_times_ind = interval_set_difference_inds( - [(0, len(valid_timestamps) - 1)], artifact_intervals_new + # find non-artifact intervals in timestamps + artifact_removed_valid_times = interval_list_complement( + sort_interval_valid_times, artifact_intervals_s, min_length=1 + ) + artifact_removed_valid_times = reduce( + _union_concat, artifact_removed_valid_times ) - - # convert back to seconds - artifact_removed_valid_times = [] - for i in artifact_removed_valid_times_ind: - artifact_removed_valid_times.append( - (valid_timestamps[i[0]], valid_timestamps[i[1]]) - ) return artifact_removed_valid_times, artifact_intervals_s @@ -306,7 +313,7 @@ def _get_artifact_times( def _init_artifact_worker( recording, zscore_thresh=None, - amplitude_thresh=None, + amplitude_thresh_uV=None, proportion_above_thresh=1.0, ): # create a local dict per worker @@ -316,7 +323,7 @@ def _init_artifact_worker( else: worker_ctx["recording"] = recording worker_ctx["zscore_thresh"] = zscore_thresh - worker_ctx["amplitude_thresh"] = amplitude_thresh + worker_ctx["amplitude_thresh_uV"] = amplitude_thresh_uV worker_ctx["proportion_above_thresh"] = proportion_above_thresh return worker_ctx @@ -324,7 +331,7 @@ def _init_artifact_worker( def _compute_artifact_chunk(segment_index, start_frame, end_frame, worker_ctx): recording = worker_ctx["recording"] zscore_thresh = worker_ctx["zscore_thresh"] - amplitude_thresh = worker_ctx["amplitude_thresh"] + amplitude_thresh_uV = worker_ctx["amplitude_thresh_uV"] proportion_above_thresh = worker_ctx["proportion_above_thresh"] # compute the number of electrodes that have to be above threshold nelect_above = np.ceil( @@ -338,13 +345,13 @@ def _compute_artifact_chunk(segment_index, start_frame, end_frame, worker_ctx): ) # find the artifact occurrences using one or both thresholds, across channels - if (amplitude_thresh is not None) and (zscore_thresh is None): - above_a = np.abs(traces) > amplitude_thresh + if (amplitude_thresh_uV is not None) and (zscore_thresh is None): + above_a = np.abs(traces) > amplitude_thresh_uV above_thresh = ( np.ravel(np.argwhere(np.sum(above_a, axis=1) >= nelect_above)) + start_frame ) - elif (amplitude_thresh is None) and (zscore_thresh is not None): + elif (amplitude_thresh_uV is None) and (zscore_thresh is not None): dataz = np.abs(stats.zscore(traces, axis=1)) above_z = dataz > zscore_thresh above_thresh = ( @@ -352,7 +359,7 @@ def _compute_artifact_chunk(segment_index, start_frame, end_frame, worker_ctx): + start_frame ) else: - above_a = np.abs(traces) > amplitude_thresh + above_a = np.abs(traces) > amplitude_thresh_uV dataz = np.abs(stats.zscore(traces, axis=1)) above_z = dataz > zscore_thresh above_thresh = ( @@ -369,20 +376,20 @@ def _compute_artifact_chunk(segment_index, start_frame, end_frame, worker_ctx): def _check_artifact_thresholds( - amplitude_thresh, zscore_thresh, proportion_above_thresh + amplitude_thresh_uV, zscore_thresh, proportion_above_thresh ): """Alerts user to likely unintended parameters. Not an exhaustive verification. Parameters ---------- zscore_thresh: float - amplitude_thresh: float + amplitude_thresh_uV: float proportion_above_thresh: float Return ------ zscore_thresh: float - amplitude_thresh: float + amplitude_thresh_uV: float proportion_above_thresh: float Raise @@ -391,7 +398,7 @@ def _check_artifact_thresholds( """ # amplitude or zscore thresholds should be negative, as they are applied to an absolute signal signal_thresholds = [ - t for t in [amplitude_thresh, zscore_thresh] if t is not None + t for t in [amplitude_thresh_uV, zscore_thresh] if t is not None ] for t in signal_thresholds: if t < 0: @@ -412,4 +419,43 @@ def _check_artifact_thresholds( f"Using proportion_above_thresh = 1 instead of {str(proportion_above_thresh)}" ) proportion_above_thresh = 1 - return amplitude_thresh, zscore_thresh, proportion_above_thresh + return amplitude_thresh_uV, zscore_thresh, proportion_above_thresh + + +def merge_intervals(intervals): + """Takes a list of intervals each of which is [start_time, stop_time] + and takes union over intervals that are intersecting + + Parameters + ---------- + intervals : _type_ + _description_ + + Returns + ------- + _type_ + _description_ + """ + if len(intervals) == 0: + return [] + + # Sort the intervals based on their start times + intervals.sort(key=lambda x: x[0]) + + merged = [intervals[0]] + + for i in range(1, len(intervals)): + current_start, current_stop = intervals[i] + last_merged_start, last_merged_stop = merged[-1] + + if current_start <= last_merged_stop: + # Overlapping intervals, merge them + merged[-1] = [ + last_merged_start, + max(last_merged_stop, current_stop), + ] + else: + # Non-overlapping intervals, add the current one to the list + merged.append([current_start, current_stop]) + + return np.asarray(merged) diff --git a/src/spyglass/spikesorting/v1/curation.py b/src/spyglass/spikesorting/v1/curation.py index b0699ba42..75ea4fc3f 100644 --- a/src/spyglass/spikesorting/v1/curation.py +++ b/src/spyglass/spikesorting/v1/curation.py @@ -1,1097 +1,419 @@ -import json -import os -import shutil -import time -import uuid -import warnings -from pathlib import Path -from typing import List +from typing import List, Union, Dict import datajoint as dj -import numpy as np +import pynwb import spikeinterface as si -import spikeinterface.preprocessing as sip -import spikeinterface.qualitymetrics as sq +import spikeinterface.extractors as se +import spikeinterface.curation as sc -from ..common.common_interval import IntervalList -from ..common.common_nwbfile import AnalysisNwbfile -from ..utils.dj_helper_fn import fetch_nwb -from .merged_sorting_extractor import MergedSortingExtractor -from .spikesorting_recording import SortInterval, SpikeSortingRecording -from .spikesorting_sorting import SpikeSorting +from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.common.common_ephys import Raw +from spyglass.spikesorting.v1.recording import ( + SpikeSortingRecording, +) +from spyglass.spikesorting.v1.sorting import SpikeSorting, SpikeSortingSelection -schema = dj.schema("spikesorting_curation") +schema = dj.schema("spikesorting_v1_curation") valid_labels = ["reject", "noise", "artifact", "mua", "accept"] -def apply_merge_groups_to_sorting( - sorting: si.BaseSorting, merge_groups: List[List[int]] -): - # return a new sorting where the units are merged according to merge_groups - # merge_groups is a list of lists of unit_ids. - # for example: merge_groups = [[1, 2], [5, 8, 4]]] - - return MergedSortingExtractor( - parent_sorting=sorting, merge_groups=merge_groups - ) - - @schema -class Curation(dj.Manual): +class CurationV1(dj.Manual): definition = """ - # Stores each spike sorting; similar to IntervalList - curation_id: int # a number corresponding to the index of this curation + # Curation of a SpikeSorting. Use `insert_curation` to insert rows. -> SpikeSorting + curation_id=0: int --- parent_curation_id=-1: int - curation_labels: blob # a dictionary of labels for the units - merge_groups: blob # a list of merge groups for the units - quality_metrics: blob # a list of quality metrics for the units (if available) - description='': varchar(1000) #optional description for this curated sort - time_of_creation: int # in Unix time, to the nearest second + -> AnalysisNwbfile + object_id: varchar(72) + merges_applied: enum("True", "False") + description: varchar(100) """ - @staticmethod + @classmethod def insert_curation( - sorting_key: dict, + cls, + sorting_id: str, parent_curation_id: int = -1, - labels=None, - merge_groups=None, - metrics=None, - description="", + labels: Union[None, Dict[str, List[str]]] = None, + merge_groups: Union[None, List[List[str]]] = None, + apply_merge: bool = False, + metrics: Union[None, Dict[str, Dict[str, float]]] = None, + description: str = "", ): - """Given a SpikeSorting key and the parent_sorting_id (and optional - arguments) insert an entry into Curation. - + """Insert an row into CurationV1. Parameters ---------- - sorting_key : dict + sorting_id : str The key for the original SpikeSorting parent_curation_id : int, optional - The id of the parent sorting + The curation id of the parent curation labels : dict or None, optional + curation labels (e.g. good, noise, mua) merge_groups : dict or None, optional + groups of unit IDs to be merged metrics : dict or None, optional - Computed metrics for sorting + Computed quality metrics, one for each neuron description : str, optional - text description of this sort + description of this curation or where it originates; e.g. "FigURL", by default "" + + Note + ---- + Example curation.json (output of figurl): + { + "labelsByUnit": + {"1":["noise","reject"],"10":["noise","reject"]}, + "mergeGroups": + [[11,12],[46,48],[51,54],[50,53]] + } Returns ------- curation_key : dict - """ - if parent_curation_id == -1: + if parent_curation_id <= -1: + parent_curation_id = -1 # check to see if this sorting with a parent of -1 has already been inserted and if so, warn the user - inserted_curation = (Curation & sorting_key).fetch("KEY") - if len(inserted_curation) > 0: - Warning( - "Sorting has already been inserted, returning key to previously" - "inserted curation" + if ( + len( + ( + cls + & { + "sorting_id": sorting_id, + "parent_curation_id": parent_curation_id, + } + ).fetch("KEY") ) - return inserted_curation[0] - - if labels is None: - labels = {} - if merge_groups is None: - merge_groups = [] - if metrics is None: - metrics = {} - - # generate a unique number for this curation - id = (Curation & sorting_key).fetch("curation_id") - if len(id) > 0: - curation_id = max(id) + 1 + > 0 + ): + Warning(f"Sorting has already been inserted.") + return ( + cls + & { + "sorting_id": sorting_id, + "parent_curation_id": parent_curation_id, + } + ).fetch("KEY") + + # generate curation ID + existing_curation_ids = (cls & {"sorting_id": sorting_id}).fetch( + "curation_id" + ) + if len(existing_curation_ids) > 0: + curation_id = max(existing_curation_ids) + 1 else: curation_id = 0 - # convert unit_ids in labels to integers for labels from sortingview. - new_labels = {int(unit_id): labels[unit_id] for unit_id in labels} - - sorting_key["curation_id"] = curation_id - sorting_key["parent_curation_id"] = parent_curation_id - sorting_key["description"] = description - sorting_key["curation_labels"] = new_labels - sorting_key["merge_groups"] = merge_groups - sorting_key["quality_metrics"] = metrics - sorting_key["time_of_creation"] = int(time.time()) - - # mike: added skip duplicates - Curation.insert1(sorting_key, skip_duplicates=True) + # write the curation labels, merge groups, and metrics as columns in the units table of NWB + analysis_file_name, object_id = _write_sorting_to_nwb_with_curation( + sorting_id=sorting_id, + labels=labels, + merge_groups=merge_groups, + metrics=metrics, + apply_merge=apply_merge, + ) + # INSERT + AnalysisNwbfile().add( + (SpikeSortingSelection & {"sorting_id": sorting_id}).fetch1( + "nwb_file_name" + ), + analysis_file_name, + ) - # get the primary key for this curation - c_key = Curation.fetch("KEY")[0] - curation_key = {item: sorting_key[item] for item in c_key} + key = { + "sorting_id": sorting_id, + "curation_id": curation_id, + "parent_curation_id": parent_curation_id, + "analysis_file_name": analysis_file_name, + "object_id": object_id, + "merges_applied": str(apply_merge), + "description": description, + } + cls.insert1( + key, + skip_duplicates=True, + ) - return curation_key + return key - @staticmethod - def get_recording(key: dict): - """Returns the recording extractor for the recording related to this curation + @classmethod + def get_recording(cls, key: dict) -> si.BaseRecording: + """Get recording related to this curation as spikeinterface BaseRecording Parameters ---------- key : dict - SpikeSortingRecording key - - Returns - ------- - recording_extractor : spike interface recording extractor - + primary key of CurationV1 table """ - recording_path = (SpikeSortingRecording & key).fetch1("recording_path") - return si.load_extractor(recording_path) - @staticmethod - def get_curated_sorting(key: dict): - """Returns the sorting extractor related to this curation, - with merges applied. - - Parameters - ---------- - key : dict - Curation key - - Returns - ------- - sorting_extractor: spike interface sorting extractor + recording_id = (SpikeSortingSelection & key).fetch1("recording_id") + analysis_file_name = ( + SpikeSortingRecording & {"recording_id": recording_id} + ).fetch1("analysis_file_name") + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_file_name + ) + recording = se.read_nwb_recording( + analysis_file_abs_path, load_time_vector=True + ) - """ - sorting_path = (SpikeSorting & key).fetch1("sorting_path") - sorting = si.load_extractor(sorting_path) - merge_groups = (Curation & key).fetch1("merge_groups") - # TODO: write code to get merged sorting extractor - if len(merge_groups) != 0: - return MergedSortingExtractor( - parent_sorting=sorting, merge_groups=merge_groups - ) - else: - return sorting + return recording - @staticmethod - def save_sorting_nwb( - key, - sorting, - timestamps, - sort_interval_list_name, - sort_interval, - labels=None, - metrics=None, - unit_ids=None, - ): - """Store a sorting in a new AnalysisNwbfile + @classmethod + def get_sorting(cls, key: dict) -> si.BaseSorting: + """Get sorting in the analysis NWB file as spikeinterface BaseSorting Parameters ---------- key : dict - key to SpikeSorting table - sorting : si.Sorting - sorting - timestamps : array_like - Time stamps of the sorted recoridng; - used to convert the spike timings from index to real time - sort_interval_list_name : str - name of sort interval - sort_interval : list - interval for start and end of sort - labels : dict, optional - curation labels, by default None - metrics : dict, optional - quality metrics, by default None - unit_ids : list, optional - IDs of units whose spiketrains to save, by default None + primary key of CurationV1 table Returns ------- - analysis_file_name : str - units_object_id : str + sorting : si.BaseSorting """ + recording = cls.get_recording(key) - sort_interval_valid_times = ( - IntervalList & {"interval_list_name": sort_interval_list_name} - ).fetch1("valid_times") - - units = dict() - units_valid_times = dict() - units_sort_interval = dict() - - if unit_ids is None: - unit_ids = sorting.get_unit_ids() - - for unit_id in unit_ids: - spike_times_in_samples = sorting.get_unit_spike_train( - unit_id=unit_id - ) - units[unit_id] = timestamps[spike_times_in_samples] - units_valid_times[unit_id] = sort_interval_valid_times - units_sort_interval[unit_id] = [sort_interval] - - analysis_file_name = AnalysisNwbfile().create(key["nwb_file_name"]) - object_ids = AnalysisNwbfile().add_units( - analysis_file_name, - units, - units_valid_times, - units_sort_interval, - metrics=metrics, - labels=labels, - ) - AnalysisNwbfile().add(key["nwb_file_name"], analysis_file_name) - - if object_ids == "": - print( - "Sorting contains no units." - "Created an empty analysis nwb file anyway." - ) - units_object_id = "" - else: - units_object_id = object_ids[0] - - return analysis_file_name, units_object_id - - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) - - -@schema -class WaveformParameters(dj.Manual): - definition = """ - waveform_params_name: varchar(80) # name of waveform extraction parameters - --- - waveform_params: blob # a dict of waveform extraction parameters - """ - - def insert_default(self): - waveform_params_name = "default_not_whitened" - waveform_params = { - "ms_before": 0.5, - "ms_after": 0.5, - "max_spikes_per_unit": 5000, - "n_jobs": 5, - "total_memory": "5G", - "whiten": False, - } - self.insert1( - [waveform_params_name, waveform_params], skip_duplicates=True - ) - waveform_params_name = "default_whitened" - waveform_params = { - "ms_before": 0.5, - "ms_after": 0.5, - "max_spikes_per_unit": 5000, - "n_jobs": 5, - "total_memory": "5G", - "whiten": True, - } - self.insert1( - [waveform_params_name, waveform_params], skip_duplicates=True - ) - - -@schema -class WaveformSelection(dj.Manual): - definition = """ - -> Curation - -> WaveformParameters - --- - """ - - -@schema -class Waveforms(dj.Computed): - definition = """ - -> WaveformSelection - --- - waveform_extractor_path: varchar(400) - -> AnalysisNwbfile - waveforms_object_id: varchar(40) # Object ID for the waveforms in NWB file - """ - - def make(self, key): - recording = Curation.get_recording(key) - if recording.get_num_segments() > 1: - recording = si.concatenate_recordings([recording]) - - sorting = Curation.get_curated_sorting(key) - - print("Extracting waveforms...") - waveform_params = (WaveformParameters & key).fetch1("waveform_params") - if "whiten" in waveform_params: - if waveform_params.pop("whiten"): - recording = sip.whiten(recording, dtype="float32") - - waveform_extractor_name = self._get_waveform_extractor_name(key) - key["waveform_extractor_path"] = str( - Path(os.environ["SPYGLASS_WAVEFORMS_DIR"]) - / Path(waveform_extractor_name) - ) - if os.path.exists(key["waveform_extractor_path"]): - shutil.rmtree(key["waveform_extractor_path"]) - waveforms = si.extract_waveforms( - recording=recording, - sorting=sorting, - folder=key["waveform_extractor_path"], - **waveform_params, - ) - - key["analysis_file_name"] = AnalysisNwbfile().create( - key["nwb_file_name"] + analysis_file_name = (CurationV1 & key).fetch1("analysis_file_name") + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_file_name ) - object_id = AnalysisNwbfile().add_units_waveforms( - key["analysis_file_name"], waveform_extractor=waveforms + sorting = se.read_nwb_sorting( + analysis_file_abs_path, + sampling_frequency=recording.get_sampling_frequency(), ) - key["waveforms_object_id"] = object_id - AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) - self.insert1(key) + return sorting - def load_waveforms(self, key: dict): - """Returns a spikeinterface waveform extractor specified by key + @classmethod + def get_merged_sorting(cls, key: dict) -> si.BaseSorting: + """Get sorting with merges applied. Parameters ---------- key : dict - Could be an entry in Waveforms, or some other key that uniquely defines - an entry in Waveforms + CurationV1 key Returns ------- - we : spikeinterface.WaveformExtractor + sorting : si.BaseSorting + """ - we_path = (self & key).fetch1("waveform_extractor_path") - we = si.WaveformExtractor.load_from_folder(we_path) - return we + recording = cls.get_recording(key) - def fetch_nwb(self, key): - # TODO: implement fetching waveforms from NWB - return NotImplementedError + curation_key = (cls & key).fetch1() - def _get_waveform_extractor_name(self, key): - waveform_params_name = (WaveformParameters & key).fetch1( - "waveform_params_name" + sorting_analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + curation_key["analysis_file_name"] ) - - return ( - f'{key["nwb_file_name"]}_{str(uuid.uuid4())[0:8]}_' - f'{key["curation_id"]}_{waveform_params_name}_waveforms' + si_sorting = se.read_nwb_sorting( + sorting_analysis_file_abs_path, + sampling_frequency=recording.get_sampling_frequency(), ) - -@schema -class MetricParameters(dj.Manual): - definition = """ - # Parameters for computing quality metrics of sorted units - metric_params_name: varchar(200) - --- - metric_params: blob - """ - metric_default_params = { - "snr": { - "peak_sign": "neg", - "random_chunk_kwargs_dict": { - "num_chunks_per_segment": 20, - "chunk_size": 10000, - "seed": 0, - }, - }, - "isi_violation": {"isi_threshold_ms": 1.5, "min_isi_ms": 0.0}, - "nn_isolation": { - "max_spikes": 1000, - "min_spikes": 10, - "n_neighbors": 5, - "n_components": 7, - "radius_um": 100, - "seed": 0, - }, - "nn_noise_overlap": { - "max_spikes": 1000, - "min_spikes": 10, - "n_neighbors": 5, - "n_components": 7, - "radius_um": 100, - "seed": 0, - }, - "peak_channel": {"peak_sign": "neg"}, - "num_spikes": {}, - } - # Example of peak_offset parameters 'peak_offset': {'peak_sign': 'neg'} - available_metrics = [ - "snr", - "isi_violation", - "nn_isolation", - "nn_noise_overlap", - "peak_offset", - "peak_channel", - "num_spikes", - ] - - def get_metric_default_params(self, metric: str): - "Returns default params for the given metric" - return self.metric_default_params(metric) - - def insert_default(self): - self.insert1( - ["franklab_default3", self.metric_default_params], - skip_duplicates=True, - ) - - def get_available_metrics(self): - for metric in _metric_name_to_func: - if metric in self.available_metrics: - metric_doc = _metric_name_to_func[metric].__doc__.split("\n")[0] - metric_string = ("{metric_name} : {metric_doc}").format( - metric_name=metric, metric_doc=metric_doc - ) - print(metric_string + "\n") - - # TODO - def _validate_metrics_list(self, key): - """Checks whether a row to be inserted contains only the available metrics""" - # get available metrics list - # get metric list from key - # compare - return NotImplementedError - - -@schema -class MetricSelection(dj.Manual): - definition = """ - -> Waveforms - -> MetricParameters - --- - """ - - def insert1(self, key, **kwargs): - waveform_params = (WaveformParameters & key).fetch1("waveform_params") - metric_params = (MetricParameters & key).fetch1("metric_params") - if "peak_offset" in metric_params: - if waveform_params["whiten"]: - warnings.warn( - "Calculating 'peak_offset' metric on " - "whitened waveforms may result in slight " - "discrepancies" - ) - if "peak_channel" in metric_params: - if waveform_params["whiten"]: - Warning( - "Calculating 'peak_channel' metric on " - "whitened waveforms may result in slight " - "discrepancies" - ) - super().insert1(key, **kwargs) - - -@schema -class QualityMetrics(dj.Computed): - definition = """ - -> MetricSelection - --- - quality_metrics_path: varchar(500) - -> AnalysisNwbfile - object_id: varchar(40) # Object ID for the metrics in NWB file - """ - - def make(self, key): - waveform_extractor = Waveforms().load_waveforms(key) - qm = {} - params = (MetricParameters & key).fetch1("metric_params") - for metric_name, metric_params in params.items(): - metric = self._compute_metric( - waveform_extractor, metric_name, **metric_params + with pynwb.NWBHDF5IO( + sorting_analysis_file_abs_path, "r", load_namespaces=True + ) as io: + nwbfile = io.read() + nwb_sorting = nwbfile.objects[curation_key["object_id"]] + merge_groups = nwb_sorting["merge_groups"][:] + + if merge_groups: + units_to_merge = _merge_dict_to_list(merge_groups) + return sc.MergeUnitsSorting( + parent_sorting=si_sorting, units_to_merge=units_to_merge ) - qm[metric_name] = metric - qm_name = self._get_quality_metrics_name(key) - key["quality_metrics_path"] = str( - Path(os.environ["SPYGLASS_WAVEFORMS_DIR"]) / Path(qm_name + ".json") - ) - # save metrics dict as json - print(f"Computed all metrics: {qm}") - self._dump_to_json(qm, key["quality_metrics_path"]) - - key["analysis_file_name"] = AnalysisNwbfile().create( - key["nwb_file_name"] - ) - key["object_id"] = AnalysisNwbfile().add_units_metrics( - key["analysis_file_name"], metrics=qm - ) - AnalysisNwbfile().add(key["nwb_file_name"], key["analysis_file_name"]) - - self.insert1(key) - - def _get_quality_metrics_name(self, key): - wf_name = Waveforms()._get_waveform_extractor_name(key) - qm_name = wf_name + "_qm" - return qm_name - - def _compute_metric(self, waveform_extractor, metric_name, **metric_params): - peak_sign_metrics = ["snr", "peak_offset", "peak_channel"] - metric_func = _metric_name_to_func[metric_name] - # TODO clean up code below - if metric_name == "isi_violation": - metric = metric_func(waveform_extractor, **metric_params) - elif metric_name in peak_sign_metrics: - if "peak_sign" in metric_params: - metric = metric_func( - waveform_extractor, - peak_sign=metric_params.pop("peak_sign"), - **metric_params, - ) - else: - raise Exception( - f"{peak_sign_metrics} metrics require peak_sign", - f"to be defined in the metric parameters", - ) else: - metric = {} - for unit_id in waveform_extractor.sorting.get_unit_ids(): - metric[str(unit_id)] = metric_func( - waveform_extractor, this_unit_id=unit_id, **metric_params - ) - # nn_isolation returns tuple with isolation and unit number. We only want isolation. - if metric_name == "nn_isolation": - metric[str(unit_id)] = metric[str(unit_id)][0] - return metric - - def _dump_to_json(self, qm_dict, save_path): - new_qm = {} - for key, value in qm_dict.items(): - m = {} - for unit_id, metric_val in value.items(): - m[str(unit_id)] = np.float64(metric_val) - new_qm[str(key)] = m - with open(save_path, "w", encoding="utf-8") as f: - json.dump(new_qm, f, ensure_ascii=False, indent=4) - - -def _compute_isi_violation_fractions(waveform_extractor, **metric_params): - """Computes the per unit fraction of interspike interval violations to total spikes.""" - isi_threshold_ms = metric_params["isi_threshold_ms"] - min_isi_ms = metric_params["min_isi_ms"] - - # Extract the total number of spikes that violated the isi_threshold for each unit - isi_violation_counts = sq.compute_isi_violations( - waveform_extractor, - isi_threshold_ms=isi_threshold_ms, - min_isi_ms=min_isi_ms, - ).isi_violations_count - - # Extract the total number of spikes from each unit. The number of ISIs is one less than this - num_spikes = sq.compute_num_spikes(waveform_extractor) + return si_sorting - # Calculate the fraction of ISIs that are violations - isi_viol_frac_metric = { - str(unit_id): isi_violation_counts[unit_id] / (num_spikes[unit_id] - 1) - for unit_id in waveform_extractor.sorting.get_unit_ids() - } - return isi_viol_frac_metric - -def _get_peak_offset( - waveform_extractor: si.WaveformExtractor, peak_sign: str, **metric_params +def _write_sorting_to_nwb_with_curation( + sorting_id: str, + labels: Union[None, Dict[str, List[str]]] = None, + merge_groups: Union[None, List[List[str]]] = None, + metrics: Union[None, Dict[str, Dict[str, float]]] = None, + apply_merge: bool = False, ): - """Computes the shift of the waveform peak from center of window.""" - if "peak_sign" in metric_params: - del metric_params["peak_sign"] - peak_offset_inds = ( - si.postprocessing.get_template_extremum_channel_peak_shift( - waveform_extractor=waveform_extractor, - peak_sign=peak_sign, - **metric_params, - ) + """Save sorting to NWB with curation information. + Curation information is saved as columns in the units table of the NWB file. + + Parameters + ---------- + sorting_id : str + key for the sorting + labels : dict or None, optional + curation labels (e.g. good, noise, mua) + merge_groups : list or None, optional + groups of unit IDs to be merged + metrics : dict or None, optional + Computed quality metrics, one for each cell + apply_merge : bool, optional + whether to apply the merge groups to the sorting before saving, by default False + + Returns + ------- + analysis_nwb_file : str + name of analysis NWB file containing the sorting and curation information + object_id : str + object_id of the units table in the analysis NWB file + """ + # FETCH: + # - primary key for the associated sorting and recording + nwb_file_name = (SpikeSortingSelection & {"sorting_id": sorting_id}).fetch1( + "nwb_file_name" ) - peak_offset = {key: int(abs(val)) for key, val in peak_offset_inds.items()} - return peak_offset - -def _get_peak_channel( - waveform_extractor: si.WaveformExtractor, peak_sign: str, **metric_params -): - """Computes the electrode_id of the channel with the extremum peak for each unit.""" - if "peak_sign" in metric_params: - del metric_params["peak_sign"] - peak_channel_dict = si.postprocessing.get_template_extremum_channel( - waveform_extractor=waveform_extractor, - peak_sign=peak_sign, - **metric_params, + # get sorting + sorting_analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + (SpikeSorting & {"sorting_id": sorting_id}).fetch1("analysis_file_name") ) - peak_channel = {key: int(val) for key, val in peak_channel_dict.items()} - return peak_channel - - -def _get_num_spikes( - waveform_extractor: si.WaveformExtractor, this_unit_id: int -): - """Computes the number of spikes for each unit.""" - all_spikes = sq.compute_num_spikes(waveform_extractor) - cluster_spikes = all_spikes[this_unit_id] - return cluster_spikes - - -_metric_name_to_func = { - "snr": sq.compute_snrs, - "isi_violation": _compute_isi_violation_fractions, - "nn_isolation": sq.nearest_neighbors_isolation, - "nn_noise_overlap": sq.nearest_neighbors_noise_overlap, - "peak_offset": _get_peak_offset, - "peak_channel": _get_peak_channel, - "num_spikes": _get_num_spikes, -} - - -@schema -class AutomaticCurationParameters(dj.Manual): - definition = """ - auto_curation_params_name: varchar(200) # name of this parameter set - --- - merge_params: blob # dictionary of params to merge units - label_params: blob # dictionary params to label units - """ - - def insert1(self, key, **kwargs): - # validate the labels and then insert - # TODO: add validation for merge_params - for metric in key["label_params"]: - if metric not in _metric_name_to_func: - raise Exception(f"{metric} not in list of available metrics") - comparison_list = key["label_params"][metric] - if comparison_list[0] not in _comparison_to_function: - raise Exception( - f'{metric}: "{comparison_list[0]}" ' - f"not in list of available comparisons" - ) - if not isinstance(comparison_list[1], (int, float)): - raise Exception( - f"{metric}: {comparison_list[1]} is of type " - f"{type(comparison_list[1])} and not a number" - ) - for label in comparison_list[2]: - if label not in valid_labels: - raise Exception( - f'{metric}: "{label}" ' - f"not in list of valid labels: {valid_labels}" - ) - super().insert1(key, **kwargs) - - def insert_default(self): - # label_params parsing: Each key is the name of a metric, - # the contents are a three value list with the comparison, a value, - # and a list of labels to apply if the comparison is true - default_params = { - "auto_curation_params_name": "default", - "merge_params": {}, - "label_params": { - "nn_noise_overlap": [">", 0.1, ["noise", "reject"]] - }, - } - self.insert1(default_params, skip_duplicates=True) - - # Second default parameter set for not applying any labels, - # or merges, but adding metrics - no_label_params = { - "auto_curation_params_name": "none", - "merge_params": {}, - "label_params": {}, - } - self.insert1(no_label_params, skip_duplicates=True) - - -@schema -class AutomaticCurationSelection(dj.Manual): - definition = """ - -> QualityMetrics - -> AutomaticCurationParameters - """ - - -_comparison_to_function = { - "<": np.less, - "<=": np.less_equal, - ">": np.greater, - ">=": np.greater_equal, - "==": np.equal, -} - - -@schema -class AutomaticCuration(dj.Computed): - definition = """ - -> AutomaticCurationSelection - --- - auto_curation_key: blob # the key to the curation inserted by make - """ - - def make(self, key): - metrics_path = (QualityMetrics & key).fetch1("quality_metrics_path") - with open(metrics_path) as f: - quality_metrics = json.load(f) - - # get the curation information and the curated sorting - parent_curation = (Curation & key).fetch(as_dict=True)[0] - parent_merge_groups = parent_curation["merge_groups"] - parent_labels = parent_curation["curation_labels"] - parent_curation_id = parent_curation["curation_id"] - parent_sorting = Curation.get_curated_sorting(key) - - merge_params = (AutomaticCurationParameters & key).fetch1( - "merge_params" - ) - merge_groups, units_merged = self.get_merge_groups( - parent_sorting, parent_merge_groups, quality_metrics, merge_params - ) - - label_params = (AutomaticCurationParameters & key).fetch1( - "label_params" - ) - labels = self.get_labels( - parent_sorting, parent_labels, quality_metrics, label_params - ) - - # keep the quality metrics only if no merging occurred. - metrics = quality_metrics if not units_merged else None - - # insert this sorting into the CuratedSpikeSorting Table - # first remove keys that aren't part of the Sorting (the primary key of curation) - c_key = (SpikeSorting & key).fetch("KEY")[0] - curation_key = {item: key[item] for item in key if item in c_key} - key["auto_curation_key"] = Curation.insert_curation( - curation_key, - parent_curation_id=parent_curation_id, - labels=labels, - merge_groups=merge_groups, - metrics=metrics, - description="auto curated", + sorting = se.read_nwb_sorting( + sorting_analysis_file_abs_path, + sampling_frequency=(Raw & {"nwb_file_name": nwb_file_name}).fetch1( + "sampling_rate" + ), + ) + if apply_merge: + sorting = sc.MergeUnitsSorting( + parent_sorting=sorting, units_to_merge=merge_groups ) - - self.insert1(key) - - @staticmethod - def get_merge_groups( - sorting, parent_merge_groups, quality_metrics, merge_params - ): - """Identifies units to be merged based on the quality_metrics and - merge parameters and returns an updated list of merges for the curation. - - Parameters - --------- - sorting : spikeinterface.sorting - parent_merge_groups : list - Information about previous merges - quality_metrics : list - merge_params : dict - - Returns - ------- - merge_groups : list of lists - merge_occurred : bool - - """ - - # overview: - # 1. Use quality metrics to determine merge groups for units - # 2. Combine merge groups with current merge groups to produce union of merges - - if not merge_params: - return parent_merge_groups, False - else: - # TODO: use the metrics to identify clusters that should be merged - # new_merges should then reflect those merges and the line below should be deleted. - new_merges = [] - # append these merges to the parent merge_groups - for new_merge in new_merges: - # check to see if the first cluster listed is in a current merge group - for previous_merge in parent_merge_groups: - if new_merge[0] == previous_merge[0]: - # add the additional units in new_merge to the identified merge group. - previous_merge.extend(new_merge[1:]) - previous_merge.sort() - break - else: - # append this merge group to the list if no previous merge - parent_merge_groups.append(new_merge) - return parent_merge_groups.sort(), True - - @staticmethod - def get_labels(sorting, parent_labels, quality_metrics, label_params): - """Returns a dictionary of labels using quality_metrics and label - parameters. - - Parameters - --------- - sorting : spikeinterface.sorting - parent_labels : list - Information about previous merges - quality_metrics : list - label_params : dict - - Returns - ------- - parent_labels : list - - """ - # overview: - # 1. Use quality metrics to determine labels for units - # 2. Append labels to current labels, checking for inconsistencies - if not label_params: - return parent_labels - else: - for metric in label_params: - if metric not in quality_metrics: - Warning(f"{metric} not found in quality metrics; skipping") + merge_groups = None + + unit_ids = sorting.get_unit_ids() + + # create new analysis nwb file + analysis_nwb_file = AnalysisNwbfile().create(nwb_file_name) + analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(analysis_nwb_file) + with pynwb.NWBHDF5IO( + path=analysis_nwb_file_abs_path, + mode="a", + load_namespaces=True, + ) as io: + nwbf = io.read() + # write sorting to the nwb file + for unit_id in unit_ids: + spike_times = sorting.get_unit_spike_train(unit_id) + nwbf.add_unit( + spike_times=spike_times, + id=unit_id, + ) + # add labels, merge groups, metrics + if labels is not None: + label_values = [] + for unit_id in unit_ids: + if unit_id not in labels: + label_values.append([""]) else: - compare = _comparison_to_function[label_params[metric][0]] - - for unit_id in quality_metrics[metric].keys(): - # compare the quality metric to the threshold with the specified operator - # note that label_params[metric] is a three element list with a comparison operator as a string, - # the threshold value, and a list of labels to be applied if the comparison is true - if compare( - quality_metrics[metric][unit_id], - label_params[metric][1], - ): - if unit_id not in parent_labels: - parent_labels[unit_id] = label_params[metric][2] - # check if the label is already there, and if not, add it - elif ( - label_params[metric][2] - not in parent_labels[unit_id] - ): - parent_labels[unit_id].extend( - label_params[metric][2] - ) - return parent_labels - - -@schema -class CuratedSpikeSortingSelection(dj.Manual): - definition = """ - -> Curation - """ - - -@schema -class CuratedSpikeSorting(dj.Computed): - definition = """ - -> CuratedSpikeSortingSelection - --- - -> AnalysisNwbfile - units_object_id: varchar(40) - """ - - class Unit(dj.Part): - definition = """ - # Table for holding sorted units - -> CuratedSpikeSorting - unit_id: int # ID for each unit - --- - label='': varchar(200) # optional set of labels for each unit - nn_noise_overlap=-1: float # noise overlap metric for each unit - nn_isolation=-1: float # isolation score metric for each unit - isi_violation=-1: float # ISI violation score for each unit - snr=0: float # SNR for each unit - firing_rate=-1: float # firing rate - num_spikes=-1: int # total number of spikes - peak_channel=null: int # channel of maximum amplitude for each unit - """ - - def make(self, key): - unit_labels_to_remove = ["reject"] - # check that the Curation has metrics - metrics = (Curation & key).fetch1("quality_metrics") - if metrics == {}: - Warning( - f"Metrics for Curation {key} should normally be calculated before insertion here" + label_values.append(labels[unit_id]) + nwbf.add_unit_column( + name="curation_label", + description="curation label", + data=label_values, ) + if merge_groups is not None: + merge_groups_dict = _list_to_merge_dict(merge_groups, unit_ids) + merge_groups_list = [ + [""] if value == [] else value + for value in merge_groups_dict.values() + ] + nwbf.add_unit_column( + name="merge_groups", + description="merge groups", + data=merge_groups_list, + ) + if metrics is not None: + for metric, metric_dict in metrics.items(): + metric_values = [] + for unit_id in unit_ids: + if unit_id not in metric_dict: + metric_values.append([]) + else: + metric_values.append(metric_dict[unit_id]) + nwbf.add_unit_column( + name=metric, + description=metric, + data=metric_values, + ) - sorting = Curation.get_curated_sorting(key) - unit_ids = sorting.get_unit_ids() - # Get the labels for the units, add only those units that do not have 'reject' or 'noise' labels - unit_labels = (Curation & key).fetch1("curation_labels") - accepted_units = [] - for unit_id in unit_ids: - if unit_id in unit_labels: - if ( - len(set(unit_labels_to_remove) & set(unit_labels[unit_id])) - == 0 - ): - accepted_units.append(unit_id) - else: - accepted_units.append(unit_id) - - # get the labels for the accepted units - labels = {} - for unit_id in accepted_units: - if unit_id in unit_labels: - labels[unit_id] = ",".join(unit_labels[unit_id]) - - # convert unit_ids in metrics to integers, including only accepted units. - # TODO: convert to int this somewhere else - final_metrics = {} - for metric in metrics: - final_metrics[metric] = { - int(unit_id): metrics[metric][unit_id] - for unit_id in metrics[metric] - if int(unit_id) in accepted_units - } + units_object_id = nwbf.units.object_id + io.write(nwbf) + return analysis_nwb_file, units_object_id - print(f"Found {len(accepted_units)} accepted units") - # get the sorting and save it in the NWB file - sorting = Curation.get_curated_sorting(key) - recording = Curation.get_recording(key) +def _union_intersecting_lists(lists): + result = [] - # get the sort_interval and sorting interval list - sort_interval_name = (SpikeSortingRecording & key).fetch1( - "sort_interval_name" - ) - sort_interval = (SortInterval & key).fetch1("sort_interval") - sort_interval_list_name = (SpikeSorting & key).fetch1( - "artifact_removed_interval_list_name" - ) + while lists: + first, *rest = lists + first = set(first) - timestamps = SpikeSortingRecording._get_recording_timestamps(recording) + merged = True + while merged: + merged = False + for idx, other in enumerate(rest): + if first.intersection(other): + first.update(other) + del rest[idx] + merged = True + break - ( - key["analysis_file_name"], - key["units_object_id"], - ) = Curation().save_sorting_nwb( - key, - sorting, - timestamps, - sort_interval_list_name, - sort_interval, - metrics=final_metrics, - unit_ids=accepted_units, - labels=labels, - ) - self.insert1(key) + result.append(list(first)) + lists = rest - # now add the units - # Remove the non primary key entries. - del key["units_object_id"] - del key["analysis_file_name"] + return result - metric_fields = self.metrics_fields() - for unit_id in accepted_units: - key["unit_id"] = unit_id - if unit_id in labels: - key["label"] = labels[unit_id] - for field in metric_fields: - if field in final_metrics: - key[field] = final_metrics[field][unit_id] - else: - Warning( - f"No metric named {field} in computed unit quality metrics; skipping" - ) - CuratedSpikeSorting.Unit.insert1(key) - def metrics_fields(self): - """Returns a list of the metrics that are currently in the Units table.""" - unit_info = self.Unit().fetch(limit=1, format="frame") - unit_fields = [column for column in unit_info.columns] - unit_fields.remove("label") - return unit_fields +def _list_to_merge_dict(lists_of_strings: List, target_strings: List) -> dict: + """Converts a list of merge groups to a dict. + The keys of the dict (unit ids) are provided separately in case + the merge groups do not contain all the unit ids. + Example: [[1,2,3],[4,5]], [1,2,3,4,5,6] -> {1: [2, 3], 2:[1,3], 3:[1,2] 4: [5], 5: [4], 6: []} - def fetch_nwb(self, *attrs, **kwargs): - return fetch_nwb( - self, (AnalysisNwbfile, "analysis_file_abs_path"), *attrs, **kwargs - ) + Parameters + ---------- + lists_of_strings : _type_ + _description_ + target_strings : _type_ + _description_ - -@schema -class UnitInclusionParameters(dj.Manual): - definition = """ - unit_inclusion_param_name: varchar(80) # the name of the list of thresholds for unit inclusion - --- - inclusion_param_dict: blob # the dictionary of inclusion / exclusion parameters + Returns + ------- + _type_ + _description_ """ + lists_of_strings = _union_intersecting_lists(lists_of_strings) + result = {string: [] for string in target_strings} - def insert1(self, key, **kwargs): - # check to see that the dictionary fits the specifications - # The inclusion parameter dict has the following form: - # param_dict['metric_name'] = (operator, value) - # where operator is '<', '>', <=', '>=', or '==' and value is the comparison (float) value to be used () - # param_dict['exclude_labels'] = [list of labels to exclude] - pdict = key["inclusion_param_dict"] - metrics_list = CuratedSpikeSorting().metrics_fields() + for lst in lists_of_strings: + for string in target_strings: + if string in lst: + result[string].extend([item for item in lst if item != string]) - for k in pdict: - if k not in metrics_list and k != "exclude_labels": - raise Exception( - f"key {k} is not a valid element of the inclusion_param_dict" - ) - if k in metrics_list: - if pdict[k][0] not in _comparison_to_function: - raise Exception( - f"operator {pdict[k][0]} for metric {k} is not in the valid operators list: {_comparison_to_function.keys()}" - ) - if k == "exclude_labels": - for label in pdict[k]: - if label not in valid_labels: - raise Exception( - f"exclude label {label} is not in the valid_labels list: {valid_labels}" - ) - super().insert1(key, **kwargs) + return result - def get_included_units( - self, curated_sorting_key, unit_inclusion_param_name - ): - """given a reference to a set of curated sorting units and the name of a unit inclusion parameter list, returns - Parameters - ---------- - curated_sorting_key : dict - key to select a set of curated sorting - unit_inclusion_param_name : str - name of a unit inclusion parameter entry +def _reverse_associations(assoc_dict): + result = [] - Returns - ------unit key - dict - key to select all of the included units - """ - curated_sortings = (CuratedSpikeSorting() & curated_sorting_key).fetch() - inc_param_dict = ( - UnitInclusionParameters - & {"unit_inclusion_param_name": unit_inclusion_param_name} - ).fetch1("inclusion_param_dict") - units = (CuratedSpikeSorting().Unit() & curated_sortings).fetch() - units_key = (CuratedSpikeSorting().Unit() & curated_sortings).fetch( - "KEY" - ) - # get a list of the metrics in the units table - metrics_list = CuratedSpikeSorting().metrics_fields() - # get the list of labels to exclude if there is one - if "exclude_labels" in inc_param_dict: - exclude_labels = inc_param_dict["exclude_labels"] - del inc_param_dict["exclude_labels"] + for key, values in assoc_dict.items(): + if values: + result.append([key] + values) else: - exclude_labels = [] + result.append([key]) - # create a list of the units to kepp. - keep = np.asarray([True] * len(units)) - for metric in inc_param_dict: - # for all units, go through each metric, compare it to the value specified, and update the list to be kept - keep = np.logical_and( - keep, - _comparison_to_function[inc_param_dict[metric][0]]( - units[metric], inc_param_dict[metric][1] - ), - ) + return result - # now exclude by label if it is specified - if len(exclude_labels): - included_units = [] - for unit_ind in np.ravel(np.argwhere(keep)): - labels = units[unit_ind]["label"].split(",") - exclude = False - for label in labels: - if label in exclude_labels: - keep[unit_ind] = False - break - # return units that passed all of the tests - # TODO: Make this more efficient - return {i: units_key[i] for i in np.ravel(np.argwhere(keep))} + +def _merge_dict_to_list(merge_groups: dict) -> List: + """Converts dict of merge groups to list of merge groups. + Example: {1: [2, 3], 4: [5]} -> [[1, 2, 3], [4, 5]] + """ + units_to_merge = _union_intersecting_lists( + _reverse_associations(merge_groups) + ) + units_to_merge = [lst for lst in units_to_merge if len(lst) >= 2] + return units_to_merge diff --git a/src/spyglass/spikesorting/v1/figurl_curation.py b/src/spyglass/spikesorting/v1/figurl_curation.py index 2280f0c2b..20e2109b5 100644 --- a/src/spyglass/spikesorting/v1/figurl_curation.py +++ b/src/spyglass/spikesorting/v1/figurl_curation.py @@ -1,130 +1,169 @@ -import datajoint as dj - from typing import Any, Union, List, Dict -from .spikesorting_curation import Curation -from .spikesorting_recording import SpikeSortingRecording -from .spikesorting_sorting import SpikeSorting +import datajoint as dj +import pynwb import spikeinterface as si -from sortingview.SpikeSortingView import SpikeSortingView +from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.spikesorting.v1.sorting import SpikeSorting +from spyglass.spikesorting.v1.curation import CurationV1, _merge_dict_to_list + import kachery_cloud as kcl import sortingview.views as vv +from sortingview.SpikeSortingView import SpikeSortingView -schema = dj.schema("spikesorting_curation_figurl") - -# A curation figURL is a link to a visualization of a curation. -# Optionally you can specify a new_curation_uri which will be -# the location of the new manually-edited curation. The -# new_curation_uri should be a github uri of the form -# gh://user/repo/branch/path/to/curation.json -# and ideally the path should be determined by the primary key -# of the curation. The new_curation_uri can also be blank if no -# further manual curation is planned. +schema = dj.schema("spikesorting_v1_figurl_curation") @schema -class CurationFigurlSelection(dj.Manual): +class FigURLCurationSelection(dj.Manual): definition = """ - -> Curation + -> CurationV1 + curation_uri: varchar(1000) # GitHub-based URI to a file to which the manual curation will be saved --- - new_curation_uri: varchar(2000) + metrics_figurl: blob # metrics to display in the figURL """ + @staticmethod + def generate_curation_uri(key: Dict) -> str: + """Generates a kachery-cloud URI containing curation info from a row in CurationV1 table + + Parameters + ---------- + key : dict + primary key from CurationV1 + """ + curation_key = (CurationV1 & key).fetch1() + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + curation_key["analysis_file_name"] + ) + with pynwb.NWBHDF5IO( + analysis_file_abs_path, "r", load_namespaces=True + ) as io: + nwbfile = io.read() + nwb_sorting = nwbfile.objects[curation_key["object_id"]] + unit_ids = nwb_sorting["id"][:] + labels = nwb_sorting["labels"][:] + merge_groups = nwb_sorting["merge_groups"][:] + + unit_ids = [str(unit_id) for unit_id in unit_ids] + + if labels: + labels_dict = dict(zip(unit_ids, labels)) + else: + labels_dict = {} + + if merge_groups: + merge_groups_list = _merge_dict_to_list(merge_groups) + merge_groups_list = [ + [str(unit_id) for unit_id in merge_group] + for merge_group in merge_groups_list + ] + else: + merge_groups_list = [] + + curation_dict = { + "labelsByUnit": labels_dict, + "mergeGroups": merge_groups_list, + } + curation_uri = kcl.store_json(curation_dict) + + return curation_uri + @schema -class CurationFigurl(dj.Computed): +class FigURLCuration(dj.Computed): definition = """ - -> CurationFigurlSelection + -> FigURLCurationSelection --- - url: varchar(2000) - initial_curation_uri: varchar(2000) - new_curation_uri: varchar(2000) + url: varchar(1000) """ def make(self, key: dict): - """Create a Curation Figurl - Parameters - ---------- - key : dict - primary key of an entry from CurationFigurlSelection table - """ - - # get new_curation_uri from selection table - new_curation_uri = (CurationFigurlSelection & key).fetch1( - "new_curation_uri" + # FETCH + sorting_analysis_file_name = (CurationV1 & key).fetch1( + "analysis_file_name" ) + object_id = (CurationV1 & key).fetch1("object_id") + recording_label = (SpikeSorting & key).fetch1("recording_id") - # fetch - recording_path = (SpikeSortingRecording & key).fetch1("recording_path") - sorting_path = (SpikeSorting & key).fetch1("sorting_path") - recording_label = SpikeSortingRecording._get_recording_name(key) - sorting_label = SpikeSorting._get_sorting_name(key) - unit_metrics = _reformat_metrics( - (Curation & key).fetch1("quality_metrics") + # DO + sorting_analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + sorting_analysis_file_name ) - initial_labels = (Curation & key).fetch1("curation_labels") - initial_merge_groups = (Curation & key).fetch1("merge_groups") - - # new_curation_uri = key["new_curation_uri"] - - # Create the initial curation and store it in kachery - for k, v in initial_labels.items(): - new_list = [] - for item in v: - if item not in new_list: - new_list.append(item) - initial_labels[k] = new_list - initial_curation = { - "labelsByUnit": initial_labels, - "mergeGroups": initial_merge_groups, - } - initial_curation_uri = kcl.store_json(initial_curation) + recording = CurationV1.get_recording(key) + sorting = CurationV1.get_sorting(key) + sorting_label = key["sorting_id"] + curation_uri = key["curation_uri"] + + metric_dict = {} + with pynwb.NWBHDF5IO( + sorting_analysis_file_abs_path, "r", load_namespaces=True + ) as io: + nwbf = io.read() + nwb_sorting = nwbf.objects[object_id] + unit_ids = nwb_sorting["id"][:] + for metric in key["metrics_figurl"]: + metric_dict[metric] = dict( + zip(unit_ids, nwb_sorting[metric][:]) + ) - # Get the recording/sorting extractors - R = si.load_extractor(recording_path) - if R.get_num_segments() > 1: - R = si.concatenate_recordings([R]) - S = si.load_extractor(sorting_path) + unit_metrics = _reformat_metrics(metric_dict) + + # TODO: figure out a way to specify the similarity metrics # Generate the figURL - url = _generate_the_figurl( - R=R, - S=S, - initial_curation_uri=initial_curation_uri, - new_curation_uri=new_curation_uri, + key["url"] = _generate_figurl( + R=recording, + S=sorting, + initial_curation_uri=curation_uri, recording_label=recording_label, sorting_label=sorting_label, unit_metrics=unit_metrics, ) - # insert - key["url"] = url - key["initial_curation_uri"] = initial_curation_uri - key["new_curation_uri"] = new_curation_uri - self.insert1(key) + # INSERT + self.insert1(key, skip_duplicates=True) + + @classmethod + def get_labels(cls): + return NotImplementedError + @classmethod + def get_merge_groups(cls): + return NotImplementedError -def _generate_the_figurl( - *, + +def _generate_figurl( R: si.BaseRecording, S: si.BaseSorting, - unit_metrics: Union[List[Any], None] = None, initial_curation_uri: str, recording_label: str, sorting_label: str, - new_curation_uri: str, + unit_metrics: Union[List[Any], None] = None, + segment_duration_sec=1200, + snippet_ms_before=1, + snippet_ms_after=1, + max_num_snippets_per_segment=1000, + channel_neighborhood_size=5, + raster_plot_subsample_max_firing_rate=50, + spike_amplitudes_subsample_max_firing_rate=50, ): print("Preparing spikesortingview data") + sampling_frequency = R.get_sampling_frequency() X = SpikeSortingView.create( recording=R, sorting=S, - segment_duration_sec=60 * 20, - snippet_len=(20, 20), - max_num_snippets_per_segment=100, - channel_neighborhood_size=7, + segment_duration_sec=segment_duration_sec, + snippet_len=( + int(snippet_ms_before * sampling_frequency / 1000), + int(snippet_ms_after * sampling_frequency / 1000), + ), + max_num_snippets_per_segment=max_num_snippets_per_segment, + channel_neighborhood_size=channel_neighborhood_size, ) + # create a fake unit similarity matrix (for future reference) # similarity_scores = [] # for u1 in X.unit_ids: @@ -136,7 +175,7 @@ def _generate_the_figurl( # similarity=similarity_matrix[(X.unit_ids==u1),(X.unit_ids==u2)] # ) # ) - # Create the similarity matrix view + # # Create the similarity matrix view # unit_similarity_matrix_view = vv.UnitSimilarityMatrix( # unit_ids=X.unit_ids, # similarity_scores=similarity_scores @@ -144,8 +183,12 @@ def _generate_the_figurl( # Assemble the views in a layout # You can replace this with other layouts - raster_plot_subsample_max_firing_rate = 50 - spike_amplitudes_subsample_max_firing_rate = 50 + raster_plot_subsample_max_firing_rate = ( + raster_plot_subsample_max_firing_rate + ) + spike_amplitudes_subsample_max_firing_rate = ( + spike_amplitudes_subsample_max_firing_rate + ) view = vv.MountainLayout( items=[ vv.MountainLayoutItem( @@ -195,14 +238,10 @@ def _generate_the_figurl( ), ] ) - url_state = ( - { - "initialSortingCuration": initial_curation_uri, - "sortingCuration": new_curation_uri, - } - if new_curation_uri - else {"sortingCuration": initial_curation_uri} - ) + url_state = { + "initialSortingCuration": initial_curation_uri, + "sortingCuration": initial_curation_uri, + } label = f"{recording_label} {sorting_label}" url = view.url(label=label, state=url_state) return url diff --git a/src/spyglass/spikesorting/v1/metric_curation.py b/src/spyglass/spikesorting/v1/metric_curation.py new file mode 100644 index 000000000..7e278e54a --- /dev/null +++ b/src/spyglass/spikesorting/v1/metric_curation.py @@ -0,0 +1,555 @@ +import os +from typing import List, Union, Dict, Any + +import datajoint as dj +import numpy as np +import pynwb +import spikeinterface as si +import spikeinterface.preprocessing as sp +import spikeinterface.qualitymetrics as sq + +from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.spikesorting.v1.recording import SpikeSortingRecording +from spyglass.spikesorting.v1.sorting import SpikeSorting, SpikeSortingSelection +from spyglass.spikesorting.v1.curation import ( + CurationV1, + _list_to_merge_dict, + _merge_dict_to_list, +) +from spyglass.spikesorting.v1.metric_utils import ( + get_num_spikes, + get_peak_channel, + get_peak_offset, + compute_isi_violation_fractions, +) +from spyglass.spikesorting.v1.utils import generate_nwb_uuid + +schema = dj.schema("spikesorting_v1_metric_curation") + + +_metric_name_to_func = { + "snr": sq.compute_snrs, + "isi_violation": compute_isi_violation_fractions, + "nn_isolation": sq.nearest_neighbors_isolation, + "nn_noise_overlap": sq.nearest_neighbors_noise_overlap, + "peak_offset": get_peak_offset, + "peak_channel": get_peak_channel, + "num_spikes": get_num_spikes, +} + +_comparison_to_function = { + "<": np.less, + "<=": np.less_equal, + ">": np.greater, + ">=": np.greater_equal, + "==": np.equal, +} + + +@schema +class WaveformParameters(dj.Lookup): + definition = """ + waveform_param_name: varchar(80) # name of waveform extraction parameters + --- + waveform_params: blob # a dict of waveform extraction parameters + """ + + contents = [ + [ + "default_not_whitened", + { + "ms_before": 0.5, + "ms_after": 0.5, + "max_spikes_per_unit": 5000, + "n_jobs": 5, + "total_memory": "5G", + "whiten": False, + }, + ], + [ + "default_whitened", + { + "ms_before": 0.5, + "ms_after": 0.5, + "max_spikes_per_unit": 5000, + "n_jobs": 5, + "total_memory": "5G", + "whiten": True, + }, + ], + ] + + @classmethod + def insert_default(cls): + cls.insert(cls.contents, skip_duplicates=True) + + +@schema +class MetricParameters(dj.Lookup): + definition = """ + # Parameters for computing quality metrics of sorted units + metric_param_name: varchar(200) + --- + metric_params: blob + """ + metric_default_param_name = "franklab_default" + metric_default_param = { + "snr": { + "peak_sign": "neg", + "random_chunk_kwargs_dict": { + "num_chunks_per_segment": 20, + "chunk_size": 10000, + "seed": 0, + }, + }, + "isi_violation": {"isi_threshold_ms": 1.5, "min_isi_ms": 0.0}, + "nn_isolation": { + "max_spikes": 1000, + "min_spikes": 10, + "n_neighbors": 5, + "n_components": 7, + "radius_um": 100, + "seed": 0, + }, + "nn_noise_overlap": { + "max_spikes": 1000, + "min_spikes": 10, + "n_neighbors": 5, + "n_components": 7, + "radius_um": 100, + "seed": 0, + }, + "peak_channel": {"peak_sign": "neg"}, + "num_spikes": {}, + } + contents = [[metric_default_param_name, metric_default_param]] + + @classmethod + def insert_default(cls): + cls.insert(cls.contents, skip_duplicates=True) + + @classmethod + def show_available_metrics(self): + for metric in _metric_name_to_func: + metric_doc = _metric_name_to_func[metric].__doc__.split("\n")[0] + print(f"{metric} : {metric_doc}\n") + + +@schema +class MetricCurationParameters(dj.Lookup): + definition = """ + metric_curation_param_name: varchar(200) + --- + label_params: blob # dict of param to label units + merge_params: blob # dict of param to merge units + """ + + contents = [ + ["default", {"nn_noise_overlap": [">", 0.1, ["noise", "reject"]]}, {}], + ["none", {}, {}], + ] + + @classmethod + def insert_default(cls): + cls.insert(cls.contents, skip_duplicates=True) + + +@schema +class MetricCurationSelection(dj.Manual): + definition = """ + metric_curation_id: varchar(32) + --- + -> CurationV1 + -> WaveformParameters + -> MetricParameters + -> MetricCurationParameters + """ + + @classmethod + def insert_selection(cls, key: dict): + """Insert a row into MetricCurationSelection with an + automatically generated unique metric curation ID as the sole primary key. + + Parameters + ---------- + key : dict + primary key of CurationV1, WaveformParameters, MetricParameters MetricCurationParameters + + Returns + ------- + key : dict + key for the inserted row + """ + if len((cls & key).fetch()) > 0: + print( + "This row has already been inserted into MetricCurationSelection." + ) + return (cls & key).fetch1() + key["metric_curation_id"] = generate_nwb_uuid( + (SpikeSortingSelection & key).fetch1("nwb_file_name"), + "MC", + 6, + ) + cls.insert1(key, skip_duplicates=True) + return key + + +@schema +class MetricCuration(dj.Computed): + definition = """ + -> MetricCurationSelection + --- + -> AnalysisNwbfile + object_id: varchar(40) # Object ID for the metrics in NWB file + """ + + def make(self, key): + # FETCH + nwb_file_name = (SpikeSortingSelection * + MetricCurationSelection & key).fetch1("nwb_file_name") + + waveform_params = (WaveformParameters * MetricCurationSelection & key).fetch1("waveform_params") + metric_params = (MetricParameters * MetricCurationSelection & key).fetch1("metric_params") + label_params, merge_params = (MetricCurationParameters* MetricCurationSelection & key).fetch1("label_params", "merge_params") + sorting_id, curation_id = (MetricCurationSelection & key).fetch1("sorting_id","curation_id") + # DO + # load recording and sorting + recording = CurationV1.get_recording({"sorting_id":sorting_id, "curation_id":curation_id}) + sorting = CurationV1.get_sorting({"sorting_id":sorting_id, "curation_id":curation_id}) + # extract waveforms + if "whiten" in waveform_params: + if waveform_params.pop("whiten"): + recording = sp.whiten(recording, dtype=np.float64) + + waveforms_dir = ( + os.environ.get("SPYGLASS_TEMP_DIR") + + "/" + + key["metric_curation_id"] + ) + try: + os.mkdir(waveforms_dir) + except FileExistsError: + pass + print("Extracting waveforms...") + waveforms = si.extract_waveforms( + recording=recording, + sorting=sorting, + folder=waveforms_dir, + **waveform_params, + ) + # compute metrics + print("Computing metrics...") + metrics = {} + for metric_name, metric_param_dict in metric_params.items(): + metrics[metric_name] = self._compute_metric( + nwb_file_name, waveforms, metric_name, **metric_param_dict + ) + + print("Applying curation...") + labels = self._compute_labels(metrics, label_params) + merge_groups = self._compute_merge_groups(metrics, merge_params) + + print("Saving to NWB...") + ( + key["analysis_file_name"], + key["object_id"], + ) = _write_metric_curation_to_nwb( + waveforms, metrics, labels, merge_groups + ) + + + # INSERT + AnalysisNwbfile().add( + nwb_file_name, + key["analysis_file_name"], + ) + self.insert1(key) + + @classmethod + def get_waveforms(cls): + return NotImplementedError + + @classmethod + def get_metrics(cls, key: dict): + analysis_file_name = (cls & key).fetch1("analysis_file_name") + object_id = (cls & key).fetch1("object_id") + metric_param_name = (cls & key).fetch1("metric_param_name") + metric_params = ( + MetricParameters & {"metric_param_name": metric_param_name} + ).fetch1("metric_params") + metric_names = list(metric_params.keys()) + + metrics = {} + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_file_name + ) + with pynwb.NWBHDF5IO( + path=analysis_file_abs_path, + mode="r", + load_namespaces=True, + ) as io: + nwbf = io.read() + units = nwbf.objects[object_id] + for metric_name in metric_names: + metrics[metric_name] = dict( + zip(units[id][:], units[metric_name][:]) + ) + + return metrics + + @classmethod + def get_labels(cls, key: dict): + analysis_file_name = (cls & key).fetch1("analysis_file_name") + object_id = (cls & key).fetch1("object_id") + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_file_name + ) + with pynwb.NWBHDF5IO( + path=analysis_file_abs_path, + mode="r", + load_namespaces=True, + ) as io: + nwbf = io.read() + units = nwbf.objects[object_id] + labels = dict(zip(units[id][:], units["curation_labels"][:])) + + return labels + + @classmethod + def get_merge_groups(cls, key: dict): + analysis_file_name = (cls & key).fetch1("analysis_file_name") + object_id = (cls & key).fetch1("object_id") + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_file_name + ) + with pynwb.NWBHDF5IO( + path=analysis_file_abs_path, + mode="r", + load_namespaces=True, + ) as io: + nwbf = io.read() + units = nwbf.objects[object_id] + labels = dict(zip(units[id][:], units["merge_groups"][:])) + + return _merge_dict_to_list(labels) + + @staticmethod + def _compute_metric(waveform_extractor, metric_name, **metric_params): + peak_sign_metrics = ["snr", "peak_offset", "peak_channel"] + metric_func = _metric_name_to_func[metric_name] + # Not sure what this is doing; leaving in case someone is using them + if metric_name in peak_sign_metrics: + if "peak_sign" in metric_params: + metric = metric_func( + waveform_extractor, + peak_sign=metric_params.pop("peak_sign"), + **metric_params, + ) + else: + raise Exception( + f"{peak_sign_metrics} metrics require peak_sign", + f"to be defined in the metric parameters", + ) + else: + metric = {} + for unit_id in waveform_extractor.sorting.get_unit_ids(): + metric[str(unit_id)] = metric_func( + waveform_extractor, this_unit_id=unit_id, **metric_params + ) + return metric + + @staticmethod + def _compute_labels( + metrics: Dict[str, Dict[str, Union[float, List[float]]]], + label_params: Dict[str, List[Any]], + ) -> Dict[str, List[str]]: + """Computes the labels based on the metric and label parameters. + + Parameters + ---------- + quality_metrics : dict + Example: {"snr" : {"1" : 2, "2" : 0.1, "3" : 2.3}} + This indicates that the values of the "snr" quality metric + for the units "1", "2", "3" are 2, 0.1, and 2.3, respectively. + + label_params : dict + Example: { + "snr" : [(">", 1, ["good", "mua"]), + ("<", 1, ["noise"])] + } + This indicates that units with values of the "snr" quality metric + greater than 1 should be given the labels "good" and "mua" and values + less than 1 should be given the label "noise". + + Returns + ------- + labels : dict + Example: {"1" : ["good", "mua"], "2" : ["noise"], "3" : ["good", "mua"]} + + """ + if not label_params: + return {} + else: + unit_ids = list(metrics[list(metrics.keys())[0]].keys()) + labels = {unit_id: [] for unit_id in unit_ids} + for metric in label_params: + if metric not in metrics: + Warning(f"{metric} not found in quality metrics; skipping") + else: + for condition in label_params[metric]: + assert ( + len(condition) == 3 + ), f"Condition {condition} must be of length 3" + compare = _comparison_to_function[condition[0]] + for unit_id in unit_ids: + if compare( + metrics[metric][unit_id], + condition[1], + ): + labels[unit_id].extend(label_params[metric][2]) + return labels + + @staticmethod + def _compute_merge_groups( + metrics: Dict[str, Dict[str, Union[float, List[float]]]], + merge_params: Dict[str, List[Any]], + ) -> Dict[str, List[str]]: + """Identifies units to be merged based on the metrics and merge parameters. + + Parameters + --------- + quality_metrics : dict + Example: {"cosine_similarity" : { + "1" : {"1" : 1.00, "2" : 0.10, "3": 0.95}, + "2" : {"1" : 0.10, "2" : 1.00, "3": 0.70}, + "3" : {"1" : 0.95, "2" : 0.70, "3": 1.00} + }} + This shows the pairwise values of the "cosine_similarity" quality metric + for the units "1", "2", "3" as a nested dict. + + merge_params : dict + Example: {"cosine_similarity" : [">", 0.9]} + This indicates that units with values of the "cosine_similarity" quality metric + greater than 0.9 should be placed in the same merge group. + + + Returns + ------- + merge_groups : dict + Example: {"1" : ["3"], "2" : [], "3" : ["1"]} + + """ + + if not merge_params: + return [] + else: + unit_ids = list(metrics[list(metrics.keys())[0]].keys()) + merge_groups = {unit_id: [] for unit_id in unit_ids} + for metric in merge_params: + if metric not in metrics: + Warning(f"{metric} not found in quality metrics; skipping") + else: + compare = _comparison_to_function[merge_params[metric][0]] + for unit_id in unit_ids: + other_unit_ids = [ + other_unit_id + for other_unit_id in unit_ids + if other_unit_id != unit_id + ] + for other_unit_id in other_unit_ids: + if compare( + metrics[metric][unit_id][other_unit_id], + merge_params[metric][1], + ): + merge_groups[unit_id].extend(other_unit_id) + return merge_groups + + +def _write_metric_curation_to_nwb( + nwb_file_name: str, + waveforms: si.WaveformExtractor, + metrics: Union[None, Dict[str, Dict[str, float]]] = None, + labels: Union[None, Dict[str, List[str]]] = None, + merge_groups: Union[None, List[List[str]]] = None, +): + """Save waveforms, metrics, labels, and merge groups to NWB + + Parameters + ---------- + sorting_id : str + key for the sorting + labels : dict or None, optional + curation labels (e.g. good, noise, mua) + merge_groups : list or None, optional + groups of unit IDs to be merged + metrics : dict or None, optional + Computed quality metrics, one for each cell + apply_merge : bool, optional + whether to apply the merge groups to the sorting before saving, by default False + + Returns + ------- + analysis_nwb_file : str + name of analysis NWB file containing the sorting and curation information + object_id : str + object_id of the units table in the analysis NWB file + """ + + unit_ids = [int(i) for i in waveforms.sorting.get_unit_ids()] + + # create new analysis nwb file + analysis_nwb_file = AnalysisNwbfile().create(nwb_file_name) + analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(analysis_nwb_file) + with pynwb.NWBHDF5IO( + path=analysis_nwb_file_abs_path, + mode="a", + load_namespaces=True, + ) as io: + nwbf = io.read() + # write waveforms to the nwb file + for unit_id in unit_ids: + nwbf.add_unit( + spike_times=waveforms.sorting.get_unit_spike_train(unit_id), + id=unit_id, + electrodes=waveforms.recording.get_channel_ids(), + waveforms=waveforms.get_waveforms(unit_id), + waveform_mean=waveforms.get_template(unit_id), + ) + + # add labels, merge groups, metrics + if labels is not None: + label_values = [] + for unit_id in unit_ids: + if unit_id not in labels: + label_values.append([]) + else: + label_values.append(labels[unit_id]) + nwbf.add_unit_column( + name="curation_label", + description="curation label", + data=label_values, + ) + if merge_groups is not None: + merge_groups_dict = _list_to_merge_dict(merge_groups, unit_ids) + nwbf.add_unit_column( + name="merge_groups", + description="merge groups", + data=list(merge_groups_dict.values()), + ) + if metrics is not None: + for metric, metric_dict in metrics.items(): + metric_values = [] + for unit_id in unit_ids: + if unit_id not in metric_dict: + metric_values.append([]) + else: + metric_values.append(metric_dict[unit_id]) + nwbf.add_unit_column( + name=metric, + description=metric, + data=metric_values, + ) + + units_object_id = nwbf.units.object_id + io.write(nwbf) + return analysis_nwb_file, units_object_id diff --git a/src/spyglass/spikesorting/v1/metric_utils.py b/src/spyglass/spikesorting/v1/metric_utils.py new file mode 100644 index 000000000..d76b478ec --- /dev/null +++ b/src/spyglass/spikesorting/v1/metric_utils.py @@ -0,0 +1,81 @@ +import spikeinterface as si +import spikeinterface.qualitymetrics as sq +import numpy as np + + +def compute_isi_violation_fractions( + waveform_extractor: si.WaveformExtractor, + this_unit_id: str, + isi_threshold_ms: float = 2.0, + min_isi_ms: float = 0.0, +): + """Computes the fraction of interspike interval violations. + + Parameters + ---------- + waveform_extractor: si.WaveformExtractor + The extractor object for the recording. + + """ + + # Extract the total number of spikes that violated the isi_threshold for each unit + isi_violation_counts = np.asarray( + sq.compute_isi_violations( + waveform_extractor, + isi_threshold_ms=isi_threshold_ms, + min_isi_ms=min_isi_ms, + ).isi_violations_count + ) + + isi_violation_count = isi_violation_counts[ + waveform_extractor.sorting.get_unit_ids() == this_unit_id + ] + total_spike_count = get_num_spikes(waveform_extractor, this_unit_id) + return isi_violation_count / (total_spike_count - 1) + + +def get_peak_offset( + waveform_extractor: si.WaveformExtractor, peak_sign: str, **metric_params +): + """Computes the shift of the waveform peak from center of window. + + Parameters + ---------- + waveform_extractor: si.WaveformExtractor + The extractor object for the recording. + peak_sign: str + The sign of the peak to compute. ('neg', 'pos', 'both') + """ + if "peak_sign" in metric_params: + del metric_params["peak_sign"] + peak_offset_inds = si.get_template_extremum_channel_peak_shift( + waveform_extractor=waveform_extractor, + peak_sign=peak_sign, + **metric_params, + ) + peak_offset = {key: int(abs(val)) for key, val in peak_offset_inds.items()} + return peak_offset + + +def get_peak_channel( + waveform_extractor: si.WaveformExtractor, peak_sign: str, **metric_params +): + """Computes the electrode_id of the channel with the extremum peak for each unit.""" + if "peak_sign" in metric_params: + del metric_params["peak_sign"] + peak_channel_dict = si.get_template_extremum_channel( + waveform_extractor=waveform_extractor, + peak_sign=peak_sign, + **metric_params, + ) + peak_channel = {key: int(val) for key, val in peak_channel_dict.items()} + return peak_channel + + +def get_num_spikes(waveform_extractor: si.WaveformExtractor, this_unit_id: str): + """Computes the number of spikes for each unit.""" + all_spikes = sq.compute_num_spikes(waveform_extractor) + unit_spikes = all_spikes[ + waveform_extractor.sorting.get_unit_ids() == this_unit_id + ] + return unit_spikes diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index cdad793e9..6d62b0152 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -1,39 +1,36 @@ -import os -import shutil -from functools import reduce -from pathlib import Path +from typing import Tuple, Iterable, Optional, Union, List import datajoint as dj import numpy as np +import pynwb import probeinterface as pi import spikeinterface as si import spikeinterface.extractors as se +from hdmf.data_utils import GenericDataChunkIterator -from ..common.common_device import Probe, ProbeType # noqa: F401 -from ..common.common_ephys import Electrode, ElectrodeGroup -from ..common.common_interval import ( +from spyglass.common import Session # noqa: F401 +from spyglass.common.common_ephys import Electrode, Raw # noqa: F401 +from spyglass.common.common_device import Probe +from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile +from spyglass.common.common_interval import ( IntervalList, interval_list_intersect, - intervals_by_length, - union_adjacent_index, ) -from ..common.common_lab import LabTeam # noqa: F401 -from ..common.common_nwbfile import Nwbfile -from ..common.common_session import Session # noqa: F401 -from ..utils.dj_helper_fn import dj_replace -from ..settings import recording_dir +from spyglass.common.common_lab import LabTeam +from spyglass.spikesorting.v1.utils import generate_nwb_uuid -schema = dj.schema("spikesorting_recording") +schema = dj.schema("spikesorting_v1_recording") @schema class SortGroup(dj.Manual): definition = """ - # Set of electrodes that will be sorted together + # Set of electrodes to spike sort together -> Session - sort_group_id: int # identifier for a group of electrodes + sort_group_name: varchar(30) --- - sort_reference_electrode_id = -1: int # the electrode to use for reference. -1: no reference, -2: common median + sort_reference_electrode_id = -1: int # the electrode to use for referencing + # -1: no reference, -2: common median """ class SortGroupElectrode(dj.Part): @@ -42,41 +39,36 @@ class SortGroupElectrode(dj.Part): -> Electrode """ + @classmethod def set_group_by_shank( - self, + cls, nwb_file_name: str, references: dict = None, omit_ref_electrode_group=False, omit_unitrode=True, ): - """Divides electrodes into groups based on their shank position. - - * Electrodes from probes with 1 shank (e.g. tetrodes) are placed in a + """Create sort group for each shank. + - Electrodes from probes with 1 shank (e.g. tetrodes) are placed in a single group - * Electrodes from probes with multiple shanks (e.g. polymer probes) are + - Electrodes from probes with multiple shanks (e.g. polymer probes) are placed in one group per shank - * Bad channels are omitted + - Bad channels are omitted Parameters ---------- nwb_file_name : str - the name of the NWB file whose electrodes should be put into - sorting groups + the name of the NWB file whose electrodes should be put into sorting groups references : dict, optional If passed, used to set references. Otherwise, references set using - original reference electrodes from config. Keys: electrode groups. - Values: reference electrode. - omit_ref_electrode_group : bool - Optional. If True, no sort group is defined for electrode group of - reference. - omit_unitrode : bool - Optional. If True, no sort groups are defined for unitrodes. + original reference electrodes from config. Keys: electrode groups. Values: reference electrode. + omit_ref_electrode_group : bool, optional + If True, no sort group is defined for electrode group of reference. + omit_unitrode : bool, optional + If True, no sort groups are defined for unitrodes. """ - # delete any current groups - (SortGroup & {"nwb_file_name": nwb_file_name}).delete() # get the electrodes from this NWB file electrodes = ( - Electrode() + Electrode & {"nwb_file_name": nwb_file_name} & {"bad_channel": "False"} ).fetch() @@ -96,7 +88,9 @@ def set_group_by_shank( sge_key["electrode_group_name"] = e_group # get the indices of all electrodes in this group / shank and set their sorting group for shank in shank_list: - sg_key["sort_group_id"] = sge_key["sort_group_id"] = sort_group + sg_key["sort_group_name"] = sge_key[ + "sort_group_name" + ] = sort_group # specify reference electrode. Use 'references' if passed, otherwise use reference from config if not references: shank_elect_ref = electrodes[ @@ -113,14 +107,12 @@ def set_group_by_shank( ] else: ValueError( - f"Error in electrode group {e_group}: reference " - + "electrodes are not all the same" + f"Error in electrode group {e_group}: reference electrodes are not all the same" ) else: if e_group not in references.keys(): raise Exception( - f"electrode group {e_group} not a key in " - + "references, so cannot set reference" + f"electrode group {e_group} not a key in references, so cannot set reference" ) else: sg_key["sort_reference_electrode_id"] = references[ @@ -141,16 +133,14 @@ def set_group_by_shank( len(reference_electrode_group) != 1 ): raise Exception( - "Should have found exactly one electrode group for " - + "reference electrode, but found " - + f"{len(reference_electrode_group)}." + f"Should have found exactly one electrode group for reference electrode," + f"but found {len(reference_electrode_group)}." ) if omit_ref_electrode_group and ( str(e_group) == str(reference_electrode_group) ): print( - f"Omitting electrode group {e_group} from sort groups " - + "because contains reference." + f"Omitting electrode group {e_group} from sort groups because contains reference." ) continue shank_elect = electrodes["electrode_id"][ @@ -166,454 +156,375 @@ def set_group_by_shank( f"Omitting electrode group {e_group}, shank {shank} from sort groups because unitrode." ) continue - self.insert1(sg_key) + cls.insert1(sg_key, skip_duplicates=True) for elect in shank_elect: sge_key["electrode_id"] = elect - self.SortGroupElectrode().insert1(sge_key) - sort_group += 1 - - def set_group_by_electrode_group(self, nwb_file_name: str): - """Assign groups to all non-bad channel electrodes based on their electrode group - and sets the reference for each group to the reference for the first channel of the group. - - Parameters - ---------- - nwb_file_name: str - the name of the nwb whose electrodes should be put into sorting groups - """ - # delete any current groups - (SortGroup & {"nwb_file_name": nwb_file_name}).delete() - # get the electrodes from this NWB file - electrodes = ( - Electrode() - & {"nwb_file_name": nwb_file_name} - & {"bad_channel": "False"} - ).fetch() - e_groups = np.unique(electrodes["electrode_group_name"]) - sg_key = dict() - sge_key = dict() - sg_key["nwb_file_name"] = sge_key["nwb_file_name"] = nwb_file_name - sort_group = 0 - for e_group in e_groups: - sge_key["electrode_group_name"] = e_group - # sg_key['sort_group_id'] = sge_key['sort_group_id'] = sort_group - # TEST - sg_key["sort_group_id"] = sge_key["sort_group_id"] = int(e_group) - # get the list of references and make sure they are all the same - shank_elect_ref = electrodes["original_reference_electrode"][ - electrodes["electrode_group_name"] == e_group - ] - if np.max(shank_elect_ref) == np.min(shank_elect_ref): - sg_key["sort_reference_electrode_id"] = shank_elect_ref[0] - else: - ValueError( - f"Error in electrode group {e_group}: reference electrodes are not all the same" - ) - self.insert1(sg_key) - - shank_elect = electrodes["electrode_id"][ - electrodes["electrode_group_name"] == e_group - ] - for elect in shank_elect: - sge_key["electrode_id"] = elect - self.SortGroupElectrode().insert1(sge_key) - sort_group += 1 - - def set_reference_from_list(self, nwb_file_name, sort_group_ref_list): - """ - Set the reference electrode from a list containing sort groups and reference electrodes - :param: sort_group_ref_list - 2D array or list where each row is [sort_group_id reference_electrode] - :param: nwb_file_name - The name of the NWB file whose electrodes' references should be updated - :return: Null - """ - key = dict() - key["nwb_file_name"] = nwb_file_name - sort_group_list = (SortGroup() & key).fetch1() - for sort_group in sort_group_list: - key["sort_group_id"] = sort_group - self.insert( - dj_replace( - sort_group_list, - sort_group_ref_list, - "sort_group_id", - "sort_reference_electrode_id", - ), - replace="True", - ) - - def get_geometry(self, sort_group_id, nwb_file_name): - """ - Returns a list with the x,y coordinates of the electrodes in the sort group - for use with the SpikeInterface package. - - Converts z locations to y where appropriate. - - Parameters - ---------- - sort_group_id : int - nwb_file_name : str - - Returns - ------- - geometry : list - List of coordinate pairs, one per electrode - """ - - # create the channel_groups dictiorary - channel_group = dict() - key = dict() - key["nwb_file_name"] = nwb_file_name - electrodes = (Electrode() & key).fetch() - - key["sort_group_id"] = sort_group_id - sort_group_electrodes = (SortGroup.SortGroupElectrode() & key).fetch() - electrode_group_name = sort_group_electrodes["electrode_group_name"][0] - probe_id = ( - ElectrodeGroup - & { - "nwb_file_name": nwb_file_name, - "electrode_group_name": electrode_group_name, - } - ).fetch1("probe_id") - channel_group[sort_group_id] = dict() - channel_group[sort_group_id]["channels"] = sort_group_electrodes[ - "electrode_id" - ].tolist() - - n_chan = len(channel_group[sort_group_id]["channels"]) - - geometry = np.zeros((n_chan, 2), dtype="float") - tmp_geom = np.zeros((n_chan, 3), dtype="float") - for i, electrode_id in enumerate( - channel_group[sort_group_id]["channels"] - ): - # get the relative x and y locations of this channel from the probe table - probe_electrode = int( - electrodes["probe_electrode"][ - electrodes["electrode_id"] == electrode_id - ] - ) - rel_x, rel_y, rel_z = ( - Probe().Electrode() - & {"probe_id": probe_id, "probe_electrode": probe_electrode} - ).fetch("rel_x", "rel_y", "rel_z") - # TODO: Fix this HACK when we can use probeinterface: - rel_x = float(rel_x) - rel_y = float(rel_y) - rel_z = float(rel_z) - tmp_geom[i, :] = [rel_x, rel_y, rel_z] - - # figure out which columns have coordinates - n_found = 0 - for i in range(3): - if np.any(np.nonzero(tmp_geom[:, i])): - if n_found < 2: - geometry[:, n_found] = tmp_geom[:, i] - n_found += 1 - else: - Warning( - "Relative electrode locations have three coordinates; only two are currently supported" + cls.SortGroupElectrode.insert1( + sge_key, skip_duplicates=True ) - return np.ndarray.tolist(geometry) - - -@schema -class SortInterval(dj.Manual): - definition = """ - -> Session - sort_interval_name: varchar(200) # name for this interval - --- - sort_interval: longblob # 1D numpy array with start and end time for a single interval to be used for spike sorting - """ + sort_group += 1 @schema -class SpikeSortingPreprocessingParameters(dj.Manual): +class SpikeSortingPreprocessingParameters(dj.Lookup): definition = """ - preproc_params_name: varchar(200) + # Parameters for denoising (filtering and referencing/whitening) recording + # prior to spike sorting + preproc_param_name: varchar(200) --- preproc_params: blob """ - def insert_default(self): - # set up the default filter parameters - freq_min = 300 # high pass filter value - freq_max = 6000 # low pass filter value - margin_ms = 5 # margin in ms on border to avoid border effect - seed = 0 # random seed for whitening + contents = [ + [ + "default", + { + "frequency_min": 300, # high pass filter value + "frequency_max": 6000, # low pass filter value + "margin_ms": 5, # margin in ms on border to avoid border effect + "seed": 0, # random seed for whitening + "min_segment_length": 1, # minimum segment length in seconds + }, + ] + ] - key = dict() - key["preproc_params_name"] = "default" - key["preproc_params"] = { - "frequency_min": freq_min, - "frequency_max": freq_max, - "margin_ms": margin_ms, - "seed": seed, - } - self.insert1(key, skip_duplicates=True) + @classmethod + def insert_default(cls): + cls.insert(cls.contents, skip_duplicates=True) @schema class SpikeSortingRecordingSelection(dj.Manual): definition = """ - # Defines recordings to be sorted + # Raw voltage traces and parameters. Use `insert_selection` method to insert rows. + recording_id: varchar(50) + --- + -> Raw -> SortGroup - -> SortInterval + -> IntervalList -> SpikeSortingPreprocessingParameters -> LabTeam - --- - -> IntervalList """ + @classmethod + def insert_selection(cls, key: dict): + """Insert a row into SpikeSortingRecordingSelection with an + automatically generated unique recording ID as the sole primary key. + + Parameters + ---------- + key : dict + primary key of Raw, SortGroup, IntervalList, SpikeSortingPreprocessingParameters, LabTeam tables + + Returns + ------- + recording_id : str + the unique recording ID serving as primary key for SpikeSortingRecordingSelection + """ + if len((cls & key).fetch()) > 0: + print( + "This row has already been inserted into SpikeSortingRecordingSelection." + ) + return (cls & key).fetch1() + key["recording_id"] = generate_nwb_uuid(key["nwb_file_name"], "R", 6) + cls.insert1(key, skip_duplicates=True) + return key + @schema class SpikeSortingRecording(dj.Computed): definition = """ + # Processed recording -> SpikeSortingRecordingSelection --- - recording_path: varchar(1000) - -> IntervalList.proj(sort_interval_list_name='interval_list_name') + -> AnalysisNwbfile + object_id: varchar(40) # Object ID for the processed recording in NWB file """ def make(self, key): + # DO: + # - get valid times for sort interval + # - proprocess recording + # - write recording to NWB file sort_interval_valid_times = self._get_sort_interval_valid_times(key) - recording = self._get_filtered_recording(key) - recording_name = self._get_recording_name(key) - - # Path to files that will hold the recording extractors - recording_path = str(recording_dir / Path(recording_name)) - if os.path.exists(recording_path): - shutil.rmtree(recording_path) - - recording.save( - folder=recording_path, chunk_duration="10000ms", n_jobs=8 + recording, timestamps = self._get_preprocessed_recording(key) + recording_nwb_file_name, recording_object_id = _write_recording_to_nwb( + recording, + timestamps, + (SpikeSortingRecordingSelection & key).fetch1("nwb_file_name"), ) + key["analysis_file_name"] = recording_nwb_file_name + key["object_id"] = recording_object_id + # INSERT: + # - valid times into IntervalList + # - analysis NWB file holding processed recording into AnalysisNwbfile + # - entry into SpikeSortingRecording IntervalList.insert1( { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": recording_name, + "nwb_file_name": (SpikeSortingRecordingSelection & key).fetch1( + "nwb_file_name" + ), + "interval_list_name": key["recording_id"], "valid_times": sort_interval_valid_times, - }, - replace=True, - ) - - self.insert1( - { - **key, - # store the list of valid times for the sort - "sort_interval_list_name": recording_name, - "recording_path": recording_path, } ) - - @staticmethod - def _get_recording_name(key): - return "_".join( - [ - key["nwb_file_name"], - key["sort_interval_name"], - str(key["sort_group_id"]), - key["preproc_params_name"], - ] + AnalysisNwbfile().add( + (SpikeSortingRecordingSelection & key).fetch1("nwb_file_name"), + key["analysis_file_name"], ) + self.insert1(key) - @staticmethod - def _get_recording_timestamps(recording): - num_segments = recording.get_num_segments() + @classmethod + def get_recording(cls, key: dict) -> si.BaseRecording: + """Get recording related to this curation as spikeinterface BaseRecording - if num_segments <= 1: - return recording.get_times() + Parameters + ---------- + key : dict + primary key of SpikeSorting table + """ - frames_per_segment = [0] + [ - recording.get_num_frames(segment_index=i) - for i in range(num_segments) - ] + analysis_file_name = (cls & key).fetch1("analysis_file_name") + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_file_name + ) + recording = se.read_nwb_recording( + analysis_file_abs_path, load_time_vector=True + ) - cumsum_frames = np.cumsum(frames_per_segment) - total_frames = np.sum(frames_per_segment) + return recording - timestamps = np.zeros((total_frames,)) - for i in range(num_segments): - start_index = cumsum_frames[i] - end_index = cumsum_frames[i + 1] - timestamps[start_index:end_index] = recording.get_times( - segment_index=i - ) + @staticmethod + def _get_recording_timestamps(recording): + if recording.get_num_segments() > 1: + frames_per_segment = [0] + for i in range(recording.get_num_segments()): + frames_per_segment.append( + recording.get_num_frames(segment_index=i) + ) + cumsum_frames = np.cumsum(frames_per_segment) + total_frames = np.sum(frames_per_segment) + + timestamps = np.zeros((total_frames,)) + for i in range(recording.get_num_segments()): + timestamps[ + cumsum_frames[i] : cumsum_frames[i + 1] + ] = recording.get_times(segment_index=i) + else: + timestamps = recording.get_times() return timestamps - def _get_sort_interval_valid_times(self, key): + def _get_sort_interval_valid_times(self, key: dict): """Identifies the intersection between sort interval specified by the user - and the valid times (times for which neural data exist) + and the valid times (times for which neural data exist, excluding e.g. dropped packets). Parameters ---------- key: dict - specifies a (partially filled) entry of SpikeSorting table + primary key of SpikeSortingRecordingSelection table Returns ------- sort_interval_valid_times: ndarray of tuples - (start, end) times for valid stretches of the sorting interval + (start, end) times for valid intervals in the sort interval """ + # FETCH: + # - sort interval + # - valid times + # - preprocessing parameters sort_interval = ( - SortInterval + IntervalList & { - "nwb_file_name": key["nwb_file_name"], - "sort_interval_name": key["sort_interval_name"], + "nwb_file_name": (SpikeSortingRecordingSelection & key).fetch1( + "nwb_file_name" + ), + "interval_list_name": ( + SpikeSortingRecordingSelection & key + ).fetch1("interval_list_name"), } - ).fetch1("sort_interval") - - interval_list_name = (SpikeSortingRecordingSelection & key).fetch1( - "interval_list_name" - ) - + ).fetch1("valid_times") valid_interval_times = ( IntervalList & { - "nwb_file_name": key["nwb_file_name"], - "interval_list_name": interval_list_name, + "nwb_file_name": (SpikeSortingRecordingSelection & key).fetch1( + "nwb_file_name" + ), + "interval_list_name": "raw data valid times", } ).fetch1("valid_times") + params = ( + SpikeSortingPreprocessingParameters * SpikeSortingRecordingSelection + & key + ).fetch1("preproc_params") + # DO: + # - take intersection between sort interval and valid times valid_sort_times = interval_list_intersect( - sort_interval, valid_interval_times - ) - # Exclude intervals shorter than specified length - params = (SpikeSortingPreprocessingParameters & key).fetch1( - "preproc_params" + sort_interval, + valid_interval_times, + min_length=params["min_segment_length"], ) - if "min_segment_length" in params: - valid_sort_times = intervals_by_length( - valid_sort_times, min_length=params["min_segment_length"] - ) + return valid_sort_times - def _get_filtered_recording(self, key: dict): - """Filters and references a recording - * Loads the NWB file created during insertion as a spikeinterface Recording - * Slices recording in time (interval) and space (channels); + def _get_preprocessed_recording(self, key: dict): + """Filters and references a recording. + - Loads the NWB file created during insertion as a spikeinterface Recording + - Slices recording in time (interval) and space (channels); recording chunks from disjoint intervals are concatenated - * Applies referencing and bandpass filtering + - Applies referencing and bandpass filtering Parameters ---------- - key: dict, - primary key of SpikeSortingRecording table + key: dict + primary key of SpikeSortingRecordingSelection table Returns ------- recording: si.Recording """ + # FETCH: + # - full path to NWB file + # - channels to be included in the sort + # - the reference channel + # - probe type + # - filter parameters + nwb_file_name = (SpikeSortingRecordingSelection & key).fetch1( + "nwb_file_name" + ) + sort_group_name = (SpikeSortingRecordingSelection & key).fetch1( + "sort_group_name" + ) + nwb_file_abs_path = Nwbfile().get_abs_path(nwb_file_name) + channel_ids = ( + SortGroup.SortGroupElectrode + & { + "nwb_file_name": nwb_file_name, + "sort_group_name": sort_group_name, + } + ).fetch("electrode_id") + ref_channel_id = ( + SortGroup + & { + "nwb_file_name": nwb_file_name, + "sort_group_name": sort_group_name, + } + ).fetch1("sort_reference_electrode_id") + recording_channel_ids = np.setdiff1d(channel_ids, ref_channel_id) - nwb_file_abs_path = Nwbfile().get_abs_path(key["nwb_file_name"]) + probe_type_by_channel = [] + electrode_group_by_channel = [] + for channel_id in channel_ids: + probe_type_by_channel.append( + ( + Electrode * Probe + & { + "nwb_file_name": nwb_file_name, + "electrode_id": channel_id, + } + ).fetch1("probe_type") + ) + electrode_group_by_channel.append( + ( + Electrode + & { + "nwb_file_name": nwb_file_name, + "electrode_id": channel_id, + } + ).fetch1("electrode_group_name") + ) + probe_type = np.unique(probe_type_by_channel) + filter_params = ( + SpikeSortingPreprocessingParameters * SpikeSortingRecordingSelection + & key + ).fetch1("preproc_params") + + # DO: + # - load NWB file as a spikeinterface Recording + # - slice the recording object in time and channels + # - apply referencing depending on the option chosen by the user + # - apply bandpass filter + # - set probe to recording recording = se.read_nwb_recording( nwb_file_abs_path, load_time_vector=True ) - + # TODO: make sure the following works for recordings that don't have explicit timestamps valid_sort_times = self._get_sort_interval_valid_times(key) - # shape is (N, 2) - valid_sort_times_indices = np.array( - [ - np.searchsorted(recording.get_times(), interval) - for interval in valid_sort_times - ] + valid_sort_times_indices = _consolidate_intervals( + valid_sort_times, recording.get_times() ) - # join intervals of indices that are adjacent - valid_sort_times_indices = reduce( - union_adjacent_index, valid_sort_times_indices - ) - if valid_sort_times_indices.ndim == 1: - valid_sort_times_indices = np.expand_dims( - valid_sort_times_indices, 0 - ) - - # create an AppendRecording if there is more than one disjoint sort interval + # slice in time; concatenate disjoint sort intervals if len(valid_sort_times_indices) > 1: recordings_list = [] + timestamps = [] for interval_indices in valid_sort_times_indices: recording_single = recording.frame_slice( start_frame=interval_indices[0], end_frame=interval_indices[1], ) recordings_list.append(recording_single) - recording = si.append_recordings(recordings_list) + timestamps.append( + recording.get_times()[ + interval_indices[0] : interval_indices[1] + ] + ) + recording = si.concatenate_recordings(recordings_list) else: recording = recording.frame_slice( start_frame=valid_sort_times_indices[0][0], end_frame=valid_sort_times_indices[0][1], ) - - channel_ids = ( - SortGroup.SortGroupElectrode - & { - "nwb_file_name": key["nwb_file_name"], - "sort_group_id": key["sort_group_id"], - } - ).fetch("electrode_id") - ref_channel_id = ( - SortGroup - & { - "nwb_file_name": key["nwb_file_name"], - "sort_group_id": key["sort_group_id"], - } - ).fetch1("sort_reference_electrode_id") - channel_ids = np.setdiff1d(channel_ids, ref_channel_id) - - # include ref channel in first slice, then exclude it in second slice + timestamps = recording.get_times()[ + valid_sort_times_indices[0][0] : valid_sort_times_indices[0][1] + ] + # slice in channels; include ref channel in first slice, then exclude it in second slice if ref_channel_id >= 0: - channel_ids_ref = np.append(channel_ids, ref_channel_id) - recording = recording.channel_slice(channel_ids=channel_ids_ref) - + recording = recording.channel_slice(channel_ids=channel_ids) recording = si.preprocessing.common_reference( - recording, reference="single", ref_channel_ids=ref_channel_id + recording, + reference="single", + ref_channel_ids=ref_channel_id, + dtype=np.float64, + ) + recording = recording.channel_slice( + channel_ids=recording_channel_ids ) - recording = recording.channel_slice(channel_ids=channel_ids) elif ref_channel_id == -2: - recording = recording.channel_slice(channel_ids=channel_ids) + recording = recording.channel_slice( + channel_ids=recording_channel_ids + ) recording = si.preprocessing.common_reference( - recording, reference="global", operator="median" + recording, + reference="global", + operator="median", + dtype=np.float64, + ) + elif ref_channel_id == -1: + recording = recording.channel_slice( + channel_ids=recording_channel_ids ) else: - raise ValueError("Invalid reference channel ID") - filter_params = (SpikeSortingPreprocessingParameters & key).fetch1( - "preproc_params" - ) + raise ValueError( + "Invalid reference channel ID. Use -1 to skip referencing. Use -2 to reference via global median. Use positive integer to reference to a specific channel." + ) + recording = si.preprocessing.bandpass_filter( recording, freq_min=filter_params["frequency_min"], freq_max=filter_params["frequency_max"], + dtype=np.float64, ) # if the sort group is a tetrode, change the channel location - # note that this is a workaround that would be deprecated when spikeinterface uses 3D probe locations - probe_type = [] - electrode_group = [] - for channel_id in channel_ids: - probe_type.append( - ( - Electrode * Probe - & { - "nwb_file_name": key["nwb_file_name"], - "electrode_id": channel_id, - } - ).fetch1("probe_type") - ) - electrode_group.append( - ( - Electrode - & { - "nwb_file_name": key["nwb_file_name"], - "electrode_id": channel_id, - } - ).fetch1("electrode_group_name") - ) + # (necessary because the channel location for tetrodes are not set properly) if ( - all(p == "tetrode_12.5" for p in probe_type) - and len(probe_type) == 4 - and all(eg == electrode_group[0] for eg in electrode_group) + len(probe_type) == 1 + and probe_type[0] == "tetrode_12.5" + and len(recording_channel_ids) == 4 + and len(np.unique(electrode_group_by_channel)) == 1 ): tetrode = pi.Probe(ndim=2) position = [[0, 0], [0, 12.5], [12.5, 0], [12.5, 12.5]] @@ -624,4 +535,332 @@ def _get_filtered_recording(self, key: dict): tetrode.set_device_channel_indices(np.arange(4)) recording = recording.set_probe(tetrode, in_place=True) - return recording + return recording, np.asarray(timestamps) + + +def _consolidate_intervals(intervals, timestamps): + """Convert a list of intervals (start_time, stop_time) + to a list of intervals (start_index, stop_index) by comparing to a list of timestamps; + then consolidates overlapping or adjacent intervals + + Parameters + ---------- + intervals : iterable of tuples + _description_ + timestamps : numpy.ndarray + _description_ + + Returns + ------- + _type_ + _description_ + """ + # Convert intervals to a numpy array if it's not + intervals = np.array(intervals) + if intervals.shape[1] != 2: + raise ValueError( + "Input array must have shape (N, 2) where N is the number of intervals." + ) + # Check if intervals are sorted. If not, sort them. + if not np.all(intervals[:-1] <= intervals[1:]): + intervals = np.sort(intervals, axis=0) + + # Initialize an empty list to store the consolidated intervals + consolidated = [] + + # Convert start and stop times to indices + start_indices = np.searchsorted(timestamps, intervals[:, 0], side="left") + stop_indices = ( + np.searchsorted(timestamps, intervals[:, 1], side="right") - 1 + ) + + # Start with the first interval + start, stop = start_indices[0], stop_indices[0] + + # Loop through the rest of the intervals to join them if needed + for i in range(1, len(start_indices)): + next_start, next_stop = start_indices[i], stop_indices[i] + + # If the stop time of the current interval is equal to or greater than the next start time minus 1 + if stop >= next_start - 1: + stop = max( + stop, next_stop + ) # Extend the current interval to include the next one + else: + # Add the current interval to the consolidated list + consolidated.append((start, stop)) + start, stop = next_start, next_stop # Start a new interval + + # Add the last interval to the consolidated list + consolidated.append((start, stop)) + + # Convert the consolidated list to a NumPy array and return + return np.array(consolidated) + + +def _write_recording_to_nwb( + recording: si.BaseRecording, + timestamps: Iterable, + nwb_file_name: str, +): + """Write a recording in NWB format + + Parameters + ---------- + recording : si.Recording + timestamps : iterable + nwb_file_name : str + name of NWB file the recording originates + + Returns + ------- + analysis_nwb_file : str + name of analysis NWB file containing the preprocessed recording + """ + + analysis_nwb_file = AnalysisNwbfile().create(nwb_file_name) + analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(analysis_nwb_file) + with pynwb.NWBHDF5IO( + path=analysis_nwb_file_abs_path, + mode="a", + load_namespaces=True, + ) as io: + nwbfile = io.read() + table_region = nwbfile.create_electrode_table_region( + region=[i for i in recording.get_channel_ids()], + description="Sort group", + ) + data_iterator = SpikeInterfaceRecordingDataChunkIterator( + recording=recording, return_scaled=False, buffer_gb=7 + ) + timestamps_iterator = TimestampsDataChunkIterator( + recording=TimestampsExtractor(timestamps), buffer_gb=5 + ) + processed_electrical_series = pynwb.ecephys.ElectricalSeries( + name="ProcessedElectricalSeries", + data=data_iterator, + electrodes=table_region, + timestamps=timestamps_iterator, + filtering="Bandpass filtered for spike band", + description=f"Referenced and filtered recording from {nwb_file_name} for spike sorting", + conversion=np.unique(recording.get_channel_gains())[0] * 1e-6, + ) + nwbfile.add_acquisition(processed_electrical_series) + recording_object_id = nwbfile.acquisition[ + "ProcessedElectricalSeries" + ].object_id + io.write(nwbfile) + return analysis_nwb_file, recording_object_id + + +# For writing recording to NWB file + + +class SpikeInterfaceRecordingDataChunkIterator(GenericDataChunkIterator): + """DataChunkIterator specifically for use on RecordingExtractor objects.""" + + def __init__( + self, + recording: si.BaseRecording, + segment_index: int = 0, + return_scaled: bool = False, + buffer_gb: Optional[float] = None, + buffer_shape: Optional[tuple] = None, + chunk_mb: Optional[float] = None, + chunk_shape: Optional[tuple] = None, + display_progress: bool = False, + progress_bar_options: Optional[dict] = None, + ): + """ + Initialize an Iterable object which returns DataChunks with data and their selections on each iteration. + + Parameters + ---------- + recording : si.BaseRecording + The SpikeInterfaceRecording object which handles the data access. + segment_index : int, optional + The recording segment to iterate on. + Defaults to 0. + return_scaled : bool, optional + Whether to return the trace data in scaled units (uV, if True) or in the raw data type (if False). + Defaults to False. + buffer_gb : float, optional + The upper bound on size in gigabytes (GB) of each selection from the iteration. + The buffer_shape will be set implicitly by this argument. + Cannot be set if `buffer_shape` is also specified. + The default is 1GB. + buffer_shape : tuple, optional + Manual specification of buffer shape to return on each iteration. + Must be a multiple of chunk_shape along each axis. + Cannot be set if `buffer_gb` is also specified. + The default is None. + chunk_mb : float, optional + The upper bound on size in megabytes (MB) of the internal chunk for the HDF5 dataset. + The chunk_shape will be set implicitly by this argument. + Cannot be set if `chunk_shape` is also specified. + The default is 1MB, as recommended by the HDF5 group. For more details, see + https://support.hdfgroup.org/HDF5/doc/TechNotes/TechNote-HDF5-ImprovingIOPerformanceCompressedDatasets.pdf + chunk_shape : tuple, optional + Manual specification of the internal chunk shape for the HDF5 dataset. + Cannot be set if `chunk_mb` is also specified. + The default is None. + display_progress : bool, optional + Display a progress bar with iteration rate and estimated completion time. + progress_bar_options : dict, optional + Dictionary of keyword arguments to be passed directly to tqdm. + See https://github.com/tqdm/tqdm#parameters for options. + """ + self.recording = recording + self.segment_index = segment_index + self.return_scaled = return_scaled + self.channel_ids = recording.get_channel_ids() + super().__init__( + buffer_gb=buffer_gb, + buffer_shape=buffer_shape, + chunk_mb=chunk_mb, + chunk_shape=chunk_shape, + display_progress=display_progress, + progress_bar_options=progress_bar_options, + ) + + def _get_data(self, selection: Tuple[slice]) -> Iterable: + return self.recording.get_traces( + segment_index=self.segment_index, + channel_ids=self.channel_ids[selection[1]], + start_frame=selection[0].start, + end_frame=selection[0].stop, + return_scaled=self.return_scaled, + ) + + def _get_dtype(self): + return self.recording.get_dtype() + + def _get_maxshape(self): + return ( + self.recording.get_num_samples(segment_index=self.segment_index), + self.recording.get_num_channels(), + ) + + +class TimestampsExtractor(si.BaseRecording): + def __init__( + self, + timestamps, + sampling_frequency=30e3, + ): + si.BaseRecording.__init__( + self, sampling_frequency, channel_ids=[0], dtype=np.float64 + ) + rec_segment = TimestampsSegment( + timestamps=timestamps, + sampling_frequency=sampling_frequency, + t_start=None, + dtype=np.float64, + ) + self.add_recording_segment(rec_segment) + + +class TimestampsSegment(si.BaseRecordingSegment): + def __init__(self, timestamps, sampling_frequency, t_start, dtype): + si.BaseRecordingSegment.__init__( + self, sampling_frequency=sampling_frequency, t_start=t_start + ) + self._timeseries = timestamps + + def get_num_samples(self) -> int: + return self._timeseries.shape[0] + + def get_traces( + self, + start_frame: Union[int, None] = None, + end_frame: Union[int, None] = None, + channel_indices: Union[List, None] = None, + ) -> np.ndarray: + return np.squeeze(self._timeseries[start_frame:end_frame]) + + +class TimestampsDataChunkIterator(GenericDataChunkIterator): + """DataChunkIterator specifically for use on RecordingExtractor objects.""" + + def __init__( + self, + recording: si.BaseRecording, + segment_index: int = 0, + return_scaled: bool = False, + buffer_gb: Optional[float] = None, + buffer_shape: Optional[tuple] = None, + chunk_mb: Optional[float] = None, + chunk_shape: Optional[tuple] = None, + display_progress: bool = False, + progress_bar_options: Optional[dict] = None, + ): + """ + Initialize an Iterable object which returns DataChunks with data and their selections on each iteration. + + Parameters + ---------- + recording : SpikeInterfaceRecording + The SpikeInterfaceRecording object (RecordingExtractor or BaseRecording) which handles the data access. + segment_index : int, optional + The recording segment to iterate on. + Defaults to 0. + return_scaled : bool, optional + Whether to return the trace data in scaled units (uV, if True) or in the raw data type (if False). + Defaults to False. + buffer_gb : float, optional + The upper bound on size in gigabytes (GB) of each selection from the iteration. + The buffer_shape will be set implicitly by this argument. + Cannot be set if `buffer_shape` is also specified. + The default is 1GB. + buffer_shape : tuple, optional + Manual specification of buffer shape to return on each iteration. + Must be a multiple of chunk_shape along each axis. + Cannot be set if `buffer_gb` is also specified. + The default is None. + chunk_mb : float, optional + The upper bound on size in megabytes (MB) of the internal chunk for the HDF5 dataset. + The chunk_shape will be set implicitly by this argument. + Cannot be set if `chunk_shape` is also specified. + The default is 1MB, as recommended by the HDF5 group. For more details, see + https://support.hdfgroup.org/HDF5/doc/TechNotes/TechNote-HDF5-ImprovingIOPerformanceCompressedDatasets.pdf + chunk_shape : tuple, optional + Manual specification of the internal chunk shape for the HDF5 dataset. + Cannot be set if `chunk_mb` is also specified. + The default is None. + display_progress : bool, optional + Display a progress bar with iteration rate and estimated completion time. + progress_bar_options : dict, optional + Dictionary of keyword arguments to be passed directly to tqdm. + See https://github.com/tqdm/tqdm#parameters for options. + """ + self.recording = recording + self.segment_index = segment_index + self.return_scaled = return_scaled + self.channel_ids = recording.get_channel_ids() + super().__init__( + buffer_gb=buffer_gb, + buffer_shape=buffer_shape, + chunk_mb=chunk_mb, + chunk_shape=chunk_shape, + display_progress=display_progress, + progress_bar_options=progress_bar_options, + ) + + # change channel id to always be first channel + def _get_data(self, selection: Tuple[slice]) -> Iterable: + return self.recording.get_traces( + segment_index=self.segment_index, + channel_ids=[0], + start_frame=selection[0].start, + end_frame=selection[0].stop, + return_scaled=self.return_scaled, + ) + + def _get_dtype(self): + return self.recording.get_dtype() + + # remove the last dim for the timestamps since it is always just a 1D vector + def _get_maxshape(self): + return ( + self.recording.get_num_samples(segment_index=self.segment_index), + ) diff --git a/src/spyglass/spikesorting/v1/sorting.py b/src/spyglass/spikesorting/v1/sorting.py index 8178ef513..05cb03346 100644 --- a/src/spyglass/spikesorting/v1/sorting.py +++ b/src/spyglass/spikesorting/v1/sorting.py @@ -1,199 +1,202 @@ +from typing import Iterable + import os -import shutil import tempfile import time -import uuid -from pathlib import Path import datajoint as dj import numpy as np +import pynwb import spikeinterface as si +import spikeinterface.extractors as se import spikeinterface.preprocessing as sip import spikeinterface.sorters as sis from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from ..common.common_lab import LabMember, LabTeam -from ..common.common_nwbfile import AnalysisNwbfile -from ..settings import temp_dir, sorting_dir -from .spikesorting_artifact import ArtifactRemovedIntervalList -from .spikesorting_recording import ( +from spyglass.common.common_lab import LabMember, LabTeam +from spyglass.common.common_interval import IntervalList +from spyglass.common.common_nwbfile import AnalysisNwbfile +from spyglass.spikesorting.v1.recording import ( SpikeSortingRecording, SpikeSortingRecordingSelection, ) +from .recording import _consolidate_intervals +from .utils import generate_nwb_uuid -schema = dj.schema("spikesorting_sorting") +schema = dj.schema("spikesorting_v1_sorting") @schema -class SpikeSorterParameters(dj.Manual): +class SpikeSorterParameters(dj.Lookup): definition = """ sorter: varchar(200) - sorter_params_name: varchar(200) + sorter_param_name: varchar(200) --- sorter_params: blob """ - - def insert_default(self): - """Default params from spike sorters available via spikeinterface""" - sorters = sis.available_sorters() - for sorter in sorters: - sorter_params = sis.get_default_sorter_params(sorter) - self.insert1( - [sorter, "default", sorter_params], skip_duplicates=True - ) - - # Insert Frank lab defaults - # Hippocampus tetrode default - sorter = "mountainsort4" - sorter_params_name = "franklab_tetrode_hippocampus_30KHz" - sorter_params = { - "detect_sign": -1, - "adjacency_radius": 100, - "freq_min": 600, - "freq_max": 6000, - "filter": False, - "whiten": True, - "num_workers": 1, - "clip_size": 40, - "detect_threshold": 3, - "detect_interval": 10, - } - self.insert1( - [sorter, sorter_params_name, sorter_params], skip_duplicates=True - ) - - # Cortical probe default - sorter = "mountainsort4" - sorter_params_name = "franklab_probe_ctx_30KHz" - sorter_params = { - "detect_sign": -1, - "adjacency_radius": 100, - "freq_min": 300, - "freq_max": 6000, - "filter": False, - "whiten": True, - "num_workers": 1, - "clip_size": 40, - "detect_threshold": 3, - "detect_interval": 10, - } - self.insert1( - [sorter, sorter_params_name, sorter_params], skip_duplicates=True - ) - - # clusterless defaults - sorter = "clusterless_thresholder" - sorter_params_name = "default_clusterless" - sorter_params = dict( - detect_threshold=100.0, # uV - # Locally exclusive means one unit per spike detected - method="locally_exclusive", - peak_sign="neg", - exclude_sweep_ms=0.1, - local_radius_um=100, - # noise levels needs to be 1.0 so the units are in uV and not MAD - noise_levels=np.asarray([1.0]), - random_chunk_kwargs={}, - # output needs to be set to sorting for the rest of the pipeline - outputs="sorting", - ) - self.insert1( - [sorter, sorter_params_name, sorter_params], skip_duplicates=True - ) + contents = [ + [ + "mountainsort4", + "franklab_tetrode_hippocampus_30KHz", + { + "detect_sign": -1, + "adjacency_radius": 100, + "freq_min": 600, + "freq_max": 6000, + "filter": False, + "whiten": True, + "num_workers": 1, + "clip_size": 40, + "detect_threshold": 3, + "detect_interval": 10, + }, + ], + [ + "mountainsort4", + "franklab_probe_ctx_30KHz", + { + "detect_sign": -1, + "adjacency_radius": 100, + "freq_min": 300, + "freq_max": 6000, + "filter": False, + "whiten": True, + "num_workers": 1, + "clip_size": 40, + "detect_threshold": 3, + "detect_interval": 10, + }, + ], + [ + "clusterless_thresholder", + "default_clusterless", + { + "detect_threshold": 100.0, # uV + # Locally exclusive means one unit per spike detected + "method": "locally_exclusive", + "peak_sign": "neg", + "exclude_sweep_ms": 0.1, + "local_radius_um": 100, + # noise levels needs to be 1.0 so the units are in uV and not MAD + "noise_levels": np.asarray([1.0]), + "random_chunk_kwargs": {}, + # output needs to be set to sorting for the rest of the pipeline + "outputs": "sorting", + }, + ], + ] + contents.extend( + [ + [sorter, "default", sis.get_default_sorter_params(sorter)] + for sorter in sis.available_sorters() + ] + ) + + @classmethod + def insert_default(cls): + cls.insert(cls.contents, skip_duplicates=True) @schema class SpikeSortingSelection(dj.Manual): definition = """ - # Table for holding selection of recording and parameters for each spike sorting run + # Processed recording and parameters. Use `insert_selection` method to insert rows. + sorting_id: varchar(50) + --- -> SpikeSortingRecording -> SpikeSorterParameters - -> ArtifactRemovedIntervalList - --- - import_path = "": varchar(200) # optional path to previous curated sorting output + -> IntervalList """ + @classmethod + def insert_selection(cls, key: dict): + """Insert a row into SpikeSortingSelection with an + automatically generated unique sorting ID as the sole primary key. + + Parameters + ---------- + key : dict + primary key of SpikeSortingRecording, SpikeSorterParameters, IntervalList tables + + Returns + ------- + sorting_id : str + the unique sorting ID serving as primary key for SpikeSorting + """ + if len((cls & key).fetch()) > 0: + print( + "This row has already been inserted into SpikeSortingSelection." + ) + return (cls & key).fetch1() + key["sorting_id"] = generate_nwb_uuid( + key["nwb_file_name"], + "S", + 6, + ) + cls.insert1(key, skip_duplicates=True) + return key + @schema class SpikeSorting(dj.Computed): definition = """ -> SpikeSortingSelection --- - sorting_path: varchar(1000) - time_of_sort: int # in Unix time, to the nearest second + -> AnalysisNwbfile + object_id: varchar(40) # Object ID for the sorting in NWB file + time_of_sort: int # in Unix time, to the nearest second """ def make(self, key: dict): """Runs spike sorting on the data and parameters specified by the SpikeSortingSelection table and inserts a new entry to SpikeSorting table. - - Specifically, - 1. Loads saved recording and runs the sort on it with spikeinterface - 2. Saves the sorting with spikeinterface - 3. Creates an analysis NWB file and saves the sorting there - (this is redundant with 2; will change in the future) - """ - # CBroz: does this not work w/o arg? as .populate() ? - recording_path = (SpikeSortingRecording & key).fetch1("recording_path") - recording = si.load_extractor(recording_path) - - # first, get the timestamps - timestamps = SpikeSortingRecording._get_recording_timestamps(recording) - _ = recording.get_sampling_frequency() - # then concatenate the recordings - # Note: the timestamps are lost upon concatenation, - # i.e. concat_recording.get_times() doesn't return true timestamps anymore. - # but concat_recording.recoring_list[i].get_times() will return correct - # timestamps for ith recording. - if recording.get_num_segments() > 1 and isinstance( - recording, si.AppendSegmentRecording - ): - recording = si.concatenate_recordings(recording.recording_list) - elif recording.get_num_segments() > 1 and isinstance( - recording, si.BinaryRecordingExtractor - ): - recording = si.concatenate_recordings([recording]) - - # load artifact intervals - artifact_times = ( - ArtifactRemovedIntervalList + # FETCH: + # - information about the recording + # - artifact free intervals + # - spike sorter and sorter params + recording_key = ( + SpikeSortingRecording * SpikeSortingSelection & key + ).fetch1() + artifact_removed_intervals = ( + IntervalList & { - "artifact_removed_interval_list_name": key[ - "artifact_removed_interval_list_name" - ] + "nwb_file_name": (SpikeSortingSelection & key).fetch1( + "nwb_file_name" + ), + "interval_list_name": (SpikeSortingSelection & key).fetch1( + "interval_list_name" + ), } - ).fetch1("artifact_times") - if len(artifact_times): - if artifact_times.ndim == 1: - artifact_times = np.expand_dims(artifact_times, 0) - - # convert artifact intervals to indices - list_triggers = [] - for interval in artifact_times: - list_triggers.append( - np.arange( - np.searchsorted(timestamps, interval[0]), - np.searchsorted(timestamps, interval[1]), - ) - ) - list_triggers = [list(np.concatenate(list_triggers))] - recording = sip.remove_artifacts( - recording=recording, - list_triggers=list_triggers, - ms_before=None, - ms_after=None, - mode="zeros", - ) + ).fetch1("valid_times") + sorter, sorter_params = ( + SpikeSorterParameters * SpikeSortingSelection & key + ).fetch1("sorter", "sorter_params") + + # DO: + # - load recording + # - concatenate artifact removed intervals + # - run spike sorting + # - save output to NWB file + recording_analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path( + recording_key["analysis_file_name"] + ) + recording = se.read_nwb_recording( + recording_analysis_nwb_file_abs_path, load_time_vector=True + ) + + timestamps = recording.get_times() - print(f"Running spike sorting on {key}...") - sorter, sorter_params = (SpikeSorterParameters & key).fetch1( - "sorter", "sorter_params" + artifact_removed_intervals_ind = _consolidate_intervals( + artifact_removed_intervals, timestamps ) - sorter_temp_dir = tempfile.TemporaryDirectory(dir=temp_dir) - # add tempdir option for mountainsort - sorter_params["tempdir"] = sorter_temp_dir.name + recording_list = [] + for interval in artifact_removed_intervals_ind: + recording_list.append( + recording.frame_slice(interval[0], interval[1]) + ) + recording = si.concatenate_recordings(recording_list) if sorter == "clusterless_thresholder": # need to remove tempdir and whiten from sorter_params @@ -204,107 +207,155 @@ def make(self, key: dict): # Detect peaks for clusterless decoding detected_spikes = detect_peaks(recording, **sorter_params) sorting = si.NumpySorting.from_times_labels( - times_list=detected_spikes["sample_index"], + times_list=detected_spikes["sample_ind"], labels_list=np.zeros(len(detected_spikes), dtype=np.int), sampling_frequency=recording.get_sampling_frequency(), ) else: - if "whiten" in sorter_params.keys(): - if sorter_params["whiten"]: - sorter_params["whiten"] = False # set whiten to False - # whiten recording separately; make sure dtype is float32 - # to avoid downstream error with svd - recording = sip.whiten(recording, dtype="float32") + # Specify tempdir (expected by some sorters like mountainsort4) + sorter_temp_dir = tempfile.TemporaryDirectory( + dir=os.getenv("SPYGLASS_TEMP_DIR") + ) + sorter_params["tempdir"] = sorter_temp_dir.name + # if whitening is specified in sorter params, apply whitening separately + # prior to sorting and turn off "sorter whitening" + if sorter_params["whiten"]: + recording = sip.whiten(recording, dtype=np.float64) + sorter_params["whiten"] = False sorting = sis.run_sorter( sorter, recording, output_folder=sorter_temp_dir.name, - delete_output_folder=True, **sorter_params, ) key["time_of_sort"] = int(time.time()) - print("Saving sorting results...") - - sorting_folder = Path(sorting_dir) - - sorting_name = self._get_sorting_name(key) - key["sorting_path"] = str(sorting_folder / Path(sorting_name)) - if os.path.exists(key["sorting_path"]): - shutil.rmtree(key["sorting_path"]) - sorting = sorting.save(folder=key["sorting_path"]) - self.insert1(key) - - def delete(self): - """Extends the delete method of base class to implement permission checking. - Note that this is NOT a security feature, as anyone that has access to source code - can disable it; it just makes it less likely to accidentally delete entries. - """ - current_user_name = dj.config["database.user"] - entries = self.fetch() - permission_bool = np.zeros((len(entries),)) - print( - f"Attempting to delete {len(entries)} entries, checking permission..." + key["analysis_file_name"], key["object_id"] = _write_sorting_to_nwb( + sorting, + timestamps, + artifact_removed_intervals, + (SpikeSortingSelection & key).fetch1("nwb_file_name"), ) - for entry_idx in range(len(entries)): - # check the team name for the entry, then look up the members in that team, - # then get their datajoint user names - team_name = ( - SpikeSortingRecordingSelection - & (SpikeSortingRecordingSelection & entries[entry_idx]).proj() - ).fetch1()["team_name"] - lab_member_name_list = ( - LabTeam.LabTeamMember & {"team_name": team_name} - ).fetch("lab_member_name") - datajoint_user_names = [] - for lab_member_name in lab_member_name_list: - datajoint_user_names.append( - ( - LabMember.LabMemberInfo - & {"lab_member_name": lab_member_name} - ).fetch1("datajoint_user_name") - ) - permission_bool[entry_idx] = ( - current_user_name in datajoint_user_names - ) - if np.sum(permission_bool) == len(entries): - print("Permission to delete all specified entries granted.") - super().delete() - else: - raise Exception( - "You do not have permission to delete all specified" - "entries. Not deleting anything." - ) + # INSERT + # - new entry to AnalysisNwbfile + # - new entry to SpikeSorting + AnalysisNwbfile().add( + (SpikeSortingSelection & key).fetch1("nwb_file_name"), + key["analysis_file_name"], + ) + self.insert1(key) - def fetch_nwb(self, *attrs, **kwargs): - raise NotImplementedError - return None - # return fetch_nwb(self, (AnalysisNwbfile, 'analysis_file_abs_path'), *attrs, **kwargs) + # def delete(self): + # """Extends the delete method of base class to implement permission checking. + # Note that this is NOT a security feature, as anyone that has access to source code + # can disable it; it just makes it less likely to accidentally delete entries. + # """ + # current_user_name = dj.config["database.user"] + # entries = self.fetch() + # permission_bool = np.zeros((len(entries),)) + # print( + # f"Attempting to delete {len(entries)} entries, checking permission..." + # ) + + # for entry_idx in range(len(entries)): + # # check the team name for the entry, then look up the members in that team, + # # then get their datajoint user names + # team_name = ( + # SpikeSortingRecordingSelection + # & (SpikeSortingRecordingSelection & entries[entry_idx]).proj() + # ).fetch1()["team_name"] + # lab_member_name_list = ( + # LabTeam.LabTeamMember & {"team_name": team_name} + # ).fetch("lab_member_name") + # datajoint_user_names = [] + # for lab_member_name in lab_member_name_list: + # datajoint_user_names.append( + # ( + # LabMember.LabMemberInfo + # & {"lab_member_name": lab_member_name} + # ).fetch1("datajoint_user_name") + # ) + # permission_bool[entry_idx] = ( + # current_user_name in datajoint_user_names + # ) + # if np.sum(permission_bool) == len(entries): + # print("Permission to delete all specified entries granted.") + # super().delete() + # else: + # raise Exception( + # "You do not have permission to delete all specified" + # "entries. Not deleting anything." + # ) + + @classmethod + def get_sorting(cls, key: dict) -> si.BaseSorting: + """Get sorting in the analysis NWB file as spikeinterface BaseSorting + + Parameters + ---------- + key : dict + primary key of SpikeSorting + + Returns + ------- + sorting : si.BaseSorting - def nightly_cleanup(self): - """Clean up spike sorting directories that are not in the SpikeSorting table. - This should be run after AnalysisNwbFile().nightly_cleanup() """ - # get a list of the files in the spike sorting storage directory - dir_names = next(os.walk(sorting_dir))[1] - # now retrieve a list of the currently used analysis nwb files - analysis_file_names = self.fetch("analysis_file_name") - for dir in dir_names: - if dir not in analysis_file_names: - full_path = str(Path(sorting_dir) / dir) - print(f"removing {full_path}") - shutil.rmtree(str(Path(sorting_dir) / dir)) - - @staticmethod - def _get_sorting_name(key): - recording_name = SpikeSortingRecording._get_recording_name(key) - sorting_name = ( - recording_name + "_" + str(uuid.uuid4())[0:8] + "_spikesorting" - ) - return sorting_name - # TODO: write a function to import sorting done outside of dj + analysis_file_name = (cls & key).fetch1("analysis_file_name") + analysis_file_abs_path = AnalysisNwbfile.get_abs_path( + analysis_file_name + ) + sorting = se.read_nwb_sorting(analysis_file_abs_path) + + return sorting + + +def _write_sorting_to_nwb( + sorting: si.BaseSorting, + timestamps: np.ndarray, + sort_interval: Iterable, + nwb_file_name: str, +): + """Write a sorting in NWB format. + + Parameters + ---------- + sorting : si.BaseSorting + spike times are in samples + timestamps: np.ndarray + the absolute time of each sample, in seconds + sort_interval : Iterable + nwb_file_name : str + Name of NWB file the recording originates from + + Returns + ------- + analysis_nwb_file : str + Name of analysis NWB file containing the sorting + """ - def _import_sorting(self, key): - raise NotImplementedError + analysis_nwb_file = AnalysisNwbfile().create(nwb_file_name) + analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path(analysis_nwb_file) + with pynwb.NWBHDF5IO( + path=analysis_nwb_file_abs_path, + mode="a", + load_namespaces=True, + ) as io: + nwbf = io.read() + nwbf.add_unit_column( + name="curation_label", + description="curation label applyed to a unit", + ) + for unit_id in sorting.get_unit_ids(): + spike_times = sorting.get_unit_spike_train(unit_id) + nwbf.add_unit( + spike_times=timestamps[spike_times], + id=unit_id, + obs_intervals=sort_interval, + curation_label="uncurated", + ) + units_object_id = nwbf.units.object_id + io.write(nwbf) + return analysis_nwb_file, units_object_id diff --git a/src/spyglass/spikesorting/v1/utils.py b/src/spyglass/spikesorting/v1/utils.py new file mode 100644 index 000000000..ad8657ef6 --- /dev/null +++ b/src/spyglass/spikesorting/v1/utils.py @@ -0,0 +1,18 @@ +import uuid + + +def generate_nwb_uuid(nwb_file_name: str, initial: str, len_uuid: int = 6): + """Generates a unique identifier related to an NWB file. + + Parameters + ---------- + nwb_file_name : str + _description_ + initial : str + R if recording; A if artifact; S if sorting etc + len_uuid : int + how many digits of uuid4 to keep + """ + uuid4 = str(uuid.uuid4()) + nwb_uuid = nwb_file_name + "_" + initial + "_" + uuid4[:len_uuid] + return nwb_uuid