From 734e3aa9b2500a9f7379f80bb8a626114bfdebbf Mon Sep 17 00:00:00 2001 From: Chris Broz Date: Mon, 15 Jul 2024 15:09:03 -0700 Subject: [PATCH] Bugfixes from 870 (#1034) * Bugfixes from 870 * Update changelog * Declare y_eq0 var. Use over --- CHANGELOG.md | 2 + src/spyglass/position/v1/dlc_utils.py | 54 ++++++++++++--------------- 2 files changed, 26 insertions(+), 30 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4f987bf95..7a003d5ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,8 @@ PositionGroup.alter() - Replace `OutputLogger` context manager with decorator #870 - Rename `check_videofile` -> `find_mp4` and `get_video_path` -> `get_video_info` to reflect actual use #870 + - Fix `red_led_bisector` `np.nan` handling issue from #870. Fixed in #1034 + - Fix `one_pt_centoid` `np.nan` handling issue from #870. Fixed in #1034 - Spikesorting - Allow user to set smoothing timescale in `SortedSpikesGroup.get_firing_rate` #994 diff --git a/src/spyglass/position/v1/dlc_utils.py b/src/spyglass/position/v1/dlc_utils.py index 1523e01b4..f8a911148 100644 --- a/src/spyglass/position/v1/dlc_utils.py +++ b/src/spyglass/position/v1/dlc_utils.py @@ -714,30 +714,27 @@ def red_led_bisector_orientation(pos_df: pd.DataFrame, **params): LED2 = params.pop("led2", None) LED3 = params.pop("led3", None) - x_vec = pos_df[[LED1, LED2]].diff(axis=1).iloc[:, 0] - y_vec = pos_df[[LED1, LED2]].diff(axis=1).iloc[:, 1] - - y_is_zero = y_vec.eq(0) - perp_direction = pos_df[[LED3]].diff(axis=1) - - # Handling the special case where y_vec is zero all Ys are the same - special_case = ( - y_is_zero - & (pos_df[LED3]["y"] == pos_df[LED1]["y"]) - & (pos_df[LED3]["y"] == pos_df[LED2]["y"]) - ) - if special_case.any(): + orient = np.full(len(pos_df), np.nan) # Initialize with NaNs + x_vec = pos_df[LED1]["x"] - pos_df[LED2]["x"] + y_vec = pos_df[LED1]["y"] - pos_df[LED2]["y"] + y_eq0 = np.isclose(y_vec, 0) + + # when y_vec is zero, 1&2 are equal. Compare to 3, determine if up or down + orient[y_eq0 & pos_df[LED3]["y"].gt(pos_df[LED1]["y"])] = np.pi / 2 + orient[y_eq0 & pos_df[LED3]["y"].lt(pos_df[LED1]["y"])] = -np.pi / 2 + + # Handling error case where y_vec is zero and all Ys are the same + y_1, y_2, y_3 = pos_df[LED1]["y"], pos_df[LED2]["y"], pos_df[LED3]["y"] + if np.any(y_eq0 & np.isclose(y_1, y_2) & np.isclose(y_2, y_3)): raise Exception("Cannot determine head direction from bisector") - orientation = np.zeros(len(pos_df)) - orientation[y_is_zero & perp_direction.iloc[:, 0].gt(0)] = np.pi / 2 - orientation[y_is_zero & perp_direction.iloc[:, 0].lt(0)] = -np.pi / 2 - - orientation[~y_is_zero & ~x_vec.eq(0)] = np.arctan2( - y_vec[~y_is_zero], x_vec[~x_vec.eq(0)] - ) + # General case where y_vec is not zero. Use arctan2 to determine orientation + length = np.sqrt(x_vec**2 + y_vec**2) + norm_x = (-y_vec / length)[~y_eq0] + norm_y = (x_vec / length)[~y_eq0] + orient[~y_eq0] = np.arctan2(norm_y, norm_x) - return orientation + return orient # Add new functions for orientation calculation here @@ -834,7 +831,8 @@ def calc_centroid( if isinstance(mask, list): mask = [reduce(np.logical_and, m) for m in mask] - if points is not None: # Check that combinations of points close enough + # Check that combinations of points close enough + if points is not None and len(points) > 1: for pair in combinations(points, 2): mask = (*mask, ~self.too_sep(pair[0], pair[1])) @@ -846,10 +844,7 @@ def calc_centroid( if replace: self.centroid[mask] = np.nan return - if len(points) == 1: # only one point - self.centroid[mask] = self.coords[points[0]][mask] - return - elif len(points) == 3: + if len(points) == 3: self.coords["midpoint"] = ( self.coords[points[0]] + self.coords[points[1]] ) / 2 @@ -867,10 +862,9 @@ def too_sep(self, point1, point2): def get_1pt_centroid(self): """Passthrough. If point is NaN, then centroid is NaN.""" PT1 = self.points_dict.get("point1", None) - self.calc_centroid( - mask=(~self.nans[PT1],), - points=[PT1], - ) + mask = ~self.nans[PT1] # For good points, centroid is the point + self.centroid[mask] = self.coords[PT1][mask] + self.centroid[~mask] = np.nan # For bad points, centroid is NaN def get_2pt_centroid(self): self.calc_centroid( # Good points