From bf7ff2aa81f93f5c7314941a7f230fd6d7b3a093 Mon Sep 17 00:00:00 2001 From: John Bradley Date: Wed, 5 Feb 2025 09:13:16 -0500 Subject: [PATCH] Fix grouping performance bug Ensures that probabilities are moved to the cpu before appling grouping logic to avoid performance issues moving data from CUDA/MPS -> CPU again and again. Fixes #84 --- src/bioclip/predict.py | 5 +++-- tests/test_predict.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/bioclip/predict.py b/src/bioclip/predict.py index 1e87e8d..bd9abe4 100644 --- a/src/bioclip/predict.py +++ b/src/bioclip/predict.py @@ -575,10 +575,11 @@ def predict(self, images: List[str] | str | List[PIL.Image.Image], rank: Rank, m result = [] for i, image in enumerate(images): key = self.make_key(image, i) + image_probs = probs[key].cpu() if rank == Rank.SPECIES: - result.extend(self.format_species_probs(key, probs[key], k)) + result.extend(self.format_species_probs(key, image_probs, k)) else: - result.extend(self.format_grouped_probs(key, probs[key], rank, min_prob, k)) + result.extend(self.format_grouped_probs(key, image_probs, rank, min_prob, k)) return result diff --git a/tests/test_predict.py b/tests/test_predict.py index ad57aab..7ecc16e 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import patch, mock_open +from unittest.mock import patch, mock_open, Mock, ANY from bioclip.predict import TreeOfLifeClassifier, Rank, get_rank_labels from bioclip.predict import CustomLabelsClassifier from bioclip.predict import CustomLabelsBinningClassifier @@ -67,6 +67,34 @@ def test_tree_of_life_classifier_family(self): } self.assertEqual(prediction_ary[0], prediction_dict) + def test_tree_of_life_classifier_groups_probs_on_cpu(self): + # Ensure that the probabilities are moved to the cpu + # before grouping to avoid performance issues + classifier = TreeOfLifeClassifier() + + # Have create_probabilities_for_images return mock probs + # with values returned from cpu() + probs = Mock() + probs.cpu.return_value = torch.Tensor([0.1, 0.2, 0.3]) + classifier.create_probabilities_for_images = Mock() + classifier.create_probabilities_for_images.return_value = { + EXAMPLE_CAT_IMAGE: probs + } + + # Mock format_grouped_probs so we can check the parameters + classifier.format_grouped_probs = Mock() + classifier.format_grouped_probs.return_value = [] + + classifier.predict(images=[EXAMPLE_CAT_IMAGE], rank=Rank.CLASS, k=2) + + # Ensure that the probabilities were moved to the cpu + classifier.format_grouped_probs.assert_called_with( + EXAMPLE_CAT_IMAGE, + probs.cpu.return_value, + Rank.CLASS, + ANY, 2 + ) + def test_custom_labels_classifier(self): classifier = CustomLabelsClassifier(cls_ary=['cat', 'dog']) prediction_ary = classifier.predict(images=EXAMPLE_CAT_IMAGE)