Skip to content

Commit

Permalink
refactor: rename vits prediction option to --use-vits
Browse files Browse the repository at this point in the history
  • Loading branch information
danellecline committed Nov 21, 2024
1 parent 5c996b2 commit a1e1301
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,12 @@ The algorithm workflow looks like this:
| google/vit-base-patch16-224(default) | 16 block size trained on ImageNet21k with 21k classes |
| facebook/dino-vits8 | trained on ImageNet which contains 1.3 M images with labels from 1000 classes |
| facebook/dino-vits16 | trained on ImageNet which contains 1.3 M images with labels from 1000 classes |
| MBARI-org/mbari-uav-vit-b-16 | MBARI UAV vits16 model trained on 10425 UAV images with labels from 21 classes |

Smaller block_size means more patches and more accurate fine-grained clustering on smaller objects, so
ViTS models with 8 block size are recommended for fine-grained clustering on small objects, and 16 is recommended for coarser clustering on
larger objects. We recommend running with multiple models to see which model works best for your data,
and to experiment with the --min_samples and --min-cluster-size options to get good clustering results.
and to experiment with the --min-samples and --min-cluster-size options to get good clustering results.

# Installation

Expand Down Expand Up @@ -145,7 +146,7 @@ Commands:
## File organization
The sdcat toolkit generates data in the following folders. Here, we assume both detection and clustering is output to the same root folder.:
The sdcat toolkit generates data in the following folders. Here, we assume both detection and clustering is stored in the same root folder:
```
/data/20230504-MBARI/
Expand Down Expand Up @@ -173,23 +174,23 @@ The sdcat toolkit generates data in the following folders. Here, we assume both
```
## Process images creating bounding box detections with the YOLOv5 model.
The YOLOv5s model is not as accurate as other models, but is fast and good for detecting larger objects in images,
## Process images creating bounding box detections with the YOLOv8s model.
The YOLOv8s model is not as accurate as other models, but is fast and good for detecting larger objects in images,
and good for experiments and quick results.
**Slice size** is the size of the detection window. The default is to allow the SAHI algorithm to determine the slice size;
a smaller slice size will take longer to process.
```shell
sdcat detect --image-dir <image-dir> --save-dir <save-dir> --model yolov5s --slice-size-width 900 --slice-size-height 900
sdcat detect --image-dir <image-dir> --save-dir <save-dir> --model yolov8s --slice-size-width 900 --slice-size-height 900
```
## Cluster detections from the YOLOv5 model
## Cluster detections from the YOLOv8s model, but use the classifications from the ViT model.
Cluster the detections from the YOLOv5 model. The detections are clustered using cosine similarity and embedding
features from a FaceBook Vision Transformer (ViT) model.
Cluster the detections from the YOLOv8s model. The detections are clustered using cosine similarity and embedding
features from the default Vision Transformer (ViT) model `google/vit-base-patch16-224`
```shell
sdcat cluster --det-dir <det-dir> --save-dir <save-dir> --model yolov5s
sdcat cluster --det-dir <det-dir>/yolov8s/det_filtered --save-dir <save-dir> --use-vits
```
Expand Down
8 changes: 4 additions & 4 deletions sdcat/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def cluster_vits(
min_cluster_size: int,
min_samples: int,
device: str = "cpu",
use_predictions: bool = False,
use_vits: bool = False,
use_tsne: bool = False,
skip_visualization: bool = False,
roi: bool = False) -> pd.DataFrame:
Expand All @@ -340,7 +340,7 @@ def cluster_vits(
:param min_cluster_size: The minimum number of samples in a cluster
:param min_samples:The number of samples in a neighborhood for a point
:param device: The device to use for clustering, 'cpu' or 'cuda'
:param use_predictions: Whether to use the predictions from the model used for clustering to assign classes
:param use_vits: Set to using the predictions from the vits cluster model
:param skip_visualization: Whether to skip the visualization of the clusters
:param use_tsne: Whether to use t-SNE for dimensionality reduction
:return: a dataframe with the assigned cluster indexes, or -1 for non-assigned."""
Expand Down Expand Up @@ -452,8 +452,8 @@ def cluster_vits(
debug(f'Adding {images[idx]} to cluster id {cluster_id} ')
df_dets.loc[df_dets['crop_path'] == images[idx], 'cluster'] = cluster_id

# If use_predictions is true, then assign the class to each detection
if use_predictions:
# If use_vits is true, then assign the class to each detection
if use_vits:
for idx, row in df_dets.iterrows():
predictions, scores = image_predictions[idx], image_scores[idx]
df_dets.loc[idx, 'class'] = predictions[0] # Use the top prediction
Expand Down
12 changes: 6 additions & 6 deletions sdcat/cluster/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
@click.option('--det-dir', help='Input folder(s) with raw detection results', multiple=True, required=True)
@click.option('--save-dir', help='Output directory to save clustered detection results', required=True)
@click.option('--device', help='Device to use, e.g. cpu or cuda:0', type=str, default='cpu')
@click.option('--use-predictions', help='Set to using the cluster model for prediction', is_flag=True)
def run_cluster_det(det_dir, save_dir, device, use_predictions, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, start_image, end_image, use_tsne, skip_visualization):
@click.option('--use-vits', help='Set to using the predictions from the vits cluster model', is_flag=True)
def run_cluster_det(det_dir, save_dir, device, use_vits, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, start_image, end_image, use_tsne, skip_visualization):
config = cfg.Config(config_ini)
max_area = int(config('cluster', 'max_area'))
min_area = int(config('cluster', 'min_area'))
Expand Down Expand Up @@ -259,7 +259,7 @@ def is_day(utc_dt):
# Cluster the detections
df_cluster = cluster_vits(prefix, model, df, save_dir, alpha, cluster_selection_epsilon, cluster_selection_method,
min_similarity, min_cluster_size, min_samples, device, use_tsne=use_tsne,
skip_visualization=skip_visualization, roi=False, use_predictions=use_predictions)
skip_visualization=skip_visualization, roi=False, use_vits=use_vits)

# Merge the results with the original DataFrame
df.update(df_cluster)
Expand All @@ -281,8 +281,8 @@ def is_day(utc_dt):
@click.option('--roi-dir', help='Input folder(s) with raw ROI images', multiple=True, required=True)
@click.option('--save-dir', help='Output directory to save clustered detection results', required=True)
@click.option('--device', help='Device to use, e.g. cpu or cuda:0', type=str)
@click.option('--use-predictions', help='Set to using the cluster model for prediction', is_flag=True)
def run_cluster_roi(roi_dir, save_dir, device, use_predictions, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, use_tsne, skip_visualization):
@click.option('--use-vits', help='Set to using the predictions from the vits cluster model', is_flag=True)
def run_cluster_roi(roi_dir, save_dir, device, use_vits, config_ini, alpha, cluster_selection_epsilon, cluster_selection_method, min_cluster_size, use_tsne, skip_visualization):
config = cfg.Config(config_ini)
min_samples = int(config('cluster', 'min_samples'))
alpha = alpha if alpha else float(config('cluster', 'alpha'))
Expand Down Expand Up @@ -363,7 +363,7 @@ def run_cluster_roi(roi_dir, save_dir, device, use_predictions, config_ini, alph
# Cluster the detections
df_cluster = cluster_vits(prefix, model, df, save_dir, alpha, cluster_selection_epsilon, cluster_selection_method,
min_similarity, min_cluster_size, min_samples, device, use_tsne,
skip_visualization=skip_visualization, use_predictions=use_predictions, roi=True)
skip_visualization=skip_visualization, use_vits=use_vits, roi=True)

# Merge the results with the original DataFrame
df.update(df_cluster)
Expand Down

0 comments on commit a1e1301

Please sign in to comment.