Skip to content

Commit

Permalink
Preserve datatype on triangulate and project
Browse files Browse the repository at this point in the history
Co-authored-by: Talmo Pereira <[email protected]>
  • Loading branch information
roomrys and talmo authored Jan 21, 2025
1 parent 8233ee9 commit 584bbf6
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions sleap_io/model/camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def triangulate(
n_points = points.shape[1]

# Undistort points
points_dtype = points.dtype
points = points.astype("float64") # Ensure float64 for opencv undistort
for cam_idx, camera in enumerate(self.cameras):
cam_points = camera.undistort_points(points[cam_idx])
Expand All @@ -159,7 +160,9 @@ def triangulate(
f"received {n_points_returned} 3D points."
)

return points_3d.reshape(*points_shape[1:-1], 3)
# Reshape to (N, 3) and cast to the original dtype.
points_3d = points_3d.reshape(*points_shape[1:-1], 3).astype(points_dtype)
return points_3d

def project(self, points: np.ndarray) -> np.ndarray:
"""Project 3D points to 2D using camera group.
Expand All @@ -175,6 +178,7 @@ def project(self, points: np.ndarray) -> np.ndarray:
# Validate points in
points = points.astype(np.float64)
points_shape = points.shape
points_dtype = points.dtype
try:
# Check if points are 3D
if points_shape[-1] != 3:
Expand All @@ -194,7 +198,7 @@ def project(self, points: np.ndarray) -> np.ndarray:
cam_points = camera.project(points)
projected_points[cam_idx] = cam_points.reshape(n_points, 2)

return projected_points.reshape(n_cameras, *points_shape[:-1], 2)
return projected_points.reshape(n_cameras, *points_shape[:-1], 2).astype(points_dtype)

@classmethod
def from_dict(cls, calibration_dict: dict) -> CameraGroup:
Expand Down

0 comments on commit 584bbf6

Please sign in to comment.