diff --git a/sleap_roots/convhull.py b/sleap_roots/convhull.py index c15a169..88ad97c 100644 --- a/sleap_roots/convhull.py +++ b/sleap_roots/convhull.py @@ -4,7 +4,10 @@ from scipy.spatial import ConvexHull from scipy.spatial.distance import pdist from typing import Tuple, Optional, Union -from sleap_roots.points import get_line_equation_from_points +from sleap_roots.points import ( + extract_points_from_geometry, + get_line_equation_from_points, +) from shapely import box, LineString, normalize, Polygon @@ -382,13 +385,9 @@ def get_chull_areas_via_intersection( # Find the intersection between the hull perimeter and the extended line intersection = extended_line.intersection(hull_perimeter) - # Add intersection points to both lists + # Compute the intersection points and add to lists if not intersection.is_empty: - intersect_points = ( - np.array([[point.x, point.y] for point in intersection.geoms]) - if intersection.geom_type == "MultiPoint" - else np.array([[intersection.x, intersection.y]]) - ) + intersect_points = extract_points_from_geometry(intersection) above_line.extend(intersect_points) below_line.extend(intersect_points) @@ -452,6 +451,12 @@ def get_chull_intersection_vectors( Raises: ValueError: If pts does not have the expected shape. """ + if r0_pts.ndim == 1 or rn_pts.ndim == 1 or pts.ndim == 2: + print( + "Not enough instances or incorrect format to compute convex hull intersections." + ) + return (np.array([[np.nan, np.nan]]), np.array([[np.nan, np.nan]])) + # Check for valid pts input if not isinstance(pts, np.ndarray) or pts.ndim != 3 or pts.shape[-1] != 2: raise ValueError("pts must be a numpy array of shape (instances, nodes, 2).") @@ -460,7 +465,7 @@ def get_chull_intersection_vectors( raise ValueError("rn_pts must be a numpy array of shape (instances, 2).") # Ensure r0_pts is a numpy array of shape (instances, 2) if not isinstance(r0_pts, np.ndarray) or r0_pts.ndim != 2 or r0_pts.shape[-1] != 2: - raise ValueError("r0_pts must be a numpy array of shape (instances, 2).") + raise ValueError(f"r0_pts must be a numpy array of shape (instances, 2).") # Flatten pts to 2D array and remove NaN values flattened_pts = pts.reshape(-1, 2) @@ -481,6 +486,9 @@ def get_chull_intersection_vectors( # Ensuring r0_pts does not contain NaN values r0_pts_valid = r0_pts[~np.isnan(r0_pts).any(axis=1)] + # Expect two vectors in the end + if len(r0_pts_valid) < 2: + return (np.array([[np.nan, np.nan]]), np.array([[np.nan, np.nan]])) # Get the vertices of the convex hull hull_vertices = hull.points[hull.vertices] diff --git a/sleap_roots/points.py b/sleap_roots/points.py index 479f564..6d5c5c1 100644 --- a/sleap_roots/points.py +++ b/sleap_roots/points.py @@ -3,11 +3,49 @@ import numpy as np from matplotlib import pyplot as plt from matplotlib.lines import Line2D -from shapely.geometry import LineString +from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection from shapely.ops import nearest_points from typing import List, Optional, Tuple +def extract_points_from_geometry(geometry): + """Extracts coordinates as a list of numpy arrays from any given Shapely geometry object. + + This function supports Point, MultiPoint, LineString, and GeometryCollection types. + It recursively extracts coordinates from complex geometries and aggregates them into a single list. + For unsupported geometry types, it returns an empty list. + + Args: + geometry (shapely.geometry.base.BaseGeometry): A Shapely geometry object from which to extract points. + + Returns: + List[np.ndarray]: A list of numpy arrays, where each array represents the coordinates of a point. + The list will be empty if the geometry type is unsupported or contains no coordinates. + + Example: + >>> from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection + >>> point = Point(1, 2) + >>> multipoint = MultiPoint([(1, 2), (3, 4)]) + >>> linestring = LineString([(0, 0), (1, 1), (2, 2)]) + >>> geom_col = GeometryCollection([point, multipoint, linestring]) + >>> extract_points_from_geometry(geom_col) + [array([1, 2]), array([1, 2]), array([3, 4]), array([0, 0]), array([1, 1]), array([2, 2])] + """ + if isinstance(geometry, Point): + return [np.array([geometry.x, geometry.y])] + elif isinstance(geometry, MultiPoint): + return [np.array([point.x, point.y]) for point in geometry.geoms] + elif isinstance(geometry, LineString): + return [np.array([x, y]) for x, y in zip(*geometry.xy)] + elif isinstance(geometry, GeometryCollection): + points = [] + for geom in geometry.geoms: + points.extend(extract_points_from_geometry(geom)) + return points + else: + raise TypeError(f"Unsupported geometry type: {type(geometry).__name__}") + + def get_count(pts: np.ndarray): """Get number of roots. diff --git a/tests/test_convhull.py b/tests/test_convhull.py index a33c279..f506312 100644 --- a/tests/test_convhull.py +++ b/tests/test_convhull.py @@ -314,7 +314,6 @@ def test_basic_functionality(pts_shape_3_6_2): @pytest.mark.parametrize( "invalid_input", [ - (np.array([1, 2]), np.array([3, 4]), np.array([[[1, 2], [3, 4]]]), None), (np.array([[1, 2, 3]]), np.array([[3, 4]]), np.array([[[1, 2], [3, 4]]]), None), # Add more invalid inputs as needed ], diff --git a/tests/test_points.py b/tests/test_points.py index 6ac3a1d..54c37d6 100644 --- a/tests/test_points.py +++ b/tests/test_points.py @@ -1,9 +1,14 @@ import numpy as np import pytest -from shapely.geometry import LineString +from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection from sleap_roots import Series from sleap_roots.lengths import get_max_length_pts -from sleap_roots.points import filter_plants_with_unexpected_ct, get_count, join_pts +from sleap_roots.points import ( + extract_points_from_geometry, + filter_plants_with_unexpected_ct, + get_count, + join_pts, +) from sleap_roots.points import ( get_all_pts_array, get_nodes, @@ -738,3 +743,55 @@ def test_filter_plants_with_unexpected_ct_incorrect_input_types(): expected_count = "not a float" with pytest.raises(ValueError): filter_plants_with_unexpected_ct(primary_pts, lateral_pts, expected_count) + + +def test_extract_from_point(): + point = Point(1, 2) + expected = [np.array([1, 2])] + assert np.array_equal(extract_points_from_geometry(point), expected) + + +def test_extract_from_multipoint(): + multipoint = MultiPoint([(1, 2), (3, 4)]) + expected = [np.array([1, 2]), np.array([3, 4])] + results = extract_points_from_geometry(multipoint) + assert all(np.array_equal(result, exp) for result, exp in zip(results, expected)) + + +def test_extract_from_linestring(): + linestring = LineString([(0, 0), (1, 1), (2, 2)]) + expected = [np.array([0, 0]), np.array([1, 1]), np.array([2, 2])] + results = extract_points_from_geometry(linestring) + assert all(np.array_equal(result, exp) for result, exp in zip(results, expected)) + + +def test_extract_from_geometrycollection(): + geom_collection = GeometryCollection([Point(1, 2), LineString([(0, 0), (1, 1)])]) + expected = [np.array([1, 2]), np.array([0, 0]), np.array([1, 1])] + results = extract_points_from_geometry(geom_collection) + assert all(np.array_equal(result, exp) for result, exp in zip(results, expected)) + + +def test_extract_from_empty_multipoint(): + empty_multipoint = MultiPoint() + expected = [] + assert extract_points_from_geometry(empty_multipoint) == expected + + +def test_extract_from_empty_linestring(): + empty_linestring = LineString() + expected = [] + assert extract_points_from_geometry(empty_linestring) == expected + + +def test_extract_from_unsupported_type(): + with pytest.raises(NameError): + extract_points_from_geometry( + Polygon([(0, 0), (1, 1), (1, 0)]) + ) # Polygon is unsupported + + +def test_extract_from_empty_geometrycollection(): + empty_geom_collection = GeometryCollection() + expected = [] + assert extract_points_from_geometry(empty_geom_collection) == expected