Skip to content

Commit

Permalink
Merge branch 'main' into mustafa-rope
Browse files Browse the repository at this point in the history
  • Loading branch information
shaikh58 committed Dec 13, 2024
2 parents 1998f6f + 4341167 commit 2792d13
Show file tree
Hide file tree
Showing 20 changed files with 720 additions and 148 deletions.
2 changes: 1 addition & 1 deletion docs/configs/index.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# DREEM Config API

We utilize `.yaml` based configs with `hydra` and `omegaconf` for config parsing.
We utilize `.yaml` based configs with [`hydra`](https://hydra.cc) and [`omegaconf`](https://omegaconf.readthedocs.io/en/2.3_branch/) for config parsing.
2 changes: 1 addition & 1 deletion docs/configs/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Here, we describe the hyperparameters used for setting up training. Please see [here](./training.md#example-config) for an example training config.

> Note: for using defaults, simply leave the field blank or don't include the key. Using `null` will initialize the value to `None` e.g
> Note: for using defaults, simply leave the field blank or don't include the key. Using `null` will initialize the value to `None` which we use to represent turning off certain features such as logging, early stopping etc. e.g
> ```YAML
> model:
> d_model: #defaults to 1024
Expand Down
3 changes: 2 additions & 1 deletion dreem/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dreem.datasets import data_utils
from dreem.io import Frame
from torch.utils.data import Dataset
from typing import Union
import numpy as np
import torch

Expand All @@ -15,7 +16,7 @@ def __init__(
label_files: list[str],
vid_files: list[str],
padding: int,
crop_size: int,
crop_size: Union[int, list[int]],
chunk: bool,
clip_length: int,
mode: str,
Expand Down
61 changes: 54 additions & 7 deletions dreem/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import numpy as np
import sleap_io as sio
import random
from pathlib import Path
import logging
from typing import Union, Optional
from dreem.io import Instance, Frame
from dreem.datasets import data_utils, BaseDataset
from torchvision.transforms import functional as tvf
Expand All @@ -21,8 +23,9 @@ def __init__(
self,
slp_files: list[str],
video_files: list[str],
data_dirs: Optional[list[str]] = None,
padding: int = 5,
crop_size: int = 128,
crop_size: Union[int, list[int]] = 128,
anchors: int | list[str] | str = "",
chunk: bool = True,
clip_length: int = 500,
Expand All @@ -32,14 +35,19 @@ def __init__(
n_chunks: int | float = 1.0,
seed: int | None = None,
verbose: bool = False,
normalize_image: bool = True,
):
"""Initialize SleapDataset.
Args:
slp_files: a list of .slp files storing tracking annotations
video_files: a list of paths to video files
data_dirs: a path, or a list of paths to data directories. If provided, crop_size should be a list of integers
with the same length as data_dirs.
padding: amount of padding around object crops
crop_size: the size of the object crops
crop_size: the size of the object crops. Can be either:
- An integer specifying a single crop size for all objects
- A list of integers specifying different crop sizes for different data directories
anchors: One of:
* a string indicating a single node to center crops around
* a list of skeleton node names to be used as the center of crops
Expand All @@ -64,6 +72,7 @@ def __init__(
Can either a fraction of the dataset (ie (0,1.0]) or number of chunks
seed: set a seed for reproducibility
verbose: boolean representing whether to print
normalize_image: whether to normalize the image to [0, 1]
"""
super().__init__(
slp_files,
Expand All @@ -79,6 +88,9 @@ def __init__(
)

self.slp_files = slp_files
self.data_dirs = (
data_dirs # empty list, list of paths, or string of single path
)
self.video_files = video_files
self.padding = padding
self.crop_size = crop_size
Expand All @@ -88,14 +100,33 @@ def __init__(
self.handle_missing = handle_missing.lower()
self.n_chunks = n_chunks
self.seed = seed

self.normalize_image = normalize_image
if self.data_dirs is None:
self.data_dirs = []
if isinstance(anchors, int):
self.anchors = anchors
elif isinstance(anchors, str):
self.anchors = [anchors]
else:
self.anchors = anchors

if not isinstance(self.data_dirs, list):
self.data_dirs = [self.data_dirs]

if not isinstance(self.crop_size, list):
# make a list so its handled consistently if multiple crops are used
if len(self.data_dirs) > 0: # for test mode, data_dirs is []
self.crop_size = [self.crop_size] * len(self.data_dirs)
else:
self.crop_size = [self.crop_size]

if len(self.data_dirs) > 0 and len(self.crop_size) != len(self.data_dirs):
raise ValueError(
f"If a list of crop sizes or data directories are given,"
f"they must have the same length but got {len(self.crop_size)} "
f"and {len(self.data_dirs)}"
)

if (
isinstance(self.anchors, list) and len(self.anchors) == 0
) or self.anchors == 0:
Expand Down Expand Up @@ -139,6 +170,20 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram
video = self.labels[label_idx]

video_name = self.video_files[label_idx]

# get the correct crop size based on the video
video_par_path = Path(video_name).parent
if len(self.data_dirs) > 0:
crop_size = self.crop_size[0]
for j, data_dir in enumerate(self.data_dirs):
if Path(data_dir) == video_par_path:
crop_size = self.crop_size[j]
break
else:
crop_size = self.crop_size[0]

vid_reader = self.videos[label_idx]

# img = vid_reader.get_data(0)

skeleton = video.skeletons[-1]
Expand Down Expand Up @@ -176,7 +221,9 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram
) # convert to grayscale to rgb

if np.issubdtype(img.dtype, np.integer): # convert int to float
img = img.astype(np.float32) / 255
img = img.astype(np.float32)
if self.normalize_image:
img = img / 255

n_instances_dropped = 0

Expand Down Expand Up @@ -313,15 +360,15 @@ def get_instances(self, label_idx: list[int], frame_idx: list[int]) -> list[Fram

else:
bbox = data_utils.pad_bbox(
data_utils.get_bbox(centroid, self.crop_size),
data_utils.get_bbox(centroid, crop_size),
padding=self.padding,
)

if bbox.isnan().all():
crop = torch.zeros(
c,
self.crop_size + 2 * self.padding,
self.crop_size + 2 * self.padding,
crop_size + 2 * self.padding,
crop_size + 2 * self.padding,
dtype=img.dtype,
)
else:
Expand Down
8 changes: 3 additions & 5 deletions dreem/inference/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,9 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
else:
hparams = {}

checkpoint = eval_cfg.cfg.ckpt_path

logger.info(f"Testing model saved at {checkpoint}")
model = GTRRunner.load_from_checkpoint(checkpoint)
logging.getLogger().setLevel(level=cfg.get("log_level", "INFO").upper())

model = GTRRunner.load_from_checkpoint(checkpoint, strict=False)
model.tracker_cfg = eval_cfg.cfg.tracker
model.tracker = Tracker(**model.tracker_cfg)

Expand All @@ -61,7 +59,7 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
"persistent_tracking", False
)
logger.info(f"Computing the following metrics:")
logger.info(model.metrics.test)
logger.info(model.metrics["test"])
model.test_results["save_path"] = eval_cfg.cfg.runner.save_path
logger.info(f"Saving results to {model.test_results['save_path']}")

Expand Down
24 changes: 18 additions & 6 deletions dreem/inference/track.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dreem.models import GTRRunner
from omegaconf import DictConfig
from pathlib import Path
from datetime import datetime

import hydra
import os
Expand All @@ -14,9 +15,20 @@
import sleap_io as sio
import logging


logger = logging.getLogger("dreem.inference")


def get_timestamp() -> str:
"""Get current timestamp.
Returns:
the current timestamp in /m/d/y-H:M:S format
"""
date_time = datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
return date_time


def export_trajectories(
frames_pred: list["dreem.io.Frame"], save_path: str | None = None
) -> pd.DataFrame:
Expand Down Expand Up @@ -115,11 +127,9 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
else:
hparams = {}

checkpoint = pred_cfg.cfg.ckpt_path

logger.info(f"Running inference with model from {checkpoint}")
model = GTRRunner.load_from_checkpoint(checkpoint)
logging.getLogger().setLevel(level=cfg.get("log_level", "INFO").upper())

model = GTRRunner.load_from_checkpoint(checkpoint, strict=False)
tracker_cfg = pred_cfg.get_tracker_cfg()

model.tracker_cfg = tracker_cfg
Expand All @@ -140,8 +150,10 @@ def run(cfg: DictConfig) -> dict[int, sio.Labels]:
)
dataloader = pred_cfg.get_dataloader(dataset, mode="test")
preds = track(model, trainer, dataloader)
outpath = os.path.join(outdir, f"{Path(label_file).stem}.dreem_inference.slp")
logger.info(f"Saving results to {outpath}...")
outpath = os.path.join(
outdir, f"{Path(label_file).stem}.dreem_inference.{get_timestamp()}.slp"
)

preds.save(outpath)

return preds
Expand Down
Loading

0 comments on commit 2792d13

Please sign in to comment.