From 205874a0cab45052dc2c2696f3ef0e8c60d01eb3 Mon Sep 17 00:00:00 2001 From: Eric Denovellis Date: Mon, 29 Jan 2024 13:54:56 -0800 Subject: [PATCH] Add intervallist to exclude artifacts --- src/spyglass/ripple/v1/ripple.py | 2 +- src/spyglass/spikesorting/analysis/v1/mua.py | 19 ++++++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/src/spyglass/ripple/v1/ripple.py b/src/spyglass/ripple/v1/ripple.py index 4d2397bc3..9df25cecd 100644 --- a/src/spyglass/ripple/v1/ripple.py +++ b/src/spyglass/ripple/v1/ripple.py @@ -149,7 +149,7 @@ class RippleTimesV1(SpyglassMixin, dj.Computed): -> RippleLFPSelection -> RippleParameters -> PositionOutput.proj(pos_merge_id='merge_id') - + -> IntervalList.proj(artifact_interval_list_name='interval_list_name') --- -> AnalysisNwbfile ripple_times_object_id : varchar(40) diff --git a/src/spyglass/spikesorting/analysis/v1/mua.py b/src/spyglass/spikesorting/analysis/v1/mua.py index 158be9f61..168df65f8 100644 --- a/src/spyglass/spikesorting/analysis/v1/mua.py +++ b/src/spyglass/spikesorting/analysis/v1/mua.py @@ -2,6 +2,7 @@ import numpy as np from ripple_detection import multiunit_HSE_detector +from spyglass.common.common_interval import IntervalList from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.position import PositionOutput # noqa: F401 from spyglass.spikesorting.analysis.v1.group import ( @@ -44,7 +45,7 @@ class MuaEventsV1(SpyglassMixin, dj.Computed): -> MuaEventsParameters -> SortedSpikesGroup -> PositionOutput.proj(pos_merge_id='merge_id') - + -> IntervalList.proj(artifact_interval_list_name='interval_list_name') # exclude artifact times --- -> AnalysisNwbfile mua_times_object_id : varchar(40) @@ -70,6 +71,22 @@ def make(self, key): mua_params = (MuaEventsParameters & key).fetch1("mua_param_dict") + # Exclude artifact times + # Alternatively could set to NaN and leave them out of the firing rate calculation + # in the multiunit_HSE_detector function + artifact_key = { + "nwb_file_name": key["nwb_file_name"], + "interval_list_name": key["artifact_interval_list_name"], + } + artifact_times = (IntervalList & artifact_key).fetch1("valid_times") + mean_n_spikes = np.mean(spike_indicator) + for artifact_time in artifact_times: + spike_indicator[ + np.logical_and( + time >= artifact_time.start, time <= artifact_time.stop + ) + ] = mean_n_spikes + mua_times = multiunit_HSE_detector( time, spike_indicator, speed, sampling_frequency, **mua_params )