Skip to content

Commit

Permalink
Make existing tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig committed Sep 17, 2024
1 parent 64eb1de commit e94e681
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
16 changes: 9 additions & 7 deletions movement/io/load_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def from_via_tracks_file(


def _numpy_arrays_from_via_tracks_file(
file_path: Path, frame_regexp: str
file_path: Path, frame_regexp: str = r"(0\d*)\.\w+$"
) -> dict:
"""Extract numpy arrays from the input VIA tracks .csv file.
Expand All @@ -376,9 +376,9 @@ def _numpy_arrays_from_via_tracks_file(
frame_regexp : str
Regular expression pattern to extract the frame number from the
filename. The frame number is expected to be encoded in the filename
as an integer number led by at least one zero, followed by the file
extension.
filename. By default, the frame number is expected to be encoded in
the filename as an integer number led by at least one zero, followed
by the file extension.
Returns
-------
Expand Down Expand Up @@ -428,7 +428,7 @@ def _numpy_arrays_from_via_tracks_file(


def _df_from_via_tracks_file(
file_path: Path, frame_regexp: str
file_path: Path, frame_regexp: str = r"(0\d*)\.\w+$"
) -> pd.DataFrame:
"""Load VIA tracks .csv file as a dataframe.
Expand Down Expand Up @@ -495,7 +495,7 @@ def _df_from_via_tracks_file(
return df


def _extract_confidence_from_via_tracks_df(df) -> np.ndarray:
def _extract_confidence_from_via_tracks_df(df: pd.DataFrame) -> np.ndarray:
"""Extract confidence scores from the VIA tracks input dataframe.
Parameters
Expand Down Expand Up @@ -526,7 +526,9 @@ def _extract_confidence_from_via_tracks_df(df) -> np.ndarray:
return bbox_confidence


def _extract_frame_number_from_via_tracks_df(df, frame_regexp) -> np.ndarray:
def _extract_frame_number_from_via_tracks_df(
df: pd.DataFrame, frame_regexp: str = r"(0\d*)\.\w+$"
) -> np.ndarray:
"""Extract frame numbers from the VIA tracks input dataframe.
Parameters
Expand Down
4 changes: 3 additions & 1 deletion tests/test_unit/test_load_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,9 @@ def test_df_from_via_tracks_file(via_tracks_file):
"""Test that the helper function correctly reads the VIA tracks .csv file
as a dataframe.
"""
df = load_bboxes._df_from_via_tracks_file(via_tracks_file)
df = load_bboxes._df_from_via_tracks_file(
file_path=via_tracks_file,
)

assert isinstance(df, pd.DataFrame)
assert len(df.frame_number.unique()) == 5
Expand Down
18 changes: 12 additions & 6 deletions tests/test_unit/test_validators/test_datasets_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,16 @@ def test_bboxes_dataset_validator_confidence_array(
f"Expected a numpy array, but got {type(list())}.",
), # not an ndarray, should raise ValueError
(
np.array([1, 2, 3, 4, 6, 7, 8, 9, 10, 11]).reshape(-1, 1),
np.array([1, 2, 3, 6, 7, 8, 4, 9, 10, 11]).reshape(-1, 1),
pytest.raises(ValueError),
"Frame numbers in frame_array are not continuous.",
), # frame numbers are not continuous
"Frame numbers in frame_array are not monotonically increasing.",
),
(
np.array([1, 2, 3, 5, 6, 7, 8, 9, 10, 11]).reshape(-1, 1),
does_not_raise(),
"",
), # valid, frame numbers are not continuous but are monotonically
# increasing
(
None,
does_not_raise(),
Expand All @@ -389,10 +395,10 @@ def test_bboxes_dataset_validator_frame_array(
frame_array=frame_array,
)

if frame_array is None:
if frame_array is not None:
assert str(getattr(excinfo, "value", "")) == log_message
else:
n_frames = ds.position_array.shape[0]
default_frame_array = np.arange(n_frames).reshape(-1, 1)
assert np.array_equal(ds.frame_array, default_frame_array)
assert ds.frame_array.shape == (ds.position_array.shape[0], 1)
else:
assert str(excinfo.value) == log_message

0 comments on commit e94e681

Please sign in to comment.