diff --git a/dreem/inference/post_processing.py b/dreem/inference/post_processing.py index 09fd8ff..d87e8a0 100644 --- a/dreem/inference/post_processing.py +++ b/dreem/inference/post_processing.py @@ -123,44 +123,45 @@ def weight_iou( def filter_max_center_dist( asso_output: torch.Tensor, max_center_dist: float = 0, - k_boxes: torch.Tensor | None = None, - nonk_boxes: torch.Tensor | None = None, id_inds: torch.Tensor | None = None, + curr_frame_boxes: torch.Tensor | None = None, + prev_frame_boxes: torch.Tensor | None = None, ) -> torch.Tensor: """Filter trajectory score by distances between objects across frames. Args: asso_output: An N_t x N association matrix max_center_dist: The euclidean distance threshold between bboxes - k_boxes: The bounding boxes in the current frame - nonk_boxes: the boxes not in the current frame id_inds: track ids + curr_frame_boxes: the raw bbox coords of the current frame instances + prev_frame_boxes: the raw bbox coords of the previous frame instances Returns: An N_t x N association matrix """ if max_center_dist is not None and max_center_dist > 0: assert ( - k_boxes is not None and nonk_boxes is not None and id_inds is not None - ), "Need `k_boxes`, `nonk_boxes`, and `id_ind` to filter by `max_center_dist`" - k_ct = (k_boxes[:, :, :2] + k_boxes[:, :, 2:]) / 2 - k_s = ((k_boxes[:, :, 2:] - k_boxes[:, :, :2]) ** 2).sum(dim=2) # n_k - - nonk_ct = (nonk_boxes[:, :, :2] + nonk_boxes[:, :, 2:]) / 2 - - dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum( - dim=-1 - ) # n_k x Np - - norm_dist = dist / (k_s[:, None, :] + 1e-8) - norm_dist = dist.mean(axis=-1) # n_k x Np - - valid = norm_dist < max_center_dist # n_k x Np + curr_frame_boxes is not None + and prev_frame_boxes is not None + and id_inds is not None + ), "Need `curr_frame_boxes`, `prev_frame_boxes`, and `id_ind` to filter by `max_center_dist`" + + k_ct = (curr_frame_boxes[:, :, :2] + curr_frame_boxes[:, :, 2:]) / 2 + # k_s = ((curr_frame_boxes[:, :, 2:] - curr_frame_boxes[:, :, :2]) ** 2).sum(dim=2) # n_k + # nonk boxes are only from previous frame rather than entire window + nonk_ct = (prev_frame_boxes[:, :, :2] + prev_frame_boxes[:, :, 2:]) / 2 + + # pairwise euclidean distance in units of pixels + dist = ((k_ct[:, None, :, :] - nonk_ct[None, :, :, :]) ** 2).sum(dim=-1) ** ( + 1 / 2 + ) # n_k x n_nonk + # norm_dist = dist / (k_s[:, None, :] + 1e-8) + + valid = dist.squeeze() < max_center_dist # n_k x n_nonk + valid_mult = valid.float().unsqueeze(-1) if valid.ndim == 1 else valid.float() + print(dist.shape, valid_mult.shape, id_inds.shape) valid_assn = ( - torch.mm(valid.float(), id_inds.to(valid.device)) - .clamp_(max=1.0) - .long() - .bool() + torch.mm(valid_mult, id_inds.to(valid.device)).clamp_(max=1.0).long().bool() ) # n_k x M asso_output_filtered = asso_output.clone() asso_output_filtered[~valid_assn] = 0 # n_k x M diff --git a/dreem/inference/tracker.py b/dreem/inference/tracker.py index 9f2e713..cddf7ba 100644 --- a/dreem/inference/tracker.py +++ b/dreem/inference/tracker.py @@ -329,6 +329,7 @@ def _run_global_tracker( sum(instances_per_frame[: query_ind + 1]), ) ] + nonquery_inds = [i for i in range(total_instances) if i not in query_inds] # instead should we do model(nonquery_instances, query_instances)? @@ -343,6 +344,19 @@ def _run_global_tracker( query_frame.add_traj_score("asso_nonquery", asso_nonquery_df) + # get raw bbox coords of prev frame instances from frame.instances_per_frame + prev_frame_ind = query_ind - 1 + prev_frame_instances = frames[prev_frame_ind].instances + prev_frame_instance_ids = torch.cat( + [instance.pred_track_id for instance in prev_frame_instances], dim=0 + ) + prev_frame_boxes = torch.cat( + [instance.bbox for instance in prev_frame_instances], dim=0 + ) + curr_frame_boxes = torch.cat( + [instance.bbox for instance in query_frame.instances], dim=0 + ) + pred_boxes = model_utils.get_boxes(all_instances) query_boxes = pred_boxes[query_inds] # n_k x 4 nonquery_boxes = pred_boxes[nonquery_inds] # n_nonquery x 4 @@ -356,6 +370,10 @@ def _run_global_tracker( unique_ids[None, :] == instance_ids[:, None] ).float() # (n_nonquery, n_traj) + prev_frame_id_inds = ( + unique_ids[None, :] == prev_frame_instance_ids[:, None] + ).float() # (n_prev_frame_instances, n_traj) + ################################################################################ # reweighting hyper-parameters for association -> they use 0.9 @@ -425,7 +443,11 @@ def _run_global_tracker( # threshold for continuing a tracking or starting a new track -> they use 1.0 # todo -> should also work without pos_embed traj_score = post_processing.filter_max_center_dist( - traj_score, self.max_center_dist, query_boxes, nonquery_boxes, id_inds + traj_score, + self.max_center_dist, + prev_frame_id_inds, + curr_frame_boxes, + prev_frame_boxes, ) if self.max_center_dist is not None and self.max_center_dist > 0: diff --git a/tests/test_inference.py b/tests/test_inference.py index 2b55484..7198f8c 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -203,8 +203,8 @@ def test_post_processing(): # set_default_device ).all() im_size = 128 - k_boxes = torch.rand((N_t, 3, 4)) * im_size - nonk_boxes = torch.rand((N_p, 3, 4)) * im_size + k_boxes = torch.rand((N_t, 1, 4)) * im_size + nonk_boxes = torch.rand((N_p, 1, 4)) * im_size id_inds = torch.tile(torch.cat((torch.zeros(M - 1), torch.ones(1))), (N_p, 1)) assert ( @@ -212,9 +212,9 @@ def test_post_processing(): # set_default_device == post_processing.filter_max_center_dist( asso_output=asso_output, max_center_dist=0, - k_boxes=k_boxes, - nonk_boxes=nonk_boxes, id_inds=id_inds, + curr_frame_boxes=k_boxes, + prev_frame_boxes=nonk_boxes, ) ).all() @@ -223,9 +223,9 @@ def test_post_processing(): # set_default_device == post_processing.filter_max_center_dist( asso_output=asso_output, max_center_dist=1e-9, - k_boxes=k_boxes, - nonk_boxes=nonk_boxes, id_inds=id_inds, + curr_frame_boxes=k_boxes, + prev_frame_boxes=nonk_boxes, ) ).all()