diff --git a/src/spyglass/spikesorting/analysis/v1/group.py b/src/spyglass/spikesorting/analysis/v1/group.py index 477f32c75..b46ee1b0b 100644 --- a/src/spyglass/spikesorting/analysis/v1/group.py +++ b/src/spyglass/spikesorting/analysis/v1/group.py @@ -133,11 +133,8 @@ def fetch_spike_data(key, time_slice=None): include_unit = SortedSpikesGroup.filter_units( group_label_list, include_labels, exclude_labels ) - sorting_spike_times = [ - times - for times, include in zip(sorting_spike_times, include_unit) - if include - ] + from itertools import compress # worth bumping to top of script + sorting_spike_times = list(compress(sorting_spike_times, include_unit)) # filter the spike times based on the time slice if provided if time_slice is not None: diff --git a/src/spyglass/spikesorting/imported.py b/src/spyglass/spikesorting/imported.py index ef6d60900..e397ee95e 100644 --- a/src/spyglass/spikesorting/imported.py +++ b/src/spyglass/spikesorting/imported.py @@ -86,18 +86,18 @@ def add_annotation( merge_annotations : bool, optional whether to merge with existing annotations, by default False """ - if not len(self & key) == 1: - raise ValueError("ImportedSpikeSorting key must be unique") + if isinstance(label, str): + label = [label] + query = self & key + if not len(query) == 1: + raise ValueError(f"ImportedSpikeSorting key must be unique. Found: {query}") unit_key = {**key, "id": id} - if ImportedSpikeSorting.Annotations & unit_key: - if not merge_annotations: - raise ValueError("Unit already has annotations") - existing_annotations = ( - ImportedSpikeSorting.Annotations & unit_key - ).fetch(as_dict=True)[0] - existing_annotations["label"] = ( - existing_annotations["label"] + label - ) + annotation_query = ImportedSpikeSorting.Annotations & unit_key + if annotation_query and not merge_annotations: + raise ValueError(f"Unit already has annotations: {annotation_query}") + elif annotation_query: + existing_annotations = annotation_query.fetch1() + existing_annotations["label"] += label existing_annotations["annotations"].update(annotations) self.Annotations.update1(existing_annotations) else: @@ -130,11 +130,12 @@ def fetch_nwb(self, *attrs, **kwargs): nwbs = super().fetch_nwb(*attrs, **kwargs) # for each nwb, get the annotations and add them to the spikes dataframe for i, key in enumerate(self.fetch("KEY")): - if ImportedSpikeSorting.Annotations & key: - # make the annotation_df - annotation_df = (self & key).make_df_from_annotations() - # concatenate the annotations to the spikes dataframe in the returned nwb - nwbs[i]["object_id"] = pd.concat( - [nwbs[i]["object_id"], annotation_df], axis="columns" - ) + if not ImportedSpikeSorting.Annotations & key: + continue + # make the annotation_df + annotation_df = (self & key).make_df_from_annotations() + # concatenate the annotations to the spikes dataframe in the returned nwb + nwbs[i]["object_id"] = pd.concat( + [nwbs[i]["object_id"], annotation_df], axis="columns" + ) return nwbs