Skip to content

Commit

Permalink
Merge pull request #83 from Imageomics/80-progress-bar
Browse files Browse the repository at this point in the history
Add batch_size parameter for predict() and cli
  • Loading branch information
johnbradley authored Feb 5, 2025
2 parents e07bd69 + 43547d3 commit 306a0be
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 81 deletions.
3 changes: 3 additions & 0 deletions docs/command-line-help.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ usage: bioclip predict [-h] [--format {table,csv}] [--output OUTPUT]
[--rank {kingdom,phylum,class,order,family,genus,species} |
--cls CLS | --bins BINS | --subset SUBSET] [--k K]
[--device DEVICE] [--model MODEL] [--pretrained PRETRAINED]
[--batch-size BATCH_SIZE]
image_file [image_file ...]
positional arguments:
Expand Down Expand Up @@ -40,6 +41,8 @@ options:
pretrained model checkpoint as tag or file, depends on
model; needed only if more than one is available
(see command list-models)
--batch-size BATCH_SIZE
Number of images to process in a batch, default: 10
```

## bioclip embed
Expand Down
66 changes: 34 additions & 32 deletions examples/PredictImages.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,7 @@
"execution_count": 1,
"id": "d20352a2-2ac2-4a6f-bcd5-9b12cf622176",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.0.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.2\u001b[0m\n",
"\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n"
]
}
],
"outputs": [],
"source": [
"!pip install pybioclip ipywidgets --quiet"
]
Expand Down Expand Up @@ -116,6 +106,13 @@
"id": "92cae978-6565-46f1-94f7-e4e2fad528a9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 6.40images/s]\n"
]
},
{
"data": {
"text/html": [
Expand Down Expand Up @@ -177,7 +174,7 @@
" <td>arctos syriacus</td>\n",
" <td>Ursus arctos syriacus</td>\n",
" <td>syrian brown bear</td>\n",
" <td>0.056171</td>\n",
" <td>0.056170</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
Expand Down Expand Up @@ -233,7 +230,7 @@
" <td>eques</td>\n",
" <td>Urticina eques</td>\n",
" <td>Horseman anemone</td>\n",
" <td>0.648790</td>\n",
" <td>0.648787</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
Expand All @@ -247,7 +244,7 @@
" <td>felina</td>\n",
" <td>Urticina felina</td>\n",
" <td>Northern Red Anemone</td>\n",
" <td>0.128011</td>\n",
" <td>0.128014</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
Expand Down Expand Up @@ -289,7 +286,7 @@
" <td>quadricolor</td>\n",
" <td>Entacmaea quadricolor</td>\n",
" <td>Bubble-tip Anemone</td>\n",
" <td>0.022781</td>\n",
" <td>0.022782</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
Expand All @@ -303,7 +300,7 @@
" <td>silvestris</td>\n",
" <td>Felis silvestris</td>\n",
" <td>European Wildcat</td>\n",
" <td>0.722105</td>\n",
" <td>0.722103</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
Expand All @@ -317,7 +314,7 @@
" <td>catus</td>\n",
" <td>Felis catus</td>\n",
" <td>Domestic Cat</td>\n",
" <td>0.198107</td>\n",
" <td>0.198109</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
Expand All @@ -331,7 +328,7 @@
" <td>margarita</td>\n",
" <td>Felis margarita</td>\n",
" <td>Sand Cat</td>\n",
" <td>0.027984</td>\n",
" <td>0.027985</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
Expand Down Expand Up @@ -402,18 +399,18 @@
"\n",
" species common_name score \n",
"0 Ursus arctos Kodiak bear 0.935603 \n",
"1 Ursus arctos syriacus syrian brown bear 0.056171 \n",
"1 Ursus arctos syriacus syrian brown bear 0.056170 \n",
"2 Ursus arctos bruinosus 0.004126 \n",
"3 Ursus arctus 0.002496 \n",
"4 Ursus americanus Louisiana black bear 0.000501 \n",
"5 Urticina eques Horseman anemone 0.648790 \n",
"6 Urticina felina Northern Red Anemone 0.128011 \n",
"5 Urticina eques Horseman anemone 0.648787 \n",
"6 Urticina felina Northern Red Anemone 0.128014 \n",
"7 Polyphyllia novaehiberniae Slipper Coral 0.062747 \n",
"8 Heteractis magnifica Magnificent Sea Anemone 0.026706 \n",
"9 Entacmaea quadricolor Bubble-tip Anemone 0.022781 \n",
"10 Felis silvestris European Wildcat 0.722105 \n",
"11 Felis catus Domestic Cat 0.198107 \n",
"12 Felis margarita Sand Cat 0.027984 \n",
"9 Entacmaea quadricolor Bubble-tip Anemone 0.022782 \n",
"10 Felis silvestris European Wildcat 0.722103 \n",
"11 Felis catus Domestic Cat 0.198109 \n",
"12 Felis margarita Sand Cat 0.027985 \n",
"13 Lynx felis 0.021830 \n",
"14 Felis bieti Chinese desert cat 0.010979 "
]
Expand All @@ -425,10 +422,7 @@
],
"source": [
"classifier = TreeOfLifeClassifier(device='cpu')\n",
"prediction_ary = []\n",
"for image_path in image_paths:\n",
" prediction_ary.extend(classifier.predict(image_path, rank=Rank.SPECIES))\n",
"\n",
"prediction_ary = classifier.predict(image_paths, rank=Rank.SPECIES)\n",
"df = pd.DataFrame(prediction_ary)\n",
"df"
]
Expand Down Expand Up @@ -538,7 +532,7 @@
" <td>arctos syriacus</td>\n",
" <td>Ursus arctos syriacus</td>\n",
" <td>syrian brown bear</td>\n",
" <td>0.056171</td>\n",
" <td>0.056170</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
Expand Down Expand Up @@ -603,7 +597,7 @@
"\n",
" score \n",
"0 0.935603 \n",
"1 0.056171 \n",
"1 0.056170 \n",
"2 0.004126 \n",
"3 0.002496 \n",
"4 0.000501 "
Expand All @@ -621,6 +615,14 @@
"display(image)\n",
"df[df['file_name'] == example_image_path]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1cbe3a25-9f57-4c51-b2f8-ad985acc88af",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -639,7 +641,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
53 changes: 24 additions & 29 deletions examples/iNaturalistPredict.ipynb

Large diffs are not rendered by default.

13 changes: 9 additions & 4 deletions src/bioclip/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,23 @@ def predict(image_file: list[str],
bins_path: str,
k: int,
subset: str,
batch_size: int,
**kwargs):
if cls_str:
classifier = CustomLabelsClassifier(cls_ary=cls_str.split(','), **kwargs)
predictions = classifier.predict(images=image_file, k=k)
predictions = classifier.predict(images=image_file, k=k, batch_size=batch_size)
write_results(predictions, format, output)
elif bins_path:
cls_to_bin = parse_bins_csv(bins_path)
classifier = CustomLabelsBinningClassifier(cls_to_bin=cls_to_bin, **kwargs)
predictions = classifier.predict(images=image_file, k=k)
predictions = classifier.predict(images=image_file, k=k, batch_size=batch_size)
write_results(predictions, format, output)
else:
classifier = TreeOfLifeClassifier(**kwargs)
if subset:
filter = classifier.create_taxa_filter_from_csv(subset)
classifier.apply_filter(filter)
predictions = classifier.predict(images=image_file, rank=rank, k=k)
predictions = classifier.predict(images=image_file, rank=rank, k=k, batch_size=batch_size)
write_results(predictions, format, output)


Expand Down Expand Up @@ -94,6 +95,8 @@ def create_parser():
model_arg = {'help': f'model identifier (see command list-models); default: {BIOCLIP_MODEL_STR}'}
pretrained_arg = {'help': 'pretrained model checkpoint as tag or file, depends on model; '
'needed only if more than one is available (see command list-models)'}
batch_size_arg = {'default': 10, 'type': int,
'help': 'Number of images to process in a batch, default: 10'}

# Predict command
predict_parser = subparsers.add_parser('predict', help='Use BioCLIP to generate predictions for image files.')
Expand All @@ -114,6 +117,7 @@ def create_parser():
predict_parser.add_argument('--device', **device_arg)
predict_parser.add_argument('--model', **model_arg)
predict_parser.add_argument('--pretrained', **pretrained_arg)
predict_parser.add_argument('--batch-size', **batch_size_arg)

# Embed command
embed_parser = subparsers.add_parser('embed', help='Use BioCLIP to generate embeddings for image files.')
Expand Down Expand Up @@ -186,7 +190,8 @@ def main():
device=args.device,
model_str=args.model,
pretrained_str=args.pretrained,
subset=args.subset)
subset=args.subset,
batch_size=args.batch_size)
elif args.command == 'list-models':
if args.model:
for tag in oc.list_pretrained_tags_by_model(args.model):
Expand Down
31 changes: 26 additions & 5 deletions src/bioclip/predict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import json
from tqdm import tqdm
import torch
from torchvision import transforms
import open_clip as oc
Expand Down Expand Up @@ -128,7 +129,6 @@ def get_txt_names():
txt_names = json.load(fd)
return txt_names


class Rank(Enum):
"""Rank for the Tree of Life classification."""
KINGDOM = 0
Expand Down Expand Up @@ -253,6 +253,21 @@ def create_probabilities_for_images(self, images: List[str] | List[PIL.Image.Ima
result[key] = probs[i]
return result

def create_batched_probabilities_for_images(self, images: List[str] | List[PIL.Image.Image],
txt_features: torch.Tensor,
batch_size: int | None) -> dict[str, torch.Tensor]:
if not batch_size:
batch_size = len(images)
result = {}
total_images = len(images)
with tqdm(total=total_images, unit="images") as progress_bar:
for i in range(0, len(images), batch_size):
grouped_images = images[i:i + batch_size]
probs = self.create_probabilities_for_images(grouped_images, txt_features)
result.update(probs)
progress_bar.update(len(grouped_images))
return result

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Given an input tensor representing multiple images, return probabilities for each class for each image.
Expand Down Expand Up @@ -297,20 +312,23 @@ def _get_txt_embeddings(self, classnames):
return all_features

@torch.no_grad()
def predict(self, images: List[str] | str | List[PIL.Image.Image], k: int = None) -> dict[str, float]:
def predict(self, images: List[str] | str | List[PIL.Image.Image], k: int = None,
batch_size: int = 10) -> dict[str, float]:
"""
Predicts the probabilities for the given images.
Parameters:
images (List[str] | str | List[PIL.Image.Image]): A list of image file paths, a single image file path, or a list of PIL Image objects.
k (int, optional): The number of top probabilities to return. If not specified or if greater than the number of classes, all probabilities are returned.
batch_size (int, optional): The number of images to process in a batch.
Returns:
List[dict]: A list of dicts with keys "file_name" and the custom class labels.
"""
if isinstance(images, str):
images = [images]
probs = self.create_probabilities_for_images(images, self.txt_embeddings)
probs = self.create_batched_probabilities_for_images(images, self.txt_embeddings,
batch_size=batch_size)
result = []
for i, image in enumerate(images):
key = self.make_key(image, i)
Expand Down Expand Up @@ -555,7 +573,8 @@ def format_grouped_probs(self, image_key: str, probs: torch.Tensor, rank: Rank,
return prediction_ary

@torch.no_grad()
def predict(self, images: List[str] | str | List[PIL.Image.Image], rank: Rank, min_prob: float = 1e-9, k: int = 5) -> dict[str, dict[str, float]]:
def predict(self, images: List[str] | str | List[PIL.Image.Image], rank: Rank,
min_prob: float = 1e-9, k: int = 5, batch_size: int = 10) -> dict[str, dict[str, float]]:
"""
Predicts probabilities for supplied taxa rank for given images using the Tree of Life embeddings.
Expand All @@ -564,14 +583,16 @@ def predict(self, images: List[str] | str | List[PIL.Image.Image], rank: Rank, m
rank (Rank): The rank at which to make predictions (e.g., species, genus).
min_prob (float, optional): The minimum probability threshold for predictions.
k (int, optional): The number of top predictions to return.
batch_size (int, optional): The number of images to process in a batch.
Returns:
List[dict]: A list of dicts with keys "file_name", taxon ranks, "common_name", and "score".
"""

if isinstance(images, str):
images = [images]
probs = self.create_probabilities_for_images(images, self.get_txt_embeddings())
probs = self.create_batched_probabilities_for_images(images, self.get_txt_embeddings(),
batch_size=batch_size)
result = []
for i, image in enumerate(images):
key = self.make_key(image, i)
Expand Down
Loading

0 comments on commit 306a0be

Please sign in to comment.