From 8bee9b950afbd943a3dcb20a32f2685ddfe560c7 Mon Sep 17 00:00:00 2001 From: niksirbi Date: Mon, 4 Nov 2024 10:24:02 +0100 Subject: [PATCH] replace drop policy with ffill --- movement/analysis/kinematics.py | 83 +++++++----------------------- tests/test_unit/test_kinematics.py | 6 +-- 2 files changed, 23 insertions(+), 66 deletions(-) diff --git a/movement/analysis/kinematics.py b/movement/analysis/kinematics.py index 8f20748f..0bd79b8b 100644 --- a/movement/analysis/kinematics.py +++ b/movement/analysis/kinematics.py @@ -200,7 +200,7 @@ def compute_path_length( data: xr.DataArray, start: float | None = None, stop: float | None = None, - nan_policy: Literal["drop", "scale"] = "drop", + nan_policy: Literal["ffill", "scale"] = "ffill", nan_warn_threshold: float = 0.2, ) -> xr.DataArray: """Compute the length of a path travelled between two time points. @@ -222,10 +222,10 @@ def compute_path_length( stop : float, optional The end time of the path. If None (default), the maximum time coordinate in the data is used. - nan_policy : Literal["drop", "scale"], optional - Policy to handle NaN (missing) values. Can be one of the ``"drop"`` - or ``"scale"``. Defaults to ``"drop"``. See Notes for more details - on the two policies. + nan_policy : Literal["ffill", "scale"], optional + Policy to handle NaN (missing) values. Can be one of the ``"ffill"`` + or ``"scale"``. Defaults to ``"ffill"`` (forward fill). + See Notes for more details on the two policies. nan_warn_threshold : float, optional If more than this proportion of values are missing in any point track, a warning will be emitted. Defaults to 0.2 (20%). @@ -239,13 +239,13 @@ def compute_path_length( Notes ----- - Choosing ``nan_policy="drop"`` will drop NaN values from each point track - before computing path length. This equates to assuming that a track - follows a straight line between two valid points adjacent to a missing - segment. Missing segments at the beginning or end of the specified - time range are not counted. This approach tends to underestimate - the path length, and the error increases with the number of missing - values. + Choosing ``nan_policy="ffill"`` will use :meth:`xarray.DataArray.ffill` + to forward-fill missing segments (NaN values) across time. + This equates to assuming that a track remains stationary for + the duration of the missing segment and then instantaneously moves to + the next valid position, following a straight line. This approach tends + to underestimate the path length, and the error increases with the number + of missing values. Choosing ``nan_policy="scale"`` will adjust the path length based on the the proportion of valid segments per point track. For example, if only @@ -268,8 +268,12 @@ def compute_path_length( _warn_about_nan_proportion(data, nan_warn_threshold) - if nan_policy == "drop": - return _compute_path_length_drop_nan(data) + if nan_policy == "ffill": + return compute_norm( + compute_displacement(data.ffill(dim="time")).isel( + time=slice(1, None) + ) # skip first displacement (always 0) + ).sum(dim="time", min_count=1) # return NaN if no valid segment elif nan_policy == "scale": return _compute_scaled_path_length(data) @@ -277,7 +281,7 @@ def compute_path_length( raise log_error( ValueError, f"Invalid value for nan_policy: {nan_policy}. " - "Must be one of 'drop' or 'scale'.", + "Must be one of 'ffill' or 'scale'.", ) @@ -488,7 +492,7 @@ def _warn_about_nan_proportion( log_warning( "The result may be unreliable for point tracks with many " "missing values. The following tracks have more than " - f"{nan_warn_threshold * 100:.3} %) NaN values:", + f"{nan_warn_threshold * 100:.3} % NaN values:", ) print(report_nan_values(data_to_warn_about)) @@ -525,50 +529,3 @@ def _compute_scaled_path_length( valid_proportion = valid_segments / (data.sizes["time"] - 1) # return scaled path length return compute_norm(displacement).sum(dim="time") / valid_proportion - - -def _compute_path_length_drop_nan( - data: xr.DataArray, -) -> xr.DataArray: - """Compute path length by dropping NaN values before computation. - - This function iterates over point tracks, drops NaN values from each - track, and then computes the path length for the remaining valid - segments (takes the sum of the norms of the displacement vectors). - If there is no valid segment in a track, the path length for that - track will be NaN. - - Parameters - ---------- - data : xarray.DataArray - The input data containing position information, with ``time`` - and ``space`` (in Cartesian coordinates) as required dimensions. - - Returns - ------- - xarray.DataArray - An xarray DataArray containing the computed path length, - with dimensions matching those of the input data, - except ``time`` and ``space`` are removed. - - """ - # Create array for holding results - path_length = xr.full_like( - data.isel(time=0, space=0, drop=True), fill_value=np.nan - ) - - # Stack data to iterate over point tracks - dims_to_stack = [d for d in data.dims if d not in ["time", "space"]] - stacked_data = data.stack(tracks=dims_to_stack) - for track_name in stacked_data.tracks.values: - # Drop NaN values from current point track - track_data = stacked_data.sel(tracks=track_name, drop=True).dropna( - dim="time", how="any" - ) - # Compute path length for current point track - # and store it in the result array - target_loc = {k: v for k, v in zip(dims_to_stack, track_name)} - path_length.loc[target_loc] = compute_norm( - compute_displacement(track_data) - ).sum(dim="time", min_count=1) # returns NaN if no valid segment - return path_length diff --git a/tests/test_unit/test_kinematics.py b/tests/test_unit/test_kinematics.py index 9c043379..4d422d75 100644 --- a/tests/test_unit/test_kinematics.py +++ b/tests/test_unit/test_kinematics.py @@ -277,7 +277,7 @@ def test_path_length_across_time_ranges( "nan_policy, expected_path_lengths_id_1, expected_exception", [ ( - "drop", + "ffill", np.array([np.sqrt(2) * 8, np.sqrt(2) * 9, np.nan]), does_not_raise(), ), @@ -316,8 +316,8 @@ def test_path_length_with_nans( Because the underlying motion is uniform linear, the "scale" policy should perfectly restore the path length for individual "id_1" to its true value. - The "drop" policy should do likewise if frames are missing in the middle, - but will not count any missing frames at the edges. + The "ffill" policy should do likewise if frames are missing in the middle, + but will not "correct" for missing values at the edges. """ position = valid_poses_dataset_uniform_linear_motion_with_nans.position with expected_exception: