This repo contains training and evaluation code for distilling SDXL VAE into two CNNs.
To evaluate trained checkpoints run main.py
python3 main.py
This already has the trained model checkpoint paths provided with the repo The code allows you to run different versions of the pipeline, such as running only the encoder or only the decoder (with the other replaced by the VAE portion) or both.
eval_AE_dist(encoder_path, decoder_path, use_encoder=True, use_decoder=True, cifar=False, l=0)
eval_AE_dist("./checkpoints/run_dist_cifar_norm/model_best.pt", "./checkpoints/run_dist_dec_cifar_norm/model_best.pt", use_encoder=True, use_decoder=True, cifar=False, l=0)
When using checkpoints in ./checkpoints/run_model_sizes/bestEncoder_l=*.pt it is necessary to set the correct value of 'l' from the checkpoint name.
eval_AE_dist("./checkpoints/run_model_sizes/bestEncoder_l=2.pt", "./checkpoints/run_dist_dec_cifar_norm/model_best.pt", use_encoder=True, use_decoder=True, cifar=False, l=2)
Further to test on different dataset, you can set "cifar=True" (for using cifar10) or "cifar=False" for streaming CC12M dataset.