Skip to content

Commit

Permalink
Remove stored hashes from tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Oct 3, 2024
1 parent 5446d07 commit b251d9e
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 44 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dj.FreeTable(dj.conn(), "common_session.session_group").drop()
- Update DataJoint install and password instructions #1131
- Fix dandi upload process for nwb's with video or linked objects #1095, #1151
- Minor docs fixes #1145
- Remove stored hashes from pytests #115X

### Pipelines

Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def pytest_configure(config):


def pytest_unconfigure(config):
from spyglass.utils.nwb_helper_fn import close_nwb_files

close_nwb_files()
if TEARDOWN:
SERVER.stop()

Expand Down
19 changes: 13 additions & 6 deletions tests/linearization/test_lin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from datajoint.hash import key_hash
import pytest


def test_fetch1_dataframe(lin_v1, lin_merge, lin_merge_key):
hash_df = key_hash(
(lin_merge & lin_merge_key).fetch1_dataframe().round(3).to_dict()
)
hash_exp = "883a7b8aa47931ae7b265660ca27b462"
assert hash_df == hash_exp, "Dataframe differs from expected"
df = (lin_merge & lin_merge_key).fetch1_dataframe().round(3).sum().to_dict()
exp = {
"linear_position": 3249449.258,
"projected_x_position": 472245.797,
"projected_y_position": 317857.473,
"track_segment_id": 31158.0,
}

for k in exp:
assert (
pytest.approx(df[k], rel=1e-3) == exp[k]
), f"Value differs from expected: {k}"


# TODO: Add more tests of this pipeline, not just the fetch1_dataframe method
26 changes: 18 additions & 8 deletions tests/position/test_trodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,25 @@ def test_sel_insert_error(trodes_sel_table, pos_interval_key):

