diff --git a/generation/2d_ldm/README.md b/generation/2d_ldm/README.md index 18b8b522c..41af57e4d 100644 --- a/generation/2d_ldm/README.md +++ b/generation/2d_ldm/README.md @@ -57,6 +57,7 @@ torchrun \ --master_addr=localhost --master_port=1234 \ train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE} ``` +Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.
@@ -88,6 +89,8 @@ torchrun \
--master_addr=localhost --master_port=1234 \
train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
```
+Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.
+
diff --git a/generation/3d_ldm/README.md b/generation/3d_ldm/README.md
index 3bb741757..0d928d2e9 100644
--- a/generation/3d_ldm/README.md
+++ b/generation/3d_ldm/README.md
@@ -61,6 +61,7 @@ torchrun \
--master_addr=localhost --master_port=1234 \
train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
```
+Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.
@@ -87,6 +88,7 @@ torchrun \
--master_addr=localhost --master_port=1234 \
train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}
```
+Please note that during multi-GPU training, additional GPU memory may be required. Users might need to reduce the `batch_size` accordingly based on their available resources to ensure smooth training.
diff --git a/generation/maisi/maisi_diff_unet_training_tutorial.ipynb b/generation/maisi/maisi_diff_unet_training_tutorial.ipynb
index 883fdf685..770474960 100644
--- a/generation/maisi/maisi_diff_unet_training_tutorial.ipynb
+++ b/generation/maisi/maisi_diff_unet_training_tutorial.ipynb
@@ -429,7 +429,9 @@
"\n",
"After all latent features have been created, we will initiate the multi-GPU script to train the latent diffusion model.\n",
"\n",
- "The image generation process utilizes the [DDPM scheduler](https://arxiv.org/pdf/2006.11239) with 1,000 iterative steps. The diffusion model is optimized using L1 loss and a decayed learning rate scheduler. The batch size for this process is set to 1."
+ "The image generation process utilizes the [DDPM scheduler](https://arxiv.org/pdf/2006.11239) with 1,000 iterative steps. The diffusion model is optimized using L1 loss and a decayed learning rate scheduler. The batch size for this process is set to 1.\n",
+ "\n",
+ "Please be aware that using the H100 GPU may occasionally result in random segmentation faults. To avoid this issue, you can disable AMP by setting the `--no_amp` flag."
]
},
{
diff --git a/generation/maisi/scripts/diff_model_train.py b/generation/maisi/scripts/diff_model_train.py
index e6bfcdd7c..abe36c4be 100644
--- a/generation/maisi/scripts/diff_model_train.py
+++ b/generation/maisi/scripts/diff_model_train.py
@@ -24,7 +24,7 @@
from torch.nn.parallel import DistributedDataParallel
import monai
-from monai.data import ThreadDataLoader, partition_dataset
+from monai.data import DataLoader, partition_dataset
from monai.transforms import Compose
from monai.utils import first
@@ -50,7 +50,7 @@ def load_filenames(data_list_path: str) -> list:
def prepare_data(
train_files: list, device: torch.device, cache_rate: float, num_workers: int = 2, batch_size: int = 1
-) -> ThreadDataLoader:
+) -> DataLoader:
"""
Prepare training data.
@@ -62,7 +62,7 @@ def prepare_data(
batch_size (int): Mini-batch size.
Returns:
- ThreadDataLoader: Data loader for training.
+ DataLoader: Data loader for training.
"""
def _load_data_from_file(file_path, key):
@@ -90,7 +90,7 @@ def _load_data_from_file(file_path, key):
data=train_files, transform=train_transforms, cache_rate=cache_rate, num_workers=num_workers
)
- return ThreadDataLoader(train_ds, num_workers=6, batch_size=batch_size, shuffle=True)
+ return DataLoader(train_ds, num_workers=6, batch_size=batch_size, shuffle=True)
def load_unet(args: argparse.Namespace, device: torch.device, logger: logging.Logger) -> torch.nn.Module:
@@ -124,14 +124,12 @@ def load_unet(args: argparse.Namespace, device: torch.device, logger: logging.Lo
return unet
-def calculate_scale_factor(
- train_loader: ThreadDataLoader, device: torch.device, logger: logging.Logger
-) -> torch.Tensor:
+def calculate_scale_factor(train_loader: DataLoader, device: torch.device, logger: logging.Logger) -> torch.Tensor:
"""
Calculate the scaling factor for the dataset.
Args:
- train_loader (ThreadDataLoader): Data loader for training.
+ train_loader (DataLoader): Data loader for training.
device (torch.device): Device to use for calculation.
logger (logging.Logger): Logger for logging information.
@@ -181,7 +179,7 @@ def create_lr_scheduler(optimizer: torch.optim.Optimizer, total_steps: int) -> t
def train_one_epoch(
epoch: int,
unet: torch.nn.Module,
- train_loader: ThreadDataLoader,
+ train_loader: DataLoader,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.PolynomialLR,
loss_pt: torch.nn.L1Loss,
@@ -193,6 +191,7 @@ def train_one_epoch(
device: torch.device,
logger: logging.Logger,
local_rank: int,
+ amp: bool = True,
) -> torch.Tensor:
"""
Train the model for one epoch.
@@ -200,7 +199,7 @@ def train_one_epoch(
Args:
epoch (int): Current epoch number.
unet (torch.nn.Module): UNet model.
- train_loader (ThreadDataLoader): Data loader for training.
+ train_loader (DataLoader): Data loader for training.
optimizer (torch.optim.Optimizer): Optimizer.
lr_scheduler (torch.optim.lr_scheduler.PolynomialLR): Learning rate scheduler.
loss_pt (torch.nn.L1Loss): Loss function.
@@ -212,6 +211,7 @@ def train_one_epoch(
device (torch.device): Device to use for training.
logger (logging.Logger): Logger for logging information.
local_rank (int): Local rank for distributed training.
+ amp (bool): Use automatic mixed precision training.
Returns:
torch.Tensor: Training loss for the epoch.
@@ -237,7 +237,7 @@ def train_one_epoch(
optimizer.zero_grad(set_to_none=True)
- with autocast("cuda", enabled=True):
+ with autocast("cuda", enabled=amp):
noise = torch.randn(
(num_images_per_batch, 4, images.size(-3), images.size(-2), images.size(-1)), device=device
)
@@ -256,9 +256,13 @@ def train_one_epoch(
loss = loss_pt(noise_pred.float(), noise.float())
- scaler.scale(loss).backward()
- scaler.step(optimizer)
- scaler.update()
+ if amp:
+ scaler.scale(loss).backward()
+ scaler.step(optimizer)
+ scaler.update()
+ else:
+ loss.backward()
+ optimizer.step()
lr_scheduler.step()
@@ -312,7 +316,9 @@ def save_checkpoint(
)
-def diff_model_train(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
+def diff_model_train(
+ env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int, amp: bool = True
+) -> None:
"""
Main function to train a diffusion model.
@@ -320,6 +326,8 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
env_config_path (str): Path to the environment configuration file.
model_config_path (str): Path to the model configuration file.
model_def_path (str): Path to the model definition file.
+ num_gpus (int): Number of GPUs to use for training.
+ amp (bool): Use automatic mixed precision training.
"""
args = load_config(env_config_path, model_config_path, model_def_path)
local_rank, world_size, device = initialize_distributed(num_gpus)
@@ -357,7 +365,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
)[local_rank]
train_loader = prepare_data(
- train_files, device, args.diffusion_unet_train["cache_rate"], args.diffusion_unet_train["batch_size"]
+ train_files, device, args.diffusion_unet_train["cache_rate"], batch_size=args.diffusion_unet_train["batch_size"]
)
unet = load_unet(args, device, logger)
@@ -392,6 +400,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
device,
logger,
local_rank,
+ amp=amp,
)
loss_torch = loss_torch.tolist()
@@ -431,6 +440,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
)
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for training")
+ parser.add_argument("--no_amp", dest="amp", action="store_false", help="Disable automatic mixed precision training")
args = parser.parse_args()
- diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus)
+ diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus, args.amp)