Skip to content

Commit

Permalink
Fix tips and bases and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eberrigan committed Mar 2, 2024
1 parent 10eb050 commit 264dcca
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 235 deletions.
36 changes: 12 additions & 24 deletions sleap_roots/tips.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,34 +50,22 @@ def get_tip_xs(tip_pts: np.ndarray) -> np.ndarray:
return tip_xs


def get_tip_ys(tip_pts: np.ndarray, flatten: bool = False) -> np.ndarray:
def get_tip_ys(tip_pts: np.ndarray) -> np.ndarray:
"""Get y coordinates of tip points.
Args:
tip_pts: Root tips as array of shape `(instances, 2)` or `(2)` when there is
only one tip.
flatten: If `True`, return scalar (0-D) array if scalar (there is only 1 root).
Defaults to `False`.
tip_pts: Root tip points as array of shape `(instances, 2)` or `(2,)` when there
is only one tip.
Return:
An array of the y-coordinates of tips (instances,) or () if `flatten` is `True`.
An array of tip y-coordinates (instances,) or (1,) when there is only one root.
"""
# If the input is a single number (float or integer), raise an error
if isinstance(tip_pts, (np.floating, float, np.integer, int)):
raise ValueError("Input must be an array of shape `(instances, 2)` or `(2, )`.")

# Check for the 2D shape of the input array
if tip_pts.ndim == 1:
# If shape is `(2,)`, then reshape it to `(1, 2)` for consistency
tip_pts = tip_pts.reshape(1, 2)
elif tip_pts.ndim != 2:
raise ValueError("Input array must be of shape `(instances, 2)` or `(2, )`.")

# At this point, `tip_pts` should be of shape `(instances, 2)`.
tip_ys = tip_pts[:, 1]

if flatten:
tip_ys = tip_ys.squeeze()
if tip_ys.size == 1:
tip_ys = tip_ys[()]
if tip_pts.ndim not in (1, 2):
raise ValueError(
"Input array must be 2-dimensional (instances, 2) or 1-dimensional (2,)."
)
if tip_pts.shape[-1] != 2:
raise ValueError("Last dimension must be (x, y).")

tip_ys = tip_pts[..., 1]
return tip_ys
201 changes: 0 additions & 201 deletions tests/test_networklength.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,207 +211,6 @@ def test_get_network_distribution_one_point():
[[[4, 4], [5, 5]], [[6, 6], [np.nan, np.nan]]]
) # One of the roots has only one point
bounding_box = (0, 0, 10, 10)
<<<<<<< HEAD
=======
fraction = 2 / 3
monocots = False

# Call the function
network_length = get_network_distribution(
primary_pts, lateral_pts, bounding_box, fraction, monocots
)

# Define the expected result
# Only the valid roots should be considered in the calculation
lower_box = Polygon(
[(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))]
)
expected_length = (
LineString(primary_pts[0]).intersection(lower_box).length
+ LineString(lateral_pts[0]).intersection(lower_box).length
)

# Assert that the result is as expected
assert network_length == pytest.approx(expected_length)


def test_get_network_distribution_empty_arrays():
primary_pts = np.full((2, 2), np.nan)
lateral_pts = np.full((2, 2, 2), np.nan)
bounding_box = (0, 0, 10, 10)

network_length = get_network_distribution(primary_pts, lateral_pts, bounding_box)
assert network_length == 0


def test_get_network_distribution_with_nans():
primary_pts = np.array([[1, 1], [2, 2], [np.nan, np.nan]])
lateral_pts = np.array([[[4, 4], [5, 5], [np.nan, np.nan]]])
bounding_box = (0, 0, 10, 10)

network_length = get_network_distribution(primary_pts, lateral_pts, bounding_box)

lower_box = Polygon(
[(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))]
)
expected_length = (
LineString(primary_pts[:-1]).intersection(lower_box).length
+ LineString(lateral_pts[0, :-1]).intersection(lower_box).length
)

assert network_length == pytest.approx(expected_length)


def test_get_network_distribution_monocots():
primary_pts = np.array([[1, 1], [2, 2], [3, 3]])
lateral_pts = np.array([[[4, 4], [5, 5]]])
bounding_box = (0, 0, 10, 10)
monocots = True

network_length = get_network_distribution(
primary_pts, lateral_pts, bounding_box, monocots=monocots
)

lower_box = Polygon(
[(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))]
)
expected_length = (
LineString(lateral_pts[0]).intersection(lower_box).length
) # Only lateral_pts are considered

assert network_length == pytest.approx(expected_length)


def test_get_network_distribution_different_fraction():
primary_pts = np.array([[1, 1], [2, 2], [3, 3]])
lateral_pts = np.array([[[4, 4], [5, 5]]])
bounding_box = (0, 0, 10, 10)
fraction = 0.5

network_length = get_network_distribution(
primary_pts, lateral_pts, bounding_box, fraction=fraction
)

lower_box = Polygon(
[(0, 10 - 10 * fraction), (0, 10), (10, 10), (10, 10 - 10 * fraction)]
)
expected_length = (
LineString(primary_pts).intersection(lower_box).length
+ LineString(lateral_pts[0]).intersection(lower_box).length
)

assert network_length == pytest.approx(expected_length)


