Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Signed-off-by: YunLiu <[email protected]>
  • Loading branch information
KumoLiu committed Sep 30, 2024
1 parent 1ad4540 commit 405291f
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 40 deletions.
81 changes: 62 additions & 19 deletions generation/maisi/maisi_diff_unet_training_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,47 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "e3bf0346",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MONAI version: 1.4.0rc10\n",
"Numpy version: 1.24.4\n",
"Pytorch version: 2.5.0a0+872d972e41.nv24.08.01\n",
"MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False\n",
"MONAI rev id: cac21f6936a2e8d6e4e57e4e958f8e32aae1585e\n",
"MONAI __file__: /usr/local/lib/python3.10/dist-packages/monai/__init__.py\n",
"\n",
"Optional dependencies:\n",
"Pytorch Ignite version: 0.4.11\n",
"ITK version: 5.4.0\n",
"Nibabel version: 5.2.1\n",
"scikit-image version: 0.23.2\n",
"scipy version: 1.13.1\n",
"Pillow version: 10.4.0\n",
"Tensorboard version: 2.17.0\n",
"gdown version: 5.2.0\n",
"TorchVision version: 0.20.0a0\n",
"tqdm version: 4.66.4\n",
"lmdb version: 1.5.1\n",
"psutil version: 5.9.8\n",
"pandas version: 2.2.2\n",
"einops version: 0.7.0\n",
"transformers version: 4.40.2\n",
"mlflow version: 2.16.0\n",
"pynrrd version: 1.0.0\n",
"clearml version: 1.16.3\n",
"\n",
"For details about installing the optional dependencies, please visit:\n",
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
"\n"
]
}
],
"source": [
"from scripts.diff_model_setting import setup_logging\n",
"import copy\n",
Expand Down Expand Up @@ -336,6 +373,8 @@
" model_config_filepath,\n",
" \"--model_def\",\n",
" model_def_filepath,\n",
" \"--num_gpus\",\n",
" str(num_gpus),\n",
"]\n",
"\n",
"run_torchrun(module, module_args, num_gpus=num_gpus)"
Expand Down Expand Up @@ -457,17 +496,17 @@
"INFO:training:[config] num_train_timesteps -> 1000.\n",
"INFO:training:num_files_train: 2\n",
"INFO:training:Training from scratch.\n",
"INFO:training:Scaling factor set to 0.89132159948349.\n",
"INFO:training:scale_factor -> 0.89132159948349.\n",
"INFO:training:Scaling factor set to 0.8903454542160034.\n",
"INFO:training:scale_factor -> 0.8903454542160034.\n",
"INFO:training:torch.set_float32_matmul_precision -> highest.\n",
"INFO:training:Epoch 1, lr 0.0001.\n",
"INFO:training:[2024-09-24 03:46:57] epoch 1, iter 1/2, loss: 0.7984, lr: 0.000100000000.\n",
"INFO:training:[2024-09-24 03:46:58] epoch 1, iter 2/2, loss: 0.7911, lr: 0.000056250000.\n",
"INFO:training:epoch 1 average loss: 0.7947.\n",
"INFO:training:[2024-09-30 06:30:33] epoch 1, iter 1/2, loss: 0.7974, lr: 0.000100000000.\n",
"INFO:training:[2024-09-30 06:30:33] epoch 1, iter 2/2, loss: 0.7939, lr: 0.000056250000.\n",
"INFO:training:epoch 1 average loss: 0.7957.\n",
"INFO:training:Epoch 2, lr 2.5e-05.\n",
"INFO:training:[2024-09-24 03:46:59] epoch 2, iter 1/2, loss: 0.7910, lr: 0.000025000000.\n",
"INFO:training:[2024-09-24 03:46:59] epoch 2, iter 2/2, loss: 0.7897, lr: 0.000006250000.\n",
"INFO:training:epoch 2 average loss: 0.7903.\n",
"INFO:training:[2024-09-30 06:30:35] epoch 2, iter 1/2, loss: 0.7902, lr: 0.000025000000.\n",
"INFO:training:[2024-09-30 06:30:35] epoch 2, iter 2/2, loss: 0.7889, lr: 0.000006250000.\n",
"INFO:training:epoch 2 average loss: 0.7895.\n",
"\n"
]
}
Expand All @@ -484,6 +523,8 @@
" model_config_filepath,\n",
" \"--model_def\",\n",
" model_def_filepath,\n",
" \"--num_gpus\",\n",
" str(num_gpus),\n",
"]\n",
"\n",
"run_torchrun(module, module_args, num_gpus=num_gpus)"
Expand Down Expand Up @@ -518,24 +559,24 @@
"output_type": "stream",
"text": [
"\n",
"INFO:inference:Using cuda:0 of 1 with random seed: 62801\n",
"INFO:inference:Using cuda:0 of 1 with random seed: 93612\n",
"INFO:inference:[config] ckpt_filepath -> ./temp_work_dir/./models/diff_unet_ckpt.pt.\n",
"INFO:inference:[config] random_seed -> 62801.\n",
"INFO:inference:[config] random_seed -> 93612.\n",
"INFO:inference:[config] output_prefix -> unet_3d.\n",
"INFO:inference:[config] output_size -> (256, 256, 128).\n",
"INFO:inference:[config] out_spacing -> (1.0, 1.0, 0.75).\n",
"INFO:root:`controllable_anatomy_size` is not provided.\n",
"INFO:inference:checkpoints ./temp_work_dir/./models/diff_unet_ckpt.pt loaded.\n",
"INFO:inference:scale_factor -> 0.89132159948349.\n",
"INFO:inference:scale_factor -> 0.8903454542160034.\n",
"INFO:inference:num_downsample_level -> 4, divisor -> 4.\n",
"INFO:inference:noise: cuda:0, torch.float32, <class 'torch.Tensor'>\n",
"\n",
" 0%| | 0/10 [00:00<?, ?it/s]\n",
" 10%|███████▍ | 1/10 [00:00<00:02, 3.62it/s]\n",
" 40%|█████████████████████████████▌ | 4/10 [00:00<00:00, 12.53it/s]\n",
" 80%|███████████████████████████████████████████████████████████▏ | 8/10 [00:00<00:00, 19.54it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 18.16it/s]\n",
"INFO:inference:Saved ./temp_work_dir/./predictions/unet_3d_seed62801_size256x256x128_spacing1.00x1.00x0.75_20240924034721.nii.gz.\n",
" 10%|███████▍ | 1/10 [00:00<00:02, 3.48it/s]\n",
" 40%|█████████████████████████████▌ | 4/10 [00:00<00:00, 12.23it/s]\n",
" 80%|███████████████████████████████████████████████████████████▏ | 8/10 [00:00<00:00, 19.26it/s]\n",
"100%|█████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 17.80it/s]\n",
"INFO:inference:Saved ./temp_work_dir/./predictions/unet_3d_seed93612_size256x256x128_spacing1.00x1.00x0.75_20240930063144_rank0.nii.gz.\n",
"\n"
]
}
Expand All @@ -552,6 +593,8 @@
" model_config_filepath,\n",
" \"--model_def\",\n",
" model_def_filepath,\n",
" \"--num_gpus\",\n",
" str(num_gpus),\n",
"]\n",
"\n",
"run_torchrun(module, module_args, num_gpus=num_gpus)\n",
Expand All @@ -562,7 +605,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand Down
9 changes: 6 additions & 3 deletions generation/maisi/scripts/diff_model_create_training_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def process_file(


@torch.inference_mode()
def diff_model_create_training_data(env_config_path: str, model_config_path: str, model_def_path: str) -> None:
def diff_model_create_training_data(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
"""
Create training data for the diffusion model.
Expand All @@ -170,7 +170,7 @@ def diff_model_create_training_data(env_config_path: str, model_config_path: str
model_def_path (str): Path to the model definition file.
"""
args = load_config(env_config_path, model_config_path, model_def_path)
local_rank, world_size, device = initialize_distributed()
local_rank, world_size, device = initialize_distributed(num_gpus=num_gpus)
logger = setup_logging("creating training data")
logger.info(f"Using device {device}")

Expand Down Expand Up @@ -224,6 +224,9 @@ def diff_model_create_training_data(env_config_path: str, model_config_path: str
parser.add_argument(
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
)
parser.add_argument(
"--num_gpus", type=int, default=1, help="Number of GPUs to use for distributed training"
)

args = parser.parse_args()
diff_model_create_training_data(args.env_config, args.model_config, args.model_def)
diff_model_create_training_data(args.env_config, args.model_config, args.model_def, args.num_gpus)
12 changes: 9 additions & 3 deletions generation/maisi/scripts/diff_model_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def save_image(


@torch.inference_mode()
def diff_model_infer(env_config_path: str, model_config_path: str, model_def_path: str) -> None:
def diff_model_infer(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
"""
Main function to run the diffusion model inference.
Expand All @@ -221,7 +221,7 @@ def diff_model_infer(env_config_path: str, model_config_path: str, model_def_pat
model_def_path (str): Path to the model definition file.
"""
args = load_config(env_config_path, model_config_path, model_def_path)
local_rank, world_size, device = initialize_distributed()
local_rank, world_size, device = initialize_distributed(num_gpus)
logger = setup_logging("inference")
random_seed = set_random_seed(
args.diffusion_unet_inference["random_seed"] + local_rank
Expand Down Expand Up @@ -311,6 +311,12 @@ def diff_model_infer(env_config_path: str, model_config_path: str, model_def_pat
default="./configs/config_maisi.json",
help="Path to model definition file",
)
parser.add_argument(
"--num_gpus",
type=int,
default=1,
help="Number of GPUs to use for distributed inference",
)

args = parser.parse_args()
diff_model_infer(args.env_config, args.model_config, args.model_def)
diff_model_infer(args.env_config, args.model_config, args.model_def, args.num_gpus)
4 changes: 2 additions & 2 deletions generation/maisi/scripts/diff_model_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ def load_config(env_config_path: str, model_config_path: str, model_def_path: st
return args


def initialize_distributed() -> tuple:
def initialize_distributed(num_gpus) -> tuple:
"""
Initialize distributed training.
Returns:
tuple: local_rank, world_size, and device.
"""
if torch.cuda.is_available() and torch.cuda.device_count() > 1:
if torch.cuda.is_available() and num_gpus > 1:
dist.init_process_group(backend="nccl", init_method="env://")
local_rank = dist.get_rank()
world_size = dist.get_world_size()
Expand Down
30 changes: 17 additions & 13 deletions generation/maisi/scripts/diff_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,14 @@ def load_unet(args: argparse.Namespace, device: torch.device, logger: logging.Lo
unet = define_instance(args, "diffusion_unet_def").to(device)
unet = torch.nn.SyncBatchNorm.convert_sync_batchnorm(unet)

if torch.cuda.device_count() > 1:
if dist.is_initialized():
unet = DistributedDataParallel(unet, device_ids=[device], find_unused_parameters=True)

if args.existing_ckpt_filepath is None:
logger.info("Training from scratch.")
else:
checkpoint_unet = torch.load(f"{args.existing_ckpt_filepath}", map_location=device)
if torch.cuda.device_count() > 1:
if dist.is_initialized():
unet.module.load_state_dict(checkpoint_unet["unet_state_dict"], strict=True)
else:
unet.load_state_dict(checkpoint_unet["unet_state_dict"], strict=True)
Expand Down Expand Up @@ -143,8 +143,9 @@ def calculate_scale_factor(
scale_factor = 1 / torch.std(z)
logger.info(f"Scaling factor set to {scale_factor}.")

dist.barrier()
dist.all_reduce(scale_factor, op=torch.distributed.ReduceOp.AVG)
if dist.is_initialized():
dist.barrier()
dist.all_reduce(scale_factor, op=torch.distributed.ReduceOp.AVG)
logger.info(f"scale_factor -> {scale_factor}.")
return scale_factor

Expand Down Expand Up @@ -271,7 +272,7 @@ def train_one_epoch(
)
)

if torch.cuda.device_count() > 1:
if dist.is_initialized():
dist.all_reduce(loss_torch, op=torch.distributed.ReduceOp.SUM)

return loss_torch
Expand All @@ -298,7 +299,7 @@ def save_checkpoint(
ckpt_folder (str): Checkpoint folder path.
args (argparse.Namespace): Configuration arguments.
"""
unet_state_dict = unet.module.state_dict() if torch.cuda.device_count() > 1 else unet.state_dict()
unet_state_dict = unet.module.state_dict() if dist.is_initialized() else unet.state_dict()
torch.save(
{
"epoch": epoch + 1,
Expand All @@ -311,7 +312,7 @@ def save_checkpoint(
)


def diff_model_train(env_config_path: str, model_config_path: str, model_def_path: str) -> None:
def diff_model_train(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
"""
Main function to train a diffusion model.
Expand All @@ -321,7 +322,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
model_def_path (str): Path to the model definition file.
"""
args = load_config(env_config_path, model_config_path, model_def_path)
local_rank, world_size, device = initialize_distributed()
local_rank, world_size, device = initialize_distributed(num_gpus)
logger = setup_logging("training")

logger.info(f"Using {device} of {world_size}")
Expand Down Expand Up @@ -350,10 +351,10 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
train_files.append(
{"image": str_img, "top_region_index": str_info, "bottom_region_index": str_info, "spacing": str_info}
)

train_files = partition_dataset(
data=train_files, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True
)[local_rank]
if dist.is_initialized():
train_files = partition_dataset(
data=train_files, shuffle=True, num_partitions=dist.get_world_size(), even_divisible=True
)[local_rank]

train_loader = prepare_data(
train_files, device, args.diffusion_unet_train["cache_rate"], args.diffusion_unet_train["batch_size"]
Expand Down Expand Up @@ -429,6 +430,9 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
parser.add_argument(
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
)
parser.add_argument(
"--num_gpus", type=int, default=1, help="Number of GPUs to use for training"
)

args = parser.parse_args()
diff_model_train(args.env_config, args.model_config, args.model_def)
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus)

0 comments on commit 405291f

Please sign in to comment.