-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement older monocot pipeline #54
Changes from 13 commits
3e841ec
92dae9e
9325a1b
1621efa
cb0f13b
e23af0f
4bc568f
ff1fb10
57b8fed
25391b4
dfabf07
890df7f
4843f08
0c42ecb
f33f0e0
797561b
aa59e2b
bd2621b
045a7b4
2364b89
a92be32
1fceb01
1ab5a39
11df884
606af51
3d1109f
57cfbd1
c063a9f
29cdb7a
113ee3e
534d9e6
b372567
93cfe87
7464680
dd06872
2279ac0
655d9cc
ef271da
a9054e3
d175680
5efb4fb
2416fd8
9838314
0c4b4df
1e692b5
10eb050
264dcca
e38b289
880817d
83b6ddd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
import numpy as np | ||
from shapely import LineString, Polygon | ||
from sleap_roots.lengths import get_root_lengths, get_max_length_pts | ||
from typing import Optional, Tuple, Union | ||
from typing import Optional, Tuple, Union, List | ||
|
||
|
||
def get_bbox(pts: np.ndarray) -> Tuple[float, float, float, float]: | ||
|
@@ -60,44 +60,42 @@ def get_network_width_depth_ratio( | |
|
||
|
||
def get_network_length( | ||
primary_length: float, | ||
lateral_lengths: Union[float, np.ndarray], | ||
monocots: bool = False, | ||
lengths0: Union[float, np.ndarray], | ||
*args: Optional[Union[float, np.ndarray]], | ||
) -> float: | ||
"""Return the total root network length given primary and lateral root lengths. | ||
|
||
Args: | ||
primary_length: Primary root length. | ||
lateral_lengths: Either a float representing the length of a single lateral | ||
root or an array of lateral root lengths with shape `(instances,)`. | ||
monocots: A boolean value, where True is rice. | ||
lengths0: Either a float representing the length of a single | ||
root or an array of root lengths with shape `(instances,)`. | ||
*args: Additional optional floats representing the lengths of single | ||
roots or arrays of root lengths with shape `(instances,)`. | ||
|
||
Returns: | ||
Total length of root network. | ||
""" | ||
# Ensure primary_length is a scalar | ||
if not isinstance(primary_length, (float, np.float64)): | ||
raise ValueError("Input primary_length must be a scalar value.") | ||
|
||
# Ensure lateral_lengths is either a scalar or has the correct shape | ||
if not ( | ||
isinstance(lateral_lengths, (float, np.float64)) or lateral_lengths.ndim == 1 | ||
): | ||
raise ValueError( | ||
"Input lateral_lengths must be a scalar or have shape (instances,)." | ||
) | ||
|
||
# Calculate the total lateral root length using np.nansum | ||
total_lateral_length = np.nansum(lateral_lengths) | ||
|
||
if monocots: | ||
length = total_lateral_length | ||
else: | ||
# Calculate the total root network length using np.nansum so the total length | ||
# will not be NaN if one of primary or lateral lengths are NaN | ||
length = np.nansum([primary_length, total_lateral_length]) | ||
|
||
return length | ||
# Initialize an empty list to store the lengths | ||
all_lengths = [] | ||
# Loop over the input arrays | ||
for length in [lengths0] + list(args): | ||
if length is None: | ||
continue # Skip None values | ||
# Ensure length is either a scalar or has the correct shape | ||
if not (np.isscalar(length) or (hasattr(length, "ndim") and length.ndim == 1)): | ||
raise ValueError( | ||
"Input length must be a scalar or have shape (instances,)." | ||
) | ||
# Add the length to the list | ||
if np.isscalar(length): | ||
all_lengths.append(length) | ||
else: | ||
all_lengths.extend(list(length)) | ||
|
||
# Calculate the total root network length using np.nansum so the total length | ||
# will not be NaN if one of primary or lateral lengths are NaN | ||
total_network_length = np.nansum(all_lengths) | ||
|
||
return total_network_length | ||
|
||
|
||
def get_network_solidity( | ||
|
@@ -121,60 +119,34 @@ def get_network_solidity( | |
|
||
|
||
def get_network_distribution( | ||
primary_pts: np.ndarray, | ||
lateral_pts: np.ndarray, | ||
pts_list: List[np.ndarray], | ||
bounding_box: Tuple[float, float, float, float], | ||
fraction: float = 2 / 3, | ||
monocots: bool = False, | ||
) -> float: | ||
"""Return the root length in the lower fraction of the plant. | ||
|
||
Args: | ||
primary_pts: Array of primary root landmarks. Can have shape `(nodes, 2)` or | ||
`(1, nodes, 2)`. | ||
lateral_pts: Array of lateral root landmarks with shape `(instances, nodes, 2)`. | ||
pts_list: A list of arrays, each having shape `(nodes, 2)`. | ||
bounding_box: Tuple in the form `(left_x, top_y, width, height)`. | ||
fraction: Lower fraction value. Defaults to 2/3. | ||
monocots: A boolean value, where True indicates rice. Defaults to False. | ||
|
||
Returns: | ||
Root network length in the lower fraction of the plant. | ||
""" | ||
# Input validation | ||
if primary_pts.ndim not in [2, 3]: | ||
# Input validation for pts_list | ||
if any(pts.ndim != 2 or pts.shape[-1] != 2 for pts in pts_list): | ||
raise ValueError( | ||
"primary_pts should have a shape of `(nodes, 2)` or `(1, nodes, 2)`." | ||
"Each pts array in pts_list should have a shape of `(nodes, 2)`." | ||
) | ||
|
||
if primary_pts.ndim == 2 and primary_pts.shape[-1] != 2: | ||
raise ValueError("primary_pts should have a shape of `(nodes, 2)`.") | ||
|
||
if primary_pts.ndim == 3 and primary_pts.shape[-1] != 2: | ||
raise ValueError("primary_pts should have a shape of `(1, nodes, 2)`.") | ||
|
||
if lateral_pts.ndim != 3 or lateral_pts.shape[-1] != 2: | ||
raise ValueError("lateral_pts should have a shape of `(instances, nodes, 2)`.") | ||
|
||
# Input validation for bounding_box | ||
if len(bounding_box) != 4: | ||
raise ValueError( | ||
"bounding_box should be in the form `(left_x, top_y, width, height)`." | ||
"bounding_box must contain exactly 4 elements: `(left_x, top_y, width, height)`." | ||
) | ||
|
||
# Make sure the longest primary root is used | ||
if primary_pts.ndim == 3: | ||
primary_pts = get_max_length_pts(primary_pts) # shape is (nodes, 2) | ||
|
||
# Make primary_pts and lateral_pts have the same dimension of 3 | ||
primary_pts = ( | ||
primary_pts[np.newaxis, :, :] if primary_pts.ndim == 2 else primary_pts | ||
) | ||
|
||
# Filter out NaN values | ||
primary_pts = [root[~np.isnan(root).any(axis=1)] for root in primary_pts] | ||
lateral_pts = [root[~np.isnan(root).any(axis=1)] for root in lateral_pts] | ||
|
||
# Collate root points. | ||
all_roots = primary_pts + lateral_pts if not monocots else lateral_pts | ||
pts_list = [pts[~np.isnan(pts).any(axis=-1)] for pts in pts_list] | ||
|
||
# Get the vertices of the bounding box | ||
left_x, top_y, width, height = bounding_box | ||
Comment on lines
119
to
152
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The + pre_filtered_pts_list = [pts for pts in pts_list if not LineString(pts).disjoint(lower_box)]
+ for root in pre_filtered_pts_list: |
||
|
@@ -185,7 +157,6 @@ def get_network_distribution( | |
return np.nan | ||
|
||
# Convert lower bounding box to polygon | ||
# Vertices are in counter-clockwise order | ||
lower_box = Polygon( | ||
[ | ||
[left_x, top_y + (height - lower_height)], | ||
|
@@ -197,7 +168,9 @@ def get_network_distribution( | |
|
||
# Calculate length of roots within the lower bounding box | ||
network_length = 0 | ||
for root in all_roots: | ||
for root in pts_list: | ||
if root.shape[0] < 2: # Skip if fewer than two points | ||
continue | ||
root_poly = LineString(root) | ||
lower_intersection = root_poly.intersection(lower_box) | ||
root_length = lower_intersection.length | ||
|
@@ -207,53 +180,27 @@ def get_network_distribution( | |
|
||
|
||
def get_network_distribution_ratio( | ||
primary_length: float, | ||
lateral_lengths: Union[float, np.ndarray], | ||
network_length: float, | ||
network_length_lower: float, | ||
fraction: float = 2 / 3, | ||
monocots: bool = False, | ||
) -> float: | ||
"""Return ratio of the root length in the lower fraction over all root length. | ||
"""Return ratio of the root length in the lower fraction to total root length. | ||
|
||
Args: | ||
primary_length: Primary root length. | ||
lateral_lengths: Lateral root lengths. Can be a single float (for one root) | ||
or an array of floats (for multiple roots). | ||
network_length_lower: The root length in the lower network. | ||
fraction: The fraction of the network considered as 'lower'. Defaults to 2/3. | ||
monocots: A boolean value, where True indicates rice. Defaults to False. | ||
network_length: Total root length of network. | ||
|
||
Returns: | ||
Float of ratio of the root network length in the lower fraction of the plant | ||
over all root length. | ||
over the total root length. | ||
""" | ||
# Ensure primary_length is a scalar | ||
if not isinstance(primary_length, (float, np.float64)): | ||
raise ValueError("Input primary_length must be a scalar value.") | ||
|
||
# Ensure lateral_lengths is either a scalar or a 1-dimensional array | ||
if not isinstance(lateral_lengths, (float, np.float64, np.ndarray)): | ||
raise ValueError( | ||
"Input lateral_lengths must be a scalar or a 1-dimensional array." | ||
) | ||
|
||
# If lateral_lengths is an ndarray, it must be one-dimensional | ||
if isinstance(lateral_lengths, np.ndarray) and lateral_lengths.ndim != 1: | ||
raise ValueError("Input lateral_lengths array must have shape (instances,).") | ||
if not isinstance(network_length, (float, np.float64)): | ||
raise ValueError("Input network_length must be a scalar value.") | ||
|
||
# Ensure network_length_lower is a scalar | ||
if not isinstance(network_length_lower, (float, np.float64)): | ||
raise ValueError("Input network_length_lower must be a scalar value.") | ||
|
||
# Calculate the total lateral root length | ||
total_lateral_length = np.nansum(lateral_lengths) | ||
|
||
# Determine total root length based on monocots flag | ||
if monocots: | ||
total_root_length = total_lateral_length | ||
else: | ||
total_root_length = np.nansum([primary_length, total_lateral_length]) | ||
|
||
# Calculate the ratio | ||
ratio = network_length_lower / total_root_length | ||
ratio = network_length_lower / network_length | ||
return ratio |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
get_network_length
function has been updated to accept a variable number of root lengths, enhancing its flexibility. This change is well-implemented, correctly handling both scalar and array inputs for root lengths. However, consider adding a unit test to cover scenarios with different types and shapes of inputs to ensure robustness.Would you like me to help with creating these unit tests?