def test_get_network_distribution_one_point():
# Define inputs
primary_pts = np.array([[[1, 1], [2, 2], [3, 3]]])
lateral_pts = np.array(
[[[4, 4], [5, 5]], [[6, 6], [np.nan, np.nan]]]
) # One of the roots has only one point
bounding_box = (0, 0, 10, 10)
fraction = 2 / 3
monocots = False

# Call the function
network_length = get_network_distribution(
primary_pts, lateral_pts, bounding_box, fraction, monocots
)

# Define the expected result
# Only the valid roots should be considered in the calculation
lower_box = Polygon(
[(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))]
)
expected_length = (
LineString(primary_pts[0]).intersection(lower_box).length
+ LineString(lateral_pts[0]).intersection(lower_box).length
)

# Assert that the result is as expected
assert network_length == pytest.approx(expected_length)


def test_get_network_distribution_empty_arrays():
primary_pts = np.full((2, 2), np.nan)
lateral_pts = np.full((2, 2, 2), np.nan)
bounding_box = (0, 0, 10, 10)

network_length = get_network_distribution(primary_pts, lateral_pts, bounding_box)
assert network_length == 0


def test_get_network_distribution_with_nans():
primary_pts = np.array([[1, 1], [2, 2], [np.nan, np.nan]])
lateral_pts = np.array([[[4, 4], [5, 5], [np.nan, np.nan]]])
bounding_box = (0, 0, 10, 10)

network_length = get_network_distribution(primary_pts, lateral_pts, bounding_box)

lower_box = Polygon(
[(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))]
)
expected_length = (
LineString(primary_pts[:-1]).intersection(lower_box).length
+ LineString(lateral_pts[0, :-1]).intersection(lower_box).length
)

assert network_length == pytest.approx(expected_length)


def test_get_network_distribution_monocots():
primary_pts = np.array([[1, 1], [2, 2], [3, 3]])
lateral_pts = np.array([[[4, 4], [5, 5]]])
bounding_box = (0, 0, 10, 10)
monocots = True

network_length = get_network_distribution(
primary_pts, lateral_pts, bounding_box, monocots=monocots
)

lower_box = Polygon(
[(0, 10 - 10 * (2 / 3)), (0, 10), (10, 10), (10, 10 - 10 * (2 / 3))]
)
expected_length = (
LineString(lateral_pts[0]).intersection(lower_box).length
) # Only lateral_pts are considered

assert network_length == pytest.approx(expected_length)


def test_get_network_distribution_different_fraction():
primary_pts = np.array([[1, 1], [2, 2], [3, 3]])
lateral_pts = np.array([[[4, 4], [5, 5]]])
bounding_box = (0, 0, 10, 10)
fraction = 0.5

network_length = get_network_distribution(
primary_pts, lateral_pts, bounding_box, fraction=fraction
)

lower_box = Polygon(
[(0, 10 - 10 * fraction), (0, 10), (10, 10), (10, 10 - 10 * fraction)]
)
expected_length = (
LineString(primary_pts).intersection(lower_box).length
+ LineString(lateral_pts[0]).intersection(lower_box).length
)

assert network_length == pytest.approx(expected_length)


def test_get_network_distribution(canola_h5):
series = Series.load(
canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes"
)
primary, lateral = series[0]
primary_pts = primary.numpy()
primary_max_length_pts = get_max_length_pts(primary_pts)
lateral_pts = lateral.numpy()
monocots = False
pts_all_array = get_all_pts_array(primary_max_length_pts, lateral_pts, monocots)
bbox = get_bbox(pts_all_array)
>>>>>>> main
fraction = 2 / 3
pts = join_pts(primary_pts, lateral_pts)
# Call the function
Expand Down
7 changes: 1 addition & 6 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test_len():


def test_series_load_canola(canola_h5: Literal["tests/data/canola_7do/919QDUH.h5"]):
series = Series.load(canola_h5, ["primary", "lateral"])
series = Series.load(canola_h5, primary_name="primary", lateral_name="lateral")
assert len(series) == 72


Expand All @@ -75,11 +75,6 @@ def test_find_all_series_canola(canola_folder: Literal["tests/data/canola_7do"])
assert len(all_series_files) == 1


def test_find_all_series_rice_10do(rice_10do_folder: Literal["tests/data/rice_10do"]):
all_series_files = find_all_series(rice_10do_folder)
assert len(all_series_files) == 1


def test_load_rice_10do(
rice_main_10do_h5: Literal["tests/data/rice_10do/0K9E8BI.h5"],
):
Expand Down
8 changes: 4 additions & 4 deletions tests/test_tips.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def test_get_tip_xs_no_tip(pts_no_tips):
np.testing.assert_almost_equal(tip_xs[1], np.nan, decimal=3)
assert type(tip_xs) == np.ndarray

tip_xs = get_tip_xs(tips[[0]], flatten=True)
assert type(tip_xs) == np.float64
tip_xs = get_tip_xs(tips[[0]])
assert type(tip_xs) == np.ndarray


# test get_tip_ys with canola
Expand Down Expand Up @@ -152,5 +152,5 @@ def test_get_tip_ys_no_tip(pts_no_tips):
np.testing.assert_almost_equal(tip_ys[1], np.nan, decimal=3)
assert type(tip_ys) == np.ndarray

tip_ys = get_tip_ys(tips[[0]], flatten=True)
assert type(tip_ys) == np.float64
tip_ys = get_tip_ys(tips[[0]])
assert type(tip_ys) == np.ndarray

0 comments on commit 264dcca

Please sign in to comment.