Skip to content

Commit

Permalink
Add output directory for group trait files
Browse files Browse the repository at this point in the history
  • Loading branch information
eberrigan committed Mar 31, 2024
1 parent bc12818 commit 40fc672
Showing 1 changed file with 41 additions and 14 deletions.
55 changes: 41 additions & 14 deletions sleap_roots/trait_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def compute_multiple_dicots_traits(
def compute_multiple_dicots_traits_for_groups(
self,
series_list: List[Series],
output_dir: str = "grouped_traits",
write_json: bool = False,
json_suffix: str = ".grouped_traits.json",
write_csv: bool = False,
Expand All @@ -509,6 +510,7 @@ def compute_multiple_dicots_traits_for_groups(
Args:
series_list: A list of Series objects containing the primary and lateral root points for each sample.
output_dir: The directory to write the JSON and CSV files to. Default is "grouped_traits".
write_json: Whether to write the aggregated traits to a JSON file. Default is False.
json_suffix: The suffix to append to the JSON file name. Default is ".grouped_traits.json".
write_csv: Whether to write the summary statistics to a CSV file. Default is False.
Expand All @@ -523,8 +525,6 @@ def compute_multiple_dicots_traits_for_groups(
):
raise ValueError("series_list must be a list of Series objects.")

Check warning on line 526 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L526

Added line #L526 was not covered by tests

grouped_results = []

# Group series by their group property
series_groups = {}
for series in series_list:
Expand All @@ -535,6 +535,9 @@ def compute_multiple_dicots_traits_for_groups(
series_groups[group_name]["names"].append(str(series.series_name))
series_groups[group_name]["series"].append(series) # Store Series objects

Check warning on line 536 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L535-L536

Added lines #L535 - L536 were not covered by tests

# Initialize the list to hold the results for each group
grouped_results = []

Check warning on line 539 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L539

Added line #L539 was not covered by tests
# Iterate over each group of series
for group_name, group_data in series_groups.items():

Check warning on line 541 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L541

Added line #L541 was not covered by tests
# Initialize the return structure with the group name
group_result = {

Check warning on line 543 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L543

Added line #L543 was not covered by tests
Expand All @@ -545,24 +548,38 @@ def compute_multiple_dicots_traits_for_groups(

# Aggregate traits over all samples in the group
aggregated_traits = {}

Check warning on line 550 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L550

Added line #L550 was not covered by tests
# Iterate over each series in the group
for series in group_data["series"]:
print(f"Processing series '{series.series_name}'")

Check warning on line 553 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L552-L553

Added lines #L552 - L553 were not covered by tests
# Get the trait results for each series in the group
result = self.compute_multiple_dicots_traits(

Check warning on line 555 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L555

Added line #L555 was not covered by tests
series=series, write_json=False, write_csv=False
)
# Aggregate the series traits into the group traits
for trait, values in result["traits"].items():

Check warning on line 559 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L559

Added line #L559 was not covered by tests
# Ensure values are at least 1D
values = np.atleast_1d(values)
if trait not in aggregated_traits:
aggregated_traits[trait] = [np.atleast_1d(values)]
aggregated_traits[trait] = values

Check warning on line 563 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L561-L563

Added lines #L561 - L563 were not covered by tests
else:
aggregated_traits[trait].append([np.atleast_1d(values)])
# Concatenate the current values with the existing array
aggregated_traits[trait] = np.concatenate(

Check warning on line 566 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L566

Added line #L566 was not covered by tests
(aggregated_traits[trait], values)
)

group_result["traits"] = aggregated_traits
print(f"Finished processing group '{group_name}'")

Check warning on line 571 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L570-L571

Added lines #L570 - L571 were not covered by tests

# Write to JSON if requested
if write_json:

Check warning on line 574 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L574

Added line #L574 was not covered by tests
# Make the output directory if it doesn't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)

Check warning on line 576 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L576

Added line #L576 was not covered by tests
# Construct the JSON file name
json_name = f"{group_name}{json_suffix}"

Check warning on line 578 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L578

Added line #L578 was not covered by tests
# Join the output directory with the JSON file name
json_path = Path(output_dir) / json_name
try:
with open(json_name, "w") as f:
with open(json_path, "w") as f:
json.dump(

Check warning on line 583 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L580-L583

Added lines #L580 - L583 were not covered by tests
group_result,
f,
Expand All @@ -571,10 +588,10 @@ def compute_multiple_dicots_traits_for_groups(
indent=4,
)
print(

Check warning on line 590 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L590

Added line #L590 was not covered by tests
f"Aggregated traits for group {group_name} saved to {json_name}"
f"Aggregated traits for group {group_name} saved to {str(json_path)}"
)
except IOError as e:
print(f"Error writing JSON file '{json_name}': {e}")
print(f"Error writing JSON file '{str(json_path)}': {e}")

Check warning on line 594 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L593-L594

Added lines #L593 - L594 were not covered by tests

# Compute summary statistics
summary_stats = {}
Expand All @@ -584,21 +601,26 @@ def compute_multiple_dicots_traits_for_groups(

group_result["summary_stats"] = summary_stats

Check warning on line 602 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L602

Added line #L602 was not covered by tests

grouped_results.append(group_result)
print(f"Finished processing group '{group_name}'")

# Write summary stats to CSV if requested
if write_csv:

Check warning on line 605 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L605

Added line #L605 was not covered by tests
# Make the output directory if it doesn't exist
Path(output_dir).mkdir(parents=True, exist_ok=True)

Check warning on line 607 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L607

Added line #L607 was not covered by tests
# Construct the CSV file name
csv_name = f"{group_name}{csv_suffix}"

Check warning on line 609 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L609

Added line #L609 was not covered by tests
# Join the output directory with the CSV file name
csv_path = Path(output_dir) / csv_name
try:
summary_df = pd.DataFrame([summary_stats])
summary_df.insert(0, "genotype", group_name)
summary_df.to_csv(csv_name, index=False)
summary_df.to_csv(csv_path, index=False)
print(

Check warning on line 616 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L611-L616

Added lines #L611 - L616 were not covered by tests
f"Summary statistics for group {group_name} saved to {csv_name}"
f"Summary statistics for group {group_name} saved to {str(csv_path)}"
)
except IOError as e:
print(f"Failed to write CSV file '{csv_name}': {e}")
print(f"Failed to write CSV file '{str(csv_path)}': {e}")

Check warning on line 620 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L619-L620

Added lines #L619 - L620 were not covered by tests

# Append the group result to the list of results
grouped_results.append(group_result)

Check warning on line 623 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L623

Added line #L623 was not covered by tests

return grouped_results

Check warning on line 625 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L625

Added line #L625 was not covered by tests

Expand Down Expand Up @@ -693,6 +715,7 @@ def compute_batch_multiple_dicots_traits(
def compute_batch_multiple_dicots_traits_for_groups(
self,
all_series: List[Series],
output_dir: str = "grouped_traits",
write_json: bool = False,
write_csv: bool = False,
csv_path: str = "group_summarized_traits.csv",
Expand All @@ -701,6 +724,7 @@ def compute_batch_multiple_dicots_traits_for_groups(
Args:
all_series: List of `Series` objects.
output_dir: The directory to write the JSON and CSV files to. Default is "grouped_traits".
write_json: If `True`, write each set of group traits to a JSON file.
write_csv: If `True`, write the computed traits to a CSV file.
csv_path: Path to write the CSV file to.
Expand All @@ -720,7 +744,10 @@ def compute_batch_multiple_dicots_traits_for_groups(
try:

Check warning on line 744 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L744

Added line #L744 was not covered by tests
# Compute traits for each group of series
grouped_results = self.compute_multiple_dicots_traits_for_groups(

Check warning on line 746 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L746

Added line #L746 was not covered by tests
all_series, write_json=write_json, write_csv=False
all_series,
output_dir=output_dir,
write_json=write_json,
write_csv=False,
)
except Exception as e:
raise RuntimeError(f"Error computing traits for groups: {e}")

Check warning on line 753 in sleap_roots/trait_pipelines.py

View check run for this annotation

Codecov / codecov/patch

sleap_roots/trait_pipelines.py#L752-L753

Added lines #L752 - L753 were not covered by tests
Expand Down

0 comments on commit 40fc672

Please sign in to comment.