Skip to content

Commit

Permalink
remove movement_dataset.py module
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Oct 24, 2024
1 parent f7a4a4b commit dfffae9
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 52 deletions.
3 changes: 1 addition & 2 deletions movement/io/load_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import pandas as pd
import xarray as xr

from movement.movement_dataset import BboxesDataset
from movement.utils.logging import log_error
from movement.validators.datasets import ValidBboxesDataset
from movement.validators.files import ValidFile, ValidVIATracksCSV
Expand Down Expand Up @@ -631,7 +630,7 @@ def _ds_from_valid_data(data: ValidBboxesDataset) -> xr.Dataset:

# Convert data to an xarray.Dataset
# with dimensions ('time', 'individuals', 'space')
DIM_NAMES = BboxesDataset.get_dim_names()
DIM_NAMES = ValidBboxesDataset.DIM_NAMES
n_space = data.position_array.shape[-1]
return xr.Dataset(
data_vars={
Expand Down
3 changes: 1 addition & 2 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from sleap_io.io.slp import read_labels
from sleap_io.model.labels import Labels

from movement.movement_dataset import PosesDataset
from movement.utils.logging import log_error, log_warning
from movement.validators.datasets import ValidPosesDataset
from movement.validators.files import ValidDeepLabCutCSV, ValidFile, ValidHDF5
Expand Down Expand Up @@ -654,7 +653,7 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset:
time_coords = time_coords / data.fps
time_unit = "seconds"

DIM_NAMES = PosesDataset.get_dim_names()
DIM_NAMES = ValidPosesDataset.DIM_NAMES
# Convert data to an xarray.Dataset
return xr.Dataset(
data_vars={
Expand Down
6 changes: 3 additions & 3 deletions movement/io/save_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import pandas as pd
import xarray as xr

from movement.movement_dataset import PosesDataset
from movement.utils.logging import log_error
from movement.validators.datasets import ValidPosesDataset
from movement.validators.files import ValidFile

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -436,13 +436,13 @@ def _validate_dataset(ds: xr.Dataset) -> None:
TypeError, f"Expected an xarray Dataset, but got {type(ds)}."
)

missing_vars = set(PosesDataset.get_var_names()) - set(ds.data_vars)
missing_vars = set(ValidPosesDataset.VAR_NAMES) - set(ds.data_vars)
if missing_vars:
raise ValueError(
f"Missing required data variables: {sorted(missing_vars)}"
) # sort for a reproducible error message

missing_dims = set(PosesDataset.get_dim_names()) - set(ds.dims)
missing_dims = set(ValidPosesDataset.DIM_NAMES) - set(ds.dims)
if missing_dims:
raise ValueError(
f"Missing required dimensions: {sorted(missing_dims)}"
Expand Down
35 changes: 0 additions & 35 deletions movement/movement_dataset.py

This file was deleted.

9 changes: 8 additions & 1 deletion movement/validators/datasets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""``attrs`` classes for validating data structures."""

from collections.abc import Iterable
from typing import Any
from typing import Any, ClassVar

import attrs
import numpy as np
Expand Down Expand Up @@ -142,6 +142,10 @@ class ValidPosesDataset:
validator=validators.optional(validators.instance_of(str)),
)

# Class variables
DIM_NAMES: ClassVar[tuple] = ("time", "individuals", "keypoints", "space")
VAR_NAMES: ClassVar[tuple] = ("position", "confidence")

# Add validators
@position_array.validator
def _validate_position_array(self, attribute, value):
Expand Down Expand Up @@ -293,6 +297,9 @@ class ValidBboxesDataset:
validator=validators.optional(validators.instance_of(str)),
)

DIM_NAMES: ClassVar[tuple] = ("time", "individuals", "space")
VAR_NAMES: ClassVar[tuple] = ("position", "shape", "confidence")

# Validators
@position_array.validator
@shape_array.validator
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
import pytest
import xarray as xr

from movement.movement_dataset import BboxesDataset, PosesDataset
from movement.sample_data import fetch_dataset_paths, list_datasets
from movement.utils.logging import configure_logging
from movement.validators.datasets import ValidBboxesDataset, ValidPosesDataset


def pytest_configure():
Expand Down Expand Up @@ -292,7 +292,7 @@ def valid_bboxes_dataset(
"""Return a valid bboxes dataset for two individuals moving in uniform
linear motion, with 5 frames with low confidence values and time in frames.
"""
dim_names = BboxesDataset.get_dim_names()
dim_names = ValidBboxesDataset.DIM_NAMES

position_array = valid_bboxes_arrays["position"]
shape_array = valid_bboxes_arrays["shape"]
Expand Down Expand Up @@ -376,7 +376,7 @@ def _valid_position_array(array_type):
@pytest.fixture
def valid_poses_dataset(valid_position_array, request):
"""Return a valid pose tracks dataset."""
dim_names = PosesDataset.get_dim_names()
dim_names = ValidPosesDataset.DIM_NAMES
# create a multi_individual_array by default unless overridden via param
try:
array_format = request.param
Expand Down Expand Up @@ -490,7 +490,7 @@ def valid_poses_dataset_uniform_linear_motion(
"""Return a valid poses dataset for two individuals moving in uniform
linear motion, with 5 frames with low confidence values and time in frames.
"""
dim_names = PosesDataset.get_dim_names()
dim_names = ValidPosesDataset.DIM_NAMES

position_array = valid_poses_array_uniform_linear_motion["position"]
confidence_array = valid_poses_array_uniform_linear_motion["confidence"]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_unit/test_load_bboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import xarray as xr

from movement.io import load_bboxes
from movement.movement_dataset import BboxesDataset
from movement.validators.datasets import ValidBboxesDataset


@pytest.fixture()
Expand Down Expand Up @@ -127,7 +127,7 @@ def assert_dataset(
assert dataset.confidence.shape == dataset.position.shape[:-1]

# Check the dims and coords
DIM_NAMES = BboxesDataset.get_dim_names()
DIM_NAMES = ValidBboxesDataset.DIM_NAMES
assert all([i in dataset.dims for i in DIM_NAMES])
for d, dim in enumerate(DIM_NAMES[1:]):
assert dataset.sizes[dim] == dataset.position.shape[d + 1]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_unit/test_load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from sleap_io.model.labels import LabeledFrame, Labels

from movement.io import load_poses
from movement.movement_dataset import PosesDataset
from movement.validators.datasets import ValidPosesDataset


class TestLoadPoses:
Expand Down Expand Up @@ -78,7 +78,7 @@ def assert_dataset(
assert dataset.position.ndim == 4
assert dataset.confidence.shape == dataset.position.shape[:-1]
# Check the dims and coords
DIM_NAMES = PosesDataset.get_dim_names()
DIM_NAMES = ValidPosesDataset.DIM_NAMES
assert all([i in dataset.dims for i in DIM_NAMES])
for d, dim in enumerate(DIM_NAMES[1:]):
assert dataset.sizes[dim] == dataset.position.shape[d + 1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def test_bboxes_dataset_validator_confidence_array(
(
np.arange(10).reshape(-1, 2),
pytest.raises(ValueError),
"Expected 'frame_array' to have shape (10, 1), " "but got (5, 2).",
"Expected 'frame_array' to have shape (10, 1), but got (5, 2).",
), # frame_array should be a column vector
(
[1, 2, 3],
Expand Down

0 comments on commit dfffae9

Please sign in to comment.