Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Chris Brozdowski <[email protected]>
  • Loading branch information
samuelbray32 and CBroz1 authored Feb 1, 2024
1 parent 411c312 commit f3ea979
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 23 deletions.
7 changes: 2 additions & 5 deletions src/spyglass/spikesorting/analysis/v1/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,8 @@ def fetch_spike_data(key, time_slice=None):
include_unit = SortedSpikesGroup.filter_units(
group_label_list, include_labels, exclude_labels
)
sorting_spike_times = [
times
for times, include in zip(sorting_spike_times, include_unit)
if include
]
from itertools import compress # worth bumping to top of script
sorting_spike_times = list(compress(sorting_spike_times, include_unit))

# filter the spike times based on the time slice if provided
if time_slice is not None:
Expand Down
37 changes: 19 additions & 18 deletions src/spyglass/spikesorting/imported.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,18 @@ def add_annotation(
merge_annotations : bool, optional
whether to merge with existing annotations, by default False
"""
if not len(self & key) == 1:
raise ValueError("ImportedSpikeSorting key must be unique")
if isinstance(label, str):
label = [label]
query = self & key
if not len(query) == 1:
raise ValueError(f"ImportedSpikeSorting key must be unique. Found: {query}")
unit_key = {**key, "id": id}
if ImportedSpikeSorting.Annotations & unit_key:
if not merge_annotations:
raise ValueError("Unit already has annotations")
existing_annotations = (
ImportedSpikeSorting.Annotations & unit_key
).fetch(as_dict=True)[0]
existing_annotations["label"] = (
existing_annotations["label"] + label
)
annotation_query = ImportedSpikeSorting.Annotations & unit_key
if annotation_query and not merge_annotations:
raise ValueError(f"Unit already has annotations: {annotation_query}")
elif annotation_query:
existing_annotations = annotation_query.fetch1()
existing_annotations["label"] += label
existing_annotations["annotations"].update(annotations)
self.Annotations.update1(existing_annotations)
else:
Expand Down Expand Up @@ -130,11 +130,12 @@ def fetch_nwb(self, *attrs, **kwargs):
nwbs = super().fetch_nwb(*attrs, **kwargs)
# for each nwb, get the annotations and add them to the spikes dataframe
for i, key in enumerate(self.fetch("KEY")):
if ImportedSpikeSorting.Annotations & key:
# make the annotation_df
annotation_df = (self & key).make_df_from_annotations()
# concatenate the annotations to the spikes dataframe in the returned nwb
nwbs[i]["object_id"] = pd.concat(
[nwbs[i]["object_id"], annotation_df], axis="columns"
)
if not ImportedSpikeSorting.Annotations & key:
continue
# make the annotation_df
annotation_df = (self & key).make_df_from_annotations()
# concatenate the annotations to the spikes dataframe in the returned nwb
nwbs[i]["object_id"] = pd.concat(
[nwbs[i]["object_id"], annotation_df], axis="columns"
)
return nwbs

0 comments on commit f3ea979

Please sign in to comment.