COLT: Cyclic Overlapping Lottery Tickets for Faster Pruning of Convolutional Neural Networks in Pytorch
This repository is for our TAI paper[Accepted]:
COLT: Cyclic Overlapping Lottery Tickets for Faster Pruning of Convolutional Neural Networks by Md. Ismail Hossain, Mohammed Rakib, MM Lutfe Elahi, Nabeel Mohammed, and Shafin Rahman
COLT aims to generate winning lottery tickets from a set of lottery tickets that can achieve similar accuracy to the original unpruned network. We introduce a novel winning ticket called Cyclic Overlapping Lottery Ticket (COLT) by data splitting and cyclic retraining of the pruned network from scratch. We apply a cyclic pruning algorithm that keeps only the overlapping weights of different pruned models trained on different data segments. Our results demonstrate that COLT can achieve similar accuracies (obtained by the unpruned model) while maintaining high sparsities. We show that the accuracy of COLT is on par with the winning tickets of Lottery Ticket Hypothesis (LTH) and, at times, is better. Moreover, COLTs can be generated using fewer iterations than tickets generated by the popular Iterative Magnitude Pruning (IMP) method. In addition, we also notice COLTs generated on large datasets can be transferred to small ones without compromising performance, demonstrating its generalizing capability.
This repository also includes the implementation for the following papers:
The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks by Jonathan Frankle and Michael Carbin
One ticket to win them all: generalizing lottery ticket initializations across datasets and optimizers by Ari S. Morcos, Haonan Yu, Michela Paganini, Yuandong Tian
- python>=3.6
- Install libraries by
pip install -r requirements.txt
.
- Put all datasets in the 'data' folder in the root directory.
- Cifar-10, Cifar-100 and FashionMNIST downloads automatically using torchvision.datasets
- For TinyImagenet, download from Kaggle and use tinyimagenet.py to process the dataset
- For ImageNet, follow the instructions here.
python3 main.py --prune_type=colt --arch_type=resnet18 --dataset=cifar100
python3 main-full.py --prune_type=lth --arch_type=resnet18 --dataset=cifar100 --output_class=100
short | long | default | help |
---|---|---|---|
-h |
--help |
show this help message and exit | |
--resume |
0 |
resume training or not | |
--lr |
0.1 |
Learning rate | |
--warmup |
1 |
1 means to apply warmup to first epoch, 0 means no wamrup for first epoch | |
--batch_size |
256 |
None |
|
--start_iter |
0 |
start epoch | |
--end_iter |
50 |
end epoch | |
--prune_type |
colt |
lth | colt | |
--augmentations |
yes |
yes | no | |
--prune_strategy |
global |
global | local | |
--bias |
0 |
prune bias or not | |
--dataset |
cifar100 |
mnist | cifar10 | fashionmnist | cifar100 | imagenet | tinyimagenet | |
--arch_type |
conv3 |
fc1 | lenet5 | alexnet | conv3 | resnet18 | densenet121 | mobilenetv2 | shufflenetv2 | |
--output_class |
50 |
Output classes for the model | |
--linear_prune_ratio |
0.0 |
Linear layer pruning proportion | |
--output_prune_ratio |
0.0 |
Output layer pruning proportion | |
--conv_prune_ratio |
0.15 |
Conv layer pruning proportion | |
--batchnorm_prune_ratio |
0.0 |
Batchnorm layer pruning proportion | |
--final_prune_rate |
99.1 |
Final prune rate before pruning stops | |
--patience |
50 |
Patience level of epochs before ending training (based on validation loss) | |
--gpu |
0 |
None |
python3 transfer.py --prune_type=colt --arch_type=resnet18 --dataset=cifar100 --part2=full
python3 transfer.py --prune_type=lth --arch_type=resnet18 --dataset=cifar100 --part1=full --part2=full
short | long | default | help |
---|---|---|---|
-h |
--help |
show this help message and exit | |
--resume |
0 |
resume training or not | |
--augmentations |
yes |
yes | no | |
--arch_type |
mobilenetv2 |
fc1 | lenet5 | alexnet | conv3 | resnet18 | densenet121 | mobilenetv2 | shufflenetv2 | |
--dataset1 |
tinyimagenet |
mnist | cifar10 | fashionmnist | cifar100 | imagenet | tinyimagenet | -dataset on which the model is already trained | |
--dataset2 |
cifar10 |
mnist | cifar10 | fashionmnist | cifar100 | imagenet | tinyimagenet | -dataset on which the model will be trained based on weights from dataset1 | |
--prune_type |
lth |
lth | colt | |
--part1 |
'' | A | B | ''- dataset partition whose weights will be transferred. '' for transferring COLT weights as COLT weights are generated leveraging both A and B partitions | |
--part2 |
full |
A | B | full - dataset partition that will be trained on transfered weights from part1. 'full' means trained on the entire dataset and not on a partition (A/B) | |
--bias |
0 |
prune bias or not | |
--output_class |
10 |
Output classes for the model (based on dataset2) | |
--start_iter |
0 |
start epoch | |
--end_iter |
50 |
end epoch | |
--batch_size |
256 |
None |
|
--lr |
0.1 |
Learning rate | |
--warmup |
1 |
Initial no. of epochs to apply learning rate warmup. 3 means apply warmup for first 3 epochs. 0 means no warmup | |
--patience |
50 |
Patience level of epochs before ending training (based on validation loss) | |
--gpu |
0 |
None |
We would like to thank rahulvigneswaran for his repo on Lottery Ticket Hypothesis.
If you have any questions, please reach out to Md. Ismail Hossain ([email protected])
If our work aids your research, cite this in your article.
@ARTICLE{10855806,
author={Hossain, Md. Ismail and Rakib, Mohammed and Elahi, M. M. Lutfe and Mohammed, Nabeel and Rahman, Shafin},
journal={IEEE Transactions on Artificial Intelligence},
title={COLT: Cyclic Overlapping Lottery Tickets for Faster Pruning of Convolutional Neural Networks},
year={2025},
volume={},
number={},
pages={1-15},
keywords={Accuracy;Training;Computational modeling;Artificial intelligence;Neural networks;Iterative algorithms;Convolutional neural networks;Convergence;Partitioning algorithms;Object detection;Model Compression;Pruning;Sparse Networks},
doi={10.1109/TAI.2025.3534745}}