From 584bbf6d144cda8b21d459bc510282a8d48a074b Mon Sep 17 00:00:00 2001 From: Liezl Maree <38435167+roomrys@users.noreply.github.com> Date: Tue, 21 Jan 2025 09:58:59 -0800 Subject: [PATCH] Preserve datatype on triangulate and project Co-authored-by: Talmo Pereira --- sleap_io/model/camera.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sleap_io/model/camera.py b/sleap_io/model/camera.py index bfd9c14b..0a5fb1bf 100644 --- a/sleap_io/model/camera.py +++ b/sleap_io/model/camera.py @@ -140,6 +140,7 @@ def triangulate( n_points = points.shape[1] # Undistort points + points_dtype = points.dtype points = points.astype("float64") # Ensure float64 for opencv undistort for cam_idx, camera in enumerate(self.cameras): cam_points = camera.undistort_points(points[cam_idx]) @@ -159,7 +160,9 @@ def triangulate( f"received {n_points_returned} 3D points." ) - return points_3d.reshape(*points_shape[1:-1], 3) + # Reshape to (N, 3) and cast to the original dtype. + points_3d = points_3d.reshape(*points_shape[1:-1], 3).astype(points_dtype) + return points_3d def project(self, points: np.ndarray) -> np.ndarray: """Project 3D points to 2D using camera group. @@ -175,6 +178,7 @@ def project(self, points: np.ndarray) -> np.ndarray: # Validate points in points = points.astype(np.float64) points_shape = points.shape + points_dtype = points.dtype try: # Check if points are 3D if points_shape[-1] != 3: @@ -194,7 +198,7 @@ def project(self, points: np.ndarray) -> np.ndarray: cam_points = camera.project(points) projected_points[cam_idx] = cam_points.reshape(n_points, 2) - return projected_points.reshape(n_cameras, *points_shape[:-1], 2) + return projected_points.reshape(n_cameras, *points_shape[:-1], 2).astype(points_dtype) @classmethod def from_dict(cls, calibration_dict: dict) -> CameraGroup: