Skip to content

Commit

Permalink
Fix and relax bboxes requirements (#313)
Browse files Browse the repository at this point in the history
* Fix for single individual

* Relax requirement of continuous frames for bboxes dataset (require only monotonically increasing frame numbers)

* Relax requirement for frame number

* Make existing tests pass

* test_from_via_tracks_file and test_via_attribute_column_to_numpy pass (WIP)

* Clarify docstring

* Fix and simplify existing tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix and simplify existing tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix sonarcloud issue

* Add test for case in which frame is defined under file_attribute

* Simplify frame number extraction test by splitting it

* Use default regexp global variable

* Add docstrings to fixtures

* Make fixtures for new frame check more evident

* Fix from numpy docstring

* Fix docstring in validator re regexp for frame number

* Add value error to assert_time_coordinates

* Remove note to self

* Expose frame regexp

* Add try except pattern for regexp match

* Extend and add tests for frame_regexp possible values

* Clarify error message for validator about frame regexp

* Simplify try except regexp block

* Fix tests

* Simplify

* Adapt tests

* Simplify frame extraction

* Small edits

* Clarify docstrings

* Specify exception caught when extracting frame number from file_attribute

* Fix error message referring to default regexp only, and adapt tests

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
sfmig and pre-commit-ci[bot] authored Dec 5, 2024
1 parent 4de4963 commit f7539b9
Show file tree
Hide file tree
Showing 6 changed files with 623 additions and 242 deletions.
82 changes: 65 additions & 17 deletions movement/io/load_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@

from movement.utils.logging import log_error
from movement.validators.datasets import ValidBboxesDataset
from movement.validators.files import ValidFile, ValidVIATracksCSV
from movement.validators.files import (
DEFAULT_FRAME_REGEXP,
ValidFile,
ValidVIATracksCSV,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -58,7 +62,7 @@ def from_numpy(
bounding boxes are defined. If None (default), frame numbers will
be assigned based on the first dimension of the ``position_array``,
starting from 0. If a specific array of frame numbers is provided,
these need to be consecutive integers.
these need to be integers sorted in increasing order.
fps : float, optional
The video sampling rate. If None (default), the ``time`` coordinates
of the resulting ``movement`` dataset will be in frame numbers. If
Expand Down Expand Up @@ -151,6 +155,7 @@ def from_file(
source_software: Literal["VIA-tracks"],
fps: float | None = None,
use_frame_numbers_from_file: bool = False,
frame_regexp: str = DEFAULT_FRAME_REGEXP,
) -> xr.Dataset:
"""Create a ``movement`` bounding boxes dataset from a supported file.
Expand Down Expand Up @@ -180,6 +185,13 @@ def from_file(
full video as the time origin. If False (default), the frame numbers
in the VIA tracks .csv file are instead mapped to a 0-based sequence of
consecutive integers.
frame_regexp : str, optional
Regular expression pattern to extract the frame number from the frame
filename. By default, the frame number is expected to be encoded in
the filename as an integer number led by at least one zero, followed
by the file extension. Only used if ``use_frame_numbers_from_file`` is
True.
Returns
-------
Expand Down Expand Up @@ -214,6 +226,7 @@ def from_file(
file_path,
fps,
use_frame_numbers_from_file=use_frame_numbers_from_file,
frame_regexp=frame_regexp,
)
else:
raise log_error(
Expand All @@ -225,6 +238,7 @@ def from_via_tracks_file(
file_path: Path | str,
fps: float | None = None,
use_frame_numbers_from_file: bool = False,
frame_regexp: str = DEFAULT_FRAME_REGEXP,
) -> xr.Dataset:
"""Create a ``movement`` dataset from a VIA tracks .csv file.
Expand All @@ -248,6 +262,12 @@ def from_via_tracks_file(
but you want to maintain the start of the full video as the time
origin. If False (default), the frame numbers in the VIA tracks .csv
file are instead mapped to a 0-based sequence of consecutive integers.
frame_regexp : str, optional
Regular expression pattern to extract the frame number from the frame
filename. By default, the frame number is expected to be encoded in
the filename as an integer number led by at least one zero, followed
by the file extension. Only used if ``use_frame_numbers_from_file`` is
True.
Returns
-------
Expand Down Expand Up @@ -316,17 +336,19 @@ def from_via_tracks_file(
)

# Specific VIA-tracks .csv file validation
via_file = ValidVIATracksCSV(file.path)
via_file = ValidVIATracksCSV(file.path, frame_regexp=frame_regexp)
logger.debug(f"Validated VIA tracks .csv file {via_file.path}.")

# Create an xarray.Dataset from the data
bboxes_arrays = _numpy_arrays_from_via_tracks_file(via_file.path)
bboxes_arrays = _numpy_arrays_from_via_tracks_file(
via_file.path, via_file.frame_regexp
)
ds = from_numpy(
position_array=bboxes_arrays["position_array"],
shape_array=bboxes_arrays["shape_array"],
confidence_array=bboxes_arrays["confidence_array"],
individual_names=[
f"id_{id}" for id in bboxes_arrays["ID_array"].squeeze()
f"id_{id.item()}" for id in bboxes_arrays["ID_array"]
],
frame_array=(
bboxes_arrays["frame_array"]
Expand All @@ -346,7 +368,9 @@ def from_via_tracks_file(
return ds


def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict:
def _numpy_arrays_from_via_tracks_file(
file_path: Path, frame_regexp: str = DEFAULT_FRAME_REGEXP
) -> dict:
"""Extract numpy arrays from the input VIA tracks .csv file.
The extracted numpy arrays are returned in a dictionary with the following
Expand All @@ -369,6 +393,12 @@ def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict:
file_path : pathlib.Path
Path to the VIA tracks .csv file containing the bounding boxes' tracks.
frame_regexp : str
Regular expression pattern to extract the frame number from the frame
filename. By default, the frame number is expected to be encoded in
the filename as an integer number led by at least one zero, followed
by the file extension.
Returns
-------
dict
Expand All @@ -378,7 +408,7 @@ def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict:
# Extract 2D dataframe from input data
# (sort data by ID and frame number, and
# fill empty frame-ID pairs with nans)
df = _df_from_via_tracks_file(file_path)
df = _df_from_via_tracks_file(file_path, frame_regexp)

# Compute indices of the rows where the IDs switch
bool_id_diff_from_prev = df["ID"].ne(df["ID"].shift()) # pandas series
Expand All @@ -398,8 +428,11 @@ def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict:
df[map_key_to_columns[key]].to_numpy(),
indices_id_switch, # indices along axis=0
)
array_dict[key] = np.stack(list_arrays, axis=1)

array_dict[key] = np.stack(list_arrays, axis=1).squeeze()
# squeeze only last dimension if it is 1
if array_dict[key].shape[-1] == 1:
array_dict[key] = array_dict[key].squeeze(axis=-1)

# Transform position_array to represent centroid of bbox,
# rather than top-left corner
Expand All @@ -413,7 +446,9 @@ def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict:
return array_dict


def _df_from_via_tracks_file(file_path: Path) -> pd.DataFrame:
def _df_from_via_tracks_file(
file_path: Path, frame_regexp: str = DEFAULT_FRAME_REGEXP
) -> pd.DataFrame:
"""Load VIA tracks .csv file as a dataframe.
Read the VIA tracks .csv file as a pandas dataframe with columns:
Expand All @@ -429,6 +464,10 @@ def _df_from_via_tracks_file(file_path: Path) -> pd.DataFrame:
empty frames are filled in with NaNs. The coordinates of the bboxes
are assumed to be in the image coordinate system (i.e., the top-left
corner of a bbox is its corner with minimum x and y coordinates).
The frame number is extracted from the filename using the provided
regexp if it is not defined as a 'file_attribute' in the VIA tracks .csv
file.
"""
# Read VIA tracks .csv file as a pandas dataframe
df_file = pd.read_csv(file_path, sep=",", header=0)
Expand All @@ -439,7 +478,9 @@ def _df_from_via_tracks_file(file_path: Path) -> pd.DataFrame:
"ID": _via_attribute_column_to_numpy(
df_file, "region_attributes", ["track"], int
),
"frame_number": _extract_frame_number_from_via_tracks_df(df_file),
"frame_number": _extract_frame_number_from_via_tracks_df(
df_file, frame_regexp
),
"x": _via_attribute_column_to_numpy(
df_file, "region_shape_attributes", ["x"], float
),
Expand Down Expand Up @@ -473,7 +514,7 @@ def _df_from_via_tracks_file(file_path: Path) -> pd.DataFrame:
return df


def _extract_confidence_from_via_tracks_df(df) -> np.ndarray:
def _extract_confidence_from_via_tracks_df(df: pd.DataFrame) -> np.ndarray:
"""Extract confidence scores from the VIA tracks input dataframe.
Parameters
Expand Down Expand Up @@ -504,7 +545,9 @@ def _extract_confidence_from_via_tracks_df(df) -> np.ndarray:
return bbox_confidence


def _extract_frame_number_from_via_tracks_df(df) -> np.ndarray:
def _extract_frame_number_from_via_tracks_df(
df: pd.DataFrame, frame_regexp: str = DEFAULT_FRAME_REGEXP
) -> np.ndarray:
"""Extract frame numbers from the VIA tracks input dataframe.
Parameters
Expand All @@ -513,14 +556,20 @@ def _extract_frame_number_from_via_tracks_df(df) -> np.ndarray:
The VIA tracks input dataframe is the one obtained from
``df = pd.read_csv(file_path, sep=",", header=0)``.
frame_regexp : str
Regular expression pattern to extract the frame number from the frame
filename. By default, the frame number is expected to be encoded in
the filename as an integer number led by at least one zero, followed by
the file extension.
Returns
-------
np.ndarray
A numpy array of size (n_frames, ) containing the frame numbers.
In the VIA tracks .csv file, the frame number is expected to be
defined as a 'file_attribute' , or encoded in the filename as an
integer number led by at least one zero, between "_" and ".", followed
by the file extension.
integer number led by at least one zero, followed by the file
extension.
"""
# Extract frame number from file_attributes if exists
Expand All @@ -534,10 +583,9 @@ def _extract_frame_number_from_via_tracks_df(df) -> np.ndarray:
)
# Else extract from filename
else:
pattern = r"_(0\d*)\.\w+$"
list_frame_numbers = [
int(re.search(pattern, f).group(1)) # type: ignore
if re.search(pattern, f)
int(re.search(frame_regexp, f).group(1)) # type: ignore
if re.search(frame_regexp, f)
else np.nan
for f in df["filename"]
]
Expand Down
7 changes: 4 additions & 3 deletions movement/validators/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,12 @@ def _validate_frame_array(self, attribute, value):
value,
expected_shape=(self.position_array.shape[0], 1),
)
# check frames are continuous: exactly one frame number per row
if not np.all(np.diff(value, axis=0) == 1):
# check frames are monotonically increasing
if not np.all(np.diff(value, axis=0) >= 1):
raise log_error(
ValueError,
f"Frame numbers in {attribute.name} are not continuous.",
f"Frame numbers in {attribute.name} are not monotonically "
"increasing.",
)

# Define defaults
Expand Down
105 changes: 72 additions & 33 deletions movement/validators/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from movement.utils.logging import log_error

DEFAULT_FRAME_REGEXP = r"(0\d*)\.\w+$"


@define
class ValidFile:
Expand Down Expand Up @@ -234,6 +236,11 @@ class ValidVIATracksCSV:
----------
path : pathlib.Path
Path to the VIA tracks .csv file.
frame_regexp : str
Regular expression pattern to extract the frame number from the
filename. By default, the frame number is expected to be encoded in
the filename as an integer number led by at least one zero, followed
by the file extension.
Raises
------
Expand All @@ -243,6 +250,7 @@ class ValidVIATracksCSV:
"""

path: Path = field(validator=validators.instance_of(Path))
frame_regexp: str = DEFAULT_FRAME_REGEXP

@path.validator
def _file_contains_valid_header(self, attribute, value):
Expand Down Expand Up @@ -281,8 +289,10 @@ def _file_contains_valid_frame_numbers(self, attribute, value):
files.
If the frame number is included as part of the image file name, then
it is expected as an integer led by at least one zero, between "_" and
".", followed by the file extension.
it is expected to be captured by the regular expression in the
`frame_regexp` attribute of the ValidVIATracksCSV object. The default
regexp matches an integer led by at least one zero, followed by the
file extension.
"""
df = pd.read_csv(value, sep=",", header=0)
Expand All @@ -294,40 +304,15 @@ def _file_contains_valid_frame_numbers(self, attribute, value):

# If 'frame' is a file_attribute for all files:
# extract frame number
list_frame_numbers = []
if all(["frame" in d for d in file_attributes_dicts]):
for k_i, k in enumerate(file_attributes_dicts):
try:
list_frame_numbers.append(int(k["frame"]))
except Exception as e:
raise log_error(
ValueError,
f"{df.filename.iloc[k_i]} (row {k_i}): "
"'frame' file attribute cannot be cast as an integer. "
f"Please review the file attributes: {k}.",
) from e

list_frame_numbers = (
self._extract_frame_numbers_from_file_attributes(
df, file_attributes_dicts
)
)
# else: extract frame number from filename.
else:
pattern = r"_(0\d*)\.\w+$"

for f_i, f in enumerate(df["filename"]):
regex_match = re.search(pattern, f)
if regex_match: # if there is a pattern match
list_frame_numbers.append(
int(regex_match.group(1)) # type: ignore
# the match will always be castable as integer
)
else:
raise log_error(
ValueError,
f"{f} (row {f_i}): "
"a frame number could not be extracted from the "
"filename. If included in the filename, the frame "
"number is expected as a zero-padded integer between "
"an underscore '_' and the file extension "
"(e.g. img_00234.png).",
)
list_frame_numbers = self._extract_frame_numbers_using_regexp(df)

# Check we have as many unique frame numbers as unique image files
if len(set(list_frame_numbers)) != len(df.filename.unique()):
Expand All @@ -339,6 +324,60 @@ def _file_contains_valid_frame_numbers(self, attribute, value):
"file. ",
)

def _extract_frame_numbers_from_file_attributes(
self, df, file_attributes_dicts
):
"""Get frame numbers from the 'frame' key under 'file_attributes'."""
list_frame_numbers = []
for k_i, k in enumerate(file_attributes_dicts):
try:
list_frame_numbers.append(int(k["frame"]))
except ValueError as e:
raise log_error(
ValueError,
f"{df.filename.iloc[k_i]} (row {k_i}): "
"'frame' file attribute cannot be cast as an integer. "
f"Please review the file attributes: {k}.",
) from e
return list_frame_numbers

def _extract_frame_numbers_using_regexp(self, df):
"""Get frame numbers from the file names using the provided regexp."""
list_frame_numbers = []
for f_i, f in enumerate(df["filename"]):
# try compiling the frame regexp
try:
regex_match = re.search(self.frame_regexp, f)
except re.error as e:
raise log_error(
re.error,
"The provided regular expression for the frame "
f"numbers ({self.frame_regexp}) could not be compiled."
" Please review its syntax.",
) from e
# try extracting the frame number from the filename using the
# compiled regexp
try:
list_frame_numbers.append(int(regex_match.group(1)))
except AttributeError as e:
raise log_error(
AttributeError,
f"{f} (row {f_i}): The provided frame regexp "
f"({self.frame_regexp}) did not "
"return any matches and a frame number could not "
"be extracted from the filename.",
) from e
except ValueError as e:
raise log_error(
ValueError,
f"{f} (row {f_i}): "
"The frame number extracted from the filename using "
f"the provided regexp ({self.frame_regexp}) could not "
"be cast as an integer.",
) from e

return list_frame_numbers

@path.validator
def _file_contains_tracked_bboxes(self, attribute, value):
"""Ensure that the VIA tracks .csv contains tracked bounding boxes.
Expand Down
Loading

0 comments on commit f7539b9

Please sign in to comment.