Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reorder dimensions #351

Merged
merged 10 commits into from
Dec 6, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Draft reorder poses dimensions
lochhh committed Dec 5, 2024
commit 4b1f35d2989ae93ba79bc9d41bc3489fc7257893
6 changes: 3 additions & 3 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
@@ -699,9 +699,9 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset:
},
coords={
DIM_NAMES[0]: time_coords,
DIM_NAMES[1]: data.individual_names,
DIM_NAMES[2]: data.keypoint_names,
DIM_NAMES[3]: ["x", "y", "z"][:n_space],
DIM_NAMES[2]: data.keypoint_names,
DIM_NAMES[1]: data.individual_names,
},
attrs={
"fps": data.fps,
@@ -710,4 +710,4 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset:
"source_file": None,
"ds_type": "poses",
},
)
).transpose("time", "space", "keypoints", "individuals")
13 changes: 8 additions & 5 deletions movement/io/save_poses.py
Original file line number Diff line number Diff line change
@@ -35,13 +35,16 @@ def _ds_to_dlc_style_df(

"""
# Concatenate the pose tracks and confidence scores into one array
# and reverse the order of the dimensions except for the time dimension
tracks_with_scores = np.concatenate(
(
ds.position.data,
ds.confidence.data[..., np.newaxis],
ds.confidence.data[:, np.newaxis, ...],
),
axis=-1,
axis=1,
)
transpose_order = [0] + list(range(tracks_with_scores.ndim - 1, 0, -1))
tracks_with_scores = tracks_with_scores.transpose(transpose_order)

# Create DataFrame with multi-index columns
df = pd.DataFrame(
@@ -320,9 +323,9 @@ def to_sleap_analysis_file(ds: xr.Dataset, file_path: str | Path) -> None:
n_frames = frame_idxs[-1] - frame_idxs[0] + 1
pos_x = ds.position.sel(space="x").values
# Mask denoting which individuals are present in each frame
track_occupancy = (~np.all(np.isnan(pos_x), axis=2)).astype(int)
tracks = np.transpose(ds.position.data, (1, 3, 2, 0))
point_scores = np.transpose(ds.confidence.data, (1, 2, 0))
track_occupancy = (~np.all(np.isnan(pos_x), axis=1)).astype(int)
tracks = np.transpose(ds.position.data, (3, 1, 2, 0))
point_scores = np.transpose(ds.confidence.data, (2, 1, 0))
instance_scores = np.full((n_individuals, n_frames), np.nan, dtype=float)
tracking_scores = np.full((n_individuals, n_frames), np.nan, dtype=float)
labels_path = (
12 changes: 6 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -397,9 +397,9 @@ def valid_poses_dataset(valid_position_array, request):
},
coords={
"time": np.arange(n_frames),
"individuals": [f"ind{i}" for i in range(1, n_individuals + 1)],
"keypoints": [f"key{i}" for i in range(1, n_keypoints + 1)],
"space": ["x", "y"],
"keypoints": [f"key{i}" for i in range(1, n_keypoints + 1)],
"individuals": [f"ind{i}" for i in range(1, n_individuals + 1)],
},
attrs={
"fps": None,
@@ -408,7 +408,7 @@ def valid_poses_dataset(valid_position_array, request):
"source_file": "test.h5",
"ds_type": "poses",
},
)
).transpose("time", "space", "keypoints", "individuals")


@pytest.fixture
@@ -504,9 +504,9 @@ def valid_poses_dataset_uniform_linear_motion(
},
coords={
dim_names[0]: np.arange(n_frames),
dim_names[1]: [f"id_{i}" for i in range(1, n_individuals + 1)],
dim_names[2]: ["centroid", "left", "right"],
dim_names[3]: ["x", "y"],
dim_names[2]: ["centroid", "left", "right"],
dim_names[1]: [f"id_{i}" for i in range(1, n_individuals + 1)],
},
attrs={
"fps": None,
@@ -515,7 +515,7 @@ def valid_poses_dataset_uniform_linear_motion(
"source_file": "test_poses.h5",
"ds_type": "poses",
},
)
).transpose("time", "space", "keypoints", "individuals")


@pytest.fixture
6 changes: 3 additions & 3 deletions tests/test_integration/test_kinematics_vector_transform.py
Original file line number Diff line number Diff line change
@@ -72,16 +72,16 @@ def test_cart2pol_transform_on_kinematics(

# Build expected data array
expected_array_pol = xr.DataArray(
np.stack(expected_kinematics_polar, axis=1),
np.stack(expected_kinematics_polar, axis=-1),
# Stack along the "individuals" axis
dims=["time", "individuals", "space"],
dims=["time", "space", "individuals"],
)
if "keypoints" in ds.position.coords:
expected_array_pol = expected_array_pol.expand_dims(
{"keypoints": ds.position.coords["keypoints"].size}
)
expected_array_pol = expected_array_pol.transpose(
"time", "individuals", "keypoints", "space"
"time", "space", "keypoints", "individuals"
)

# Compare the values of the kinematic_array against the expected_array
12 changes: 6 additions & 6 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
@@ -86,19 +86,19 @@ def test_kinematics_uniform_linear_motion(
# and in the final xarray.DataArray
expected_dims = ["time", "individuals"]
if kinematic_variable in ["displacement", "velocity", "acceleration"]:
expected_dims.append("space")
expected_dims.insert(1, "space")

# Build expected data array from the expected numpy array
expected_array = xr.DataArray(
# Stack along the "individuals" axis
np.stack(expected_kinematics, axis=1),
np.stack(expected_kinematics, axis=-1),
dims=expected_dims,
)
if "keypoints" in position.coords:
expected_array = expected_array.expand_dims(
{"keypoints": position.coords["keypoints"].size}
)
expected_dims.insert(2, "keypoints")
expected_dims.insert(-1, "keypoints")
expected_array = expected_array.transpose(*expected_dims)

# Compare the values of the kinematic_array against the expected_array
@@ -263,11 +263,11 @@ def test_path_length_across_time_ranges(
num_segments -= 9 - np.floor(min(9, stop))

expected_path_length = xr.DataArray(
np.ones((2, 3)) * np.sqrt(2) * num_segments,
dims=["individuals", "keypoints"],
np.ones((3, 2)) * np.sqrt(2) * num_segments,
dims=["keypoints", "individuals"],
coords={
"individuals": position.coords["individuals"],
"keypoints": position.coords["keypoints"],
"individuals": position.coords["individuals"],
},
)
xr.testing.assert_allclose(path_length, expected_path_length)
24 changes: 15 additions & 9 deletions tests/test_unit/test_load_poses.py
Original file line number Diff line number Diff line change
@@ -76,16 +76,22 @@ def assert_dataset(
assert var in dataset.data_vars
assert isinstance(dataset[var], xr.DataArray)
assert dataset.position.ndim == 4
assert dataset.confidence.shape == dataset.position.shape[:-1]
# Check the dims and coords
position_shape = dataset.position.shape
# Confidence has the same shape as position, except for the space dim
assert (
dataset.confidence.shape == position_shape[:1] + position_shape[2:]
)
# Check the dims
DIM_NAMES = ValidPosesDataset.DIM_NAMES
assert all([i in dataset.dims for i in DIM_NAMES])
for d, dim in enumerate(DIM_NAMES[1:]):
assert dataset.sizes[dim] == dataset.position.shape[d + 1]
assert all(
[isinstance(s, str) for s in dataset.coords[dim].values]
)
assert all([i in dataset.coords["space"] for i in ["x", "y"]])
expected_dim_length_dict = {
DIM_NAMES[idx]: position_shape[i]
for i, idx in enumerate([0, 3, 2, 1])
}
assert expected_dim_length_dict == dataset.sizes
# Check the coords
for dim in DIM_NAMES[1:]:
assert all(isinstance(s, str) for s in dataset.coords[dim].values)
assert all(coord in dataset.coords["space"] for coord in ["x", "y"])
# Check the metadata attributes
assert (
dataset.source_file is None