From 58faf9c34092c5b913fed2a6821b877db6253ce6 Mon Sep 17 00:00:00 2001 From: John Bradley Date: Thu, 6 Feb 2025 11:59:02 -0500 Subject: [PATCH] Fix subsetting range bug Fixes a RuntimeError when using the --subset option with a CSV that has too few species. This error only occurred when k > species. This change will return less than k entries when there are too few species. This matches the behavior of heapq.nlargest() used for non species grouping. Fixes #87 --- src/bioclip/predict.py | 2 ++ tests/test_predict.py | 13 +++++++++++++ 2 files changed, 15 insertions(+) diff --git a/src/bioclip/predict.py b/src/bioclip/predict.py index 0275b28..51765c1 100644 --- a/src/bioclip/predict.py +++ b/src/bioclip/predict.py @@ -544,6 +544,8 @@ def apply_filter(self, keep_labels_ary: List[bool]): self._subset_txt_names = names def format_species_probs(self, image_key: str, probs: torch.Tensor, k: int = 5) -> List[dict[str, float]]: + # Prevent error when probs is smaller than k + k = min(k, probs.shape[0]) topk = probs.topk(k) result = [] for i, prob in zip(topk.indices, topk.values): diff --git a/tests/test_predict.py b/tests/test_predict.py index 8b650f4..023295e 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -281,6 +281,19 @@ def test_create_taxa_filter_from_csv(self): def test_get_rank_labels(self): self.assertEqual(','.join(get_rank_labels()), 'kingdom,phylum,class,order,family,genus,species') + def test_format_species_probs_too_few_species(self): + classifier = TreeOfLifeClassifier() + + # test when k < number of probabilities + probs = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]) + top_probs = classifier.format_species_probs(EXAMPLE_CAT_IMAGE, probs, k=5) + self.assertEqual(len(top_probs), 5) + self.assertEqual(top_probs[0]['file_name'], EXAMPLE_CAT_IMAGE) + + # test when k > number of probabilities + probs = torch.tensor([0.1, 0.2, 0.3, 0.4]) + top_probs = classifier.format_species_probs(EXAMPLE_CAT_IMAGE, probs, k=5) + class TestEmbed(unittest.TestCase): def test_get_image_features(self):