diff --git a/tests/test_convhull.py b/tests/test_convhull.py index 94e6a35..f7b44c0 100644 --- a/tests/test_convhull.py +++ b/tests/test_convhull.py @@ -14,10 +14,15 @@ get_chull_division_areas, get_chull_areas_via_intersection, get_chull_intersection_vectors, + get_chull_intersection_vectors_left, + get_chull_intersection_vectors_right, + get_chull_area_via_intersection_above, + get_chull_area_via_intersection_below, ) from sleap_roots.lengths import get_max_length_pts from sleap_roots.points import get_all_pts_array, get_nodes from sleap_roots.bases import get_bases +from sleap_roots.angle import get_vector_angles_from_gravity @pytest.fixture @@ -210,6 +215,11 @@ def test_get_convhull_features_rice(rice_h5, rice_long_slp, rice_main_slp): ) # Get the crown root from the series crown_pts = series.get_crown_points(frame_index) + + # Get the r0 and r1 nodes from the crown root + r0_pts = get_bases(crown_pts) + r1_pts = get_nodes(crown_pts, 1) + # Get the convex hull from the points convex_hull = get_convhull(crown_pts) perimeter = get_chull_perimeter(convex_hull) @@ -217,10 +227,28 @@ def test_get_convhull_features_rice(rice_h5, rice_long_slp, rice_main_slp): max_width = get_chull_max_width(convex_hull) max_height = get_chull_max_height(convex_hull) + # Get the intersection vectors + left_vector = get_chull_intersection_vectors_left(get_chull_intersection_vectors(r0_pts, r1_pts, crown_pts, convex_hull)) + right_vector = get_chull_intersection_vectors_right(get_chull_intersection_vectors(r0_pts, r1_pts, crown_pts, convex_hull)) + # Get angles from gravity + left_angle = get_vector_angles_from_gravity(left_vector) + right_angle = get_vector_angles_from_gravity(right_vector) + + # Get the intersection areas + area_above = get_chull_area_via_intersection_above(get_chull_areas_via_intersection(r1_pts, crown_pts, convex_hull)) + area_below = get_chull_area_via_intersection_below(get_chull_areas_via_intersection(r1_pts, crown_pts, convex_hull)) + + assert left_vector.shape == (1, 2) + assert right_vector.shape == (1, 2) + np.testing.assert_almost_equal(perimeter, 1450.6365795858003, decimal=3) np.testing.assert_almost_equal(area, 23722.883102604676, decimal=3) np.testing.assert_almost_equal(max_width, 64.341064453125, decimal=3) np.testing.assert_almost_equal(max_height, 715.6949920654297, decimal=3) + np.testing.assert_almost_equal(left_angle, 166.08852115310046, decimal=3) + np.testing.assert_almost_equal(right_angle, 0.04343543279020469, decimal=3) + np.testing.assert_almost_equal(area_above, 1903.040353098403, decimal=3) + np.testing.assert_almost_equal(area_below, 21819.842749506273, decimal=3) # test plant with 2 roots/instances with nan nodes diff --git a/tests/test_trait_pipelines.py b/tests/test_trait_pipelines.py index 2ac7288..8f6464c 100644 --- a/tests/test_trait_pipelines.py +++ b/tests/test_trait_pipelines.py @@ -237,6 +237,22 @@ def test_older_monocot_pipeline(rice_10do_pipeline_output_folder): (0 <= all_traits["crown_angles_proximal_median_p95"].dropna()) & (all_traits["crown_angles_proximal_median_p95"].dropna() <= 180) ).all(), "angle_column in all_traits contains values out of range [0, 180]" + assert ( + (0 <= rice_traits["angle_chull_r1_left_intersection_vector"].dropna()) + & (rice_traits["angle_chull_r1_left_intersection_vector"].dropna() <= 180) + ).all(), "angle column in rice_traits contains values out of range [0, 180]" + assert ( + (0 <= all_traits["angle_chull_r1_left_intersection_vector_p95"].dropna()) + & (all_traits["angle_chull_r1_left_intersection_vector_p95"].dropna() <= 180) + ).all(), "angle column in all_traits contains values out of range [0, 180]" + assert ( + (0 <= rice_traits["angle_chull_r1_right_intersection_vector"].dropna()) + & (rice_traits["angle_chull_r1_right_intersection_vector"].dropna() <= 180) + ).all(), "angle column in rice_traits contains values out of range [0, 180]" + assert ( + (0 <= all_traits["angle_chull_r1_right_intersection_vector_p95"].dropna()) + & (all_traits["angle_chull_r1_right_intersection_vector_p95"].dropna() <= 180) + ).all(), "angle column in all_traits contains values out of range [0, 180]" def test_multiple_dicot_pipeline(