Skip to content

Commit

Permalink
- Started implementation for post processing fixes; no logic changes
Browse files Browse the repository at this point in the history
- added tracker debugging script
  • Loading branch information
shaikh58 committed Oct 9, 2024
1 parent 3bc9fef commit 1998f6f
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 9 deletions.
12 changes: 9 additions & 3 deletions dreem/inference/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def filter_max_center_dist(
k_boxes: torch.Tensor | None = None,
nonk_boxes: torch.Tensor | None = None,
id_inds: torch.Tensor | None = None,
h: int = None,
w: int = None
) -> torch.Tensor:
"""Filter trajectory score by distances between objects across frames.
Expand All @@ -135,6 +137,8 @@ def filter_max_center_dist(
k_boxes: The bounding boxes in the current frame
nonk_boxes: the boxes not in the current frame
id_inds: track ids
h: height of image
w: width of image
Returns:
An N_t x N association matrix
Expand All @@ -147,13 +151,15 @@ def filter_max_center_dist(
k_s = ((k_boxes[:, :, 2:] - k_boxes[:, :, :2]) ** 2).sum(dim=2) # n_k

nonk_ct = (nonk_boxes[:, :, :2] + nonk_boxes[:, :, 2:]) / 2

# TODO: nonk_boxes should be only from previous frame rather than entire window
dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum(
dim=-1
) # n_k x Np

norm_dist = dist / (k_s[:, None, :] + 1e-8)
# TODO: note that dist is in units of fraction of the height and width of the image;
# TODO: need to scale it by the original image size so that its in units of pixels
# norm_dist = dist / (k_s[:, None, :] + 1e-8)
norm_dist = dist.mean(axis=-1) # n_k x Np
# norm_dist =

valid = norm_dist < max_center_dist # n_k x Np
valid_assn = (
Expand Down
6 changes: 4 additions & 2 deletions dreem/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def sliding_inference(
# if no track ids, then assign new ones
for i, instance in enumerate(frames[batch_idx].instances):
if instance.pred_track_id == -1:
curr_track += 1
curr_track_id += 1
instance.pred_track_id = curr_track_id

else:
Expand Down Expand Up @@ -351,6 +351,8 @@ def _run_global_tracker(

query_frame.add_traj_score("asso_nonquery", asso_nonquery_df)

# need frame height and width to scale boxes during post-processing
_, h, w = query_frame.img_shape.flatten()
pred_boxes = model_utils.get_boxes(all_instances)
query_boxes = pred_boxes[query_inds] # n_k x 4
nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4
Expand Down Expand Up @@ -435,7 +437,7 @@ def _run_global_tracker(
# threshold for continuing a tracking or starting a new track -> they use 1.0
# todo -> should also work without pos_embed
traj_score = post_processing.filter_max_center_dist(
traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds
traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds, h, w
)

if self.max_center_dist is not None and self.max_center_dist > 0:
Expand Down
12 changes: 12 additions & 0 deletions scripts/run_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from dreem.inference import track
from omegaconf import OmegaConf
import os

# /Users/mustafashaikh/dreem/dreem/training
# /Users/main/Documents/GitHub/dreem/dreem/training
# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training")
config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/inference.yaml"

cfg = OmegaConf.load(config)

track.run(cfg)
10 changes: 6 additions & 4 deletions scripts/run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

# /Users/mustafashaikh/dreem/dreem/training
# /Users/main/Documents/GitHub/dreem/dreem/training
os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training")
base_config = "./configs/base.yaml"
# params_config = "./configs/override.yaml"
# os.chdir("/Users/main/Documents/GitHub/dreem/dreem/training")
base_config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/base-updated.yaml"
params_config = "/root/vast/mustafa/dreem-experiments/run/lysosome-baselines/debug/configs/override-updated.yaml"

cfg = OmegaConf.load(base_config)
# cfg["params_config"] = params_config
# Load and merge override config
override_cfg = OmegaConf.load(params_config)
cfg = OmegaConf.merge(cfg, override_cfg)

train.run(cfg)
4 changes: 4 additions & 0 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ def test_post_processing(): # set_default_device
k_boxes=k_boxes,
nonk_boxes=nonk_boxes,
id_inds=id_inds,
h=im_size,
w=im_size
)
).all()

Expand All @@ -226,6 +228,8 @@ def test_post_processing(): # set_default_device
k_boxes=k_boxes,
nonk_boxes=nonk_boxes,
id_inds=id_inds,
h=im_size,
w=im_size
)
).all()

Expand Down

0 comments on commit 1998f6f

Please sign in to comment.