Skip to content

Commit

Permalink
add prediction paths as attributes in Series class
Browse files Browse the repository at this point in the history
  • Loading branch information
eberrigan committed May 9, 2024
1 parent 151ddb3 commit f7d66d5
Showing 1 changed file with 79 additions and 55 deletions.
134 changes: 79 additions & 55 deletions sleap_roots/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@ class Series:
"""Data and predictions for a single image series.
Attributes:
series_name: Unique identifier for the series.
h5_path: Optional path to the HDF5-formatted image series.
primary_path: Optional path to the primary root predictions file. At least one
of the primary, lateral, or crown paths must be provided.
lateral_path: Optional path to the lateral root predictions file. At least one
of the primary, lateral, or crown paths must be provided.
crown_path: Optional path to the crown predictions file. At least one of the
primary, lateral, or crown paths must be provided.
primary_labels: Optional `sio.Labels` corresponding to the primary root predictions.
lateral_labels: Optional `sio.Labels` corresponding to the lateral root predictions.
crown_labels: Optional `sio.Labels` corresponding to the crown predictions.
Expand All @@ -36,13 +43,16 @@ class Series:
get_crown_points: Get crown root points.
Properties:
series_name: Name of the series derived from the HDF5 filename.
expected_count: Fetch the expected plant count for this series from the CSV.
group: Group name for the series from the CSV.
qc_fail: Flag to indicate if the series failed QC from the CSV.
"""

