diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000..a059721
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,3 @@
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.slp filter=lfs diff=lfs merge=lfs -text
+*.type filter=lfs diff=lfs merge=lfs -text
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index ae2a82c..b56aa62 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -66,6 +66,8 @@ jobs:
steps:
- name: Checkout repo
uses: actions/checkout@v3
+ with:
+ lfs: true # Fetch large files with Git LFS
- name: Setup Micromamba
# https://github.com/mamba-org/setup-micromamba
diff --git a/MultiDicotPipeline.ipynb b/MultiDicotPipeline.ipynb
new file mode 100644
index 0000000..7939493
--- /dev/null
+++ b/MultiDicotPipeline.ipynb
@@ -0,0 +1,733 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sleap_roots import Series, find_all_series\n",
+ "from sleap_roots import MultipleDicotPipeline\n",
+ "from sleap_roots.trait_pipelines import Pipeline\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import json\n",
+ "\n",
+ "from pathlib import Path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "csv_path = \"tests/data/multiple_arabidopsis_11do/merged_proofread_samples_03122024.csv\" # For sample information (count, group)\n",
+ "folder_path = \"tests/data/multiple_arabidopsis_11do\" # Location of h5 files and predictions\n",
+ "primary_name = \"primary\" # For loading primary root predictions\n",
+ "lateral_name = \"lateral\" # For loading lateral root predictions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['tests/data/multiple_arabidopsis_11do/6039_1.h5',\n",
+ " 'tests/data/multiple_arabidopsis_11do/7327_2.h5',\n",
+ " 'tests/data/multiple_arabidopsis_11do/9535_1.h5',\n",
+ " 'tests/data/multiple_arabidopsis_11do/997_1.h5']"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Find all h5 files in the folder\n",
+ "all_h5s = find_all_series(folder_path)\n",
+ "all_h5s"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[Series(h5_path='tests/data/multiple_arabidopsis_11do/6039_1.h5', primary_labels=Labels(labeled_frames=67, videos=1, skeletons=1, tracks=0), lateral_labels=Labels(labeled_frames=68, videos=1, skeletons=1, tracks=0), crown_labels=None, video=Video(filename=\"tests/data/multiple_arabidopsis_11do/6039_1.h5\", shape=(72, 1088, 2048, 1), dataset=vol, backend=HDF5Video), csv_path='tests/data/multiple_arabidopsis_11do/merged_proofread_samples_03122024.csv'),\n",
+ " Series(h5_path='tests/data/multiple_arabidopsis_11do/7327_2.h5', primary_labels=Labels(labeled_frames=43, videos=1, skeletons=1, tracks=0), lateral_labels=Labels(labeled_frames=31, videos=1, skeletons=1, tracks=0), crown_labels=None, video=Video(filename=\"tests/data/multiple_arabidopsis_11do/7327_2.h5\", shape=(72, 1088, 2048, 1), dataset=vol, backend=HDF5Video), csv_path='tests/data/multiple_arabidopsis_11do/merged_proofread_samples_03122024.csv'),\n",
+ " Series(h5_path='tests/data/multiple_arabidopsis_11do/9535_1.h5', primary_labels=Labels(labeled_frames=42, videos=1, skeletons=1, tracks=0), lateral_labels=Labels(labeled_frames=36, videos=1, skeletons=1, tracks=0), crown_labels=None, video=Video(filename=\"tests/data/multiple_arabidopsis_11do/9535_1.h5\", shape=(72, 1088, 2048, 1), dataset=vol, backend=HDF5Video), csv_path='tests/data/multiple_arabidopsis_11do/merged_proofread_samples_03122024.csv'),\n",
+ " Series(h5_path='tests/data/multiple_arabidopsis_11do/997_1.h5', primary_labels=Labels(labeled_frames=72, videos=1, skeletons=1, tracks=0), lateral_labels=Labels(labeled_frames=72, videos=1, skeletons=1, tracks=0), crown_labels=None, video=Video(filename=\"tests/data/multiple_arabidopsis_11do/997_1.h5\", shape=(72, 1088, 2048, 1), dataset=vol, backend=HDF5Video), csv_path='tests/data/multiple_arabidopsis_11do/merged_proofread_samples_03122024.csv')]"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Load the cylinder series (one per h5 file)\n",
+ "all_series = [Series.load(h5_path=h5, primary_name=primary_name, lateral_name=lateral_name, csv_path=csv_path) for h5 in all_h5s]\n",
+ "all_series"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Get the first series in the list\n",
+ "series = all_series[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "First sample has name 6039_1\n",
+ "First sample has genotype 6039\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"First sample has name {series.series_name}\")\n",
+ "print(f\"First sample has genotype {series.group}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Initialize the pipeline\n",
+ "pipeline = MultipleDicotPipeline()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Aggregated traits saved to 6039_1.all_frames_traits.json\n",
+ "Summary statistics saved to 6039_1.all_frames_summary.csv\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Get the traits of the first sample\n",
+ "first_sample_traits = pipeline.compute_multiple_dicots_traits(series=series, write_json=True, write_csv=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " lateral_count_min | \n",
+ " lateral_count_max | \n",
+ " lateral_count_mean | \n",
+ " lateral_count_median | \n",
+ " lateral_count_std | \n",
+ " lateral_count_p5 | \n",
+ " lateral_count_p25 | \n",
+ " lateral_count_p75 | \n",
+ " lateral_count_p95 | \n",
+ " lateral_lengths_min | \n",
+ " ... | \n",
+ " network_distribution_ratio_p95 | \n",
+ " network_solidity_min | \n",
+ " network_solidity_max | \n",
+ " network_solidity_mean | \n",
+ " network_solidity_median | \n",
+ " network_solidity_std | \n",
+ " network_solidity_p5 | \n",
+ " network_solidity_p25 | \n",
+ " network_solidity_p75 | \n",
+ " network_solidity_p95 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 7 | \n",
+ " 5.08209 | \n",
+ " 5.0 | \n",
+ " 1.240176 | \n",
+ " 3.0 | \n",
+ " 4.0 | \n",
+ " 6.0 | \n",
+ " 7.0 | \n",
+ " 3.777593 | \n",
+ " ... | \n",
+ " 0.757133 | \n",
+ " 0.041121 | \n",
+ " 0.150504 | \n",
+ " 0.062255 | \n",
+ " 0.057276 | \n",
+ " 0.01982 | \n",
+ " 0.042815 | \n",
+ " 0.048231 | \n",
+ " 0.070095 | \n",
+ " 0.098175 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
1 rows × 315 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " lateral_count_min lateral_count_max lateral_count_mean \\\n",
+ "0 1 7 5.08209 \n",
+ "\n",
+ " lateral_count_median lateral_count_std lateral_count_p5 \\\n",
+ "0 5.0 1.240176 3.0 \n",
+ "\n",
+ " lateral_count_p25 lateral_count_p75 lateral_count_p95 \\\n",
+ "0 4.0 6.0 7.0 \n",
+ "\n",
+ " lateral_lengths_min ... network_distribution_ratio_p95 \\\n",
+ "0 3.777593 ... 0.757133 \n",
+ "\n",
+ " network_solidity_min network_solidity_max network_solidity_mean \\\n",
+ "0 0.041121 0.150504 0.062255 \n",
+ "\n",
+ " network_solidity_median network_solidity_std network_solidity_p5 \\\n",
+ "0 0.057276 0.01982 0.042815 \n",
+ "\n",
+ " network_solidity_p25 network_solidity_p75 network_solidity_p95 \n",
+ "0 0.048231 0.070095 0.098175 \n",
+ "\n",
+ "[1 rows x 315 columns]"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pd.DataFrame([first_sample_traits[\"summary_stats\"]])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processing series '6039_1'\n",
+ "Finished processing group '6039'\n",
+ "Aggregated traits for group 6039 saved to 6039.grouped_traits.json\n",
+ "Finished processing group '6039'\n",
+ "Summary statistics for group 6039 saved to 6039.grouped_summary.csv\n",
+ "Processing series '7327_2'\n",
+ "Finished processing group '7327'\n",
+ "Aggregated traits for group 7327 saved to 7327.grouped_traits.json\n",
+ "Finished processing group '7327'\n",
+ "Summary statistics for group 7327 saved to 7327.grouped_summary.csv\n",
+ "Processing series '9535_1'\n",
+ "Finished processing group '9535'\n",
+ "Aggregated traits for group 9535 saved to 9535.grouped_traits.json\n",
+ "Finished processing group '9535'\n",
+ "Summary statistics for group 9535 saved to 9535.grouped_summary.csv\n",
+ "Processing series '997_1'\n",
+ "Finished processing group '997'\n",
+ "Aggregated traits for group 997 saved to 997.grouped_traits.json\n",
+ "Finished processing group '997'\n",
+ "Summary statistics for group 997 saved to 997.grouped_summary.csv\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Get the traits grouped by genotype\n",
+ "grouped_traits = pipeline.compute_multiple_dicots_traits_for_groups(series_list=list(all_series), write_json=True, write_csv=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "4"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(grouped_traits)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " lateral_count_min | \n",
+ " lateral_count_max | \n",
+ " lateral_count_mean | \n",
+ " lateral_count_median | \n",
+ " lateral_count_std | \n",
+ " lateral_count_p5 | \n",
+ " lateral_count_p25 | \n",
+ " lateral_count_p75 | \n",
+ " lateral_count_p95 | \n",
+ " lateral_lengths_min | \n",
+ " ... | \n",
+ " network_distribution_ratio_p95 | \n",
+ " network_solidity_min | \n",
+ " network_solidity_max | \n",
+ " network_solidity_mean | \n",
+ " network_solidity_median | \n",
+ " network_solidity_std | \n",
+ " network_solidity_p5 | \n",
+ " network_solidity_p25 | \n",
+ " network_solidity_p75 | \n",
+ " network_solidity_p95 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 7 | \n",
+ " 5.08209 | \n",
+ " 5.0 | \n",
+ " 1.240176 | \n",
+ " 3.0 | \n",
+ " 4.0 | \n",
+ " 6.0 | \n",
+ " 7.0 | \n",
+ " 3.777593 | \n",
+ " ... | \n",
+ " 0.757133 | \n",
+ " 0.041121 | \n",
+ " 0.150504 | \n",
+ " 0.062255 | \n",
+ " 0.057276 | \n",
+ " 0.01982 | \n",
+ " 0.042815 | \n",
+ " 0.048231 | \n",
+ " 0.070095 | \n",
+ " 0.098175 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
1 rows × 315 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " lateral_count_min lateral_count_max lateral_count_mean \\\n",
+ "0 1 7 5.08209 \n",
+ "\n",
+ " lateral_count_median lateral_count_std lateral_count_p5 \\\n",
+ "0 5.0 1.240176 3.0 \n",
+ "\n",
+ " lateral_count_p25 lateral_count_p75 lateral_count_p95 \\\n",
+ "0 4.0 6.0 7.0 \n",
+ "\n",
+ " lateral_lengths_min ... network_distribution_ratio_p95 \\\n",
+ "0 3.777593 ... 0.757133 \n",
+ "\n",
+ " network_solidity_min network_solidity_max network_solidity_mean \\\n",
+ "0 0.041121 0.150504 0.062255 \n",
+ "\n",
+ " network_solidity_median network_solidity_std network_solidity_p5 \\\n",
+ "0 0.057276 0.01982 0.042815 \n",
+ "\n",
+ " network_solidity_p25 network_solidity_p75 network_solidity_p95 \n",
+ "0 0.048231 0.070095 0.098175 \n",
+ "\n",
+ "[1 rows x 315 columns]"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "pd.DataFrame([grouped_traits[0][\"summary_stats\"]])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "grouped_summary_df = pd.DataFrame([grouped_trait[\"summary_stats\"] for grouped_trait in grouped_traits])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " lateral_count_min | \n",
+ " lateral_count_max | \n",
+ " lateral_count_mean | \n",
+ " lateral_count_median | \n",
+ " lateral_count_std | \n",
+ " lateral_count_p5 | \n",
+ " lateral_count_p25 | \n",
+ " lateral_count_p75 | \n",
+ " lateral_count_p95 | \n",
+ " lateral_lengths_min | \n",
+ " ... | \n",
+ " network_distribution_ratio_p95 | \n",
+ " network_solidity_min | \n",
+ " network_solidity_max | \n",
+ " network_solidity_mean | \n",
+ " network_solidity_median | \n",
+ " network_solidity_std | \n",
+ " network_solidity_p5 | \n",
+ " network_solidity_p25 | \n",
+ " network_solidity_p75 | \n",
+ " network_solidity_p95 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " 1 | \n",
+ " 7 | \n",
+ " 5.082090 | \n",
+ " 5.0 | \n",
+ " 1.240176 | \n",
+ " 3.00 | \n",
+ " 4.0 | \n",
+ " 6.00 | \n",
+ " 7.0 | \n",
+ " 3.777593 | \n",
+ " ... | \n",
+ " 0.757133 | \n",
+ " 0.041121 | \n",
+ " 0.150504 | \n",
+ " 0.062255 | \n",
+ " 0.057276 | \n",
+ " 0.019820 | \n",
+ " 0.042815 | \n",
+ " 0.048231 | \n",
+ " 0.070095 | \n",
+ " 0.098175 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " 1 | \n",
+ " 9 | \n",
+ " 3.434109 | \n",
+ " 1.0 | \n",
+ " 2.825260 | \n",
+ " 1.00 | \n",
+ " 1.0 | \n",
+ " 6.00 | \n",
+ " 8.0 | \n",
+ " 4.345694 | \n",
+ " ... | \n",
+ " 0.679840 | \n",
+ " 0.024168 | \n",
+ " 0.293489 | \n",
+ " 0.092920 | \n",
+ " 0.087395 | \n",
+ " 0.062009 | \n",
+ " 0.030521 | \n",
+ " 0.041196 | \n",
+ " 0.125539 | \n",
+ " 0.214581 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1 | \n",
+ " 13 | \n",
+ " 6.007937 | \n",
+ " 6.0 | \n",
+ " 3.027640 | \n",
+ " 1.00 | \n",
+ " 4.0 | \n",
+ " 8.00 | \n",
+ " 11.0 | \n",
+ " 4.431438 | \n",
+ " ... | \n",
+ " 0.677514 | \n",
+ " 0.032377 | \n",
+ " 0.166538 | \n",
+ " 0.055098 | \n",
+ " 0.048888 | \n",
+ " 0.023023 | \n",
+ " 0.033393 | \n",
+ " 0.038470 | \n",
+ " 0.065840 | \n",
+ " 0.092981 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 4 | \n",
+ " 9 | \n",
+ " 7.000000 | \n",
+ " 7.5 | \n",
+ " 1.914854 | \n",
+ " 4.25 | \n",
+ " 5.5 | \n",
+ " 8.75 | \n",
+ " 9.0 | \n",
+ " 17.140351 | \n",
+ " ... | \n",
+ " 0.550392 | \n",
+ " 0.017635 | \n",
+ " 0.028867 | \n",
+ " 0.021103 | \n",
+ " 0.019285 | \n",
+ " 0.004037 | \n",
+ " 0.017699 | \n",
+ " 0.017987 | \n",
+ " 0.022816 | \n",
+ " 0.027564 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
4 rows × 315 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " lateral_count_min lateral_count_max lateral_count_mean \\\n",
+ "0 1 7 5.082090 \n",
+ "1 1 9 3.434109 \n",
+ "2 1 13 6.007937 \n",
+ "3 4 9 7.000000 \n",
+ "\n",
+ " lateral_count_median lateral_count_std lateral_count_p5 \\\n",
+ "0 5.0 1.240176 3.00 \n",
+ "1 1.0 2.825260 1.00 \n",
+ "2 6.0 3.027640 1.00 \n",
+ "3 7.5 1.914854 4.25 \n",
+ "\n",
+ " lateral_count_p25 lateral_count_p75 lateral_count_p95 \\\n",
+ "0 4.0 6.00 7.0 \n",
+ "1 1.0 6.00 8.0 \n",
+ "2 4.0 8.00 11.0 \n",
+ "3 5.5 8.75 9.0 \n",
+ "\n",
+ " lateral_lengths_min ... network_distribution_ratio_p95 \\\n",
+ "0 3.777593 ... 0.757133 \n",
+ "1 4.345694 ... 0.679840 \n",
+ "2 4.431438 ... 0.677514 \n",
+ "3 17.140351 ... 0.550392 \n",
+ "\n",
+ " network_solidity_min network_solidity_max network_solidity_mean \\\n",
+ "0 0.041121 0.150504 0.062255 \n",
+ "1 0.024168 0.293489 0.092920 \n",
+ "2 0.032377 0.166538 0.055098 \n",
+ "3 0.017635 0.028867 0.021103 \n",
+ "\n",
+ " network_solidity_median network_solidity_std network_solidity_p5 \\\n",
+ "0 0.057276 0.019820 0.042815 \n",
+ "1 0.087395 0.062009 0.030521 \n",
+ "2 0.048888 0.023023 0.033393 \n",
+ "3 0.019285 0.004037 0.017699 \n",
+ "\n",
+ " network_solidity_p25 network_solidity_p75 network_solidity_p95 \n",
+ "0 0.048231 0.070095 0.098175 \n",
+ "1 0.041196 0.125539 0.214581 \n",
+ "2 0.038470 0.065840 0.092981 \n",
+ "3 0.017987 0.022816 0.027564 \n",
+ "\n",
+ "[4 rows x 315 columns]"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "grouped_summary_df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Processing series '6039_1'\n",
+ "Finished processing group '6039'\n",
+ "Aggregated traits for group 6039 saved to 6039.grouped_traits.json\n",
+ "Finished processing group '6039'\n",
+ "Processing series '7327_2'\n",
+ "Finished processing group '7327'\n",
+ "Aggregated traits for group 7327 saved to 7327.grouped_traits.json\n",
+ "Finished processing group '7327'\n",
+ "Processing series '9535_1'\n",
+ "Finished processing group '9535'\n",
+ "Aggregated traits for group 9535 saved to 9535.grouped_traits.json\n",
+ "Finished processing group '9535'\n",
+ "Processing series '997_1'\n",
+ "Finished processing group '997'\n",
+ "Aggregated traits for group 997 saved to 997.grouped_traits.json\n",
+ "Finished processing group '997'\n",
+ "Computed traits for all groups saved to group_summarized_traits.csv\n"
+ ]
+ }
+ ],
+ "source": [
+ "grouped_summary_df = pipeline.compute_batch_multiple_dicots_traits_for_groups(all_series=list(all_series), write_json=True, write_csv=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'6039'"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "grouped_traits[0][\"group\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "(4, 316)"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "grouped_summary_df.shape"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "sleap_roots",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.18"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/sleap_roots/__init__.py b/sleap_roots/__init__.py
index aa3a417..e758455 100644
--- a/sleap_roots/__init__.py
+++ b/sleap_roots/__init__.py
@@ -17,9 +17,10 @@
TraitDef,
YoungerMonocotPipeline,
OlderMonocotPipeline,
+ MultipleDicotPipeline,
)
from sleap_roots.series import Series, find_all_series
# Define package version.
# This is read dynamically by setuptools in pyproject.toml to determine the release version.
-__version__ = "0.0.6"
+__version__ = "0.0.7"
diff --git a/sleap_roots/bases.py b/sleap_roots/bases.py
index 790092e..ac7aafc 100644
--- a/sleap_roots/bases.py
+++ b/sleap_roots/bases.py
@@ -245,23 +245,22 @@ def get_root_widths(
Returns:
- If `return_inds` is False (default):
- Returns an array of distances between the bases of matched roots. An empty
- array is returned if no matching indices are found.
+ Returns an array of distances between the bases of matched roots. If no
+ matched indices are found, NaN is returned.
- If `return_inds` is True:
Returns a tuple containing the following four elements:
- - matched_dists: Distances between the bases of matched roots. An empty
- array is returned if no matched indices
- are found.
+ - matched_dists: Distances between the bases of matched roots. If no
+ matched indices are found, NaN is returned.
- matched_indices: List of tuples, each containing the indices
of matched roots on the left and right sides. A list containing a
tuple of NaNs is returned if no matched indices are found.
- left_bases_final: (n, 2) array containing the (x, y)
- coordinates of the left bases of the matched roots. An empty array
- of shape (0, 2) is returned if no matched indices are found.
+ coordinates of the left bases of the matched roots. An array of
+ NaNs is returned if no matched indices are found.
- right_bases_final: (n, 2) array containing the (x, y)
- coordinates of the right bases of the matched roots. An empty array
- of shape (0, 2) is returned if no matched indices are found.
+ coordinates of the right bases of the matched roots. An array of
+ NaNs is returned if no matched indices are found.
"""
# Validate tolerance
if tolerance <= 0:
@@ -275,11 +274,11 @@ def get_root_widths(
if primary_max_length_pts.shape[1] != 2 or lateral_pts.shape[2] != 2:
raise ValueError("The last dimension should contain x and y coordinates")
- # Initialize default return values with shapes that match the expected output
- default_dists = np.array([])
- default_indices = [(np.nan, np.nan)]
- default_left_bases = np.empty((0, 2))
- default_right_bases = np.empty((0, 2))
+ # Initialize default return values
+ default_dists = np.nan
+ default_indices = [(np.nan, np.nan)] # List of tuples with NaN values
+ default_left_bases = np.full((1, 2), np.nan) # 2D array filled with NaN values
+ default_right_bases = np.full((1, 2), np.nan) # 2D array filled with NaN values
# Check for minimum length, or all NaNs in arrays
if (
diff --git a/sleap_roots/lengths.py b/sleap_roots/lengths.py
index 127a845..a7938b6 100644
--- a/sleap_roots/lengths.py
+++ b/sleap_roots/lengths.py
@@ -2,38 +2,44 @@
import numpy as np
from typing import Union
+from shapely.geometry import LineString
def get_max_length_pts(pts: np.ndarray) -> np.ndarray:
"""Points of the root with maximum length (intended for primary root traits).
Args:
- pts: Root landmarks as array of shape `(instances, nodes, 2)`.
+ pts: Root landmarks as array of shape `(instances, nodes, 2)` or `(nodes, 2)`.
Returns:
np.ndarray: Array of points with shape `(nodes, 2)` from the root with maximum
- length.
+ length, or the input array unchanged if its shape is `(nodes, 2)`.
"""
+ # Return the input array unchanged if its shape is (nodes, 2)
+ if pts.ndim == 2 and pts.shape[1] == 2:
+ return pts
+
# Return NaN points if the input array is empty
if len(pts) == 0:
return np.array([[np.nan, np.nan]])
- # Check if pts has the correct shape, raise error if it does not
+ # Check if pts has the correct shape for processing multiple instances
if pts.ndim != 3 or pts.shape[2] != 2:
- raise ValueError("Input array should have shape (instances, nodes, 2)")
+ raise ValueError(
+ "Input array should have shape (instances, nodes, 2) for multiple instances"
+ )
# Calculate the differences between consecutive points in each root
segment_diffs = np.diff(pts, axis=1)
- # Calculate the length of each segment (the Euclidean distance between consecutive
- # points)
+ # Calculate the length of each segment
segment_lengths = np.linalg.norm(segment_diffs, axis=-1)
# Sum the lengths of the segments for each root
total_lengths = np.nansum(segment_lengths, axis=-1)
- # Handle roots where all segment lengths are NaN, recording NaN in place of the
- # total length for these roots
+ # Handle roots where all segment lengths are NaN,
+ # recording NaN in place of the total length for these roots
total_lengths[np.isnan(segment_lengths).all(axis=-1)] = np.nan
# Return NaN points if all total lengths are NaN
@@ -128,3 +134,27 @@ def get_curve_index(
return curve_index.item()
else:
return curve_index
+
+
+def get_min_distance_line_to_line(line1: LineString, line2: LineString) -> float:
+ """Calculate the minimum distance between two LineString objects.
+
+ This function computes the shortest distance between any two points on the first
+ line segment and the second line segment. If the lines intersect, the minimum
+ distance is zero. The distance is calculated in the same units as the coordinates
+ of the LineStrings.
+
+ Args:
+ line1: The first LineString object representing a line segment.
+ line2: The second LineString object representing a line segment.
+
+ Returns:
+ The minimum distance between the two line segments.
+ """
+ # Check if the inputs are LineString instances
+ if not isinstance(line1, LineString):
+ raise TypeError("The first argument must be a LineString object.")
+ if not isinstance(line2, LineString):
+ raise TypeError("The second argument must be a LineString object.")
+
+ return line1.distance(line2)
diff --git a/sleap_roots/points.py b/sleap_roots/points.py
index e5b8280..479f564 100644
--- a/sleap_roots/points.py
+++ b/sleap_roots/points.py
@@ -1,6 +1,10 @@
"""Get traits related to the points."""
import numpy as np
+from matplotlib import pyplot as plt
+from matplotlib.lines import Line2D
+from shapely.geometry import LineString
+from shapely.ops import nearest_points
from typing import List, Optional, Tuple
@@ -285,3 +289,308 @@ def get_line_equation_from_points(pts1: np.ndarray, pts2: np.ndarray):
b = pts1[1] - m * pts1[0]
return m, b
+
+
+def filter_roots_with_nans(pts: np.ndarray) -> np.ndarray:
+ """Remove roots with NaN values from an array of root points.
+
+ Args:
+ pts: An array of points representing roots, with shape (instances, nodes, 2),
+ where 'instances' is the number of roots, 'nodes' is the number of points in
+ each root, and '2' corresponds to the x and y coordinates.
+
+ Returns:
+ np.ndarray: An array of shape (instances, nodes, 2) with NaN-containing roots
+ removed. If all roots contain NaN values, an empty array of shape
+ (0, nodes, 2) is returned.
+ """
+ if not isinstance(pts, np.ndarray):
+ raise TypeError("Input must be a numpy array.")
+ if pts.ndim != 3 or pts.shape[2] != 2:
+ raise ValueError("Input array must have a shape of (instances, nodes, 2).")
+
+ cleaned_pts = np.array([root for root in pts if not np.isnan(root).any()])
+
+ if cleaned_pts.size == 0:
+ return np.empty((0, pts.shape[1], 2))
+
+ return cleaned_pts
+
+
+def filter_plants_with_unexpected_ct(
+ primary_pts: np.ndarray, lateral_pts: np.ndarray, expected_count: float
+) -> Tuple[np.ndarray, np.ndarray]:
+ """Filter out primary and lateral roots with an unexpected number of plants.
+
+ Args:
+ primary_pts: A numpy array of primary root points with shape
+ (instances, nodes, 2), where 'instances' is the number of primary roots,
+ 'nodes' is the number of points in each root, and '2' corresponds to the x and y
+ coordinates.
+ lateral_pts: A numpy array of lateral root points with a shape similar
+ to primary_pts, representing the lateral roots.
+ expected_count: The expected number of primary roots as a float or NaN. If NaN,
+ no filtering is applied based on count. If a number, it will be rounded to
+ the nearest integer for comparison.
+
+ Returns:
+ A tuple containing the filtered primary and lateral root points arrays. If the
+ input types are incorrect, the function will raise a ValueError.
+
+ Raises:
+ ValueError: If input types are incorrect.
+ """
+ # Type checking
+ if not isinstance(primary_pts, np.ndarray) or not isinstance(
+ lateral_pts, np.ndarray
+ ):
+ raise ValueError("primary_pts and lateral_pts must be numpy arrays.")
+ if not np.issubdtype(type(expected_count), np.number):
+ raise ValueError("expected_count must be a numeric type.")
+
+ # Handle NaN expected_count: Skip filtering if expected_count is NaN
+ if not np.isnan(expected_count):
+ # Rounding expected_count to the nearest integer for comparison
+ expected_count_rounded = round(expected_count)
+
+ if len(primary_pts) != expected_count_rounded:
+ # Adjusting primary and lateral roots to empty arrays of the same shape
+ primary_pts = np.empty((0, primary_pts.shape[1], 2))
+ lateral_pts = np.empty((0, lateral_pts.shape[1], 2))
+
+ return primary_pts, lateral_pts
+
+
+def get_filtered_primary_pts(filtered_pts: Tuple[np.ndarray, np.ndarray]) -> np.ndarray:
+ """Get the filtered primary root points from a tuple of filtered primary and lateral roots.
+
+ Args:
+ filtered_pts: A tuple containing the filtered primary and lateral root points arrays.
+
+ Returns:
+ np.ndarray: The filtered primary root points array.
+ """
+ return filtered_pts[0]
+
+
+def get_filtered_lateral_pts(filtered_pts: Tuple[np.ndarray, np.ndarray]) -> np.ndarray:
+ """Get the filtered lateral root points from a tuple of filtered primary and lateral roots.
+
+ Args:
+ filtered_pts: A tuple containing the filtered primary and lateral root points arrays.
+
+ Returns:
+ np.ndarray: The filtered lateral root points array.
+ """
+ return filtered_pts[1]
+
+
+def is_line_valid(line: np.ndarray) -> bool:
+ """Check if a line (numpy array of points) does not contain NaN values, indicating it is valid.
+
+ Args:
+ line: A numpy array representing a line with shape (nodes, 2), where 'nodes' is
+ the number of points in the line.
+
+ Returns:
+ True if the line does not contain any NaN values, False otherwise.
+ """
+ return not np.isnan(line).any()
+
+
+def clean_points(points):
+ """Remove NaN points from root points.
+
+ Args:
+ points: An array of points representing a root, with shape (nodes, 2).
+
+ Returns:
+ np.ndarray: An array of the same points with NaN values removed.
+ """
+ # Filter out points with NaN values and return the cleaned array
+ return np.array([pt for pt in points if not np.isnan(pt).any()])
+
+
+def associate_lateral_to_primary(
+ primary_pts: np.ndarray, lateral_pts: np.ndarray
+) -> dict:
+ """Associates each lateral root with the closest primary root.
+
+ Args:
+ primary_pts: A numpy array of primary root points with shape
+ (instances, nodes, 2), where 'instances' is the number of primary roots,
+ 'nodes' is the number of points in each root, and '2' corresponds to the x and y
+ coordinates. Points cannot have NaN values.
+ lateral_pts: A numpy array of lateral root points with a shape similar
+ to primary_pts, representing the lateral roots. Points cannot have NaN values.
+
+ Returns:
+ dict: A dictionary where each key is an index of a primary root (from the primary_pts
+ array) and each value is a dictionary containing 'primary_points' as the points of
+ the primary root (1, nodes, 2) and 'lateral_points' as an array of
+ lateral root points that are closest to that primary root. The shape of
+ 'lateral_points' is (instances, nodes, 2), where instances is the number of
+ lateral roots associated with the primary root.
+ """
+ # Basic input validation
+ if not isinstance(primary_pts, np.ndarray) or not isinstance(
+ lateral_pts, np.ndarray
+ ):
+ raise ValueError("Both primary_pts and lateral_pts must be numpy arrays.")
+ if len(primary_pts.shape) != 3 or len(lateral_pts.shape) != 3:
+ raise ValueError("Input arrays must have a shape of (instances, nodes, 2).")
+ if primary_pts.shape[2] != 2 or lateral_pts.shape[2] != 2:
+ raise ValueError(
+ "The last dimension of input arrays must be 2, representing x and y coordinates."
+ )
+
+ plant_associations = {}
+
+ # Initialize plant associations dictionary
+ for i, primary_root in enumerate(primary_pts):
+ if not is_line_valid(primary_root):
+ continue # Skip primary roots containing NaN values
+ plant_associations[i] = {
+ "primary_points": primary_root,
+ "lateral_points": [],
+ }
+
+ # Associate each lateral root with the closest primary root
+ for lateral_root in lateral_pts:
+ if not is_line_valid(lateral_root):
+ continue # Skip lateral roots containing NaN values
+
+ lateral_line = LineString(lateral_root)
+ min_distance = float("inf")
+ closest_primary_index = None
+
+ for primary_index, primary_data in plant_associations.items():
+ primary_root = primary_data["primary_points"]
+ try:
+ primary_line = LineString(primary_root)
+ distance = primary_line.distance(lateral_line)
+ except Exception as e:
+ print(f"Error computing distance: {e}")
+ continue
+
+ if distance < min_distance:
+ min_distance = distance
+ closest_primary_index = primary_index
+
+ if closest_primary_index is not None:
+ plant_associations[closest_primary_index]["lateral_points"].append(
+ lateral_root
+ )
+
+ # Convert lateral points lists into arrays
+ for primary_index, data in plant_associations.items():
+ lateral_points_list = data["lateral_points"]
+ if lateral_points_list: # Check if there are any lateral points to convert
+ lateral_points_array = np.array(lateral_points_list)
+ plant_associations[primary_index]["lateral_points"] = lateral_points_array
+ else:
+ # Create an array of NaNs if there are no lateral points
+ shape = (1, lateral_pts.shape[1], 2) # Shape of lateral points array
+ plant_associations[primary_index]["lateral_points"] = np.full(shape, np.nan)
+
+ return plant_associations
+
+
+def flatten_associated_points(associations: dict) -> dict:
+ """Creates a dictionary of flattened arrays containing primary and lateral root points.
+
+ Args:
+ associations: A dictionary where each key is an index of a primary root and each value
+ is a dictionary containing 'primary_points' as the points of the primary root
+ and 'lateral_points' as an array of lateral root points that are closest to
+ that primary root.
+
+ Returns:
+ A dictionary with the same keys as associations. Each key corresponds to a flattened
+ array containing all the primary and lateral root points for that plant.
+ """
+ flattened_points = {}
+
+ for key, data in associations.items():
+ # Get the primary root points for the current key
+ primary_root_points = data["primary_points"]
+
+ # Get the lateral root points array
+ lateral_root_points = data["lateral_points"]
+
+ # Initialize an array with the primary root points
+ all_points = [primary_root_points]
+
+ # Check if there are lateral points and extend the array if so
+ if lateral_root_points.size > 0 and not np.isnan(lateral_root_points[0][0][0]):
+ all_points.extend(lateral_root_points)
+
+ # Concatenate all the points into a single array
+ all_points_array = np.vstack(all_points)
+
+ # Flatten the array and add to the dictionary
+ flattened_points[key] = all_points_array.flatten()
+
+ return flattened_points
+
+
+def plot_root_associations(associations: dict):
+ """Plots the associations between primary and lateral roots.
+
+ Plots the associations between primary and lateral roots, including the line
+ connecting the closest points between each lateral root and its closest primary root,
+ and ensures the color map does not include red. Adds explanations in the legend and
+ inverts the y-axis for image coordinate system.
+
+ Args:
+ associations: The output dictionary from associate_lateral_to_primary function.
+ """
+ plt.figure(figsize=(12, 10))
+
+ # Generate a color map for primary roots
+ cmap = plt.cm.viridis # Using viridis which doesn't contain red
+ colors = cmap(np.linspace(0, 1, len(associations)))
+
+ for primary_index, data in associations.items():
+ primary_points = data["primary_points"]
+ lateral_points_list = data["lateral_points"]
+ color = colors[primary_index]
+
+ # Convert primary points to LineString
+ primary_line = LineString(primary_points)
+
+ # Plot primary root
+ plt.plot(primary_points[:, 0], primary_points[:, 1], color=color, linewidth=2)
+
+ # Plot each associated lateral root
+ for lateral_points in lateral_points_list:
+ # Convert lateral points to LineString
+ lateral_line = LineString(lateral_points)
+ plt.plot(
+ lateral_points[:, 0],
+ lateral_points[:, 1],
+ color=color,
+ linestyle="--",
+ linewidth=1,
+ )
+
+ # Use nearest_points to find the closest points between the two lines
+ p1, p2 = nearest_points(primary_line, lateral_line)
+ plt.plot([p1.x, p2.x], [p1.y, p2.y], "r--", linewidth=1)
+
+ # Invert y-axis
+ plt.gca().invert_yaxis()
+
+ # Custom legend
+ custom_lines = [
+ Line2D([0], [0], color="black", lw=2),
+ Line2D([0], [0], color="black", lw=2, linestyle="--"),
+ Line2D([0], [0], color="red", lw=1, linestyle="--"),
+ ]
+ plt.legend(custom_lines, ["Primary Root", "Lateral Root", "Minimum Distance"])
+
+ plt.xlabel("X Coordinate")
+ plt.ylabel("Y Coordinate")
+ plt.title("Primary and Lateral Root Associations with Minimum Distances")
+ plt.axis("equal") # Ensure equal aspect ratio for x and y axes
+ plt.show()
diff --git a/sleap_roots/series.py b/sleap_roots/series.py
index c23d9d6..f5d3bb7 100644
--- a/sleap_roots/series.py
+++ b/sleap_roots/series.py
@@ -6,6 +6,7 @@
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
+import pandas as pd
from typing import Dict, Optional, Tuple, List, Union
from pathlib import Path
@@ -21,6 +22,7 @@ class Series:
lateral_labels: Optional `sio.Labels` corresponding to the lateral root predictions.
crown_labels: Optional `sio.Labels` corresponding to the crown predictions.
video: Optional `sio.Video` corresponding to the image series.
+ csv_path: Optional path to the CSV file containing the expected plant count.
Methods:
load: Load a set of predictions for this series.
@@ -35,6 +37,7 @@ class Series:
Properties:
series_name: Name of the series derived from the HDF5 filename.
+ expected_count: Fetch the expected plant count for this series from the CSV.
"""
h5_path: Optional[str] = None
@@ -42,6 +45,7 @@ class Series:
lateral_labels: Optional[sio.Labels] = None
crown_labels: Optional[sio.Labels] = None
video: Optional[sio.Video] = None
+ csv_path: Optional[str] = None
@classmethod
def load(
@@ -50,6 +54,7 @@ def load(
primary_name: Optional[str] = None,
lateral_name: Optional[str] = None,
crown_name: Optional[str] = None,
+ csv_path: Optional[str] = None,
) -> "Series":
"""Load a set of predictions for this series.
@@ -61,6 +66,7 @@ def load(
the file is expected to be named "{h5_path}.{lateral_name}.predictions.slp".
crown_name: Optional name of the crown predictions file. If provided,
the file is expected to be named "{h5_path}.{crown_name}.predictions.slp".
+ csv_path: Optional path to the CSV file containing the expected plant count.
Returns:
An instance of Series loaded with the specified predictions.
@@ -116,6 +122,7 @@ def load(
lateral_labels=lateral_labels,
crown_labels=crown_labels,
video=video,
+ csv_path=csv_path,
)
@property
@@ -123,6 +130,36 @@ def series_name(self) -> str:
"""Name of the series derived from the HDF5 filename."""
return Path(self.h5_path).name.split(".")[0]
+ @property
+ def expected_count(self) -> Union[float, int]:
+ """Fetch the expected plant count for this series from the CSV."""
+ if not self.csv_path or not Path(self.csv_path).exists():
+ print("CSV path is not set or the file does not exist.")
+ return np.nan
+ df = pd.read_csv(self.csv_path)
+ try:
+ # Match the series_name (or plant_qr_code in the CSV) to fetch the expected count
+ return df[df["plant_qr_code"] == self.series_name][
+ "number_of_plants_cylinder"
+ ].iloc[0]
+ except IndexError:
+ print(f"No expected count found for series {self.series_name} in CSV.")
+ return np.nan
+
+ @property
+ def group(self) -> str:
+ """Group name for the series from the CSV."""
+ if not self.csv_path or not Path(self.csv_path).exists():
+ print("CSV path is not set or the file does not exist.")
+ return np.nan
+ df = pd.read_csv(self.csv_path)
+ try:
+ # Match the series_name (or plant_qr_code in the CSV) to fetch the group
+ return df[df["plant_qr_code"] == self.series_name]["genotype"].iloc[0]
+ except IndexError:
+ print(f"No group found for series {self.series_name} in CSV.")
+ return np.nan
+
def __len__(self) -> int:
"""Length of the series (number of images)."""
return len(self.video)
@@ -224,6 +261,9 @@ def get_primary_points(self, frame_idx: int) -> np.ndarray:
Returns:
Primary root points as array of shape `(n_instances, n_nodes, 2)`.
"""
+ # Check that self.primary_labels is not None
+ if self.primary_labels is None:
+ raise ValueError("Primary labels are not available.")
# Retrieve all available frames
frames = self.get_frame(frame_idx)
# Get the primary labeled frame
@@ -247,6 +287,9 @@ def get_lateral_points(self, frame_idx: int) -> np.ndarray:
Returns:
Lateral root points as array of shape `(n_instances, n_nodes, 2)`.
"""
+ # Check that self.lateral_labels is not None
+ if self.lateral_labels is None:
+ raise ValueError("Lateral labels are not available.")
# Retrieve all available frames
frames = self.get_frame(frame_idx)
# Get the lateral labeled frame
@@ -270,6 +313,9 @@ def get_crown_points(self, frame_idx: int) -> np.ndarray:
Returns:
Crown root points as array of shape `(n_instances, n_nodes, 2)`.
"""
+ # Check that self.crown_labels is not None
+ if self.crown_labels is None:
+ raise ValueError("Crown labels are not available.")
# Retrieve all available frames
frames = self.get_frame(frame_idx)
# Get the crown labeled frame
diff --git a/sleap_roots/trait_pipelines.py b/sleap_roots/trait_pipelines.py
index 672246c..b153880 100644
--- a/sleap_roots/trait_pipelines.py
+++ b/sleap_roots/trait_pipelines.py
@@ -1,8 +1,9 @@
"""Extract traits in a pipeline based on a trait graph."""
+import json
import warnings
from pathlib import Path
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional, Union
import attrs
import networkx as nx
@@ -54,7 +55,17 @@
get_network_solidity,
get_network_width_depth_ratio,
)
-from sleap_roots.points import get_all_pts_array, get_count, get_nodes, join_pts
+from sleap_roots.points import (
+ associate_lateral_to_primary,
+ filter_plants_with_unexpected_ct,
+ filter_roots_with_nans,
+ get_all_pts_array,
+ get_count,
+ get_filtered_lateral_pts,
+ get_filtered_primary_pts,
+ get_nodes,
+ join_pts,
+)
from sleap_roots.scanline import (
count_scanline_intersections,
get_scanline_first_ind,
@@ -104,6 +115,26 @@
)
+class NumpyArrayEncoder(json.JSONEncoder):
+ """Custom encoder for NumPy array types."""
+
+ def default(self, obj):
+ """Serialize NumPy arrays to lists.
+
+ Args:
+ obj: The object to serialize.
+
+ Returns:
+ A list representation of the NumPy array.
+ """
+ if isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, np.int64):
+ return int(obj)
+ # Let the base class default method raise the TypeError
+ return json.JSONEncoder.default(self, obj)
+
+
@attrs.define
class TraitDef:
"""Definition of how to compute a trait.
@@ -237,6 +268,15 @@ def csv_traits(self) -> List[str]:
)
return csv_traits
+ @property
+ def csv_traits_multiple_plants(self) -> List[str]:
+ """List of frame-level traits to include in the CSV for multiple plants."""
+ csv_traits = []
+ for trait in self.traits:
+ if trait.include_in_csv:
+ csv_traits.append(trait.name)
+ return csv_traits
+
def compute_frame_traits(self, traits: Dict[str, Any]) -> Dict[str, Any]:
"""Compute traits based on the pipeline.
@@ -343,6 +383,247 @@ def compute_plant_traits(
else:
return traits[["plant_name", "frame_idx"] + self.csv_traits]
+ def compute_multiple_dicots_traits(
+ self,
+ series: Series,
+ write_json: bool = False,
+ json_suffix: str = ".all_frames_traits.json",
+ write_csv: bool = False,
+ csv_suffix: str = ".all_frames_summary.csv",
+ ):
+ """Computes plant traits for pipelines with multiple plants over all frames in a series.
+
+ Args:
+ series: The Series object containing the primary and lateral root points.
+ write_json: Whether to write the aggregated traits to a JSON file. Default is False.
+ json_suffix: The suffix to append to the JSON file name. Default is ".all_frames_traits.json".
+ write_csv: Whether to write the summary statistics to a CSV file. Default is False.
+ csv_suffix: The suffix to append to the CSV file name. Default is ".all_frames_summary.csv".
+
+ Returns:
+ A dictionary containing the series name, group, aggregated traits, and summary statistics.
+ """
+ # Initialize the return structure with the series name and group
+ result = {
+ "series": str(series.series_name),
+ "group": str(series.group),
+ "traits": {},
+ "summary_stats": {},
+ }
+
+ # Check if the series has frames to process
+ if len(series) == 0:
+ print(f"Series '{series.series_name}' contains no frames to process.")
+ # Return early with the initialized structure
+ return result
+
+ # Initialize a separate dictionary to hold the aggregated traits across all frames
+ aggregated_traits = {}
+
+ # Iterate over frames in series
+ for frame in range(len(series)):
+ # Get initial points and number of plants per frame
+ initial_frame_traits = self.get_initial_frame_traits(series, frame)
+ # Compute initial associations and perform filter operations
+ frame_traits = self.compute_frame_traits(initial_frame_traits)
+
+ # Instantiate DicotPipeline
+ dicot_pipeline = DicotPipeline()
+
+ # Extract the plant associations for this frame
+ associations = frame_traits["plant_associations_dict"]
+
+ for primary_idx, assoc in associations.items():
+ primary_pts = assoc["primary_points"]
+ lateral_pts = assoc["lateral_points"]
+ # Get the initial frame traits for this plant using the primary and lateral points
+ initial_frame_traits = {
+ "primary_pts": primary_pts,
+ "lateral_pts": lateral_pts,
+ }
+ # Use the dicot pipeline to compute the plant traits on this frame
+ plant_traits = dicot_pipeline.compute_frame_traits(initial_frame_traits)
+
+ # For each plant's traits in the frame
+ for trait_name, trait_value in plant_traits.items():
+ # Not all traits are added to the aggregated traits dictionary
+ if trait_name in dicot_pipeline.csv_traits_multiple_plants:
+ if trait_name not in aggregated_traits:
+ # Initialize the trait array if it's the first frame
+ aggregated_traits[trait_name] = [np.atleast_1d(trait_value)]
+ else:
+ # Append new trait values for subsequent frames
+ aggregated_traits[trait_name].append(
+ np.atleast_1d(trait_value)
+ )
+
+ # After processing, update the result dictionary with computed traits
+ for trait, arrays in aggregated_traits.items():
+ aggregated_traits[trait] = np.concatenate(arrays, axis=0)
+ result["traits"] = aggregated_traits
+
+ # Write to JSON if requested
+ if write_json:
+ json_name = f"{series.series_name}{json_suffix}"
+ try:
+ with open(json_name, "w") as f:
+ json.dump(
+ result, f, cls=NumpyArrayEncoder, ensure_ascii=False, indent=4
+ )
+ print(f"Aggregated traits saved to {json_name}")
+ except IOError as e:
+ print(f"Error writing JSON file '{json_name}': {e}")
+
+ # Compute summary statistics and update result
+ summary_stats = {}
+ for trait_name, trait_values in aggregated_traits.items():
+ trait_stats = get_summary(trait_values, prefix=f"{trait_name}_")
+ summary_stats.update(trait_stats)
+ result["summary_stats"] = summary_stats
+
+ # Optionally write summary stats to CSV
+ if write_csv:
+ csv_name = f"{series.series_name}{csv_suffix}"
+ try:
+ summary_df = pd.DataFrame([summary_stats])
+ summary_df.insert(0, "series", series.series_name)
+ summary_df.to_csv(csv_name, index=False)
+ print(f"Summary statistics saved to {csv_name}")
+ except IOError as e:
+ print(f"Failed to write CSV file '{csv_name}': {e}")
+
+ # Return the final result structure
+ return result
+
+ def compute_multiple_dicots_traits_for_groups(
+ self,
+ series_list: List[Series],
+ output_dir: str = "grouped_traits",
+ write_json: bool = False,
+ json_suffix: str = ".grouped_traits.json",
+ write_csv: bool = False,
+ csv_suffix: str = ".grouped_summary.csv",
+ ) -> List[
+ Dict[str, Union[str, List[str], Dict[str, Union[List[float], np.ndarray]]]]
+ ]:
+ """Aggregates plant traits over groups of samples.
+
+ Args:
+ series_list: A list of Series objects containing the primary and lateral root points for each sample.
+ output_dir: The directory to write the JSON and CSV files to. Default is "grouped_traits".
+ write_json: Whether to write the aggregated traits to a JSON file. Default is False.
+ json_suffix: The suffix to append to the JSON file name. Default is ".grouped_traits.json".
+ write_csv: Whether to write the summary statistics to a CSV file. Default is False.
+ csv_suffix: The suffix to append to the CSV file name. Default is ".grouped_summary.csv".
+
+ Returns:
+ A list of dictionaries containing the aggregated traits and summary statistics for each group.
+ """
+ # Input Validation
+ if not isinstance(series_list, list) or not all(
+ isinstance(series, Series) for series in series_list
+ ):
+ raise ValueError("series_list must be a list of Series objects.")
+
+ # Group series by their group property
+ series_groups = {}
+ for series in series_list:
+ group_name = str(series.group)
+ if group_name not in series_groups:
+ series_groups[group_name] = {"names": [], "series": []}
+ # Store series names and objects in the dictionary
+ series_groups[group_name]["names"].append(str(series.series_name))
+ series_groups[group_name]["series"].append(series) # Store Series objects
+
+ # Initialize the list to hold the results for each group
+ grouped_results = []
+ # Iterate over each group of series
+ for group_name, group_data in series_groups.items():
+ # Initialize the return structure with the group name
+ group_result = {
+ "group": group_name,
+ "series": group_data["names"], # Use series names
+ "traits": {},
+ }
+
+ # Aggregate traits over all samples in the group
+ aggregated_traits = {}
+ # Iterate over each series in the group
+ for series in group_data["series"]:
+ print(f"Processing series '{series.series_name}'")
+ # Get the trait results for each series in the group
+ result = self.compute_multiple_dicots_traits(
+ series=series, write_json=False, write_csv=False
+ )
+ # Aggregate the series traits into the group traits
+ for trait, values in result["traits"].items():
+ # Ensure values are at least 1D
+ values = np.atleast_1d(values)
+ if trait not in aggregated_traits:
+ aggregated_traits[trait] = values
+ else:
+ # Concatenate the current values with the existing array
+ aggregated_traits[trait] = np.concatenate(
+ (aggregated_traits[trait], values)
+ )
+
+ group_result["traits"] = aggregated_traits
+ print(f"Finished processing group '{group_name}'")
+
+ # Write to JSON if requested
+ if write_json:
+ # Make the output directory if it doesn't exist
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
+ # Construct the JSON file name
+ json_name = f"{group_name}{json_suffix}"
+ # Join the output directory with the JSON file name
+ json_path = Path(output_dir) / json_name
+ try:
+ with open(json_path, "w") as f:
+ json.dump(
+ group_result,
+ f,
+ cls=NumpyArrayEncoder,
+ ensure_ascii=False,
+ indent=4,
+ )
+ print(
+ f"Aggregated traits for group {group_name} saved to {str(json_path)}"
+ )
+ except IOError as e:
+ print(f"Error writing JSON file '{str(json_path)}': {e}")
+
+ # Compute summary statistics
+ summary_stats = {}
+ for trait, trait_values in aggregated_traits.items():
+ trait_stats = get_summary(trait_values, prefix=f"{trait}_")
+ summary_stats.update(trait_stats)
+
+ group_result["summary_stats"] = summary_stats
+
+ # Write summary stats to CSV if requested
+ if write_csv:
+ # Make the output directory if it doesn't exist
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
+ # Construct the CSV file name
+ csv_name = f"{group_name}{csv_suffix}"
+ # Join the output directory with the CSV file name
+ csv_path = Path(output_dir) / csv_name
+ try:
+ summary_df = pd.DataFrame([summary_stats])
+ summary_df.insert(0, "genotype", group_name)
+ summary_df.to_csv(csv_path, index=False)
+ print(
+ f"Summary statistics for group {group_name} saved to {str(csv_path)}"
+ )
+ except IOError as e:
+ print(f"Failed to write CSV file '{str(csv_path)}': {e}")
+
+ # Append the group result to the list of results
+ grouped_results.append(group_result)
+
+ return grouped_results
+
def compute_batch_traits(
self,
plants: List[Series],
@@ -385,6 +666,129 @@ def compute_batch_traits(
all_traits.to_csv(csv_path, index=False)
return all_traits
+ def compute_batch_multiple_dicots_traits(
+ self,
+ all_series: List[Series],
+ write_csv: bool = False,
+ csv_path: str = "traits.csv",
+ ) -> pd.DataFrame:
+ """Compute traits for a batch of series with multiple dicots.
+
+ Args:
+ all_series: List of `Series` objects.
+ write_csv: If `True`, write the computed traits to a CSV file.
+ csv_path: Path to write the CSV file to.
+
+ Returns:
+ A pandas DataFrame of computed traits summarized over all frames of each
+ series. The resulting dataframe will have a row for each series and a column
+ for each series-level summarized trait.
+
+ Summarized traits are prefixed with the trait name and an underscore,
+ followed by the summary statistic.
+ """
+ all_series_summaries = []
+
+ for series in all_series:
+ print(f"Processing series '{series.series_name}'")
+ # Use the updated function and access its return value
+ series_result = self.compute_multiple_dicots_traits(
+ series, write_json=False, write_csv=False
+ )
+ # Prepare the series-level summary.
+ series_summary = {
+ "series_name": series_result["series"],
+ **series_result["summary_stats"], # Unpack summary_stats
+ }
+ all_series_summaries.append(series_summary)
+
+ # Convert list of dictionaries to a DataFrame
+ all_series_summaries_df = pd.DataFrame(all_series_summaries)
+
+ # Write to CSV if requested
+ if write_csv:
+ all_series_summaries_df.to_csv(csv_path, index=False)
+ print(f"Computed traits for all series saved to {csv_path}")
+
+ return all_series_summaries_df
+
+ def compute_batch_multiple_dicots_traits_for_groups(
+ self,
+ all_series: List[Series],
+ output_dir: str = "grouped_traits",
+ write_json: bool = False,
+ write_csv: bool = False,
+ csv_path: str = "group_summarized_traits.csv",
+ ) -> pd.DataFrame:
+ """Compute traits for a batch of grouped series with multiple dicots.
+
+ Args:
+ all_series: List of `Series` objects.
+ output_dir: The directory to write the JSON and CSV files to. Default is "grouped_traits".
+ write_json: If `True`, write each set of group traits to a JSON file.
+ write_csv: If `True`, write the computed traits to a CSV file.
+ csv_path: Path to write the CSV file to.
+
+ Returns:
+ A pandas DataFrame of computed traits summarized over all frames of each
+ series. The resulting dataframe will have a row for each series and a column
+ for each series-level summarized trait.
+
+ Summarized traits are prefixed with the trait name and an underscore,
+ followed by the summary statistic.
+ """
+ # Check if the input list is empty
+ if not all_series:
+ raise ValueError("The input list 'all_series' is empty.")
+
+ try:
+ # Compute traits for each group of series
+ grouped_results = self.compute_multiple_dicots_traits_for_groups(
+ all_series,
+ output_dir=output_dir,
+ write_json=write_json,
+ write_csv=False,
+ )
+ except Exception as e:
+ raise RuntimeError(f"Error computing traits for groups: {e}")
+
+ # Prepare the list of dictionaries for the DataFrame
+ all_group_summaries = []
+ for group_result in grouped_results:
+ # Validate the expected key exists in the result
+ if "summary_stats" not in group_result:
+ raise KeyError(
+ "Expected key 'summary_stats' not found in group result."
+ )
+
+ # Assuming 'group' key exists in group_result and it indicates the genotype
+ genotype = group_result.get(
+ "group", "Unknown Genotype"
+ ) # Default to "Unknown Genotype" if not found
+
+ # Start with a dictionary containing the genotype
+ group_summary = {"genotype": genotype}
+
+ # Add each trait statistic from the summary_stats dictionary to the group_summary
+ # This assumes summary_stats is a dictionary where keys are trait names and values are the statistics
+ for trait, statistic in group_result["summary_stats"].items():
+ group_summary[trait] = statistic
+
+ all_group_summaries.append(group_summary)
+
+ # Create a DataFrame from the list of dictionaries
+ all_group_summaries_df = pd.DataFrame(all_group_summaries)
+
+ # Write to CSV if requested
+ if write_csv:
+ try:
+ all_group_summaries_df.to_csv(csv_path, index=False)
+ print(f"Computed traits for all groups saved to {csv_path}")
+ except Exception as e:
+ raise IOError(f"Failed to write computed traits to CSV: {e}")
+
+ return all_group_summaries_df
+
@attrs.define
class DicotPipeline(Pipeline):
@@ -1788,3 +2192,98 @@ def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, A
"""
crown_pts = plant.get_crown_points(frame_idx)
return {"crown_pts": crown_pts}
+
+
+@attrs.define
+class MultipleDicotPipeline(Pipeline):
+ """Pipeline for computing traits for multiple dicot plants."""
+
+ def define_traits(self) -> List[TraitDef]:
+ """Define the trait computation pipeline for primary roots."""
+ trait_definitions = [
+ TraitDef(
+ name="primary_pts_no_nans",
+ fn=filter_roots_with_nans,
+ input_traits=["primary_pts"],
+ scalar=False,
+ include_in_csv=False,
+ kwargs={},
+ description="Primary roots without any NaNs.",
+ ),
+ TraitDef(
+ name="lateral_pts_no_nans",
+ fn=filter_roots_with_nans,
+ input_traits=["lateral_pts"],
+ scalar=False,
+ include_in_csv=False,
+ kwargs={},
+ description="Lateral roots without any NaNs.",
+ ),
+ TraitDef(
+ name="filtered_pts_expected_plant_ct",
+ fn=filter_plants_with_unexpected_ct,
+ input_traits=[
+ "primary_pts_no_nans",
+ "lateral_pts_no_nans",
+ "expected_plant_ct",
+ ],
+ scalar=False,
+ include_in_csv=False,
+ kwargs={},
+ description="Tuple of filtered points with expected plant count.",
+ ),
+ TraitDef(
+ name="primary_pts_expected_plant_ct",
+ fn=get_filtered_primary_pts,
+ input_traits=["filtered_pts_expected_plant_ct"],
+ scalar=False,
+ include_in_csv=False,
+ kwargs={},
+ description="Filtered primary root points with expected plant count.",
+ ),
+ TraitDef(
+ name="lateral_pts_expected_plant_ct",
+ fn=get_filtered_lateral_pts,
+ input_traits=["filtered_pts_expected_plant_ct"],
+ scalar=False,
+ include_in_csv=False,
+ kwargs={},
+ description="Filtered lateral root points with expected plant count.",
+ ),
+ TraitDef(
+ name="plant_associations_dict",
+ fn=associate_lateral_to_primary,
+ input_traits=[
+ "primary_pts_expected_plant_ct",
+ "lateral_pts_expected_plant_ct",
+ ],
+ scalar=False,
+ include_in_csv=False,
+ kwargs={},
+ description="Dictionary of plant associations.",
+ ),
+ ]
+
+ return trait_definitions
+
+ def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, Any]:
+ """Return initial traits for a plant frame.
+
+ Args:
+ plant: The plant `Series` object.
+ frame_idx: The index of the current frame.
+
+ Returns:
+ A dictionary of initial traits with keys:
+ - "primary_pts": Array of primary root points.
+ - "lateral_pts": Array of lateral root points.
+ - "expected_ct": Expected number of plants as a float.
+ """
+ primary_pts = plant.get_primary_points(frame_idx)
+ lateral_pts = plant.get_lateral_points(frame_idx)
+ expected_plant_ct = plant.expected_count
+ return {
+ "primary_pts": primary_pts,
+ "lateral_pts": lateral_pts,
+ "expected_plant_ct": expected_plant_ct,
+ }
diff --git a/tests/data/canola_7do/919QDUH.h5 b/tests/data/canola_7do/919QDUH.h5
index 73359a3..5df6733 100755
Binary files a/tests/data/canola_7do/919QDUH.h5 and b/tests/data/canola_7do/919QDUH.h5 differ
diff --git a/tests/data/canola_7do/919QDUH.lateral.predictions.slp b/tests/data/canola_7do/919QDUH.lateral.predictions.slp
index f4aabf3..d1507d2 100644
Binary files a/tests/data/canola_7do/919QDUH.lateral.predictions.slp and b/tests/data/canola_7do/919QDUH.lateral.predictions.slp differ
diff --git a/tests/data/canola_7do/919QDUH.primary.predictions.slp b/tests/data/canola_7do/919QDUH.primary.predictions.slp
index d6433a1..b481021 100644
Binary files a/tests/data/canola_7do/919QDUH.primary.predictions.slp and b/tests/data/canola_7do/919QDUH.primary.predictions.slp differ
diff --git a/tests/data/multiple_arabidopsis_11do/6039_1.h5 b/tests/data/multiple_arabidopsis_11do/6039_1.h5
new file mode 100644
index 0000000..0d0b007
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/6039_1.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b38c2840b54215cb7a204992bc161e57f32173fb74e7a6b78331fa39b9a303c6
+size 110368660
diff --git a/tests/data/multiple_arabidopsis_11do/6039_1.lateral.predictions.slp b/tests/data/multiple_arabidopsis_11do/6039_1.lateral.predictions.slp
new file mode 100644
index 0000000..5f25e84
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/6039_1.lateral.predictions.slp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:53876c1d685717d09477545d00a3fdf92aa09245e2c09b2f0aaa7c6fa21e7a36
+size 107860
diff --git a/tests/data/multiple_arabidopsis_11do/6039_1.primary.predictions.slp b/tests/data/multiple_arabidopsis_11do/6039_1.primary.predictions.slp
new file mode 100644
index 0000000..b765111
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/6039_1.primary.predictions.slp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:3aa2488feeca137f297cd6ea3d28a474280100278ce220a79d1cc86497e87b92
+size 40882
diff --git a/tests/data/multiple_arabidopsis_11do/7327_2.h5 b/tests/data/multiple_arabidopsis_11do/7327_2.h5
new file mode 100644
index 0000000..17773da
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/7327_2.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a196840fff5404f19915d8785c1ab6c0d9365c71fd5377eef093892f07016984
+size 110454621
diff --git a/tests/data/multiple_arabidopsis_11do/7327_2.lateral.predictions.slp b/tests/data/multiple_arabidopsis_11do/7327_2.lateral.predictions.slp
new file mode 100644
index 0000000..f27f5c7
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/7327_2.lateral.predictions.slp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fdd3c81db5683245eb18f4d7a2f04bc8b3ef93d0f77df9700aea054d6395b4af
+size 69994
diff --git a/tests/data/multiple_arabidopsis_11do/7327_2.primary.predictions.slp b/tests/data/multiple_arabidopsis_11do/7327_2.primary.predictions.slp
new file mode 100644
index 0000000..6a11c67
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/7327_2.primary.predictions.slp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:31422995734a3a31060c8323dd4f20aaabfbdc84aadc86526aecf9489ff76524
+size 37645
diff --git a/tests/data/multiple_arabidopsis_11do/9535_1.h5 b/tests/data/multiple_arabidopsis_11do/9535_1.h5
new file mode 100644
index 0000000..0f18948
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/9535_1.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1aaa0dd30b585b187304a7185a5422056bb543f933dc4e90012ef53743768f6b
+size 110849195
diff --git a/tests/data/multiple_arabidopsis_11do/9535_1.lateral.predictions.slp b/tests/data/multiple_arabidopsis_11do/9535_1.lateral.predictions.slp
new file mode 100644
index 0000000..41cd4d2
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/9535_1.lateral.predictions.slp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e471ecacadadf2ca06a181d17743ac58e6c86777194b76cb7c7e3ecc47ff6db6
+size 112120
diff --git a/tests/data/multiple_arabidopsis_11do/9535_1.not_proofread_lateral.predictions.slp b/tests/data/multiple_arabidopsis_11do/9535_1.not_proofread_lateral.predictions.slp
new file mode 100644
index 0000000..6c15549
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/9535_1.not_proofread_lateral.predictions.slp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2c922bcff90b869229d74f551c5761aa1795122b65e069553174bfb222936d6e
+size 167936
diff --git a/tests/data/multiple_arabidopsis_11do/9535_1.not_proofread_primary.predictions.slp b/tests/data/multiple_arabidopsis_11do/9535_1.not_proofread_primary.predictions.slp
new file mode 100644
index 0000000..b334664
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/9535_1.not_proofread_primary.predictions.slp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:6be50e215281cc9bffd68e61fdf546d4ab23a133868b356169fd3871c764e77f
+size 77210
diff --git a/tests/data/multiple_arabidopsis_11do/9535_1.primary.predictions.slp b/tests/data/multiple_arabidopsis_11do/9535_1.primary.predictions.slp
new file mode 100644
index 0000000..4c7dbc4
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/9535_1.primary.predictions.slp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:103177e8f48a9b4c3eafd8d56925d1b832f49b6e8740e8e8cd209f0d05b4ce1b
+size 37150
diff --git a/tests/data/multiple_arabidopsis_11do/997_1.h5 b/tests/data/multiple_arabidopsis_11do/997_1.h5
new file mode 100644
index 0000000..253412d
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/997_1.h5
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b3b7d4f6a4774ff48300f4c802aef3f31bca1e075558d81359d3ecb364d21533
+size 110317849
diff --git a/tests/data/multiple_arabidopsis_11do/997_1.lateral.predictions.slp b/tests/data/multiple_arabidopsis_11do/997_1.lateral.predictions.slp
new file mode 100644
index 0000000..fcf144a
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/997_1.lateral.predictions.slp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:214b2de9f7defb9707431a3d68b1a232ac99439e4438f4f8e3757aaee50d2456
+size 424912
diff --git a/tests/data/multiple_arabidopsis_11do/997_1.primary.predictions.slp b/tests/data/multiple_arabidopsis_11do/997_1.primary.predictions.slp
new file mode 100644
index 0000000..06a8473
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/997_1.primary.predictions.slp
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ac196f7ea781a79c4687f54b3a74311f067487c82d6f8f2bffd1c9c39e8074ef
+size 76888
diff --git a/tests/data/multiple_arabidopsis_11do/merged_proofread_samples_03122024.csv b/tests/data/multiple_arabidopsis_11do/merged_proofread_samples_03122024.csv
new file mode 100644
index 0000000..a7a9ae7
--- /dev/null
+++ b/tests/data/multiple_arabidopsis_11do/merged_proofread_samples_03122024.csv
@@ -0,0 +1,351 @@
+plant_qr_code,genotype,replicate,path,qc_cylinder,qc_code,number_of_plants_cylinder,primary_root_proofread,lateral_root_proofread,Unnamed: 9,Unnamed: 10,Unnamed: 11,Unnamed: 12,Instructions
+1002_1,1002,1,h5s_predictions\ES\1002_1.h5,0,,3,,,,,,,
+5830_1,5830,1,h5s_predictions\ES\5830_1.h5,0,,2,,,,,,,
+5830_2,5830,2,h5s_predictions\ES\5830_2.h5,0,,3,,,,,,,
+5867_1,5867,1,h5s_predictions\ES\5867_1.h5,0,,3,,,,,,,
+5867_2,5867,2,h5s_predictions\ES\5867_2.h5,0,,3,,,,,,,
+6016_1,6016,1,h5s_predictions\ES\6016_1.h5,0,,3,,,,,,,
+6016_2,6016,2,h5s_predictions\ES\6016_2.h5,0,,3,,,,,,,
+6019_1,6019,1,h5s_predictions\ES\6019_1.h5,0,,3,,,,,,,
+6021_1,6021,1,h5s_predictions\ES\6021_1.h5,0,,3,,,,,,,
+6021_2,6021,2,h5s_predictions\ES\6021_2.h5,0,,1,,,,,,,
+6025_1,6025,1,h5s_predictions\ES\6025_1.h5,0,,3,,,,,,,
+6025_2,6025,2,h5s_predictions\ES\6025_2.h5,0,,3,,,,,,,
+6030_1,6030,1,h5s_predictions\ES\6030_1.h5,0,,3,,,,,,,
+6030_2,6030,2,h5s_predictions\ES\6030_2.h5,0,,3,,,,,,,
+6035_1,6035,1,h5s_predictions\ES\6035_1.h5,0,,3,,,,,,,
+6035_2,6035,2,h5s_predictions\ES\6035_2.h5,0,,3,,,,,,,
+6039_1,6039,1,h5s_predictions\ES\6039_1.h5,0,,2,1,1,,,,,
+6039_2,6039,2,h5s_predictions\ES\6039_2.h5,0,,3,,,,,,,
+6042_1,6042,1,h5s_predictions\ES\6042_1.h5,0,,3,,,,,,,
+6042_2,6042,2,h5s_predictions\ES\6042_2.h5,0,,3,,,,,,,
+6071_1,6071,1,h5s_predictions\ES\6071_1.h5,0,,2,,,,,,,
+6071_2,6071,2,h5s_predictions\ES\6071_2.h5,0,,2,,,,,,,
+6073_1,6073,1,h5s_predictions\ES\6073_1.h5,0,,3,,,,,,,
+6073_2,6073,2,h5s_predictions\ES\6073_2.h5,0,,3,,,,,,,
+6074_1,6074,1,h5s_predictions\ES\6074_1.h5,0,,3,,,,,,,
+6074_2,6074,2,h5s_predictions\ES\6074_2.h5,0,,3,,,,,,,
+6086_1,6086,1,h5s_predictions\ES\6086_1.h5,0,,3,,,,,,,
+6086_2,6086,2,h5s_predictions\ES\6086_2.h5,0,,3,,,,,,,
+6094_1,6094,1,h5s_predictions\ES\6094_1.h5,0,,3,,,,,,,
+6094_2,6094,2,h5s_predictions\ES\6094_2.h5,0,,3,,,,,,,
+6106_1,6106,1,h5s_predictions\ES\6106_1.h5,0,,3,,,,,,,
+6106_2,6106,2,h5s_predictions\ES\6106_2.h5,0,,3,,,,,,,
+6133_1,6133,1,h5s_predictions\ES\6133_1.h5,0,,3,,,,,,,
+6133_2,6133,2,h5s_predictions\ES\6133_2.h5,0,,3,,,,,,,
+6169_1,6169,1,h5s_predictions\ES\6169_1.h5,0,,3,,,,,,,
+6169_2,6169,2,h5s_predictions\ES\6169_2.h5,0,,3,,,,,,,
+6177_1,6177,1,h5s_predictions\ES\6177_1.h5,0,,3,,,,,,,
+6177_2,6177,2,h5s_predictions\ES\6177_2.h5,0,,2,,,,,,,
+6184_1,6184,1,h5s_predictions\ES\6184_1.h5,0,,2,,,,,,,
+6184_2,6184,2,h5s_predictions\ES\6184_2.h5,0,,1,,,,,,,
+6203_1,6203,1,h5s_predictions\ES\6203_1.h5,0,,3,,,,,,,
+6203_2,6203,2,h5s_predictions\ES\6203_2.h5,0,,3,,,,,,,
+6210_1,6210,1,h5s_predictions\ES\6210_1.h5,0,,3,,,,,,,
+6210_2,6210,2,h5s_predictions\ES\6210_2.h5,0,,3,,,,,,,
+6220_1,6220,1,h5s_predictions\ES\6220_1.h5,0,,2,,,,,,,
+6238_1,6238,1,h5s_predictions\ES\6238_1.h5,0,,3,,,,,,,
+6243_1,6243,1,h5s_predictions\ES\6243_1.h5,0,,2,,,,,,,
+6243_2,6243,2,h5s_predictions\ES\6243_2.h5,0,,3,,,,,,,
+6244_1,6244,1,h5s_predictions\ES\6244_1.h5,0,,3,,,,,,,
+6244_2,6244,2,h5s_predictions\ES\6244_2.h5,0,,3,,,,,,,
+6413_1,6413,1,h5s_predictions\ES\6413_1.h5,0,,3,,,,,,,
+6413_2,6413,2,h5s_predictions\ES\6413_2.h5,0,,3,,,,,,,
+6909_1,6909,1,h5s_predictions\ES\6909_1.h5,0,,3,,,,,,,
+6909_2,6909,2,h5s_predictions\ES\6909_2.h5,0,,3,,,,,,,
+6913_1,6913,1,h5s_predictions\ES\6913_1.h5,0,,3,,,,,,,
+6913_2,6913,2,h5s_predictions\ES\6913_2.h5,0,,1,,,,,,,
+6933_1,6933,1,h5s_predictions\ES\6933_1.h5,0,,3,,,,,,,
+6961_1,6961,1,h5s_predictions\ES\6961_1.h5,0,,3,,,,,,,
+6961_2,6961,2,h5s_predictions\ES\6961_2.h5,0,,3,,,,,,,
+6973_1,6973,1,h5s_predictions\ES\6973_1.h5,0,,3,,,,,,,
+6973_2,6973,2,h5s_predictions\ES\6973_2.h5,0,,3,,,,,,,
+7081_1,7081,1,h5s_predictions\ES\7081_1.h5,0,,3,,,,,,,
+7327_1,7327,1,h5s_predictions\ES\7327_1.h5,0,,3,,,,,,,
+7327_2,7327,2,h5s_predictions\ES\7327_2.h5,0,,3,1,1,,,,,
+7328_1,7328,1,h5s_predictions\ES\7328_1.h5,0,,3,,,,,,,
+7328_2,7328,2,h5s_predictions\ES\7328_2.h5,0,,2,,,,,,,
+8222_1,8222,1,h5s_predictions\ES\8222_1.h5,0,,3,,,,,,,
+8231_1,8231,1,h5s_predictions\ES\8231_1.h5,0,,3,,,,,,,
+8231_2,8231,2,h5s_predictions\ES\8231_2.h5,0,,3,,,,,,,
+8241_1,8241,1,h5s_predictions\ES\8241_1.h5,0,,3,,,,,,,
+8247_1,8247,1,h5s_predictions\ES\8247_1.h5,0,,3,,,,,,,
+8247_2,8247,2,h5s_predictions\ES\8247_2.h5,0,,3,,,,,,,
+8249_1,8249,1,h5s_predictions\ES\8249_1.h5,0,,3,,,,,,,
+8249_2,8249,2,h5s_predictions\ES\8249_2.h5,0,,3,,,,,,,
+8256_1,8256,1,h5s_predictions\ES\8256_1.h5,0,,3,,,,,,,
+8256_2,8256,2,h5s_predictions\ES\8256_2.h5,0,,3,,,,,,,
+8259_1,8259,1,h5s_predictions\ES\8259_1.h5,0,,3,,,,,,,
+8334_2,8334,2,h5s_predictions\ES\8334_2.h5,0,,3,,,,,,,
+8357_1,8357,1,h5s_predictions\ES\8357_1.h5,0,,3,,,,,,,
+8357_2,8357,2,h5s_predictions\ES\8357_2.h5,0,,2,,,,,,,
+8369_1,8369,1,h5s_predictions\ES\8369_1.h5,0,,3,,,,,,,
+8369_2,8369,2,h5s_predictions\ES\8369_2.h5,0,,2,,,,,,,
+8426_1,8426,1,h5s_predictions\ES\8426_1.h5,0,,3,1,,,,,,
+9057_1,9057,1,h5s_predictions\ES\9057_1.h5,0,,3,,,,,,,
+9321_1,9321,1,h5s_predictions\ES\9321_1.h5,0,,2,,,,,,,
+9323_1,9323,1,h5s_predictions\ES\9323_1.h5,1,cyl,3,,,,,,,
+9323_2,9323,2,h5s_predictions\ES\9323_2.h5,0,,3,,,,,,,
+9332_1,9332,1,h5s_predictions\ES\9332_1.h5,0,,3,,,,,,,
+9332_2,9332,2,h5s_predictions\ES\9332_2.h5,0,,2,,,,,,,
+9339_1,9339,1,h5s_predictions\ES\9339_1.h5,0,,3,,,,,,,
+9343_1,9343,1,h5s_predictions\ES\9343_1.h5,0,,3,,,,,,,
+9343_2,9343,2,h5s_predictions\ES\9343_2.h5,0,,3,,,,,,,
+9369_1,9369,1,h5s_predictions\ES\9369_1.h5,0,,3,,,,,,,
+9369_2,9369,2,h5s_predictions\ES\9369_2.h5,0,,2,,,,,,,
+9380_1,9380,1,h5s_predictions\ES\9380_1.h5,0,,3,,,,,,,
+9383_1,9383,1,h5s_predictions\ES\9383_1.h5,0,,3,,,,,,,
+9390_1,9390,1,h5s_predictions\ES\9390_1.h5,0,,3,,,,,,,
+9390_2,9390,2,h5s_predictions\ES\9390_2.h5,0,,3,,,,,,,
+9391_1,9391,1,h5s_predictions\ES\9391_1.h5,0,,3,,,,,,,
+9391_2,9391,2,h5s_predictions\ES\9391_2.h5,0,,3,,,,,,,
+9399_1,9399,1,h5s_predictions\ES\9399_1.h5,0,,3,,,,,,,
+9399_2,9399,2,h5s_predictions\ES\9399_2.h5,0,,3,,,,,,,
+9402_1,9402,1,h5s_predictions\ES\9402_1.h5,0,,3,,,,,,,
+9402_2,9402,2,h5s_predictions\ES\9402_2.h5,0,,3,,,,,,,
+9404_1,9404,1,h5s_predictions\ES\9404_1.h5,0,,2,,,,,,,
+9404_2,9404,2,h5s_predictions\ES\9404_2.h5,0,,3,,,,,,,
+9407_1,9407,1,h5s_predictions\ES\9407_1.h5,0,,3,,,,,,,
+9407_2,9407,2,h5s_predictions\ES\9407_2.h5,0,,3,,,,,,,
+9412_1,9412,1,h5s_predictions\ES\9412_1.h5,0,,3,,,,,,,
+9412_2,9412,2,h5s_predictions\ES\9412_2.h5,0,,2,,,,,,,
+9413_1,9413,1,h5s_predictions\ES\9413_1.h5,0,,3,,,,,,,
+9413_2,9413,2,h5s_predictions\ES\9413_2.h5,0,,3,,,,,,,
+9421_1,9421,1,h5s_predictions\ES\9421_1.h5,0,,3,,,,,,,
+9421_2,9421,2,h5s_predictions\ES\9421_2.h5,0,,3,,,,,,,
+9436_1,9436,1,h5s_predictions\ES\9436_1.h5,0,,3,,,,,,,
+9436_2,9436,2,h5s_predictions\ES\9436_2.h5,0,,3,,,,,,,
+9450_1,9450,1,h5s_predictions\ES\9450_1.h5,0,,3,,,,,,,
+9450_2,9450,2,h5s_predictions\ES\9450_2.h5,0,,2,,,,,,,
+9471_1,9471,1,h5s_predictions\ES\9471_1.h5,0,,3,,,,,,,
+9471_2,9471,2,h5s_predictions\ES\9471_2.h5,0,,3,,,,,,,
+9506_1,9506,1,h5s_predictions\ES\9506_1.h5,0,,3,,,,,,,
+9506_2,9506,2,h5s_predictions\ES\9506_2.h5,0,,3,,,,,,,
+9515_1,9515,1,h5s_predictions\ES\9515_1.h5,0,,3,,,,,,,
+9517_1,9517,1,h5s_predictions\ES\9517_1.h5,0,,3,,,,,,,
+9517_2,9517,2,h5s_predictions\ES\9517_2.h5,0,,3,,,,,,,
+9518_1,9518,1,h5s_predictions\ES\9518_1.h5,0,,3,,,,,,,
+9519_1,9519,1,h5s_predictions\ES\9519_1.h5,0,,3,,,,,,,
+9519_2,9519,2,h5s_predictions\ES\9519_2.h5,0,,3,,,,,,,
+9522_1,9522,1,h5s_predictions\ES\9522_1.h5,0,,3,,,,,,,
+9522_2,9522,2,h5s_predictions\ES\9522_2.h5,0,,3,,,,,,,
+9523_1,9523,1,h5s_predictions\ES\9523_1.h5,0,,3,,,,,,,
+9523_2,9523,2,h5s_predictions\ES\9523_2.h5,0,,1,,,,,,,
+9525_1,9525,1,h5s_predictions\ES\9525_1.h5,0,,3,,,,,,,
+9525_2,9525,2,h5s_predictions\ES\9525_2.h5,0,,2,,,,,,,
+9527_1,9527,1,h5s_predictions\ES\9527_1.h5,0,,3,,,,,,,
+9527_2,9527,2,h5s_predictions\ES\9527_2.h5,0,,3,,,,,,,
+9529_1,9529,1,h5s_predictions\ES\9529_1.h5,0,,3,,,,,,,
+9529_2,9529,2,h5s_predictions\ES\9529_2.h5,0,,3,,,,,,,
+9530_1,9530,1,h5s_predictions\ES\9530_1.h5,0,,3,,,,,,,
+9530_2,9530,2,h5s_predictions\ES\9530_2.h5,0,,3,,,,,,,
+9532_1,9532,1,h5s_predictions\ES\9532_1.h5,0,,3,,,,,,,
+9532_2,9532,2,h5s_predictions\ES\9532_2.h5,0,,3,,,,,,,
+9533_1,9533,1,h5s_predictions\ES\9533_1.h5,0,,3,,,,,,,
+9533_2,9533,2,h5s_predictions\ES\9533_2.h5,0,,2,,,,,,,
+9535_1,9535,1,h5s_predictions\ES\9535_1.h5,0,,3,1,1,,,,,
+9535_2,9535,2,h5s_predictions\ES\9535_2.h5,0,,3,,,,,,,
+9536_1,9536,1,h5s_predictions\ES\9536_1.h5,0,,3,,,,,,,
+9536_2,9536,2,h5s_predictions\ES\9536_2.h5,0,,3,,,,,,,
+9537_1,9537,1,h5s_predictions\ES\9537_1.h5,0,,3,,,,,,,
+9537_2,9537,2,h5s_predictions\ES\9537_2.h5,0,,3,,,,,,,
+9539_1,9539,1,h5s_predictions\ES\9539_1.h5,0,,3,,,,,,,
+9539_2,9539,2,h5s_predictions\ES\9539_2.h5,1,cont,3,,,,,,,
+9540_1,9540,1,h5s_predictions\ES\9540_1.h5,0,,3,,,,,,,
+9540_2,9540,2,h5s_predictions\ES\9540_2.h5,0,,3,,,,,,,
+9541_1,9541,1,h5s_predictions\ES\9541_1.h5,0,,3,,,,,,,
+9541_2,9541,2,h5s_predictions\ES\9541_2.h5,0,,3,,,,,,,
+9543_1,9543,1,h5s_predictions\ES\9543_1.h5,0,,3,,,,,,,
+9543_2,9543,2,h5s_predictions\ES\9543_2.h5,0,,2,,,,,,,
+9544_1,9544,1,h5s_predictions\ES\9544_1.h5,0,,3,,,,,,,
+9545_1,9545,1,h5s_predictions\ES\9545_1.h5,0,,3,,,,,,,
+9545_2,9545,2,h5s_predictions\ES\9545_2.h5,0,,3,,,,,,,
+9547_1,9547,1,h5s_predictions\ES\9547_1.h5,0,,2,,,,,,,
+9549_1,9549,1,h5s_predictions\ES\9549_1.h5,0,,2,,,,,,,
+9549_2,9549,2,h5s_predictions\ES\9549_2.h5,0,,3,,,,,,,
+9555_1,9555,1,h5s_predictions\ES\9555_1.h5,0,,2,,,,,,,
+9556_1,9556,1,h5s_predictions\ES\9556_1.h5,0,,3,,,,,,,
+9556_2,9556,2,h5s_predictions\ES\9556_2.h5,0,,3,,,,,,,
+9559_1,9559,1,h5s_predictions\ES\9559_1.h5,0,,3,,,,,,,
+9559_2,9559,2,h5s_predictions\ES\9559_2.h5,0,,3,,,,,,,
+9561_1,9561,1,h5s_predictions\ES\9561_1.h5,0,,3,,,,,,,
+9561_2,9561,2,h5s_predictions\ES\9561_2.h5,0,,3,,,,,,,
+9565_1,9565,1,h5s_predictions\ES\9565_1.h5,0,,3,,,,,,,
+9565_2,9565,2,h5s_predictions\ES\9565_2.h5,0,,3,,,,,,,
+9569_1,9569,1,h5s_predictions\ES\9569_1.h5,0,,2,,,,,,,
+9569_2,9569,2,h5s_predictions\ES\9569_2.h5,0,,3,,,,,,,
+9571_1,9571,1,h5s_predictions\ES\9571_1.h5,0,,3,,,,,,,
+9571_2,9571,2,h5s_predictions\ES\9571_2.h5,0,,3,,,,,,,
+9582_1,9582,1,h5s_predictions\ES\9582_1.h5,0,,3,,,,,,,
+9583_1,9583,1,h5s_predictions\ES\9583_1.h5,0,,3,,,,,,,
+9583_2,9583,2,h5s_predictions\ES\9583_2.h5,0,,1,,,,,,,
+9584_1,9584,1,h5s_predictions\ES\9584_1.h5,0,,3,,,,,,,
+9584_2,9584,2,h5s_predictions\ES\9584_2.h5,0,,3,,,,,,,
+9585_1,9585,1,h5s_predictions\ES\9585_1.h5,0,,3,,,,,,,
+9585_2,9585,2,h5s_predictions\ES\9585_2.h5,0,,2,,,,,,,
+9587_1,9587,1,h5s_predictions\ES\9587_1.h5,0,,3,,,,,,,
+9587_2,9587,2,h5s_predictions\ES\9587_2.h5,0,,3,,,,,,,
+9589_1,9589,1,h5s_predictions\ES\9589_1.h5,0,,3,,,,,,,
+9589_2,9589,2,h5s_predictions\ES\9589_2.h5,0,,3,,,,,,,
+9597_1,9597,1,h5s_predictions\ES\9597_1.h5,0,,2,,,,,,,
+9599_1,9599,1,h5s_predictions\ES\9599_1.h5,0,,3,,,,,,,
+9601_1,9601,1,h5s_predictions\ES\9601_1.h5,0,,3,,,,,,,
+9601_2,9601,2,h5s_predictions\ES\9601_2.h5,0,,3,,,,,,,
+9822_1,9822,1,h5s_predictions\ES\9822_1.h5,0,,3,,,,,,,
+9822_2,9822,2,h5s_predictions\ES\9822_2.h5,0,,3,,,,,,,
+9825_1,9825,1,h5s_predictions\ES\9825_1.h5,0,,2,,,,,,,
+9825_2,9825,2,h5s_predictions\ES\9825_2.h5,0,,2,,,,,,,
+9828_1,9828,1,h5s_predictions\ES\9828_1.h5,0,,3,,,,,,,
+9828_2,9828,2,h5s_predictions\ES\9828_2.h5,0,,3,,,,,,,
+9830_1,9830,1,h5s_predictions\ES\9830_1.h5,0,,3,,,,,,,
+9830_2,9830,2,h5s_predictions\ES\9830_2.h5,0,,3,,,,,,,
+9832_1,9832,1,h5s_predictions\ES\9832_1.h5,0,,3,,,,,,,
+9832_2,9832,2,h5s_predictions\ES\9832_2.h5,0,,3,,,,,,,
+9835_1,9835,1,h5s_predictions\ES\9835_1.h5,0,,3,,,,,,,
+9835_2,9835,2,h5s_predictions\ES\9835_2.h5,0,,3,,,,,,,
+9838_1,9838,1,h5s_predictions\ES\9838_1.h5,0,,3,,,,,,,
+9838_2,9838,2,h5s_predictions\ES\9838_2.h5,0,,3,,,,,,,
+9840_1,9840,1,h5s_predictions\ES\9840_1.h5,0,,3,,,,,,,
+9840_2,9840,2,h5s_predictions\ES\9840_2.h5,0,,3,,,,,,,
+9841_1,9841,1,h5s_predictions\ES\9841_1.h5,0,,3,,,,,,,
+9841_2,9841,2,h5s_predictions\ES\9841_2.h5,0,,2,,,,,,,
+9843_1,9843,1,h5s_predictions\ES\9843_1.h5,0,,3,,,,,,,
+9843_2,9843,2,h5s_predictions\ES\9843_2.h5,0,,3,,,,,,,
+9844_1,9844,1,h5s_predictions\ES\9844_1.h5,0,,3,,,,,,,
+9844_2,9844,2,h5s_predictions\ES\9844_2.h5,0,,2,,,,,,,
+9845_1,9845,1,h5s_predictions\ES\9845_1.h5,0,,3,,,,,,,
+9845_2,9845,2,h5s_predictions\ES\9845_2.h5,0,,3,,,,,,,
+9847_1,9847,1,h5s_predictions\ES\9847_1.h5,0,,3,,,,,,,
+9847_2,9847,2,h5s_predictions\ES\9847_2.h5,0,,2,,,,,,,
+9848_1,9848,1,h5s_predictions\ES\9848_1.h5,0,,3,,,,,,,
+9850_1,9850,1,h5s_predictions\ES\9850_1.h5,0,,3,,,,,,,
+9850_2,9850,2,h5s_predictions\ES\9850_2.h5,0,,3,,,,,,,
+9853_1,9853,1,h5s_predictions\ES\9853_1.h5,0,,3,,,,,,,
+9855_1,9855,1,h5s_predictions\ES\9855_1.h5,0,,3,,,,,,,
+9855_2,9855,2,h5s_predictions\ES\9855_2.h5,0,,3,,,,,,,
+9857_1,9857,1,h5s_predictions\ES\9857_1.h5,0,,3,,,,,,,
+9857_2,9857,2,h5s_predictions\ES\9857_2.h5,0,,3,,,,,,,
+9861_1,9861,1,h5s_predictions\ES\9861_1.h5,0,,3,,,,,,,
+9861_2,9861,2,h5s_predictions\ES\9861_2.h5,0,,3,,,,,,,
+9873_1,9873,1,h5s_predictions\ES\9873_1.h5,0,,3,,,,,,,
+9879_1,9879,1,h5s_predictions\ES\9879_1.h5,0,,3,,,,,,,
+9881_1,9881,1,h5s_predictions\ES\9881_1.h5,0,,3,,,,,,,
+9881_2,9881,2,h5s_predictions\ES\9881_2.h5,0,,3,,,,,,,
+9882_1,9882,1,h5s_predictions\ES\9882_1.h5,0,,3,,,,,,,
+9883_1,9883,1,h5s_predictions\ES\9883_1.h5,0,,3,,,,,,,
+9883_2,9883,2,h5s_predictions\ES\9883_2.h5,0,,3,,,,,,,
+9885_1,9885,1,h5s_predictions\ES\9885_1.h5,0,,2,,,,,,,
+9885_2,9885,2,h5s_predictions\ES\9885_2.h5,0,,3,,,,,,,
+9894_1,9894,1,h5s_predictions\ES\9894_1.h5,0,,3,,,,,,,
+9895_1,9895,1,h5s_predictions\ES\9895_1.h5,0,,3,,,,,,,
+9895_2,9895,2,h5s_predictions\ES\9895_2.h5,0,,3,,,,,,,
+9897_1,9897,1,h5s_predictions\ES\9897_1.h5,0,,3,,,,,,,
+9897_2,9897,2,h5s_predictions\ES\9897_2.h5,0,,2,,,,,,,
+9899_1,9899,1,h5s_predictions\ES\9899_1.h5,0,,3,,,,,,,
+9900_1,9900,1,h5s_predictions\ES\9900_1.h5,0,,3,,,,,,,
+9900_2,9900,2,h5s_predictions\ES\9900_2.h5,0,,3,,,,,,,
+9901_1,9901,1,h5s_predictions\ES\9901_1.h5,0,,3,,,,,,,
+9902_1,9902,1,h5s_predictions\ES\9902_1.h5,0,,3,,,,,,,
+9902_2,9902,2,h5s_predictions\ES\9902_2.h5,0,,3,1,1,,,,,
+9904_1,9904,1,h5s_predictions\ES\9904_1.h5,0,,3,,,,,,,
+9941_1,9941,1,h5s_predictions\ES\9941_1.h5,0,,3,,,,,,,
+9941_2,9941,2,h5s_predictions\ES\9941_2.h5,0,,3,,,,,,,
+9942_1,9942,1,h5s_predictions\ES\9942_1.h5,0,,3,,,,,,,
+9942_2,9942,2,h5s_predictions\ES\9942_2.h5,0,,3,,,,,,,
+9943_1,9943,1,h5s_predictions\ES\9943_1.h5,0,,3,,,,,,,
+9945_1,9945,1,h5s_predictions\ES\9945_1.h5,0,,3,,,,,,,
+9945_2,9945,2,h5s_predictions\ES\9945_2.h5,0,,3,,,,,,,
+9948_1,9948,1,h5s_predictions\ES\9948_1.h5,0,,3,,,,,,,
+1006_1,1006,1,h5s_predictions\ES\1006_1.h5,0,,2,,,,,,,
+1006_2,1006,2,h5s_predictions\ES\1006_2.h5,0,,3,,,,,,,
+1063_1,1063,1,h5s_predictions\ES\1063_1.h5,0,,3,,,,,,,
+1158_1,1158,1,h5s_predictions\ES\1158_1.h5,0,,3,,,,,,,
+1166_1,1166,1,h5s_predictions\ES\1166_1.h5,0,,2,,,,,,,
+1166_2,1166,2,h5s_predictions\ES\1166_2.h5,0,,3,,,,,,,
+1257_1,1257,1,h5s_predictions\ES\1257_1.h5,0,,2,,,,,,,
+1257_2,1257,2,h5s_predictions\ES\1257_2.h5,0,,3,,,,,,,
+1313_1,1313,1,h5s_predictions\ES\1313_1.h5,0,,3,,,,,,,
+1313_2,1313,2,h5s_predictions\ES\1313_2.h5,0,,3,,,,,,,
+1317_1,1317,1,h5s_predictions\ES\1317_1.h5,0,,3,,,,,,,
+1552_1,1552,1,h5s_predictions\ES\1552_1.h5,0,,2,,,,,,,
+1552_2,1552,2,h5s_predictions\ES\1552_2.h5,0,,3,,,,,,,
+5860_1,5860,1,h5s_predictions\ES\5860_1.h5,0,,3,,,,,,,
+5865_1,5865,1,h5s_predictions\ES\5865_1.h5,0,,3,,,,,,,
+6009_1,6009,1,h5s_predictions\ES\6009_1.h5,0,,3,,,,,,,
+6009_2,6009,2,h5s_predictions\ES\6009_2.h5,0,,2,,,,,,,
+6010_1,6010,1,h5s_predictions\ES\6010_1.h5,0,,3,,,,,,,
+6011_1,6011,1,h5s_predictions\ES\6011_1.h5,0,,3,,,,,,,
+6012_1,6012,1,h5s_predictions\ES\6012_1.h5,0,,3,,,,,,,
+6013_1,6013,1,h5s_predictions\ES\6013_1.h5,0,,3,,,,,,,
+6013_2,6013,2,h5s_predictions\ES\6013_2.h5,0,,3,,,,,,,
+6017_1,6017,1,h5s_predictions\ES\6017_1.h5,0,,3,,,,,,,
+6017_2,6017,2,h5s_predictions\ES\6017_2.h5,0,,3,,,,,,,
+6020_1,6020,1,h5s_predictions\ES\6020_1.h5,0,,2,,,,,,,
+6020_2,6020,2,h5s_predictions\ES\6020_2.h5,0,,3,,,,,,,
+6022_1,6022,1,h5s_predictions\ES\6022_1.h5,0,,2,,,,,,,
+6022_2,6022,2,h5s_predictions\ES\6022_2.h5,0,,3,,,,,,,
+6023_1,6023,1,h5s_predictions\ES\6023_1.h5,0,,2,,,,,,,
+6024_1,6024,1,h5s_predictions\ES\6024_1.h5,0,,3,,,,,,,
+6024_2,6024,2,h5s_predictions\ES\6024_2.h5,0,,3,,,,,,,
+6034_1,6034,1,h5s_predictions\ES\6034_1.h5,0,,3,,,,,,,
+6034_2,6034,2,h5s_predictions\ES\6034_2.h5,0,,3,,,,,,,
+6038_1,6038,1,h5s_predictions\ES\6038_1.h5,0,,2,,,,,,,
+6038_2,6038,2,h5s_predictions\ES\6038_2.h5,0,,3,,,,,,,
+6041_1,6041,1,h5s_predictions\ES\6041_1.h5,0,,3,,,,,,,
+6041_2,6041,2,h5s_predictions\ES\6041_2.h5,0,,3,,,,,,,
+6069_1,6069,1,h5s_predictions\ES\6069_1.h5,0,,3,,,,,,,
+6069_2,6069,2,h5s_predictions\ES\6069_2.h5,0,,3,,,,,,,
+6070_1,6070,1,h5s_predictions\ES\6070_1.h5,0,,3,,,,,,,
+6077_1,6077,1,h5s_predictions\ES\6077_1.h5,0,,3,,,,,,,
+6087_1,6087,1,h5s_predictions\ES\6087_1.h5,0,,3,,,,,,,
+6088_1,6088,1,h5s_predictions\ES\6088_1.h5,0,,3,,,,,,,
+6096_1,6096,1,h5s_predictions\ES\6096_1.h5,0,,3,,,,,,,
+6096_2,6096,2,h5s_predictions\ES\6096_2.h5,0,,3,,,,,,,
+6097_1,6097,1,h5s_predictions\ES\6097_1.h5,0,,3,,,,,,,
+6097_2,6097,2,h5s_predictions\ES\6097_2.h5,0,,3,,,,,,,
+6100_1,6100,1,h5s_predictions\ES\6100_1.h5,0,,3,,,,,,,
+6100_2,6100,2,h5s_predictions\ES\6100_2.h5,0,,3,,,,,,,
+6105_1,6105,1,h5s_predictions\ES\6105_1.h5,0,,3,,,,,,,
+6118_1,6118,1,h5s_predictions\ES\6118_1.h5,0,,3,,,,,,,
+6122_1,6122,1,h5s_predictions\ES\6122_1.h5,0,,3,,,,,,,
+6132_1,6132,1,h5s_predictions\ES\6132_1.h5,0,,3,,,,,,,
+6132_2,6132,2,h5s_predictions\ES\6132_2.h5,0,,3,,,,,,,
+6137_1,6137,1,h5s_predictions\ES\6137_1.h5,0,,3,,,,,,,
+6137_2,6137,2,h5s_predictions\ES\6137_2.h5,0,,3,,,,,,,
+6138_1,6138,1,h5s_predictions\ES\6138_1.h5,0,,3,,,,,,,
+6138_2,6138,2,h5s_predictions\ES\6138_2.h5,0,,3,,,,,,,
+6141_1,6141,1,h5s_predictions\ES\6141_1.h5,0,,3,,,,,,,
+6142_1,6142,1,h5s_predictions\ES\6142_1.h5,0,,3,,,,,,,
+6148_1,6148,1,h5s_predictions\ES\6148_1.h5,0,,3,,,,,,,
+6150_1,6150,1,h5s_predictions\ES\6150_1.h5,0,,3,,,,,,,
+6150_2,6150,2,h5s_predictions\ES\6150_2.h5,0,,3,,,,,,,
+6172_1,6172,1,h5s_predictions\ES\6172_1.h5,0,,2,,,,,,,
+6172_2,6172,2,h5s_predictions\ES\6172_2.h5,0,,3,,,,,,,
+6173_1,6173,1,h5s_predictions\ES\6173_1.h5,0,,3,,,,,,,
+6173_2,6173,2,h5s_predictions\ES\6173_2.h5,0,,3,,,,,,,
+6174_1,6174,1,h5s_predictions\ES\6174_1.h5,0,,2,,,,,,,
+6174_2,6174,2,h5s_predictions\ES\6174_2.h5,0,,2,,,,,,,
+6194_1,6194,1,h5s_predictions\ES\6194_1.h5,0,,1,,,,,,,
+6194_2,6194,2,h5s_predictions\ES\6194_2.h5,0,,1,,,,,,,
+6201_1,6201,1,h5s_predictions\ES\6201_1.h5,0,,3,,,,,,,
+6201_2,6201,2,h5s_predictions\ES\6201_2.h5,0,,3,,,,,,,
+6202_1,6202,1,h5s_predictions\ES\6202_1.h5,0,,3,,,,,,,
+6202_2,6202,2,h5s_predictions\ES\6202_2.h5,0,,3,,,,,,,
+6209_1,6209,1,h5s_predictions\ES\6209_1.h5,0,,3,,,,,,,
+6209_2,6209,2,h5s_predictions\ES\6209_2.h5,0,,3,,,,,,,
+6217_1,6217,1,h5s_predictions\ES\6217_1.h5,0,,3,,,,,,,
+6217_2,6217,2,h5s_predictions\ES\6217_2.h5,0,,3,,,,,,,
+6218_1,6218,1,h5s_predictions\ES\6218_1.h5,0,,3,,,,,,,
+6218_2,6218,2,h5s_predictions\ES\6218_2.h5,0,,2,,,,,,,
+6221_1,6221,1,h5s_predictions\ES\6221_1.h5,0,,3,,,,,,,
+6221_2,6221,2,h5s_predictions\ES\6221_2.h5,0,,3,,,,,,,
+6231_2,6231,2,h5s_predictions\ES\6231_2.h5,0,,3,,,,,,,
+6237_1,6237,1,h5s_predictions\ES\6237_1.h5,0,,3,,,,,,,
+6240_1,6240,1,h5s_predictions\ES\6240_1.h5,0,,2,,,,,,,
+6240_2,6240,2,h5s_predictions\ES\6240_2.h5,0,,3,,,,,,,
+6241_1,6241,1,h5s_predictions\ES\6241_1.h5,0,,3,,,,,,,
+6241_2,6241,2,h5s_predictions\ES\6241_2.h5,0,,3,,,,,,,
+6258_1,6258,1,h5s_predictions\ES\6258_1.h5,0,,2,,,,,,,
+6258_2,6258,2,h5s_predictions\ES\6258_2.h5,0,,3,,,,,,,
+997_1,997,1,h5s_predictions\ES\997_1.h5,0,,3,,,,,,,
+997_2,997,2,h5s_predictions\ES\997_2.h5,0,,3,,,,,,,
diff --git a/tests/data/rice_10do/0K9E8BI.crown.predictions.slp b/tests/data/rice_10do/0K9E8BI.crown.predictions.slp
index 0bffba8..9ccf240 100644
Binary files a/tests/data/rice_10do/0K9E8BI.crown.predictions.slp and b/tests/data/rice_10do/0K9E8BI.crown.predictions.slp differ
diff --git a/tests/data/rice_10do/0K9E8BI.h5 b/tests/data/rice_10do/0K9E8BI.h5
index baa7763..e452ef4 100644
Binary files a/tests/data/rice_10do/0K9E8BI.h5 and b/tests/data/rice_10do/0K9E8BI.h5 differ
diff --git a/tests/data/rice_3do/0K9E8BI.crown.predictions.slp b/tests/data/rice_3do/0K9E8BI.crown.predictions.slp
index 9f19a8a..5d81a51 100644
Binary files a/tests/data/rice_3do/0K9E8BI.crown.predictions.slp and b/tests/data/rice_3do/0K9E8BI.crown.predictions.slp differ
diff --git a/tests/data/rice_3do/0K9E8BI.h5 b/tests/data/rice_3do/0K9E8BI.h5
index da83e3b..d82ffe3 100644
Binary files a/tests/data/rice_3do/0K9E8BI.h5 and b/tests/data/rice_3do/0K9E8BI.h5 differ
diff --git a/tests/data/rice_3do/0K9E8BI.longest_3do_6nodes.predictions.slp b/tests/data/rice_3do/0K9E8BI.longest_3do_6nodes.predictions.slp
index bac1b56..b7262f4 100644
Binary files a/tests/data/rice_3do/0K9E8BI.longest_3do_6nodes.predictions.slp and b/tests/data/rice_3do/0K9E8BI.longest_3do_6nodes.predictions.slp differ
diff --git a/tests/data/rice_3do/0K9E8BI.main_3do_6nodes.predictions.slp b/tests/data/rice_3do/0K9E8BI.main_3do_6nodes.predictions.slp
index 9f19a8a..5d81a51 100644
Binary files a/tests/data/rice_3do/0K9E8BI.main_3do_6nodes.predictions.slp and b/tests/data/rice_3do/0K9E8BI.main_3do_6nodes.predictions.slp differ
diff --git a/tests/data/rice_3do/0K9E8BI.primary.predictions.slp b/tests/data/rice_3do/0K9E8BI.primary.predictions.slp
index bac1b56..b7262f4 100644
Binary files a/tests/data/rice_3do/0K9E8BI.primary.predictions.slp and b/tests/data/rice_3do/0K9E8BI.primary.predictions.slp differ
diff --git a/tests/data/rice_3do/YR39SJX.crown.predictions.slp b/tests/data/rice_3do/YR39SJX.crown.predictions.slp
index 30a0ad3..6dd13d3 100644
Binary files a/tests/data/rice_3do/YR39SJX.crown.predictions.slp and b/tests/data/rice_3do/YR39SJX.crown.predictions.slp differ
diff --git a/tests/data/rice_3do/YR39SJX.h5 b/tests/data/rice_3do/YR39SJX.h5
index b8d60c9..2d8c4db 100644
Binary files a/tests/data/rice_3do/YR39SJX.h5 and b/tests/data/rice_3do/YR39SJX.h5 differ
diff --git a/tests/data/rice_3do/YR39SJX.primary.predictions.slp b/tests/data/rice_3do/YR39SJX.primary.predictions.slp
index d40d4d7..8214dfa 100644
Binary files a/tests/data/rice_3do/YR39SJX.primary.predictions.slp and b/tests/data/rice_3do/YR39SJX.primary.predictions.slp differ
diff --git a/tests/data/soy_6do/6PR6AA22JK.h5 b/tests/data/soy_6do/6PR6AA22JK.h5
index 8458a04..64005dc 100644
Binary files a/tests/data/soy_6do/6PR6AA22JK.h5 and b/tests/data/soy_6do/6PR6AA22JK.h5 differ
diff --git a/tests/data/soy_6do/6PR6AA22JK.lateral.predictions.slp b/tests/data/soy_6do/6PR6AA22JK.lateral.predictions.slp
index 7319128..7a9e6db 100644
Binary files a/tests/data/soy_6do/6PR6AA22JK.lateral.predictions.slp and b/tests/data/soy_6do/6PR6AA22JK.lateral.predictions.slp differ
diff --git a/tests/data/soy_6do/6PR6AA22JK.primary.predictions.slp b/tests/data/soy_6do/6PR6AA22JK.primary.predictions.slp
index 516d009..94dba54 100644
Binary files a/tests/data/soy_6do/6PR6AA22JK.primary.predictions.slp and b/tests/data/soy_6do/6PR6AA22JK.primary.predictions.slp differ
diff --git a/tests/fixtures/data.py b/tests/fixtures/data.py
index a295079..f30e3c5 100644
--- a/tests/fixtures/data.py
+++ b/tests/fixtures/data.py
@@ -89,3 +89,33 @@ def soy_primary_slp():
def soy_lateral_slp():
"""Path to lateral root predictions for 6 day old soy."""
return "tests/data/soy_6do/6PR6AA22JK.lateral__nodes.predictions.slp"
+
+
+@pytest.fixture
+def multiple_arabidopsis_11do_folder():
+ """Path to a folder with the predictions for 3, 11 day old arabidopsis."""
+ return "tests/data/multiple_arabidopsis_11do"
+
+
+@pytest.fixture
+def multiple_arabidopsis_11do_h5():
+ """Path to image stack for 11 day old arabidopsis."""
+ return "tests/data/multiple_arabidopsis_11do/997_1.h5"
+
+
+@pytest.fixture
+def multiple_arabidopsis_11do_primary_slp():
+ """Path to primary root predictions for 11 day old arabidopsis."""
+ return "tests/data/multiple_arabidopsis_11do/997_1.primary.predictions.slp"
+
+
+@pytest.fixture
+def multiple_arabidopsis_11do_lateral_slp():
+ """Path to lateral root predictions for 11 day old arabidopsis."""
+ return "tests/data/multiple_arabidopsis_11do/997_1.lateral.predictions.slp"
+
+
+@pytest.fixture
+def multiple_arabidopsis_11do_csv():
+ """Path to the CSV file with expected count and group information."""
+ return "tests/data/multiple_arabidopsis_11do/merged_proofread_samples_03122024.csv"
diff --git a/tests/test_bases.py b/tests/test_bases.py
index 4e07d30..c887e59 100644
--- a/tests/test_bases.py
+++ b/tests/test_bases.py
@@ -376,13 +376,23 @@ def test_root_width_canola(canola_h5):
np.array([[0, 0], [1, 1]]),
np.array([[[0, 0], [1, 1]], [[1, 1], [2, 2]]]),
0.02,
- (np.array([]), [(np.nan, np.nan)], np.empty((0, 2)), np.empty((0, 2))),
+ (
+ np.nan,
+ [(np.nan, np.nan)],
+ np.full((1, 2), np.nan),
+ np.full((1, 2), np.nan),
+ ),
),
(
np.array([[np.nan, np.nan], [np.nan, np.nan]]),
np.array([[[0, 0], [1, 1]], [[1, 1], [2, 2]]]),
0.02,
- (np.array([]), [(np.nan, np.nan)], np.empty((0, 2)), np.empty((0, 2))),
+ (
+ np.nan,
+ [(np.nan, np.nan)],
+ np.full((1, 2), np.nan),
+ np.full((1, 2), np.nan),
+ ),
),
],
)
@@ -416,27 +426,27 @@ def test_get_root_widths_invalid_cases():
# Minimum length
result = get_root_widths(np.array([[0, 0]]), np.array([[[0, 0]]]))
- assert np.array_equal(result, np.array([]))
+ assert np.isnan(result)
# Return default values with return_inds=True
result = get_root_widths(np.array([[0, 0]]), np.array([[[0, 0]]]), return_inds=True)
# Checks if both arrays are exactly the same
- assert np.array_equal(result[0], np.array([]))
+ assert np.isnan(result[0])
# Continue to check the other parts of the tuple
assert result[1] == [(np.nan, np.nan)]
# Check the other NumPy arrays in the tuple
- assert np.array_equal(result[2], np.empty((0, 2)))
- assert np.array_equal(result[3], np.empty((0, 2)))
+ assert np.all(np.isnan(result[2]))
+ assert np.all(np.isnan(result[3]))
# All NaNs in input arrays
result = get_root_widths(
np.array([[np.nan, np.nan], [np.nan, np.nan]]),
np.array([[[np.nan, np.nan], [np.nan, np.nan]]]),
)
- assert np.array_equal(result, np.array([]))
+ assert np.isnan(result)
# All lateral roots on the same side
result = get_root_widths(
np.array([[0, 0], [1, 1]]), np.array([[[0, 0], [1, 1]], [[0, 0], [1, 1]]])
)
- assert np.array_equal(result, np.array([]))
+ assert np.isnan(result)
diff --git a/tests/test_lengths.py b/tests/test_lengths.py
index bd58299..0021710 100644
--- a/tests/test_lengths.py
+++ b/tests/test_lengths.py
@@ -2,10 +2,12 @@
get_curve_index,
get_root_lengths,
get_max_length_pts,
+ get_min_distance_line_to_line,
)
from sleap_roots.bases import get_base_tip_dist, get_bases
from sleap_roots.tips import get_tips
from sleap_roots import Series
+from shapely.geometry import LineString
import numpy as np
import pytest
@@ -145,6 +147,29 @@ def lengths_all_nan():
return np.array([np.nan, np.nan, np.nan])
+def test_min_distance_line_to_line():
+ # Test with non-intersecting lines
+ line1 = LineString([(0, 0), (1, 1)])
+ line2 = LineString([(1, 0), (2, 0)])
+ assert get_min_distance_line_to_line(line1, line2) == np.sqrt(2) / 2
+
+ # Test with intersecting lines (expect 0 distance)
+ line1 = LineString([(0, 0), (1, 1)])
+ line2 = LineString([(0, 1), (1, 0)])
+ assert get_min_distance_line_to_line(line1, line2) == 0
+
+ # Test with parallel lines
+ line1 = LineString([(0, 0), (1, 0)])
+ line2 = LineString([(0, 1), (1, 1)])
+ assert get_min_distance_line_to_line(line1, line2) == 1
+
+ # Test with invalid input types
+ with pytest.raises(TypeError):
+ get_min_distance_line_to_line("not a linestring", LineString([(0, 0), (1, 1)]))
+ with pytest.raises(TypeError):
+ get_min_distance_line_to_line(LineString([(0, 0), (1, 1)]), "not a linestring")
+
+
# tests for get_curve_index function
def test_get_curve_index_canola(canola_h5):
# Set the frame index to 0
diff --git a/tests/test_points.py b/tests/test_points.py
index f9b4c3e..6ac3a1d 100644
--- a/tests/test_points.py
+++ b/tests/test_points.py
@@ -1,8 +1,9 @@
-import pytest
import numpy as np
+import pytest
+from shapely.geometry import LineString
from sleap_roots import Series
from sleap_roots.lengths import get_max_length_pts
-from sleap_roots.points import get_count, join_pts
+from sleap_roots.points import filter_plants_with_unexpected_ct, get_count, join_pts
from sleap_roots.points import (
get_all_pts_array,
get_nodes,
@@ -10,6 +11,9 @@
get_left_normalized_vector,
get_right_normalized_vector,
get_line_equation_from_points,
+ associate_lateral_to_primary,
+ flatten_associated_points,
+ filter_roots_with_nans,
)
@@ -355,3 +359,382 @@ def test_get_line_equation_from_points(pts1, pts2, expected):
def test_get_line_equation_input_errors(pts1, pts2):
with pytest.raises(ValueError):
get_line_equation_from_points(pts1, pts2)
+
+
+def test_associate_basic():
+ # Tests basic association between one primary and one lateral root.
+ primary_pts = np.array([[[0, 0], [0, 1]]])
+ lateral_pts = np.array([[[0, 1], [0, 2]]])
+
+ expected = {0: {"primary_points": primary_pts[0], "lateral_points": lateral_pts}}
+ result = associate_lateral_to_primary(primary_pts, lateral_pts)
+
+ # Ensure the keys match
+ assert set(result.keys()) == set(expected.keys())
+
+ # Loop through the result and the expected dictionary to compare the numpy arrays within
+ for key in expected:
+ # Ensure both dictionaries have the same keys (e.g., 'primary_points', 'lateral_points')
+ assert set(result[key].keys()) == set(expected[key].keys())
+
+ # Now compare the NumPy arrays for each key within the dictionaries
+ for sub_key in expected[key]:
+ np.testing.assert_array_equal(result[key][sub_key], expected[key][sub_key])
+
+
+def test_associate_no_primary():
+ # Tests that an empty dictionary is returned when there are no primary roots.
+ primary_pts = np.empty((0, 6, 2)) # Empty array representing no primary roots
+ lateral_pts = np.array([[[0, 1], [0, 2]]]) # Some lateral roots for the test
+
+ expected = {} # Expect an empty dictionary when there are no primary roots
+ result = associate_lateral_to_primary(primary_pts, lateral_pts)
+
+ assert result == expected
+
+
+def test_associate_no_lateral():
+ # Tests that correct association is made when there are no lateral roots.
+ primary_pts = np.array([[[0, 0], [0, 1]]])
+ lateral_pts = np.empty((0, 2, 2)) # No lateral roots
+
+ expected = {
+ 0: {
+ "primary_points": primary_pts[0],
+ "lateral_points": np.full((1, 2, 2), np.nan),
+ }
+ }
+ result = associate_lateral_to_primary(primary_pts, lateral_pts)
+
+ # Ensure the keys match
+ assert set(result.keys()) == set(expected.keys())
+
+ # Loop through the result and the expected dictionary to compare the numpy arrays within
+ for key in expected:
+ # Ensure both dictionaries have the same keys (e.g., 'primary_points', 'lateral_points')
+ assert set(result[key].keys()) == set(expected[key].keys())
+
+ # Now compare the NumPy arrays for each key within the dictionaries
+ for sub_key in expected[key]:
+ np.testing.assert_array_equal(result[key][sub_key], expected[key][sub_key])
+
+
+def test_associate_invalid_input_type():
+ # Tests that the function raises a ValueError with invalid input types.
+ primary_pts = [[[0, 0], [0, 1]]]
+ lateral_pts = [[[0, 1], [0, 2]]]
+
+ with pytest.raises(ValueError):
+ associate_lateral_to_primary(primary_pts, lateral_pts)
+
+
+def test_associate_incorrect_dimensions():
+ # Tests the function raises a ValueError when input dimensions are incorrect.
+ primary_pts = np.array([[0, 0], [0, 1]]) # Missing a dimension
+ lateral_pts = np.array([[[0, 1], [0, 2]]])
+
+ with pytest.raises(ValueError):
+ associate_lateral_to_primary(primary_pts, lateral_pts)
+
+
+def test_associate_incorrect_coordinate_dimensions():
+ # Tests that the function handles incorrect coordinate dimensions.
+ primary_pts = np.array([[[0, 0, 0], [0, 1, 1]]])
+ lateral_pts = np.array([[[0, 1, 1], [0, 2, 2]]])
+
+ with pytest.raises(ValueError):
+ associate_lateral_to_primary(primary_pts, lateral_pts)
+
+
+def test_associate_lateral_to_primary_valid_input():
+ """Ensures correct associations are made with valid input."""
+ primary_pts = np.array([[[0, 0], [0, 10]], [[10, 0], [10, 10]]])
+ lateral_pts = np.array([[[5, 5], [5, 6]], [[11, 0], [11, 1]]])
+ filtered_primary = filter_roots_with_nans(primary_pts)
+ filtered_lateral = filter_roots_with_nans(lateral_pts)
+ associations = associate_lateral_to_primary(filtered_primary, filtered_lateral)
+ assert len(associations) == 2
+ # Check that the first lateral root is associated with the first primary root
+ assert np.array_equal(
+ associations[0]["lateral_points"], np.array([[[5, 5], [5, 6]]])
+ )
+ # Check that the second lateral root is associated with the second primary root
+ assert np.array_equal(
+ associations[1]["lateral_points"], np.array([[[11, 0], [11, 1]]])
+ )
+
+
+def test_associate_lateral_to_primary_all_nan_laterals():
+ """Ensures lateral roots with NaNs are ignored."""
+ primary_pts = np.array([[[0, 0], [0, 10]]])
+ lateral_pts = np.array([[[np.nan, np.nan], [np.nan, np.nan]]])
+ filtered_primary = filter_roots_with_nans(primary_pts)
+ filtered_lateral = filter_roots_with_nans(lateral_pts)
+ associations = associate_lateral_to_primary(filtered_primary, filtered_lateral)
+ # Expect an empty array for lateral points due to NaN filtering
+ assert np.isnan(associations[0]["lateral_points"]).all()
+
+
+def test_flatten_associated_points_single_primary_no_lateral():
+ # Given a single primary root with no lateral roots,
+ # the function should return a dictionary with a flattened array of the primary points.
+ associations = {
+ 0: {
+ "primary_points": np.array([[1, 2], [3, 4]]),
+ "lateral_points": np.full(
+ (1, 2, 2), np.nan
+ ), # Assuming this represents no lateral points
+ }
+ }
+ expected = {0: np.array([1, 2, 3, 4])}
+ # When
+ result = flatten_associated_points(associations)
+ # Then
+ np.testing.assert_array_equal(result[0], expected[0])
+
+
+def test_flatten_associated_points_single_primary_single_lateral():
+ # Given a single primary root with one lateral root,
+ # the function should return a flattened array combining both primary and lateral points.
+ associations = {
+ 0: {
+ "primary_points": np.array([[1, 2], [3, 4]]),
+ "lateral_points": np.array([[[5, 6], [7, 8]]]),
+ }
+ }
+ expected = {0: np.array([1, 2, 3, 4, 5, 6, 7, 8])}
+ # When
+ result = flatten_associated_points(associations)
+ # Then
+ np.testing.assert_array_equal(result[0], expected[0])
+
+
+def test_associate_lateral_to_primary_valid_input():
+ """Test associate_lateral_to_primary with valid input arrays."""
+ primary_pts = np.array([[[0, 0], [0, 10]], [[10, 0], [10, 10]]])
+ lateral_pts = np.array([[[5, 5], [5, 6]], [[11, 0], [11, 1]]])
+ associations = associate_lateral_to_primary(primary_pts, lateral_pts)
+ assert len(associations) == 2
+ assert len(associations[0]["lateral_points"]) == 1
+ assert len(associations[1]["lateral_points"]) == 1
+ assert np.array_equal(associations[0]["lateral_points"], [[[5, 5], [5, 6]]])
+ assert np.array_equal(associations[1]["lateral_points"], [[[11, 0], [11, 1]]])
+
+
+def test_associate_lateral_to_primary_nan_values():
+ """Test associate_lateral_to_primary with NaN values in lateral roots."""
+ primary_pts = np.array([[[0, 0], [0, 10]]])
+ lateral_pts = np.array([[[np.nan, np.nan], [1, 1]]])
+ associations = associate_lateral_to_primary(primary_pts, lateral_pts)
+ assert len(associations) == 1
+ assert len(associations[0]["lateral_points"]) == 1
+
+
+def test_associate_lateral_to_primary_invalid_input_type():
+ """Test associate_lateral_to_primary with invalid input types."""
+ with pytest.raises(ValueError):
+ associate_lateral_to_primary(None, None)
+
+
+def test_associate_lateral_to_primary_invalid_input_shape():
+ """Test associate_lateral_to_primary with invalid input shapes."""
+ primary_pts = np.array([0, 0]) # Invalid shape
+ lateral_pts = np.array([[[1, 1], [2, 2]]])
+ with pytest.raises(ValueError):
+ associate_lateral_to_primary(primary_pts, lateral_pts)
+
+
+def test_associate_lateral_to_primary_large_dataset():
+ """Test associate_lateral_to_primary with a larger dataset to check performance and correctness."""
+ np.random.seed(0)
+ primary_pts = np.random.randint(0, 100, (10, 5, 2))
+ lateral_pts = np.random.randint(0, 100, (20, 5, 2))
+ associations = associate_lateral_to_primary(primary_pts, lateral_pts)
+ assert (
+ len(associations) == 10
+ ) # Assuming all primary roots have at least one lateral root associated
+
+
+def test_flatten_associated_points_multiple_primaries_multiple_laterals():
+ # Given multiple primary roots, each with one or more lateral roots,
+ # the function should return a dictionary with keys as primary root indices
+ # and values as flattened arrays of their associated primary and lateral points.
+ associations = {
+ 0: {
+ "primary_points": np.array([[1, 2], [3, 4]]),
+ "lateral_points": np.array([[[5, 6], [7, 8]]]),
+ },
+ 1: {
+ "primary_points": np.array([[17, 18], [19, 20]]),
+ "lateral_points": np.concatenate(
+ ([[[9, 10], [11, 12]]], [[[13, 14], [15, 16]]])
+ ),
+ },
+ }
+ expected = {
+ 0: np.array([1, 2, 3, 4, 5, 6, 7, 8]),
+ 1: np.array([17, 18, 19, 20, 9, 10, 11, 12, 13, 14, 15, 16]),
+ }
+ # When
+ result = flatten_associated_points(associations)
+ # Then
+ for key in expected:
+ np.testing.assert_array_equal(result[key], expected[key])
+
+
+def test_flatten_associated_points_empty_input():
+ # Given an empty dictionary for associations,
+ # the function should return an empty dictionary.
+ associations = {}
+ expected = {}
+ # When
+ result = flatten_associated_points(associations)
+ # Then
+ assert result == expected
+
+
+@pytest.mark.parametrize(
+ "associations, expected",
+ [
+ (
+ {
+ 0: {
+ "primary_points": np.array([[1, 2]]),
+ "lateral_points": np.array([[[5, 6]]]),
+ }
+ },
+ {0: np.array([1, 2, 5, 6])},
+ ),
+ ({}, {}),
+ ],
+)
+def test_flatten_associated_points_parametrized(associations, expected):
+ # This parametrized test checks the function with various combinations
+ # of associations.
+ # When
+ result = flatten_associated_points(associations)
+ # Then
+ for key in expected:
+ np.testing.assert_array_equal(result[key], expected[key])
+
+
+def test_filter_roots_with_nans_no_nans():
+ """Test with an array that contains no NaN values."""
+ pts = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
+ expected = pts
+ result = filter_roots_with_nans(pts)
+ np.testing.assert_array_equal(result, expected)
+
+
+def test_filter_roots_with_nans_nan_in_one_instance():
+ """Test with an array where one instance contains NaN values."""
+ pts = np.array([[[1, 2], [3, 4]], [[np.nan, 6], [7, 8]]])
+ expected = np.array([[[1, 2], [3, 4]]])
+ result = filter_roots_with_nans(pts)
+ np.testing.assert_array_equal(result, expected)
+
+
+def test_filter_roots_with_nans_all_nans_in_one_instance():
+ """Test with an array where one instance is entirely NaN."""
+ pts = np.array([[[np.nan, np.nan], [np.nan, np.nan]], [[5, 6], [7, 8]]])
+ expected = np.array([[[5, 6], [7, 8]]])
+ result = filter_roots_with_nans(pts)
+ np.testing.assert_array_equal(result, expected)
+
+
+def test_filter_roots_with_nans_nan_across_multiple_instances():
+ """Test with NaN values scattered across multiple instances."""
+ pts = np.array([[[1, np.nan], [3, 4]], [[5, 6], [np.nan, 8]], [[9, 10], [11, 12]]])
+ expected = np.array([[[9, 10], [11, 12]]])
+ result = filter_roots_with_nans(pts)
+ np.testing.assert_array_equal(result, expected)
+
+
+def test_filter_roots_with_nans_all_instances_contain_nans():
+ """Test with an array where all instances contain at least one NaN value."""
+ pts = np.array(
+ [[[np.nan, 2], [3, 4]], [[5, np.nan], [7, 8]], [[9, 10], [np.nan, 12]]]
+ )
+ expected = np.empty((0, pts.shape[1], 2))
+ result = filter_roots_with_nans(pts)
+ np.testing.assert_array_equal(result, expected)
+
+
+def test_filter_roots_with_nans_empty_array():
+ """Test with an empty array."""
+ pts = np.empty((0, 0, 2))
+ expected = np.empty((0, 0, 2))
+ result = filter_roots_with_nans(pts)
+ np.testing.assert_array_equal(result, expected)
+
+
+def test_filter_roots_with_nans_single_instance_with_nans():
+ """Test with a single instance that contains NaN values."""
+ pts = np.array([[[np.nan, np.nan], [np.nan, np.nan]]])
+ expected = np.empty((0, pts.shape[1], 2))
+ result = filter_roots_with_nans(pts)
+ np.testing.assert_array_equal(result, expected)
+
+
+def test_filter_roots_with_nans_single_instance_without_nans():
+ """Test with a single instance that does not contain NaN values."""
+ pts = np.array([[[1, 2], [3, 4]]])
+ expected = pts
+ result = filter_roots_with_nans(pts)
+ np.testing.assert_array_equal(result, expected)
+
+
+def test_filter_plants_with_unexpected_ct_valid_input_matching_count():
+ """Test with valid input where the number of primary roots matches the expected count."""
+ primary_pts = np.random.rand(5, 10, 2)
+ lateral_pts = np.random.rand(5, 10, 2)
+ expected_count = 5.0
+ filtered_primary, filtered_lateral = filter_plants_with_unexpected_ct(
+ primary_pts, lateral_pts, expected_count
+ )
+ assert np.array_equal(filtered_primary, primary_pts)
+ assert np.array_equal(filtered_lateral, lateral_pts)
+
+
+def test_filter_plants_with_unexpected_ct_valid_input_non_matching_count():
+ """Test with valid input where the number of primary roots does not match the expected count."""
+ primary_pts = np.random.rand(5, 10, 2)
+ lateral_pts = np.random.rand(5, 10, 2)
+ expected_count = 3.0 # Non-matching count
+ filtered_primary, filtered_lateral = filter_plants_with_unexpected_ct(
+ primary_pts, lateral_pts, expected_count
+ )
+ assert filtered_primary.shape == (0, primary_pts.shape[1], 2)
+ assert filtered_lateral.shape == (0, lateral_pts.shape[1], 2)
+
+
+def test_filter_plants_with_unexpected_ct_nan_expected_count():
+ """Test with NaN as the expected count, which should skip filtering."""
+ primary_pts = np.random.rand(5, 10, 2)
+ lateral_pts = np.random.rand(5, 10, 2)
+ expected_count = np.nan
+ filtered_primary, filtered_lateral = filter_plants_with_unexpected_ct(
+ primary_pts, lateral_pts, expected_count
+ )
+ assert np.array_equal(filtered_primary, primary_pts)
+ assert np.array_equal(filtered_lateral, lateral_pts)
+
+
+def test_filter_plants_with_unexpected_ct_incorrect_input_types():
+ """Test with incorrect input types to ensure ValueError is raised."""
+ primary_pts = "not a numpy array"
+ lateral_pts = np.random.rand(5, 10, 2)
+ expected_count = 5.0
+ with pytest.raises(ValueError):
+ filter_plants_with_unexpected_ct(primary_pts, lateral_pts, expected_count)
+
+ primary_pts = np.random.rand(5, 10, 2)
+ lateral_pts = "not a numpy array"
+ with pytest.raises(ValueError):
+ filter_plants_with_unexpected_ct(primary_pts, lateral_pts, expected_count)
+
+ primary_pts = np.random.rand(5, 10, 2)
+ lateral_pts = np.random.rand(5, 10, 2)
+ expected_count = "not a float"
+ with pytest.raises(ValueError):
+ filter_plants_with_unexpected_ct(primary_pts, lateral_pts, expected_count)
diff --git a/tests/test_series.py b/tests/test_series.py
index 861f198..9c528a1 100644
--- a/tests/test_series.py
+++ b/tests/test_series.py
@@ -6,6 +6,12 @@
from typing import Literal
+@pytest.fixture
+def series_instance():
+ # Create a Series instance with dummy data
+ return Series(h5_path="dummy.h5")
+
+
@pytest.fixture
def dummy_video_path(tmp_path):
video_path = tmp_path / "dummy_video.mp4"
@@ -41,6 +47,16 @@ def dummy_series(dummy_video_path, dummy_labels_path):
return Series.load(**kwargs)
+@pytest.fixture
+def csv_path(tmp_path):
+ # Create a dummy CSV file
+ csv_path = tmp_path / "dummy.csv"
+ csv_path.write_text(
+ "plant_qr_code,number_of_plants_cylinder,genotype\ndummy,10,1100\nseries2,15,Kitaake-X\n"
+ )
+ return csv_path
+
+
def test_series_name(dummy_series):
expected_name = "dummy_video" # Based on the dummy_video_path fixture
assert dummy_series.series_name == expected_name
@@ -60,6 +76,15 @@ def test_series_name_property():
assert series.series_name == "file_name"
+def test_series_name(series_instance):
+ assert series_instance.series_name == "dummy"
+
+
+def test_expected_count(series_instance, csv_path):
+ series_instance.csv_path = csv_path
+ assert series_instance.expected_count == 10
+
+
def test_len():
series = Series(video=["frame1", "frame2"])
assert len(series) == 2
diff --git a/tests/test_trait_pipelines.py b/tests/test_trait_pipelines.py
index 0b7435c..c16e814 100644
--- a/tests/test_trait_pipelines.py
+++ b/tests/test_trait_pipelines.py
@@ -1,7 +1,10 @@
+import numpy as np
+import pandas as pd
from sleap_roots.trait_pipelines import (
DicotPipeline,
YoungerMonocotPipeline,
OlderMonocotPipeline,
+ MultipleDicotPipeline,
)
from sleap_roots.series import Series, find_all_series
@@ -133,3 +136,55 @@ def test_older_monocot_pipeline(rice_main_10do_h5, rice_10do_folder):
(0 <= all_traits["crown_angles_proximal_median_p95"])
& (all_traits["crown_angles_proximal_median_p95"] <= 180)
).all(), "angle_column in all_traits contains values out of range [0, 180]"
+
+
+def test_multiple_dicot_pipeline(
+ multiple_arabidopsis_11do_h5,
+ multiple_arabidopsis_11do_folder,
+ multiple_arabidopsis_11do_csv,
+):
+ arabidopsis = Series.load(
+ multiple_arabidopsis_11do_h5,
+ primary_name="primary",
+ lateral_name="lateral",
+ csv_path=multiple_arabidopsis_11do_csv,
+ )
+ arabidopsis_series_all = find_all_series(multiple_arabidopsis_11do_folder)
+ series_all = [
+ Series.load(
+ series,
+ primary_name="primary",
+ lateral_name="lateral",
+ csv_path=multiple_arabidopsis_11do_csv,
+ )
+ for series in arabidopsis_series_all
+ ]
+
+ pipeline = MultipleDicotPipeline()
+ arabidopsis_traits = pipeline.compute_multiple_dicots_traits(arabidopsis)
+ all_traits = pipeline.compute_batch_multiple_dicots_traits(series_all)
+
+ # Dataframe shape assertions
+ assert pd.DataFrame([arabidopsis_traits["summary_stats"]]).shape == (1, 315)
+ assert all_traits.shape == (4, 316)
+
+ # Dataframe dtype assertions
+ expected_all_traits_dtypes = {
+ "lateral_count_min": "int64",
+ "lateral_count_max": "int64",
+ }
+
+ for col, expected_dtype in expected_all_traits_dtypes.items():
+ assert np.issubdtype(
+ all_traits[col].dtype, np.integer
+ ), f"Unexpected dtype for column {col} in all_traits. Expected integer, got {all_traits[col].dtype}"
+
+ # Value range assertions for traits
+ assert (
+ all_traits["curve_index_median"] >= 0
+ ).all(), "curve_index in all_traits contains negative values"
+
+ # Check that series dictionary
+ assert isinstance(arabidopsis_traits, dict)
+ assert arabidopsis_traits["series"] == "997_1"
+ assert arabidopsis_traits["group"] == "997"