This repository is no longer maintained. For an updated version, please check out the CaBRNet library :)
This repository contains the code developed for the experiments presented in the paper "Sanity checks for patch visualisation in prototype-based image classification", accepted at XAI4CV (2nd Explainable AI for Computer Vision Workshop at CVPR'23), including our modifications to the original code of:
All dependencies are regrouped in the file requirements.txt. Simply run:
python -m pip install -r requirements.txt
- Download and preprocess the CUB200 and Stanford cars datasets, annotations and segmentation masks:
python prototree/preprocess_data/
python prototree/preprocess_data/
python prototree/preprocess_data/
python prototree/preprocess_data/
python protopnet/
- Download a ResNet50 model pretrained on the INaturalist dataset:
python prototree/features/
- Train a ProtoTree as follows
cd prototree
python \
--num_features 256 --depth 9 --net resnet50_inat --init_mode pretrained --dataset CUB-200-2011 \
--epochs 100 --lr 0.001 --lr_block 0.001 --lr_net 1e-5 --device cuda:0 \
--freeze_epochs 10 --milestones 60,70,80,90,100 --batch_size 64 --random_seed 42 \
--root_dir runs/prototree_cub \
--proj_dir proj_corners_sm --upsample_mode smoothgrads --upsample_threshold 0.3 --projection_mode corners
At the end of the training, this command will project prototypes using the augmented CUB dataset (images cropped to the four corners + center), and visualise prototypes using Smoothgrads.
Supported projection modes:
- raw: Use the raw training dataset
- corners: Training dataset augmented using 4 corners + center crop (CUB only)
- cropped: Training dataset cropped to object bounding box (CUB only)
Upsample mode is either:
- vanilla: Original upsampling with cubic interpolation
- smoothgrads: Use Smoothgrads x input
- prp: Use Prototype Relevance Propagation (PRP)
- To perform inference and generate explanations on a test image, use:
python \
--root_dir runs/prototree_cub/ --tree runs/prototree_cub/proj_raw_prp/model/ \
--proj_dir proj_raw_prp/ --dataset CUB-200-2011 --device cuda:0 \
--upsample_mode vanilla \
--sample_dir ../data/CUB_200_2011/dataset/test_full/054.Blue_Grosbeak/Blue_Grosbeak_0078_36655.jpg \
--results_dir explanations_vanilla
- Generating fidelity and relevance statistics on prototypes.
python \
--tree_dir runs/prototree_cub/checkpoints/latest/ \
--base_arch resnet50 \
--dataset CUB-200-2011 --use-segmentation \
--output prototree_birds_r50_proto_stats.csv \
--target_areas 0.001 0.02 0.001 --random_seed 0 \
--projection_mode corners \
--proj_dir runs/prototree_cub/proj_corners
- Generating fidelity and relevance statistics on test images.
python \
--tree_dir runs/prototree_cub/proj_corners_sm/model/ \
--base_arch resnet50 \
--dataset CUB-200-2011 --use-segmentation \
--device cuda:0 \
--output runs/prototree_cub/proj_corners/prototree_birds_r50_inference_stats.csv \
--target_areas 0.001 0.02 0.001 --random_seed 0
- Train a ProtoTree as follows
cd prototree
python \
--num_features 128 --depth 9 --net resnet50 --init_mode pretrained \
--dataset CARS \
--epochs 500 --lr 0.001 --lr_block 0.001 --lr_net 2e-4 --device cuda:0 \
--freeze_epochs 30 --milestones 250,350,400,425,450,475,500 \
--batch_size 64 --random_seed 42 \
--root_dir runs/prototree_cub \
--proj_dir proj_prp --upsample_mode prp \
--upsample_threshold 0.3
- To generate explanations and statistics, simply replace the CUB-200-2011 dataset with CARS.
Note that
is not available for Stanford Cars.
There are two ways to restart a training sequence.
- Relaunch the
with exactly the same options. If the--root_dir
directory exists, the training process will automatically restart from the checkpoint located in<root_dir>/checkpoints/latest
. - Specify explicitely a path using the
option pointing to the checkpoint directory.
It is possible to test different projection methods with different projection datasets without retraining the entire ProtoTree. Ex. --tree_dir ./runs/prototree/checkpoints/latest/ \
--root_dir runs/prototree \
--dataset CUB-200-2011 \
--device cuda:0 \
--proj_dir proj_raw_vanilla \
--upsample_threshold 0.98 --upsample_mode vanilla --projection_mode raw
To train a ProtoPNet, follow the instructions from the original code and edit the file
(see protopnet/README.txt
cd protopnet
python \
--model saved_models/resnet50/birds/trained_model.pth \
--base_arch resnet50 \
--dataset ../data/CUB_200_2011/dataset/train_crop/ \
--segm_dir ../data/CUB_200_2011/dataset/train_crop_seg/ \
--proj_dir saved_models/resnet50/birds/proj \
--device cuda:0 \
--output protopnet_birds_r50_proto_stats.csv \
--target_areas 0.001 0.02 0.001
python \
--model saved_models/resnet50/birds/trained_model.pth \
--base_arch resnet50 \
--dataset ../data/CUB_200_2011/dataset/test_full/ \
--segm_dir ../data/CUB_200_2011/dataset/test_full_seg/ \
--device cuda:0 \
--output saved_models/resnet50/birds/protopnet_birds_r50_inference_stats.csv \
--target_areas 0.001 0.02 0.001
To generate statistics for Stanford Cars, simply replace the path to dataset.
Note that --segm_dir
is not available for Stanford Cars.