series_name: str
h5_path: Optional[str] = None
primary_path: Optional[str] = None
lateral_path: Optional[str] = None
crown_path: Optional[str] = None
primary_labels: Optional[sio.Labels] = None
lateral_labels: Optional[sio.Labels] = None
crown_labels: Optional[sio.Labels] = None
Expand All @@ -52,22 +62,22 @@ class Series:
@classmethod
def load(
cls,
h5_path: str,
primary_name: Optional[str] = None,
lateral_name: Optional[str] = None,
crown_name: Optional[str] = None,
series_name: str,
h5_path: Optional[str] = None,
primary_path: Optional[str] = None,
lateral_path: Optional[str] = None,
crown_path: Optional[str] = None,
csv_path: Optional[str] = None,
) -> "Series":
"""Load a set of predictions for this series.
Args:
h5_path: Path to the HDF5-formatted image series.
primary_name: Optional name of the primary root predictions file. If provided,
the file is expected to be named "{h5_path}.{primary_name}.predictions.slp".
lateral_name: Optional name of the lateral root predictions file. If provided,
the file is expected to be named "{h5_path}.{lateral_name}.predictions.slp".
crown_name: Optional name of the crown predictions file. If provided,
the file is expected to be named "{h5_path}.{crown_name}.predictions.slp".
series_name: Unique identifier for the series.
h5_path: Optional path to the HDF5-formatted image series, which will be
used to load the video.
primary_path: Optional path to the primary root '.slp' predictions file.
lateral_path: Optional path to the lateral root '.slp' predictions file.
crown_path: Optional path to the crown '.slp' predictions file.
csv_path: Optional path to the CSV file containing the expected plant count.
Returns:
Expand All @@ -78,62 +88,75 @@ def load(

# Attempt to load the predictions, with error handling
try:
if primary_name:
primary_path = (
Path(h5_path)
.with_suffix(f".{primary_name}.predictions.slp")
.as_posix()
)
if Path(primary_path).exists():
primary_labels = sio.load_slp(primary_path)
if primary_path:
# Make path object
primary_path = Path(primary_path)
# Check if the file exists
if primary_path.exists():
# Load the primary predictions
primary_labels = sio.load_slp(primary_path.as_posix())
else:
print(f"Primary prediction file not found: {primary_path}")
if lateral_name:
lateral_path = (
Path(h5_path)
.with_suffix(f".{lateral_name}.predictions.slp")
.as_posix()
)
if Path(lateral_path).exists():
lateral_labels = sio.load_slp(lateral_path)
print(
f"Primary prediction file not found: {primary_path.as_posix()}"
)
if lateral_path:
# Make path object
lateral_path = Path(lateral_path)
# Check if the file exists
if lateral_path.exists():
# Load the lateral predictions
lateral_labels = sio.load_slp(lateral_path.as_posix())
else:
print(f"Lateral prediction file not found: {lateral_path}")
if crown_name:
crown_path = (
Path(h5_path)
.with_suffix(f".{crown_name}.predictions.slp")
.as_posix()
)
if Path(crown_path).exists():
crown_labels = sio.load_slp(crown_path)
print(
f"Lateral prediction file not found: {lateral_path.as_posix()}"
)
if crown_path:
# Make path object
crown_path = Path(crown_path)
# Check if the file exists
if crown_path.exists():
# Load the crown predictions
crown_labels = sio.load_slp(crown_path.as_posix())
else:
print(f"Crown prediction file not found: {crown_path}")
print(f"Crown prediction file not found: {crown_path.as_posix()}")
except Exception as e:
print(f"Error loading prediction files: {e}")

# Attempt to load the video, with error handling
video = None
try:
if not Path(h5_path).exists():
raise FileNotFoundError(f"File not found")
video = sio.Video.from_filename(h5_path)
if h5_path:
# Make path object
h5_path = Path(h5_path)
# Check if the file exists
if h5_path.exists():
# Load the video
video = sio.Video.from_filename(h5_path.as_posix())
else:
print(f"Video file not found: {h5_path.as_posix()}")
except Exception as e:
print(f"Error loading video file {h5_path}: {e}")

# Replace the filename in the labels with the h5_path if it is provided.
if h5_path:
for labels in [primary_labels, lateral_labels, crown_labels]:
if labels is not None:
if not labels.video.exists():
labels.video.replace_filename(h5_path)

return cls(
series_name=series_name,
h5_path=h5_path,
primary_path=primary_path,
lateral_path=lateral_path,
crown_path=crown_path,
primary_labels=primary_labels,
lateral_labels=lateral_labels,
crown_labels=crown_labels,
video=video,
csv_path=csv_path,
)

@property
def series_name(self) -> str:
"""Name of the series derived from the HDF5 filename."""
return Path(self.h5_path).name.split(".")[0]

@property
def expected_count(self) -> Union[float, int]:
"""Fetch the expected plant count for this series from the CSV."""
Expand All @@ -142,7 +165,8 @@ def expected_count(self) -> Union[float, int]:
return np.nan
df = pd.read_csv(self.csv_path)
try:
# Match the series_name (or plant_qr_code in the CSV) to fetch the expected count
# Match the series_name (or plant_qr_code in the CSV) to fetch the expected
# count
return df[df["plant_qr_code"] == self.series_name][
"number_of_plants_cylinder"
].iloc[0]
Expand Down Expand Up @@ -349,22 +373,22 @@ def get_crown_points(self, frame_idx: int) -> np.ndarray:
return crown_pts


def find_all_series(data_folders: Union[str, List[str]]) -> List[str]:
"""Find all .h5 series from a list of folders.
def find_all_h5_paths(data_folders: Union[str, List[str]]) -> List[str]:
"""Find all .h5 paths from a list of folders.
Args:
data_folders: Path or list of paths to folders containing .h5 series.
data_folders: Path or list of paths to folders containing .h5 paths.
Returns:
A list of filenames to .h5 series.
A list of filenames to .h5 paths.
"""
if type(data_folders) != list:
data_folders = [data_folders]

h5_series = []
h5_paths = []
for data_folder in data_folders:
h5_series.extend([Path(p).as_posix() for p in Path(data_folder).glob("*.h5")])
return h5_series
h5_paths.extend([Path(p).as_posix() for p in Path(data_folder).glob("*.h5")])
return h5_paths


def imgfig(
Expand Down

0 comments on commit f7d66d5

Please sign in to comment.