diff --git a/sleap_io/io/coco.py b/sleap_io/io/coco.py new file mode 100644 index 00000000..c9d8c596 --- /dev/null +++ b/sleap_io/io/coco.py @@ -0,0 +1,291 @@ +"""This module implements routines for reading and writing COCO-formatted datasets.""" + +from __future__ import annotations +import numpy as np +import simplejson as json +from pathlib import Path +from collections import defaultdict +from sleap_io import ( + Video, + Skeleton, + Track, + Instance, + LabeledFrame, + Labels, +) + +import sys +import imageio.v3 as iio + +try: + import cv2 +except ImportError: + pass + + +def read_ann(ann_json_path: str | Path): + """Read annotations JSON file. + + Args: + ann_json_path: Path to a JSON file with the annotations. + + Returns: + A dictionary with the parsed data. + """ + with open(ann_json_path, "r") as f: + ann = json.load(f) + return ann + + +def make_skeleton(ann: dict) -> Skeleton: + """Parse skeleton metadata. + + Args: + ann: Dictionary with decoded JSON data. Must contain a key named "categories". + This key must contain sub-keys "keypoints" (node names), "skeleton" (edges), + and optionally "name". + + Returns: + The `Skeleton` object. + + Notes: + This assumes that `skeleton` (edge indices) are 1-based. + """ + return Skeleton( + nodes=ann["categories"][0]["keypoints"], + edges=(np.array(ann["categories"][0]["skeleton"]) - 1).tolist(), + name=ann["categories"][0].get("name", None), + ) + + +def make_videos( + ann: dict, imgs_prefix: str | Path | None = None +) -> tuple[list[Video], dict[int, tuple[int, int]]]: + """Make videos and return mapping to indices. + + Args: + ann: Dictionary with decoded JSON data. Must contain a key named "images". + imgs_prefix: Optional path specifying a prefix to prepend to image filenames. + + Returns: + A tuple of `videos, video_id_map`. + + `videos` is a list of `Video`s. + + `video_id_map` is a dictionary that maps an image ID to a tuple of + `(video_ind, frame_ind)`, corresponding to the order in `videos`. + + Notes: + This function will group images that have the same shape together into a single + logical video. + """ + if type(imgs_prefix) == str: + imgs_prefix = Path(imgs_prefix) + imgs_by_shape = defaultdict(list) + video_id_map = {} + for img in ann["images"]: + shape = img["height"], img["width"] + img_filename = img["file_name"] + if imgs_prefix is not None: + img_filename = (imgs_prefix / img_filename).as_posix() + imgs_by_shape[shape].append(img_filename) + video_id_map[ann["id"]] = ( + imgs_by_shape.keys().index(shape), + len(imgs_by_shape[shape]) - 1, + ) + + videos = [] + for shape, imgs in imgs_by_shape.items(): + videos.append( + Video.from_filename(imgs, backend_metadata={"shape": shape + (3,)}) + ) + + return videos, video_id_map + + +def make_labels( + ann: dict, + videos: list[Video], + video_id_map: dict[int, tuple[int, int]], + skeleton: Skeleton, +) -> Labels: + """Make a `Labels` object from annotations. + + Args: + ann: Dictionary with decoded JSON data. Must contain a key named "annotations". + videos: A list of `Video`s. + video_id_map: A dictionary that maps an image ID to a tuple of + `(video_ind, frame_ind)`, corresponding to the order in `videos`. + skeleton: A `Skeleton`. + + Returns: + A `Labels` file with parsed data. + """ + tracks_by_id = {} + + lfs_by_ind = defaultdict(list) + for an in ann["annotations"]: + pts = np.array(an["keypoints"]).reshape(-1, 3) + pts[pts[:, 3] != 2] = np.nan + pts = pts[:, :2] + + video_ind, frame_ind = video_id_map[an["image_id"]] + + if "track_id" in an: + track_id = an["track_id"] + if track_id in tracks_by_id: + tracks_by_id[track_id] = Track(name=f"{track_id}") + track = tracks_by_id[track_id] + else: + track = None + + lfs_by_ind[(video_ind, frame_ind)].append( + Instance.from_numpy(pts, skeleton=skeleton, track=track) + ) + + lfs = [] + for (video_ind, frame_ind), insts in lfs_by_ind.items(): + lfs.append( + LabeledFrame(video=videos[video_ind], frame_idx=frame_ind, instances=insts) + ) + labels = Labels(lfs) + labels.provenance["info"] = ann.get("info", None) + + return labels + + +def read_labels( + ann_json_path: str | Path, imgs_prefix: str | Path | None = None +) -> Labels: + """Read and parse COCO annotations. + + Args: + ann_json_path: Path to a JSON file with the annotations. + imgs_prefix: Optional path specifying a prefix to prepend to image filenames. + This is typically a path to the folder containing the images. If not + provided, assumes that there exists an "images" folder in the parent + directory of the folder containing the annotations. + + Returns: + `Labels` with the parsed data. + """ + ann = read_ann(ann_json_path) + if imgs_prefix is None: + imgs_prefix = Path(ann_json_path).parent / "images" + videos, video_id_map = make_videos(ann, imgs_prefix=imgs_prefix) + skeleton = make_skeleton(ann) + labels = make_labels(ann, videos, video_id_map, skeleton) + return labels + + +def write_labels( + labels: Labels, + dataset_folder: str | Path, + split: str | None = None, + img_format: str = "png", +): + """Save a `Labels` to COCO format. + + Args: + labels: A `Labels` object. + dataset_folder: Path to a folder to save data to. + split: Optional string specifying the split name. + img_format: Format to save images to. Formats: "png" (default) or "jpg". + + Notes: + If `split` was not provided, the annotations will be saved to + `{dataset_folder}/annotations/ann.json` and images will be saved to + `{dataset_folder}/images`. + + If `split` was provided, the annotations will be saved to + `{dataset_folder}/annotations/ann_{split}.json` and images will be saved to + `{dataset_folder}/images/{split}`. + + Calling this multiple times with the same dataset folder may overwrite previous + data if `split` is not provided. + """ + if split is None: + ann_path = dataset_folder / "annotations" / "ann.json" + imgs_folder = dataset_folder / "images" + else: + ann_path = dataset_folder / "annotations" / f"ann_{split}.json" + imgs_folder = dataset_folder / "images" / split + + ann_path.parent.mkdir(parents=True, exist_ok=True) + imgs_folder.mkdir(parents=True, exist_ok=True) + + lfs = labels.user_labeled_frames + + imgs = [] + img_filename_map = {} + for img_id, lf in enumerate(lfs): + img_filename = f"{img_id}.{img_format}" + img_shape = video.shape[[1, 2]] + imgs.append( + { + "id": img_id, + "file_name": img_filename.as_posix(), + "height": img_shape[0], + "width": img_shape[1], + } + ) + img_filename_map[(lf.video, lf.frame_idx)] = img_filename + + for (video, frame_idx), img_filename in img_filename_map.items(): + img = video[frame_idx] + img_path = (imgs_folder / img_filename).as_posix() + if "cv2" in sys.modules: + cv2.imwrite(img_path, img) + else: + iio.imwrite(img_path, img) + + inst_id = 0 + annotations = [] + for img_id, lf in enumerate(lfs): + for inst in lf: + ann = {} + + pts = inst.numpy() + vis = np.isnan(pts).any(axis=1, keepdims=True).astype(int) + vis[vis == 0] = 2 # labeled and visible + # 1: labeled but not visible + vis[vis == 1] = 0 # not labeled + pts[np.isnan(pts)] = -1 + kps = np.concatenate([pts, vis], axis=1).reshape(-1).tolist() + ann["keypoints"] = kps + ann["id"] = inst_id + ann["image_id"] = img_id + ann["num_keypoints"] = len(pts) + + x, y = np.nanmin(pts, axis=0) + w, h = np.nanmax(pts, axis=0) - np.nanmin(pts, axis=0) + ann["bbox"] = [x, y, w, h] + ann["iscrowd"] = 0 + ann["area"] = w * h + ann["category_id"] = labels.skeletons.index(inst.skeleton) + + if inst.track is not None: + ann["track_id"] = labels.tracks.index(inst.track) + + annotations.append(ann) + inst_id += 1 + + categories = [] + for skel_ind, skel in enumerate(labels.skeletons): + category = {} + category["supercategory"] = "animal" + category["id"] = skel_ind + category["name"] = skel.name + category["keypoints"] = skel.node_names + category["skeleton"] = (np.array(skel.edge_inds) + 1).tolist() + categories.append(category) + + ann = { + "info": labels.provenance.get("info", {}), + "images": imgs, + "annotations": annotations, + "categories": categories, + } + + with open(ann_path, "w") as f: + json.dump(ann, f) diff --git a/sleap_io/io/main.py b/sleap_io/io/main.py index 7fd702f7..38044ce7 100644 --- a/sleap_io/io/main.py +++ b/sleap_io/io/main.py @@ -2,7 +2,7 @@ from __future__ import annotations from sleap_io import Labels, Skeleton, Video -from sleap_io.io import slp, nwb, labelstudio, jabs +from sleap_io.io import slp, nwb, labelstudio, jabs, coco from typing import Optional, Union from pathlib import Path @@ -131,6 +131,43 @@ def save_jabs(labels: Labels, pose_version: int, root_folder: Optional[str] = No jabs.write_labels(labels, pose_version, root_folder) +def load_coco( + filename: str | Path, imgs_prefix: str | Path | None = None +) -> Labels | list[Labels]: + """Load a COCO dataset. + + Args: + filename: Path to a JSON file with the annotations or a directory containing + folders named "annotation" and "images". + imgs_prefix: Optional path specifying a prefix to prepend to image filenames. + This is typically a path to the folder containing the images. If not + provided, assumes that there exists an "images" folder in the parent + directory of the folder containing the annotations. + + Returns: + The parsed `Labels`. + + Notes: + If a directory is provided, the first JSON annotation file will be loaded. To + load a specific split when multiple JSON files are present, specify a direct + path to the JSON annotation file. + """ + filename = Path(filename) + if filename.is_dir(): + ann_dir = filename / "annotations" + jsons = list(ann_dir.glob("*.json")) + if len(jsons) == 0: + FileNotFoundError( + f"Could not find any JSON files in {ann_dir}. " + "Provide a path to an annotation JSON file or verify that the " + "annotations directory contains JSON files." + ) + filename = jsons[0] + # TODO: Recursively call if multiple annotations (usually splits) are found? + + return coco.read_labels(filename, imgs_prefix=imgs_prefix) + + def load_video(filename: str, **kwargs) -> Video: """Load a video file. @@ -155,7 +192,7 @@ def load_file( filename: Path to a file. format: Optional format to load as. If not provided, will be inferred from the file extension. Available formats are: "slp", "nwb", "labelstudio", "jabs" - and "video". + "coco" and "video". Returns: A `Labels` or `Video` object. @@ -169,7 +206,7 @@ def load_file( elif filename.endswith(".nwb"): format = "nwb" elif filename.endswith(".json"): - format = "json" + format = "labelstudio" elif filename.endswith(".h5"): format = "jabs" else: @@ -180,16 +217,20 @@ def load_file( if format is None: raise ValueError(f"Could not infer format from filename: '{filename}'.") - if filename.endswith(".slp"): + if format == "slp": return load_slp(filename, **kwargs) - elif filename.endswith(".nwb"): + elif format == "nwb": return load_nwb(filename, **kwargs) - elif filename.endswith(".json"): + elif format == "labelstudio": return load_labelstudio(filename, **kwargs) - elif filename.endswith(".h5"): + elif format == "jabs": return load_jabs(filename, **kwargs) + elif format == "coco": + return load_coco(filename, **kwargs) elif format == "video": return load_video(filename, **kwargs) + else: + raise ValueError(f"Unknown format: {format}") def save_file( diff --git a/sleap_io/io/video.py b/sleap_io/io/video.py index 25ccc0d0..aeb7070f 100644 --- a/sleap_io/io/video.py +++ b/sleap_io/io/video.py @@ -737,7 +737,11 @@ def _read_frame(self, frame_idx: int) -> np.ndarray: This does not apply grayscale conversion. It is recommended to use the `get_frame` method of the `VideoBackend` class instead. """ - img = iio.imread(self.filename[frame_idx]) + if "cv2" in sys.modules: + img = cv2.imread(self.filename[frame_idx]) + else: + img = iio.imread(self.filename[frame_idx]) + if img.ndim == 2: img = np.expand_dims(img, axis=-1) return img diff --git a/sleap_io/model/video.py b/sleap_io/model/video.py index edd49489..189b3b8d 100644 --- a/sleap_io/model/video.py +++ b/sleap_io/model/video.py @@ -63,6 +63,7 @@ def from_filename( grayscale: Optional[bool] = None, keep_open: bool = True, source_video: Optional[Video] = None, + backend_metadata: Optional[dict[str, any]] = None, **kwargs, ) -> VideoBackend: """Create a Video from a filename. @@ -82,6 +83,9 @@ def from_filename( source_video: The source video object if this is a proxy video. This is present when the video contains an embedded subset of frames from another video. + backend_metadata: A dictionary of metadata specific to the backend. This is + useful for storing metadata that requires an open backend (e.g., shape + information) without having access to the video file itself. Returns: Video instance with the appropriate backend instantiated. @@ -96,6 +100,7 @@ def from_filename( **kwargs, ), source_video=source_video, + backend_metadata=backend_metadata, ) @property