diff --git a/sleap_roots/trait_pipelines.py b/sleap_roots/trait_pipelines.py index 02143ce..50a2029 100644 --- a/sleap_roots/trait_pipelines.py +++ b/sleap_roots/trait_pipelines.py @@ -405,8 +405,8 @@ def compute_multiple_dicots_traits( """ # Initialize the return structure with the series name and group result = { - "series": series.series_name, - "group": series.group, + "series": str(series.series_name), + "group": str(series.group), "traits": {}, "summary_stats": {}, } @@ -528,7 +528,6 @@ def compute_multiple_dicots_traits_for_groups( # Group series by their group property series_groups = {} for series in series_list: - print(f"Grouping series '{series.series_name}'") group_name = str(series.group) if group_name not in series_groups: series_groups[group_name] = {"names": [], "series": []} @@ -537,7 +536,6 @@ def compute_multiple_dicots_traits_for_groups( series_groups[group_name]["series"].append(series) # Store Series objects for group_name, group_data in series_groups.items(): - print(f"Initializing group '{group_name}'") # Initialize the return structure with the group name group_result = { "group": group_name, @@ -558,7 +556,7 @@ def compute_multiple_dicots_traits_for_groups( else: aggregated_traits[trait].append([np.atleast_1d(values)]) group_result["traits"] = aggregated_traits - print(f"Group results: {group_result}") + print(f"Finished processing group '{group_name}'") # Write to JSON if requested if write_json: @@ -594,7 +592,7 @@ def compute_multiple_dicots_traits_for_groups( csv_name = f"{group_name}{csv_suffix}" try: summary_df = pd.DataFrame([summary_stats]) - summary_df.insert(0, "group", group_name) + summary_df.insert(0, "genotype", group_name) summary_df.to_csv(csv_name, index=False) print( f"Summary statistics for group {group_name} saved to {csv_name}" @@ -607,8 +605,6 @@ def compute_multiple_dicots_traits_for_groups( def compute_batch_traits( self, plants: List[Series], - write_json_per_series: bool = False, - json_suffix: str = ".all_frames_traits.json", write_csv: bool = False, csv_path: str = "traits.csv", ) -> pd.DataFrame: @@ -616,10 +612,6 @@ def compute_batch_traits( Args: plants: List of `Series` objects. - write_json_per_series: If `True`, write the computed traits to a JSON file - for each series. - json_suffix: The suffix to append to the JSON file name. Default is - ".all_frames_traits.json". write_csv: If `True`, write the computed traits to a CSV file. csv_path: Path to write the CSV file to. @@ -698,6 +690,78 @@ def compute_batch_multiple_dicots_traits( return all_series_summaries_df + def compute_batch_multiple_dicots_traits_for_groups( + self, + all_series: List[Series], + write_json: bool = False, + write_csv: bool = False, + csv_path: str = "group_summarized_traits.csv", + ) -> pd.DataFrame: + """Compute traits for a batch of grouped series with multiple dicots. + + Args: + all_series: List of `Series` objects. + write_json: If `True`, write each set of group traits to a JSON file. + write_csv: If `True`, write the computed traits to a CSV file. + csv_path: Path to write the CSV file to. + + Returns: + A pandas DataFrame of computed traits summarized over all frames of each + series. The resulting dataframe will have a row for each series and a column + for each series-level summarized trait. + + Summarized traits are prefixed with the trait name and an underscore, + followed by the summary statistic. + """ + # Check if the input list is empty + if not all_series: + raise ValueError("The input list 'all_series' is empty.") + + try: + # Compute traits for each group of series + grouped_results = self.compute_multiple_dicots_traits_for_groups( + all_series, write_json=write_json, write_csv=False + ) + except Exception as e: + raise RuntimeError(f"Error computing traits for groups: {e}") + + # Prepare the list of dictionaries for the DataFrame + all_group_summaries = [] + for group_result in grouped_results: + # Validate the expected key exists in the result + if "summary_stats" not in group_result: + raise KeyError( + "Expected key 'summary_stats' not found in group result." + ) + + # Assuming 'group' key exists in group_result and it indicates the genotype + genotype = group_result.get( + "group", "Unknown Genotype" + ) # Default to "Unknown Genotype" if not found + + # Start with a dictionary containing the genotype + group_summary = {"genotype": genotype} + + # Add each trait statistic from the summary_stats dictionary to the group_summary + # This assumes summary_stats is a dictionary where keys are trait names and values are the statistics + for trait, statistic in group_result["summary_stats"].items(): + group_summary[trait] = statistic + + all_group_summaries.append(group_summary) + + # Create a DataFrame from the list of dictionaries + all_group_summaries_df = pd.DataFrame(all_group_summaries) + + # Write to CSV if requested + if write_csv: + try: + all_group_summaries_df.to_csv(csv_path, index=False) + print(f"Computed traits for all groups saved to {csv_path}") + except Exception as e: + raise IOError(f"Failed to write computed traits to CSV: {e}") + + return all_group_summaries_df + @attrs.define class DicotPipeline(Pipeline): diff --git a/tests/test_bases.py b/tests/test_bases.py index 4e07d30..c887e59 100644 --- a/tests/test_bases.py +++ b/tests/test_bases.py @@ -376,13 +376,23 @@ def test_root_width_canola(canola_h5): np.array([[0, 0], [1, 1]]), np.array([[[0, 0], [1, 1]], [[1, 1], [2, 2]]]), 0.02, - (np.array([]), [(np.nan, np.nan)], np.empty((0, 2)), np.empty((0, 2))), + ( + np.nan, + [(np.nan, np.nan)], + np.full((1, 2), np.nan), + np.full((1, 2), np.nan), + ), ), ( np.array([[np.nan, np.nan], [np.nan, np.nan]]), np.array([[[0, 0], [1, 1]], [[1, 1], [2, 2]]]), 0.02, - (np.array([]), [(np.nan, np.nan)], np.empty((0, 2)), np.empty((0, 2))), + ( + np.nan, + [(np.nan, np.nan)], + np.full((1, 2), np.nan), + np.full((1, 2), np.nan), + ), ), ], ) @@ -416,27 +426,27 @@ def test_get_root_widths_invalid_cases(): # Minimum length result = get_root_widths(np.array([[0, 0]]), np.array([[[0, 0]]])) - assert np.array_equal(result, np.array([])) + assert np.isnan(result) # Return default values with return_inds=True result = get_root_widths(np.array([[0, 0]]), np.array([[[0, 0]]]), return_inds=True) # Checks if both arrays are exactly the same - assert np.array_equal(result[0], np.array([])) + assert np.isnan(result[0]) # Continue to check the other parts of the tuple assert result[1] == [(np.nan, np.nan)] # Check the other NumPy arrays in the tuple - assert np.array_equal(result[2], np.empty((0, 2))) - assert np.array_equal(result[3], np.empty((0, 2))) + assert np.all(np.isnan(result[2])) + assert np.all(np.isnan(result[3])) # All NaNs in input arrays result = get_root_widths( np.array([[np.nan, np.nan], [np.nan, np.nan]]), np.array([[[np.nan, np.nan], [np.nan, np.nan]]]), ) - assert np.array_equal(result, np.array([])) + assert np.isnan(result) # All lateral roots on the same side result = get_root_widths( np.array([[0, 0], [1, 1]]), np.array([[[0, 0], [1, 1]], [[0, 0], [1, 1]]]) ) - assert np.array_equal(result, np.array([])) + assert np.isnan(result) diff --git a/tests/test_trait_pipelines.py b/tests/test_trait_pipelines.py index 102188a..c16e814 100644 --- a/tests/test_trait_pipelines.py +++ b/tests/test_trait_pipelines.py @@ -165,7 +165,7 @@ def test_multiple_dicot_pipeline( all_traits = pipeline.compute_batch_multiple_dicots_traits(series_all) # Dataframe shape assertions - assert pd.DataFrame(arabidopsis_traits["summary_stats"]).shape == (1, 316) + assert pd.DataFrame([arabidopsis_traits["summary_stats"]]).shape == (1, 315) assert all_traits.shape == (4, 316) # Dataframe dtype assertions