Skip to content

Commit

Permalink
Reduce test code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
lochhh committed Nov 28, 2024
1 parent 6711a06 commit 24240af
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 106 deletions.
47 changes: 47 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,3 +882,50 @@ def count_consecutive_nans(da):
def helpers():
"""Return an instance of the ``Helpers`` class."""
return Helpers


# --------- movement dataset assertion fixtures ---------
class MovementDatasetAsserts:
"""Class for asserting valid ``movement`` poses or bboxes datasets."""

@staticmethod
def valid_dataset(dataset, expected_values):
"""Assert the dataset is a proper ``movement`` Dataset."""
expected_dim_names = expected_values.get("dim_names")
expected_file_path = expected_values.get("file_path")
assert isinstance(dataset, xr.Dataset)
# Expected variables are present and of right shape/type
for var, ndim in expected_values.get("vars_dims").items():
data_var = dataset.get(var)
assert isinstance(data_var, xr.DataArray)
assert data_var.ndim == ndim
position_shape = dataset.position.shape
# Confidence has the same shape as position, except for the space dim
assert (
dataset.confidence.shape == position_shape[:1] + position_shape[2:]
)
# Check the dims and coords
expected_dim_length_dict = dict(
zip(expected_dim_names, position_shape, strict=True)
)
assert expected_dim_length_dict == dataset.sizes
# Check the coords
for dim in expected_dim_names[1:]:
assert all(isinstance(s, str) for s in dataset.coords[dim].values)
assert all(coord in dataset.coords["space"] for coord in ["x", "y"])
# Check the metadata attributes
assert dataset.source_file == (
expected_file_path.as_posix()
if expected_file_path is not None
else None
)
assert dataset.source_software == expected_values.get(
"source_software"
)
assert dataset.fps == expected_values.get("fps")


@pytest.fixture
def movement_dataset_asserts():
"""Return an instance of the ``MovementDatasetAsserts`` class."""
return MovementDatasetAsserts
98 changes: 37 additions & 61 deletions tests/test_unit/test_load_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np
import pandas as pd
import pytest
import xarray as xr

from movement.io import load_bboxes
from movement.validators.datasets import ValidBboxesDataset
Expand Down Expand Up @@ -112,49 +111,6 @@ def update_attribute_column(df_input, attribute_column_name, dict_to_append):
return df


def assert_dataset(
dataset, file_path=None, expected_source_software=None, expected_fps=None
):
"""Assert that the dataset is a proper ``movement`` Dataset."""
assert isinstance(dataset, xr.Dataset)

# Expected variables are present and of right shape/type
for var in ["position", "shape", "confidence"]:
assert var in dataset.data_vars
assert isinstance(dataset[var], xr.DataArray)
assert dataset.position.ndim == 3
assert dataset.shape.ndim == 3
position_shape = dataset.position.shape
# Confidence has the same shape as position, except for the space dim
assert dataset.confidence.shape == position_shape[:1] + position_shape[2:]
# Check the dims and coords
dim_names = ValidBboxesDataset.DIM_NAMES
expected_dim_length_dict = dict(
zip(dim_names, position_shape, strict=True)
)
assert expected_dim_length_dict == dataset.sizes
# Check the coords
for dim in dim_names[1:]:
assert all(isinstance(s, str) for s in dataset.coords[dim].values)
assert all(coord in dataset.coords["space"] for coord in ["x", "y"])
# Check the metadata attributes
assert (
dataset.source_file is None
if file_path is None
else dataset.source_file == file_path.as_posix()
)
assert (
dataset.source_software is None
if expected_source_software is None
else dataset.source_software == expected_source_software
)
assert (
dataset.fps is None
if expected_fps is None
else dataset.fps == expected_fps
)


