Skip to content

Commit

Permalink
replace drop policy with ffill
Browse files Browse the repository at this point in the history
  • Loading branch information
niksirbi committed Nov 4, 2024
1 parent 0ef2d6b commit edf4352
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 66 deletions.
83 changes: 20 additions & 63 deletions movement/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,7 @@ def compute_path_length(
data: xr.DataArray,
start: float | None = None,
stop: float | None = None,
nan_policy: Literal["drop", "scale"] = "drop",
nan_policy: Literal["ffill", "scale"] = "ffill",
nan_warn_threshold: float = 0.2,
) -> xr.DataArray:
"""Compute the length of a path travelled between two time points.
Expand All @@ -728,10 +728,10 @@ def compute_path_length(
stop : float, optional
The end time of the path. If None (default),
the maximum time coordinate in the data is used.
nan_policy : Literal["drop", "scale"], optional
Policy to handle NaN (missing) values. Can be one of the ``"drop"``
or ``"scale"``. Defaults to ``"drop"``. See Notes for more details
on the two policies.
nan_policy : Literal["ffill", "scale"], optional
Policy to handle NaN (missing) values. Can be one of the ``"ffill"``
or ``"scale"``. Defaults to ``"ffill"`` (forward fill).
See Notes for more details on the two policies.
nan_warn_threshold : float, optional
If more than this proportion of values are missing in any point track,
a warning will be emitted. Defaults to 0.2 (20%).
Expand All @@ -745,13 +745,13 @@ def compute_path_length(
Notes
-----
Choosing ``nan_policy="drop"`` will drop NaN values from each point track
before computing path length. This equates to assuming that a track
follows a straight line between two valid points adjacent to a missing
segment. Missing segments at the beginning or end of the specified
time range are not counted. This approach tends to underestimate
the path length, and the error increases with the number of missing
values.
Choosing ``nan_policy="ffill"`` will use :meth:`xarray.DataArray.ffill`
to forward-fill missing segments (NaN values) across time.
This equates to assuming that a track remains stationary for
the duration of the missing segment and then instantaneously moves to
the next valid position, following a straight line. This approach tends
to underestimate the path length, and the error increases with the number
of missing values.
Choosing ``nan_policy="scale"`` will adjust the path length based on the
the proportion of valid segments per point track. For example, if only
Expand All @@ -774,16 +774,20 @@ def compute_path_length(

_warn_about_nan_proportion(data, nan_warn_threshold)

if nan_policy == "drop":
return _compute_path_length_drop_nan(data)
if nan_policy == "ffill":
return compute_norm(
compute_displacement(data.ffill(dim="time")).isel(
time=slice(1, None)
) # skip first displacement (always 0)
).sum(dim="time", min_count=1) # return NaN if no valid segment

elif nan_policy == "scale":
return _compute_scaled_path_length(data)
else:
raise log_error(
ValueError,
f"Invalid value for nan_policy: {nan_policy}. "
"Must be one of 'drop' or 'scale'.",
"Must be one of 'ffill' or 'scale'.",
)


Expand Down Expand Up @@ -820,7 +824,7 @@ def _warn_about_nan_proportion(
log_warning(
"The result may be unreliable for point tracks with many "
"missing values. The following tracks have more than "
f"{nan_warn_threshold * 100:.3} %) NaN values:",
f"{nan_warn_threshold * 100:.3} % NaN values:",
)
print(report_nan_values(data_to_warn_about))

Expand Down Expand Up @@ -857,50 +861,3 @@ def _compute_scaled_path_length(
valid_proportion = valid_segments / (data.sizes["time"] - 1)
# return scaled path length
return compute_norm(displacement).sum(dim="time") / valid_proportion


def _compute_path_length_drop_nan(
data: xr.DataArray,
) -> xr.DataArray:
"""Compute path length by dropping NaN values before computation.
This function iterates over point tracks, drops NaN values from each
track, and then computes the path length for the remaining valid
segments (takes the sum of the norms of the displacement vectors).
If there is no valid segment in a track, the path length for that
track will be NaN.
Parameters
----------
data : xarray.DataArray
The input data containing position information, with ``time``
and ``space`` (in Cartesian coordinates) as required dimensions.
Returns
-------
xarray.DataArray
An xarray DataArray containing the computed path length,
with dimensions matching those of the input data,
except ``time`` and ``space`` are removed.
"""
# Create array for holding results
path_length = xr.full_like(
data.isel(time=0, space=0, drop=True), fill_value=np.nan
)

# Stack data to iterate over point tracks
dims_to_stack = [d for d in data.dims if d not in ["time", "space"]]
stacked_data = data.stack(tracks=dims_to_stack)
for track_name in stacked_data.tracks.values:
# Drop NaN values from current point track
track_data = stacked_data.sel(tracks=track_name, drop=True).dropna(
dim="time", how="any"
)
# Compute path length for current point track
# and store it in the result array
target_loc = {k: v for k, v in zip(dims_to_stack, track_name)}
path_length.loc[target_loc] = compute_norm(
compute_displacement(track_data)
).sum(dim="time", min_count=1) # returns NaN if no valid segment
return path_length
6 changes: 3 additions & 3 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def test_path_length_across_time_ranges(
"nan_policy, expected_path_lengths_id_1, expected_exception",
[
(
"drop",
"ffill",
np.array([np.sqrt(2) * 8, np.sqrt(2) * 9, np.nan]),
does_not_raise(),
),
Expand Down Expand Up @@ -316,8 +316,8 @@ def test_path_length_with_nans(
Because the underlying motion is uniform linear, the "scale" policy should
perfectly restore the path length for individual "id_1" to its true value.
The "drop" policy should do likewise if frames are missing in the middle,
but will not count any missing frames at the edges.
The "ffill" policy should do likewise if frames are missing in the middle,
but will not "correct" for missing values at the edges.
"""
position = valid_poses_dataset_uniform_linear_motion_with_nans.position
with expected_exception:
Expand Down

0 comments on commit edf4352

Please sign in to comment.