Skip to content

Commit

Permalink
Bug fix max filter dist (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
shaikh58 authored Dec 19, 2024
1 parent 08256ae commit ef05900
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions dreem/inference/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,17 @@ def filter_max_center_dist(
# norm_dist = dist / (k_s[:, None, :] + 1e-8)

valid = dist.squeeze() < max_center_dist # n_k x n_nonk
valid_mult = valid.float().unsqueeze(-1) if valid.ndim == 1 else valid.float()
# print(dist.shape, valid_mult.shape, id_inds.shape)
# handle case where id_inds and valid is a single value
# handle this better
if valid.ndim == 0: valid = valid.unsqueeze(0)
if valid.ndim == 1:
if id_inds.shape[0] == 1:
valid_mult = valid.float().unsqueeze(-1)
else:
valid_mult = valid.float().unsqueeze(0)
else:
valid_mult = valid.float()

valid_assn = (
torch.mm(valid_mult, id_inds.to(valid.device)).clamp_(max=1.0).long().bool()
) # n_k x M
Expand Down

0 comments on commit ef05900

Please sign in to comment.