def assert_time_coordinates(ds, fps, start_frame):
"""Assert that the time coordinates are as expected, depending on
fps value and start_frame.
Expand Down Expand Up @@ -210,10 +166,16 @@ def test_from_file(source_software, fps, use_frame_numbers_from_file):
)


expected_values_bboxes = {
"vars_dims": {"position": 3, "shape": 3, "confidence": 2},
"dim_names": ValidBboxesDataset.DIM_NAMES,
}


@pytest.mark.parametrize("fps", [None, 30, 60.0])
@pytest.mark.parametrize("use_frame_numbers_from_file", [True, False])
def test_from_via_tracks_file(
via_tracks_file, fps, use_frame_numbers_from_file
via_tracks_file, fps, use_frame_numbers_from_file, movement_dataset_asserts
):
"""Test that loading tracked bounding box data from
a valid VIA tracks .csv file returns a proper Dataset
Expand All @@ -223,8 +185,13 @@ def test_from_via_tracks_file(
ds = load_bboxes.from_via_tracks_file(
via_tracks_file, fps, use_frame_numbers_from_file
)
assert_dataset(ds, via_tracks_file, "VIA-tracks", fps)

expected_values = {
**expected_values_bboxes,
"source_software": "VIA-tracks",
"fps": fps,
"file_path": via_tracks_file,
}
movement_dataset_asserts.valid_dataset(ds, expected_values)
# check time coordinates are as expected
# in sample VIA tracks .csv file frame numbers start from 1
start_frame = 1 if use_frame_numbers_from_file else 0
Expand All @@ -240,28 +207,36 @@ def test_from_via_tracks_file(
)
@pytest.mark.parametrize("fps", [None, 30, 60.0])
@pytest.mark.parametrize("source_software", [None, "VIA-tracks"])
def test_from_numpy(valid_from_numpy_inputs, fps, source_software, request):
def test_from_numpy(
valid_from_numpy_inputs,
fps,
source_software,
movement_dataset_asserts,
request,
):
"""Test that loading bounding boxes trajectories from the input
numpy arrays returns a proper Dataset.
"""
# get the input arrays
from_numpy_inputs = request.getfixturevalue(valid_from_numpy_inputs)

# run general dataset checks
ds = load_bboxes.from_numpy(
**from_numpy_inputs,
fps=fps,
source_software=source_software,
)
assert_dataset(
ds, expected_source_software=source_software, expected_fps=fps
)

expected_values = {
**expected_values_bboxes,
"source_software": source_software,
"fps": fps,
}
movement_dataset_asserts.valid_dataset(ds, expected_values)
# check time coordinates are as expected
if "frame_array" in from_numpy_inputs:
start_frame = from_numpy_inputs["frame_array"][0, 0]
else:
start_frame = 0
start_frame = (
from_numpy_inputs["frame_array"][0, 0]
if "frame_array" in from_numpy_inputs
else 0
)
assert_time_coordinates(ds, fps, start_frame)


Expand Down Expand Up @@ -417,10 +392,11 @@ def test_fps_and_time_coords(
assert ds.fps == expected_fps

# check time coordinates
if use_frame_numbers_from_file:
start_frame = ds_in_frames_from_file.coords["time"].data[0]
else:
start_frame = 0
start_frame = (
ds_in_frames_from_file.coords["time"].data[0]
if use_frame_numbers_from_file
else 0
)
assert_time_coordinates(ds, expected_fps, start_frame)


Expand Down
86 changes: 41 additions & 45 deletions tests/test_unit/test_load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,48 +68,23 @@ def sleap_file_without_tracks(request):
return request.getfixturevalue(request.param)


def assert_dataset(dataset, file_path=None, expected_source_software=None):
"""Assert that the dataset is a proper xarray Dataset."""
assert isinstance(dataset, xr.Dataset)
# Expected variables are present and of right shape/type
for var in ["position", "confidence"]:
assert var in dataset.data_vars
assert isinstance(dataset[var], xr.DataArray)
assert dataset.position.ndim == 4

position_shape = dataset.position.shape
# Confidence has the same shape as position, except for the space dim
assert dataset.confidence.shape == position_shape[:1] + position_shape[2:]
# Check the dims
dim_names = ValidPosesDataset.DIM_NAMES
expected_dim_length_dict = dict(
zip(dim_names, position_shape, strict=True)
)
assert expected_dim_length_dict == dataset.sizes
# Check the coords
for dim in dim_names[1:]:
assert all(isinstance(s, str) for s in dataset.coords[dim].values)
assert all(coord in dataset.coords["space"] for coord in ["x", "y"])
# Check the metadata attributes
assert (
dataset.source_file is None
if file_path is None
else dataset.source_file == file_path.as_posix()
)
assert (
dataset.source_software is None
if expected_source_software is None
else dataset.source_software == expected_source_software
)
assert dataset.fps is None
expected_values_poses = {
"vars_dims": {"position": 4, "confidence": 3},
"dim_names": ValidPosesDataset.DIM_NAMES,
}


def test_load_from_sleap_file(sleap_file):
def test_load_from_sleap_file(sleap_file, movement_dataset_asserts):
"""Test that loading pose tracks from valid SLEAP files
returns a proper Dataset.
"""
ds = load_poses.from_sleap_file(sleap_file)
assert_dataset(ds, sleap_file, "SLEAP")
expected_values = {
**expected_values_poses,
"source_software": "SLEAP",
"file_path": sleap_file,
}
movement_dataset_asserts.valid_dataset(ds, expected_values)


def test_load_from_sleap_file_without_tracks(sleap_file_without_tracks):
Expand Down Expand Up @@ -167,26 +142,37 @@ def test_load_from_sleap_slp_file_or_h5_file_returns_same(slp_file, h5_file):
"DLC_two-mice.predictions.csv",
],
)
def test_load_from_dlc_file(file_name):
def test_load_from_dlc_file(file_name, movement_dataset_asserts):
"""Test that loading pose tracks from valid DLC files
returns a proper Dataset.
"""
file_path = DATA_PATHS.get(file_name)
ds = load_poses.from_dlc_file(file_path)
assert_dataset(ds, file_path, "DeepLabCut")
expected_values = {
**expected_values_poses,
"source_software": "DeepLabCut",
"file_path": file_path,
}
movement_dataset_asserts.valid_dataset(ds, expected_values)


