Skip to content

Code for the experiments described in the paper "Sanity checks for patch visualisation in prototype-based image classification" (XAI4CV 2023)

Notifications You must be signed in to change notification settings

romain-xu-darme/prototype_sanity_checks

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 

Repository files navigation

This repository is no longer maintained. For an updated version, please check out the CaBRNet library :)

Sanity checks for patch visualisation in prototype-based image classification

This repository contains the code developed for the experiments presented in the paper "Sanity checks for patch visualisation in prototype-based image classification", accepted at XAI4CV (2nd Explainable AI for Computer Vision Workshop at CVPR'23), including our modifications to the original code of:

Dependencies

All dependencies are regrouped in the file requirements.txt. Simply run:

python -m pip install -r requirements.txt

Setup

  1. Download and preprocess the CUB200 and Stanford cars datasets, annotations and segmentation masks:
python prototree/preprocess_data/download_birds.py
python prototree/preprocess_data/cub.py
python prototree/preprocess_data/download_cars.py
python prototree/preprocess_data/cars.py
python protopnet/img_aug.py
  1. Download a ResNet50 model pretrained on the INaturalist dataset:
python prototree/features/get_state_dict.py 

Experiments on ProtoTree

On CUB200

  1. Train a ProtoTree as follows
cd prototree
python main_tree.py \
	--num_features 256 --depth 9 --net resnet50_inat --init_mode pretrained --dataset CUB-200-2011 \
	--epochs 100 --lr 0.001 --lr_block 0.001 --lr_net 1e-5 --device cuda:0 \
	--freeze_epochs 10 --milestones 60,70,80,90,100 --batch_size 64 --random_seed 42 \
	--root_dir runs/prototree_cub  \
	--proj_dir proj_corners_sm --upsample_mode smoothgrads --upsample_threshold 0.3 --projection_mode corners

At the end of the training, this command will project prototypes using the augmented CUB dataset (images cropped to the four corners + center), and visualise prototypes using Smoothgrads.

Supported projection modes:

  • raw: Use the raw training dataset
  • corners: Training dataset augmented using 4 corners + center crop (CUB only)
  • cropped: Training dataset cropped to object bounding box (CUB only)

Upsample mode is either:

  • vanilla: Original upsampling with cubic interpolation
  • smoothgrads: Use Smoothgrads x input
  • prp: Use Prototype Relevance Propagation (PRP)
  1. To perform inference and generate explanations on a test image, use:
python main_explain_local.py \
	--root_dir runs/prototree_cub/ --tree runs/prototree_cub/proj_raw_prp/model/ \
	--proj_dir proj_raw_prp/ --dataset CUB-200-2011 --device cuda:0 \
	--upsample_mode vanilla  \
	--sample_dir ../data/CUB_200_2011/dataset/test_full/054.Blue_Grosbeak/Blue_Grosbeak_0078_36655.jpg  \
	--results_dir explanations_vanilla
  1. Generating fidelity and relevance statistics on prototypes.
python get_prototype_stats.py \
	--tree_dir runs/prototree_cub/checkpoints/latest/ \
	--base_arch resnet50  \
	--dataset CUB-200-2011 --use-segmentation \
	--output prototree_birds_r50_proto_stats.csv \
	--target_areas 0.001 0.02 0.001 --random_seed 0 \
	--projection_mode corners   \
	--proj_dir runs/prototree_cub/proj_corners
  1. Generating fidelity and relevance statistics on test images.
python get_inference_stats.py \
	--tree_dir runs/prototree_cub/proj_corners_sm/model/ \
	--base_arch resnet50 \
	--dataset CUB-200-2011 --use-segmentation \
	--device cuda:0 \ 
	--output runs/prototree_cub/proj_corners/prototree_birds_r50_inference_stats.csv \
	--target_areas 0.001 0.02 0.001 --random_seed 0

On Stanford Cars

  1. Train a ProtoTree as follows
cd prototree
python main_tree.py \
	--num_features 128 --depth 9 --net resnet50 --init_mode pretrained \
	--dataset CARS \
	--epochs 500 --lr 0.001 --lr_block 0.001 --lr_net 2e-4 --device cuda:0 \
	--freeze_epochs 30 --milestones 250,350,400,425,450,475,500 \
	--batch_size 64 --random_seed 42 \
	--root_dir runs/prototree_cub  \
	--proj_dir proj_prp --upsample_mode prp \ 
	--upsample_threshold 0.3
  1. To generate explanations and statistics, simply replace the CUB-200-2011 dataset with CARS. Note that --use-segmentation is not available for Stanford Cars.

Restart a training sequence from a checkpoint

There are two ways to restart a training sequence.

  1. Relaunch the main_tree.py with exactly the same options. If the --root_dir directory exists, the training process will automatically restart from the checkpoint located in <root_dir>/checkpoints/latest.
  2. Specify explicitely a path using the --tree_dir option pointing to the checkpoint directory.

Perform different projection methods on a pretrained ProtoTree

It is possible to test different projection methods with different projection datasets without retraining the entire ProtoTree. Ex.

finalize_tree.py --tree_dir ./runs/prototree/checkpoints/latest/ \
	--root_dir runs/prototree \
	--dataset CUB-200-2011 \
	--device cuda:0 \
	--proj_dir proj_raw_vanilla \
	--upsample_threshold 0.98 --upsample_mode vanilla --projection_mode raw 

Experiments on ProtoPNet

Training a model

To train a ProtoPNet, follow the instructions from the original code and edit the file protopnet/settings.py (see protopnet/README.txt).

Compute prototype/test image statistics

cd protopnet
python get_prototype_stats.py \
  --model saved_models/resnet50/birds/trained_model.pth \
  --base_arch resnet50 \
  --dataset ../data/CUB_200_2011/dataset/train_crop/ \
  --segm_dir ../data/CUB_200_2011/dataset/train_crop_seg/ \
  --proj_dir saved_models/resnet50/birds/proj \
  --device cuda:0 \ 
  --output protopnet_birds_r50_proto_stats.csv \ 
  --target_areas 0.001 0.02 0.001
python get_inference_stats.py \
  --model saved_models/resnet50/birds/trained_model.pth \
  --base_arch resnet50 \
  --dataset ../data/CUB_200_2011/dataset/test_full/ \
  --segm_dir ../data/CUB_200_2011/dataset/test_full_seg/ \
  --device cuda:0 \
  --output saved_models/resnet50/birds/protopnet_birds_r50_inference_stats.csv \
  --target_areas 0.001 0.02 0.001

To generate statistics for Stanford Cars, simply replace the path to dataset. Note that --segm_dir is not available for Stanford Cars.

About

Code for the experiments described in the paper "Sanity checks for patch visualisation in prototype-based image classification" (XAI4CV 2023)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages