From f1826767cc7a17a2f9d88d897bdf977ea0ad05aa Mon Sep 17 00:00:00 2001 From: Chris Broz Date: Thu, 5 Sep 2024 13:11:22 -0500 Subject: [PATCH] Add tests for spikesorting (#1078) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * WIP: Add tests for spikesorting * ✅ : Add tests for spikesorting 2 * Update changelog * ✅ : Add tests of utils --- CHANGELOG.md | 1 + pyproject.toml | 3 +- .../spikesorting/analysis/v1/group.py | 3 + .../analysis/v1/unit_annotation.py | 5 +- .../spikesorting/spikesorting_merge.py | 2 + src/spyglass/spikesorting/v1/artifact.py | 2 + src/spyglass/spikesorting/v1/recording.py | 9 +- src/spyglass/spikesorting/v1/utils.py | 34 +-- tests/conftest.py | 10 +- tests/spikesorting/__init__.py | 0 tests/spikesorting/conftest.py | 262 ++++++++++++++++++ tests/spikesorting/test_analysis.py | 9 + tests/spikesorting/test_artifact.py | 28 ++ tests/spikesorting/test_curation.py | 51 ++++ tests/spikesorting/test_figurl.py | 11 + tests/spikesorting/test_merge.py | 63 +++++ tests/spikesorting/test_metric_curation.py | 3 + tests/spikesorting/test_recording.py | 10 + tests/spikesorting/test_sorting.py | 3 + tests/spikesorting/test_utils.py | 20 ++ 20 files changed, 505 insertions(+), 24 deletions(-) create mode 100644 tests/spikesorting/__init__.py create mode 100644 tests/spikesorting/conftest.py create mode 100644 tests/spikesorting/test_analysis.py create mode 100644 tests/spikesorting/test_artifact.py create mode 100644 tests/spikesorting/test_curation.py create mode 100644 tests/spikesorting/test_figurl.py create mode 100644 tests/spikesorting/test_merge.py create mode 100644 tests/spikesorting/test_metric_curation.py create mode 100644 tests/spikesorting/test_recording.py create mode 100644 tests/spikesorting/test_sorting.py create mode 100644 tests/spikesorting/test_utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index e1afbe680..9930bcbf1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -102,6 +102,7 @@ - Set `sparse` parameter to waveform extraction step in `spikesorting.v1` #1039 - Efficiency improvement to `v0.Curation.insert_curation` #1072 + - Add pytests for `spikesorting.v1` #1078 ## [0.5.2] (April 22, 2024) diff --git a/pyproject.toml b/pyproject.toml index a3f96c3e0..8db231c8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -165,7 +165,8 @@ omit = [ # which submodules have no tests # "*/position/*", "*/ripple/*", "*/sharing/*", - "*/spikesorting/*", + "*/spikesorting/v0/*", + # "*/spikesorting/*", # "*/utils/*", "settings.py", ] diff --git a/src/spyglass/spikesorting/analysis/v1/group.py b/src/spyglass/spikesorting/analysis/v1/group.py index 8b3138e69..ad6517558 100644 --- a/src/spyglass/spikesorting/analysis/v1/group.py +++ b/src/spyglass/spikesorting/analysis/v1/group.py @@ -6,6 +6,7 @@ from ripple_detection import get_multiunit_population_firing_rate from spyglass.common import Session # noqa: F401 +from spyglass.settings import test_mode from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput from spyglass.utils.dj_mixin import SpyglassMixin, SpyglassMixinPart from spyglass.utils.spikesorting import firing_rate_from_spike_indicator @@ -72,6 +73,8 @@ def create_group( "unit_filter_params_name": unit_filter_params_name, } if self & group_key: + if test_mode: + return raise ValueError( f"Group {nwb_file_name}: {group_name} already exists", "please delete the group before creating a new one", diff --git a/src/spyglass/spikesorting/analysis/v1/unit_annotation.py b/src/spyglass/spikesorting/analysis/v1/unit_annotation.py index 4e1328979..d1ac26a11 100644 --- a/src/spyglass/spikesorting/analysis/v1/unit_annotation.py +++ b/src/spyglass/spikesorting/analysis/v1/unit_annotation.py @@ -101,9 +101,8 @@ def fetch_unit_spikes( """ if len(self) == len(UnitAnnotation()): logger.warning( - "fetching all unit spikes", - "if this is unintended, please call as: ", - "(UnitAnnotation & key).fetch_unit_spikes()", + "fetching all unit spikes if this is unintended, please call as" + + ": (UnitAnnotation & key).fetch_unit_spikes()" ) # get the set of nwb files to load merge_keys = [ diff --git a/src/spyglass/spikesorting/spikesorting_merge.py b/src/spyglass/spikesorting/spikesorting_merge.py index 7d12601e2..e7a27bae0 100644 --- a/src/spyglass/spikesorting/spikesorting_merge.py +++ b/src/spyglass/spikesorting/spikesorting_merge.py @@ -83,6 +83,8 @@ def get_restricted_merge_ids( merge_ids : list list of merge ids from the restricted sources """ + # TODO: replace with long-distance restrictions + merge_ids = [] if "v1" in sources: diff --git a/src/spyglass/spikesorting/v1/artifact.py b/src/spyglass/spikesorting/v1/artifact.py index 139f30c81..04a7dd463 100644 --- a/src/spyglass/spikesorting/v1/artifact.py +++ b/src/spyglass/spikesorting/v1/artifact.py @@ -330,6 +330,8 @@ def merge_intervals(intervals): _type_ _description_ """ + # TODO: Migrate to common_interval.py + if len(intervals) == 0: return [] diff --git a/src/spyglass/spikesorting/v1/recording.py b/src/spyglass/spikesorting/v1/recording.py index 72f099a18..fd5214e40 100644 --- a/src/spyglass/spikesorting/v1/recording.py +++ b/src/spyglass/spikesorting/v1/recording.py @@ -19,6 +19,7 @@ ) from spyglass.common.common_lab import LabTeam from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile +from spyglass.settings import test_mode from spyglass.spikesorting.utils import ( _get_recording_timestamps, get_group_by_shank, @@ -76,8 +77,12 @@ def set_group_by_shank( omit_unitrode : bool Optional. If True, no sort groups are defined for unitrodes. """ - # delete any current groups - (SortGroup & {"nwb_file_name": nwb_file_name}).delete() + existing_entries = SortGroup & {"nwb_file_name": nwb_file_name} + if existing_entries and test_mode: + return + elif existing_entries: + # delete any current groups + (SortGroup & {"nwb_file_name": nwb_file_name}).delete() sg_keys, sge_keys = get_group_by_shank( nwb_file_name=nwb_file_name, diff --git a/src/spyglass/spikesorting/v1/utils.py b/src/spyglass/spikesorting/v1/utils.py index 6a511c43e..66cea0b41 100644 --- a/src/spyglass/spikesorting/v1/utils.py +++ b/src/spyglass/spikesorting/v1/utils.py @@ -9,17 +9,25 @@ from spyglass.spikesorting.v1.sorting import SpikeSortingSelection -def generate_nwb_uuid(nwb_file_name: str, initial: str, len_uuid: int = 6): +def generate_nwb_uuid( + nwb_file_name: str, initial: str, len_uuid: int = 6 +) -> str: """Generates a unique identifier related to an NWB file. Parameters ---------- nwb_file_name : str - _description_ + Nwb file name, first part of resulting string. initial : str R if recording; A if artifact; S if sorting etc len_uuid : int how many digits of uuid4 to keep + + Returns + ------- + str + A unique identifier for the NWB file. + "{nwbf}_{initial}_{uuid4[:len_uuid]}" """ uuid4 = str(uuid.uuid4()) nwb_uuid = nwb_file_name + "_" + initial + "_" + uuid4[:len_uuid] @@ -44,6 +52,7 @@ def get_spiking_sorting_v1_merge_ids(restriction: dict): name of the artifact parameter curation_id : int, optional id of the curation (if not specified, uses the latest curation) + Returns ------- merge_id_list : list @@ -62,29 +71,22 @@ def get_spiking_sorting_v1_merge_ids(restriction: dict): ] # list of sorting ids for each recording sorting_restriction = restriction.copy() - del sorting_restriction["interval_list_name"] + _ = sorting_restriction.pop("interval_list_name", None) sorting_id_list = [] for r_id, a_id in zip(recording_id_list, artifact_id_list): + rec_dict = {"recording_id": str(r_id), "interval_list_name": str(a_id)} # if sorted with artifact detection - if ( - SpikeSortingSelection() - & sorting_restriction - & {"recording_id": r_id, "interval_list_name": a_id} - ): + if SpikeSortingSelection() & sorting_restriction & rec_dict: sorting_id_list.append( ( - SpikeSortingSelection() - & sorting_restriction - & {"recording_id": r_id, "interval_list_name": a_id} + SpikeSortingSelection() & sorting_restriction & rec_dict ).fetch1("sorting_id") ) # if sorted without artifact detection else: sorting_id_list.append( ( - SpikeSortingSelection() - & sorting_restriction - & {"recording_id": r_id, "interval_list_name": r_id} + SpikeSortingSelection() & sorting_restriction & rec_dict ).fetch1("sorting_id") ) # if curation_id is specified, use that id for each sorting_id @@ -100,8 +102,8 @@ def get_spiking_sorting_v1_merge_ids(restriction: dict): merge_id_list = [ ( SpikeSortingOutput.CurationV1() - & {"sorting_id": id, "curation_id": c_id} + & {"sorting_id": s_id, "curation_id": c_id} ).fetch1("merge_id") - for id, c_id in zip(sorting_id_list, curation_id) + for s_id, c_id in zip(sorting_id_list, curation_id) ] return merge_id_list diff --git a/tests/conftest.py b/tests/conftest.py index 04420ffe2..8a9bc1a79 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -815,6 +815,13 @@ def dlc_project_name(): yield "pytest_proj" +@pytest.fixture(scope="session") +def team_name(common): + team_name = "sc_eb" + common.LabTeam.insert1({"team_name": team_name}, skip_duplicates=True) + yield team_name + + @pytest.fixture(scope="session") def insert_project( verbose_context, @@ -823,6 +830,7 @@ def insert_project( dlc_project_name, dlc_project_tbl, common, + team_name, bodyparts, mini_copy_name, ): @@ -845,8 +853,6 @@ def insert_project( RippleTimesV1, ) - team_name = "sc_eb" - common.LabTeam.insert1({"team_name": team_name}, skip_duplicates=True) video_list = common.VideoFile().fetch( "nwb_file_name", "epoch", as_dict=True )[:2] diff --git a/tests/spikesorting/__init__.py b/tests/spikesorting/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/spikesorting/conftest.py b/tests/spikesorting/conftest.py new file mode 100644 index 000000000..d287d35f4 --- /dev/null +++ b/tests/spikesorting/conftest.py @@ -0,0 +1,262 @@ +import re + +import pytest +from datajoint.hash import key_hash + + +@pytest.fixture(scope="session") +def spike_v1(common): + from spyglass.spikesorting import v1 + + yield v1 + + +@pytest.fixture(scope="session") +def pop_rec(spike_v1, mini_dict, team_name): + spike_v1.SortGroup.set_group_by_shank(**mini_dict) + key = { + **mini_dict, + "sort_group_id": 0, + "preproc_param_name": "default", + "interval_list_name": "01_s1", + "team_name": team_name, + } + spike_v1.SpikeSortingRecordingSelection.insert_selection(key) + ssr_pk = ( + (spike_v1.SpikeSortingRecordingSelection & key).proj().fetch1("KEY") + ) + spike_v1.SpikeSortingRecording.populate(ssr_pk) + + yield ssr_pk + + +@pytest.fixture(scope="session") +def pop_art(spike_v1, mini_dict, pop_rec): + key = { + "recording_id": pop_rec["recording_id"], + "artifact_param_name": "default", + } + spike_v1.ArtifactDetectionSelection.insert_selection(key) + spike_v1.ArtifactDetection.populate() + + yield spike_v1.ArtifactDetection().fetch("KEY", as_dict=True)[0] + + +@pytest.fixture(scope="session") +def sorter_dict(): + return {"sorter": "mountainsort4"} + + +@pytest.fixture(scope="session") +def pop_sort(spike_v1, pop_rec, pop_art, mini_dict, sorter_dict): + key = { + **mini_dict, + **sorter_dict, + "recording_id": pop_rec["recording_id"], + "interval_list_name": str(pop_art["artifact_id"]), + "sorter_param_name": "franklab_tetrode_hippocampus_30KHz", + } + spike_v1.SpikeSortingSelection.insert_selection(key) + spike_v1.SpikeSorting.populate() + + yield spike_v1.SpikeSorting().fetch("KEY", as_dict=True)[0] + + +@pytest.fixture(scope="session") +def sorting_objs(spike_v1, pop_sort): + sort_nwb = (spike_v1.SpikeSorting & pop_sort).fetch_nwb() + sort_si = spike_v1.SpikeSorting.get_sorting(pop_sort) + yield sort_nwb, sort_si + + +@pytest.fixture(scope="session") +def pop_curation(spike_v1, pop_sort): + spike_v1.CurationV1.insert_curation( + sorting_id=pop_sort["sorting_id"], + description="testing sort", + ) + + yield spike_v1.CurationV1().fetch("KEY", as_dict=True)[0] + + +@pytest.fixture(scope="session") +def pop_metric(spike_v1, pop_sort, pop_curation): + _ = pop_curation # make sure this happens first + key = { + "sorting_id": pop_sort["sorting_id"], + "curation_id": 0, + "waveform_param_name": "default_not_whitened", + "metric_param_name": "franklab_default", + "metric_curation_param_name": "default", + } + + spike_v1.MetricCurationSelection.insert_selection(key) + spike_v1.MetricCuration.populate(key) + + yield spike_v1.MetricCuration().fetch("KEY", as_dict=True)[0] + + +@pytest.fixture(scope="session") +def metric_objs(spike_v1, pop_metric): + key = {"metric_curation_id": pop_metric["metric_curation_id"]} + labels = spike_v1.MetricCuration.get_labels(key) + merge_groups = spike_v1.MetricCuration.get_merge_groups(key) + metrics = spike_v1.MetricCuration.get_metrics(key) + yield labels, merge_groups, metrics + + +@pytest.fixture(scope="session") +def pop_curation_metric(spike_v1, pop_metric, metric_objs): + labels, merge_groups, metrics = metric_objs + parent_dict = {"parent_curation_id": 0} + spike_v1.CurationV1.insert_curation( + sorting_id=( + spike_v1.MetricCurationSelection + & {"metric_curation_id": pop_metric["metric_curation_id"]} + ).fetch1("sorting_id"), + **parent_dict, + labels=labels, + merge_groups=merge_groups, + metrics=metrics, + description="after metric curation", + ) + + yield (spike_v1.CurationV1 & parent_dict).fetch("KEY", as_dict=True)[0] + + +@pytest.fixture(scope="session") +def pop_figurl(spike_v1, pop_sort, metric_objs): + # WON'T WORK UNTIL CI/CD KACHERY_CLOUD INIT + sort_dict = {"sorting_id": pop_sort["sorting_id"], "curation_id": 1} + curation_uri = spike_v1.FigURLCurationSelection.generate_curation_uri( + sort_dict + ) + _, _, metrics = metric_objs + key = { + **sort_dict, + "curation_uri": curation_uri, + "metrics_figurl": list(metrics.keys()), + } + spike_v1.FigURLCurationSelection.insert_selection(key) + spike_v1.FigURLCuration.populate() + + yield spike_v1.FigURLCuration().fetch("KEY", as_dict=True)[0] + + +@pytest.fixture(scope="session") +def pop_figurl_json(spike_v1, pop_metric): + # WON'T WORK UNTIL CI/CD KACHERY_CLOUD INIT + gh_curation_uri = ( + "gh://LorenFrankLab/sorting-curations/main/khl02007/test/curation.json" + ) + key = { + "sorting_id": pop_metric["sorting_id"], + "curation_id": 1, + "curation_uri": gh_curation_uri, + "metrics_figurl": [], + } + spike_v1.FigURLCurationSelection.insert_selection(key) + spike_v1.FigURLCuration.populate() + + labels = spike_v1.FigURLCuration.get_labels(gh_curation_uri) + merge_groups = spike_v1.FigURLCuration.get_merge_groups(gh_curation_uri) + _, _, metrics = metric_objs + spike_v1.CurationV1.insert_curation( + sorting_id=pop_sort["sorting_id"], + parent_curation_id=1, + labels=labels, + merge_groups=merge_groups, + metrics=metrics, + description="after figurl curation", + ) + yield spike_v1.CurationV1().fetch("KEY", as_dict=True) # list of dicts + + +@pytest.fixture(scope="session") +def spike_merge(spike_v1): + from spyglass.spikesorting.spikesorting_merge import SpikeSortingOutput + + yield SpikeSortingOutput() + + +@pytest.fixture(scope="session") +def pop_merge( + spike_v1, pop_curation_metric, spike_merge, mini_dict, sorter_dict +): + # TODO: add figurl fixtures when kachery_cloud is initialized + + spike_merge.insert([pop_curation_metric], part_name="CurationV1") + yield spike_merge.fetch("KEY", as_dict=True)[0] + + +def is_uuid(text): + uuid_pattern = re.compile( + r"\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b" + ) + return uuid_pattern.fullmatch(str(text)) is not None + + +def hash_sort_info(sort_info): + """Hashes attributes of a dj.Table object that are not randomly assigned.""" + no_str_uuid = { + k: v + for k, v in sort_info.fetch(as_dict=True)[0].items() + if not is_uuid(v) and k != "analysis_file_name" + } + return key_hash(no_str_uuid) + + +@pytest.fixture(scope="session") +def spike_v1_group(): + from spyglass.spikesorting.analysis.v1 import group + + yield group + + +@pytest.fixture(scope="session") +def pop_group(spike_v1_group, spike_merge, mini_dict, pop_merge): + + _ = pop_merge # make sure this happens first + + spike_v1_group.UnitSelectionParams().insert_default() + spike_v1_group.SortedSpikesGroup().create_group( + **mini_dict, + group_name="demo_group", + keys=spike_merge.proj(spikesorting_merge_id="merge_id").fetch("KEY"), + unit_filter_params_name="default_exclusion", + ) + yield spike_v1_group.SortedSpikesGroup().fetch("KEY", as_dict=True)[0] + + +@pytest.fixture(scope="session") +def spike_v1_ua(): + from spyglass.spikesorting.analysis.v1.unit_annotation import UnitAnnotation + + yield UnitAnnotation() + + +@pytest.fixture(scope="session") +def pop_annotations(spike_v1_group, spike_v1_ua, pop_group): + spike_times, unit_ids = spike_v1_group.SortedSpikesGroup().fetch_spike_data( + pop_group, return_unit_ids=True + ) + for spikes, unit_key in zip(spike_times, unit_ids): + quant_key = { + **unit_key, + "annotation": "spike_count", + "quantification": len(spikes), + } + label_key = { + **unit_key, + "annotation": "cell_type", + "label": "pyridimal" if len(spikes) < 1000 else "interneuron", + } + + spike_v1_ua.add_annotation(quant_key, skip_duplicates=True) + spike_v1_ua.add_annotation(label_key, skip_duplicates=True) + + yield ( + spike_v1_ua.Annotation + # * (spike_v1_group.SortedSpikesGroup.Units & pop_group) + & {"annotation": "spike_count"} + ) diff --git a/tests/spikesorting/test_analysis.py b/tests/spikesorting/test_analysis.py new file mode 100644 index 000000000..aa95e24b2 --- /dev/null +++ b/tests/spikesorting/test_analysis.py @@ -0,0 +1,9 @@ +def test_analysis_units(pop_annotations): + selected_spike_times, selected_unit_ids = pop_annotations.fetch_unit_spikes( + return_unit_ids=True + ) + + assert selected_spike_times[0].shape[0] == 243, "Unuxpected spike count" + + units = [d["unit_id"] for d in selected_unit_ids] + assert units == [0, 1, 2], "Unexpected unit ids" diff --git a/tests/spikesorting/test_artifact.py b/tests/spikesorting/test_artifact.py new file mode 100644 index 000000000..5466f9571 --- /dev/null +++ b/tests/spikesorting/test_artifact.py @@ -0,0 +1,28 @@ +import numpy as np +import pytest + + +@pytest.fixture +def art_interval(common, spike_v1, pop_art): + id = str((spike_v1.ArtifactDetection & pop_art).fetch1("artifact_id")) + yield (common.IntervalList & {"interval_list_name": id}).fetch1() + + +def test_artifact_detection(art_interval): + assert ( + art_interval["pipeline"] == "spikesorting_artifact_v1" + ), "Artifact detection failed to populate interval list" + + +def test_null_artifact_detection(spike_v1, art_interval): + from spyglass.spikesorting.v1.artifact import _get_artifact_times + + rec_key = spike_v1.SpikeSortingRecording.fetch("KEY")[0] + rec = spike_v1.SpikeSortingRecording.get_recording(rec_key) + + input_times = art_interval["valid_times"] + null_times = _get_artifact_times(rec, input_times) + + assert np.array_equal( + input_times[0], null_times[0] + ), "Null artifact detection failed" diff --git a/tests/spikesorting/test_curation.py b/tests/spikesorting/test_curation.py new file mode 100644 index 000000000..43df0fed5 --- /dev/null +++ b/tests/spikesorting/test_curation.py @@ -0,0 +1,51 @@ +import numpy as np +from datajoint.hash import key_hash +from spikeinterface import BaseSorting +from spikeinterface.extractors.nwbextractors import NwbRecordingExtractor + +from .conftest import hash_sort_info + + +def test_curation_rec(spike_v1, pop_curation): + rec = spike_v1.CurationV1.get_recording(pop_curation) + assert isinstance( + rec, NwbRecordingExtractor + ), "CurationV1.get_recording failed to return a RecordingExtractor" + + sample_freq = rec.get_sampling_frequency() + assert np.isclose( + 29_959.3, sample_freq + ), "CurqtionV1.get_sampling_frequency unexpected value" + + times = rec.get_times() + assert np.isclose( + 1687474805.4, np.mean((times[0], times[-1])) + ), "CurationV1.get_times unexpected value" + + +def test_curation_sort(spike_v1, pop_curation): + sort = spike_v1.CurationV1.get_sorting(pop_curation) + sort_dict = sort.to_dict() + assert isinstance( + sort, BaseSorting + ), "CurationV1.get_sorting failed to return a BaseSorting" + assert ( + key_hash(sort_dict) == "612983fbf4958f6b2c7abe7ced86ab73" + ), "CurationV1.get_sorting unexpected value" + assert ( + sort_dict["kwargs"]["spikes"].shape[0] == 918 + ), "CurationV1.get_sorting unexpected shape" + + +def test_curation_sort_info(spike_v1, pop_curation): + sort_info = spike_v1.CurationV1.get_sort_group_info(pop_curation) + assert ( + hash_sort_info(sort_info) == "be874e806a482ed2677fd0d0b449f965" + ), "CurationV1.get_sort_group_info unexpected value" + + +def test_curation_metric(spike_v1, pop_curation_metric): + sort_info = spike_v1.CurationV1.get_sort_group_info(pop_curation_metric) + assert ( + hash_sort_info(sort_info) == "48e437bc116900fe64e492d74595b56d" + ), "CurationV1.get_sort_group_info unexpected value" diff --git a/tests/spikesorting/test_figurl.py b/tests/spikesorting/test_figurl.py new file mode 100644 index 000000000..cf8a98e8b --- /dev/null +++ b/tests/spikesorting/test_figurl.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.mark.skip(reason="Not testing kachery") +def test_figurl(spike_v1): + pass + + +@pytest.mark.skip(reason="Not testing kachery") +def test_figurl_json(spike_v1): + pass diff --git a/tests/spikesorting/test_merge.py b/tests/spikesorting/test_merge.py new file mode 100644 index 000000000..25751684c --- /dev/null +++ b/tests/spikesorting/test_merge.py @@ -0,0 +1,63 @@ +import pytest +from spikeinterface import BaseSorting +from spikeinterface.extractors.nwbextractors import NwbRecordingExtractor + +from .conftest import hash_sort_info + + +def test_merge_get_restr(spike_merge, pop_merge, pop_curation_metric): + restr_id = spike_merge.get_restricted_merge_ids( + pop_curation_metric, sources=["v1"] + )[0] + assert ( + restr_id == pop_merge["merge_id"] + ), "SpikeSortingOutput merge_id mismatch" + + non_artifact = spike_merge.get_restricted_merge_ids( + pop_curation_metric, sources=["v1"], restrict_by_artifact=False + )[0] + assert restr_id == non_artifact, "SpikeSortingOutput merge_id mismatch" + + +def test_merge_get_recording(spike_merge, pop_merge): + rec = spike_merge.get_recording(pop_merge) + assert isinstance( + rec, NwbRecordingExtractor + ), "SpikeSortingOutput.get_recording failed to return a RecordingExtractor" + + +def test_merge_get_sorting(spike_merge, pop_merge): + sort = spike_merge.get_sorting(pop_merge) + assert isinstance( + sort, BaseSorting + ), "SpikeSortingOutput.get_sorting failed to return a BaseSorting" + + +def test_merge_get_sort_group_info(spike_merge, pop_merge): + hash = hash_sort_info(spike_merge.get_sort_group_info(pop_merge)) + assert ( + hash == "48e437bc116900fe64e492d74595b56d" + ), "SpikeSortingOutput.get_sort_group_info unexpected value" + + +@pytest.fixture(scope="session") +def merge_times(spike_merge, pop_merge): + yield spike_merge.get_spike_times(pop_merge) + + +def test_merge_get_spike_times(merge_times): + assert ( + merge_times[0].shape[0] == 243 + ), "SpikeSortingOutput.get_spike_times unexpected shape" + + +@pytest.mark.skip(reason="Not testing bc #1077") +def test_merge_get_spike_indicators(spike_merge, pop_merge, merge_times): + ret = spike_merge.get_spike_indicator(pop_merge, time=merge_times) + raise NotImplementedError(ret) + + +@pytest.mark.skip(reason="Not testing bc #1077") +def test_merge_get_firing_rate(spike_merge, pop_merge, merge_times): + ret = spike_merge.get_firing_rate(pop_merge, time=merge_times) + raise NotImplementedError(ret) diff --git a/tests/spikesorting/test_metric_curation.py b/tests/spikesorting/test_metric_curation.py new file mode 100644 index 000000000..0f7dc7a9a --- /dev/null +++ b/tests/spikesorting/test_metric_curation.py @@ -0,0 +1,3 @@ +def test_metric_curation(spike_v1, pop_curation_metric): + ret = spike_v1.CurationV1 & pop_curation_metric & "description LIKE 'a%'" + assert len(ret) == 1, "CurationV1.insert_curation failed to insert a record" diff --git a/tests/spikesorting/test_recording.py b/tests/spikesorting/test_recording.py new file mode 100644 index 000000000..780cbc46c --- /dev/null +++ b/tests/spikesorting/test_recording.py @@ -0,0 +1,10 @@ +def test_sort_group(spike_v1, pop_rec): + max_id = max(spike_v1.SortGroup.fetch("sort_group_id")) + assert ( + max_id == 31 + ), "SortGroup.insert_sort_group failed to insert all records" + + +def test_spike_sorting(spike_v1, pop_rec): + n_records = len(spike_v1.SpikeSortingRecording()) + assert n_records == 1, "SpikeSortingRecording failed to insert a record" diff --git a/tests/spikesorting/test_sorting.py b/tests/spikesorting/test_sorting.py new file mode 100644 index 000000000..e908fed07 --- /dev/null +++ b/tests/spikesorting/test_sorting.py @@ -0,0 +1,3 @@ +def test_sorting(spike_v1, pop_sort): + n_sorts = len(spike_v1.SpikeSorting & pop_sort) + assert n_sorts >= 1, "SpikeSorting population failed" diff --git a/tests/spikesorting/test_utils.py b/tests/spikesorting/test_utils.py new file mode 100644 index 000000000..47638f993 --- /dev/null +++ b/tests/spikesorting/test_utils.py @@ -0,0 +1,20 @@ +from uuid import UUID + + +def test_uuid_generator(): + + from spyglass.spikesorting.v1.utils import generate_nwb_uuid + + nwb_file_name, initial = "test.nwb", "R" + ret_parts = generate_nwb_uuid(nwb_file_name, initial).split("_") + assert ret_parts[0] == nwb_file_name, "Unexpected nwb file name" + assert ret_parts[1] == initial, "Unexpected initial" + assert len(ret_parts[2]) == 6, "Unexpected uuid length" + + +def test_get_merge_ids(pop_merge, mini_dict): + from spyglass.spikesorting.v1.utils import get_spiking_sorting_v1_merge_ids + + ret = get_spiking_sorting_v1_merge_ids(dict(mini_dict, curation_id=1)) + assert isinstance(ret[0], UUID), "Unexpected type from util" + assert ret[0] == pop_merge["merge_id"], "Unexpected merge_id from util"