diff --git a/sleap_roots/series.py b/sleap_roots/series.py index c769fe4..a3e97d3 100644 --- a/sleap_roots/series.py +++ b/sleap_roots/series.py @@ -17,7 +17,14 @@ class Series: """Data and predictions for a single image series. Attributes: + series_name: Unique identifier for the series. h5_path: Optional path to the HDF5-formatted image series. + primary_path: Optional path to the primary root predictions file. At least one + of the primary, lateral, or crown paths must be provided. + lateral_path: Optional path to the lateral root predictions file. At least one + of the primary, lateral, or crown paths must be provided. + crown_path: Optional path to the crown predictions file. At least one of the + primary, lateral, or crown paths must be provided. primary_labels: Optional `sio.Labels` corresponding to the primary root predictions. lateral_labels: Optional `sio.Labels` corresponding to the lateral root predictions. crown_labels: Optional `sio.Labels` corresponding to the crown predictions. @@ -36,13 +43,16 @@ class Series: get_crown_points: Get crown root points. Properties: - series_name: Name of the series derived from the HDF5 filename. expected_count: Fetch the expected plant count for this series from the CSV. group: Group name for the series from the CSV. qc_fail: Flag to indicate if the series failed QC from the CSV. """ + series_name: str h5_path: Optional[str] = None + primary_path: Optional[str] = None + lateral_path: Optional[str] = None + crown_path: Optional[str] = None primary_labels: Optional[sio.Labels] = None lateral_labels: Optional[sio.Labels] = None crown_labels: Optional[sio.Labels] = None @@ -52,22 +62,22 @@ class Series: @classmethod def load( cls, - h5_path: str, - primary_name: Optional[str] = None, - lateral_name: Optional[str] = None, - crown_name: Optional[str] = None, + series_name: str, + h5_path: Optional[str] = None, + primary_path: Optional[str] = None, + lateral_path: Optional[str] = None, + crown_path: Optional[str] = None, csv_path: Optional[str] = None, ) -> "Series": """Load a set of predictions for this series. Args: - h5_path: Path to the HDF5-formatted image series. - primary_name: Optional name of the primary root predictions file. If provided, - the file is expected to be named "{h5_path}.{primary_name}.predictions.slp". - lateral_name: Optional name of the lateral root predictions file. If provided, - the file is expected to be named "{h5_path}.{lateral_name}.predictions.slp". - crown_name: Optional name of the crown predictions file. If provided, - the file is expected to be named "{h5_path}.{crown_name}.predictions.slp". + series_name: Unique identifier for the series. + h5_path: Optional path to the HDF5-formatted image series, which will be + used to load the video. + primary_path: Optional path to the primary root '.slp' predictions file. + lateral_path: Optional path to the lateral root '.slp' predictions file. + crown_path: Optional path to the crown '.slp' predictions file. csv_path: Optional path to the CSV file containing the expected plant count. Returns: @@ -78,50 +88,68 @@ def load( # Attempt to load the predictions, with error handling try: - if primary_name: - primary_path = ( - Path(h5_path) - .with_suffix(f".{primary_name}.predictions.slp") - .as_posix() - ) - if Path(primary_path).exists(): - primary_labels = sio.load_slp(primary_path) + if primary_path: + # Make path object + primary_path = Path(primary_path) + # Check if the file exists + if primary_path.exists(): + # Load the primary predictions + primary_labels = sio.load_slp(primary_path.as_posix()) else: - print(f"Primary prediction file not found: {primary_path}") - if lateral_name: - lateral_path = ( - Path(h5_path) - .with_suffix(f".{lateral_name}.predictions.slp") - .as_posix() - ) - if Path(lateral_path).exists(): - lateral_labels = sio.load_slp(lateral_path) + print( + f"Primary prediction file not found: {primary_path.as_posix()}" + ) + if lateral_path: + # Make path object + lateral_path = Path(lateral_path) + # Check if the file exists + if lateral_path.exists(): + # Load the lateral predictions + lateral_labels = sio.load_slp(lateral_path.as_posix()) else: - print(f"Lateral prediction file not found: {lateral_path}") - if crown_name: - crown_path = ( - Path(h5_path) - .with_suffix(f".{crown_name}.predictions.slp") - .as_posix() - ) - if Path(crown_path).exists(): - crown_labels = sio.load_slp(crown_path) + print( + f"Lateral prediction file not found: {lateral_path.as_posix()}" + ) + if crown_path: + # Make path object + crown_path = Path(crown_path) + # Check if the file exists + if crown_path.exists(): + # Load the crown predictions + crown_labels = sio.load_slp(crown_path.as_posix()) else: - print(f"Crown prediction file not found: {crown_path}") + print(f"Crown prediction file not found: {crown_path.as_posix()}") except Exception as e: print(f"Error loading prediction files: {e}") # Attempt to load the video, with error handling video = None try: - if not Path(h5_path).exists(): - raise FileNotFoundError(f"File not found") - video = sio.Video.from_filename(h5_path) + if h5_path: + # Make path object + h5_path = Path(h5_path) + # Check if the file exists + if h5_path.exists(): + # Load the video + video = sio.Video.from_filename(h5_path.as_posix()) + else: + print(f"Video file not found: {h5_path.as_posix()}") except Exception as e: print(f"Error loading video file {h5_path}: {e}") + # Replace the filename in the labels with the h5_path if it is provided. + if h5_path: + for labels in [primary_labels, lateral_labels, crown_labels]: + if labels is not None: + if not labels.video.exists(): + labels.video.replace_filename(h5_path) + return cls( + series_name=series_name, h5_path=h5_path, + primary_path=primary_path, + lateral_path=lateral_path, + crown_path=crown_path, primary_labels=primary_labels, lateral_labels=lateral_labels, crown_labels=crown_labels, @@ -129,11 +157,6 @@ def load( csv_path=csv_path, ) - @property - def series_name(self) -> str: - """Name of the series derived from the HDF5 filename.""" - return Path(self.h5_path).name.split(".")[0] - @property def expected_count(self) -> Union[float, int]: """Fetch the expected plant count for this series from the CSV.""" @@ -142,7 +165,8 @@ def expected_count(self) -> Union[float, int]: return np.nan df = pd.read_csv(self.csv_path) try: - # Match the series_name (or plant_qr_code in the CSV) to fetch the expected count + # Match the series_name (or plant_qr_code in the CSV) to fetch the expected + # count return df[df["plant_qr_code"] == self.series_name][ "number_of_plants_cylinder" ].iloc[0] @@ -349,22 +373,22 @@ def get_crown_points(self, frame_idx: int) -> np.ndarray: return crown_pts -def find_all_series(data_folders: Union[str, List[str]]) -> List[str]: - """Find all .h5 series from a list of folders. +def find_all_h5_paths(data_folders: Union[str, List[str]]) -> List[str]: + """Find all .h5 paths from a list of folders. Args: - data_folders: Path or list of paths to folders containing .h5 series. + data_folders: Path or list of paths to folders containing .h5 paths. Returns: - A list of filenames to .h5 series. + A list of filenames to .h5 paths. """ if type(data_folders) != list: data_folders = [data_folders] - h5_series = [] + h5_paths = [] for data_folder in data_folders: - h5_series.extend([Path(p).as_posix() for p in Path(data_folder).glob("*.h5")]) - return h5_series + h5_paths.extend([Path(p).as_posix() for p in Path(data_folder).glob("*.h5")]) + return h5_paths def imgfig(