Skip to content

Commit

Permalink
bugfix in collector batch to tensor when top_k > number of elements
Browse files Browse the repository at this point in the history
Laouen committed Jan 6, 2025
1 parent 7b64458 commit 5c9e468
Showing 2 changed files with 5 additions and 81 deletions.
7 changes: 5 additions & 2 deletions thoi/collectors.py
Original file line number Diff line number Diff line change
@@ -376,9 +376,12 @@ def batch_to_tensor(nplets_idxs: torch.Tensor,
metric,
largest)

metric_func = partial(_get_string_metric, metric=metric) if isinstance(metric, str) else metric

# If not top_k or len(nplets_measuresa) > top_k return the original values
# |k x D x 4|, |k x N|
return (nplets_measures, nplets_idxs, None)
# |k x D x 4|, |k x N|, |k|
values = metric_func(nplets_measures).to(nplets_measures.device)
return (nplets_measures, nplets_idxs, values)


def concat_batched_tensors(batched_tensors: List[Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]],
79 changes: 0 additions & 79 deletions thoi/graph.py

This file was deleted.

0 comments on commit 5c9e468

Please sign in to comment.