🔥 [NeurIPS 2024] Flow Priors for Linear Inverse Problems via Iterative Corrupted Trajectory Matching
This repository hosts the code and resources associated with our paper on utlizing flow priors to solve linear inverse problems.
Generative models based on flow matching have attracted significant attention for their simplicity and superior performance in high-resolution image synthesis. By leveraging the instantaneous change-of-variables formula, one can directly compute image likelihoods from a learned flow, making them enticing candidates as priors for downstream tasks such as inverse problems. In particular, a natural approach would be to incorporate such image probabilities in a maximum-a-posteriori (MAP) estimation problem. A major obstacle, however, lies in the slow computation of the log-likelihood, as it requires backpropagating through an ODE solver, which can be prohibitively slow for high-dimensional problems. In this work, we propose an iterative algorithm to approximate the MAP estimator efficiently to solve a variety of linear inverse problems. Our algorithm is mathematically justified by the observation that the MAP objective can be approximated by a sum of
Clone this repository and create a conda environment:
git clone [email protected]:YasminZhang/ICTM.git
conda create -n ictm python=3.9 -y
conda activate ictm
Install the following packages:
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install tensorflow==2.9.0 tensorflow-probability==0.12.2 tensorflow-gan==2.0.0 tensorflow-datasets==4.6.0
pip install jax==0.3.4 jaxlib==0.3.2
pip install numpy==1.21.6 ninja==1.11.1 matplotlib==3.7.0 ml_collections==0.1.1
pip install tensorflow-io==0.26.0 # https://stackoverflow.com/questions/65623468/unable-to-open-file-libtensorflow-io-so-caused-by-undefined-symbol
If the jax or jaxlib installation fails, please use:
pip install jax==VERSION -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install jaxlib==VERSION -f https://storage.googleapis.com/jax-releases/jax_releases.html
- flow-checkpoint of CelebAHQ images: please check out the repo Recitified Flow or directly use the link here
- flow-checkpoint of MRI images: checkpoint
- CelebAHQ: we randomly select 100 images of the test set of CelebAHQ dataset
- MRI: we randomly select 200 images of HCP T2w test dataset link
To run our method, please use the following command:
CUDA_VISIBLE_DEVICES=1 python main.py --config ./configs/rectified_flow/celeba_hq_pytorch_rf_gaussian_inverse.py --eval_folder <eval_folder> --mode eval_inverse --workdir ./logs/celebahq_ckpt --config.eval.method ours --config.eval.task super_resolution --config.sampling.sample_N 100 --config.eval.eta 1.0e-02 --config.eval.k 1 --config.eval.lamda 1.0e+04
task
: super_resolution, inpainting_box, gaussian, inpainting, cs (compressed sensing)workdir
: where you save the checkpointsconfig.eval.lamda
: guidance weightconfig.eval.eta
: step sizeconfig.eval.k
: iteration number, default = 1config.sampling.sample_N
: number of sampling steps, default = 100
We mainly use the following metrics to evaluate the generated images:
- PSNR
- SSIM
Please make sure that your eval_dir
folder structure is as follows:
-recon
--000001.png
...
-label
--000001.png
...
To get the PSNR and SSIM scores, (and LPIPS/FID), run the following command:
python get_metric.py --eval_dir=<path/to/evaldir> --enable_fid=<True/False> --gpu=<gpu_id>
- For Wavelet and TV norms, please check out package DeepInverse
- For RedDiff and $\Pi$GDM, please check out repo RED-Diff
- 🔥 For new baselines in flow matching, we refer to this great repo PnP-flow
If you find the code or our results useful, please cite as:
@article{zhang2024flow,
title={Flow Priors for Linear Inverse Problems via Iterative Corrupted Trajectory Matching},
author={Zhang, Yasi and Yu, Peiyu and Zhu, Yaxuan and Chang, Yingshan and Gao, Feng and Wu, Ying Nian and Leong, Oscar},
journal={arXiv preprint arXiv:2405.18816},
year={2024}
}