Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bbox centroid fix #303

Merged
merged 6 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions movement/io/load_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,11 @@ def _numpy_arrays_from_via_tracks_file(file_path: Path) -> dict:

array_dict[key] = np.stack(list_arrays, axis=1).squeeze()

# Transform position_array to represent centroid of bbox,
# rather than top-left corner
# (top left corner: corner of the bbox with minimum x and y coordinates)
array_dict["position_array"] += array_dict["shape_array"] / 2

# Add remaining arrays to dict
array_dict["ID_array"] = df["ID"].unique().reshape(-1, 1)
array_dict["frame_array"] = df["frame_number"].unique().reshape(-1, 1)
Expand All @@ -415,14 +420,16 @@ def _df_from_via_tracks_file(file_path: Path) -> pd.DataFrame:
Read the VIA tracks .csv file as a pandas dataframe with columns:
- ID: the integer ID of the tracked bounding box.
- frame_number: the frame number of the tracked bounding box.
- x: the x-coordinate of the tracked bounding box centroid.
- y: the y-coordinate of the tracked bounding box centroid.
- x: the x-coordinate of the tracked bounding box's top-left corner.
- y: the y-coordinate of the tracked bounding box's top-left corner.
- w: the width of the tracked bounding box.
- h: the height of the tracked bounding box.
- confidence: the confidence score of the tracked bounding box.

The dataframe is sorted by ID and frame number, and for each ID,
empty frames are filled in with NaNs.
empty frames are filled in with NaNs. The coordinates of the bboxes
are assumed to be in the image coordinate system (i.e., the top-left
corner of a bbox is its corner with minimum x and y coordinates).
"""
# Read VIA tracks .csv file as a pandas dataframe
df_file = pd.read_csv(file_path, sep=",", header=0)
Expand Down
51 changes: 51 additions & 0 deletions tests/test_unit/test_load_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,3 +419,54 @@ def test_fps_and_time_coords(
else:
start_frame = 0
assert_time_coordinates(ds, expected_fps, start_frame)


def test_df_from_via_tracks_file(via_tracks_file):
"""Test that the helper function correctly reads the VIA tracks .csv file
as a dataframe.
"""
df = load_bboxes._df_from_via_tracks_file(via_tracks_file)

assert isinstance(df, pd.DataFrame)
assert len(df.frame_number.unique()) == 5
assert (
df.shape[0] == len(df.ID.unique()) * 5
) # all individuals in all frames (even if nan)
assert list(df.columns) == [
"ID",
"frame_number",
"x",
"y",
"w",
"h",
"confidence",
]


def test_position_numpy_array_from_via_tracks_file(via_tracks_file):
"""Test the extracted position array from the VIA tracks .csv file
represents the centroid of the bbox.
"""
# Extract numpy arrays from VIA tracks .csv file
bboxes_arrays = load_bboxes._numpy_arrays_from_via_tracks_file(
via_tracks_file
)

# Read VIA tracks .csv file as a dataframe
df = load_bboxes._df_from_via_tracks_file(via_tracks_file)

# Compute centroid positions from the dataframe
# (go thru in the same order as ID array)
list_derived_centroids = []
for id in bboxes_arrays["ID_array"]:
df_one_id = df[df["ID"] == id.item()]
centroid_position = np.array(
[df_one_id.x + df_one_id.w / 2, df_one_id.y + df_one_id.h / 2]
).T # frames, xy
list_derived_centroids.append(centroid_position)

# Compare to extracted position array
assert np.allclose(
bboxes_arrays["position_array"], # frames, individuals, xy
np.stack(list_derived_centroids, axis=1),
)
Loading