diff --git a/sleap_roots/trait_pipelines.py b/sleap_roots/trait_pipelines.py index 50a2029..b153880 100644 --- a/sleap_roots/trait_pipelines.py +++ b/sleap_roots/trait_pipelines.py @@ -498,6 +498,7 @@ def compute_multiple_dicots_traits( def compute_multiple_dicots_traits_for_groups( self, series_list: List[Series], + output_dir: str = "grouped_traits", write_json: bool = False, json_suffix: str = ".grouped_traits.json", write_csv: bool = False, @@ -509,6 +510,7 @@ def compute_multiple_dicots_traits_for_groups( Args: series_list: A list of Series objects containing the primary and lateral root points for each sample. + output_dir: The directory to write the JSON and CSV files to. Default is "grouped_traits". write_json: Whether to write the aggregated traits to a JSON file. Default is False. json_suffix: The suffix to append to the JSON file name. Default is ".grouped_traits.json". write_csv: Whether to write the summary statistics to a CSV file. Default is False. @@ -523,8 +525,6 @@ def compute_multiple_dicots_traits_for_groups( ): raise ValueError("series_list must be a list of Series objects.") - grouped_results = [] - # Group series by their group property series_groups = {} for series in series_list: @@ -535,6 +535,9 @@ def compute_multiple_dicots_traits_for_groups( series_groups[group_name]["names"].append(str(series.series_name)) series_groups[group_name]["series"].append(series) # Store Series objects + # Initialize the list to hold the results for each group + grouped_results = [] + # Iterate over each group of series for group_name, group_data in series_groups.items(): # Initialize the return structure with the group name group_result = { @@ -545,24 +548,38 @@ def compute_multiple_dicots_traits_for_groups( # Aggregate traits over all samples in the group aggregated_traits = {} + # Iterate over each series in the group for series in group_data["series"]: print(f"Processing series '{series.series_name}'") + # Get the trait results for each series in the group result = self.compute_multiple_dicots_traits( series=series, write_json=False, write_csv=False ) + # Aggregate the series traits into the group traits for trait, values in result["traits"].items(): + # Ensure values are at least 1D + values = np.atleast_1d(values) if trait not in aggregated_traits: - aggregated_traits[trait] = [np.atleast_1d(values)] + aggregated_traits[trait] = values else: - aggregated_traits[trait].append([np.atleast_1d(values)]) + # Concatenate the current values with the existing array + aggregated_traits[trait] = np.concatenate( + (aggregated_traits[trait], values) + ) + group_result["traits"] = aggregated_traits print(f"Finished processing group '{group_name}'") # Write to JSON if requested if write_json: + # Make the output directory if it doesn't exist + Path(output_dir).mkdir(parents=True, exist_ok=True) + # Construct the JSON file name json_name = f"{group_name}{json_suffix}" + # Join the output directory with the JSON file name + json_path = Path(output_dir) / json_name try: - with open(json_name, "w") as f: + with open(json_path, "w") as f: json.dump( group_result, f, @@ -571,10 +588,10 @@ def compute_multiple_dicots_traits_for_groups( indent=4, ) print( - f"Aggregated traits for group {group_name} saved to {json_name}" + f"Aggregated traits for group {group_name} saved to {str(json_path)}" ) except IOError as e: - print(f"Error writing JSON file '{json_name}': {e}") + print(f"Error writing JSON file '{str(json_path)}': {e}") # Compute summary statistics summary_stats = {} @@ -584,21 +601,26 @@ def compute_multiple_dicots_traits_for_groups( group_result["summary_stats"] = summary_stats - grouped_results.append(group_result) - print(f"Finished processing group '{group_name}'") - # Write summary stats to CSV if requested if write_csv: + # Make the output directory if it doesn't exist + Path(output_dir).mkdir(parents=True, exist_ok=True) + # Construct the CSV file name csv_name = f"{group_name}{csv_suffix}" + # Join the output directory with the CSV file name + csv_path = Path(output_dir) / csv_name try: summary_df = pd.DataFrame([summary_stats]) summary_df.insert(0, "genotype", group_name) - summary_df.to_csv(csv_name, index=False) + summary_df.to_csv(csv_path, index=False) print( - f"Summary statistics for group {group_name} saved to {csv_name}" + f"Summary statistics for group {group_name} saved to {str(csv_path)}" ) except IOError as e: - print(f"Failed to write CSV file '{csv_name}': {e}") + print(f"Failed to write CSV file '{str(csv_path)}': {e}") + + # Append the group result to the list of results + grouped_results.append(group_result) return grouped_results @@ -693,6 +715,7 @@ def compute_batch_multiple_dicots_traits( def compute_batch_multiple_dicots_traits_for_groups( self, all_series: List[Series], + output_dir: str = "grouped_traits", write_json: bool = False, write_csv: bool = False, csv_path: str = "group_summarized_traits.csv", @@ -701,6 +724,7 @@ def compute_batch_multiple_dicots_traits_for_groups( Args: all_series: List of `Series` objects. + output_dir: The directory to write the JSON and CSV files to. Default is "grouped_traits". write_json: If `True`, write each set of group traits to a JSON file. write_csv: If `True`, write the computed traits to a CSV file. csv_path: Path to write the CSV file to. @@ -720,7 +744,10 @@ def compute_batch_multiple_dicots_traits_for_groups( try: # Compute traits for each group of series grouped_results = self.compute_multiple_dicots_traits_for_groups( - all_series, write_json=write_json, write_csv=False + all_series, + output_dir=output_dir, + write_json=write_json, + write_csv=False, ) except Exception as e: raise RuntimeError(f"Error computing traits for groups: {e}")