def test_fetch_df(trodes_pos_v1, trodes_params):
upsampled = {"trodes_pos_params_name": "single_led_upsampled"}
hash_df = key_hash(
(
(trodes_pos_v1 & upsampled)
.fetch1_dataframe(add_frame_ind=True)
.round(3) # float precision
).to_dict()
df = (
(trodes_pos_v1 & upsampled)
.fetch1_dataframe(add_frame_ind=True)
.round(3)
.sum()
.to_dict()
)
hash_exp = "5296e74dea2e5e68d39f81bc81723a12"
assert hash_df == hash_exp, "Dataframe differs from expected"
exp = {
"position_x": 230389.335,
"position_y": 295368.260,
"orientation": 4716.906,
"velocity_x": 1726.304,
"velocity_y": -1675.276,
"speed": 6257.273,
}
for k in exp:
assert (
pytest.approx(df[k], rel=1e-3) == exp[k]
), f"Value differs from expected: {k}"


def test_trodes_video(sgp, trodes_pos_v1):
Expand Down
10 changes: 0 additions & 10 deletions tests/spikesorting/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,16 +198,6 @@ def is_uuid(text):
return uuid_pattern.fullmatch(str(text)) is not None


def hash_sort_info(sort_info):
"""Hashes attributes of a dj.Table object that are not randomly assigned."""
no_str_uuid = {
k: v
for k, v in sort_info.fetch(as_dict=True)[0].items()
if not is_uuid(v) and k != "analysis_file_name"
}
return key_hash(no_str_uuid)


@pytest.fixture(scope="session")
def spike_v1_group():
from spyglass.spikesorting.analysis.v1 import group
Expand Down
98 changes: 84 additions & 14 deletions tests/spikesorting/test_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from spikeinterface import BaseSorting
from spikeinterface.extractors.nwbextractors import NwbRecordingExtractor

from .conftest import hash_sort_info


def test_curation_rec(spike_v1, pop_curation):
rec = spike_v1.CurationV1.get_recording(pop_curation)
Expand All @@ -29,22 +27,94 @@ def test_curation_sort(spike_v1, pop_curation):
assert isinstance(
sort, BaseSorting
), "CurationV1.get_sorting failed to return a BaseSorting"
assert (
key_hash(sort_dict) == "612983fbf4958f6b2c7abe7ced86ab73"
), "CurationV1.get_sorting unexpected value"
assert (
sort_dict["kwargs"]["spikes"].shape[0] == 918
), "CurationV1.get_sorting unexpected shape"

expected = {
"class": "spikeinterface.core.numpyextractors.NumpySorting",
"module": "spikeinterface",
"relative_paths": False,
}
for k in expected:
assert (
sort_dict[k] == expected[k]
), f"CurationV1.get_sorting unexpected value: {k}"

def test_curation_sort_info(spike_v1, pop_curation, pop_curation_metric):
sort_info = spike_v1.CurationV1.get_sort_group_info(pop_curation)
sort_metric = spike_v1.CurationV1.get_sort_group_info(pop_curation_metric)

assert (
hash_sort_info(sort_info) == "be874e806a482ed2677fd0d0b449f965"
), "CurationV1.get_sort_group_info unexpected value"
def test_curation_sort_info(spike_v1, pop_curation):
sort_info = spike_v1.CurationV1.get_sort_group_info(pop_curation).fetch1()
exp = {
"bad_channel": "False",
"curation_id": 0,
"description": "testing sort",
"electrode_group_name": "0",
"electrode_id": 0,
"filtering": "None",
"impedance": 0.0,
"merges_applied": 0,
"name": "0",
"nwb_file_name": "minirec20230622_.nwb",
"original_reference_electrode": 0,
"parent_curation_id": -1,
"probe_electrode": 0,
"probe_id": "tetrode_12.5",
"probe_shank": 0,
"region_id": 1,
"sort_group_id": 0,
"sorter": "mountainsort4",
"sorter_param_name": "franklab_tetrode_hippocampus_30KHz",
"subregion_name": None,
"subsubregion_name": None,
"x": 0.0,
"x_warped": 0.0,
"y": 0.0,
"y_warped": 0.0,
"z": 0.0,
"z_warped": 0.0,
}
for k in exp:
assert (
sort_info[k] == exp[k]
), f"CurationV1.get_sort_group_info unexpected value: {k}"

assert (
hash_sort_info(sort_metric) == "48e437bc116900fe64e492d74595b56d"
), "CurationV1.get_sort_group_info unexpected value"

def test_curation_sort_metric(spike_v1, pop_curation, pop_curation_metric):
sort_metric = spike_v1.CurationV1.get_sort_group_info(
pop_curation_metric
).fetch1()
expected = {
"bad_channel": "False",
"contacts": "",
"curation_id": 1,
"description": "after metric curation",
"electrode_group_name": "0",
"electrode_id": 0,
"filtering": "None",
"impedance": 0.0,
"merges_applied": 0,
"name": "0",
"nwb_file_name": "minirec20230622_.nwb",
"object_id": "a77cbb7a-b18c-47a3-982c-6c159ffdf40e",
"original_reference_electrode": 0,
"parent_curation_id": 0,
"probe_electrode": 0,
"probe_id": "tetrode_12.5",
"probe_shank": 0,
"region_id": 1,
"sort_group_id": 0,
"sorter": "mountainsort4",
"sorter_param_name": "franklab_tetrode_hippocampus_30KHz",
"subregion_name": None,
"subsubregion_name": None,
"x": 0.0,
"x_warped": 0.0,
"y": 0.0,
"y_warped": 0.0,
"z": 0.0,
"z_warped": 0.0,
}
for k in expected:
assert (
sort_metric[k] == expected[k]
), f"CurationV1.get_sort_group_info unexpected value: {k}"
42 changes: 36 additions & 6 deletions tests/spikesorting/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from spikeinterface import BaseSorting
from spikeinterface.extractors.nwbextractors import NwbRecordingExtractor

from .conftest import hash_sort_info


def test_merge_get_restr(spike_merge, pop_merge, pop_curation_metric):
restr_id = spike_merge.get_restricted_merge_ids(
Expand Down Expand Up @@ -34,10 +32,42 @@ def test_merge_get_sorting(spike_merge, pop_merge):


def test_merge_get_sort_group_info(spike_merge, pop_merge):
hash = hash_sort_info(spike_merge.get_sort_group_info(pop_merge))
assert (
hash == "48e437bc116900fe64e492d74595b56d"
), "SpikeSortingOutput.get_sort_group_info unexpected value"
sort_info = spike_merge.get_sort_group_info(pop_merge).fetch1()
expected = {
"bad_channel": "False",
"contacts": "",
"curation_id": 1,
"description": "after metric curation",
"electrode_group_name": "0",
"electrode_id": 0,
"filtering": "None",
"impedance": 0.0,
"merges_applied": 0,
"name": "0",
"nwb_file_name": "minirec20230622_.nwb",
"original_reference_electrode": 0,
"parent_curation_id": 0,
"probe_electrode": 0,
"probe_id": "tetrode_12.5",
"probe_shank": 0,
"region_id": 1,
"sort_group_id": 0,
"sorter": "mountainsort4",
"sorter_param_name": "franklab_tetrode_hippocampus_30KHz",
"subregion_name": None,
"subsubregion_name": None,
"x": 0.0,
"x_warped": 0.0,
"y": 0.0,
"y_warped": 0.0,
"z": 0.0,
"z_warped": 0.0,
}

for k in expected:
assert (
sort_info[k] == expected[k]
), f"SpikeSortingOutput.get_sort_group_info unexpected value: {k}"


@pytest.fixture(scope="session")
Expand Down

0 comments on commit b251d9e

Please sign in to comment.