diff --git a/pyproject.toml b/pyproject.toml index 872a5d5c5..75ea034fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,7 @@ test = [ "pre-commit", # linting "pytest", # unit testing "pytest-cov", # code coverage + "pytest-xvfb", # for headless testing of Qt ] docs = [ "hatch", # Get version from env @@ -120,12 +121,12 @@ ignore-words-list = 'nevers' [tool.pytest.ini_options] minversion = "7.0" addopts = [ - "-sv", - "--sw", # stepwise: resume with next test after failure - "--pdb", # drop into debugger on failure + # "-sv", # verbose output + # "--sw", # stepwise: resume with next test after failure + # "--pdb", # drop into debugger on failure "-p no:warnings", - "--no-teardown", # don't teardown the database after tests - "--quiet-spy", # don't show logging from spyglass + # "--no-teardown", # don't teardown the database after tests + # "--quiet-spy", # don't show logging from spyglass "--show-capture=no", "--pdbcls=IPython.terminal.debugger:TerminalPdb", # use ipython debugger "--cov=spyglass", @@ -134,6 +135,12 @@ addopts = [ ] testpaths = ["tests"] log_level = "INFO" +env = [ + "QT_QPA_PLATFORM = offscreen", # QT fails headless without this + # "DISPLAY = :0", # QT fails headless without this + "TF_ENABLE_ONEDNN_OPTS = 0", # TF disable approx calcs + "TF_CPP_MIN_LOG_LEVEL = 2", # Disable TF warnings +] [tool.coverage.run] source = ["*/src/spyglass/*"] diff --git a/src/spyglass/decoding/v0/dj_decoder_conversion.py b/src/spyglass/decoding/v0/dj_decoder_conversion.py index edcb0d637..1cf6d30c4 100644 --- a/src/spyglass/decoding/v0/dj_decoder_conversion.py +++ b/src/spyglass/decoding/v0/dj_decoder_conversion.py @@ -26,6 +26,21 @@ ObservationModel, ) except ImportError as e: + ( + Identity, + RandomWalk, + RandomWalkDirection1, + RandomWalkDirection2, + Uniform, + DiagonalDiscrete, + RandomDiscrete, + UniformDiscrete, + UserDefinedDiscrete, + Environment, + UniformInitialConditions, + UniformOneEnvironmentInitialConditions, + ObservationModel, + ) = [None] * 13 logger.warning(e) from track_linearization import make_track_graph diff --git a/src/spyglass/position/v1/dlc_utils.py b/src/spyglass/position/v1/dlc_utils.py index 7b14efa85..6d27615e4 100644 --- a/src/spyglass/position/v1/dlc_utils.py +++ b/src/spyglass/position/v1/dlc_utils.py @@ -385,7 +385,17 @@ def infer_output_dir(key, makedir=True): """ # TODO: add check to make sure interval_list_name refers to a single epoch # Or make key include epoch in and of itself instead of interval_list_name - nwb_file_name = key["nwb_file_name"].split("_.")[0] + + file_name = key.get("nwb_file_name") + dlc_model_name = key.get("dlc_model_name") + epoch = key.get("epoch") + + if not all([file_name, dlc_model_name, epoch]): + raise ValueError( + "Key must contain 'nwb_file_name', 'dlc_model_name', and 'epoch'" + ) + + nwb_file_name = file_name.split("_.")[0] output_dir = pathlib.Path(dlc_output_dir) / pathlib.Path( f"{nwb_file_name}/{nwb_file_name}_{key['epoch']:02}" f"_model_" + key["dlc_model_name"].replace(" ", "-") @@ -1021,7 +1031,10 @@ def make_video( video.release() out.release() print("destroying cv2 windows") - cv2.destroyAllWindows() + try: + cv2.destroyAllWindows() + except cv2.error: # if cv is already closed or does not have func + pass print("finished making video with opencv") return diff --git a/src/spyglass/position/v1/position_dlc_pose_estimation.py b/src/spyglass/position/v1/position_dlc_pose_estimation.py index 42733fd0f..6ae7669bf 100644 --- a/src/spyglass/position/v1/position_dlc_pose_estimation.py +++ b/src/spyglass/position/v1/position_dlc_pose_estimation.py @@ -35,7 +35,7 @@ class DLCPoseEstimationSelection(SpyglassMixin, dj.Manual): """ @classmethod - def get_video_crop(cls, video_path): + def get_video_crop(cls, video_path, crop_input=None): """ Queries the user to determine the cropping parameters for a given video @@ -61,9 +61,13 @@ def get_video_crop(cls, video_path): ax.set_yticks(np.arange(ylims[0], ylims[-1], -50)) ax.grid(visible=True, color="white", lw=0.5, alpha=0.5) display(fig) - crop_input = input( - "Please enter the crop parameters for your video in format xmin, xmax, ymin, ymax, or 'none'\n" - ) + + if crop_input is None: + crop_input = input( + "Please enter the crop parameters for your video in format " + + "xmin, xmax, ymin, ymax, or 'none'\n" + ) + plt.close() if crop_input.lower() == "none": return None diff --git a/src/spyglass/position/v1/position_dlc_selection.py b/src/spyglass/position/v1/position_dlc_selection.py index 70400dbba..02692ce14 100644 --- a/src/spyglass/position/v1/position_dlc_selection.py +++ b/src/spyglass/position/v1/position_dlc_selection.py @@ -304,6 +304,8 @@ class DLCPosVideo(SpyglassMixin, dj.Computed): --- """ + # TODO: Shoultn't this keep track of the video file it creates? + def make(self, key): from tqdm import tqdm as tqdm @@ -432,3 +434,4 @@ def make(self, key): crop=crop, **params["video_params"], ) + self.insert1(key) diff --git a/src/spyglass/position/v1/position_dlc_training.py b/src/spyglass/position/v1/position_dlc_training.py index 6393a8a29..393eb6af9 100644 --- a/src/spyglass/position/v1/position_dlc_training.py +++ b/src/spyglass/position/v1/position_dlc_training.py @@ -108,7 +108,7 @@ class DLCModelTrainingSelection(SpyglassMixin, dj.Manual): """ def insert1(self, key, **kwargs): - training_id = key["training_id"] + training_id = key.get("training_id") if training_id is None: training_id = ( dj.U().aggr(self & key, n="max(training_id)").fetch1("n") or 0 diff --git a/tests/README.md b/tests/README.md index 476dbb4c8..20d4bd2bc 100644 --- a/tests/README.md +++ b/tests/README.md @@ -1,5 +1,19 @@ # PyTests +## Environment + +To facilitate headless testing of various Qt-based tools as well as Tensorflow, +`pyproject.toml` includes some environment variables associated with the +display. These are... + +- `QT_QPA_PLATFORM`: Set to `offscreen` to prevent the need for a display. +- `TF_ENABLE_ONEDNN_OPTS`: Set to `1` to enable Tensorflow optimizations. +- `TF_CPP_MIN_LOG_LEVEL`: Set to `2` to suppress Tensorflow warnings. + + + +## Options + This directory is contains files for testing the code. Simply by running `pytest` from the root directory, all tests will be run with default parameters specified in `pyproject.toml`. Notable optional parameters include... @@ -7,7 +21,7 @@ specified in `pyproject.toml`. Notable optional parameters include... - Coverage items. The coverage report indicates what percentage of the code was included in tests. - - `--cov=spyglatss`: Which package should be described in the coverage report + - `--cov=spyglass`: Which package should be described in the coverage report - `--cov-report term-missing`: Include lines of items missing in coverage - Verbosity. diff --git a/tests/common/test_behav.py b/tests/common/test_behav.py index e659cef5c..56ba037b2 100644 --- a/tests/common/test_behav.py +++ b/tests/common/test_behav.py @@ -79,11 +79,6 @@ def test_populate_state_script(common, pop_state_script): ), "StateScript populate unexpected effect" -@pytest.fixture(scope="session") -def video_keys(common): - return common.VideoFile().fetch(as_dict=True) - - @pytest.mark.usefixtures("skipif_noextras") def test_videofile_update_entries(common, video_keys): """Test update entries""" diff --git a/tests/conftest.py b/tests/conftest.py index 2e2f5633b..05b1a0f72 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -388,6 +388,14 @@ def populate_exception(): yield PopulateException +# -------------------------- FIXTURES, COMMON TABLES -------------------------- + + +@pytest.fixture(scope="session") +def video_keys(common): + return common.VideoFile().fetch(as_dict=True) + + # ------------------------- FIXTURES, POSITION TABLES ------------------------- @@ -428,7 +436,7 @@ def trodes_params(trodes_params_table, teardown): }, }, } - trodes_params_table.get_default() + _ = trodes_params_table.get_default() trodes_params_table.insert( [v for k, v in paramsets.items()], skip_duplicates=True ) @@ -778,10 +786,16 @@ def dlc_project_tbl(sgp): yield sgp.v1.DLCProject() +@pytest.fixture(scope="session") +def dlc_project_name(): + yield "pytest_proj" + + @pytest.fixture(scope="session") def insert_project( verbose_context, teardown, + dlc_project_name, dlc_project_tbl, common, bodyparts, @@ -791,7 +805,7 @@ def insert_project( common.LabTeam.insert1({"team_name": team_name}, skip_duplicates=True) with verbose_context: project_key = dlc_project_tbl.insert_new_project( - project_name="pytest_proj", + project_name=dlc_project_name, bodyparts=bodyparts, lab_team=team_name, frames_per_video=100, @@ -858,8 +872,14 @@ def extract_frames( ) vid_name = list(dlc_config["video_sets"].keys())[0].split("/")[-1] label_dir = project_dir / "labeled-data" / vid_name.split(".")[0] + yield label_dir + for file in label_dir.glob("*png"): + if file.stem in ["img000", "img001"]: + continue + file.unlink() + @pytest.fixture(scope="session") def labeled_vid_dir(extract_frames): @@ -889,11 +909,18 @@ def add_training_files(dlc_project_tbl, project_key, fix_downloaded): @pytest.fixture(scope="session") -def training_params_key(verbose_context, sgp, project_key): - training_params_name = "pytest" +def dlc_training_params(sgp): + params_tbl = sgp.v1.DLCModelTrainingParams() + params_name = "pytest" + yield params_tbl, params_name + + +@pytest.fixture(scope="session") +def training_params_key(verbose_context, sgp, project_key, dlc_training_params): + params_tbl, params_name = dlc_training_params with verbose_context: - sgp.v1.DLCModelTrainingParams.insert_new_params( - paramset_name=training_params_name, + params_tbl.insert_new_params( + paramset_name=params_name, params={ "trainingsetindex": 0, "shuffle": 1, @@ -901,10 +928,11 @@ def training_params_key(verbose_context, sgp, project_key): "TFGPUinference": False, "net_type": "resnet_50", "augmenter_type": "imgaug", + "video_sets": "test skipping param", }, skip_duplicates=True, ) - yield {"dlc_training_params_name": training_params_name} + yield {"dlc_training_params_name": params_name} @pytest.fixture(scope="session") @@ -913,7 +941,6 @@ def model_train_key(sgp, project_key, training_params_key): model_train_key = { **project_key, **training_params_key, - "training_id": 0, } sgp.v1.DLCModelTrainingSelection().insert1( { @@ -974,19 +1001,17 @@ def pose_estimation_key(sgp, mini_copy_name, populate_model, model_key): @pytest.fixture(scope="session") def populate_pose_estimation(sgp, pose_estimation_key): - pose_est_tbl = sgp.v1.DLCPoseEstimation - if pose_est_tbl & pose_estimation_key: - yield - else: + pose_est_tbl = sgp.v1.DLCPoseEstimation() + if len(pose_est_tbl & pose_estimation_key) < 1: pose_est_tbl.populate(pose_estimation_key) - yield + yield pose_est_tbl @pytest.fixture(scope="session") def si_params_name(sgp, populate_pose_estimation): params_name = "low_bar" params_tbl = sgp.v1.DLCSmoothInterpParams - # if len(params_tbl & {"dlc_si_params_name": params_name}) == 0: + # if len(params_tbl & {"dlc_si_params_name": params_name}) < 1: if True: # TODO: remove before merge nan_params = params_tbl.get_nan_params() nan_params["dlc_si_params_name"] = params_name @@ -995,6 +1020,9 @@ def si_params_name(sgp, populate_pose_estimation): "likelihood_thresh": 0.4, "max_cm_between_pts": 100, "num_inds_to_span": 50, + # Smoothing and Interpolation added later - must check + "smoothing_params": {"smoothing_duration": 0.05}, + "interp_params": {"max_cm_to_interp": 100}, } ) params_tbl.insert1(nan_params, skip_duplicates=True) diff --git a/tests/position/conftest.py b/tests/position/conftest.py index 1aaec3384..caf88448a 100644 --- a/tests/position/conftest.py +++ b/tests/position/conftest.py @@ -11,15 +11,6 @@ 58, 61, 69, 72, 97-100, 104, 149-161, 232-235, 239-241, 246, 259, 280, 293-305, 310-316, 328-341, 356-373, 395, 404, 480, 487-488, 530, 548-561, 594-601, 611-612, 641-657, 682-736, 762-772, 787, 809-1286 - -TODO: tests for -pose_estimat 51-71, 102, 115, 256, 345-366 -position.py 53, 99, 114, 119, 197-198, 205-219, 349, 353-355, 360, 382, 385, 407, 443-466 -project.py 45-54, 128-205, 250-255, 259, 278-303, 316, 347, 361-413, 425, 457, 476-479, 486-489, 514-555, 582, 596 -selection.py 213, 282, 308-417 -training.py 55, 67-73, 85-87, 113, 143-144, 161, 207-210 -es_position.py 67, 282-283, 361-362, 496, 502-503 - """ from itertools import product as iter_prodect @@ -33,7 +24,7 @@ def dlc_video_params(sgp): sgp.v1.DLCPosVideoParams.insert_default() params_key = {"dlc_pos_video_params_name": "five_percent"} - sgp.v1.DLCPosVideoSelection.insert1( + sgp.v1.DLCPosVideoParams.insert1( { **params_key, "params": { @@ -47,7 +38,7 @@ def dlc_video_params(sgp): @pytest.fixture(scope="session") -def dlc_video_selection(sgp, dlc_key, dlc_video_params): +def dlc_video_selection(sgp, dlc_key, dlc_video_params, populate_dlc): s_key = {**dlc_key, **dlc_video_params} sgp.v1.DLCPosVideoSelection.insert1(s_key, skip_duplicates=True) yield dlc_key @@ -56,7 +47,7 @@ def dlc_video_selection(sgp, dlc_key, dlc_video_params): @pytest.fixture(scope="session") def populate_dlc_video(sgp, dlc_video_selection): sgp.v1.DLCPosVideo.populate(dlc_video_selection) - yield + yield sgp.v1.DLCPosVideo() @pytest.fixture(scope="session") @@ -95,4 +86,4 @@ def increment_count(): count[0] += 1 return count[0] - return df.map(lambda x: increment_count() if x == 1 else x) + return df.applymap(lambda x: increment_count() if x == 1 else x) diff --git a/tests/position/test_dlc_cent.py b/tests/position/test_dlc_cent.py index 86ccba275..b312ebff5 100644 --- a/tests/position/test_dlc_cent.py +++ b/tests/position/test_dlc_cent.py @@ -10,22 +10,11 @@ def centroid_df(sgp, centroid_key, populate_centroid): yield (sgp.v1.DLCCentroid & centroid_key).fetch1_dataframe() -@pytest.mark.parametrize( - "column, exp_sum", - [ - ("video_frame_ind", 36312), - ("position_x", 17987), - ("position_y", 2983), - ("velocity_x", -1.489), - ("velocity_y", 4.160), - ("speed", 12957), - ], -) -def test_centroid_fetch1_dataframe(centroid_df, column, exp_sum): - tolerance = abs(centroid_df[column].iloc[0] * 0.1) +def test_centroid_fetch1_dataframe(centroid_df): + df_sum = centroid_df.sum().sum() assert np_isclose( - centroid_df[column].sum(), exp_sum, atol=tolerance - ), f"Sum of {column} in Centroid dataframe is not as expected" + df_sum, 55_860, atol=1000 + ), f"Unexpected checksum for centroid dataframe: {df_sum}" @pytest.fixture(scope="session") @@ -45,8 +34,8 @@ def test_insert_default_params(params_tbl): def test_validate_params(params_tbl): params = params_tbl.get_default() - params["dlc_centroid_params_name"] = "test" - params_tbl.insert1(params) + params["dlc_centroid_params_name"] = "other test" + params_tbl.insert1(params, skip_duplicates=True) @pytest.mark.parametrize( diff --git a/tests/position/test_dlc_pos.py b/tests/position/test_dlc_pos.py deleted file mode 100644 index df878c90c..000000000 --- a/tests/position/test_dlc_pos.py +++ /dev/null @@ -1,52 +0,0 @@ -import pytest -from numpy import isclose as np_isclose - - -def test_si_params_default(sgp): - assert sgp.v1.DLCSmoothInterpParams.get_default() == { - "dlc_si_params_name": "default", - "params": { - "interp_params": {"max_cm_to_interp": 15}, - "interpolate": True, - "likelihood_thresh": 0.95, - "max_cm_between_pts": 20, - "num_inds_to_span": 20, - "smooth": True, - "smoothing_params": { - "smooth_method": "moving_avg", - "smoothing_duration": 0.05, - }, - }, - } - assert sgp.v1.DLCSmoothInterpParams.get_nan_params() == { - "dlc_si_params_name": "just_nan", - "params": { - "interpolate": False, - "likelihood_thresh": 0.95, - "max_cm_between_pts": 20, - "num_inds_to_span": 20, - "smooth": False, - }, - } - - -@pytest.fixture(scope="session") -def si_df(sgp, si_key, populate_si, bodyparts): - yield ( - sgp.v1.DLCSmoothInterp() & {**si_key, "bodypart": bodyparts[0]} - ).fetch1_dataframe() - - -@pytest.mark.parametrize( - "column, exp_sum", - [ - ("video_frame_ind", 36312), - ("x", 17987), - ("y", 2983), - ], -) -def test_centroid_fetch1_dataframe(si_df, column, exp_sum): - tolerance = abs(si_df[column].iloc[0] * 0.1) - assert np_isclose( - si_df[column].sum(), exp_sum, atol=tolerance - ), f"Sum of {column} in SmoothInterp dataframe is not as expected" diff --git a/tests/position/test_dlc_pos_est.py b/tests/position/test_dlc_pos_est.py new file mode 100644 index 000000000..e011664ad --- /dev/null +++ b/tests/position/test_dlc_pos_est.py @@ -0,0 +1,34 @@ +import pytest + + +@pytest.fixture(scope="session") +def pos_est_sel(sgp): + yield sgp.v1.position_dlc_pose_estimation.DLCPoseEstimationSelection() + + +def test_rename_non_default_columns(sgp, common, pos_est_sel, video_keys): + vid_path, vid_name, _, _ = sgp.v1.dlc_utils.get_video_path(video_keys[0]) + + input = "0, 10, 0, 1000" + output = pos_est_sel.get_video_crop(vid_path + vid_name, input) + + assert ( + input == output + ), f"{pos_est_sel.table_name}.get_video_crop did not return expected output" + + +def test_invalid_video(pos_est_sel, pose_estimation_key): + _ = pose_estimation_key # Ensure populated + example_key = pos_est_sel.fetch("KEY", as_dict=True)[0] + example_key["nwb_file_name"] = "invalid.nwb" + with pytest.raises(FileNotFoundError): + pos_est_sel.insert_estimation_task(example_key) + + +def test_pose_est_dataframe(populate_pose_estimation): + pose_cols = populate_pose_estimation.fetch_dataframe().columns + + for bp in ["tailBase", "tailMid", "tailTip"]: + for val in ["video_frame_ind", "x", "y"]: + col = (bp, val) + assert col in pose_cols, f"PoseEstimation df missing column {col}." diff --git a/tests/position/test_dlc_position.py b/tests/position/test_dlc_position.py new file mode 100644 index 000000000..94646f315 --- /dev/null +++ b/tests/position/test_dlc_position.py @@ -0,0 +1,64 @@ +import pytest + + +@pytest.fixture(scope="session") +def si_params_tbl(sgp): + yield sgp.v1.DLCSmoothInterpParams() + + +def test_si_params_default(si_params_tbl): + assert si_params_tbl.get_default() == { + "dlc_si_params_name": "default", + "params": { + "interp_params": {"max_cm_to_interp": 15}, + "interpolate": True, + "likelihood_thresh": 0.95, + "max_cm_between_pts": 20, + "num_inds_to_span": 20, + "smooth": True, + "smoothing_params": { + "smooth_method": "moving_avg", + "smoothing_duration": 0.05, + }, + }, + } + assert si_params_tbl.get_nan_params() == { + "dlc_si_params_name": "just_nan", + "params": { + "interpolate": False, + "likelihood_thresh": 0.95, + "max_cm_between_pts": 20, + "num_inds_to_span": 20, + "smooth": False, + }, + } + assert list(si_params_tbl.get_available_methods()) == [ + "moving_avg" + ], f"{si_params_tbl.table_name}: unexpected available methods" + + +def test_invalid_params_insert(si_params_tbl): + with pytest.raises(KeyError): + si_params_tbl.insert1({"params": "invalid"}) + + +@pytest.fixture(scope="session") +def si_df(sgp, si_key, populate_si, bodyparts): + yield ( + sgp.v1.DLCSmoothInterp() & {**si_key, "bodypart": bodyparts[0]} + ).fetch1_dataframe() + + +def test_cohort_fetch1_dataframe(si_df): + df_cols = si_df.columns + exp_cols = ["video_frame_ind", "x", "y"] + assert all( + e in df_cols for e in exp_cols + ), f"Unexpected cols in DLCSmoothInterp dataframe: {df_cols}" + + +def test_all_nans(populate_pose_estimation, sgp): + pose_est_tbl = populate_pose_estimation + df = pose_est_tbl.BodyPart().fetch1_dataframe() + with pytest.raises(ValueError): + sgp.v1.position_dlc_position.nan_inds(df, 10, 0.99, 10) diff --git a/tests/position/test_dlc_proj.py b/tests/position/test_dlc_proj.py index d3236b1c5..7eaba196d 100644 --- a/tests/position/test_dlc_proj.py +++ b/tests/position/test_dlc_proj.py @@ -1,7 +1,68 @@ +import pytest + + +def test_bp_insert(sgp): + bp_tbl = sgp.v1.position_dlc_project.BodyPart() + + bp_w_desc, desc = "test_bp", "test_desc" + bp_no_desc = "test_bp_no_desc" + + bp_tbl.add_from_config([bp_w_desc], [desc]) + bp_tbl.add_from_config([bp_no_desc]) + + assert bp_tbl & { + "bodypart": bp_w_desc, + "description": desc, + }, "Bodypart with description not inserted correctly" + assert bp_tbl & { + "bodypart": bp_no_desc, + "description": bp_no_desc, + }, "Bodypart without description not inserted correctly" + + def test_project_insert(dlc_project_tbl, project_key): assert dlc_project_tbl & project_key, "Project not inserted correctly" +@pytest.fixture +def new_project_key(): + return { + "project_name": "test_project_name", + "bodyparts": ["bp1"], + "lab_team": "any", + "frames_per_video": 1, + "video_list": ["any"], + "groupname": "fake group", + } + + +def test_failed_name_insert( + dlc_project_tbl, dlc_project_name, config_path, new_project_key +): + new_project_key.update({"project_name": dlc_project_name}) + existing_key = dlc_project_tbl.insert_new_project( + project_name=dlc_project_name, + bodyparts=["bp1"], + lab_team="any", + frames_per_video=1, + video_list=["any"], + groupname="any", + ) + expected_key = { + "project_name": dlc_project_name, + "config_path": config_path, + } + assert ( + existing_key == expected_key + ), "Project re-insert did not return expected key" + + +def test_failed_group_insert(dlc_project_tbl, new_project_key): + with pytest.raises(ValueError): + dlc_project_tbl.insert_new_project(**new_project_key) + + def test_extract_frames(extract_frames, labeled_vid_dir): extracted_files = list(labeled_vid_dir.glob("*.png")) - assert len(extracted_files) == 4, "Incorrect number of frames extracted" + stems = set([f.stem for f in extracted_files]) - {"img000", "img001"} + assert len(stems) == 2, "Incorrect number of frames extracted" diff --git a/tests/position/test_dlc_sel.py b/tests/position/test_dlc_sel.py index 9d869f8bb..35b33fe06 100644 --- a/tests/position/test_dlc_sel.py +++ b/tests/position/test_dlc_sel.py @@ -1,5 +1,5 @@ def test_dlcvideo_default(sgp): - assert sgp.v1.DLCPosVideoParams.get_default() == { + expected_default = { "dlc_pos_video_params_name": "default", "params": { "incl_likelihood": True, @@ -7,3 +7,11 @@ def test_dlcvideo_default(sgp): "video_params": {"arrow_radius": 20, "circle_radius": 6}, }, } + + # run twice to trigger fetch existing + assert sgp.v1.DLCPosVideoParams.get_default() == expected_default + assert sgp.v1.DLCPosVideoParams.get_default() == expected_default + + +def test_dlc_video_populate(populate_dlc_video): + assert len(populate_dlc_video) > 0, "DLCPosVideo table is empty" diff --git a/tests/position/test_dlc_train.py b/tests/position/test_dlc_train.py new file mode 100644 index 000000000..58204abb4 --- /dev/null +++ b/tests/position/test_dlc_train.py @@ -0,0 +1,30 @@ +def test_existing_params( + verbose_context, dlc_training_params, training_params_key +): + params_tbl, params_name = dlc_training_params + + _ = training_params_key # Ensure populated + params_query = params_tbl & {"dlc_training_params_name": params_name} + assert params_query, "Existing params not found" + + with verbose_context: + params_tbl.insert_new_params( + paramset_name=params_name, + params={ + "shuffle": 1, + "trainingsetindex": 0, + "net_type": "any", + "gputouse": None, + }, + skip_duplicates=False, + ) + + assert len(params_query) == 1, "Existing params duplicated" + + +def test_get_params(verbose_context, dlc_training_params): + params_tbl, _ = dlc_training_params + with verbose_context: + accepted_params = params_tbl.get_accepted_params() + + assert accepted_params is not None, "Failed to get accepted params" diff --git a/tests/position/test_pos_merge.py b/tests/position/test_pos_merge.py index af6b17e0f..047129cd5 100644 --- a/tests/position/test_pos_merge.py +++ b/tests/position/test_pos_merge.py @@ -1,5 +1,4 @@ import pytest -from numpy import isclose as np_isclose @pytest.fixture(scope="session") @@ -8,19 +7,18 @@ def merge_df(sgp, pos_merge, dlc_key, populate_dlc): yield (pos_merge & merge_key).fetch1_dataframe() -@pytest.mark.parametrize( - "column, exp_sum", - [ # NOTE: same as test_centroid_fetch1_dataframe - ("video_frame_ind", 36312), - ("position_x", 17987), - ("position_y", 2983), - ("velocity_x", -1.489), - ("velocity_y", 4.160), - ("speed", 12957), - ], -) -def test_merge_dlc_fetch1_dataframe(merge_df, column, exp_sum): - tolerance = abs(merge_df[column].iloc[0] * 0.1) - assert np_isclose( - merge_df[column].sum(), exp_sum, atol=tolerance - ), f"Sum of {column} in Merge.DLCPosV1 dataframe is not as expected" +def test_merge_dlc_fetch1_dataframe(merge_df): + df_cols = merge_df.columns + exp_cols = [ + "video_frame_ind", + "position_x", + "position_y", + "orientation", + "velocity_x", + "velocity_y", + "speed", + ] + + assert all( + e in df_cols for e in exp_cols + ), f"Unexpected cols in position merge dataframe: {df_cols}"