From dfffae9a2db7a6f6668f5fab559af5a54e29b71b Mon Sep 17 00:00:00 2001 From: niksirbi Date: Thu, 24 Oct 2024 11:13:20 +0100 Subject: [PATCH] remove movement_dataset.py module --- movement/io/load_bboxes.py | 3 +- movement/io/load_poses.py | 3 +- movement/io/save_poses.py | 6 ++-- movement/movement_dataset.py | 35 ------------------- movement/validators/datasets.py | 9 ++++- tests/conftest.py | 8 ++--- tests/test_unit/test_load_bboxes.py | 4 +-- tests/test_unit/test_load_poses.py | 4 +-- .../test_datasets_validators.py | 2 +- 9 files changed, 22 insertions(+), 52 deletions(-) delete mode 100644 movement/movement_dataset.py diff --git a/movement/io/load_bboxes.py b/movement/io/load_bboxes.py index 5753a8daa..3e1b0e0db 100644 --- a/movement/io/load_bboxes.py +++ b/movement/io/load_bboxes.py @@ -11,7 +11,6 @@ import pandas as pd import xarray as xr -from movement.movement_dataset import BboxesDataset from movement.utils.logging import log_error from movement.validators.datasets import ValidBboxesDataset from movement.validators.files import ValidFile, ValidVIATracksCSV @@ -631,7 +630,7 @@ def _ds_from_valid_data(data: ValidBboxesDataset) -> xr.Dataset: # Convert data to an xarray.Dataset # with dimensions ('time', 'individuals', 'space') - DIM_NAMES = BboxesDataset.get_dim_names() + DIM_NAMES = ValidBboxesDataset.DIM_NAMES n_space = data.position_array.shape[-1] return xr.Dataset( data_vars={ diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index e137b47ad..f425d8a15 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -11,7 +11,6 @@ from sleap_io.io.slp import read_labels from sleap_io.model.labels import Labels -from movement.movement_dataset import PosesDataset from movement.utils.logging import log_error, log_warning from movement.validators.datasets import ValidPosesDataset from movement.validators.files import ValidDeepLabCutCSV, ValidFile, ValidHDF5 @@ -654,7 +653,7 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset: time_coords = time_coords / data.fps time_unit = "seconds" - DIM_NAMES = PosesDataset.get_dim_names() + DIM_NAMES = ValidPosesDataset.DIM_NAMES # Convert data to an xarray.Dataset return xr.Dataset( data_vars={ diff --git a/movement/io/save_poses.py b/movement/io/save_poses.py index 57b2346d3..c47d28f12 100644 --- a/movement/io/save_poses.py +++ b/movement/io/save_poses.py @@ -9,8 +9,8 @@ import pandas as pd import xarray as xr -from movement.movement_dataset import PosesDataset from movement.utils.logging import log_error +from movement.validators.datasets import ValidPosesDataset from movement.validators.files import ValidFile logger = logging.getLogger(__name__) @@ -436,13 +436,13 @@ def _validate_dataset(ds: xr.Dataset) -> None: TypeError, f"Expected an xarray Dataset, but got {type(ds)}." ) - missing_vars = set(PosesDataset.get_var_names()) - set(ds.data_vars) + missing_vars = set(ValidPosesDataset.VAR_NAMES) - set(ds.data_vars) if missing_vars: raise ValueError( f"Missing required data variables: {sorted(missing_vars)}" ) # sort for a reproducible error message - missing_dims = set(PosesDataset.get_dim_names()) - set(ds.dims) + missing_dims = set(ValidPosesDataset.DIM_NAMES) - set(ds.dims) if missing_dims: raise ValueError( f"Missing required dimensions: {sorted(missing_dims)}" diff --git a/movement/movement_dataset.py b/movement/movement_dataset.py deleted file mode 100644 index e458cc6f1..000000000 --- a/movement/movement_dataset.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Define the canonical structure of Movement Datasets.""" - - -class MovementDataset: - """Base class to define the canonical structure of a Movement Dataset.""" - - # Base dimensions and variables common to all datasets - DIM_NAMES: tuple[str, ...] = ("time", "space") - VAR_NAMES: tuple[str, ...] = ("position", "confidence") - - @classmethod - def get_dim_names(cls): - """Get dimension names for the dataset.""" - return cls.DIM_NAMES - - @classmethod - def get_var_names(cls): - """Get variable names for the dataset.""" - return cls.VAR_NAMES - - -class PosesDataset(MovementDataset): - """Dataset class for pose data, extending MovementDataset.""" - - # Additional dimensions and variables specific to poses - DIM_NAMES: tuple[str, ...] = ("time", "individuals", "keypoints", "space") - VAR_NAMES: tuple[str, ...] = MovementDataset.VAR_NAMES - - -class BboxesDataset(MovementDataset): - """Dataset class for bounding boxes' data, extending MovementDataset.""" - - # Additional dimensions and variables specific to bounding boxes - DIM_NAMES: tuple[str, ...] = ("time", "individuals", "space") - VAR_NAMES: tuple[str, ...] = ("position", "shape", "confidence") diff --git a/movement/validators/datasets.py b/movement/validators/datasets.py index fd31246d3..99a68c102 100644 --- a/movement/validators/datasets.py +++ b/movement/validators/datasets.py @@ -1,7 +1,7 @@ """``attrs`` classes for validating data structures.""" from collections.abc import Iterable -from typing import Any +from typing import Any, ClassVar import attrs import numpy as np @@ -142,6 +142,10 @@ class ValidPosesDataset: validator=validators.optional(validators.instance_of(str)), ) + # Class variables + DIM_NAMES: ClassVar[tuple] = ("time", "individuals", "keypoints", "space") + VAR_NAMES: ClassVar[tuple] = ("position", "confidence") + # Add validators @position_array.validator def _validate_position_array(self, attribute, value): @@ -293,6 +297,9 @@ class ValidBboxesDataset: validator=validators.optional(validators.instance_of(str)), ) + DIM_NAMES: ClassVar[tuple] = ("time", "individuals", "space") + VAR_NAMES: ClassVar[tuple] = ("position", "shape", "confidence") + # Validators @position_array.validator @shape_array.validator diff --git a/tests/conftest.py b/tests/conftest.py index 6843adee5..6da9a598f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,9 +11,9 @@ import pytest import xarray as xr -from movement.movement_dataset import BboxesDataset, PosesDataset from movement.sample_data import fetch_dataset_paths, list_datasets from movement.utils.logging import configure_logging +from movement.validators.datasets import ValidBboxesDataset, ValidPosesDataset def pytest_configure(): @@ -292,7 +292,7 @@ def valid_bboxes_dataset( """Return a valid bboxes dataset for two individuals moving in uniform linear motion, with 5 frames with low confidence values and time in frames. """ - dim_names = BboxesDataset.get_dim_names() + dim_names = ValidBboxesDataset.DIM_NAMES position_array = valid_bboxes_arrays["position"] shape_array = valid_bboxes_arrays["shape"] @@ -376,7 +376,7 @@ def _valid_position_array(array_type): @pytest.fixture def valid_poses_dataset(valid_position_array, request): """Return a valid pose tracks dataset.""" - dim_names = PosesDataset.get_dim_names() + dim_names = ValidPosesDataset.DIM_NAMES # create a multi_individual_array by default unless overridden via param try: array_format = request.param @@ -490,7 +490,7 @@ def valid_poses_dataset_uniform_linear_motion( """Return a valid poses dataset for two individuals moving in uniform linear motion, with 5 frames with low confidence values and time in frames. """ - dim_names = PosesDataset.get_dim_names() + dim_names = ValidPosesDataset.DIM_NAMES position_array = valid_poses_array_uniform_linear_motion["position"] confidence_array = valid_poses_array_uniform_linear_motion["confidence"] diff --git a/tests/test_unit/test_load_bboxes.py b/tests/test_unit/test_load_bboxes.py index 97fd6efa1..2f80459dc 100644 --- a/tests/test_unit/test_load_bboxes.py +++ b/tests/test_unit/test_load_bboxes.py @@ -9,7 +9,7 @@ import xarray as xr from movement.io import load_bboxes -from movement.movement_dataset import BboxesDataset +from movement.validators.datasets import ValidBboxesDataset @pytest.fixture() @@ -127,7 +127,7 @@ def assert_dataset( assert dataset.confidence.shape == dataset.position.shape[:-1] # Check the dims and coords - DIM_NAMES = BboxesDataset.get_dim_names() + DIM_NAMES = ValidBboxesDataset.DIM_NAMES assert all([i in dataset.dims for i in DIM_NAMES]) for d, dim in enumerate(DIM_NAMES[1:]): assert dataset.sizes[dim] == dataset.position.shape[d + 1] diff --git a/tests/test_unit/test_load_poses.py b/tests/test_unit/test_load_poses.py index e8d40c18d..77990a429 100644 --- a/tests/test_unit/test_load_poses.py +++ b/tests/test_unit/test_load_poses.py @@ -9,7 +9,7 @@ from sleap_io.model.labels import LabeledFrame, Labels from movement.io import load_poses -from movement.movement_dataset import PosesDataset +from movement.validators.datasets import ValidPosesDataset class TestLoadPoses: @@ -78,7 +78,7 @@ def assert_dataset( assert dataset.position.ndim == 4 assert dataset.confidence.shape == dataset.position.shape[:-1] # Check the dims and coords - DIM_NAMES = PosesDataset.get_dim_names() + DIM_NAMES = ValidPosesDataset.DIM_NAMES assert all([i in dataset.dims for i in DIM_NAMES]) for d, dim in enumerate(DIM_NAMES[1:]): assert dataset.sizes[dim] == dataset.position.shape[d + 1] diff --git a/tests/test_unit/test_validators/test_datasets_validators.py b/tests/test_unit/test_validators/test_datasets_validators.py index 493f1d460..e41331f77 100644 --- a/tests/test_unit/test_validators/test_datasets_validators.py +++ b/tests/test_unit/test_validators/test_datasets_validators.py @@ -352,7 +352,7 @@ def test_bboxes_dataset_validator_confidence_array( ( np.arange(10).reshape(-1, 2), pytest.raises(ValueError), - "Expected 'frame_array' to have shape (10, 1), " "but got (5, 2).", + "Expected 'frame_array' to have shape (10, 1), but got (5, 2).", ), # frame_array should be a column vector ( [1, 2, 3],