From e94e68127d09aff142765fbc62237319b6c76cdf Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Tue, 17 Sep 2024 18:49:49 +0100 Subject: [PATCH] Make existing tests pass --- movement/io/load_bboxes.py | 16 +++++++++------- tests/test_unit/test_load_bboxes.py | 4 +++- .../test_datasets_validators.py | 18 ++++++++++++------ 3 files changed, 24 insertions(+), 14 deletions(-) diff --git a/movement/io/load_bboxes.py b/movement/io/load_bboxes.py index 1fbee4fcb..3e239b127 100644 --- a/movement/io/load_bboxes.py +++ b/movement/io/load_bboxes.py @@ -350,7 +350,7 @@ def from_via_tracks_file( def _numpy_arrays_from_via_tracks_file( - file_path: Path, frame_regexp: str + file_path: Path, frame_regexp: str = r"(0\d*)\.\w+$" ) -> dict: """Extract numpy arrays from the input VIA tracks .csv file. @@ -376,9 +376,9 @@ def _numpy_arrays_from_via_tracks_file( frame_regexp : str Regular expression pattern to extract the frame number from the - filename. The frame number is expected to be encoded in the filename - as an integer number led by at least one zero, followed by the file - extension. + filename. By default, the frame number is expected to be encoded in + the filename as an integer number led by at least one zero, followed + by the file extension. Returns ------- @@ -428,7 +428,7 @@ def _numpy_arrays_from_via_tracks_file( def _df_from_via_tracks_file( - file_path: Path, frame_regexp: str + file_path: Path, frame_regexp: str = r"(0\d*)\.\w+$" ) -> pd.DataFrame: """Load VIA tracks .csv file as a dataframe. @@ -495,7 +495,7 @@ def _df_from_via_tracks_file( return df -def _extract_confidence_from_via_tracks_df(df) -> np.ndarray: +def _extract_confidence_from_via_tracks_df(df: pd.DataFrame) -> np.ndarray: """Extract confidence scores from the VIA tracks input dataframe. Parameters @@ -526,7 +526,9 @@ def _extract_confidence_from_via_tracks_df(df) -> np.ndarray: return bbox_confidence -def _extract_frame_number_from_via_tracks_df(df, frame_regexp) -> np.ndarray: +def _extract_frame_number_from_via_tracks_df( + df: pd.DataFrame, frame_regexp: str = r"(0\d*)\.\w+$" +) -> np.ndarray: """Extract frame numbers from the VIA tracks input dataframe. Parameters diff --git a/tests/test_unit/test_load_bboxes.py b/tests/test_unit/test_load_bboxes.py index 474e61183..5b7dff6ea 100644 --- a/tests/test_unit/test_load_bboxes.py +++ b/tests/test_unit/test_load_bboxes.py @@ -425,7 +425,9 @@ def test_df_from_via_tracks_file(via_tracks_file): """Test that the helper function correctly reads the VIA tracks .csv file as a dataframe. """ - df = load_bboxes._df_from_via_tracks_file(via_tracks_file) + df = load_bboxes._df_from_via_tracks_file( + file_path=via_tracks_file, + ) assert isinstance(df, pd.DataFrame) assert len(df.frame_number.unique()) == 5 diff --git a/tests/test_unit/test_validators/test_datasets_validators.py b/tests/test_unit/test_validators/test_datasets_validators.py index 493f1d460..fe5d7d5ee 100644 --- a/tests/test_unit/test_validators/test_datasets_validators.py +++ b/tests/test_unit/test_validators/test_datasets_validators.py @@ -360,10 +360,16 @@ def test_bboxes_dataset_validator_confidence_array( f"Expected a numpy array, but got {type(list())}.", ), # not an ndarray, should raise ValueError ( - np.array([1, 2, 3, 4, 6, 7, 8, 9, 10, 11]).reshape(-1, 1), + np.array([1, 2, 3, 6, 7, 8, 4, 9, 10, 11]).reshape(-1, 1), pytest.raises(ValueError), - "Frame numbers in frame_array are not continuous.", - ), # frame numbers are not continuous + "Frame numbers in frame_array are not monotonically increasing.", + ), + ( + np.array([1, 2, 3, 5, 6, 7, 8, 9, 10, 11]).reshape(-1, 1), + does_not_raise(), + "", + ), # valid, frame numbers are not continuous but are monotonically + # increasing ( None, does_not_raise(), @@ -389,10 +395,10 @@ def test_bboxes_dataset_validator_frame_array( frame_array=frame_array, ) - if frame_array is None: + if frame_array is not None: + assert str(getattr(excinfo, "value", "")) == log_message + else: n_frames = ds.position_array.shape[0] default_frame_array = np.arange(n_frames).reshape(-1, 1) assert np.array_equal(ds.frame_array, default_frame_array) assert ds.frame_array.shape == (ds.position_array.shape[0], 1) - else: - assert str(excinfo.value) == log_message