Skip to content

Commit

Permalink
Add tests for spikesorting (#1078)
Browse files Browse the repository at this point in the history
* WIP: Add tests for spikesorting

* ✅ : Add tests for spikesorting 2

* Update changelog

* ✅ : Add tests of utils
  • Loading branch information
CBroz1 authored Sep 5, 2024
1 parent d4dbc23 commit f182676
Show file tree
Hide file tree
Showing 20 changed files with 505 additions and 24 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@
- Set `sparse` parameter to waveform extraction step in `spikesorting.v1`
#1039
- Efficiency improvement to `v0.Curation.insert_curation` #1072
- Add pytests for `spikesorting.v1` #1078

## [0.5.2] (April 22, 2024)

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ omit = [ # which submodules have no tests
# "*/position/*",
"*/ripple/*",
"*/sharing/*",
"*/spikesorting/*",
"*/spikesorting/v0/*",
# "*/spikesorting/*",
# "*/utils/*",
"settings.py",
]
Expand Down
3 changes: 3 additions & 0 deletions src/spyglass/spikesorting/analysis/v1/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ripple_detection import get_multiunit_population_firing_rate

from spyglass.common import Session # noqa: F401
from spyglass.settings import test_mode
from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput
from spyglass.utils.dj_mixin import SpyglassMixin, SpyglassMixinPart
from spyglass.utils.spikesorting import firing_rate_from_spike_indicator
Expand Down Expand Up @@ -72,6 +73,8 @@ def create_group(
"unit_filter_params_name": unit_filter_params_name,
}
if self & group_key:
if test_mode:
return
raise ValueError(
f"Group {nwb_file_name}: {group_name} already exists",
"please delete the group before creating a new one",
Expand Down
5 changes: 2 additions & 3 deletions src/spyglass/spikesorting/analysis/v1/unit_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,8 @@ def fetch_unit_spikes(
"""
if len(self) == len(UnitAnnotation()):
logger.warning(
"fetching all unit spikes",
"if this is unintended, please call as: ",
"(UnitAnnotation & key).fetch_unit_spikes()",
"fetching all unit spikes if this is unintended, please call as"
+ ": (UnitAnnotation & key).fetch_unit_spikes()"
)
# get the set of nwb files to load
merge_keys = [
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/spikesorting/spikesorting_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def get_restricted_merge_ids(
merge_ids : list
list of merge ids from the restricted sources
"""
# TODO: replace with long-distance restrictions

merge_ids = []

if "v1" in sources:
Expand Down
2 changes: 2 additions & 0 deletions src/spyglass/spikesorting/v1/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@ def merge_intervals(intervals):
_type_
_description_
"""
# TODO: Migrate to common_interval.py

if len(intervals) == 0:
return []

Expand Down
9 changes: 7 additions & 2 deletions src/spyglass/spikesorting/v1/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)
from spyglass.common.common_lab import LabTeam
from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile
from spyglass.settings import test_mode
from spyglass.spikesorting.utils import (
_get_recording_timestamps,
get_group_by_shank,
Expand Down Expand Up @@ -76,8 +77,12 @@ def set_group_by_shank(
omit_unitrode : bool
Optional. If True, no sort groups are defined for unitrodes.
"""
# delete any current groups
(SortGroup & {"nwb_file_name": nwb_file_name}).delete()
existing_entries = SortGroup & {"nwb_file_name": nwb_file_name}
if existing_entries and test_mode:
return
elif existing_entries:
# delete any current groups
(SortGroup & {"nwb_file_name": nwb_file_name}).delete()

sg_keys, sge_keys = get_group_by_shank(
nwb_file_name=nwb_file_name,
Expand Down
34 changes: 18 additions & 16 deletions src/spyglass/spikesorting/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,25 @@
from spyglass.spikesorting.v1.sorting import SpikeSortingSelection


def generate_nwb_uuid(nwb_file_name: str, initial: str, len_uuid: int = 6):
def generate_nwb_uuid(
nwb_file_name: str, initial: str, len_uuid: int = 6
) -> str:
"""Generates a unique identifier related to an NWB file.
Parameters
----------
nwb_file_name : str
_description_
Nwb file name, first part of resulting string.
initial : str
R if recording; A if artifact; S if sorting etc
len_uuid : int
how many digits of uuid4 to keep
Returns
-------
str
A unique identifier for the NWB file.
"{nwbf}_{initial}_{uuid4[:len_uuid]}"
"""
uuid4 = str(uuid.uuid4())
nwb_uuid = nwb_file_name + "_" + initial + "_" + uuid4[:len_uuid]
Expand All @@ -44,6 +52,7 @@ def get_spiking_sorting_v1_merge_ids(restriction: dict):
name of the artifact parameter
curation_id : int, optional
id of the curation (if not specified, uses the latest curation)
Returns
-------
merge_id_list : list
Expand All @@ -62,29 +71,22 @@ def get_spiking_sorting_v1_merge_ids(restriction: dict):
]
# list of sorting ids for each recording
sorting_restriction = restriction.copy()
del sorting_restriction["interval_list_name"]
_ = sorting_restriction.pop("interval_list_name", None)
sorting_id_list = []
for r_id, a_id in zip(recording_id_list, artifact_id_list):
rec_dict = {"recording_id": str(r_id), "interval_list_name": str(a_id)}
# if sorted with artifact detection
if (
SpikeSortingSelection()
& sorting_restriction
& {"recording_id": r_id, "interval_list_name": a_id}
):
if SpikeSortingSelection() & sorting_restriction & rec_dict:
sorting_id_list.append(
(
SpikeSortingSelection()
& sorting_restriction
& {"recording_id": r_id, "interval_list_name": a_id}
SpikeSortingSelection() & sorting_restriction & rec_dict
).fetch1("sorting_id")
)
# if sorted without artifact detection
else:
sorting_id_list.append(
(
SpikeSortingSelection()
& sorting_restriction
& {"recording_id": r_id, "interval_list_name": r_id}
SpikeSortingSelection() & sorting_restriction & rec_dict
).fetch1("sorting_id")
)
# if curation_id is specified, use that id for each sorting_id
Expand All @@ -100,8 +102,8 @@ def get_spiking_sorting_v1_merge_ids(restriction: dict):
merge_id_list = [
(
SpikeSortingOutput.CurationV1()
& {"sorting_id": id, "curation_id": c_id}
& {"sorting_id": s_id, "curation_id": c_id}
).fetch1("merge_id")
for id, c_id in zip(sorting_id_list, curation_id)
for s_id, c_id in zip(sorting_id_list, curation_id)
]
return merge_id_list
10 changes: 8 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,13 @@ def dlc_project_name():
yield "pytest_proj"


@pytest.fixture(scope="session")
def team_name(common):
team_name = "sc_eb"
common.LabTeam.insert1({"team_name": team_name}, skip_duplicates=True)
yield team_name


@pytest.fixture(scope="session")
def insert_project(
verbose_context,
Expand All @@ -823,6 +830,7 @@ def insert_project(
dlc_project_name,
dlc_project_tbl,
common,
team_name,
bodyparts,
mini_copy_name,
):
Expand All @@ -845,8 +853,6 @@ def insert_project(
RippleTimesV1,
)

team_name = "sc_eb"
common.LabTeam.insert1({"team_name": team_name}, skip_duplicates=True)
video_list = common.VideoFile().fetch(
"nwb_file_name", "epoch", as_dict=True
)[:2]
Expand Down
Empty file added tests/spikesorting/__init__.py
Empty file.
Loading

0 comments on commit f182676

Please sign in to comment.