Skip to content

Commit

Permalink
add tutorial for 3d ldm on brats (Project-MONAI#1301)
Browse files Browse the repository at this point in the history
Fixes # .

### Description
A few sentences describing the changes proposed in this pull request.

### Checks
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [ ] Avoid including large-size files in the PR.
- [ ] Clean up long text outputs from code cells in the notebook.
- [ ] For security purposes, please check the contents and remove any
sensitive info such as user names and private key.
- [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use
relative paths for tutorial repo files (3) put figure and graphs in the
`./figure` folder
- [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>`

---------

Signed-off-by: Can-Zhao <[email protected]>
Signed-off-by: Can Zhao <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Can-Zhao and pre-commit-ci[bot] authored May 3, 2023
1 parent e5adbd0 commit e49afe3
Show file tree
Hide file tree
Showing 17 changed files with 1,163 additions and 0 deletions.
137 changes: 137 additions & 0 deletions generative/3d_ldm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# 3D Latent Diffusion Example
This folder contains an example for training and validating a 3D Latent Diffusion Model on Brats data. The example includes support for multi-GPU training with distributed data parallelism based on a [tutorial designed for using single GPU](https://github.com/Project-MONAI/GenerativeModels/blob/main/tutorials/generative/3d_ldm/3d_ldm_tutorial.ipynb).

The workflow of the Latent Diffusion Model is depicted in the figure below. It begins by training an autoencoder in pixel space to encode images into latent features. Following that, it trains a diffusion model in the latent space to denoise the noisy latent features. During inference, it first generates latent features from random noise by applying multiple denoising steps using the trained diffusion model. Finally, it decodes the denoised latent features into images using the trained autoencoder.
<p align="center">
<img src="./figs/ldm.png" alt="latent diffusion scheme")
</p>

MONAI latent diffusion model implementation is based on the following papers:

[**Latent Diffusion:** Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." CVPR 2022.](https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf)

This network is designed as a demonstration to showcase the training process for this type of network using MONAI. To achieve optimal performance, it is recommended that users have a GPU with memory larger than 32G to accommodate larger networks and attention layers.

### 1. Data

The dataset we are experimenting with in this example is BraTS 2016 and 2017 data.

BraTS is a public dataset of brain MR images. Using these images, the goal is to generate images that look similar to the images in BraTS 2016 and 2017 dataset.

The data can be automatically downloaded from [Medical Decathlon](http://medicaldecathlon.com/) at the beginning of training.

Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset! We acknowledge the National Cancer Institute and the Foundation for the National Institutes of Health, and their critical role in the creation of the free publicly available LIDC/IDRI Database used in this study.

### 2. Installation
```
pip install lpips
pip install git+https://github.com/Project-MONAI/GenerativeModels.git#egg=Generative
```

Or install it from source:
```
pip install lpips
git clone https://github.com/Project-MONAI/GenerativeModels.git
cd GenerativeModels/
python setup.py install
cd ..
```
### 3. Run the example

#### [3.1 3D Autoencoder Training](./train_autoencoder.py)

The network configuration files are located in [./config/config_train_32g.json](./config/config_train_32g.json) for 32G GPU
and [./config/config_train_16g.json](./config/config_train_16g.json) for 16G GPU.
You can modify the hyperparameters in these files to suit your requirements.

The training script resamples the brain images based on the voxel spacing specified in the `"spacing"` field of the configuration files. For instance, `"spacing": [1.1, 1.1, 1.1]` resamples the images to a resolution of 1.1x1.1x1.1 mm. If you have a GPU with larger memory, you may consider changing the `"spacing"` field to `"spacing": [1.0, 1.0, 1.0]`.

The training script uses the batch size and patch size defined in the configuration files. If you have a different GPU memory size, you should adjust the `"batch_size"` and `"patch_size"` parameters in the `"autoencoder_train"` to match your GPU. Note that the `"patch_size"` needs to be divisible by 4.

Before you start training, please set the path in [./config/environment.json](./config/environment.json).

- `"model_dir"`: where it saves the trained models
- `"tfevent_path"`: where it saves the tensorboard events
- `"output_dir"`: where you store the generated images during inference.
- `"resume_ckpt"`: whether to resume training from existing checkpoints.
- `"data_base_dir"`: where you store the Brats dataset.

If the Brats dataset is not downloaded, please add `--download_data` in training command, the Brats data will be downloaded from [Medical Decathlon](http://medicaldecathlon.com/) and extracted to `$data_base_dir`. You will see a subfolder `Task01_BrainTumour` under `$data_base_dir`. By default, you will see `./Task01_BrainTumour`
For example, this command is for running the training script with one 32G gpu.
```bash
python train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g 1 --download_data
```
If `$data_base_dir/Task01_BrainTumour` already exists, you may omit the downloading.
```bash
python train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g 1
```

The training script also enables multi-GPU training. For instance, if you are using eight 32G GPUs, you can run the training script with the following command:
```bash
export NUM_GPUS_PER_NODE=8
torchrun \
--nproc_per_node=${NUM_GPUS_PER_NODE} \
--nnodes=1 \
--master_addr=localhost --master_port=1234 \
train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
```

<p align="center">
<img src="./figs/train_recon.png" alt="autoencoder train curve" width="45%" >
&nbsp; &nbsp; &nbsp; &nbsp;
<img src="./figs/val_recon.png" alt="autoencoder validation curve" width="45%" >
</p>

With eight DGX1V 32G GPUs, it took around 55 hours to train 1000 epochs.

#### [3.2 3D Latent Diffusion Training](./train_diffusion.py)
The training script uses the batch size and patch size defined in the configuration files. If you have a different GPU memory size, you should adjust the `"batch_size"` and `"patch_size"` parameters in the `"diffusion_train"` to match your GPU. Note that the `"patch_size"` needs to be divisible by 16.

To train with single 32G GPU, please run:
```bash
python train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g 1
```

The training script also enables multi-GPU training. For instance, if you are using eight 32G GPUs, you can run the training script with the following command:
```bash
export NUM_GPUS_PER_NODE=8
torchrun \
--nproc_per_node=${NUM_GPUS_PER_NODE} \
--nnodes=1 \
--master_addr=localhost --master_port=1234 \
train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
```
<p align="center">
<img src="./figs/train_diffusion.png" alt="latent diffusion train curve" width="45%" >
&nbsp; &nbsp; &nbsp; &nbsp;
<img src="./figs/val_diffusion.png" alt="latent diffusion validation curve" width="45%" >
</p>

#### [3.3 Inference](./inference.py)
To generate one image during inference, please run the following command:
```bash
python inference.py -c ./config/config_train_32g.json -e ./config/environment.json --num 1
```
`--num` defines how many images it would generate.

An example output is shown below.
<p align="center">
<img src="./figs/syn_axial.png" width="30%" >
&nbsp; &nbsp; &nbsp; &nbsp;
<img src="./figs/syn_sag.png" width="30%" >
&nbsp; &nbsp; &nbsp; &nbsp;
<img src="./figs/syn_cor.png" width="30%" >
</p>

### 4. Questions and bugs

- For questions relating to the use of MONAI, please use our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI.
- For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues).
- For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues).

### Reference
[1] [Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." CVPR 2022.](https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf)

[2] [Menze, Bjoern H., et al. "The multimodal brain tumor image segmentation benchmark (BRATS)." IEEE transactions on medical imaging 34.10 (2014): 1993-2024.](https://ieeexplore.ieee.org/document/6975210)

[3] [Pinaya et al. "Brain imaging generation with latent diffusion models"](https://arxiv.org/abs/2209.07162)
61 changes: 61 additions & 0 deletions generative/3d_ldm/config/config_train_16g.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
{
"channel": 0,
"spacing": [1.1, 1.1, 1.1],
"spatial_dims": 3,
"image_channels": 1,
"latent_channels": 8,
"autoencoder_def": {
"_target_": "generative.networks.nets.AutoencoderKL",
"spatial_dims": "@spatial_dims",
"in_channels": "$@image_channels",
"out_channels": "@image_channels",
"latent_channels": "@latent_channels",
"num_channels": [
64,
128,
256
],
"num_res_blocks": 2,
"norm_num_groups": 32,
"norm_eps": 1e-06,
"attention_levels": [
false,
false,
false
],
"with_encoder_nonlocal_attn": false,
"with_decoder_nonlocal_attn": false
},
"autoencoder_train": {
"batch_size": 1,
"patch_size": [112,128,80],
"lr": 5e-6,
"perceptual_weight": 0.001,
"kl_weight": 1e-7,
"recon_loss": "l1",
"n_epochs": 1000,
"val_interval": 10
},
"diffusion_def": {
"_target_": "generative.networks.nets.DiffusionModelUNet",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"out_channels": "@latent_channels",
"num_channels":[256, 256, 512],
"attention_levels":[false, true, true],
"num_head_channels":[0, 64, 64],
"num_res_blocks": 2
},
"diffusion_train": {
"batch_size": 2,
"patch_size": [144,176,112],
"lr": 5e-6,
"n_epochs": 10000,
"val_interval": 2
},
"NoiseScheduler": {
"num_train_timesteps": 1000,
"beta_start": 0.0015,
"beta_end": 0.0195
}
}
61 changes: 61 additions & 0 deletions generative/3d_ldm/config/config_train_32g.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
{
"channel": 0,
"spacing": [1.1, 1.1, 1.1],
"spatial_dims": 3,
"image_channels": 1,
"latent_channels": 8,
"autoencoder_def": {
"_target_": "generative.networks.nets.AutoencoderKL",
"spatial_dims": "@spatial_dims",
"in_channels": "$@image_channels",
"out_channels": "@image_channels",
"latent_channels": "@latent_channels",
"num_channels": [
64,
128,
256
],
"num_res_blocks": 2,
"norm_num_groups": 32,
"norm_eps": 1e-06,
"attention_levels": [
false,
false,
false
],
"with_encoder_nonlocal_attn": false,
"with_decoder_nonlocal_attn": false
},
"autoencoder_train": {
"batch_size": 2,
"patch_size": [112,128,80],
"lr": 1e-5,
"perceptual_weight": 0.001,
"kl_weight": 1e-7,
"recon_loss": "l1",
"n_epochs": 1000,
"val_interval": 10
},
"diffusion_def": {
"_target_": "generative.networks.nets.DiffusionModelUNet",
"spatial_dims": "@spatial_dims",
"in_channels": "@latent_channels",
"out_channels": "@latent_channels",
"num_channels":[256, 256, 512],
"attention_levels":[false, true, true],
"num_head_channels":[0, 64, 64],
"num_res_blocks": 2
},
"diffusion_train": {
"batch_size": 3,
"patch_size": [144,176,112],
"lr": 1e-5,
"n_epochs": 10000,
"val_interval": 2
},
"NoiseScheduler": {
"num_train_timesteps": 1000,
"beta_start": 0.0015,
"beta_end": 0.0195
}
}
7 changes: 7 additions & 0 deletions generative/3d_ldm/config/environment.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"data_base_dir": "./dataset",
"model_dir": "./trained_weights/diffusion_3d",
"tfevent_path": "./tfevent/diffusion_3d",
"output_dir": "./output",
"resume_ckpt": false
}
Binary file added generative/3d_ldm/figs/ldm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added generative/3d_ldm/figs/syn_axial.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added generative/3d_ldm/figs/syn_cor.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added generative/3d_ldm/figs/syn_sag.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added generative/3d_ldm/figs/train_diffusion.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added generative/3d_ldm/figs/train_recon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added generative/3d_ldm/figs/val_diffusion.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added generative/3d_ldm/figs/val_recon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
112 changes: 112 additions & 0 deletions generative/3d_ldm/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import json
import logging
import os
import sys
from datetime import datetime
from pathlib import Path

import nibabel as nib
import numpy as np
import torch
from generative.inferers import LatentDiffusionInferer
from generative.networks.schedulers import DDPMScheduler
from monai.config import print_config
from monai.utils import set_determinism

from utils import define_instance


def main():
parser = argparse.ArgumentParser(description="PyTorch Latent Diffusion Model Inference")
parser.add_argument(
"-e",
"--environment-file",
default="./config/environment.json",
help="environment json file that stores environment path",
)
parser.add_argument(
"-c",
"--config-file",
default="./config/config_train_48g.json",
help="config json file that stores hyper-parameters",
)
parser.add_argument(
"-n",
"--num",
type=int,
default=1,
help="number of generated images",
)
args = parser.parse_args()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print_config()
torch.backends.cudnn.benchmark = True
torch.set_num_threads(4)

env_dict = json.load(open(args.environment_file, "r"))
config_dict = json.load(open(args.config_file, "r"))

for k, v in env_dict.items():
setattr(args, k, v)
for k, v in config_dict.items():
setattr(args, k, v)

set_determinism(42)

# load trained networks
autoencoder = define_instance(args, "autoencoder_def").to(device)
trained_g_path = os.path.join(args.model_dir, "autoencoder.pt")
autoencoder.load_state_dict(torch.load(trained_g_path))

diffusion_model = define_instance(args, "diffusion_def").to(device)
trained_diffusion_path = os.path.join(args.model_dir, "diffusion_unet.pt")
diffusion_model.load_state_dict(torch.load(trained_diffusion_path))

scheduler = DDPMScheduler(
num_train_timesteps=args.NoiseScheduler["num_train_timesteps"],
beta_schedule="scaled_linear",
beta_start=args.NoiseScheduler["beta_start"],
beta_end=args.NoiseScheduler["beta_end"],
)
inferer = LatentDiffusionInferer(scheduler, scale_factor=1.0)

Path(args.output_dir).mkdir(parents=True, exist_ok=True)
latent_shape = [p // 4 for p in args.diffusion_train["patch_size"]]
noise_shape = [1, args.latent_channels] + latent_shape

for _ in range(args.num):
noise = torch.randn(noise_shape, dtype=torch.float32).to(device)
with torch.no_grad():
synthetic_images = inferer.sample(
input_noise=noise,
autoencoder_model=autoencoder,
diffusion_model=diffusion_model,
scheduler=scheduler,
)
filename = os.path.join(args.output_dir, datetime.now().strftime("synimg_%Y%m%d_%H%M%S"))
final_img = nib.Nifti1Image(synthetic_images[0, 0, ...].unsqueeze(-1).cpu().numpy(), np.eye(4))
nib.save(final_img, filename)


if __name__ == "__main__":
logging.basicConfig(
stream=sys.stdout,
level=logging.INFO,
format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
main()
Loading

0 comments on commit e49afe3

Please sign in to comment.