From ef059008e1d13ffdf63156c2cede6e6320564cb9 Mon Sep 17 00:00:00 2001 From: shaikh58 Date: Thu, 19 Dec 2024 16:33:37 +0400 Subject: [PATCH] Bug fix max filter dist (#105) --- dreem/inference/post_processing.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/dreem/inference/post_processing.py b/dreem/inference/post_processing.py index cae071c..9b78a95 100644 --- a/dreem/inference/post_processing.py +++ b/dreem/inference/post_processing.py @@ -158,8 +158,17 @@ def filter_max_center_dist( # norm_dist = dist / (k_s[:, None, :] + 1e-8) valid = dist.squeeze() < max_center_dist # n_k x n_nonk - valid_mult = valid.float().unsqueeze(-1) if valid.ndim == 1 else valid.float() - # print(dist.shape, valid_mult.shape, id_inds.shape) + # handle case where id_inds and valid is a single value + # handle this better + if valid.ndim == 0: valid = valid.unsqueeze(0) + if valid.ndim == 1: + if id_inds.shape[0] == 1: + valid_mult = valid.float().unsqueeze(-1) + else: + valid_mult = valid.float().unsqueeze(0) + else: + valid_mult = valid.float() + valid_assn = ( torch.mm(valid_mult, id_inds.to(valid.device)).clamp_(max=1.0).long().bool() ) # n_k x M