From 52ab11755c9b65bd635022927404ad984c047d2d Mon Sep 17 00:00:00 2001 From: Kyu Hyun Lee Date: Mon, 30 Oct 2023 14:58:59 -0700 Subject: [PATCH] Update artifact detection --- src/spyglass/common/common_interval.py | 31 +++++++ src/spyglass/spikesorting/v1/artifact.py | 100 ++++++----------------- 2 files changed, 58 insertions(+), 73 deletions(-) diff --git a/src/spyglass/common/common_interval.py b/src/spyglass/common/common_interval.py index 6cc024430..6b431bdc6 100644 --- a/src/spyglass/common/common_interval.py +++ b/src/spyglass/common/common_interval.py @@ -506,3 +506,34 @@ def interval_set_difference_inds(intervals1, intervals2): i += 1 result += intervals1[i:] return result + + +def interval_list_complement(intervals1, intervals2): + "Finds intervals in intervals1 that are not in intervals2" + result = [] + + for start1, end1 in intervals1: + subtracted = [(start1, end1)] + + for start2, end2 in intervals2: + new_subtracted = [] + + for s, e in subtracted: + if start2 <= s and e <= end2: + continue + + if e <= start2 or end2 <= s: + new_subtracted.append((s, e)) + continue + + if start2 > s: + new_subtracted.append((s, start2)) + + if end2 < e: + new_subtracted.append((end2, e)) + + subtracted = new_subtracted + + result.extend(subtracted) + + return np.asarray(result) diff --git a/src/spyglass/spikesorting/v1/artifact.py b/src/spyglass/spikesorting/v1/artifact.py index ae740ab73..7bd58cf17 100644 --- a/src/spyglass/spikesorting/v1/artifact.py +++ b/src/spyglass/spikesorting/v1/artifact.py @@ -1,6 +1,6 @@ import warnings from functools import reduce -from typing import Union +from typing import Union, List import datajoint as dj import numpy as np @@ -14,6 +14,7 @@ IntervalList, _union_concat, interval_from_inds, + interval_list_complement, ) from spyglass.spikesorting.v1.utils import generate_nwb_uuid from spyglass.spikesorting.v1.recording import ( @@ -119,7 +120,19 @@ def make(self, key): * ArtifactDetectionSelection & key ).fetch1("artifact_params", "analysis_file_name") - + sort_interval_valid_times = ( + IntervalList + & { + "nwb_file_name": ( + SpikeSortingRecordingSelection * ArtifactDetectionSelection + & key + ).fetch1("nwb_file_name"), + "interval_list_name": ( + SpikeSortingRecordingSelection * ArtifactDetectionSelection + & key + ).fetch1("interval_list_name"), + } + ).fetch1("valid_times") # DO: # - load recording recording_analysis_nwb_file_abs_path = AnalysisNwbfile.get_abs_path( @@ -128,9 +141,11 @@ def make(self, key): recording = se.read_nwb_recording( recording_analysis_nwb_file_abs_path, load_time_vector=True ) + # - detect artifacts artifact_removed_valid_times, _ = _get_artifact_times( recording, + sort_interval_valid_times, **artifact_params, ) @@ -153,6 +168,7 @@ def make(self, key): def _get_artifact_times( recording: si.BaseRecording, + sort_interval_valid_times: List[List], zscore_thresh: Union[float, None] = None, amplitude_thresh_uV: Union[float, None] = None, proportion_above_thresh: float = 1.0, @@ -171,6 +187,8 @@ def _get_artifact_times( Parameters ---------- recording : si.BaseRecording + sort_interval_valid_times : List[List] + The sort interval for the recording, unit: seconds zscore_thresh : float, optional Stdev threshold for exclusion, should be >=0, defaults to None amplitude_thresh_uV : float, optional @@ -243,6 +261,7 @@ def _get_artifact_times( artifact_frames = executor.run() artifact_frames = np.concatenate(artifact_frames) + print(f"artifact_frames: {artifact_frames}") # turn ms to remove total into s to remove from either side of each detected artifact half_removal_window_s = removal_window_ms / 2 / 1000 @@ -257,6 +276,7 @@ def _get_artifact_times( # convert indices to intervals artifact_intervals = interval_from_inds(artifact_frames) + print(f"artifact_intervals: {artifact_intervals}") # convert to seconds and pad with window artifact_intervals_s = np.zeros( @@ -282,8 +302,11 @@ def _get_artifact_times( artifact_intervals_s = reduce(_union_concat, artifact_intervals_s) # find non-artifact intervals in timestamps - artifact_removed_valid_times = find_missing_intervals( - artifact_intervals_s, valid_timestamps + artifact_removed_valid_times = interval_list_complement( + sort_interval_valid_times, artifact_intervals_s + ) + artifact_removed_valid_times = reduce( + _union_concat, artifact_removed_valid_times ) return artifact_removed_valid_times, artifact_intervals_s @@ -401,75 +424,6 @@ def _check_artifact_thresholds( return amplitude_thresh_uV, zscore_thresh, proportion_above_thresh -def find_missing_intervals(intervals, timestamps): - """Given a list of intervals each of which is [start_time, end_time] and an array of timestamps, - find intervals are not contained in the input list of intervals but contained in the array of timestamps. - Note that the start and stop times of such intervals must be explicitly contained in the array of timestamps - - Parameters - ---------- - intervals : _type_ - _description_ - timestamps : _type_ - _description_ - - Returns - ------- - _type_ - _description_ - """ - # Sort the list of intervals and timestamps - intervals.sort() - timestamps.sort() - - missing_intervals = [] - timestamp_idx = 0 - - # Initialize an empty interval - new_interval = [] - - for start, end in intervals: - # Look for potential missing intervals - while ( - timestamp_idx < len(timestamps) - and timestamps[timestamp_idx] < start - ): - new_interval.append(timestamps[timestamp_idx]) - timestamp_idx += 1 - - if len(new_interval) == 1: - continue - - if timestamps[timestamp_idx] > new_interval[-1]: - new_interval.append(timestamps[timestamp_idx - 1]) - missing_intervals.append(new_interval) - new_interval = [] - - # Move the index to the point after the end of the current interval - while ( - timestamp_idx < len(timestamps) and timestamps[timestamp_idx] <= end - ): - timestamp_idx += 1 - - # Check for any remaining missing intervals - while timestamp_idx < len(timestamps): - new_interval.append(timestamps[timestamp_idx]) - timestamp_idx += 1 - - if len(new_interval) == 1: - continue - - if ( - timestamp_idx == len(timestamps) - or timestamps[timestamp_idx] > new_interval[-1] - ): - new_interval.append(timestamps[timestamp_idx - 1]) - missing_intervals.append(new_interval) - new_interval = [] - - return np.asarray(missing_intervals) - - def merge_intervals(intervals): """Takes a list of intervals each of which is [start_time, stop_time] and takes union over intervals that are intersecting