diff --git a/sleap_roots/__init__.py b/sleap_roots/__init__.py index fe5c29e..aa3a417 100644 --- a/sleap_roots/__init__.py +++ b/sleap_roots/__init__.py @@ -12,9 +12,14 @@ import sleap_roots.series import sleap_roots.summary import sleap_roots.trait_pipelines -from sleap_roots.trait_pipelines import DicotPipeline, TraitDef, YoungerMonocotPipeline +from sleap_roots.trait_pipelines import ( + DicotPipeline, + TraitDef, + YoungerMonocotPipeline, + OlderMonocotPipeline, +) from sleap_roots.series import Series, find_all_series # Define package version. # This is read dynamically by setuptools in pyproject.toml to determine the release version. -__version__ = "0.0.5" +__version__ = "0.0.6" diff --git a/sleap_roots/angle.py b/sleap_roots/angle.py index 12c470b..5455307 100644 --- a/sleap_roots/angle.py +++ b/sleap_roots/angle.py @@ -74,7 +74,8 @@ def get_root_angle( pts = np.expand_dims(pts, axis=0) angs_root = [] - for i in range(len(node_ind)): + # Calculate the angle for each instance + for i in range(pts.shape[0]): # if the node_ind is 0, do NOT calculate angs if node_ind[i] == 0: angs = np.nan @@ -90,3 +91,30 @@ def get_root_angle( if angs_root.shape[0] == 1: return angs_root[0] return angs_root + + +def get_vector_angles_from_gravity(vectors: np.ndarray) -> np.ndarray: + """Calculate the angle of given vectors from the gravity vector. + + Args: + vectors: An array of vectorss with shape (instances, 2), each representing a vector + from start to end in an instance. + + Returns: + An array of angles in degrees with shape (instances,), representing the angle + between each vector and the downward-pointing gravity vector. + """ + gravity_vector = np.array([0, 1]) # Downwards along the positive y-axis + # Calculate the angle between the vectors and the gravity vectors + angles = np.arctan2(vectors[:, 1], vectors[:, 0]) - np.arctan2( + gravity_vector[1], gravity_vector[0] + ) + angles = np.degrees(angles) + # Normalize angles to the range [0, 180] since direction doesn't matter + angles = np.abs(angles) + angles[angles > 180] = 360 - angles[angles > 180] + + # If only one root, return a scalar instead of a single-element array + if angles.shape[0] == 1: + return angles[0] + return angles diff --git a/sleap_roots/bases.py b/sleap_roots/bases.py index 9861efe..790092e 100644 --- a/sleap_roots/bases.py +++ b/sleap_roots/bases.py @@ -1,4 +1,4 @@ -"""Trait calculations that rely on bases (i.e., dicot-only).""" +"""Trait calculations that rely on bases.""" import numpy as np from shapely.geometry import LineString, Point @@ -7,20 +7,16 @@ from typing import Union, Tuple -def get_bases(pts: np.ndarray, monocots: bool = False) -> np.ndarray: +def get_bases(pts: np.ndarray) -> np.ndarray: """Return bases (r1) from each root. Args: pts: Root landmarks as array of shape `(instances, nodes, 2)` or `(nodes, 2)`. - monocots: Boolean value, where false is dicot (default), true is rice. Returns: Array of bases `(instances, (x, y))`. If the input is `(nodes, 2)`, an array of - shape `(2,)` will be returned. + shape `(2,)` will be returned. """ - if monocots: - return np.nan - # If the input has shape `(nodes, 2)`, reshape it for consistency if pts.ndim == 2: pts = pts[np.newaxis, ...] @@ -73,57 +69,33 @@ def get_base_tip_dist( return distances -def get_lateral_count(pts: np.ndarray): - """Get number of lateral roots. - - Args: - pts: lateral root landmarks as array of shape `(instance, node, 2)`. - - Return: - Scalar of number of lateral roots. - """ - lateral_count = pts.shape[0] - return lateral_count - - def get_base_xs(base_pts: np.ndarray) -> np.ndarray: """Get x coordinates of the base of each lateral root. Args: - base_pts: Array of bases as returned by `get_bases`, shape `(instances, 2)` or - `(2,)`. + base_pts: root bases as array of shape `(instances, 2)` or `(2)` when there is + only one root, as is the case for primary roots. - Returns: - An array of the x-coordinates of bases `(instances,)` or a single x-coordinate. + Return: + An array of base x-coordinates (instances,) or (1,) when there is only one root. """ - # If the input is a single number (float or integer), return np.nan - if isinstance(base_pts, (np.floating, float, np.integer, int)): - return np.nan - - # If the base points array has shape `(2,)`, return the first element (x) - if base_pts.ndim == 1 and base_pts.shape[0] == 2: - return base_pts[0] - - # If the base points array doesn't have exactly 2 dimensions or - # the second dimension is not of size 2, raise an error - elif base_pts.ndim != 2 or base_pts.shape[1] != 2: + if base_pts.ndim not in (1, 2): raise ValueError( - "Array of base points must be 2-dimensional with shape (instances, 2)." + "Input array must be 2-dimensional (instances, 2) or 1-dimensional (2,)." ) + if base_pts.shape[-1] != 2: + raise ValueError("Last dimension must be (x, y).") - # If everything is fine, extract and return the x-coordinates of the base points - else: - base_xs = base_pts[:, 0] - return base_xs + base_xs = base_pts[..., 0] + return base_xs -def get_base_ys(base_pts: np.ndarray, monocots: bool = False) -> np.ndarray: +def get_base_ys(base_pts: np.ndarray) -> np.ndarray: """Get y coordinates of the base of each root. Args: base_pts: root bases as array of shape `(instances, 2)` or `(2)` when there is only one root, as is the case for primary roots. - monocots: Boolean value, where false is dicot (default), true is rice. Return: An array of the y-coordinates of bases (instances,). @@ -144,25 +116,19 @@ def get_base_ys(base_pts: np.ndarray, monocots: bool = False) -> np.ndarray: return base_ys -def get_base_length(lateral_base_ys: np.ndarray, monocots: bool = False) -> float: +def get_base_length(lateral_base_ys: np.ndarray) -> float: """Get the y-axis difference from the top lateral base to the bottom lateral base. Args: lateral_base_ys: y-coordinates of the base points of lateral roots of shape `(instances,)`. - monocots: Boolean value, where false is dicot (default), true is rice. Return: The distance between the top base y-coordinate and the deepest base y-coordinate. """ - # If the roots are monocots, return NaN - if monocots: - return np.nan - # Compute the difference between the maximum and minimum y-coordinates base_length = np.nanmax(lateral_base_ys) - np.nanmin(lateral_base_ys) - return base_length @@ -179,7 +145,7 @@ def get_base_ct_density( Return: Scalar of base count density. """ - # Check if the input is invalid + # Check if the input is valid for lateral_base_pts if ( isinstance(lateral_base_pts, (np.floating, float, np.integer, int)) or np.isnan(lateral_base_pts).all() @@ -204,22 +170,19 @@ def get_base_ct_density( return base_ct_density -def get_base_length_ratio( - primary_length: float, base_length: float, monocots: bool = False -) -> float: +def get_base_length_ratio(primary_length: float, base_length: float) -> float: """Calculate the ratio of the length of the bases to the primary root length. Args: primary_length (float): Length of the primary root. base_length (float): Length of the bases along the primary root. - monocots (bool): True if the roots are monocots, False if they are dicots. Returns: Ratio of the length of the bases along the primary root to the primary root length. """ - # If roots are monocots or either of the lengths are NaN, return NaN - if monocots or np.isnan(primary_length) or np.isnan(base_length): + # If either of the lengths are NaN, return NaN + if np.isnan(primary_length) or np.isnan(base_length): return np.nan # Handle case where primary length is zero to avoid division by zero @@ -231,24 +194,19 @@ def get_base_length_ratio( return base_length_ratio -def get_base_median_ratio(lateral_base_ys, primary_tip_pt_y, monocots: bool = False): +def get_base_median_ratio(lateral_base_ys, primary_tip_pt_y): """Get ratio of median value in all base points to tip of primary root in y axis. Args: - lateral_base_ys: y-coordinates of the base points of lateral roots of shape + lateral_base_ys: Y-coordinates of the base points of lateral roots of shape `(instances,)`. - primary_tip_pt_y: y-coordinate of the tip point of the primary root of shape + primary_tip_pt_y: Y-coordinate of the tip point of the primary root of shape `(1)`. - monocots: Boolean value, where false is dicot (default), true is rice. Return: Scalar of base median ratio. If all y-coordinates of the lateral root bases are - NaN, the function returns NaN. + NaN, the function returns NaN. """ - # Check if the roots are monocots, if so return NaN - if monocots: - return np.nan - # Check if all y-coordinates of lateral root bases are NaN, if so return NaN if np.isnan(lateral_base_ys).all(): return np.nan @@ -271,9 +229,8 @@ def get_root_widths( primary_max_length_pts: np.ndarray, lateral_pts: np.ndarray, tolerance: float = 0.02, - monocots: bool = False, return_inds: bool = False, -) -> (np.ndarray, list, np.ndarray, np.ndarray): +) -> Tuple[np.ndarray, list, np.ndarray, np.ndarray]: """Estimate root width using bases of lateral roots. Args: @@ -283,8 +240,6 @@ def get_root_widths( shape (n, nodes, 2). tolerance: Tolerance level for the projection difference between matched roots. Defaults to 0.02. - monocots: Indicates the type of plant. Set to False for dicots (default) or - True for monocots like rice. return_inds: Flag to indicate whether to return matched indices along with distances. Defaults to False. @@ -326,11 +281,10 @@ def get_root_widths( default_left_bases = np.empty((0, 2)) default_right_bases = np.empty((0, 2)) - # Check for minimum length, monocots, or all NaNs in arrays + # Check for minimum length, or all NaNs in arrays if ( len(primary_max_length_pts) < 2 or len(lateral_pts) < 2 - or monocots or np.isnan(primary_max_length_pts).all() or np.isnan(lateral_pts).all() ): diff --git a/sleap_roots/convhull.py b/sleap_roots/convhull.py index 5ff3f92..c15a169 100644 --- a/sleap_roots/convhull.py +++ b/sleap_roots/convhull.py @@ -4,6 +4,8 @@ from scipy.spatial import ConvexHull from scipy.spatial.distance import pdist from typing import Tuple, Optional, Union +from sleap_roots.points import get_line_equation_from_points +from shapely import box, LineString, normalize, Polygon def get_convhull(pts: np.ndarray) -> Optional[ConvexHull]: @@ -15,7 +17,7 @@ def get_convhull(pts: np.ndarray) -> Optional[ConvexHull]: Returns: An object representing the convex hull or None if a hull can't be formed. """ - # Ensure the input is an array of shape (n, 2) + # Ensure the input is an array of shape (..., 2) if pts.ndim < 2 or pts.shape[-1] != 2: raise ValueError("Input points should be of shape (..., 2).") @@ -23,16 +25,23 @@ def get_convhull(pts: np.ndarray) -> Optional[ConvexHull]: pts = pts.reshape(-1, 2) pts = pts[~np.isnan(pts).any(axis=-1)] - # Check for NaNs or infinite values - if np.isnan(pts).any() or np.isinf(pts).any(): + # Check for infinite values + if np.isinf(pts).any(): + print("Cannot compute convex hull: input contains infinite values.") return None # Ensure there are at least 3 unique non-collinear points - if len(np.unique(pts, axis=0)) < 3: + unique_pts = np.unique(pts, axis=0) + if len(unique_pts) < 3: + print("Cannot compute convex hull: not enough unique points.") return None - # Compute and return the convex hull - return ConvexHull(pts) + try: + # Compute and return the convex hull + return ConvexHull(unique_pts) + except Exception as e: + print(f"Cannot compute convex hull: {e}") + return None def get_chull_perimeter(hull: Union[np.ndarray, ConvexHull, None]) -> float: @@ -190,3 +199,396 @@ def get_chull_line_lengths(hull: Union[np.ndarray, ConvexHull]) -> np.ndarray: chull_line_lengths = pdist(hull.points[hull.vertices], "euclidean") return chull_line_lengths + + +def get_chull_division_areas( + rn_pts: np.ndarray, pts: np.ndarray, hull: ConvexHull +) -> Tuple[float, float]: + """Get areas above and below the line formed by the leftmost and rightmost rn nodes. + + Args: + rn_pts: The nth root nodes when indexing from 0. Shape is (instances, 2). + pts: Numpy array of points with shape (instances, nodes, 2). + hull: A ConvexHull object computed from pts. + + Returns: + A tuple containing the areas of the convex hull of the points above and below + the line, respectively, where the line is formed by the leftmost and rightmost + rn nodes and the y-axis increases downward in image coordinates. Returns + (np.nan, np.nan) if the area cannot be calculated. + + Raises: + ValueError: If pts does not have the expected shape, or if hull is not a valid + ConvexHull object. + """ + if not isinstance(pts, np.ndarray) or pts.ndim != 3 or pts.shape[-1] != 2: + raise ValueError("pts must be a numpy array of shape (instances, nodes, 2).") + if not isinstance(hull, ConvexHull): + raise ValueError("hull must be a ConvexHull object.") + + # There must be at least 3 unique non-collinear points to form a convex hull + # Flatten pts to 2D array and check for at least 3 unique points + flattened_pts = pts.reshape(-1, 2) + unique_pts = np.unique(flattened_pts, axis=0) + if len(unique_pts) < 3: + return np.nan, np.nan + + # Attempt to get the line equation between the leftmost and rightmost r1 nodes + try: + leftmost_rn = rn_pts[np.argmin(rn_pts[:, 0])] + rightmost_rn = rn_pts[np.argmax(rn_pts[:, 0])] + m, b = get_line_equation_from_points(leftmost_rn, rightmost_rn) + except Exception: + # If line equation cannot be found, return NaNs + return np.nan, np.nan + + # Initialize lists to hold points above/on and below the line + above_or_on_line = [] + below_line = [] + # Classify each point as being above or below the line + for point in flattened_pts: + if ( + point[1] <= m * point[0] + b + ): # y <= mx + b (y increases downward in image coordinates) + above_or_on_line.append(point) + else: + below_line.append(point) + + # Calculate areas using get_chull_area, return np.nan if no points satisfy the condition + area_above_line = ( + get_chull_area(np.array(above_or_on_line)) if above_or_on_line else np.nan + ) + area_below_line = get_chull_area(np.array(below_line)) if below_line else np.nan + + return area_above_line, area_below_line + + +def get_chull_division_areas_above(areas: Tuple[float, float]) -> float: + """Get the chull area of the points above the line from `get_chull_division_areas`. + + Args: + areas: Tuple containing two float objects: + - The first is the area of the convex hull of the points above the line + formed by the leftmost and rightmost rn nodes. + - The second is the area of the convex hull of the points below the line + formed by the leftmost and rightmost rn nodes. + + Returns: + area_above_line: the area of the convex hull of the points above the line, + formed by the leftmost and rightmost rn nodes. + """ + return areas[0] + + +def get_chull_division_areas_below(areas: Tuple[float, float]) -> float: + """Get the chull area of the points below the line from `get_chull_division_areas`. + + Args: + areas: Tuple containing two float objects: + - The first is the area of the convex hull of the points above the line + formed by the leftmost and rightmost rn nodes. + - The second is the area of the convex hull of the points below the line + formed by the leftmost and rightmost rn nodes. + + Returns: + area_below_line: the area of the convex hull of the points below the line, + formed by the leftmost and rightmost rn nodes. + """ + return areas[1] + + +def get_chull_areas_via_intersection( + rn_pts: np.ndarray, pts: np.ndarray, hull: Optional[ConvexHull] +) -> Tuple[float, float]: + """Get convex hull areas above and below the intersecting line. + + Args: + rn_pts: The nth root nodes when indexing from 0. Shape is (instances, 2). + pts: Numpy array of points with shape (instances, nodes, 2). + hull: A ConvexHull object computed from pts, or None if a convex hull couldn't be formed. + + Returns: + A tuple containing the areas of the convex hull above and below + the line, respectively, where the line is formed by the leftmost and rightmost + rn nodes and the y-axis increases downward in image coordinates. Returns + (np.nan, np.nan) if the area cannot be calculated. + + Raises: + ValueError: If pts does not have the expected shape. + """ + # Check for valid pts input + if not isinstance(pts, np.ndarray) or pts.ndim != 3 or pts.shape[-1] != 2: + raise ValueError("pts must be a numpy array of shape (instances, nodes, 2).") + + # Flatten pts to 2D array and remove NaN values + flattened_pts = pts.reshape(-1, 2) + valid_pts = flattened_pts[~np.isnan(flattened_pts).any(axis=1)] + # Get unique points + unique_pts = np.unique(valid_pts, axis=0) + + # Check for a valid or existing convex hull + if hull is None or len(unique_pts) < 3: + return np.nan, np.nan + + # Ensure rn_pts does not contain NaN values + rn_pts_valid = rn_pts[~np.isnan(rn_pts).any(axis=1)] + # Need at least two points to define a line + if len(rn_pts_valid) < 2: + return np.nan, np.nan + + # Attempt to get the line equation between the leftmost and rightmost rn nodes + try: + leftmost_rn = rn_pts[np.argmin(rn_pts[:, 0])] + rightmost_rn = rn_pts[np.argmax(rn_pts[:, 0])] + m, b = get_line_equation_from_points(leftmost_rn, rightmost_rn) + except Exception: + # If line equation cannot be found, return NaNs + return np.nan, np.nan + + # Initialize lists to hold points above/on and below the line + above_line = [] + below_line = [] + # Classify each point as being above or below the line + for point in unique_pts: + if ( + point[1] <= m * point[0] + b + ): # y <= mx + b (y increases downward in image coordinates) + above_line.append(point) + if point[1] >= m * point[0] + b: + below_line.append(point) + + # Find the leftmost and rightmost points + leftmost_pt = np.nanmin(unique_pts[:, 0]) + rightmost_pt = np.nanmax(unique_pts[:, 0]) + + # Define how far to extend the line in terms of x + x_min_extended = leftmost_pt # Far left point + x_max_extended = rightmost_pt # Far right point + + # Calculate the corresponding y-values using the line equation + y_min_extended = m * x_min_extended + b + y_max_extended = m * x_max_extended + b + + # Create the extended line + extended_line = LineString( + [(x_min_extended, y_min_extended), (x_max_extended, y_max_extended)] + ) + + # Create a LineString that represents the perimeter of the convex hull + hull_perimeter = LineString( + hull.points[hull.vertices].tolist() + [hull.points[hull.vertices[0]].tolist()] + ) + + # Find the intersection between the hull perimeter and the extended line + intersection = extended_line.intersection(hull_perimeter) + + # Add intersection points to both lists + if not intersection.is_empty: + intersect_points = ( + np.array([[point.x, point.y] for point in intersection.geoms]) + if intersection.geom_type == "MultiPoint" + else np.array([[intersection.x, intersection.y]]) + ) + above_line.extend(intersect_points) + below_line.extend(intersect_points) + + # Calculate areas using get_chull_area + area_above_line = get_chull_area(np.array(above_line)) if above_line else 0.0 + area_below_line = get_chull_area(np.array(below_line)) if below_line else 0.0 + + return area_above_line, area_below_line + + +def get_chull_area_via_intersection_above(areas: Tuple[float, float]) -> float: + """Get the chull area above the line from `get_chull_area_via_intersection`. + + Args: + areas: Tuple containing two float objects: + - The first is the area of the convex hull above the line + formed by the leftmost and rightmost rn nodes. + - The second is the area of the convex hull below the line + formed by the leftmost and rightmost rn nodes. + + Returns: + area_above_line: the area of the convex hull above the line, + formed by the leftmost and rightmost rn nodes. + """ + return areas[0] + + +def get_chull_area_via_intersection_below(areas: Tuple[float, float]) -> float: + """Get the chull area below the line from `get_chull_area_via_intersection`. + + Args: + areas: Tuple containing two float objects: + - The first is the area of the convex hull above the line + formed by the leftmost and rightmost rn nodes. + - The second is the area of the convex hull below the line + formed by the leftmost and rightmost rn nodes. + + Returns: + area_below_line: the area of the convex hull below the line, + formed by the leftmost and rightmost rn nodes. + """ + return areas[1] + + +def get_chull_intersection_vectors( + r0_pts: np.ndarray, rn_pts: np.ndarray, pts: np.ndarray, hull: Optional[ConvexHull] +) -> Tuple[np.ndarray, np.ndarray]: + """Get vectors from top left and top right to intersection on convex hull. + + Args: + r0_pts: The 0th root nodes when indexing from 0. Shape is (instances, 2). + rn_pts: The nth root nodes when indexing from 0. Shape is (instances, 2). + pts: Numpy array of points with shape (instances, nodes, 2). + hull: A ConvexHull object computed from pts, or None if a convex hull couldn't be formed. + + Returns: + A tuple containing vectors from the top left point to the left intersection point, and from + the top right point to the right intersection point with the convex hull. Returns two vectors + of NaNs if the vectors can't be calculated. Vectors are of shape (1, 2). + + Raises: + ValueError: If pts does not have the expected shape. + """ + # Check for valid pts input + if not isinstance(pts, np.ndarray) or pts.ndim != 3 or pts.shape[-1] != 2: + raise ValueError("pts must be a numpy array of shape (instances, nodes, 2).") + # Ensure rn_pts is a numpy array of shape (instances, 2) + if not isinstance(rn_pts, np.ndarray) or rn_pts.ndim != 2 or rn_pts.shape[-1] != 2: + raise ValueError("rn_pts must be a numpy array of shape (instances, 2).") + # Ensure r0_pts is a numpy array of shape (instances, 2) + if not isinstance(r0_pts, np.ndarray) or r0_pts.ndim != 2 or r0_pts.shape[-1] != 2: + raise ValueError("r0_pts must be a numpy array of shape (instances, 2).") + + # Flatten pts to 2D array and remove NaN values + flattened_pts = pts.reshape(-1, 2) + valid_pts = flattened_pts[~np.isnan(flattened_pts).any(axis=1)] + # Get unique points + unique_pts = np.unique(valid_pts, axis=0) + + # Check for a valid or existing convex hull + if hull is None or len(unique_pts) < 3: + # Return two vectors of NaNs if not valid hull + return (np.array([[np.nan, np.nan]]), np.array([[np.nan, np.nan]])) + + # Ensure rn_pts does not contain NaN values + rn_pts_valid = rn_pts[~np.isnan(rn_pts).any(axis=1)] + # Need at least two points to define a line + if len(rn_pts_valid) < 2: + return (np.array([[np.nan, np.nan]]), np.array([[np.nan, np.nan]])) + + # Ensuring r0_pts does not contain NaN values + r0_pts_valid = r0_pts[~np.isnan(r0_pts).any(axis=1)] + + # Get the vertices of the convex hull + hull_vertices = hull.points[hull.vertices] + + # Find the leftmost and rightmost r0 point + leftmost_r0 = r0_pts_valid[np.argmin(r0_pts_valid[:, 0])] + rightmost_r0 = r0_pts_valid[np.argmax(r0_pts_valid[:, 0])] + + # Check if these points are on the convex hull + is_leftmost_on_hull = any( + np.array_equal(leftmost_r0, vertex) for vertex in hull_vertices + ) + is_rightmost_on_hull = any( + np.array_equal(rightmost_r0, vertex) for vertex in hull_vertices + ) + + # Initialize vectors + leftmost_vector = np.array([[np.nan, np.nan]]) + rightmost_vector = np.array([[np.nan, np.nan]]) + if not is_leftmost_on_hull and not is_rightmost_on_hull: + # If leftmost and rightmost r0 points are not on the convex hull return NaNs + return leftmost_vector, rightmost_vector + + # Attempt to get the line equation between the leftmost and rightmost rn nodes + try: + leftmost_rn = rn_pts[np.argmin(rn_pts[:, 0])] + rightmost_rn = rn_pts[np.argmax(rn_pts[:, 0])] + m, b = get_line_equation_from_points(leftmost_rn, rightmost_rn) + except Exception: + # If line equation cannot be found, return NaNs + return leftmost_vector, rightmost_vector + + # Find the leftmost and rightmost points + leftmost_pt = np.nanmin(unique_pts[:, 0]) + rightmost_pt = np.nanmax(unique_pts[:, 0]) + + # Define how far to extend the line in terms of x + x_min_extended = leftmost_pt # Far left point + x_max_extended = rightmost_pt # Far right point + + # Calculate the corresponding y-values using the line equation + y_min_extended = m * x_min_extended + b + y_max_extended = m * x_max_extended + b + + # Create the extended line + extended_line = LineString( + [(x_min_extended, y_min_extended), (x_max_extended, y_max_extended)] + ) + + # Create a LineString that represents the perimeter of the convex hull + hull_perimeter = LineString( + hull.points[hull.vertices].tolist() + [hull.points[hull.vertices[0]].tolist()] + ) + + # Find the intersection between the hull perimeter and the extended line + intersection = extended_line.intersection(hull_perimeter) + + # Get the intersection points + if not intersection.is_empty: + intersect_points = ( + np.array([[point.x, point.y] for point in intersection.geoms]) + if intersection.geom_type == "MultiPoint" + else np.array([[intersection.x, intersection.y]]) + ) + else: + # Return two vectors of NaNs if there is no intersection + return leftmost_vector, rightmost_vector + + # Get the leftmost and rightmost intersection points + leftmost_intersect = intersect_points[np.argmin(intersect_points[:, 0])] + rightmost_intersect = intersect_points[np.argmax(intersect_points[:, 0])] + + # Make a vector from the leftmost r0 point to the leftmost intersection point + leftmost_vector = (leftmost_intersect - leftmost_r0).reshape(1, -1) + + # Make a vector from the rightmost r0 point to the rightmost intersection point + rightmost_vector = (rightmost_intersect - rightmost_r0).reshape(1, -1) + + return leftmost_vector, rightmost_vector + + +def get_chull_intersection_vectors_left( + vectors: Tuple[np.ndarray, np.ndarray] +) -> np.ndarray: + """Get the vector from the top left point to the left intersection point. + + Args: + vectors: Tuple containing two numpy arrays: + - The first is the vector from the top left point to the left intersection point. + - The second is the vector from the top right point to the right intersection point. + + Returns: + leftmost_vector: the vector from the top left point to the left intersection point. + """ + return vectors[0] + + +def get_chull_intersection_vectors_right( + vectors: Tuple[np.ndarray, np.ndarray] +) -> np.ndarray: + """Get the vector from the top right point to the right intersection point. + + Args: + vectors: Tuple containing two numpy arrays: + - The first is the vector from the top left point to the left intersection point. + - The second is the vector from the top right point to the right intersection point. + + Returns: + rightmost_vector: the vector from the top right point to the right intersection point. + """ + return vectors[1] diff --git a/sleap_roots/lengths.py b/sleap_roots/lengths.py index d623f30..127a845 100644 --- a/sleap_roots/lengths.py +++ b/sleap_roots/lengths.py @@ -1,4 +1,5 @@ """Get length-related traits.""" + import numpy as np from typing import Union @@ -7,7 +8,7 @@ def get_max_length_pts(pts: np.ndarray) -> np.ndarray: """Points of the root with maximum length (intended for primary root traits). Args: - pts (np.ndarray): Root landmarks as array of shape `(instances, nodes, 2)`. + pts: Root landmarks as array of shape `(instances, nodes, 2)`. Returns: np.ndarray: Array of points with shape `(nodes, 2)` from the root with maximum @@ -15,7 +16,7 @@ def get_max_length_pts(pts: np.ndarray) -> np.ndarray: """ # Return NaN points if the input array is empty if len(pts) == 0: - return np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) + return np.array([[np.nan, np.nan]]) # Check if pts has the correct shape, raise error if it does not if pts.ndim != 3 or pts.shape[2] != 2: @@ -68,7 +69,8 @@ def get_root_lengths(pts: np.ndarray) -> np.ndarray: segment_lengths = np.linalg.norm(segment_diffs, axis=-1) # Add the segments together to get the total length using nansum total_lengths = np.nansum(segment_lengths, axis=-1) - # Find the NaN segment lengths and record NaN in place of 0 when finding the total length + # Find the NaN segment lengths and record NaN in place of 0 when finding the total + # length total_lengths[np.isnan(segment_lengths).all(axis=-1)] = np.nan # If there is 1 instance, return a scalar instead of an array of length 1 @@ -78,39 +80,6 @@ def get_root_lengths(pts: np.ndarray) -> np.ndarray: return total_lengths -def get_root_lengths_max(pts: np.ndarray) -> np.ndarray: - """Return maximum root length for all roots in a frame. - - Args: - pts: root landmarks as array of shape `(instance, nodes, 2)` or lengths - `(instances)`. - - Returns: - Scalar of the maximum root length. - """ - # If the pts are NaNs, return NaN - if np.isnan(pts).all(): - return np.nan - - if pts.ndim not in (1, 3): - raise ValueError( - "Input array must be 1-dimensional (n_lengths) or " - "3-dimensional (n_roots, n_nodes, 2)." - ) - - # If the input array has 3 dimensions, calculate the root lengths, - # otherwise, assume the input array already contains the root lengths - if pts.ndim == 3: - root_lengths = get_root_lengths( - pts - ) # Assuming get_root_lengths returns an array of shape (instances) - max_length = np.nanmax(root_lengths) - else: - max_length = np.nanmax(pts) - - return max_length - - def get_curve_index( lengths: Union[float, np.ndarray], base_tip_dists: Union[float, np.ndarray] ) -> Union[float, np.ndarray]: diff --git a/sleap_roots/networklength.py b/sleap_roots/networklength.py index 90c0688..7db6ad5 100644 --- a/sleap_roots/networklength.py +++ b/sleap_roots/networklength.py @@ -3,7 +3,7 @@ import numpy as np from shapely import LineString, Polygon from sleap_roots.lengths import get_max_length_pts -from typing import Tuple, Union +from typing import Tuple, Union, List, Optional def get_bbox(pts: np.ndarray) -> Tuple[float, float, float, float]: @@ -60,44 +60,42 @@ def get_network_width_depth_ratio( def get_network_length( - primary_length: float, - lateral_lengths: Union[float, np.ndarray], - monocots: bool = False, + lengths0: Union[float, np.ndarray], + *args: Optional[Union[float, np.ndarray]], ) -> float: """Return the total root network length given primary and lateral root lengths. Args: - primary_length: Primary root length. - lateral_lengths: Either a float representing the length of a single lateral - root or an array of lateral root lengths with shape `(instances,)`. - monocots: A boolean value, where True is rice. + lengths0: Either a float representing the length of a single + root or an array of root lengths with shape `(instances,)`. + *args: Additional optional floats representing the lengths of single + roots or arrays of root lengths with shape `(instances,)`. Returns: Total length of root network. """ - # Ensure primary_length is a scalar - if not isinstance(primary_length, (float, np.float64)): - raise ValueError("Input primary_length must be a scalar value.") - - # Ensure lateral_lengths is either a scalar or has the correct shape - if not ( - isinstance(lateral_lengths, (float, np.float64)) or lateral_lengths.ndim == 1 - ): - raise ValueError( - "Input lateral_lengths must be a scalar or have shape (instances,)." - ) - - # Calculate the total lateral root length using np.nansum - total_lateral_length = np.nansum(lateral_lengths) - - if monocots: - length = total_lateral_length - else: - # Calculate the total root network length using np.nansum so the total length - # will not be NaN if one of primary or lateral lengths are NaN - length = np.nansum([primary_length, total_lateral_length]) - - return length + # Initialize an empty list to store the lengths + all_lengths = [] + # Loop over the input arrays + for length in [lengths0] + list(args): + if length is None: + continue # Skip None values + # Ensure length is either a scalar or has the correct shape + if not (np.isscalar(length) or (hasattr(length, "ndim") and length.ndim == 1)): + raise ValueError( + "Input length must be a scalar or have shape (instances,)." + ) + # Add the length to the list + if np.isscalar(length): + all_lengths.append(length) + else: + all_lengths.extend(list(length)) + + # Calculate the total root network length using np.nansum so the total length + # will not be NaN if one of primary or lateral lengths are NaN + total_network_length = np.nansum(all_lengths) + + return total_network_length def get_network_solidity( @@ -121,60 +119,34 @@ def get_network_solidity( def get_network_distribution( - primary_pts: np.ndarray, - lateral_pts: np.ndarray, + pts_list: List[np.ndarray], bounding_box: Tuple[float, float, float, float], fraction: float = 2 / 3, - monocots: bool = False, ) -> float: """Return the root length in the lower fraction of the plant. Args: - primary_pts: Array of primary root landmarks. Can have shape `(nodes, 2)` or - `(1, nodes, 2)`. - lateral_pts: Array of lateral root landmarks with shape `(instances, nodes, 2)`. + pts_list: A list of arrays, each having shape `(nodes, 2)`. bounding_box: Tuple in the form `(left_x, top_y, width, height)`. fraction: Lower fraction value. Defaults to 2/3. - monocots: A boolean value, where True indicates rice. Defaults to False. Returns: Root network length in the lower fraction of the plant. """ - # Input validation - if primary_pts.ndim not in [2, 3]: + # Input validation for pts_list + if any(pts.ndim != 2 or pts.shape[-1] != 2 for pts in pts_list): raise ValueError( - "primary_pts should have a shape of `(nodes, 2)` or `(1, nodes, 2)`." + "Each pts array in pts_list should have a shape of `(nodes, 2)`." ) - if primary_pts.ndim == 2 and primary_pts.shape[-1] != 2: - raise ValueError("primary_pts should have a shape of `(nodes, 2)`.") - - if primary_pts.ndim == 3 and primary_pts.shape[-1] != 2: - raise ValueError("primary_pts should have a shape of `(1, nodes, 2)`.") - - if lateral_pts.ndim != 3 or lateral_pts.shape[-1] != 2: - raise ValueError("lateral_pts should have a shape of `(instances, nodes, 2)`.") - + # Input validation for bounding_box if len(bounding_box) != 4: raise ValueError( - "bounding_box should be in the form `(left_x, top_y, width, height)`." + "bounding_box must contain exactly 4 elements: `(left_x, top_y, width, height)`." ) - # Make sure the longest primary root is used - if primary_pts.ndim == 3: - primary_pts = get_max_length_pts(primary_pts) # shape is (nodes, 2) - - # Make primary_pts and lateral_pts have the same dimension of 3 - primary_pts = ( - primary_pts[np.newaxis, :, :] if primary_pts.ndim == 2 else primary_pts - ) - # Filter out NaN values - primary_pts = [root[~np.isnan(root).any(axis=1)] for root in primary_pts] - lateral_pts = [root[~np.isnan(root).any(axis=1)] for root in lateral_pts] - - # Collate root points. - all_roots = primary_pts + lateral_pts if not monocots else lateral_pts + pts_list = [pts[~np.isnan(pts).any(axis=-1)] for pts in pts_list] # Get the vertices of the bounding box left_x, top_y, width, height = bounding_box @@ -185,7 +157,6 @@ def get_network_distribution( return np.nan # Convert lower bounding box to polygon - # Vertices are in counter-clockwise order lower_box = Polygon( [ [left_x, top_y + (height - lower_height)], @@ -197,7 +168,7 @@ def get_network_distribution( # Calculate length of roots within the lower bounding box network_length = 0 - for root in all_roots: + for root in pts_list: if len(root) > 1: # Ensure that root has more than one point root_poly = LineString(root) lower_intersection = root_poly.intersection(lower_box) @@ -208,53 +179,27 @@ def get_network_distribution( def get_network_distribution_ratio( - primary_length: float, - lateral_lengths: Union[float, np.ndarray], + network_length: float, network_length_lower: float, - fraction: float = 2 / 3, - monocots: bool = False, ) -> float: - """Return ratio of the root length in the lower fraction over all root length. + """Return ratio of the root length in the lower fraction to total root length. Args: - primary_length: Primary root length. - lateral_lengths: Lateral root lengths. Can be a single float (for one root) - or an array of floats (for multiple roots). network_length_lower: The root length in the lower network. - fraction: The fraction of the network considered as 'lower'. Defaults to 2/3. - monocots: A boolean value, where True indicates rice. Defaults to False. + network_length: Total root length of network. Returns: Float of ratio of the root network length in the lower fraction of the plant - over all root length. + over the total root length. """ # Ensure primary_length is a scalar - if not isinstance(primary_length, (float, np.float64)): - raise ValueError("Input primary_length must be a scalar value.") - - # Ensure lateral_lengths is either a scalar or a 1-dimensional array - if not isinstance(lateral_lengths, (float, np.float64, np.ndarray)): - raise ValueError( - "Input lateral_lengths must be a scalar or a 1-dimensional array." - ) - - # If lateral_lengths is an ndarray, it must be one-dimensional - if isinstance(lateral_lengths, np.ndarray) and lateral_lengths.ndim != 1: - raise ValueError("Input lateral_lengths array must have shape (instances,).") + if not isinstance(network_length, (float, np.float64)): + raise ValueError("Input network_length must be a scalar value.") # Ensure network_length_lower is a scalar if not isinstance(network_length_lower, (float, np.float64)): raise ValueError("Input network_length_lower must be a scalar value.") - # Calculate the total lateral root length - total_lateral_length = np.nansum(lateral_lengths) - - # Determine total root length based on monocots flag - if monocots: - total_root_length = total_lateral_length - else: - total_root_length = np.nansum([primary_length, total_lateral_length]) - # Calculate the ratio - ratio = network_length_lower / total_root_length + ratio = network_length_lower / network_length return ratio diff --git a/sleap_roots/points.py b/sleap_roots/points.py index 2620939..e5b8280 100644 --- a/sleap_roots/points.py +++ b/sleap_roots/points.py @@ -1,52 +1,287 @@ """Get traits related to the points.""" import numpy as np +from typing import List, Optional, Tuple -def get_all_pts_array( - primary_max_length_pts: np.ndarray, lateral_pts: np.ndarray, monocots: bool = False -) -> np.ndarray: +def get_count(pts: np.ndarray): + """Get number of roots. + + Args: + pts: Root landmarks as array of shape `(instances, nodes, 2)`. + + Return: + Scalar of number of roots. + """ + # The number of roots is the number of instances + count = pts.shape[0] + return count + + +def join_pts(pts0: np.ndarray, *args: Optional[np.ndarray]) -> List[np.ndarray]: + """Join an arbitrary number of points arrays and return them as a list. + + Args: + pts0: The first array of points. Should have shape `(instances, nodes, 2)` + or `(nodes, 2)`. + *args: Additional optional arrays of points. Each should have shape + `(instances, nodes, 2)` or `(nodes, 2)`. + + Returns: + A list of arrays, each having shape `(nodes, 2)`. + """ + # Initialize an empty list to store the points + all_pts = [] + # Loop over the input arrays + for pts in [pts0] + list(args): + if pts is None: + continue # Skip None values + + # If an array has shape `(nodes, 2)`, expand dimensions to `(1, nodes, 2)` + if pts.ndim == 2 and pts.shape[-1] == 2: + pts = pts[np.newaxis, :, :] + + # Validate the shape of each array + if pts.ndim != 3 or pts.shape[-1] != 2: + raise ValueError( + "Points should have a shape of `(instances, nodes, 2)` or `(nodes, 2)`." + ) + + # Add the points to the list + all_pts.extend(list(pts)) + + return all_pts + + +def get_all_pts_array(pts0: np.ndarray, *args: Optional[np.ndarray]) -> np.ndarray: """Get all landmark points within a given frame as a flat array of coordinates. Args: - primary_max_length_pts: Points of the primary root with maximum length of shape - `(nodes, 2)`. - lateral_pts: Lateral root points of shape `(instances, nodes, 2)`. - monocots: If False (default), returns a combined array of primary and lateral - root points. If True, returns only lateral root points. + pts0: The first array of points. Should have shape `(instances, nodes, 2)` + or `(nodes, 2)`. + *args: Additional optional arrays of points. Each should have shape + `(instances, nodes, 2)` or `(nodes, 2)`. Returns: A 2D array of shape (n_points, 2), containing the coordinates of all extracted points. """ - # Check if the input arrays have the right number of dimensions - if primary_max_length_pts.ndim != 2 or lateral_pts.ndim != 3: + # Initialize an empty list to store the points + concatenated_pts = [] + + # Loop over the input arrays + for pts in [pts0] + list(args): + if pts is None: + continue + + # Check if the array has the right number of dimensions + if pts.ndim not in [2, 3]: + raise ValueError("Each input array should be 2D or 3D.") + + # Check if the last dimension of the array has size 2 + # (representing x and y coordinates) + if pts.shape[-1] != 2: + raise ValueError( + "The last dimension should have size 2, representing x and y coordinates." + ) + + # Flatten the array to 2D and append to list + flat_pts = pts.reshape(-1, 2) + concatenated_pts.append(flat_pts) + + # Concatenate all points into a single array + return np.concatenate(concatenated_pts, axis=0) + + +def get_nodes(pts: np.ndarray, node_index: int) -> np.ndarray: + """Extracts the (x, y) coordinates of a specified node. + + Args: + pts: An array of points. For multiple instances, the shape should be + (instances, nodes, 2). For a single instance,the shape should be (nodes, 2). + node_index: The index of the node for which to extract the coordinates, based on + the node's position in the sequence of connected nodes (0-based indexing). + + Returns: + np.ndarray: An array of (x, y) coordinates for the specified node. For multiple + instances, the shape will be (instances, 2). For a single instance, the + shape will be (2,). + + Raises: + ValueError: If node_index is out of bounds for the number of nodes. + """ + # Adjust for a single instance with shape (nodes, 2) + if pts.ndim == 2: + if not 0 <= node_index < pts.shape[0]: + raise ValueError("node_index is out of bounds for the number of nodes.") + # Return a (2,) shape array for the node coordinates in a single instance + return pts[node_index, :] + + # Handle multiple instances with shape (instances, nodes, 2) + elif pts.ndim == 3: + if not 0 <= node_index < pts.shape[1]: + raise ValueError("node_index is out of bounds for the number of nodes.") + # Return (instances, 2) shape array for the node coordinates across instances + return pts[:, node_index, :] + + else: raise ValueError( - "Input arrays should have the correct number of dimensions:" - "primary_max_length_pts should be 2-dimensional and lateral_pts should be" - "3-dimensional." + "Input array should have shape (nodes, 2) for a single instance " + "or (instances, nodes, 2) for multiple instances." ) - # Check if the last dimension of the input arrays has size 2 - # (representing x and y coordinates) - if primary_max_length_pts.shape[-1] != 2 or lateral_pts.shape[-1] != 2: - raise ValueError( - "The last dimension of the input arrays should have size 2, representing x" - "and y coordinates." + +def get_root_vectors(start_nodes: np.ndarray, end_nodes: np.ndarray) -> np.ndarray: + """Calculate the vector from start to end for each instance in a set of points. + + Args: + start_nodes: array of points with shape (instances, 2) or (2,) representing the + start node in each instance. + end_nodes: array of points with shape (instances, 2) or (2,) representing the + end node in each instance. + + Returns: + An array of vectors with shape (instances, 2), representing the vector from start + to end for each instance. + """ + # Ensure that the start and end nodes have the same shapes + if start_nodes.shape != end_nodes.shape: + raise ValueError("start_nodes and end_nodes should have the same shape.") + # Handle single instances with shape (2,) + if start_nodes.ndim == 1: + start_nodes = start_nodes[np.newaxis, :] + if end_nodes.ndim == 1: + end_nodes = end_nodes[np.newaxis, :] + # Calculate the vectors from start to end for each instance + vectors = start_nodes - end_nodes + return vectors + + +def get_left_right_normalized_vectors( + r0_pts: np.ndarray, r1_pts: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + """Get the unit vectors formed from r0 to r1 on the left and right sides of a crown root system. + + Args: + r0_pts: An array of points representing the r0 nodes, with shape (instances, 2), + where instances are different observations of r0 points, and 2 represents + the x and y coordinates. + r1_pts: An array of points representing the r1 nodes, similar in structure to r0_pts. + + Returns: + A tuple containing two np.ndarray objects: + - The first is a normalized vector from r0 to r1 on the left side, or a vector + of NaNs if normalization fails. + - The second is a normalized vector from r0 to r1 on the right side, or a vector + of NaNs if normalization fails. + """ + # Validate input shapes and ensure there are multiple instances for comparison + if ( + r0_pts.ndim == 2 + and r1_pts.ndim == 2 + and r0_pts.shape == r1_pts.shape + and r0_pts.shape[0] > 1 + ): + # Find indices of the leftmost and rightmost r0 and r1 points + leftmost_r0_index = np.nanargmin(r0_pts[:, 0]) + rightmost_r0_index = np.nanargmax(r0_pts[:, 0]) + leftmost_r1_index = np.nanargmin(r1_pts[:, 0]) + rightmost_r1_index = np.nanargmax(r1_pts[:, 0]) + + # Extract the corresponding r0 and r1 points for leftmost and rightmost nodes + r0_left = r0_pts[leftmost_r0_index] + r1_left = r1_pts[leftmost_r1_index] + r0_right = r0_pts[rightmost_r0_index] + r1_right = r1_pts[rightmost_r1_index] + + # Calculate the vectors from r0 to r1 for both the leftmost and rightmost points + vector_left = r1_left - r0_left + vector_right = r1_right - r0_right + + # Calculate norms of both vectors for normalization + norm_left = np.linalg.norm(vector_left) + norm_right = np.linalg.norm(vector_right) + + # Normalize the vectors if their norms are non-zero + # otherwise, return vectors filled with NaNs + norm_vector_left = ( + vector_left / norm_left if norm_left > 0 else np.array([np.nan, np.nan]) ) + norm_vector_right = ( + vector_right / norm_right if norm_right > 0 else np.array([np.nan, np.nan]) + ) + + return norm_vector_left, norm_vector_right + else: + # Return pairs of NaN vectors if inputs are invalid or do not meet the requirements + return np.array([np.nan, np.nan]), np.array([np.nan, np.nan]) + + +def get_left_normalized_vector( + normalized_vectors: Tuple[np.ndarray, np.ndarray] +) -> np.ndarray: + """Get the normalized vector from r0 to r1 on the left side of a crown root system. + + Args: + normalized_vectors: A tuple containing two np.ndarray objects: + - The first is a normalized vector from r0 to r1 on the left side, or a vector + of NaNs if normalization fails. + - The second is a normalized vector from r0 to r1 on the right side, or a vector + of NaNs if normalization fails. + + Returns: + np.ndarray: A normalized vector from r0 to r1 on the left side, or a vector of NaNs + if normalization fails. + """ + return normalized_vectors[0] + + +def get_right_normalized_vector( + normalized_vectors: Tuple[np.ndarray, np.ndarray] +) -> np.ndarray: + """Get the normalized vector from r0 to r1 on the right side of a crown root system. + + Args: + normalized_vectors: A tuple containing two np.ndarray objects: + - The first is a normalized vector from r0 to r1 on the left side, or a vector + of NaNs if normalization fails. + - The second is a normalized vector from r0 to r1 on the right side, or a vector + of NaNs if normalization fails. + + Returns: + np.ndarray: A normalized vector from r0 to r1 on the right side, or a vector of NaNs + if normalization fails. + """ + return normalized_vectors[1] + + +def get_line_equation_from_points(pts1: np.ndarray, pts2: np.ndarray): + """Calculate the slope (m) and y-intercept (b) of the line connecting two points. + + Args: + pts1: First point as (x, y). 1D array of shape (2,). + pts2: Second point as (x, y). 1D array of shape (2,). + + Returns: + A tuple (m, b) representing the slope and y-intercept of the line. If the line is + vertical, NaNs are returned. + """ + # Convert inputs to arrays if they're not already + pts1 = np.asarray(pts1) + pts2 = np.asarray(pts2) - # Flatten the arrays to 2D - primary_max_length_pts = primary_max_length_pts.reshape(-1, 2) - lateral_pts = lateral_pts.reshape(-1, 2) + # Validate input shapes + if pts1.ndim != 1 or pts1.shape[0] != 2 or pts2.ndim != 1 or pts2.shape[0] != 2: + raise ValueError("Each input point must be a 1D array of shape (2,).") - # Combine points - if monocots: - pts_all_array = lateral_pts + # If the line is vertical return NaNs + if pts1[0] == pts2[0]: + return np.nan, np.nan else: - # Check if the data types of the arrays are compatible - if primary_max_length_pts.dtype != lateral_pts.dtype: - raise ValueError("Input arrays should have the same data type.") + # Calculate the slope + m = (pts2[1] - pts1[1]) / (pts2[0] - pts1[0]) - pts_all_array = np.concatenate((primary_max_length_pts, lateral_pts), axis=0) + # Calculate the y-intercept + b = pts1[1] - m * pts1[0] - return pts_all_array + return m, b diff --git a/sleap_roots/scanline.py b/sleap_roots/scanline.py index b623fe9..e447829 100644 --- a/sleap_roots/scanline.py +++ b/sleap_roots/scanline.py @@ -1,16 +1,13 @@ """Get intersections between roots and horizontal scan lines.""" import numpy as np -import math +from typing import List def count_scanline_intersections( - primary_pts: np.ndarray, - lateral_pts: np.ndarray, + pts_list: List[np.ndarray], height: int = 1080, - width: int = 2048, n_line: int = 50, - monocots: bool = False, ) -> np.ndarray: """Count intersections of roots with a series of horizontal scanlines. @@ -19,33 +16,19 @@ def count_scanline_intersections( are equally spaced across the specified height. Args: - primary_pts: Array of primary root landmarks of shape `(nodes, 2)`. - Will be reshaped internally to `(1, nodes, 2)`. - lateral_pts: Array of lateral root landmarks with shape - `(instances, nodes, 2)`. + pts_list: A list of arrays, each having shape `(nodes, 2)`. height: The height of the image or cylinder. Defaults to 1080. - width: The width of the image or cylinder. Defaults to 2048. n_line: Number of scanlines to use. Defaults to 50. - monocots: If `True`, only uses lateral roots (e.g., for rice). - If `False`, uses both primary and lateral roots (e.g., for dicots). - Defaults to `False`. Returns: An array with shape `(n_line,)` representing the number of intersections of roots with each scanline. """ - # Input validation - if primary_pts.ndim != 2 or primary_pts.shape[-1] != 2: - raise ValueError("primary_pts should have a shape of `(nodes, 2)`.") - - if lateral_pts.ndim != 3 or lateral_pts.shape[-1] != 2: - raise ValueError("lateral_pts should have a shape of `(instances, nodes, 2)`.") - - # Reshape primary_pts to have three dimensions - primary_pts = primary_pts[np.newaxis, :, :] - - # Collate root points. - all_roots = list(primary_pts) + list(lateral_pts) if not monocots else lateral_pts + # Input validation for pts_list + if any(pts.ndim != 2 or pts.shape[-1] != 2 for pts in pts_list): + raise ValueError( + "Each pts array in pts_list should have a shape of `(nodes, 2)`." + ) # Calculate the interval between two scanlines interval = height / (n_line - 1) @@ -57,7 +40,7 @@ def count_scanline_intersections( y_coord = interval * i line_intersections = 0 - for root_points in all_roots: + for root_points in pts_list: # Remove NaN values valid_points = root_points[(~np.isnan(root_points)).any(axis=1)] diff --git a/sleap_roots/series.py b/sleap_roots/series.py index 45ac8ad..c23d9d6 100644 --- a/sleap_roots/series.py +++ b/sleap_roots/series.py @@ -2,59 +2,120 @@ import attrs import numpy as np -from pathlib import Path import sleap_io as sio -from typing import Optional, Tuple, List, Union - import matplotlib import matplotlib.pyplot as plt import seaborn as sns +from typing import Dict, Optional, Tuple, List, Union +from pathlib import Path + @attrs.define class Series: """Data and predictions for a single image series. Attributes: - h5_path: Path to the HDF5-formatted image series. - primary_labels: A `sio.Labels` corresponding to the primary root predictions. - lateral_labels: A `sio.Labels` corresponding to the lateral root predictions. - video: A `sio.Video` corresponding to the image series. + h5_path: Optional path to the HDF5-formatted image series. + primary_labels: Optional `sio.Labels` corresponding to the primary root predictions. + lateral_labels: Optional `sio.Labels` corresponding to the lateral root predictions. + crown_labels: Optional `sio.Labels` corresponding to the crown predictions. + video: Optional `sio.Video` corresponding to the image series. + + Methods: + load: Load a set of predictions for this series. + __len__: Length of the series (number of images). + __getitem__: Return labeled frames for predictions. + __iter__: Iterator for looping through predictions. + get_frame: Return labeled frames for predictions. + plot: Plot predictions on top of the image. + get_primary_points: Get primary root points. + get_lateral_points: Get lateral root points. + get_crown_points: Get crown root points. + + Properties: + series_name: Name of the series derived from the HDF5 filename. """ h5_path: Optional[str] = None primary_labels: Optional[sio.Labels] = None lateral_labels: Optional[sio.Labels] = None + crown_labels: Optional[sio.Labels] = None video: Optional[sio.Video] = None @classmethod def load( cls, h5_path: str, - primary_name: str = "primary_multi_day", - lateral_name: str = "lateral__nodes", - ): + primary_name: Optional[str] = None, + lateral_name: Optional[str] = None, + crown_name: Optional[str] = None, + ) -> "Series": """Load a set of predictions for this series. Args: h5_path: Path to the HDF5-formatted image series. - primary_name: Name of the primary root predictions. The predictions file is - expected to be named `"{h5_path}.{primary_name}.predictions.slp"`. - lateral_name: Name of the lateral root predictions. The predictions file is - expected to be named `"{h5_path}.{lateral_name}.predictions.slp"`. + primary_name: Optional name of the primary root predictions file. If provided, + the file is expected to be named "{h5_path}.{primary_name}.predictions.slp". + lateral_name: Optional name of the lateral root predictions file. If provided, + the file is expected to be named "{h5_path}.{lateral_name}.predictions.slp". + crown_name: Optional name of the crown predictions file. If provided, + the file is expected to be named "{h5_path}.{crown_name}.predictions.slp". + + Returns: + An instance of Series loaded with the specified predictions. """ - primary_path = ( - Path(h5_path).with_suffix(f".{primary_name}.predictions.slp").as_posix() - ) - lateral_path = ( - Path(h5_path).with_suffix(f".{lateral_name}.predictions.slp").as_posix() - ) + # Initialize the labels as None + primary_labels, lateral_labels, crown_labels = None, None, None + + # Attempt to load the predictions, with error handling + try: + if primary_name: + primary_path = ( + Path(h5_path) + .with_suffix(f".{primary_name}.predictions.slp") + .as_posix() + ) + if Path(primary_path).exists(): + primary_labels = sio.load_slp(primary_path) + else: + print(f"Primary prediction file not found: {primary_path}") + if lateral_name: + lateral_path = ( + Path(h5_path) + .with_suffix(f".{lateral_name}.predictions.slp") + .as_posix() + ) + if Path(lateral_path).exists(): + lateral_labels = sio.load_slp(lateral_path) + else: + print(f"Lateral prediction file not found: {lateral_path}") + if crown_name: + crown_path = ( + Path(h5_path) + .with_suffix(f".{crown_name}.predictions.slp") + .as_posix() + ) + if Path(crown_path).exists(): + crown_labels = sio.load_slp(crown_path) + else: + print(f"Crown prediction file not found: {crown_path}") + except Exception as e: + print(f"Error loading prediction files: {e}") + + # Attempt to load the video, with error handling + video = None + try: + video = sio.Video.from_filename(h5_path) if Path(h5_path).exists() else None + except Exception as e: + print(f"Error loading video file {h5_path}: {e}") return cls( - h5_path, - primary_labels=sio.load_slp(primary_path), - lateral_labels=sio.load_slp(lateral_path), - video=sio.Video.from_filename(h5_path), + h5_path=h5_path, + primary_labels=primary_labels, + lateral_labels=lateral_labels, + crown_labels=crown_labels, + video=video, ) @property @@ -66,8 +127,8 @@ def __len__(self) -> int: """Length of the series (number of images).""" return len(self.video) - def __getitem__(self, idx: int) -> Tuple[sio.LabeledFrame, sio.LabeledFrame]: - """Return labeled frames for primary and lateral predictions.""" + def __getitem__(self, idx: int) -> Dict[str, Optional[sio.LabeledFrame]]: + """Return labeled frames for primary and/or lateral and/or crown predictions.""" return self.get_frame(idx) def __iter__(self): @@ -75,23 +136,44 @@ def __iter__(self): for i in range(len(self)): yield self[i] - def get_frame(self, frame_idx: int) -> Tuple[sio.LabeledFrame, sio.LabeledFrame]: - """Return labeled frames for primary and lateral predictions. + def get_frame(self, frame_idx: int) -> dict: + """Return labeled frames for primary, lateral, and crown predictions. Args: frame_idx: Integer frame number. Returns: - Tuple of (primary_lf, lateral_lf) corresponding to the `sio.LabeledFrame` - from each set of predictions on the same frame. + Dictionary with keys 'primary', 'lateral', and 'crown', each corresponding + to the `sio.LabeledFrame` from each set of predictions on the same frame. If + any set of predictions is not available, its value will be None. """ - lf_primary = self.primary_labels.find( - self.primary_labels.video, frame_idx, return_new=True - )[0] - lf_lateral = self.lateral_labels.find( - self.lateral_labels.video, frame_idx, return_new=True - )[0] - return lf_primary, lf_lateral + frames = {} + + # For primary predictions + if self.primary_labels is not None: + frames["primary"] = self.primary_labels.find( + self.primary_labels.video, frame_idx, return_new=True + )[0] + else: + frames["primary"] = None + + # For lateral predictions + if self.lateral_labels is not None: + frames["lateral"] = self.lateral_labels.find( + self.lateral_labels.video, frame_idx, return_new=True + )[0] + else: + frames["lateral"] = None + + # For crown predictions + if self.crown_labels is not None: + frames["crown"] = self.crown_labels.find( + self.crown_labels.video, frame_idx, return_new=True + )[0] + else: + frames["crown"] = None + + return frames def plot(self, frame_idx: int, scale: float = 1.0, **kwargs): """Plot predictions on top of the image. @@ -101,10 +183,37 @@ def plot(self, frame_idx: int, scale: float = 1.0, **kwargs): scale: Relative size of the visualized image. Useful for plotting smaller images within notebooks. """ - primary_lf, lateral_lf = self.get_frame(frame_idx) - plot_img(primary_lf.image, scale=scale) - plot_instances(primary_lf.instances, cmap=["r"], **kwargs) - plot_instances(lateral_lf.instances, cmap=["g"], **kwargs) + # Retrieve all available frames + frames = self.get_frame(frame_idx) + + # Generate the color palette from seaborn + cmap = sns.color_palette("tab10") + + # Define the order of preference for the predictions for plotting the image + prediction_order = ["primary", "lateral", "crown"] + + # Variable to keep track if the image has been plotted + image_plotted = False + + # First, find the first available prediction to plot the image + for prediction in prediction_order: + labeled_frame = frames.get(prediction) + if labeled_frame is not None and not image_plotted: + # Plot the image + plot_img(labeled_frame.image, scale=scale) + # Set the flag to True to avoid plotting the image again + image_plotted = True + + # Then, iterate through all predictions to plot instances + for i, prediction in enumerate(prediction_order): + labeled_frame = frames.get(prediction) + if labeled_frame is not None: + # Use the color map index for each prediction type + # Modulo the length of the color map to avoid index out of range + color = cmap[i % len(cmap)] + + # Plot the instances + plot_instances(labeled_frame.instances, cmap=[color], **kwargs) def get_primary_points(self, frame_idx: int) -> np.ndarray: """Get primary root points. @@ -115,10 +224,16 @@ def get_primary_points(self, frame_idx: int) -> np.ndarray: Returns: Primary root points as array of shape `(n_instances, n_nodes, 2)`. """ - primary_lf, lateral_lf = self.get_frame(frame_idx) + # Retrieve all available frames + frames = self.get_frame(frame_idx) + # Get the primary labeled frame + primary_lf = frames.get("primary") + # Get the ground truth instances and unused predictions gt_instances_pr = primary_lf.user_instances + primary_lf.unused_predictions + # If there are no instances, return an empty array if len(gt_instances_pr) == 0: - return [] + primary_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) + # Otherwise, stack the instances into an array else: primary_pts = np.stack([inst.numpy() for inst in gt_instances_pr], axis=0) return primary_pts @@ -132,14 +247,43 @@ def get_lateral_points(self, frame_idx: int) -> np.ndarray: Returns: Lateral root points as array of shape `(n_instances, n_nodes, 2)`. """ - primary_lf, lateral_lf = self.get_frame(frame_idx) + # Retrieve all available frames + frames = self.get_frame(frame_idx) + # Get the lateral labeled frame + lateral_lf = frames.get("lateral") + # Get the ground truth instances and unused predictions gt_instances_lr = lateral_lf.user_instances + lateral_lf.unused_predictions + # If there are no instances, return an empty array if len(gt_instances_lr) == 0: - return [] + lateral_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) + # Otherwise, stack the instances into an array else: lateral_pts = np.stack([inst.numpy() for inst in gt_instances_lr], axis=0) return lateral_pts + def get_crown_points(self, frame_idx: int) -> np.ndarray: + """Get crown root points. + + Args: + frame_idx: Frame index. + + Returns: + Crown root points as array of shape `(n_instances, n_nodes, 2)`. + """ + # Retrieve all available frames + frames = self.get_frame(frame_idx) + # Get the crown labeled frame + crown_lf = frames.get("crown") + # Get the ground truth instances and unused predictions + gt_instances_cr = crown_lf.user_instances + crown_lf.unused_predictions + # If there are no instances, return an empty array + if len(gt_instances_cr) == 0: + crown_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) + # Otherwise, stack the instances into an array + else: + crown_pts = np.stack([inst.numpy() for inst in gt_instances_cr], axis=0) + return crown_pts + def find_all_series(data_folders: Union[str, List[str]]) -> List[str]: """Find all .h5 series from a list of folders. diff --git a/sleap_roots/tips.py b/sleap_roots/tips.py index feab7dc..c2588e8 100644 --- a/sleap_roots/tips.py +++ b/sleap_roots/tips.py @@ -29,67 +29,43 @@ def get_tips(pts: np.ndarray) -> np.ndarray: return tip_pts -def get_tip_xs(tip_pts: np.ndarray, flatten: bool = False) -> np.ndarray: +def get_tip_xs(tip_pts: np.ndarray) -> np.ndarray: """Get x coordinates of tip points. Args: - tip_pts: Root tips as array of shape `(instances, 2)` or `(2)` when there is - only one tip. - flatten: If `True`, return scalar (0-D) array if scalar (there is only 1 root). - Defaults to `False`. + tip_pts: Root tip points as array of shape `(instances, 2)` or `(2,)` when there + is only one tip. - Returns: - An array of the x-coordinates of tips (instances,) or () if `flatten` is `True`. + Return: + An array of tip x-coordinates (instances,) or (1,) when there is only one root. """ - # If the input is a single number (float or integer), raise an error - if isinstance(tip_pts, (np.floating, float, np.integer, int)): - raise ValueError("Input must be an array of shape `(instances, 2)` or `(2, )`.") - - # Check for the 2D shape of the input array - if tip_pts.ndim == 1: - # If shape is `(2,)`, then reshape it to `(1, 2)` for consistency - tip_pts = tip_pts.reshape(1, 2) - elif tip_pts.ndim != 2: - raise ValueError("Input array must be of shape `(instances, 2)` or `(2, )`.") - - # At this point, `tip_pts` should be of shape `(instances, 2)`. - tip_xs = tip_pts[:, 0] - - if flatten: - tip_xs = tip_xs.squeeze() - if tip_xs.size == 1: - tip_xs = tip_xs[()] + if tip_pts.ndim not in (1, 2): + raise ValueError( + "Input array must be 2-dimensional (instances, 2) or 1-dimensional (2,)." + ) + if tip_pts.shape[-1] != 2: + raise ValueError("Last dimension must be (x, y).") + + tip_xs = tip_pts[..., 0] return tip_xs -def get_tip_ys(tip_pts: np.ndarray, flatten: bool = False) -> np.ndarray: +def get_tip_ys(tip_pts: np.ndarray) -> np.ndarray: """Get y coordinates of tip points. Args: - tip_pts: Root tips as array of shape `(instances, 2)` or `(2)` when there is - only one tip. - flatten: If `True`, return scalar (0-D) array if scalar (there is only 1 root). - Defaults to `False`. + tip_pts: Root tip points as array of shape `(instances, 2)` or `(2,)` when there + is only one tip. Return: - An array of the y-coordinates of tips (instances,) or () if `flatten` is `True`. + An array of tip y-coordinates (instances,) or (1,) when there is only one root. """ - # If the input is a single number (float or integer), raise an error - if isinstance(tip_pts, (np.floating, float, np.integer, int)): - raise ValueError("Input must be an array of shape `(instances, 2)` or `(2, )`.") - - # Check for the 2D shape of the input array - if tip_pts.ndim == 1: - # If shape is `(2,)`, then reshape it to `(1, 2)` for consistency - tip_pts = tip_pts.reshape(1, 2) - elif tip_pts.ndim != 2: - raise ValueError("Input array must be of shape `(instances, 2)` or `(2, )`.") - - # At this point, `tip_pts` should be of shape `(instances, 2)`. - tip_ys = tip_pts[:, 1] - - if flatten: - tip_ys = tip_ys.squeeze() - if tip_ys.size == 1: - tip_ys = tip_ys[()] + if tip_pts.ndim not in (1, 2): + raise ValueError( + "Input array must be 2-dimensional (instances, 2) or 1-dimensional (2,)." + ) + if tip_pts.shape[-1] != 2: + raise ValueError("Last dimension must be (x, y).") + + tip_ys = tip_pts[..., 1] return tip_ys diff --git a/sleap_roots/trait_pipelines.py b/sleap_roots/trait_pipelines.py index 3723f38..672246c 100644 --- a/sleap_roots/trait_pipelines.py +++ b/sleap_roots/trait_pipelines.py @@ -9,7 +9,11 @@ import numpy as np import pandas as pd -from sleap_roots.angle import get_node_ind, get_root_angle +from sleap_roots.angle import ( + get_node_ind, + get_root_angle, + get_vector_angles_from_gravity, +) from sleap_roots.bases import ( get_base_ct_density, get_base_length, @@ -19,16 +23,21 @@ get_base_xs, get_base_ys, get_bases, - get_lateral_count, get_root_widths, ) from sleap_roots.convhull import ( get_chull_area, + get_chull_intersection_vectors, + get_chull_intersection_vectors_left, + get_chull_intersection_vectors_right, get_chull_line_lengths, get_chull_max_height, get_chull_max_width, get_chull_perimeter, get_convhull, + get_chull_areas_via_intersection, + get_chull_area_via_intersection_below, + get_chull_area_via_intersection_above, ) from sleap_roots.ellipse import ( fit_ellipse, @@ -45,7 +54,7 @@ get_network_solidity, get_network_width_depth_ratio, ) -from sleap_roots.points import get_all_pts_array +from sleap_roots.points import get_all_pts_array, get_count, get_nodes, join_pts from sleap_roots.scanline import ( count_scanline_intersections, get_scanline_first_ind, @@ -379,18 +388,16 @@ def compute_batch_traits( @attrs.define class DicotPipeline(Pipeline): - """Pipeline for computing traits for dicot plants. + """Pipeline for computing traits for dicot plants (primary + lateral roots). Attributes: img_height: Image height. - img_width: Image width. root_width_tolerance: Difference in projection norm between right and left side. n_scanlines: Number of scan lines, np.nan for no interaction. network_fraction: Length found in the lower fraction value of the network. """ img_height: int = 1080 - img_width: int = 2048 root_width_tolerance: float = 0.02 n_scanlines: int = 50 network_fraction: float = 2 / 3 @@ -413,10 +420,19 @@ def define_traits(self) -> List[TraitDef]: input_traits=["primary_max_length_pts", "lateral_pts"], scalar=False, include_in_csv=False, - kwargs={"monocots": False}, + kwargs={}, description="Landmark points within a given frame as a flat array" "of coordinates.", ), + TraitDef( + name="pts_list", + fn=join_pts, + input_traits=["primary_max_length_pts", "lateral_pts"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="A list of instance arrays, each having shape `(nodes, 2)`.", + ), TraitDef( name="root_widths", fn=get_root_widths, @@ -425,14 +441,13 @@ def define_traits(self) -> List[TraitDef]: include_in_csv=True, kwargs={ "tolerance": self.root_width_tolerance, - "monocots": False, "return_inds": False, }, description="Estimate root width using bases of lateral roots.", ), TraitDef( name="lateral_count", - fn=get_lateral_count, + fn=get_count, input_traits=["lateral_pts"], scalar=True, include_in_csv=True, @@ -472,7 +487,7 @@ def define_traits(self) -> List[TraitDef]: input_traits=["lateral_pts"], scalar=False, include_in_csv=False, - kwargs={"monocots": False}, + kwargs={}, description="Array of lateral bases `(instances, (x, y))`.", ), TraitDef( @@ -487,14 +502,12 @@ def define_traits(self) -> List[TraitDef]: TraitDef( name="scanline_intersection_counts", fn=count_scanline_intersections, - input_traits=["primary_max_length_pts", "lateral_pts"], + input_traits=["pts_list"], scalar=False, include_in_csv=True, kwargs={ "height": self.img_height, - "width": self.img_width, "n_line": self.n_scanlines, - "monocots": False, }, description="Array of intersections of each scanline `(n_scanlines,)`.", ), @@ -608,7 +621,7 @@ def define_traits(self) -> List[TraitDef]: input_traits=["primary_max_length_pts"], scalar=False, include_in_csv=False, - kwargs={"monocots": False}, + kwargs={}, description="Primary root base point.", ), TraitDef( @@ -623,10 +636,12 @@ def define_traits(self) -> List[TraitDef]: TraitDef( name="network_length_lower", fn=get_network_distribution, - input_traits=["primary_max_length_pts", "lateral_pts", "bounding_box"], + input_traits=["pts_list", "bounding_box"], scalar=True, include_in_csv=True, - kwargs={"fraction": self.network_fraction, "monocots": False}, + kwargs={ + "fraction": self.network_fraction, + }, description="Scalar of the root network length in the lower fraction " "of the plant.", ), @@ -645,7 +660,7 @@ def define_traits(self) -> List[TraitDef]: input_traits=["lateral_base_pts"], scalar=False, include_in_csv=True, - kwargs={"monocots": False}, + kwargs={}, description="Array of the y-coordinates of lateral bases " "`(instances,)`.", ), @@ -679,14 +694,10 @@ def define_traits(self) -> List[TraitDef]: TraitDef( name="network_distribution_ratio", fn=get_network_distribution_ratio, - input_traits=[ - "primary_length", - "lateral_lengths", - "network_length_lower", - ], + input_traits=["network_length", "network_length_lower"], scalar=True, include_in_csv=True, - kwargs={"fraction": self.network_fraction, "monocots": False}, + kwargs={}, description="Scalar of ratio of the root network length in the lower " "fraction of the plant over all root length.", ), @@ -696,7 +707,7 @@ def define_traits(self) -> List[TraitDef]: input_traits=["primary_length", "lateral_lengths"], scalar=True, include_in_csv=True, - kwargs={"monocots": False}, + kwargs={}, description="Scalar of all roots network length.", ), TraitDef( @@ -705,7 +716,7 @@ def define_traits(self) -> List[TraitDef]: input_traits=["primary_base_pt"], scalar=True, include_in_csv=False, - kwargs={"monocots": False}, + kwargs={}, description="Y-coordinate of the primary root base node.", ), TraitDef( @@ -714,7 +725,7 @@ def define_traits(self) -> List[TraitDef]: input_traits=["primary_tip_pt"], scalar=True, include_in_csv=True, - kwargs={"flatten": True}, + kwargs={}, description="Y-coordinate of the primary root tip node.", ), TraitDef( @@ -807,7 +818,7 @@ def define_traits(self) -> List[TraitDef]: input_traits=["lateral_base_ys", "primary_tip_pt_y"], scalar=True, include_in_csv=True, - kwargs={"monocots": False}, + kwargs={}, description="Scalar of base median ratio.", ), TraitDef( @@ -825,7 +836,7 @@ def define_traits(self) -> List[TraitDef]: input_traits=["primary_length", "base_length"], scalar=True, include_in_csv=True, - kwargs={"monocots": False}, + kwargs={}, description="Scalar of base length ratio.", ), TraitDef( @@ -882,38 +893,22 @@ def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, A - "primary_pts": Array of primary root points. - "lateral_pts": Array of lateral root points. """ - # Get the root instances. - primary, lateral = plant[frame_idx] - gt_instances_pr = primary.user_instances + primary.unused_predictions - gt_instances_lr = lateral.user_instances + lateral.unused_predictions - - # Convert the instances to numpy arrays. - if len(gt_instances_lr) == 0: - lateral_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) - else: - lateral_pts = np.stack([inst.numpy() for inst in gt_instances_lr], axis=0) - - if len(gt_instances_pr) == 0: - primary_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) - else: - primary_pts = np.stack([inst.numpy() for inst in gt_instances_pr], axis=0) - + primary_pts = plant.get_primary_points(frame_idx) + lateral_pts = plant.get_lateral_points(frame_idx) return {"primary_pts": primary_pts, "lateral_pts": lateral_pts} @attrs.define class YoungerMonocotPipeline(Pipeline): - """Pipeline for computing traits for young monocot plants (primary + seminal). + """Pipeline for computing traits for young monocot plants (primary + crown roots). Attributes: img_height: Image height. - img_width: Image width. n_scanlines: Number of scan lines, np.nan for no interaction. - network_fraction: Length found in the lower fraction value of the network. + network_fraction: Lower fraction value. Defaults to 2/3. """ img_height: int = 1080 - img_width: int = 2048 n_scanlines: int = 50 network_fraction: float = 2 / 3 @@ -932,119 +927,118 @@ def define_traits(self) -> List[TraitDef]: TraitDef( name="pts_all_array", fn=get_all_pts_array, - input_traits=["primary_max_length_pts", "main_pts"], + input_traits=["crown_pts"], scalar=False, include_in_csv=False, - kwargs={"monocots": True}, - description="Landmark points within a given frame as a flat array" + kwargs={}, + description="Crown root points within a given frame as a flat array" "of coordinates.", ), TraitDef( - name="main_count", - fn=get_lateral_count, - input_traits=["main_pts"], + name="crown_count", + fn=get_count, + input_traits=["crown_pts"], scalar=True, include_in_csv=True, kwargs={}, - description="Get the number of main roots.", + description="Get the number of crown roots.", ), TraitDef( - name="main_proximal_node_inds", + name="crown_proximal_node_inds", fn=get_node_ind, - input_traits=["main_pts"], + input_traits=["crown_pts"], scalar=False, include_in_csv=False, kwargs={"proximal": True}, - description="Get the indices of the proximal nodes of main roots.", + description="Get the indices of the proximal nodes of crown roots.", ), TraitDef( - name="main_distal_node_inds", + name="crown_distal_node_inds", fn=get_node_ind, - input_traits=["main_pts"], + input_traits=["crown_pts"], scalar=False, include_in_csv=False, kwargs={"proximal": False}, - description="Get the indices of the distal nodes of main roots.", + description="Get the indices of the distal nodes of crown roots.", ), TraitDef( - name="main_lengths", + name="crown_lengths", fn=get_root_lengths, - input_traits=["main_pts"], + input_traits=["crown_pts"], scalar=False, include_in_csv=True, kwargs={}, - description="Array of main root lengths of shape `(instances,)`.", + description="Array of crown root lengths of shape `(instances,)`.", ), TraitDef( - name="main_base_pts", + name="crown_base_pts", fn=get_bases, - input_traits=["main_pts"], + input_traits=["crown_pts"], scalar=False, include_in_csv=False, - kwargs={"monocots": False}, - description="Array of main bases `(instances, (x, y))`.", + kwargs={}, + description="Array of crown bases `(instances, (x, y))`.", ), TraitDef( - name="main_tip_pts", + name="crown_tip_pts", fn=get_tips, - input_traits=["main_pts"], + input_traits=["crown_pts"], scalar=False, include_in_csv=False, kwargs={}, - description="Array of main tips `(instances, (x, y))`.", + description="Array of crown tips `(instances, (x, y))`.", ), TraitDef( name="scanline_intersection_counts", fn=count_scanline_intersections, - input_traits=["primary_max_length_pts", "main_pts"], + input_traits=["crown_pts"], scalar=False, include_in_csv=True, kwargs={ "height": self.img_height, - "width": self.img_width, "n_line": self.n_scanlines, - "monocots": True, }, description="Array of intersections of each scanline" "`(n_scanlines,)`.", ), TraitDef( - name="main_angles_distal", + name="crown_angles_distal", fn=get_root_angle, - input_traits=["main_pts", "main_distal_node_inds"], + input_traits=["crown_pts", "crown_distal_node_inds"], scalar=False, include_in_csv=True, kwargs={"proximal": False, "base_ind": 0}, - description="Array of main distal angles in degrees `(instances,)`.", + description="Array of crown distal angles in degrees `(instances,)`.", ), TraitDef( - name="main_angles_proximal", + name="crown_angles_proximal", fn=get_root_angle, - input_traits=["main_pts", "main_proximal_node_inds"], + input_traits=["crown_pts", "crown_proximal_node_inds"], scalar=False, include_in_csv=True, kwargs={"proximal": True, "base_ind": 0}, - description="Array of main proximal angles in degrees " + description="Array of crown proximal angles in degrees " "`(instances,)`.", ), TraitDef( name="network_length_lower", fn=get_network_distribution, input_traits=[ - "primary_max_length_pts", - "main_pts", + "crown_pts", "bounding_box", ], scalar=True, include_in_csv=True, - kwargs={"fraction": self.network_fraction, "monocots": True}, + kwargs={ + "fraction": self.network_fraction, + }, description="Scalar of the root network length in the lower fraction " "of the plant.", ), TraitDef( name="ellipse", fn=fit_ellipse, - input_traits=["pts_all_array"], + input_traits=["crown_pts"], scalar=False, include_in_csv=False, kwargs={}, @@ -1055,7 +1049,7 @@ def define_traits(self) -> List[TraitDef]: TraitDef( name="bounding_box", fn=get_bbox, - input_traits=["pts_all_array"], + input_traits=["crown_pts"], scalar=False, include_in_csv=False, kwargs={}, @@ -1064,7 +1058,7 @@ def define_traits(self) -> List[TraitDef]: TraitDef( name="convex_hull", fn=get_convhull, - input_traits=["pts_all_array"], + input_traits=["crown_pts"], scalar=False, include_in_csv=False, kwargs={}, @@ -1125,7 +1119,7 @@ def define_traits(self) -> List[TraitDef]: input_traits=["primary_max_length_pts"], scalar=False, include_in_csv=False, - kwargs={"monocots": False}, + kwargs={}, description="Primary root base point.", ), TraitDef( @@ -1138,64 +1132,63 @@ def define_traits(self) -> List[TraitDef]: description="Primary root tip point.", ), TraitDef( - name="main_tip_xs", + name="crown_tip_xs", fn=get_tip_xs, - input_traits=["main_tip_pts"], + input_traits=["crown_tip_pts"], scalar=False, include_in_csv=True, kwargs={}, - description="Array of the x-coordinates of main tips `(instance,)`.", + description="Array of the x-coordinates of crown tips `(instance,)`.", ), TraitDef( - name="main_tip_ys", + name="crown_tip_ys", fn=get_tip_ys, - input_traits=["main_tip_pts"], + input_traits=["crown_tip_pts"], scalar=False, include_in_csv=True, kwargs={}, - description="Array of the y-coordinates of main tips `(instance,)`.", + description="Array of the y-coordinates of crown tips `(instance,)`.", ), TraitDef( name="network_distribution_ratio", fn=get_network_distribution_ratio, input_traits=[ - "primary_length", - "main_lengths", + "network_length", "network_length_lower", ], scalar=True, include_in_csv=True, - kwargs={"fraction": self.network_fraction, "monocots": False}, + kwargs={}, description="Scalar of ratio of the root network length in the lower" "fraction of the plant over all root length.", ), TraitDef( name="network_length", fn=get_network_length, - input_traits=["primary_length", "main_lengths"], + input_traits=["crown_lengths"], scalar=True, include_in_csv=True, - kwargs={"monocots": True}, + kwargs={}, description="Scalar of all roots network length.", ), TraitDef( - name="main_base_tip_dists", + name="crown_base_tip_dists", fn=get_base_tip_dist, - input_traits=["main_base_pts", "main_tip_pts"], + input_traits=["crown_base_pts", "crown_tip_pts"], scalar=False, include_in_csv=True, kwargs={}, description="Straight-line distance(s) from the base(s) to the" - "tip(s) of the main root(s).", + "tip(s) of the crown root(s).", ), TraitDef( - name="main_curve_indices", + name="crown_curve_indices", fn=get_base_tip_dist, - input_traits=["main_base_pts", "main_tip_pts"], + input_traits=["crown_base_pts", "crown_tip_pts"], scalar=False, include_in_csv=True, kwargs={}, - description="Curvature index for each main root.", + description="Curvature index for each crown root.", ), TraitDef( name="network_solidity", @@ -1213,7 +1206,7 @@ def define_traits(self) -> List[TraitDef]: input_traits=["primary_tip_pt"], scalar=True, include_in_csv=True, - kwargs={"flatten": True}, + kwargs={}, description="Y-coordinate of the primary root tip node.", ), TraitDef( @@ -1351,22 +1344,447 @@ def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, A Returns: A dictionary of initial traits with keys: - "primary_pts": Array of primary root points. - - "main_pts": Array of main root points. + - "crown_pts": Array of crown root points. """ - # Get the root instances. - primary, main = plant[frame_idx] - gt_instances_pr = primary.user_instances + primary.unused_predictions - gt_instances_lr = main.user_instances + main.unused_predictions - - # Convert the instances to numpy arrays. - if len(gt_instances_lr) == 0: - main_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) - else: - main_pts = np.stack([inst.numpy() for inst in gt_instances_lr], axis=0) + primary_pts = plant.get_primary_points(frame_idx) + crown_pts = plant.get_crown_points(frame_idx) + return {"primary_pts": primary_pts, "crown_pts": crown_pts} - if len(gt_instances_pr) == 0: - primary_pts = np.array([[(np.nan, np.nan), (np.nan, np.nan)]]) - else: - primary_pts = np.stack([inst.numpy() for inst in gt_instances_pr], axis=0) - return {"primary_pts": primary_pts, "main_pts": main_pts} +@attrs.define +class OlderMonocotPipeline(Pipeline): + """Pipeline for computing traits for older monocot plants (crown roots only). + + Attributes: + img_height: Image height. + n_scanlines: Number of scan lines, np.nan for no interaction. + network_fraction: Lower fraction value. Defaults to 2/3. + """ + + img_height: int = 1080 + n_scanlines: int = 50 + network_fraction: float = 2 / 3 + + def define_traits(self) -> List[TraitDef]: + """Define the trait computation pipeline for older monocot plants (crown roots).""" + trait_definitions = [ + TraitDef( + name="pts_all_array", + fn=get_all_pts_array, + input_traits=["crown_pts"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Landmark points within a given frame as a flat array" + "of coordinates.", + ), + TraitDef( + name="crown_count", + fn=get_count, + input_traits=["crown_pts"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Get the number of crown roots.", + ), + TraitDef( + name="crown_proximal_node_inds", + fn=get_node_ind, + input_traits=["crown_pts"], + scalar=False, + include_in_csv=False, + kwargs={"proximal": True}, + description="Get the indices of the proximal nodes of crown roots.", + ), + TraitDef( + name="crown_distal_node_inds", + fn=get_node_ind, + input_traits=["crown_pts"], + scalar=False, + include_in_csv=False, + kwargs={"proximal": False}, + description="Get the indices of the distal nodes of crown roots.", + ), + TraitDef( + name="crown_lengths", + fn=get_root_lengths, + input_traits=["crown_pts"], + scalar=False, + include_in_csv=True, + kwargs={}, + description="Array of crown root lengths of shape `(instances,)`.", + ), + TraitDef( + name="crown_base_pts", + fn=get_bases, + input_traits=["crown_pts"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Array of crown bases `(instances, (x, y))`.", + ), + TraitDef( + name="crown_tip_pts", + fn=get_tips, + input_traits=["crown_pts"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Array of crown tips `(instances, (x, y))`.", + ), + TraitDef( + name="scanline_intersection_counts", + fn=count_scanline_intersections, + input_traits=["crown_pts"], + scalar=False, + include_in_csv=True, + kwargs={ + "height": self.img_height, + "n_line": self.n_scanlines, + }, + description="Array of intersections of each scanline" + "`(n_scanlines,)`.", + ), + TraitDef( + name="crown_angles_distal", + fn=get_root_angle, + input_traits=["crown_pts", "crown_distal_node_inds"], + scalar=False, + include_in_csv=True, + kwargs={"proximal": False, "base_ind": 0}, + description="Array of crown distal angles in degrees `(instances,)`.", + ), + TraitDef( + name="crown_angles_proximal", + fn=get_root_angle, + input_traits=["crown_pts", "crown_proximal_node_inds"], + scalar=False, + include_in_csv=True, + kwargs={"proximal": True, "base_ind": 0}, + description="Array of crown proximal angles in degrees " + "`(instances,)`.", + ), + TraitDef( + name="network_length_lower", + fn=get_network_distribution, + input_traits=[ + "crown_pts", + "bounding_box", + ], + scalar=True, + include_in_csv=True, + kwargs={ + "fraction": self.network_fraction, + }, + description="Scalar of the root network length in the lower fraction " + "of the plant.", + ), + TraitDef( + name="ellipse", + fn=fit_ellipse, + input_traits=["crown_pts"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Tuple of (a, b, ratio) containing the semi-major axis " + "length, semi-minor axis length, and the ratio of the major to minor " + "lengths.", + ), + TraitDef( + name="bounding_box", + fn=get_bbox, + input_traits=["crown_pts"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Tuple of four parameters representing bounding box.", + ), + TraitDef( + name="convex_hull", + fn=get_convhull, + input_traits=["crown_pts"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Convex hull of the crown points.", + ), + TraitDef( + name="crown_tip_xs", + fn=get_tip_xs, + input_traits=["crown_tip_pts"], + scalar=False, + include_in_csv=True, + kwargs={}, + description="Array of the x-coordinates of crown tips `(instance,)`.", + ), + TraitDef( + name="crown_tip_ys", + fn=get_tip_ys, + input_traits=["crown_tip_pts"], + scalar=False, + include_in_csv=True, + kwargs={}, + description="Array of the y-coordinates of crown tips `(instance,)`.", + ), + TraitDef( + name="network_distribution_ratio", + fn=get_network_distribution_ratio, + input_traits=[ + "network_length", + "network_length_lower", + ], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of ratio of the root network length in the lower" + "fraction of the plant over all root length.", + ), + TraitDef( + name="network_length", + fn=get_network_length, + input_traits=["crown_lengths"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of all roots network length.", + ), + TraitDef( + name="crown_base_tip_dists", + fn=get_base_tip_dist, + input_traits=["crown_base_pts", "crown_tip_pts"], + scalar=False, + include_in_csv=True, + kwargs={}, + description="Straight-line distance(s) from the base(s) to the" + "tip(s) of the crown root(s).", + ), + TraitDef( + name="crown_curve_indices", + fn=get_base_tip_dist, + input_traits=["crown_base_pts", "crown_tip_pts"], + scalar=False, + include_in_csv=True, + kwargs={}, + description="Curvature index for each crown root.", + ), + TraitDef( + name="network_solidity", + fn=get_network_solidity, + input_traits=["network_length", "chull_area"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of the total network length divided by the" + "network convex hull area.", + ), + TraitDef( + name="ellipse_a", + fn=get_ellipse_a, + input_traits=["ellipse"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of semi-major axis length.", + ), + TraitDef( + name="ellipse_b", + fn=get_ellipse_b, + input_traits=["ellipse"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of semi-minor axis length.", + ), + TraitDef( + name="network_width_depth_ratio", + fn=get_network_width_depth_ratio, + input_traits=["bounding_box"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of bounding box width to depth ratio of root " + "network.", + ), + TraitDef( + name="chull_perimeter", + fn=get_chull_perimeter, + input_traits=["convex_hull"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of convex hull perimeter.", + ), + TraitDef( + name="chull_area", + fn=get_chull_area, + input_traits=["convex_hull"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of convex hull area.", + ), + TraitDef( + name="chull_max_width", + fn=get_chull_max_width, + input_traits=["convex_hull"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of convex hull maximum width.", + ), + TraitDef( + name="chull_max_height", + fn=get_chull_max_height, + input_traits=["convex_hull"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of convex hull maximum height.", + ), + TraitDef( + name="chull_line_lengths", + fn=get_chull_line_lengths, + input_traits=["convex_hull"], + scalar=False, + include_in_csv=True, + kwargs={}, + description="Array of line lengths connecting any two vertices on the" + "convex hull.", + ), + TraitDef( + name="ellipse_ratio", + fn=get_ellipse_ratio, + input_traits=["ellipse"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of ratio of the minor to major lengths.", + ), + TraitDef( + name="scanline_last_ind", + fn=get_scanline_last_ind, + input_traits=["scanline_intersection_counts"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of count_scanline_interaction index for the last" + "interaction.", + ), + TraitDef( + name="scanline_first_ind", + fn=get_scanline_first_ind, + input_traits=["scanline_intersection_counts"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of count_scanline_interaction index for the first" + "interaction.", + ), + TraitDef( + name="crown_r1_pts", + fn=get_nodes, + input_traits=["crown_pts"], + scalar=False, + include_in_csv=False, + kwargs={"node_index": 1}, + description="Array of crown bases `(instances, (x, y))`.", + ), + TraitDef( + name="chull_r1_intersection_vectors", + fn=get_chull_intersection_vectors, + input_traits=[ + "crown_base_pts", + "crown_r1_pts", + "crown_pts", + "convex_hull", + ], + scalar=False, + include_in_csv=False, + kwargs={}, + description="A tuple containing vectors from the top left point to the" + "left intersection point, and from the top right point to the right" + "intersection point with the convex hull.", + ), + TraitDef( + name="chull_r1_left_intersection_vector", + fn=get_chull_intersection_vectors_left, + input_traits=["chull_r1_intersection_vectors"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Vector from the base point to the left" + "intersection point with the convex hull.", + ), + TraitDef( + name="chull_r1_right_intersection_vector", + fn=get_chull_intersection_vectors_right, + input_traits=["chull_r1_intersection_vectors"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Vector from the base point to the right" + "intersection point with the convex hull.", + ), + TraitDef( + name="angle_chull_r1_left_intersection_vector", + fn=get_vector_angles_from_gravity, + input_traits=["chull_r1_left_intersection_vector"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Angle of the left intersection vector from gravity.", + ), + TraitDef( + name="angle_chull_r1_right_intersection_vector", + fn=get_vector_angles_from_gravity, + input_traits=["chull_r1_right_intersection_vector"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Angle of the right intersection vector from gravity.", + ), + TraitDef( + name="chull_areas_r1_intersection", + fn=get_chull_areas_via_intersection, + input_traits=["crown_r1_pts", "crown_pts", "convex_hull"], + scalar=False, + include_in_csv=False, + kwargs={}, + description="Tuple of the convex hull areas above and below the r1" + "intersection.", + ), + TraitDef( + name="chull_area_above_r1_intersection", + fn=get_chull_area_via_intersection_above, + input_traits=["chull_areas_r1_intersection"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of the convex hull area above the r1 intersection.", + ), + TraitDef( + name="chull_area_below_r1_intersection", + fn=get_chull_area_via_intersection_below, + input_traits=["chull_areas_r1_intersection"], + scalar=True, + include_in_csv=True, + kwargs={}, + description="Scalar of the convex hull area below the r1 intersection.", + ), + ] + + return trait_definitions + + def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, Any]: + """Return initial traits for a plant frame. + + Args: + plant: The plant `Series` object. + frame_idx: The index of the current frame. + + Returns: + A dictionary of initial traits with keys: + - "crown_pts": Array of crown root points. + """ + crown_pts = plant.get_crown_points(frame_idx) + return {"crown_pts": crown_pts} diff --git a/tests/data/canola_7do/919QDUH.lateral_3_nodes.predictions.slp b/tests/data/canola_7do/919QDUH.lateral.predictions.slp old mode 100755 new mode 100644 similarity index 100% rename from tests/data/canola_7do/919QDUH.lateral_3_nodes.predictions.slp rename to tests/data/canola_7do/919QDUH.lateral.predictions.slp diff --git a/tests/data/canola_7do/919QDUH.primary_multi_day.predictions.slp b/tests/data/canola_7do/919QDUH.primary.predictions.slp old mode 100755 new mode 100644 similarity index 100% rename from tests/data/canola_7do/919QDUH.primary_multi_day.predictions.slp rename to tests/data/canola_7do/919QDUH.primary.predictions.slp diff --git a/tests/data/rice_10do/0K9E8BI.crown.predictions.slp b/tests/data/rice_10do/0K9E8BI.crown.predictions.slp new file mode 100644 index 0000000..0bffba8 Binary files /dev/null and b/tests/data/rice_10do/0K9E8BI.crown.predictions.slp differ diff --git a/tests/data/rice_10do/0K9E8BI.h5 b/tests/data/rice_10do/0K9E8BI.h5 new file mode 100644 index 0000000..baa7763 Binary files /dev/null and b/tests/data/rice_10do/0K9E8BI.h5 differ diff --git a/tests/data/rice_3do/0K9E8BI.crown.predictions.slp b/tests/data/rice_3do/0K9E8BI.crown.predictions.slp new file mode 100644 index 0000000..9f19a8a Binary files /dev/null and b/tests/data/rice_3do/0K9E8BI.crown.predictions.slp differ diff --git a/tests/data/rice_3do/0K9E8BI.primary.predictions.slp b/tests/data/rice_3do/0K9E8BI.primary.predictions.slp new file mode 100644 index 0000000..bac1b56 Binary files /dev/null and b/tests/data/rice_3do/0K9E8BI.primary.predictions.slp differ diff --git a/tests/data/rice_3do/YR39SJX.main_3do_6nodes.predictions.slp b/tests/data/rice_3do/YR39SJX.crown.predictions.slp similarity index 100% rename from tests/data/rice_3do/YR39SJX.main_3do_6nodes.predictions.slp rename to tests/data/rice_3do/YR39SJX.crown.predictions.slp diff --git a/tests/data/rice_3do/YR39SJX.longest_3do_6nodes.predictions.slp b/tests/data/rice_3do/YR39SJX.primary.predictions.slp similarity index 100% rename from tests/data/rice_3do/YR39SJX.longest_3do_6nodes.predictions.slp rename to tests/data/rice_3do/YR39SJX.primary.predictions.slp diff --git a/tests/data/soy_6do/6PR6AA22JK.lateral__nodes.predictions.slp b/tests/data/soy_6do/6PR6AA22JK.lateral.predictions.slp similarity index 100% rename from tests/data/soy_6do/6PR6AA22JK.lateral__nodes.predictions.slp rename to tests/data/soy_6do/6PR6AA22JK.lateral.predictions.slp diff --git a/tests/data/soy_6do/6PR6AA22JK.primary_multi_day.predictions.slp b/tests/data/soy_6do/6PR6AA22JK.primary.predictions.slp similarity index 100% rename from tests/data/soy_6do/6PR6AA22JK.primary_multi_day.predictions.slp rename to tests/data/soy_6do/6PR6AA22JK.primary.predictions.slp diff --git a/tests/fixtures/data.py b/tests/fixtures/data.py index cddb014..a295079 100644 --- a/tests/fixtures/data.py +++ b/tests/fixtures/data.py @@ -16,13 +16,13 @@ def canola_h5(): @pytest.fixture def canola_primary_slp(): """Path to primary root predictions for 7 day old canola.""" - return "tests/data/canola_7do/919QDUH.primary_multi_day.predictions.slp" + return "tests/data/canola_7do/919QDUH.primary.predictions.slp" @pytest.fixture def canola_lateral_slp(): """Path to lateral root predictions for 7 day old canola.""" - return "tests/data/canola_7do/919QDUH.lateral_3_nodes.predictions.slp" + return "tests/data/canola_7do/919QDUH.lateral.predictions.slp" @pytest.fixture @@ -40,13 +40,31 @@ def rice_h5(): @pytest.fixture def rice_long_slp(): """Path to longest root predictions for 3 day old rice.""" - return "tests/data/rice_3do/YR39SJX.longest_3do_6nodes.predictions.slp" + return "tests/data/rice_3do/YR39SJX.primary.predictions.slp" @pytest.fixture def rice_main_slp(): """Path to main root predictions for 3 day old rice.""" - return "tests/data/rice_3do/YR39SJX.main_3do_6nodes.predictions.slp" + return "tests/data/rice_3do/YR39SJX.crown.predictions.slp" + + +@pytest.fixture +def rice_10do_folder(): + """Path to a folder with the predictions for 10 day old rice.""" + return "tests/data/rice_10do" + + +@pytest.fixture +def rice_main_10do_h5(): + """Path to root image stack for 10 day old rice.""" + return "tests/data/rice_10do/0K9E8BI.h5" + + +@pytest.fixture +def rice_main_10do_slp(): + """Path to main root predictions for 10 day old rice.""" + return "tests/data/rice_10do/0K9E8BI.crown.predictions.slp" @pytest.fixture @@ -64,7 +82,7 @@ def soy_h5(): @pytest.fixture def soy_primary_slp(): """Path to primary root predictions for 6 day old soy.""" - return "tests/data/soy_6do/6PR6AA22JK.primary_multi_day.predictions.slp" + return "tests/data/soy_6do/6PR6AA22JK.primary.predictions.slp" @pytest.fixture diff --git a/tests/test_angle.py b/tests/test_angle.py index 4761bd4..0550a8b 100644 --- a/tests/test_angle.py +++ b/tests/test_angle.py @@ -1,7 +1,11 @@ import numpy as np import pytest from sleap_roots import Series -from sleap_roots.angle import get_node_ind, get_root_angle +from sleap_roots.angle import ( + get_node_ind, + get_root_angle, + get_vector_angles_from_gravity, +) @pytest.fixture @@ -114,15 +118,39 @@ def pts_nan32_5node(): ) +@pytest.mark.parametrize( + "vector, expected_angle", + [ + (np.array([[0, 1]]), 0), # Directly downwards (with gravity) + (np.array([[0, -1]]), 180), # Directly upwards (against gravity) + (np.array([[1, 0]]), 90), # Right, perpendicular to gravity + (np.array([[-1, 0]]), 90), # Left, perpendicular to gravity + (np.array([[1, 1]]), 45), # Diagonal right-down + (np.array([[1, -1]]), 135), # Diagonal right-up, against gravity + (np.array([[-1, 1]]), 45), # Diagonal left-down, aligned with gravity + (np.array([[-1, -1]]), 135), # Diagonal left-up, against gravity + ], +) +def test_get_vector_angle_from_gravity(vector, expected_angle): + """Test get_vector_angle_from_gravity function with vectors from various directions, + considering a coordinate system where positive y-direction is downwards. + """ + angle = get_vector_angles_from_gravity(vector) + np.testing.assert_almost_equal(angle, expected_angle, decimal=3) + + # test get_node_ind function def test_get_node_ind(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - pts = primary.numpy() + # Set the frame index to 0 + frame_index = 0 + # Load the series from the canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the primary root points + primary_points = series.get_primary_points(frame_index) + # Set the proximal flag to True proximal = True - node_ind = get_node_ind(pts, proximal) + # Get the node index + node_ind = get_node_ind(primary_points, proximal) np.testing.assert_array_equal(node_ind, 1) @@ -177,30 +205,36 @@ def test_get_node_ind_5node_proximal(pts_nan32_5node): # test canola get_root_angle function (base node to distal node angle) def test_get_root_angle_distal(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - pts = primary.numpy() + # Set the frame index to 0 + frame_index = 0 + # Load the series from the canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the primary root points + primary_points = series.get_primary_points(frame_index) + # Set the proximal flag to False proximal = False - node_ind = get_node_ind(pts, proximal) - angs = get_root_angle(pts, node_ind, proximal) - assert pts.shape == (1, 6, 2) + # Get the distal node index + node_ind = get_node_ind(primary_points, proximal) + angs = get_root_angle(primary_points, node_ind, proximal) + assert primary_points.shape == (1, 6, 2) np.testing.assert_almost_equal(angs, 7.7511306, decimal=3) # test rice get_root_angle function (base node to proximal node angle) def test_get_root_angle_proximal_rice(rice_h5): - series = Series.load( - rice_h5, primary_name="main_3do_6nodes", lateral_name="longest_3do_6nodes" - ) - primary, lateral = series[0] - pts = primary.numpy() + # Set the frame index to 0 + frame_index = 0 + # Load the series from the rice dataset + series = Series.load(rice_h5, primary_name="crown", lateral_name="primary") + # Get the primary root points + primary_points = series.get_primary_points(frame_index) + # Set the proximal flag to True proximal = True - node_ind = get_node_ind(pts, proximal) - angs = get_root_angle(pts, node_ind, proximal) + # Get the proximal node index + node_ind = get_node_ind(primary_points, proximal) + angs = get_root_angle(primary_points, node_ind, proximal) assert angs.shape == (2,) - assert pts.shape == (2, 6, 2) + assert primary_points.shape == (2, 6, 2) np.testing.assert_almost_equal(angs, [17.3180819, 3.2692877], decimal=3) @@ -228,3 +262,79 @@ def test_get_root_angle_proximal_allnan(pts_nanall): node_ind = get_node_ind(pts_nanall, proximal) angs = get_root_angle(pts_nanall, node_ind, proximal) np.testing.assert_almost_equal(angs, np.nan, decimal=3) + + +def test_get_root_angle_horizontal(): + # Root pointing right, should be 90 degrees from the downward gravity vector + # Gravity vector is upwards in this coordinate system + pts = np.array( + [[[0, 0], [1, 0]]] + ) # Two nodes: base and end node horizontally aligned + node_ind = np.array([1]) + expected_angles = np.array([90]) + angles = get_root_angle(pts, node_ind) + assert np.allclose(angles, expected_angles), "Angle for horizontal root incorrect." + + +def test_get_root_angle_vertical(): + # Root pointing down, should be 0 degrees from the gravity vector + # Gravity vector is upwards in this coordinate system + pts = np.array( + [[[0, 0], [0, 1]]] + ) # Two nodes: base and end node vertically aligned downwards + node_ind = np.array([1]) + expected_angles = np.array([0]) + angles = get_root_angle(pts, node_ind) + assert np.allclose(angles, expected_angles), "Angle for vertical root incorrect." + + +def test_get_root_angle_up_left(): + # Root pointing up and to the left: should be 45 degrees from the gravity vector + # Gravity vector is upwards in this coordinate system + pts = np.array( + [[[0, 0], [-1, 1]]] + ) # Two nodes: base and end node diagonally upwards to the left + node_ind = np.array([1]) + expected_angles = np.array([45]) + angles = get_root_angle(pts, node_ind) + assert np.allclose(angles, expected_angles), "Angle for vertical root incorrect." + + +def test_get_root_angle_up_right(): + # Root pointing up and to the right: should be 45 degrees from the gravity vector + # Gravity vector is upwards in this coordinate system + pts = np.array( + [[[0, 0], [1, 1]]] + ) # Two nodes: base and end node diagonally upwards to the right + node_ind = np.array([1]) + expected_angles = np.array([45]) + angles = get_root_angle(pts, node_ind) + assert np.allclose( + angles, expected_angles + ), "Angle for diagonally upwards root incorrect." + + +def test_get_root_angle_down_left(): + # Root pointing down and to the left: should be 135 degrees from the gravity vector + # Gravity vector is upwards in this coordinate system + pts = np.array( + [[[0, 0], [-1, -1]]] + ) # Two nodes: base and end node diagonally downwards to the left + node_ind = np.array([1]) + expected_angles = np.array([135]) + angles = get_root_angle(pts, node_ind) + assert np.allclose(angles, expected_angles), "Angle for vertical root incorrect." + + +def test_get_root_angle_down_right(): + # Root pointing down and to the right: should be 135 degrees from the gravity vector + # Gravity vector is upwards in this coordinate system + pts = np.array( + [[[0, 0], [1, -1]]] + ) # Two nodes: base and end node diagonally downwards to the right + node_ind = np.array([1]) + expected_angles = np.array([135]) + angles = get_root_angle(pts, node_ind) + assert np.allclose( + angles, expected_angles + ), "Angle for diagonally upwards root incorrect." diff --git a/tests/test_bases.py b/tests/test_bases.py index e529d86..4e07d30 100644 --- a/tests/test_bases.py +++ b/tests/test_bases.py @@ -2,14 +2,13 @@ get_bases, get_base_ct_density, get_base_tip_dist, - get_lateral_count, get_base_xs, get_base_ys, get_base_length, get_base_length_ratio, get_root_widths, ) -from sleap_roots.lengths import get_max_length_pts, get_root_lengths_max +from sleap_roots.lengths import get_max_length_pts, get_root_lengths from sleap_roots.tips import get_tips from sleap_roots import Series import numpy as np @@ -215,47 +214,27 @@ def test_get_base_tip_dist_no_roots(pts_no_roots): np.testing.assert_almost_equal(distance, [np.nan, np.nan], decimal=7) -# test get_lateral_count function with canola -def test_get_lateral_count(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - lateral_pts = lateral.numpy() - lateral_count = get_lateral_count(lateral_pts) - assert lateral_count == 5 - - # test get_base_xs with canola def test_get_base_xs_canola(canola_h5): - monocots = False - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - lateral = plant[0][1] # first frame, lateral labels - lateral_pts = lateral.numpy() # lateral points as numpy array - bases = get_bases(lateral_pts) + # Set the frame idx to 0 + frame_idx = 0 + # Load a series from a canola dataset + plant = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the labeled frame + lateral_points = plant.get_lateral_points(frame_idx) + # Get the lateral root bases + bases = get_bases(lateral_points) + # Get the base x-coordinates base_xs = get_base_xs(bases) assert base_xs.shape[0] == 5 np.testing.assert_almost_equal(base_xs[1], 1112.5506591796875, decimal=3) -# test get_base_xs with rice -def test_get_base_xs_rice(rice_h5): - monocots = True - plant = Series.load( - rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" - ) - lateral = plant[0][1] # first frame, lateral labels - lateral_pts = lateral.numpy() # lateral points as numpy array - bases = get_bases(lateral_pts, monocots=monocots) - base_xs = get_base_xs(bases) - assert np.isnan(base_xs) - - # test get_base_xs with pts_standard def test_get_base_xs_standard(pts_standard): + # Get the base points bases = get_bases(pts_standard) + # Get the x-coordinates of the base points base_xs = get_base_xs(bases) assert base_xs.shape[0] == 2 np.testing.assert_almost_equal(base_xs[0], 1, decimal=3) @@ -264,39 +243,14 @@ def test_get_base_xs_standard(pts_standard): # test get_base_xs with pts_no_roots def test_get_base_xs_no_roots(pts_no_roots): + # Get the base points bases = get_bases(pts_no_roots) + # Get the x-coordinates of the base points base_xs = get_base_xs(bases) assert base_xs.shape[0] == 2 np.testing.assert_almost_equal(base_xs[0], np.nan, decimal=3) -# test get_base_ys with canola -def test_get_base_ys_canola(canola_h5): - monocots = False - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - lateral = plant[0][1] # first frame, lateral labels - lateral_pts = lateral.numpy() # lateral points as numpy array - base_pts = get_bases(lateral_pts) # get the bases of the lateral roots - base_ys = get_base_ys(base_pts, monocots) - assert base_ys.shape[0] == 5 - np.testing.assert_almost_equal(base_ys[1], 228.0966796875, decimal=3) - - -# test get_base_ys with rice -def test_get_base_ys_rice(rice_h5): - monocots = True - plant = Series.load( - rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" - ) - lateral = plant[0][1] # first frame, lateral labels - lateral_pts = lateral.numpy() # lateral points as numpy array - base_pts = get_bases(lateral_pts, monocots) # get the bases of the lateral roots - base_ys = get_base_ys(base_pts, monocots) - assert np.isnan(base_ys) - - # test get_base_ys with pts_standard def test_get_base_ys_standard(pts_standard): bases = get_bases(pts_standard) @@ -316,28 +270,21 @@ def test_get_base_ys_no_roots(pts_no_roots): # test get_base_length with canola def test_get_base_length_canola(canola_h5): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - lateral = plant[0][1] # first frame, lateral labels - lateral_pts = lateral.numpy() # lateral points as numpy array - bases = get_bases(lateral_pts) # get bases of lateral roots - base_ys = get_base_ys(bases) # get y-coordinates of bases + # Set the frame index to 0 + frame_idx = 0 + # Load a series from a canola dataset + plant = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the lateral points + lateral_pts = plant.get_lateral_points(frame_idx) + # Get the bases of the lateral roots + bases = get_bases(lateral_pts) + # Get the y-coordinates of the bases + base_ys = get_base_ys(bases) + # Get the length of the bases of the lateral roots base_length = get_base_length(base_ys) np.testing.assert_almost_equal(base_length, 83.69914245605469, decimal=3) -# test get_base_length with rice -def test_get_base_length_rice(rice_h5): - plant = Series.load( - rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" - ) - lateral = plant[0][1] # first frame, lateral labels - lateral_pts = lateral.numpy() # lateral points as numpy array - base_length = get_base_length(lateral_pts, monocots=True) - assert np.isnan(base_length) - - # test get_base_length with pts_standard def test_get_base_length_standard(pts_standard): bases = get_bases(pts_standard) # get bases of lateral roots @@ -354,7 +301,8 @@ def test_get_base_length_no_roots(pts_no_roots): # test get_base_ct_density function with defined primary and lateral points def test_get_base_ct_density(primary_pts, lateral_pts): - primary_length_max = get_root_lengths_max(primary_pts) + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length_max = get_root_lengths(primary_max_length_pts) lateral_base_pts = get_bases(lateral_pts) base_ct_density = get_base_ct_density(primary_length_max, lateral_base_pts) np.testing.assert_almost_equal(base_ct_density, 0.00334, decimal=5) @@ -362,58 +310,57 @@ def test_get_base_ct_density(primary_pts, lateral_pts): # test get_base_ct_density function with canola example def test_get_base_ct_density_canola(canola_h5): - monocots = False - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() - lateral_pts = lateral.numpy() - primary_length_max = get_root_lengths_max(primary_pts) + # Set the frame index to 0 + frame_idx = 0 + # Load a series from a canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the primary and lateral points + primary_pts = series.get_primary_points(frame_idx) + lateral_pts = series.get_lateral_points(frame_idx) + # Get the maximum length of the primary root + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length_max = get_root_lengths(primary_max_length_pts) + # Get the bases of the lateral roots lateral_base_pts = get_bases(lateral_pts) + # Get the CT density of the bases of the lateral roots base_ct_density = get_base_ct_density(primary_length_max, lateral_base_pts) np.testing.assert_almost_equal(base_ct_density, 0.004119, decimal=5) -# test get_base_ct_density function with rice example -def test_get_base_ct_density_rice(rice_h5): - monocots = True - series = Series.load( - rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() - primary_max_length = get_root_lengths_max(primary_pts) - lateral_pts = lateral.numpy() - bases = get_bases(lateral_pts, monocots=monocots) - base_ct_density = get_base_ct_density(primary_max_length, bases) - assert np.isnan(base_ct_density) - - # test get_base_length_ratio with canola def test_get_base_length_ratio(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() - lateral_pts = lateral.numpy() - primary_length_max = get_root_lengths_max(primary_pts) + # Set the frame index to 0 + frame_idx = 0 + # Load a series from a canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the primary and lateral points + primary_pts = series.get_primary_points(frame_idx) + lateral_pts = series.get_lateral_points(frame_idx) + # Get the maximum length of the primary root + primary_max_length_pts = get_max_length_pts(primary_pts) + primary_length_max = get_root_lengths(primary_max_length_pts) + # Get the bases of the lateral roots bases = get_bases(lateral_pts) + # Get the y-coordinates of the bases lateral_base_ys = get_base_ys(bases) + # Get the length of the bases of the lateral roots base_length = get_base_length(lateral_base_ys) + # Get the length ratio of the bases of the lateral roots base_length_ratio = get_base_length_ratio(primary_length_max, base_length) np.testing.assert_almost_equal(base_length_ratio, 0.086, decimal=3) def test_root_width_canola(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() + # Set the frame index to 0 + frame_idx = 0 + # Load a series from a canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the primary points + primary_pts = series.get_primary_points(frame_idx) + # Get the primary points with the maximum length primary_max_length_pts = get_max_length_pts(primary_pts) - lateral_pts = lateral.numpy() + # Get the lateral points + lateral_pts = series.get_lateral_points(frame_idx) assert primary_max_length_pts.shape == (6, 2) assert lateral_pts.shape == (5, 3, 2) @@ -421,43 +368,26 @@ def test_root_width_canola(canola_h5): np.testing.assert_almost_equal(root_widths[0], np.array([31.60323909]), decimal=7) -# Test get_root_widths with rice -def test_root_width_rice(rice_h5): - series = Series.load( - rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() - primary_max_length_pts = get_max_length_pts(primary_pts) - lateral_pts = lateral.numpy() - root_widths = get_root_widths( - primary_max_length_pts, lateral_pts, 0.02, monocots=True, return_inds=False - ) - assert np.allclose(root_widths, np.array([]), atol=1e-7) - - # Test for get_root_widths with return_inds=True @pytest.mark.parametrize( - "primary, lateral, tolerance, monocots, expected", + "primary, lateral, tolerance, expected", [ ( np.array([[0, 0], [1, 1]]), np.array([[[0, 0], [1, 1]], [[1, 1], [2, 2]]]), 0.02, - False, (np.array([]), [(np.nan, np.nan)], np.empty((0, 2)), np.empty((0, 2))), ), ( np.array([[np.nan, np.nan], [np.nan, np.nan]]), np.array([[[0, 0], [1, 1]], [[1, 1], [2, 2]]]), 0.02, - False, (np.array([]), [(np.nan, np.nan)], np.empty((0, 2)), np.empty((0, 2))), ), ], ) -def test_get_root_widths(primary, lateral, tolerance, monocots, expected): - result = get_root_widths(primary, lateral, tolerance, monocots, return_inds=True) +def test_get_root_widths(primary, lateral, tolerance, expected): + result = get_root_widths(primary, lateral, tolerance, return_inds=True) np.testing.assert_array_almost_equal(result[0], expected[0]) assert result[1] == expected[1] np.testing.assert_array_almost_equal(result[2], expected[2]) diff --git a/tests/test_convhull.py b/tests/test_convhull.py index 90cb473..a33c279 100644 --- a/tests/test_convhull.py +++ b/tests/test_convhull.py @@ -7,13 +7,60 @@ get_chull_max_height, get_chull_max_width, get_chull_perimeter, + get_chull_division_areas, + get_chull_areas_via_intersection, + get_chull_intersection_vectors, ) from sleap_roots.lengths import get_max_length_pts -from sleap_roots.points import get_all_pts_array +from sleap_roots.points import get_all_pts_array, get_nodes import numpy as np import pytest +@pytest.fixture +def valid_input(): + # Example points forming a convex hull with points above and below a line + pts = np.array([[[0, 0], [2, 2], [4, 0], [2, -2], [0, -4], [4, -4]]]) + rn_pts = np.array([[0, 0], [4, 0]]) # Line from the leftmost to rightmost rn nodes + hull = ConvexHull(pts.reshape(-1, 2)) + expected_area_above = 16.0 + expected_area_below = 4.0 + return rn_pts, pts, hull, (expected_area_above, expected_area_below) + + +@pytest.fixture +def invalid_pts_shape(): + rn_pts = np.array([[0, 0], [1, 1]]) + pts = np.array([1, 2]) # Incorrect shape + return rn_pts, pts + + +@pytest.fixture +def nan_in_rn_pts(): + rn_pts = np.array([[np.nan, np.nan], [1, 1]]) + pts = np.array([[[0, 0], [1, 2], [2, 3]], [[3, 1], [4, 2], [5, 3]]]) + hull = ConvexHull(pts.reshape(-1, 2)) + return rn_pts, pts, hull + + +@pytest.fixture +def insufficient_unique_points_for_hull(): + rn_pts = np.array([[0, 0], [1, 1]]) + pts = np.array([[[0, 0], [0, 0], [0, 0]]]) # Only one unique point + return rn_pts, pts + + +@pytest.fixture +def pts_shape_3_6_2(): + return np.array( + [ + [[-1, 0], [-1, 1], [-1, 2], [-1, 3], [-2, 4], [-3, 5]], + [[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [0, 5]], + [[1, 0], [1, 1], [2, 2], [3, 3], [4, 4], [4, 5]], + ] + ) + + @pytest.fixture def pts_nan31_5node(): return np.array( @@ -75,13 +122,16 @@ def lateral_pts(): # test get_convhull function using canola def test_get_convhull_canola(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() - lateral_pts = lateral.numpy() + # Set frame index to 0 + frame_index = 0 + # Load the series from the canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the primary and lateral root from the series + primary_pts = series.get_primary_points(frame_index) + lateral_pts = series.get_lateral_points(frame_index) + # Get the maximum length points from the primary root primary_max_length_pts = get_max_length_pts(primary_pts) + # Get all points from the primary and lateral roots pts = get_all_pts_array(primary_max_length_pts, lateral_pts) convex_hull = get_convhull(pts) assert type(convex_hull) == ConvexHull @@ -89,14 +139,17 @@ def test_get_convhull_canola(canola_h5): # test canola model def test_get_convhull_features_canola(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() - lateral_pts = lateral.numpy() + # Set the frame index to 0 + frame_index = 0 + # Load the series from the canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the primary and lateral root from the series + primary_pts = series.get_primary_points(frame_index) + lateral_pts = series.get_lateral_points(frame_index) primary_max_length_pts = get_max_length_pts(primary_pts) + # Get all points from the primary and lateral roots pts = get_all_pts_array(primary_max_length_pts, lateral_pts) + # Get the convex hull from the points convex_hull = get_convhull(pts) perimeter = get_chull_perimeter(convex_hull) @@ -112,25 +165,23 @@ def test_get_convhull_features_canola(canola_h5): # test rice model def test_get_convhull_features_rice(rice_h5): - series = Series.load( - rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() - lateral_pts = lateral.numpy() - primary_max_length_pts = get_max_length_pts(primary_pts) - pts = get_all_pts_array(primary_max_length_pts, lateral_pts) - convex_hull = get_convhull(pts) - + # Set the frame index to 0 + frame_index = 0 + # Load the series from the rice dataset + series = Series.load(rice_h5, primary_name="primary", crown_name="crown") + # Get the crown root from the series + crown_pts = series.get_crown_points(frame_index) + # Get the convex hull from the points + convex_hull = get_convhull(crown_pts) perimeter = get_chull_perimeter(convex_hull) area = get_chull_area(convex_hull) max_width = get_chull_max_width(convex_hull) max_height = get_chull_max_height(convex_hull) - np.testing.assert_almost_equal(perimeter, 1458.8585933576614, decimal=3) - np.testing.assert_almost_equal(area, 23878.72090798154, decimal=3) - np.testing.assert_almost_equal(max_width, 64.4229736328125, decimal=3) - np.testing.assert_almost_equal(max_height, 720.0375061035156, decimal=3) + np.testing.assert_almost_equal(perimeter, 1450.6365795858003, decimal=3) + np.testing.assert_almost_equal(area, 23722.883102604676, decimal=3) + np.testing.assert_almost_equal(max_width, 64.341064453125, decimal=3) + np.testing.assert_almost_equal(max_height, 715.6949920654297, decimal=3) # test plant with 2 roots/instances with nan nodes @@ -171,12 +222,14 @@ def test_get_chull_perimeter(lateral_pts): # test get_chull_line_lengths with canola def test_get_chull_line_lengths(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() - lateral_pts = lateral.numpy() + # Set the frame index to 0 + frame_index = 0 + # Load the series from the canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the primary and lateral root from the series + primary_pts = series.get_primary_points(frame_index) + lateral_pts = series.get_lateral_points(frame_index) + # Get the maximum length points from the primary root primary_max_length_pts = get_max_length_pts(primary_pts) pts = get_all_pts_array(primary_max_length_pts, lateral_pts) convex_hull = get_convhull(pts) @@ -189,3 +242,102 @@ def test_get_chull_line_lengths(canola_h5): def test_get_chull_line_lengths_nonehull(pts_nan_5node): chull_line_lengths = get_chull_line_lengths(pts_nan_5node) np.testing.assert_almost_equal(chull_line_lengths, np.nan, decimal=3) + + +def test_get_chull_division_areas(pts_shape_3_6_2): + # Points arranged in a way that the line between the leftmost and rightmost + # r1 nodes has and area above it and below it + hull = get_convhull(pts_shape_3_6_2) + r1_pts = get_nodes(pts_shape_3_6_2, 1) + above, below = get_chull_division_areas(r1_pts, pts_shape_3_6_2, hull) + np.testing.assert_almost_equal(above, 2.0, decimal=3) + np.testing.assert_almost_equal(below, 16.0, decimal=3) + + +def test_get_chull_area_via_intersection_valid(valid_input): + rn_pts, pts, hull, expected_areas = valid_input + above, below = get_chull_areas_via_intersection(rn_pts, pts, hull) + np.testing.assert_almost_equal(above, expected_areas[0], decimal=3) + np.testing.assert_almost_equal(below, expected_areas[1], decimal=3) + + +def test_invalid_pts_shape_area_via_intersection(invalid_pts_shape): + rn_pts, pts = invalid_pts_shape + with pytest.raises(ValueError): + _ = get_chull_areas_via_intersection(rn_pts, pts, None) + + +def test_nan_in_rn_pts_area_via_intersection(nan_in_rn_pts): + rn_pts, pts, hull = nan_in_rn_pts + area_above_line, area_below_line = get_chull_areas_via_intersection( + rn_pts, pts, hull + ) + assert np.isnan(area_above_line) and np.isnan( + area_below_line + ), "Expected NaN areas when rn_pts contains NaN values" + + +def test_insufficient_unique_points_for_hull_area_via_intersection( + insufficient_unique_points_for_hull, +): + rn_pts, pts = insufficient_unique_points_for_hull + area_above_line, area_below_line = get_chull_areas_via_intersection( + rn_pts, pts, None + ) + assert np.isnan(area_above_line) and np.isnan( + area_below_line + ), "Expected NaN areas when there are insufficient unique points for a convex hull" + + +# Helper function to create a convex hull from points +def create_convex_hull_from_points(points): + return ConvexHull(points) + + +# Basic functionality test +def test_basic_functionality(pts_shape_3_6_2): + r0_pts = pts_shape_3_6_2[:, 0, :] + r1_pts = pts_shape_3_6_2[:, 1, :] + pts = pts_shape_3_6_2 + hull = create_convex_hull_from_points(pts.reshape(-1, 2)) + + left_vector, right_vector = get_chull_intersection_vectors( + r0_pts, r1_pts, pts, hull + ) + + # Assertions depend on the expected outcome, which you'll need to calculate based on your function's logic + assert not np.isnan(left_vector).any(), "Left vector should not contain NaNs" + assert not np.isnan(right_vector).any(), "Right vector should not contain NaNs" + + +# Test with invalid input shapes +@pytest.mark.parametrize( + "invalid_input", + [ + (np.array([1, 2]), np.array([3, 4]), np.array([[[1, 2], [3, 4]]]), None), + (np.array([[1, 2, 3]]), np.array([[3, 4]]), np.array([[[1, 2], [3, 4]]]), None), + # Add more invalid inputs as needed + ], +) +def test_invalid_input_shapes(invalid_input): + r0_pts, rn_pts, pts, hull = invalid_input + with pytest.raises(ValueError): + get_chull_intersection_vectors(r0_pts, rn_pts, pts, hull) + + +# Test with no convex hull +def test_no_convex_hull(): + r0_pts = np.array([[1, 1], [2, 2]]) + rn_pts = np.array([[3, 3], [4, 4]]) + pts = np.array([[[1, 1], [2, 2], [3, 3], [4, 4]]]) + + left_vector, right_vector = get_chull_intersection_vectors( + r0_pts, rn_pts, pts, None + ) + + assert np.isnan( + left_vector + ).all(), "Expected NaN vector for left_vector when hull is None" + assert np.isnan( + right_vector + ).all(), "Expected NaN vector for right_vector when hull is None" diff --git a/tests/test_ellipse.py b/tests/test_ellipse.py index 74828bd..fa9bfbc 100644 --- a/tests/test_ellipse.py +++ b/tests/test_ellipse.py @@ -12,13 +12,13 @@ def test_get_ellipse(canola_h5: Literal["tests/data/canola_7do/919QDUH.h5"]): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] + # Set the frame index = 0 + frame_index = 0 + # Load the canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + primary_pts = series.get_primary_points(frame_index) # only test ellipse for primary root points - pts = primary.numpy() - a, b, ratio = fit_ellipse(pts) + a, b, ratio = fit_ellipse(primary_pts) np.testing.assert_almost_equal(a, 733.3038028507555, decimal=3) np.testing.assert_almost_equal(b, 146.47723651978848, decimal=3) np.testing.assert_almost_equal(ratio, 5.006264591916579, decimal=3) @@ -34,66 +34,20 @@ def test_get_ellipse(canola_h5: Literal["tests/data/canola_7do/919QDUH.h5"]): assert np.isnan(ratio) -def test_get_ellipse_a(canola_h5: Literal["tests/data/canola_7do/919QDUH.h5"]): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = plant[0] - primary_pts = primary.numpy() +def test_get_ellipse_all_points(canola_h5: Literal["tests/data/canola_7do/919QDUH.h5"]): + # Set the frame index = 0 + frame_index = 0 + # Load the canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + primary_pts = series.get_primary_points(frame_index) primary_max_length_pts = get_max_length_pts(primary_pts) - lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array( - primary_max_length_pts, lateral_pts, monocots=False - ) + lateral_pts = series.get_lateral_points(frame_index) + pts_all_array = get_all_pts_array(primary_max_length_pts, lateral_pts) ellipse_a = get_ellipse_a(pts_all_array) - np.testing.assert_almost_equal(ellipse_a, 398.1275346610801, decimal=3) - - -def test_get_ellipse_b(canola_h5: Literal["tests/data/canola_7do/919QDUH.h5"]): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = plant[0] - primary_pts = primary.numpy() - primary_max_length_pts = get_max_length_pts(primary_pts) - lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array( - primary_max_length_pts, lateral_pts, monocots=False - ) ellipse_b = get_ellipse_b(pts_all_array) - np.testing.assert_almost_equal(ellipse_b, 115.03734180292595, decimal=3) - - -def test_get_ellipse_ratio(canola_h5: Literal["tests/data/canola_7do/919QDUH.h5"]): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = plant[0] - primary_pts = primary.numpy() - primary_max_length_pts = get_max_length_pts(primary_pts) - lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array( - primary_max_length_pts, lateral_pts, monocots=False - ) ellipse_ratio = get_ellipse_ratio(pts_all_array) - np.testing.assert_almost_equal(ellipse_ratio, 3.460854783511295, decimal=3) - - -def test_get_ellipse_ratio_ellipse( - canola_h5: Literal["tests/data/canola_7do/919QDUH.h5"], -): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = plant[0] - primary_pts = primary.numpy() - primary_max_length_pts = get_max_length_pts(primary_pts) - lateral_pts = lateral.numpy() - pts_all_array = get_all_pts_array( - primary_max_length_pts, lateral_pts, monocots=False - ) - ellipse = fit_ellipse(pts_all_array) - ellipse_ratio = get_ellipse_ratio(ellipse) + np.testing.assert_almost_equal(ellipse_a, 398.1275346610801, decimal=3) + np.testing.assert_almost_equal(ellipse_b, 115.03734180292595, decimal=3) np.testing.assert_almost_equal(ellipse_ratio, 3.460854783511295, decimal=3) @@ -102,10 +56,6 @@ def test_fit_ellipse(): pts = np.array([]) assert np.isnan(fit_ellipse(pts)).all() - # Test when pts has less than 5 points - pts = np.array([[0, 0], [1, 1], [2, 2], [3, 3]]) - assert np.isnan(fit_ellipse(pts)).all() - # Test when pts has NaNs only pts = np.array([[np.nan, np.nan], [np.nan, np.nan], [np.nan, np.nan]]) assert np.isnan(fit_ellipse(pts)).all() diff --git a/tests/test_lengths.py b/tests/test_lengths.py index 6df97ed..bd58299 100644 --- a/tests/test_lengths.py +++ b/tests/test_lengths.py @@ -1,7 +1,6 @@ from sleap_roots.lengths import ( get_curve_index, get_root_lengths, - get_root_lengths_max, get_max_length_pts, ) from sleap_roots.bases import get_base_tip_dist, get_bases @@ -148,16 +147,22 @@ def lengths_all_nan(): # tests for get_curve_index function def test_get_curve_index_canola(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary = series[0][0] # first frame, primary labels - primary_pts = primary.numpy() # primary points as numpy array - primary_length = get_root_lengths_max(primary_pts) - max_length_pts = get_max_length_pts(primary_pts) - bases = get_bases(max_length_pts) - tips = get_tips(max_length_pts) + # Set the frame index to 0 + frame_idx = 0 + # Load the series from the canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the primary points from the first frame + primary_pts = series.get_primary_points(frame_idx) + # Get the points from the root with maximum length + primary_pts = get_max_length_pts(primary_pts) + # Get the length of the root with maximum length + primary_length = get_root_lengths(primary_pts) + # Get the bases and tips of the root + bases = get_bases(primary_pts) + tips = get_tips(primary_pts) + # Get the distance between the base and the tip of the root base_tip_dist = get_base_tip_dist(bases, tips) + # Get the curvature index of the root curve_index = get_curve_index(primary_length, base_tip_dist) np.testing.assert_almost_equal(curve_index, 0.08898137324716636) @@ -254,22 +259,49 @@ def test_invalid_scalar_values(): # tests for `get_root_lengths` -def test_get_root_lengths(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" +def test_curve_index_float(): + assert get_curve_index(10.0, 5.0) == 0.5 + + +def test_curve_index_float_invalid(): + assert np.isnan(get_curve_index(np.nan, 5.0)) + + +def test_curve_index_array(): + lengths = np.array([10, 20, 30, 0, np.nan]) + base_tip_dists = np.array([5, 15, 25, 0, np.nan]) + expected = np.array([0.5, 0.25, 0.16666667, np.nan, np.nan]) + np.testing.assert_allclose( + get_curve_index(lengths, base_tip_dists), expected, rtol=1e-6 ) - primary, lateral = series[0] - pts = primary.numpy() - assert pts.shape == (1, 6, 2) - root_lengths = get_root_lengths(pts) - assert np.isscalar(root_lengths) - np.testing.assert_array_almost_equal(root_lengths, [971.050417]) - pts = lateral.numpy() - assert pts.shape == (5, 3, 2) +def test_curve_index_mixed_invalid(): + lengths = np.array([10, np.nan, 0]) + base_tip_dists = np.array([5, 5, 5]) + expected = np.array([0.5, np.nan, np.nan]) + np.testing.assert_allclose( + get_curve_index(lengths, base_tip_dists), expected, rtol=1e-6 + ) + - root_lengths = get_root_lengths(pts) +def test_get_root_lengths(canola_h5): + # Set the frame index to 0 + frame_idx = 0 + # Load the series from the canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the primary points from the first frame + primary_pts = series.get_primary_points(frame_idx) + assert primary_pts.shape == (1, 6, 2) + # Get the root lengths + root_lengths = get_root_lengths(primary_pts) + assert np.isscalar(root_lengths) + np.testing.assert_array_almost_equal(root_lengths, [971.050417]) + # Get the lateral points from the first frame + lateral_pts = series.get_lateral_points(frame_idx) + assert lateral_pts.shape == (5, 3, 2) + # Get the root lengths + root_lengths = get_root_lengths(lateral_pts) assert root_lengths.shape == (5,) np.testing.assert_array_almost_equal( root_lengths, [20.129579, 62.782368, 80.268003, 34.925591, 3.89724] @@ -290,30 +322,14 @@ def test_get_root_lengths_one_point(pts_one_base): ) -# test get_root_lengths_max function with lengths_normal -def test_get_root_lengths_max_normal(lengths_normal): - max_length = get_root_lengths_max(lengths_normal) - np.testing.assert_array_almost_equal(max_length, 329.4) - - -# test get_root_lengths_max function with lengths_with_nan -def test_get_root_lengths_max_with_nan(lengths_with_nan): - max_length = get_root_lengths_max(lengths_with_nan) - np.testing.assert_array_almost_equal(max_length, 329.4) - - -# test get_root_lengths_max function with lengths_all_nan -def test_get_root_lengths_max_all_nan(lengths_all_nan): - max_length = get_root_lengths_max(lengths_all_nan) - np.testing.assert_array_almost_equal(max_length, np.nan) - - def test_get_max_length_pts(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary = series[0][0] # first frame, primary labels - primary_pts = primary.numpy() # primary points as numpy array + # Set the frame index to 0 + frame_idx = 0 + # Load the series from the canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the primary points from the first frame + primary_pts = series.get_primary_points(frame_idx) + # Get the points from the root with maximum length max_length_pts = get_max_length_pts(primary_pts) assert max_length_pts.shape == (6, 2) np.testing.assert_almost_equal( diff --git a/tests/test_networklength.py b/tests/test_networklength.py index f244cec..a98be95 100644 --- a/tests/test_networklength.py +++ b/tests/test_networklength.py @@ -10,7 +10,7 @@ from sleap_roots.networklength import get_network_length from sleap_roots.networklength import get_network_solidity from sleap_roots.networklength import get_network_width_depth_ratio -from sleap_roots.points import get_all_pts_array +from sleap_roots.points import get_all_pts_array, join_pts @pytest.fixture @@ -27,24 +27,26 @@ def pts_nan3(): def test_get_bbox(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - pts = primary.numpy() - bbox = get_bbox(pts) + # Set the frame index = 0 + frame_index = 0 + # Load the series from canola + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + primary_pts = series.get_primary_points(frame_index) + lateral_pts = series.get_lateral_points(frame_index) + pts_all_array = get_all_pts_array(primary_pts, lateral_pts) + bbox = get_bbox(pts_all_array) np.testing.assert_almost_equal( - bbox, [1016.7844238, 144.4191589, 192.1080322, 876.5622253], decimal=7 + bbox, [1016.7844238, 144.4191589, 211.2792969, 876.5622253], decimal=7 ) def test_get_bbox_rice(rice_h5): - series = Series.load( - rice_h5, primary_name="main_3do_6nodes", lateral_name="longest_3do_6nodes" - ) - primary, lateral = series[0] - pts = primary.numpy() - bbox = get_bbox(pts) + # Set the frame index = 0 + frame_index = 0 + # Load the series from rice + series = Series.load(rice_h5, crown_name="crown", primary_name="primary") + crown_pts = series.get_crown_points(frame_index) + bbox = get_bbox(crown_pts) np.testing.assert_almost_equal( bbox, [796.2611694, 248.6078033, 64.3410645, 715.6949921], decimal=7 ) @@ -59,22 +61,22 @@ def test_get_bbox_nan(pts_nan3): def test_get_network_width_depth_ratio(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - pts = primary.numpy() - ratio = get_network_width_depth_ratio(pts) + # Set the frame index = 0 + frame_index = 0 + # Load the series from canola + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + primary_pts = series.get_primary_points(frame_index) + ratio = get_network_width_depth_ratio(primary_pts) np.testing.assert_almost_equal(ratio, 0.2191607471467916, decimal=7) def test_get_network_width_depth_ratio_rice(rice_h5): - series = Series.load( - rice_h5, primary_name="main_3do_6nodes", lateral_name="longest_3do_6nodes" - ) - primary, lateral = series[0] - pts = primary.numpy() - ratio = get_network_width_depth_ratio(pts) + # Set the frame index = 0 + frame_index = 0 + # Load the series from rice + series = Series.load(rice_h5, crown_name="crown", primary_name="primary") + crown_pts = series.get_crown_points(frame_index) + ratio = get_network_width_depth_ratio(crown_pts) np.testing.assert_almost_equal(ratio, 0.0899001182996162, decimal=7) @@ -84,59 +86,103 @@ def test_get_network_width_depth_ratio_nan(pts_nan3): np.testing.assert_almost_equal(ratio, np.nan, decimal=7) +def test_get_network_distribution_basic_functionality(): + pts_list = [np.array([[0, 0], [4, 0]]), np.array([[0, 1], [4, 1]])] + bounding_box = (0, 0, 4, 1) + fraction = 2 / 3 + result = get_network_distribution(pts_list, bounding_box, fraction) + assert ( + result == 4 + ) # Only the first line segment is in the lower 2/3 of the bounding box + + +def test_get_network_distribution_invalid_shape(): + with pytest.raises(ValueError): + pts_list = [np.array([0, 1])] + bounding_box = (0, 0, 4, 4) + get_network_distribution(pts_list, bounding_box) + + +def test_get_network_distribution_invalid_bounding_box(): + with pytest.raises(ValueError): + pts_list = [np.array([[0, 0], [4, 0]])] + bounding_box = (0, 0, 4) + get_network_distribution(pts_list, bounding_box) + + +def test_get_network_distribution_with_nan(): + # NaNs should be filtered out + pts_list = [np.array([[0, 0], [4, 0]]), np.array([[0, 1], [4, np.nan]])] + bounding_box = (0, 0, 4, 1) + fraction = 2 / 3 + result = get_network_distribution(pts_list, bounding_box, fraction) + assert ( + result == 0.0 + ) # Given (0,0) is the top-left, the line segment is in the upper 1/3 + + +def test_get_network_distribution_with_nan_nonzero_length(): + # First line segment is at y = 2/3, which is in the lower 2/3 of the bounding box. + # Second line segment has a NaN value and will be filtered out. + pts_list = [np.array([[0, 2 / 3], [4, 2 / 3]]), np.array([[0, 1], [4, np.nan]])] + bounding_box = (0, 0, 4, 1) + fraction = 2 / 3 + result = get_network_distribution(pts_list, bounding_box, fraction) + assert ( + result == 4.0 + ) # Only the first line segment is in the lower 2/3 and its length is 4. + + +def test_get_network_distribution_different_fraction(): + pts_list = [np.array([[0, 0], [4, 0]]), np.array([[0, 1], [4, 1]])] + bounding_box = (0, 0, 4, 1) + fraction = 1 # Cover the whole bounding box + result = get_network_distribution(pts_list, bounding_box, fraction) + assert result == 8 # Both line segments are in the lower part of the bounding box + + def test_get_network_length(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() + # Set the frame index = 0 + frame_index = 0 + # Load the series from canola + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + primary_pts = series.get_primary_points(frame_index) # get primary length primary_max_length_pts = get_max_length_pts(primary_pts) primary_length = get_root_lengths(primary_max_length_pts) # get lateral_lengths - lateral_pts = lateral.numpy() + lateral_pts = series.get_lateral_points(frame_index) lateral_lengths = get_root_lengths(lateral_pts) - monocots = False - length = get_network_length(primary_length, lateral_lengths, monocots) + length = get_network_length(primary_length, lateral_lengths) np.testing.assert_almost_equal(length, 1173.0531992388217, decimal=7) def test_get_network_length_rice(rice_h5): - series = Series.load( - rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() - # get primary length - primary_max_length_pts = get_max_length_pts(primary_pts) - primary_length = get_root_lengths(primary_max_length_pts) - # get lateral_lengths - lateral_pts = lateral.numpy() - lateral_lengths = get_root_lengths(lateral_pts) - monocots = True - length = get_network_length(primary_length, lateral_lengths, monocots) + # Set the frame index = 0 + frame_index = 0 + series = Series.load(rice_h5, primary_name="primary", crown_name="crown") + crown_pts = series.get_crown_points(frame_index) + crown_lengths = get_root_lengths(crown_pts) + length = get_network_length(crown_lengths) np.testing.assert_almost_equal(length, 798.5726441151357, decimal=7) def test_get_network_solidity(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() + # Set the frame index = 0 + frame_index = 0 + # Load the series from canola + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + primary_pts = series.get_primary_points(frame_index) # get primary length primary_max_length_pts = get_max_length_pts(primary_pts) primary_length = get_root_lengths(primary_max_length_pts) # get lateral_lengths - lateral_pts = lateral.numpy() + lateral_pts = series.get_lateral_points(frame_index) lateral_lengths = get_root_lengths(lateral_pts) - monocots = False - network_length = get_network_length(primary_length, lateral_lengths, monocots) + network_length = get_network_length(primary_length, lateral_lengths) # get chull_area - pts_all_array = get_all_pts_array( - primary_max_length_pts, lateral_pts, monocots=monocots - ) + pts_all_array = get_all_pts_array(primary_max_length_pts, lateral_pts) convex_hull = get_convhull(pts_all_array) chull_area = get_chull_area(convex_hull) @@ -145,27 +191,15 @@ def test_get_network_solidity(canola_h5): def test_get_network_solidity_rice(rice_h5): - series = Series.load( - rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() - # get primary length - primary_max_length_pts = get_max_length_pts(primary_pts) - primary_length = get_root_lengths(primary_max_length_pts) - # get lateral_lengths - lateral_pts = lateral.numpy() - lateral_lengths = get_root_lengths(lateral_pts) - monocots = True - network_length = get_network_length(primary_length, lateral_lengths, monocots) - - # get chull_area - pts_all_array = get_all_pts_array( - primary_max_length_pts, lateral_pts, monocots=monocots - ) - convex_hull = get_convhull(pts_all_array) + # Set the frame index = 0 + frame_index = 0 + # Load the series from rice + series = Series.load(rice_h5, primary_name="primary", crown_name="crown") + crown_pts = series.get_crown_points(frame_index) + crown_lengths = get_root_lengths(crown_pts) + network_length = get_network_length(crown_lengths) + convex_hull = get_convhull(crown_pts) chull_area = get_chull_area(convex_hull) - ratio = get_network_solidity(network_length, chull_area) np.testing.assert_almost_equal(ratio, 0.03366254601775008, decimal=7) @@ -178,12 +212,9 @@ def test_get_network_distribution_one_point(): ) # One of the roots has only one point bounding_box = (0, 0, 10, 10) fraction = 2 / 3 - monocots = False - + pts = join_pts(primary_pts, lateral_pts) # Call the function - network_length = get_network_distribution( - primary_pts, lateral_pts, bounding_box, fraction, monocots - ) + network_length = get_network_distribution(pts, bounding_box, fraction) # Define the expected result # Only the valid roots should be considered in the calculation @@ -203,8 +234,8 @@ def test_get_network_distribution_empty_arrays(): primary_pts = np.full((2, 2), np.nan) lateral_pts = np.full((2, 2, 2), np.nan) bounding_box = (0, 0, 10, 10) - - network_length = get_network_distribution(primary_pts, lateral_pts, bounding_box) + pts = join_pts(primary_pts, lateral_pts) + network_length = get_network_distribution(pts, bounding_box) assert network_length == 0 @@ -212,9 +243,8 @@ def test_get_network_distribution_with_nans(): primary_pts = np.array([[1, 1], [2, 2], [np.nan, np.nan]]) lateral_pts = np.array([[[4, 4], [5, 5], [np.nan, np.nan]]]) bounding_box = (0, 0, 10, 10) - - network_length = get_network_distribution(primary_pts, lateral_pts, bounding_box) - + pts = join_pts(primary_pts, lateral_pts) + network_length = get_network_distribution(pts, bounding_box) lower_box = Polygon( [(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))] ) @@ -222,19 +252,15 @@ def test_get_network_distribution_with_nans(): LineString(primary_pts[:-1]).intersection(lower_box).length + LineString(lateral_pts[0, :-1]).intersection(lower_box).length ) - assert network_length == pytest.approx(expected_length) -def test_get_network_distribution_monocots(): +def test_get_network_distribution_basic_functionality(): primary_pts = np.array([[1, 1], [2, 2], [3, 3]]) lateral_pts = np.array([[[4, 4], [5, 5]]]) bounding_box = (0, 0, 10, 10) - monocots = True - - network_length = get_network_distribution( - primary_pts, lateral_pts, bounding_box, monocots=monocots - ) + pts = join_pts(primary_pts, lateral_pts) + network_length = get_network_distribution(pts, bounding_box) lower_box = Polygon( [(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))] @@ -251,10 +277,8 @@ def test_get_network_distribution_different_fraction(): lateral_pts = np.array([[[4, 4], [5, 5]]]) bounding_box = (0, 0, 10, 10) fraction = 0.5 - - network_length = get_network_distribution( - primary_pts, lateral_pts, bounding_box, fraction=fraction - ) + pts = join_pts(primary_pts, lateral_pts) + network_length = get_network_distribution(pts, bounding_box, fraction) lower_box = Polygon( [(0, 10 - 10 * fraction), (0, 10), (10, 10), (10, 10 - 10 * fraction)] @@ -268,106 +292,77 @@ def test_get_network_distribution_different_fraction(): def test_get_network_distribution(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() + # Set the frame index = 0 + frame_index = 0 + # Load the series from canola + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + primary_pts = series.get_primary_points(frame_index) primary_max_length_pts = get_max_length_pts(primary_pts) - lateral_pts = lateral.numpy() - monocots = False - pts_all_array = get_all_pts_array(primary_max_length_pts, lateral_pts, monocots) + lateral_pts = series.get_lateral_points(frame_index) + pts_all_array = get_all_pts_array(primary_max_length_pts, lateral_pts) bbox = get_bbox(pts_all_array) + pts_all_list = join_pts(primary_max_length_pts, lateral_pts) fraction = 2 / 3 - monocots = False - root_length = get_network_distribution( - primary_max_length_pts, lateral_pts, bbox, fraction, monocots - ) + root_length = get_network_distribution(pts_all_list, bbox, fraction) np.testing.assert_almost_equal(root_length, 589.4322131363684, decimal=7) def test_get_network_distribution_rice(rice_h5): - series = Series.load( - rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() - primary_max_length_pts = get_max_length_pts(primary_pts) - lateral_pts = lateral.numpy() - monocots = True - pts_all_array = get_all_pts_array( - primary_max_length_pts, lateral_pts, monocots=monocots - ) - bbox = get_bbox(pts_all_array) + # Set the frame index = 0 + frame_index = 0 + # Load the series from rice + series = Series.load(rice_h5, primary_name="primary", crown_name="crown") + crown_pts = series.get_crown_points(frame_index) + bbox = get_bbox(crown_pts) fraction = 2 / 3 - root_length = get_network_distribution( - primary_max_length_pts, lateral_pts, bbox, fraction, monocots - ) + root_length = get_network_distribution(crown_pts, bbox, fraction) np.testing.assert_almost_equal(root_length, 477.77168597561507, decimal=7) def test_get_network_distribution_ratio(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - monocots = False - primary, lateral = series[0] - primary_pts = primary.numpy() + # Set the frame index = 0 + frame_index = 0 + # Load the series from canola + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + primary_pts = series.get_primary_points(frame_index) # get primary length primary_max_length_pts = get_max_length_pts(primary_pts) primary_length = get_root_lengths(primary_max_length_pts) # get lateral lengths - lateral_pts = lateral.numpy() + lateral_pts = series.get_lateral_points(frame_index) lateral_lengths = get_root_lengths(lateral_pts) # get pts_all_array - pts_all_array = get_all_pts_array( - primary_max_length_pts, lateral_pts, monocots=monocots - ) + pts_all_array = get_all_pts_array(primary_max_length_pts, lateral_pts) bbox = get_bbox(pts_all_array) + pts_all_list = join_pts(primary_max_length_pts, lateral_pts) # get network_length_lower - network_length_lower = get_network_distribution( - primary_max_length_pts, lateral_pts, bbox - ) - fraction = 2 / 3 + network_length_lower = get_network_distribution(pts_all_list, bbox) + # get total network length + network_length = get_network_length(primary_length, lateral_lengths) + # get ratio of network length in lower 2/3 of bounding box to total network length ratio = get_network_distribution_ratio( - primary_length, - lateral_lengths, + network_length, network_length_lower, - fraction, - monocots, ) np.testing.assert_almost_equal(ratio, 0.5024769665338648, decimal=7) def test_get_network_distribution_ratio_rice(rice_h5): - series = Series.load( - rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" - ) - monocots = True + # Set the frame index = 0 + frame_index = 0 + # Load the series from rice + series = Series.load(rice_h5, primary_name="primary", crown_name="crown") fraction = 2 / 3 - primary, lateral = series[0] - primary_pts = primary.numpy() - # get primary length - primary_max_length_pts = get_max_length_pts(primary_pts) - primary_length = get_root_lengths(primary_max_length_pts) - # get lateral lengths - lateral_pts = lateral.numpy() - lateral_lengths = get_root_lengths(lateral_pts) - # get pts_all_array - pts_all_array = get_all_pts_array( - primary_max_length_pts, lateral_pts, monocots=monocots - ) - bbox = get_bbox(pts_all_array) + crown_pts = series.get_crown_points(frame_index) + crown_lengths = get_root_lengths(crown_pts) + bbox = get_bbox(crown_pts) # get network_length_lower - network_length_lower = get_network_distribution( - primary_max_length_pts, lateral_pts, bbox, fraction=fraction, monocots=monocots - ) + network_length_lower = get_network_distribution(crown_pts, bbox, fraction=fraction) + # get total network length + network_length = get_network_length(crown_lengths) + # get ratio of network length in lower 2/3 of bounding box to total network length ratio = get_network_distribution_ratio( - primary_length, - lateral_lengths, + network_length, network_length_lower, - fraction, - monocots, ) - np.testing.assert_almost_equal(ratio, 0.5982820592421038, decimal=7) diff --git a/tests/test_points.py b/tests/test_points.py index 3fdf368..f9b4c3e 100644 --- a/tests/test_points.py +++ b/tests/test_points.py @@ -1,41 +1,357 @@ +import pytest +import numpy as np from sleap_roots import Series -from sleap_roots.lengths import get_max_length_pts, get_root_lengths +from sleap_roots.lengths import get_max_length_pts +from sleap_roots.points import get_count, join_pts from sleap_roots.points import ( get_all_pts_array, + get_nodes, + get_left_right_normalized_vectors, + get_left_normalized_vector, + get_right_normalized_vector, + get_line_equation_from_points, ) +# test get_count function with canola +def test_get_lateral_count(canola_h5): + # Set frame index to 0 + frame_idx = 0 + # Load the series + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the lateral points + lateral_points = series.get_lateral_points(frame_idx) + # Get the count of lateral roots + lateral_count = get_count(lateral_points) + assert lateral_count == 5 + + +def test_join_pts_basic_functionality(): + pts1 = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]]) + pts2 = np.array([[8, 9], [10, 11]]) + result = join_pts(pts1, pts2) + + expected = [ + np.array([[0, 1], [2, 3]]), + np.array([[4, 5], [6, 7]]), + np.array([[8, 9], [10, 11]]), + ] + for r, e in zip(result, expected): + assert np.array_equal(r, e) + assert r.shape == (2, 2) + + +def test_join_pts_single_array_input(): + pts = np.array([[[0, 1], [2, 3]]]) + result = join_pts(pts) + + expected = [np.array([[0, 1], [2, 3]])] + for r, e in zip(result, expected): + assert np.array_equal(r, e) + assert r.shape == (2, 2) + + +def test_join_pts_none_input(): + pts1 = np.array([[[0, 1], [2, 3]]]) + pts2 = None + result = join_pts(pts1, pts2) + + expected = [np.array([[0, 1], [2, 3]])] + for r, e in zip(result, expected): + assert np.array_equal(r, e) + assert r.shape == (2, 2) + + +def test_join_pts_invalid_shape(): + # Test for array with last dimension not equal to 2 + with pytest.raises(ValueError): + pts = np.array([[[0, 1, 2]]]) + join_pts(pts) + + # Test for array with more than 3 dimensions + with pytest.raises(ValueError): + pts = np.array([[[[0, 1]]]]) + join_pts(pts) + + # Test for array with fewer than 2 dimensions + with pytest.raises(ValueError): + pts = np.array([0, 1]) + join_pts(pts) + + +def test_join_pts_mixed_shapes(): + pts1 = np.array([[0, 1], [2, 3]]) + pts2 = np.array([[[4, 5], [6, 7]], [[8, 9], [10, 11]]]) + result = join_pts(pts1, pts2) + + expected = [ + np.array([[0, 1], [2, 3]]), + np.array([[4, 5], [6, 7]]), + np.array([[8, 9], [10, 11]]), + ] + for r, e in zip(result, expected): + assert np.array_equal(r, e) + assert r.shape == (2, 2) + + # test get_all_pts_array function def test_get_all_pts_array(canola_h5): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = plant[0] - primary_pts = primary.numpy() - # get primary length + # Set frame index to 0 + frame_idx = 0 + # Load the series + plant = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the primary points + primary_pts = plant.get_primary_points(frame_idx) + # Get primary length primary_max_length_pts = get_max_length_pts(primary_pts) - # get lateral_lengths - lateral_pts = lateral.numpy() - monocots = False - pts_all_array = get_all_pts_array( - primary_max_length_pts, lateral_pts, monocots=monocots - ) + # Get lateral points + lateral_pts = plant.get_lateral_points(frame_idx) + pts_all_array = get_all_pts_array(primary_max_length_pts, lateral_pts) assert pts_all_array.shape[0] == 21 # test get_all_pts_array function def test_get_all_pts_array_rice(rice_h5): - plant = Series.load( - rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" + # Set frame index to 0 + frame_idx = 0 + # Load the series + plant = Series.load(rice_h5, primary_name="primary", lateral_name="crown") + # Get the lateral points + lateral_pts = plant.get_lateral_points(frame_idx) + # Get the flattened array with all of the points + pts_all_array = get_all_pts_array(lateral_pts) + assert pts_all_array.shape[0] == 12 + + +def test_single_instance(): + # Single instance with two nodes + pts = np.array([[1, 2], [3, 4]]) + node_index = 1 + expected_output = np.array([3, 4]) + assert np.array_equal(get_nodes(pts, node_index), expected_output) + + +def test_multiple_instances(): + # Multiple instances, each with two nodes + pts = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + node_index = 0 + expected_output = np.array([[1, 2], [5, 6]]) + assert np.array_equal(get_nodes(pts, node_index), expected_output) + + +def test_node_index_out_of_bounds(): + # Test with node_index out of bounds for the given points + pts = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + node_index = 2 # Out of bounds + with pytest.raises(ValueError): + get_nodes(pts, node_index) + + +def test_invalid_shape(): + # Test with invalid shape (not 2 or 3 dimensions) + pts = np.array([1, 2, 3]) # Invalid shape + node_index = 0 + with pytest.raises(ValueError): + get_nodes(pts, node_index) + + +def test_return_shape_single_instance(): + # Single instance input should return shape (2,) + pts = np.array([[1, 2], [3, 4]]) + node_index = 0 + output = get_nodes(pts, node_index) + assert output.shape == (2,) + + +def test_return_shape_multiple_instances(): + # Multiple instances input should return shape (instances, 2) + pts = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + node_index = 0 + output = get_nodes(pts, node_index) + assert output.shape == (2, 2) + + +def test_valid_input_vectors(): + """Test the get_left_right_normalized_vectors function with valid input arrays + where normalization is straightforward. + """ + r0_pts = np.array([[0, 1], [2, 3], [4, 5]]) + r1_pts = np.array([[1, 2], [3, 4], [5, 6]]) + expected_left_vector = np.array([1 / np.sqrt(2), 1 / np.sqrt(2)]) + expected_right_vector = np.array([1 / np.sqrt(2), 1 / np.sqrt(2)]) + + norm_vector_left, norm_vector_right = get_left_right_normalized_vectors( + r0_pts, r1_pts ) - primary, lateral = plant[0] - primary_pts = primary.numpy() - # get primary length - primary_max_length_pts = get_max_length_pts(primary_pts) - # get lateral_lengths - lateral_pts = lateral.numpy() - monocots = True - pts_all_array = get_all_pts_array( - primary_max_length_pts, lateral_pts, monocots=monocots + + assert np.allclose( + norm_vector_left, expected_left_vector + ), "Left vector normalization failed" + assert np.allclose( + norm_vector_right, expected_right_vector + ), "Right vector normalization failed" + + +def test_zero_length_vector_vectors(): + """Test the get_left_right_normalized_vectors function with inputs that result in + a zero-length vector, expecting vectors filled with NaNs. + """ + r0_pts = np.array([[0, 0], [0, 0]]) + r1_pts = np.array([[0, 0], [0, 0]]) + + norm_vector_left, norm_vector_right = get_left_right_normalized_vectors( + r0_pts, r1_pts ) - assert pts_all_array.shape[0] == 12 + + assert np.isnan( + norm_vector_left + ).all(), "Left vector should be NaN for zero-length vector" + assert np.isnan( + norm_vector_right + ).all(), "Right vector should be NaN for zero-length vector" + + +def test_invalid_input_shapes_vectors(): + """Test the get_left_right_normalized_vectors function with inputs that have + mismatched shapes, expecting vectors filled with NaNs. + """ + r0_pts = np.array([[0, 1]]) + r1_pts = np.array([[1, 2], [3, 4]]) + + norm_vector_left, norm_vector_right = get_left_right_normalized_vectors( + r0_pts, r1_pts + ) + + assert np.isnan( + norm_vector_left + ).all(), "Left vector should be NaN for invalid input shapes" + assert np.isnan( + norm_vector_right + ).all(), "Right vector should be NaN for invalid input shapes" + + +def test_single_instance_input_vectors(): + """Test the get_left_right_normalized_vectors function with a single instance, + which should return vectors filled with NaNs since the function requires + more than one instance for comparison. + """ + r0_pts = np.array([[0, 1]]) + r1_pts = np.array([[1, 2]]) + + norm_vector_left, norm_vector_right = get_left_right_normalized_vectors( + r0_pts, r1_pts + ) + + assert np.isnan( + norm_vector_left + ).all(), "Left vector should be NaN for single instance input" + assert np.isnan( + norm_vector_right + ).all(), "Right vector should be NaN for single instance input" + + +def test_get_left_normalized_vector_with_valid_input(): + """ + Test get_left_normalized_vector with a valid pair of normalized vectors. + """ + left_vector = np.array([1 / np.sqrt(2), 1 / np.sqrt(2)]) + right_vector = np.array([-1 / np.sqrt(2), 1 / np.sqrt(2)]) + normalized_vectors = (left_vector, right_vector) + + result = get_left_normalized_vector(normalized_vectors) + assert np.allclose( + result, left_vector + ), "The left vector was not returned correctly." + + +def test_get_right_normalized_vector_with_valid_input(): + """ + Test get_right_normalized_vector with a valid pair of normalized vectors. + """ + left_vector = np.array([1 / np.sqrt(2), 1 / np.sqrt(2)]) + right_vector = np.array([-1 / np.sqrt(2), 1 / np.sqrt(2)]) + normalized_vectors = (left_vector, right_vector) + + result = get_right_normalized_vector(normalized_vectors) + assert np.allclose( + result, right_vector + ), "The right vector was not returned correctly." + + +def test_get_left_normalized_vector_with_nan(): + """ + Test get_left_normalized_vector when the left vector is filled with NaNs. + """ + left_vector = np.array([np.nan, np.nan]) + right_vector = np.array([1, 0]) + normalized_vectors = (left_vector, right_vector) + + result = get_left_normalized_vector(normalized_vectors) + assert np.isnan(result).all(), "Expected a vector of NaNs for the left side." + + +def test_get_right_normalized_vector_with_nan(): + """ + Test get_right_normalized_vector when the right vector is filled with NaNs. + """ + left_vector = np.array([0, 1]) + right_vector = np.array([np.nan, np.nan]) + normalized_vectors = (left_vector, right_vector) + + result = get_right_normalized_vector(normalized_vectors) + assert np.isnan(result).all(), "Expected a vector of NaNs for the right side." + + +def test_normalized_vectors_with_empty_arrays(): + """ + Test get_left_normalized_vector and get_right_normalized_vector with empty arrays. + """ + left_vector = np.array([]) + right_vector = np.array([]) + normalized_vectors = (left_vector, right_vector) + + left_result = get_left_normalized_vector(normalized_vectors) + right_result = get_right_normalized_vector(normalized_vectors) + + assert ( + left_result.size == 0 and right_result.size == 0 + ), "Expected empty arrays for both left and right vectors." + + +@pytest.mark.parametrize( + "pts1, pts2, expected", + [ + (np.array([0, 0]), np.array([1, 1]), (1, 0)), # Diagonal line, positive slope + (np.array([1, 1]), np.array([2, 2]), (1, 0)), # Diagonal line, positive slope + (np.array([0, 1]), np.array([1, 0]), (-1, 1)), # Diagonal line, negative slope + (np.array([1, 2]), np.array([3, 2]), (0, 2)), # Horizontal line + ( + np.array([2, 3]), + np.array([2, 5]), + (np.nan, np.nan), + ), # Vertical line should return NaNs + ( + np.array([0, 0]), + np.array([0, 0]), + (np.nan, np.nan), + ), # Identical points should return NaNs + ], +) +def test_get_line_equation_from_points(pts1, pts2, expected): + m, b = get_line_equation_from_points(pts1, pts2) + assert np.isclose(m, expected[0], equal_nan=True) and np.isclose( + b, expected[1], equal_nan=True + ), f"Expected slope {expected[0]} and intercept {expected[1]} but got slope {m} and intercept {b}" + + +@pytest.mark.parametrize( + "pts1, pts2", + [ + (np.array([1]), np.array([1, 2])), # Incorrect shape + (5, np.array([1, 2])), # Not an array + ("test", "test"), # Incorrect type + ], +) +def test_get_line_equation_input_errors(pts1, pts2): + with pytest.raises(ValueError): + get_line_equation_from_points(pts1, pts2) diff --git a/tests/test_scanline.py b/tests/test_scanline.py index a621b87..32526a1 100644 --- a/tests/test_scanline.py +++ b/tests/test_scanline.py @@ -1,6 +1,8 @@ import pytest import numpy as np +from typing import List from sleap_roots import Series +from sleap_roots.points import join_pts from sleap_roots.lengths import get_max_length_pts from sleap_roots.scanline import ( get_scanline_first_ind, @@ -55,58 +57,83 @@ def pts_nan3(): def test_count_scanline_intersections_canola(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() - lateral_pts = lateral.numpy() + # Set the frame number to 0 + frame = 0 + # Load the series from canola + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the primary and lateral roots + primary_pts = series.get_primary_points(frame) primary_pts = get_max_length_pts(primary_pts) + lateral_pts = series.get_lateral_points(frame) + pts_all_list = join_pts(primary_pts, lateral_pts) depth = 1080 - width = 2048 n_line = 50 - monocots = False - n_inter = count_scanline_intersections( - primary_pts, lateral_pts, depth, width, n_line, monocots - ) + n_inter = count_scanline_intersections(pts_all_list, depth, n_line) assert n_inter.shape == (50,) np.testing.assert_equal(n_inter[14], 1) def test_count_scanline_intersections_rice(rice_h5): - series = Series.load( - rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" - ) - primary, lateral = series[0] - primary_pts = primary.numpy() - lateral_pts = lateral.numpy() - primary_pts = get_max_length_pts(primary_pts) + # Set the frame number to 0 + frame = 0 + # Load the series from rice + series = Series.load(rice_h5, primary_name="primary", crown_name="crown") + crown_pts = series.get_crown_points(frame) depth = 1080 - width = 2048 n_line = 50 - monocots = True - n_inter = count_scanline_intersections( - primary_pts, lateral_pts, depth, width, n_line, monocots - ) + n_inter = count_scanline_intersections(crown_pts, depth, n_line) assert n_inter.shape == (50,) np.testing.assert_equal(n_inter[14], 2) +def test_count_scanline_intersections_basic(): + pts_list = [np.array([[0, 0], [4, 0]]), np.array([[0, 1], [4, 1]])] + height = 2 + n_line = 3 # y-values: 0, 1, 2 + result = count_scanline_intersections(pts_list, height, n_line) + assert np.all(result == np.array([1, 1, 0])) # Intersections at y = 0 and y = 1 + + +def test_count_scanline_intersections_invalid_shape(): + with pytest.raises(ValueError): + pts_list = [np.array([0, 1])] + count_scanline_intersections(pts_list) + + +def test_count_scanline_intersections_with_nan(): + pts_list = [np.array([[0, 0], [4, 0]]), np.array([[0, 1], [4, np.nan]])] + height = 2 + n_line = 3 # y-values: 0, 1, 2 + result = count_scanline_intersections(pts_list, height, n_line) + assert np.all(result == np.array([1, 0, 0])) # Only one valid intersection at y = 0 + + +def test_count_scanline_intersections_different_params(): + pts_list = [np.array([[0, 0], [4, 0]]), np.array([[0, 2], [4, 2]])] + height = 4 + n_line = 5 # y-values: 0, 1, 2, 3, 4 + result = count_scanline_intersections(pts_list, height, n_line) + assert np.all( + result == np.array([1, 0, 1, 0, 0]) + ) # Intersections at y = 0 and y = 2 + + # test get_scanline_first_ind with canola def test_get_scanline_first_ind(canola_h5): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = plant[0] - primary_pts = primary.numpy() - lateral_pts = lateral.numpy() + # Set the frame number to 0 + frame = 0 + # Load the series from canola + plant = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + primary_pts = plant.get_primary_points(frame) primary_pts = get_max_length_pts(primary_pts) + lateral_pts = plant.get_lateral_points(frame) depth = 1080 - width = 2048 n_line = 50 - monocots = False + pts_all_list = join_pts(primary_pts, lateral_pts) scanline_intersection_counts = count_scanline_intersections( - primary_pts, lateral_pts, depth, width, n_line, monocots + pts_all_list, + depth, + n_line, ) scanline_first_ind = get_scanline_first_ind(scanline_intersection_counts) np.testing.assert_equal(scanline_first_ind, 7) @@ -114,19 +141,18 @@ def test_get_scanline_first_ind(canola_h5): # test get_scanline_last_ind with canola def test_get_scanline_last_ind(canola_h5): - plant = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - primary, lateral = plant[0] - primary_pts = primary.numpy() - lateral_pts = lateral.numpy() + # Set the frame number to 0 + frame = 0 + # Load the series from canola + plant = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + primary_pts = plant.get_primary_points(frame) primary_pts = get_max_length_pts(primary_pts) + lateral_pts = plant.get_lateral_points(frame) depth = 1080 - width = 2048 n_line = 50 - monocots = True + pts_all_list = join_pts(primary_pts, lateral_pts) scanline_intersection_counts = count_scanline_intersections( - primary_pts, lateral_pts, depth, width, n_line, monocots + pts_all_list, depth, n_line ) scanline_last_ind = get_scanline_last_ind(scanline_intersection_counts) - np.testing.assert_equal(scanline_last_ind, 12) + np.testing.assert_equal(scanline_last_ind, 46) diff --git a/tests/test_series.py b/tests/test_series.py index 6c87113..861f198 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1,13 +1,121 @@ +import sleap_io as sio +import numpy as np +import pytest from sleap_roots.series import Series, find_all_series +from pathlib import Path +from typing import Literal -def test_series_load(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) +@pytest.fixture +def dummy_video_path(tmp_path): + video_path = tmp_path / "dummy_video.mp4" + video_path.write_text("This is a dummy video file.") + return str(video_path) + + +@pytest.fixture(params=["primary", "lateral", "crown"]) +def label_type(request): + """Yields label types for tests, one by one.""" + return request.param + + +@pytest.fixture +def dummy_labels_path(tmp_path, label_type): + labels_path = tmp_path / f"dummy.{label_type}.predictions.slp" + # Simulate the structure of a SLEAP labels file. + labels_path.write_text("Dummy SLEAP labels content.") + return str(labels_path) + + +@pytest.fixture +def dummy_series(dummy_video_path, dummy_labels_path): + # Assuming dummy_labels_path names are formatted as "{label_type}.predictions.slp" + # Extract the label type (primary, lateral, crown) from the filename + label_type = Path(dummy_labels_path).stem.split(".")[1] + + # Construct the keyword argument for Series.load() + kwargs = { + "h5_path": dummy_video_path, + f"{label_type}_name": dummy_labels_path, + } + return Series.load(**kwargs) + + +def test_series_name(dummy_series): + expected_name = "dummy_video" # Based on the dummy_video_path fixture + assert dummy_series.series_name == expected_name + + +def test_get_frame(dummy_series): + frame_idx = 0 + frames = dummy_series.get_frame(frame_idx) + assert isinstance(frames, dict) + assert "primary" in frames + assert "lateral" in frames + assert "crown" in frames + + +def test_series_name_property(): + series = Series(h5_path="mock_path/file_name.h5") + assert series.series_name == "file_name" + + +def test_len(): + series = Series(video=["frame1", "frame2"]) + assert len(series) == 2 + + +def test_series_load_canola(canola_h5: Literal["tests/data/canola_7do/919QDUH.h5"]): + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") assert len(series) == 72 -def test_find_all_series(canola_folder): +def test_find_all_series_canola(canola_folder: Literal["tests/data/canola_7do"]): all_series_files = find_all_series(canola_folder) assert len(all_series_files) == 1 + + +def test_load_rice_10do( + rice_main_10do_h5: Literal["tests/data/rice_10do/0K9E8BI.h5"], +): + series = Series.load(rice_main_10do_h5, crown_name="crown") + expected_video = sio.Video.from_filename(rice_main_10do_h5) + + assert len(series) == 72 + assert series.h5_path == rice_main_10do_h5 + assert series.video.filename == expected_video.filename + + +def test_get_frame_rice_10do( + rice_main_10do_h5: Literal["tests/data/rice_10do/0K9E8BI.h5"], + rice_main_10do_slp: Literal["tests/data/rice_10do/0K9E8BI.crown.predictions.slp"], +): + # Set the frame index to 0 + frame_idx = 0 + + # Load the expected Labels object for comparison + expected_labels = sio.load_slp(rice_main_10do_slp) + # Get the first labeled frame + expected_labeled_frame = expected_labels[0] + + # Load the series + series = Series.load(rice_main_10do_h5, crown_name="crown") + # Retrieve all available frames + frames = series.get_frame(frame_idx) + # Get the crown labeled frame + crown_lf = frames.get("crown") + assert crown_lf == expected_labeled_frame + # Check the series name property + assert series.series_name == "0K9E8BI" + + +def test_find_all_series_rice_10do(rice_10do_folder: Literal["tests/data/rice_10do"]): + series_h5_path = Path(rice_10do_folder) / "0K9E8BI.h5" + all_series_files = find_all_series(rice_10do_folder) + assert len(all_series_files) == 1 + assert series_h5_path.as_posix() == "tests/data/rice_10do/0K9E8BI.h5" + + +def test_find_all_series_rice(rice_folder: Literal["tests/data/rice_3do"]): + all_series_files = find_all_series(rice_folder) + assert len(all_series_files) == 2 diff --git a/tests/test_tips.py b/tests/test_tips.py index 81c1816..65cd000 100644 --- a/tests/test_tips.py +++ b/tests/test_tips.py @@ -82,12 +82,15 @@ def test_tips_one_tip(pts_one_tip): # test get_tip_xs with canola def test_get_tip_xs_canola(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - lateral = series[0][1] # LabeledFrame - lateral_pts = lateral.numpy() # Lateral roots as a numpy array + # Set the frame index to 0 + frame_index = 0 + # Load the series with a canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the lateral roots from the series + lateral_pts = series.get_lateral_points(frame_index) + # Get the tips from the lateral roots tips = get_tips(lateral_pts) + # Get the tip x-coordinates tip_xs = get_tip_xs(tips) assert tip_xs.shape[0] == 5 np.testing.assert_almost_equal(tip_xs[1], 1072.6610107421875, decimal=3) @@ -95,6 +98,7 @@ def test_get_tip_xs_canola(canola_h5): # test get_tip_xs with standard points def test_get_tip_xs_standard(pts_standard): + # Get the tips from the standard points tips = get_tips(pts_standard) tip_xs = get_tip_xs(tips) assert tip_xs.shape[0] == 2 @@ -104,24 +108,28 @@ def test_get_tip_xs_standard(pts_standard): # test get_tip_xs with no tips def test_get_tip_xs_no_tip(pts_no_tips): + # Get the tips from the no tips points tips = get_tips(pts_no_tips) tip_xs = get_tip_xs(tips) assert tip_xs.shape[0] == 2 np.testing.assert_almost_equal(tip_xs[1], np.nan, decimal=3) assert type(tip_xs) == np.ndarray - tip_xs = get_tip_xs(tips[[0]], flatten=True) - assert type(tip_xs) == np.float64 + tip_xs = get_tip_xs(tips[[0]]) + assert type(tip_xs) == np.ndarray # test get_tip_ys with canola def test_get_tip_ys_canola(canola_h5): - series = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - lateral = series[0][1] # LabeledFrame - lateral_pts = lateral.numpy() # Lateral roots as a numpy array + # Set the frame index to 0 + frame_index = 0 + # Load the series with a canola dataset + series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + # Get the lateral root points from the series + lateral_pts = series.get_lateral_points(frame_index) + # Get the tips from the lateral roots tips = get_tips(lateral_pts) + # Get the tip y-coordinates tip_ys = get_tip_ys(tips) assert tip_ys.shape[0] == 5 np.testing.assert_almost_equal(tip_ys[1], 276.51275634765625, decimal=3) @@ -144,5 +152,5 @@ def test_get_tip_ys_no_tip(pts_no_tips): np.testing.assert_almost_equal(tip_ys[1], np.nan, decimal=3) assert type(tip_ys) == np.ndarray - tip_ys = get_tip_ys(tips[[0]], flatten=True) - assert type(tip_ys) == np.float64 + tip_ys = get_tip_ys(tips[[0]]) + assert type(tip_ys) == np.ndarray diff --git a/tests/test_trait_pipelines.py b/tests/test_trait_pipelines.py index feec406..0b7435c 100644 --- a/tests/test_trait_pipelines.py +++ b/tests/test_trait_pipelines.py @@ -1,14 +1,15 @@ -from sleap_roots.trait_pipelines import DicotPipeline, YoungerMonocotPipeline +from sleap_roots.trait_pipelines import ( + DicotPipeline, + YoungerMonocotPipeline, + OlderMonocotPipeline, +) from sleap_roots.series import Series, find_all_series def test_dicot_pipeline(canola_h5, soy_h5): - canola = Series.load( - canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes" - ) - soy = Series.load( - soy_h5, primary_name="primary_multi_day", lateral_name="lateral__nodes" - ) + # Load the data + canola = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") + soy = Series.load(soy_h5, primary_name="primary", lateral_name="lateral") pipeline = DicotPipeline() canola_traits = pipeline.compute_plant_traits(canola) @@ -20,15 +21,20 @@ def test_dicot_pipeline(canola_h5, soy_h5): assert all_traits.shape == (2, 1036) +def test_OlderMonocot_pipeline(rice_main_10do_h5): + rice = Series.load(rice_main_10do_h5, crown_name="crown") + + pipeline = OlderMonocotPipeline() + rice_10dag_traits = pipeline.compute_plant_traits(rice) + + assert rice_10dag_traits.shape == (72, 102) + + def test_younger_monocot_pipeline(rice_h5, rice_folder): - rice = Series.load( - rice_h5, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" - ) + rice = Series.load(rice_h5, primary_name="primary", crown_name="crown") rice_series_all = find_all_series(rice_folder) series_all = [ - Series.load( - series, primary_name="longest_3do_6nodes", lateral_name="main_3do_6nodes" - ) + Series.load(series, primary_name="primary", crown_name="crown") for series in rice_series_all ] @@ -43,12 +49,12 @@ def test_younger_monocot_pipeline(rice_h5, rice_folder): # Dataframe dtype assertions expected_rice_traits_dtypes = { "frame_idx": "int64", - "main_count": "int64", + "crown_count": "int64", } expected_all_traits_dtypes = { - "main_count_min": "int64", - "main_count_max": "int64", + "crown_count_min": "int64", + "crown_count_max": "int64", } for col, expected_dtype in expected_rice_traits_dtypes.items(): @@ -69,13 +75,61 @@ def test_younger_monocot_pipeline(rice_h5, rice_folder): all_traits["curve_index_median"] >= 0 ).all(), "curve_index in all_traits contains negative values" assert ( - all_traits["main_curve_indices_mean_median"] >= 0 - ).all(), "main_curve_indices_mean_median in all_traits contains negative values" + all_traits["crown_curve_indices_mean_median"] >= 0 + ).all(), "crown_curve_indices_mean_median in all_traits contains negative values" + assert ( + (0 <= rice_traits["crown_angles_proximal_p95"]) + & (rice_traits["crown_angles_proximal_p95"] <= 180) + ).all(), "angle_column in rice_traits contains values out of range [0, 180]" + assert ( + (0 <= all_traits["crown_angles_proximal_median_p95"]) + & (all_traits["crown_angles_proximal_median_p95"] <= 180) + ).all(), "angle_column in all_traits contains values out of range [0, 180]" + + +def test_older_monocot_pipeline(rice_main_10do_h5, rice_10do_folder): + rice = Series.load(rice_main_10do_h5, crown_name="crown") + rice_series_all = find_all_series(rice_10do_folder) + series_all = [Series.load(series, crown_name="crown") for series in rice_series_all] + + pipeline = OlderMonocotPipeline() + rice_traits = pipeline.compute_plant_traits(rice) + all_traits = pipeline.compute_batch_traits(series_all) + + # Dataframe shape assertions + assert rice_traits.shape == (72, 102) + assert all_traits.shape == (1, 901) + + # Dataframe dtype assertions + expected_rice_traits_dtypes = { + "frame_idx": "int64", + "crown_count": "int64", + } + + expected_all_traits_dtypes = { + "crown_count_min": "int64", + "crown_count_max": "int64", + } + + for col, expected_dtype in expected_rice_traits_dtypes.items(): + assert ( + rice_traits[col].dtype == expected_dtype + ), f"Unexpected dtype for column {col} in rice_traits" + + for col, expected_dtype in expected_all_traits_dtypes.items(): + assert ( + all_traits[col].dtype == expected_dtype + ), f"Unexpected dtype for column {col} in all_traits" + + # Value range assertions for traits + assert ( + all_traits["crown_curve_indices_mean_median"] >= 0 + ).all(), "crown_curve_indices_mean_median in all_traits contains negative values" assert ( - (0 <= rice_traits["main_angles_proximal_p95"]) - & (rice_traits["main_angles_proximal_p95"] <= 180) + (0 <= rice_traits["crown_angles_proximal_p95"]) + & (rice_traits["crown_angles_proximal_p95"] <= 180) ).all(), "angle_column in rice_traits contains values out of range [0, 180]" assert ( - (0 <= all_traits["main_angles_proximal_median_p95"]) - & (all_traits["main_angles_proximal_median_p95"] <= 180) + (0 <= all_traits["crown_angles_proximal_median_p95"]) + & (all_traits["crown_angles_proximal_median_p95"] <= 180) ).all(), "angle_column in all_traits contains values out of range [0, 180]"