@pytest.mark.parametrize(
"source_software", ["DeepLabCut", "LightningPose", None]
)
def test_load_from_dlc_style_df(dlc_style_df, source_software):
def test_load_from_dlc_style_df(
dlc_style_df, source_software, movement_dataset_asserts
):
"""Test that loading pose tracks from a valid DLC-style DataFrame
returns a proper Dataset.
"""
ds = load_poses.from_dlc_style_df(
dlc_style_df, source_software=source_software
)
assert_dataset(ds, expected_source_software=source_software)
expected_values = {
**expected_values_poses,
"source_software": source_software,
}
movement_dataset_asserts.valid_dataset(ds, expected_values)


def test_load_from_dlc_file_csv_or_h5_file_returns_same():
Expand Down Expand Up @@ -234,13 +220,18 @@ def test_fps_and_time_coords(fps, expected_fps, expected_time_unit):
"LP_mouse-twoview_AIND.predictions.csv",
],
)
def test_load_from_lp_file(file_name):
def test_load_from_lp_file(file_name, movement_dataset_asserts):
"""Test that loading pose tracks from valid LightningPose (LP) files
returns a proper Dataset.
"""
file_path = DATA_PATHS.get(file_name)
ds = load_poses.from_lp_file(file_path)
assert_dataset(ds, file_path, "LightningPose")
expected_values = {
**expected_values_poses,
"source_software": "LightningPose",
"file_path": file_path,
}
movement_dataset_asserts.valid_dataset(ds, expected_values)


def test_load_from_lp_or_dlc_file_returns_same():
Expand Down Expand Up @@ -289,14 +280,15 @@ def test_from_file_delegates_correctly(source_software, fps):


@pytest.mark.parametrize("source_software", [None, "SLEAP"])
def test_from_numpy_valid(valid_position_array, source_software):
def test_from_numpy_valid(
valid_position_array, source_software, movement_dataset_asserts
):
"""Test that loading pose tracks from a multi-animal numpy array
with valid parameters returns a proper Dataset.
"""
valid_position = valid_position_array("multi_individual_array")
rng = np.random.default_rng(seed=42)
valid_confidence = rng.random(valid_position.shape[:-1])

ds = load_poses.from_numpy(
valid_position,
valid_confidence,
Expand All @@ -305,7 +297,11 @@ def test_from_numpy_valid(valid_position_array, source_software):
fps=None,
source_software=source_software,
)
assert_dataset(ds, expected_source_software=source_software)
expected_values = {
**expected_values_poses,
"source_software": source_software,
}
movement_dataset_asserts.valid_dataset(ds, expected_values)


def test_from_multiview_files():
Expand Down

0 comments on commit 24240af

Please sign in to comment.