Skip to content

Commit

Permalink
Filter max dist (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
shaikh58 authored Dec 19, 2024
1 parent 4341167 commit d79f2f7
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 30 deletions.
47 changes: 24 additions & 23 deletions dreem/inference/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,44 +123,45 @@ def weight_iou(
def filter_max_center_dist(
asso_output: torch.Tensor,
max_center_dist: float = 0,
k_boxes: torch.Tensor | None = None,
nonk_boxes: torch.Tensor | None = None,
id_inds: torch.Tensor | None = None,
curr_frame_boxes: torch.Tensor | None = None,
prev_frame_boxes: torch.Tensor | None = None,
) -> torch.Tensor:
"""Filter trajectory score by distances between objects across frames.
Args:
asso_output: An N_t x N association matrix
max_center_dist: The euclidean distance threshold between bboxes
k_boxes: The bounding boxes in the current frame
nonk_boxes: the boxes not in the current frame
id_inds: track ids
curr_frame_boxes: the raw bbox coords of the current frame instances
prev_frame_boxes: the raw bbox coords of the previous frame instances
Returns:
An N_t x N association matrix
"""
if max_center_dist is not None and max_center_dist > 0:
assert (
k_boxes is not None and nonk_boxes is not None and id_inds is not None
), "Need `k_boxes`, `nonk_boxes`, and `id_ind` to filter by `max_center_dist`"
k_ct = (k_boxes[:, :, :2] + k_boxes[:, :, 2:]) / 2
k_s = ((k_boxes[:, :, 2:] - k_boxes[:, :, :2]) ** 2).sum(dim=2) # n_k

nonk_ct = (nonk_boxes[:, :, :2] + nonk_boxes[:, :, 2:]) / 2

dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum(
dim=-1
) # n_k x Np

norm_dist = dist / (k_s[:, None, :] + 1e-8)
norm_dist = dist.mean(axis=-1) # n_k x Np

valid = norm_dist < max_center_dist # n_k x Np
curr_frame_boxes is not None
and prev_frame_boxes is not None
and id_inds is not None
), "Need `curr_frame_boxes`, `prev_frame_boxes`, and `id_ind` to filter by `max_center_dist`"

k_ct = (curr_frame_boxes[:, :, :2] + curr_frame_boxes[:, :, 2:]) / 2
# k_s = ((curr_frame_boxes[:, :, 2:] - curr_frame_boxes[:, :, :2]) ** 2).sum(dim=2) # n_k
# nonk boxes are only from previous frame rather than entire window
nonk_ct = (prev_frame_boxes[:, :, :2] + prev_frame_boxes[:, :, 2:]) / 2

# pairwise euclidean distance in units of pixels
dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum(dim=-1) ** (
1 / 2
) # n_k x n_nonk
# norm_dist = dist / (k_s[:, None, :] + 1e-8)

valid = dist.squeeze() < max_center_dist # n_k x n_nonk
valid_mult = valid.float().unsqueeze(-1) if valid.ndim == 1 else valid.float()
print(dist.shape, valid_mult.shape, id_inds.shape)
valid_assn = (
torch.mm(valid.float(), id_inds.to(valid.device))
.clamp_(max=1.0)
.long()
.bool()
torch.mm(valid_mult, id_inds.to(valid.device)).clamp_(max=1.0).long().bool()
) # n_k x M
asso_output_filtered = asso_output.clone()
asso_output_filtered[~valid_assn] = 0 # n_k x M
Expand Down
24 changes: 23 additions & 1 deletion dreem/inference/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def _run_global_tracker(
sum(instances_per_frame[: query_ind + 1]),
)
]

nonquery_inds = [i for i in range(total_instances) if i not in query_inds]

# instead should we do model(nonquery_instances, query_instances)?
Expand All @@ -343,6 +344,19 @@ def _run_global_tracker(

query_frame.add_traj_score("asso_nonquery", asso_nonquery_df)

# get raw bbox coords of prev frame instances from frame.instances_per_frame
prev_frame_ind = query_ind - 1
prev_frame_instances = frames[prev_frame_ind].instances
prev_frame_instance_ids = torch.cat(
[instance.pred_track_id for instance in prev_frame_instances], dim=0
)
prev_frame_boxes = torch.cat(
[instance.bbox for instance in prev_frame_instances], dim=0
)
curr_frame_boxes = torch.cat(
[instance.bbox for instance in query_frame.instances], dim=0
)

pred_boxes = model_utils.get_boxes(all_instances)
query_boxes = pred_boxes[query_inds] # n_k x 4
nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4
Expand All @@ -356,6 +370,10 @@ def _run_global_tracker(
unique_ids[None, :] == instance_ids[:, None]
).float() # (n_nonquery, n_traj)

prev_frame_id_inds = (
unique_ids[None, :] == prev_frame_instance_ids[:, None]
).float() # (n_prev_frame_instances, n_traj)

################################################################################

# reweighting hyper-parameters for association -> they use 0.9
Expand Down Expand Up @@ -425,7 +443,11 @@ def _run_global_tracker(
# threshold for continuing a tracking or starting a new track -> they use 1.0
# todo -> should also work without pos_embed
traj_score = post_processing.filter_max_center_dist(
traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds
traj_score,
self.max_center_dist,
prev_frame_id_inds,
curr_frame_boxes,
prev_frame_boxes,
)

if self.max_center_dist is not None and self.max_center_dist > 0:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,18 @@ def test_post_processing(): # set_default_device
).all()

im_size = 128
k_boxes = torch.rand((N_t, 3, 4)) * im_size
nonk_boxes = torch.rand((N_p, 3, 4)) * im_size
k_boxes = torch.rand((N_t, 1, 4)) * im_size
nonk_boxes = torch.rand((N_p, 1, 4)) * im_size
id_inds = torch.tile(torch.cat((torch.zeros(M - 1), torch.ones(1))), (N_p, 1))

assert (
asso_output
== post_processing.filter_max_center_dist(
asso_output=asso_output,
max_center_dist=0,
k_boxes=k_boxes,
nonk_boxes=nonk_boxes,
id_inds=id_inds,
curr_frame_boxes=k_boxes,
prev_frame_boxes=nonk_boxes,
)
).all()

Expand All @@ -223,9 +223,9 @@ def test_post_processing(): # set_default_device
== post_processing.filter_max_center_dist(
asso_output=asso_output,
max_center_dist=1e-9,
k_boxes=k_boxes,
nonk_boxes=nonk_boxes,
id_inds=id_inds,
curr_frame_boxes=k_boxes,
prev_frame_boxes=nonk_boxes,
)
).all()

Expand Down

0 comments on commit d79f2f7

Please sign in to comment.