Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct crown-curve-indices definition in trait pipeline #83

Merged
merged 10 commits into from
Aug 26, 2024
88 changes: 52 additions & 36 deletions notebooks/DicotPipeline.ipynb

Large diffs are not rendered by default.

451 changes: 451 additions & 0 deletions notebooks/Pipeline_mermaid_diagrams.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion sleap_roots/lengths.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_curve_index(
& (~np.isnan(base_tip_dists))
& (lengths > 0)
& (lengths >= base_tip_dists),
(lengths - base_tip_dists) / lengths,
(lengths - base_tip_dists) / np.where(lengths != 0, lengths, np.nan),
np.nan,
)

Expand Down
9 changes: 4 additions & 5 deletions sleap_roots/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,11 +473,10 @@ def load_series_from_slps(
) -> List[Series]:
"""Load a list of Series from a list of .slp paths.

To load the `Series`, the files must be named with the following convention:
To load the `Series`, the files must be named with the following convention.
The `slp_paths` are expeted to have the `series_name` in the filename and "primary",
"lateral", or "crown" in the filename to differentiate the predictions.
h5_path: '/path/to/scan/series_name.h5'
primary_path: '/path/to/scan/series_name.model{model_id}.rootprimary.slp'
lateral_path: '/path/to/scan/series_name.model{model_id}.rootlateral.slp'
crown_path: '/path/to/scan/series_name.model{model_id}.rootcrown.slp'
Note that everything is expected to be in the same folder.

Our pipeline outputs prediction files with this format:
Expand All @@ -500,7 +499,7 @@ def load_series_from_slps(
if h5s:
# Get directory of the h5s
h5_dir = Path(slp_paths[0]).parent
# Generate the path to the .h5 file
# Create path to the .h5 file
h5_path = h5_dir / f"{series_name}.h5"
else:
h5_path = None
Expand Down
8 changes: 4 additions & 4 deletions sleap_roots/trait_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,8 +1597,8 @@ def define_traits(self) -> List[TraitDef]:
),
TraitDef(
name="crown_curve_indices",
fn=get_base_tip_dist,
input_traits=["crown_base_pts", "crown_tip_pts"],
fn=get_curve_index,
input_traits=["crown_lengths", "crown_base_tip_dists"],
scalar=False,
include_in_csv=True,
kwargs={},
Expand Down Expand Up @@ -1974,8 +1974,8 @@ def define_traits(self) -> List[TraitDef]:
),
TraitDef(
name="crown_curve_indices",
fn=get_base_tip_dist,
input_traits=["crown_base_pts", "crown_tip_pts"],
fn=get_curve_index,
input_traits=["crown_lengths", "crown_base_tip_dists"],
scalar=False,
include_in_csv=True,
kwargs={},
Expand Down
1 change: 0 additions & 1 deletion tests/test_lengths.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,6 @@ def test_invalid_scalar_values():
assert np.isnan(get_curve_index(0, 8))


# tests for `get_root_lengths`
def test_curve_index_float():
assert get_curve_index(10.0, 5.0) == 0.5

Expand Down
14 changes: 12 additions & 2 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,12 @@ def test_get_frame_rice_10do(
# Get the crown labeled frame
crown_lf = frames.get("crown")

assert crown_lf == expected_labeled_frame
# Compare the attributes of the labeled frames
assert crown_lf.frame_idx == expected_labeled_frame.frame_idx
assert crown_lf.instances == expected_labeled_frame.instances
assert crown_lf.video.filename == expected_labeled_frame.video.filename
assert crown_lf.video.shape == expected_labeled_frame.video.shape
assert crown_lf.video.backend == expected_labeled_frame.video.backend
assert series.series_name == "0K9E8BI"


Expand All @@ -302,7 +307,12 @@ def test_get_frame_rice_10do_no_video(
# Get the crown labeled frame
crown_lf = frames.get("crown")

assert crown_lf == expected_labeled_frame
# Compare the attributes of the labeled frames
assert crown_lf.frame_idx == expected_labeled_frame.frame_idx
assert crown_lf.instances == expected_labeled_frame.instances
assert crown_lf.video.filename == expected_labeled_frame.video.filename
assert crown_lf.video.shape == expected_labeled_frame.video.shape
assert crown_lf.video.backend == expected_labeled_frame.video.backend
assert series.series_name == "0K9E8BI"


Expand Down
53 changes: 53 additions & 0 deletions tests/test_trait_pipelines.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import numpy as np
import pandas as pd
import json
import pytest

from sleap_roots.trait_pipelines import (
DicotPipeline,
YoungerMonocotPipeline,
OlderMonocotPipeline,
MultipleDicotPipeline,
NumpyArrayEncoder,
)
from sleap_roots.series import (
Series,
Expand All @@ -15,6 +19,47 @@
)


def test_numpy_array_serialization():
array = np.array([1, 2, 3])
expected = [1, 2, 3]
json_str = json.dumps(array, cls=NumpyArrayEncoder)
assert json.loads(json_str) == expected


def test_numpy_int64_serialization():
int64_value = np.int64(42)
expected = 42
json_str = json.dumps(int64_value, cls=NumpyArrayEncoder)
assert json.loads(json_str) == expected


def test_unsupported_type_serialization():
class UnsupportedType:
pass

with pytest.raises(TypeError):
json.dumps(UnsupportedType(), cls=NumpyArrayEncoder)


def test_mixed_data_serialization():
data = {
"array": np.array([1, 2, 3]),
"int64": np.int64(42),
"regular_int": 99,
"list": [4, 5, 6],
"dict": {"key": "value"},
}
expected = {
"array": [1, 2, 3],
"int64": 42,
"regular_int": 99,
"list": [4, 5, 6],
"dict": {"key": "value"},
}
json_str = json.dumps(data, cls=NumpyArrayEncoder)
assert json.loads(json_str) == expected


def test_dicot_pipeline(
canola_h5,
soy_h5,
Expand Down Expand Up @@ -107,12 +152,17 @@ def test_younger_monocot_pipeline(rice_pipeline_output_folder):
assert (
rice_traits["curve_index"].fillna(0) >= 0
).all(), "curve_index in rice_traits contains negative values"
assert rice_traits["curve_index"].fillna(0).max() <= 1, "curve_index in rice_traits contains values greater than 1"
assert (
all_traits["curve_index_median"] >= 0
).all(), "curve_index in all_traits contains negative values"
assert all_traits["curve_index_median"].max() <= 1, "curve_index in all_traits contains values greater than 1"
assert (
all_traits["crown_curve_indices_mean_median"] >= 0
).all(), "crown_curve_indices_mean_median in all_traits contains negative values"
assert (
all_traits["crown_curve_indices_mean_median"] <= 1
).all(), "crown_curve_indices_mean_median in all_traits contains values greater than 1"
assert (
(0 <= rice_traits["crown_angles_proximal_p95"])
& (rice_traits["crown_angles_proximal_p95"] <= 180)
Expand Down Expand Up @@ -169,6 +219,9 @@ def test_older_monocot_pipeline(rice_10do_pipeline_output_folder):
assert (
all_traits["crown_curve_indices_mean_median"] >= 0
).all(), "crown_curve_indices_mean_median in all_traits contains negative values"
assert (
all_traits["crown_curve_indices_mean_median"] <= 1
).all(), "crown_curve_indices_mean_median in all_traits contains values greater than 1"
assert (
(0 <= rice_traits["crown_angles_proximal_p95"])
& (rice_traits["crown_angles_proximal_p95"] <= 180)
Expand Down
Loading