Forked from PyTorch implementation for the paper Score-Based Generative Modeling through Stochastic Differential Equations
by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole
- Create conda environment:
conda env create -f environment.lock.yml
- Clone and install https://github.com/henryaddison/mlde_utils into the environment: e.g.
pip install -e ../mlde_utils
- Install ml_downscaling_emulator locally:
pip install -e .
- Install unet code:
git clone --depth 1 [email protected]:henryaddison/Pytorch-UNet src/ml_downscaling_emulator/unet
- Configure necessary environment variables:
DERVIED_DATA
and ...
bin/local-test-train
Train models through main.py
.
main.py:
--config: Training configuration.
(default: 'None')
--mode: <train|eval>: Running mode: train or eval
--workdir: Working directory
-
config
is the path to the config file. Our prescribed config files are provided inconfigs/
. They are formatted according toml_collections
and should be quite self-explanatory.Naming conventions of config files: the path of a config file is a combination of the following dimensions:
- dataset: One of
cifar10
,celeba
,celebahq
,celebahq_256
,ffhq_256
,celebahq
,ffhq
. - model: One of
ncsn
,ncsnv2
,ncsnpp
,ddpm
,ddpmpp
. - continuous: train the model with continuously sampled time steps.
- dataset: One of
-
workdir
is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results. -
mode
is "train". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist inworkdir/checkpoints-meta
.These functionalities can be configured through config files, or more conveniently, through the command-line support of the
ml_collections
package. For example, to generate samples and evaluate sample quality, supply the--config.eval.enable_sampling
flag; to compute log-likelihoods, supply the--config.eval.enable_bpd
flag, and specify--config.eval.dataset=train/test
to indicate whether to compute the likelihoods on the training or test dataset.
TODO
- New SDEs: inherent the
sde_lib.SDE
abstract class and implement all abstract methods. Thediscretize()
method is optional and the default is Euler-Maruyama discretization. Existing sampling methods and likelihood computation will automatically work for this new SDE. - New predictors: inherent the
sampling.Predictor
abstract class, implement theupdate_fn
abstract method, and register its name with@register_predictor
. The new predictor can be directly used insampling.get_pc_sampler
for Predictor-Corrector sampling, and all other controllable generation methods incontrollable_generation.py
. - New correctors: inherent the
sampling.Corrector
abstract class, implement theupdate_fn
abstract method, and register its name with@register_corrector
. The new corrector can be directly used insampling.get_pc_sampler
, and all other controllable generation methods incontrollable_generation.py
.
- When using the JAX codebase, you can jit multiple training steps together to improve training speed at the cost of more memory usage. This can be set via
config.training.n_jitted_steps
. For CIFAR-10, we recommend usingconfig.training.n_jitted_steps=5
when your GPU/TPU has sufficient memory; otherwise we recommend usingconfig.training.n_jitted_steps=1
. Our current implementation requiresconfig.training.log_freq
to be dividable byn_jitted_steps
for logging and checkpointing to work normally. - The
snr
(signal-to-noise ratio) parameter ofLangevinCorrector
somewhat behaves like a temperature parameter. Largersnr
typically results in smoother samples, while smallersnr
gives more diverse but lower quality samples. Typical values ofsnr
is0.05 - 0.2
, and it requires tuning to strike the sweet spot. - For VE SDEs, we recommend choosing
config.model.sigma_max
to be the maximum pairwise distance between data samples in the training dataset.
This code based on the following work:
@inproceedings{
song2021scorebased,
title={Score-Based Generative Modeling through Stochastic Differential Equations},
author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=PxTIG12RRHS}
}
This work is built upon some previous papers which might also interest you:
- Song, Yang, and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." Proceedings of the 33rd Annual Conference on Neural Information Processing Systems. 2019.
- Song, Yang, and Stefano Ermon. "Improved techniques for training score-based generative models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.
- Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Proceedings of the 34th Annual Conference on Neural Information Processing Systems. 2020.