From 892d12da5dc23661d7101a3c9345d0775ee8f889 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Sun, 12 May 2024 17:45:09 -0700 Subject: [PATCH] fix tests for pipelines to take in slp paths --- tests/test_trait_pipelines.py | 104 ++++++++++++++++++++++------------ 1 file changed, 69 insertions(+), 35 deletions(-) diff --git a/tests/test_trait_pipelines.py b/tests/test_trait_pipelines.py index 41b9c18..4055ecc 100644 --- a/tests/test_trait_pipelines.py +++ b/tests/test_trait_pipelines.py @@ -15,10 +15,27 @@ ) -def test_dicot_pipeline(canola_h5, soy_h5): +def test_dicot_pipeline( + canola_h5, + soy_h5, + canola_primary_slp, + canola_lateral_slp, + soy_primary_slp, + soy_lateral_slp, +): # Load the data - canola = Series.load(canola_h5, primary_name="primary", lateral_name="lateral") - soy = Series.load(soy_h5, primary_name="primary", lateral_name="lateral") + canola = Series.load( + series_name="canola", + h5_path=canola_h5, + primary_path=canola_primary_slp, + lateral_path=canola_lateral_slp, + ) + soy = Series.load( + series_name="soy", + h5_path=soy_h5, + primary_path=soy_primary_slp, + lateral_path=soy_lateral_slp, + ) pipeline = DicotPipeline() canola_traits = pipeline.compute_plant_traits(canola) @@ -30,8 +47,12 @@ def test_dicot_pipeline(canola_h5, soy_h5): assert all_traits.shape == (2, 1036) -def test_OlderMonocot_pipeline(rice_main_10do_h5): - rice = Series.load(rice_main_10do_h5, crown_name="crown") +def test_OlderMonocot_pipeline(rice_main_10do_h5, rice_main_10do_slp): + rice = Series.load( + series_name="rice_10do", + h5_path=rice_main_10do_h5, + crown_path=rice_main_10do_slp, + ) pipeline = OlderMonocotPipeline() rice_10dag_traits = pipeline.compute_plant_traits(rice) @@ -39,17 +60,23 @@ def test_OlderMonocot_pipeline(rice_main_10do_h5): assert rice_10dag_traits.shape == (72, 102) -def test_younger_monocot_pipeline(rice_h5, rice_folder): - rice = Series.load(rice_h5, primary_name="primary", crown_name="crown") - rice_series_all = find_all_series(rice_folder) - series_all = [ - Series.load(series, primary_name="primary", crown_name="crown") - for series in rice_series_all - ] - +def test_younger_monocot_pipeline(rice_pipeline_output_folder): + # Find slp paths in folder + slp_paths = find_all_slp_paths(rice_pipeline_output_folder) + assert len(slp_paths) == 4 + # Load series from slps + rice_series_all = load_series_from_slps( + slp_paths=slp_paths, h5s=False, csv_path=None + ) + assert len(rice_series_all) == 2 + # Get first series + rice_series = rice_series_all[0] + # Initialize pipeline for younger monocot pipeline = YoungerMonocotPipeline() - rice_traits = pipeline.compute_plant_traits(rice) - all_traits = pipeline.compute_batch_traits(series_all) + # Get traits for the first series using the pipeline + rice_traits = pipeline.compute_plant_traits(rice_series) + # Get all traits for all series using the pipeline + all_traits = pipeline.compute_batch_traits(rice_series_all) # Dataframe shape assertions assert rice_traits.shape == (72, 104) @@ -96,14 +123,22 @@ def test_younger_monocot_pipeline(rice_h5, rice_folder): ).all(), "angle_column in all_traits contains values out of range [0, 180]" -def test_older_monocot_pipeline(rice_main_10do_h5, rice_10do_folder): - rice = Series.load(rice_main_10do_h5, crown_name="crown") - rice_series_all = find_all_series(rice_10do_folder) - series_all = [Series.load(series, crown_name="crown") for series in rice_series_all] +def test_older_monocot_pipeline(rice_10do_pipeline_output_folder): + # Find slp paths in folder + slp_paths = find_all_slp_paths(rice_10do_pipeline_output_folder) + assert len(slp_paths) == 1 + # Load series from slps + rice_series_all = load_series_from_slps( + slp_paths=slp_paths, h5s=False, csv_path=None + ) + assert len(rice_series_all) == 1 + # Get first series + rice_series = rice_series_all[0] + assert rice_series.series_name == "scan0K9E8BI" pipeline = OlderMonocotPipeline() - rice_traits = pipeline.compute_plant_traits(rice) - all_traits = pipeline.compute_batch_traits(series_all) + all_traits = pipeline.compute_batch_traits(rice_series_all) + rice_traits = pipeline.compute_plant_traits(rice_series) # Dataframe shape assertions assert rice_traits.shape == (72, 102) @@ -148,27 +183,26 @@ def test_multiple_dicot_pipeline( multiple_arabidopsis_11do_h5, multiple_arabidopsis_11do_folder, multiple_arabidopsis_11do_csv, + multiple_arabidopsis_11do_primary_slp, + multiple_arabidopsis_11do_lateral_slp, ): arabidopsis = Series.load( - multiple_arabidopsis_11do_h5, - primary_name="primary", - lateral_name="lateral", + series_name="997_1", + h5_path=multiple_arabidopsis_11do_h5, + primary_path=multiple_arabidopsis_11do_primary_slp, + lateral_path=multiple_arabidopsis_11do_lateral_slp, + csv_path=multiple_arabidopsis_11do_csv, + ) + arabidopsis_slp_paths = find_all_slp_paths(multiple_arabidopsis_11do_folder) + arabidopsis_series_all = load_series_from_slps( + slp_paths=arabidopsis_slp_paths, + h5s=True, csv_path=multiple_arabidopsis_11do_csv, ) - arabidopsis_series_all = find_all_series(multiple_arabidopsis_11do_folder) - series_all = [ - Series.load( - series, - primary_name="primary", - lateral_name="lateral", - csv_path=multiple_arabidopsis_11do_csv, - ) - for series in arabidopsis_series_all - ] pipeline = MultipleDicotPipeline() arabidopsis_traits = pipeline.compute_multiple_dicots_traits(arabidopsis) - all_traits = pipeline.compute_batch_multiple_dicots_traits(series_all) + all_traits = pipeline.compute_batch_multiple_dicots_traits(arabidopsis_series_all) # Dataframe shape assertions assert pd.DataFrame([arabidopsis_traits["summary_stats"]]).shape == (1, 315)