forked from Project-MONAI/tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add tutorial for 3d ldm on brats (Project-MONAI#1301)
Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` --------- Signed-off-by: Can-Zhao <[email protected]> Signed-off-by: Can Zhao <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
e5adbd0
commit e49afe3
Showing
17 changed files
with
1,163 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# 3D Latent Diffusion Example | ||
This folder contains an example for training and validating a 3D Latent Diffusion Model on Brats data. The example includes support for multi-GPU training with distributed data parallelism based on a [tutorial designed for using single GPU](https://github.com/Project-MONAI/GenerativeModels/blob/main/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb). | ||
|
||
The workflow of the Latent Diffusion Model is depicted in the figure below. It begins by training an autoencoder in pixel space to encode images into latent features. Following that, it trains a diffusion model in the latent space to denoise the noisy latent features. During inference, it first generates latent features from random noise by applying multiple denoising steps using the trained diffusion model. Finally, it decodes the denoised latent features into images using the trained autoencoder. | ||
<p align="center"> | ||
<img src="./figs/ldm.png" alt="latent diffusion scheme") | ||
</p> | ||
|
||
MONAI latent diffusion model implementation is based on the following papers: | ||
|
||
[**Latent Diffusion:** Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." CVPR 2022.](https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf) | ||
|
||
This network is designed as a demonstration to showcase the training process for this type of network using MONAI. To achieve optimal performance, it is recommended that users have a GPU with memory larger than 32G to accommodate larger networks and attention layers. | ||
|
||
### 1. Data | ||
|
||
The dataset we are experimenting with in this example is BraTS 2016 and 2017 data. | ||
|
||
BraTS is a public dataset of brain MR images. Using these images, the goal is to generate images that look similar to the images in BraTS 2016 and 2017 dataset. | ||
|
||
The data can be automatically downloaded from [Medical Decathlon](http://medicaldecathlon.com/) at the beginning of training. | ||
|
||
Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset! We acknowledge the National Cancer Institute and the Foundation for the National Institutes of Health, and their critical role in the creation of the free publicly available LIDC/IDRI Database used in this study. | ||
|
||
### 2. Installation | ||
``` | ||
pip install lpips | ||
pip install git+https://github.com/Project-MONAI/GenerativeModels.git#egg=Generative | ||
``` | ||
|
||
Or install it from source: | ||
``` | ||
pip install lpips | ||
git clone https://github.com/Project-MONAI/GenerativeModels.git | ||
cd GenerativeModels/ | ||
python setup.py install | ||
cd .. | ||
``` | ||
### 3. Run the example | ||
|
||
#### [3.1 3D Autoencoder Training](./train_autoencoder.py) | ||
|
||
The network configuration files are located in [./config/config_train_32g.json](./config/config_train_32g.json) for 32G GPU | ||
and [./config/config_train_16g.json](./config/config_train_16g.json) for 16G GPU. | ||
You can modify the hyperparameters in these files to suit your requirements. | ||
|
||
The training script resamples the brain images based on the voxel spacing specified in the `"spacing"` field of the configuration files. For instance, `"spacing": [1.1, 1.1, 1.1]` resamples the images to a resolution of 1.1x1.1x1.1 mm. If you have a GPU with larger memory, you may consider changing the `"spacing"` field to `"spacing": [1.0, 1.0, 1.0]`. | ||
|
||
The training script uses the batch size and patch size defined in the configuration files. If you have a different GPU memory size, you should adjust the `"batch_size"` and `"patch_size"` parameters in the `"autoencoder_train"` to match your GPU. Note that the `"patch_size"` needs to be divisible by 4. | ||
|
||
Before you start training, please set the path in [./config/environment.json](./config/environment.json). | ||
|
||
- `"model_dir"`: where it saves the trained models | ||
- `"tfevent_path"`: where it saves the tensorboard events | ||
- `"output_dir"`: where you store the generated images during inference. | ||
- `"resume_ckpt"`: whether to resume training from existing checkpoints. | ||
- `"data_base_dir"`: where you store the Brats dataset. | ||
|
||
If the Brats dataset is not downloaded, please add `--download_data` in training command, the Brats data will be downloaded from [Medical Decathlon](http://medicaldecathlon.com/) and extracted to `$data_base_dir`. You will see a subfolder `Task01_BrainTumour` under `$data_base_dir`. By default, you will see `./Task01_BrainTumour` | ||
For example, this command is for running the training script with one 32G gpu. | ||
```bash | ||
python train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g 1 --download_data | ||
``` | ||
If `$data_base_dir/Task01_BrainTumour` already exists, you may omit the downloading. | ||
```bash | ||
python train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g 1 | ||
``` | ||
|
||
The training script also enables multi-GPU training. For instance, if you are using eight 32G GPUs, you can run the training script with the following command: | ||
```bash | ||
export NUM_GPUS_PER_NODE=8 | ||
torchrun \ | ||
--nproc_per_node=${NUM_GPUS_PER_NODE} \ | ||
--nnodes=1 \ | ||
--master_addr=localhost --master_port=1234 \ | ||
train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE} | ||
``` | ||
|
||
<p align="center"> | ||
<img src="./figs/train_recon.png" alt="autoencoder train curve" width="45%" > | ||
| ||
<img src="./figs/val_recon.png" alt="autoencoder validation curve" width="45%" > | ||
</p> | ||
|
||
With eight DGX1V 32G GPUs, it took around 55 hours to train 1000 epochs. | ||
|
||
#### [3.2 3D Latent Diffusion Training](./train_diffusion.py) | ||
The training script uses the batch size and patch size defined in the configuration files. If you have a different GPU memory size, you should adjust the `"batch_size"` and `"patch_size"` parameters in the `"diffusion_train"` to match your GPU. Note that the `"patch_size"` needs to be divisible by 16. | ||
|
||
To train with single 32G GPU, please run: | ||
```bash | ||
python train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g 1 | ||
``` | ||
|
||
The training script also enables multi-GPU training. For instance, if you are using eight 32G GPUs, you can run the training script with the following command: | ||
```bash | ||
export NUM_GPUS_PER_NODE=8 | ||
torchrun \ | ||
--nproc_per_node=${NUM_GPUS_PER_NODE} \ | ||
--nnodes=1 \ | ||
--master_addr=localhost --master_port=1234 \ | ||
train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE} | ||
``` | ||
<p align="center"> | ||
<img src="./figs/train_diffusion.png" alt="latent diffusion train curve" width="45%" > | ||
| ||
<img src="./figs/val_diffusion.png" alt="latent diffusion validation curve" width="45%" > | ||
</p> | ||
|
||
#### [3.3 Inference](./inference.py) | ||
To generate one image during inference, please run the following command: | ||
```bash | ||
python inference.py -c ./config/config_train_32g.json -e ./config/environment.json --num 1 | ||
``` | ||
`--num` defines how many images it would generate. | ||
|
||
An example output is shown below. | ||
<p align="center"> | ||
<img src="./figs/syn_axial.png" width="30%" > | ||
| ||
<img src="./figs/syn_sag.png" width="30%" > | ||
| ||
<img src="./figs/syn_cor.png" width="30%" > | ||
</p> | ||
|
||
### 4. Questions and bugs | ||
|
||
- For questions relating to the use of MONAI, please use our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI. | ||
- For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues). | ||
- For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues). | ||
|
||
### Reference | ||
[1] [Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." CVPR 2022.](https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf) | ||
|
||
[2] [Menze, Bjoern H., et al. "The multimodal brain tumor image segmentation benchmark (BRATS)." IEEE transactions on medical imaging 34.10 (2014): 1993-2024.](https://ieeexplore.ieee.org/document/6975210) | ||
|
||
[3] [Pinaya et al. "Brain imaging generation with latent diffusion models"](https://arxiv.org/abs/2209.07162) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
{ | ||
"channel": 0, | ||
"spacing": [1.1, 1.1, 1.1], | ||
"spatial_dims": 3, | ||
"image_channels": 1, | ||
"latent_channels": 8, | ||
"autoencoder_def": { | ||
"_target_": "generative.networks.nets.AutoencoderKL", | ||
"spatial_dims": "@spatial_dims", | ||
"in_channels": "$@image_channels", | ||
"out_channels": "@image_channels", | ||
"latent_channels": "@latent_channels", | ||
"num_channels": [ | ||
64, | ||
128, | ||
256 | ||
], | ||
"num_res_blocks": 2, | ||
"norm_num_groups": 32, | ||
"norm_eps": 1e-06, | ||
"attention_levels": [ | ||
false, | ||
false, | ||
false | ||
], | ||
"with_encoder_nonlocal_attn": false, | ||
"with_decoder_nonlocal_attn": false | ||
}, | ||
"autoencoder_train": { | ||
"batch_size": 1, | ||
"patch_size": [112,128,80], | ||
"lr": 5e-6, | ||
"perceptual_weight": 0.001, | ||
"kl_weight": 1e-7, | ||
"recon_loss": "l1", | ||
"n_epochs": 1000, | ||
"val_interval": 10 | ||
}, | ||
"diffusion_def": { | ||
"_target_": "generative.networks.nets.DiffusionModelUNet", | ||
"spatial_dims": "@spatial_dims", | ||
"in_channels": "@latent_channels", | ||
"out_channels": "@latent_channels", | ||
"num_channels":[256, 256, 512], | ||
"attention_levels":[false, true, true], | ||
"num_head_channels":[0, 64, 64], | ||
"num_res_blocks": 2 | ||
}, | ||
"diffusion_train": { | ||
"batch_size": 2, | ||
"patch_size": [144,176,112], | ||
"lr": 5e-6, | ||
"n_epochs": 10000, | ||
"val_interval": 2 | ||
}, | ||
"NoiseScheduler": { | ||
"num_train_timesteps": 1000, | ||
"beta_start": 0.0015, | ||
"beta_end": 0.0195 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
{ | ||
"channel": 0, | ||
"spacing": [1.1, 1.1, 1.1], | ||
"spatial_dims": 3, | ||
"image_channels": 1, | ||
"latent_channels": 8, | ||
"autoencoder_def": { | ||
"_target_": "generative.networks.nets.AutoencoderKL", | ||
"spatial_dims": "@spatial_dims", | ||
"in_channels": "$@image_channels", | ||
"out_channels": "@image_channels", | ||
"latent_channels": "@latent_channels", | ||
"num_channels": [ | ||
64, | ||
128, | ||
256 | ||
], | ||
"num_res_blocks": 2, | ||
"norm_num_groups": 32, | ||
"norm_eps": 1e-06, | ||
"attention_levels": [ | ||
false, | ||
false, | ||
false | ||
], | ||
"with_encoder_nonlocal_attn": false, | ||
"with_decoder_nonlocal_attn": false | ||
}, | ||
"autoencoder_train": { | ||
"batch_size": 2, | ||
"patch_size": [112,128,80], | ||
"lr": 1e-5, | ||
"perceptual_weight": 0.001, | ||
"kl_weight": 1e-7, | ||
"recon_loss": "l1", | ||
"n_epochs": 1000, | ||
"val_interval": 10 | ||
}, | ||
"diffusion_def": { | ||
"_target_": "generative.networks.nets.DiffusionModelUNet", | ||
"spatial_dims": "@spatial_dims", | ||
"in_channels": "@latent_channels", | ||
"out_channels": "@latent_channels", | ||
"num_channels":[256, 256, 512], | ||
"attention_levels":[false, true, true], | ||
"num_head_channels":[0, 64, 64], | ||
"num_res_blocks": 2 | ||
}, | ||
"diffusion_train": { | ||
"batch_size": 3, | ||
"patch_size": [144,176,112], | ||
"lr": 1e-5, | ||
"n_epochs": 10000, | ||
"val_interval": 2 | ||
}, | ||
"NoiseScheduler": { | ||
"num_train_timesteps": 1000, | ||
"beta_start": 0.0015, | ||
"beta_end": 0.0195 | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{ | ||
"data_base_dir": "./dataset", | ||
"model_dir": "./trained_weights/diffusion_3d", | ||
"tfevent_path": "./tfevent/diffusion_3d", | ||
"output_dir": "./output", | ||
"resume_ckpt": false | ||
} |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import json | ||
import logging | ||
import os | ||
import sys | ||
from datetime import datetime | ||
from pathlib import Path | ||
|
||
import nibabel as nib | ||
import numpy as np | ||
import torch | ||
from generative.inferers import LatentDiffusionInferer | ||
from generative.networks.schedulers import DDPMScheduler | ||
from monai.config import print_config | ||
from monai.utils import set_determinism | ||
|
||
from utils import define_instance | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="PyTorch Latent Diffusion Model Inference") | ||
parser.add_argument( | ||
"-e", | ||
"--environment-file", | ||
default="./config/environment.json", | ||
help="environment json file that stores environment path", | ||
) | ||
parser.add_argument( | ||
"-c", | ||
"--config-file", | ||
default="./config/config_train_48g.json", | ||
help="config json file that stores hyper-parameters", | ||
) | ||
parser.add_argument( | ||
"-n", | ||
"--num", | ||
type=int, | ||
default=1, | ||
help="number of generated images", | ||
) | ||
args = parser.parse_args() | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
print_config() | ||
torch.backends.cudnn.benchmark = True | ||
torch.set_num_threads(4) | ||
|
||
env_dict = json.load(open(args.environment_file, "r")) | ||
config_dict = json.load(open(args.config_file, "r")) | ||
|
||
for k, v in env_dict.items(): | ||
setattr(args, k, v) | ||
for k, v in config_dict.items(): | ||
setattr(args, k, v) | ||
|
||
set_determinism(42) | ||
|
||
# load trained networks | ||
autoencoder = define_instance(args, "autoencoder_def").to(device) | ||
trained_g_path = os.path.join(args.model_dir, "autoencoder.pt") | ||
autoencoder.load_state_dict(torch.load(trained_g_path)) | ||
|
||
diffusion_model = define_instance(args, "diffusion_def").to(device) | ||
trained_diffusion_path = os.path.join(args.model_dir, "diffusion_unet.pt") | ||
diffusion_model.load_state_dict(torch.load(trained_diffusion_path)) | ||
|
||
scheduler = DDPMScheduler( | ||
num_train_timesteps=args.NoiseScheduler["num_train_timesteps"], | ||
beta_schedule="scaled_linear", | ||
beta_start=args.NoiseScheduler["beta_start"], | ||
beta_end=args.NoiseScheduler["beta_end"], | ||
) | ||
inferer = LatentDiffusionInferer(scheduler, scale_factor=1.0) | ||
|
||
Path(args.output_dir).mkdir(parents=True, exist_ok=True) | ||
latent_shape = [p // 4 for p in args.diffusion_train["patch_size"]] | ||
noise_shape = [1, args.latent_channels] + latent_shape | ||
|
||
for _ in range(args.num): | ||
noise = torch.randn(noise_shape, dtype=torch.float32).to(device) | ||
with torch.no_grad(): | ||
synthetic_images = inferer.sample( | ||
input_noise=noise, | ||
autoencoder_model=autoencoder, | ||
diffusion_model=diffusion_model, | ||
scheduler=scheduler, | ||
) | ||
filename = os.path.join(args.output_dir, datetime.now().strftime("synimg_%Y%m%d_%H%M%S")) | ||
final_img = nib.Nifti1Image(synthetic_images[0, 0, ...].unsqueeze(-1).cpu().numpy(), np.eye(4)) | ||
nib.save(final_img, filename) | ||
|
||
|
||
if __name__ == "__main__": | ||
logging.basicConfig( | ||
stream=sys.stdout, | ||
level=logging.INFO, | ||
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s", | ||
datefmt="%Y-%m-%d %H:%M:%S", | ||
) | ||
main() |
Oops, something went wrong.