diff --git a/auto3dseg/README.md b/auto3dseg/README.md index 6a873a5e..f7adada0 100644 --- a/auto3dseg/README.md +++ b/auto3dseg/README.md @@ -7,6 +7,7 @@ A unit test script is provided to evaluate the integrity of all algorithm templa ``` python auto3dseg/tests/test_algo_templates.py +python auto3dseg/tests/test_gpu_customization.py ``` ## Adding new templates diff --git a/auto3dseg/algorithm_templates/dints/scripts/algo.py b/auto3dseg/algorithm_templates/dints/scripts/algo.py index 95b32366..38587c27 100644 --- a/auto3dseg/algorithm_templates/dints/scripts/algo.py +++ b/auto3dseg/algorithm_templates/dints/scripts/algo.py @@ -294,13 +294,15 @@ def objective(trial): cmd += f"--num_sw_batch_size {num_sw_batch_size} " cmd += f"--validation_data_device {validation_data_device}" _ = subprocess.run(cmd.split(), check=True) - except: - print("[error] OOM") - return ( - float(num_images_per_batch) - * float(num_sw_batch_size) - * device_factor - ) + except RuntimeError as e: + if "out of memory" in str(e): + return ( + float(num_images_per_batch) + * float(num_sw_batch_size) + * device_factor + ) + else: + raise(e) value = ( -1.0 diff --git a/auto3dseg/algorithm_templates/dints/scripts/dummy_runner.py b/auto3dseg/algorithm_templates/dints/scripts/dummy_runner.py index 1b0d92c1..675dc36c 100644 --- a/auto3dseg/algorithm_templates/dints/scripts/dummy_runner.py +++ b/auto3dseg/algorithm_templates/dints/scripts/dummy_runner.py @@ -69,22 +69,12 @@ def __init__(self, output_path, data_stats_file, device_id: int = 0): pixdim = parser.get_parsed_content("transforms_train#transforms#3#pixdim") pixdim = [np.abs(pixdim[_i]) for _i in range(3)] - self.max_shape = [0, 0, 0] - for _k in range(len(data_stat["stats_by_cases"])): - image_shape = data_stat["stats_by_cases"][_k]["image_stats"]["shape"] - image_shape = np.squeeze(image_shape) - image_spacing = data_stat["stats_by_cases"][_k]["image_stats"]["spacing"] - image_spacing = np.squeeze(image_spacing) - image_spacing = [np.abs(image_spacing[_i]) for _i in range(3)] - - new_shape = [ - int( - np.ceil(float(image_shape[_l]) * image_spacing[_l] / pixdim[_l]) - ) - for _l in range(3) - ] - if np.prod(new_shape) > np.prod(self.max_shape): - self.max_shape = new_shape + if "sizemm" not in data_stat["stats_summary"]["image_stats"]: + raise ValueError("The data stats file is generated by older version of MONAI. " + "Please update MONAI >= 1.2 and re-run the data analyzer on your dataset.") + + image_size_mm = data_stat["stats_summary"]["image_stats"]["sizemm"]["percentile_99_5"] + self.max_shape = [int(np.ceil(image_size_mm[_l] / pixdim[_l])) for _l in range(3)] print("max_shape", self.max_shape) def run( diff --git a/auto3dseg/algorithm_templates/segresnet2d/scripts/algo.py b/auto3dseg/algorithm_templates/segresnet2d/scripts/algo.py index dbebb893..b3bd2195 100644 --- a/auto3dseg/algorithm_templates/segresnet2d/scripts/algo.py +++ b/auto3dseg/algorithm_templates/segresnet2d/scripts/algo.py @@ -281,13 +281,15 @@ def objective(trial): cmd += f"--num_sw_batch_size {num_sw_batch_size} " cmd += f"--validation_data_device {validation_data_device}" _ = subprocess.run(cmd.split(), check=True) - except: - print("[error] OOM") - return ( - float(num_images_per_batch) - * float(num_sw_batch_size) - * device_factor - ) + except RuntimeError as e: + if "out of memory" in str(e): + return ( + float(num_images_per_batch) + * float(num_sw_batch_size) + * device_factor + ) + else: + raise(e) value = ( -1.0 diff --git a/auto3dseg/algorithm_templates/segresnet2d/scripts/dummy_runner.py b/auto3dseg/algorithm_templates/segresnet2d/scripts/dummy_runner.py index 257085c3..4767af09 100644 --- a/auto3dseg/algorithm_templates/segresnet2d/scripts/dummy_runner.py +++ b/auto3dseg/algorithm_templates/segresnet2d/scripts/dummy_runner.py @@ -69,23 +69,15 @@ def __init__(self, output_path, data_stats_file, device_id: int = 0): pixdim = parser.get_parsed_content("transforms_train#transforms#3#pixdim") pixdim = [np.abs(pixdim[_i]) for _i in range(3)] - self.max_shape = [0, 0, 0] - for _k in range(len(data_stat["stats_by_cases"])): - image_shape = data_stat["stats_by_cases"][_k]["image_stats"]["shape"] - image_shape = np.squeeze(image_shape) - image_spacing = data_stat["stats_by_cases"][_k]["image_stats"]["spacing"] - image_spacing = np.squeeze(image_spacing) - image_spacing = [np.abs(image_spacing[_i]) for _i in range(3)] - - new_shape = [ - int( - np.ceil(float(image_shape[_l]) * image_spacing[_l] / pixdim[_l]) - ) - for _l in range(2) - ] - new_shape += [image_shape[-1]] - if np.prod(new_shape) > np.prod(self.max_shape): - self.max_shape = new_shape + if "sizemm" not in data_stat["stats_summary"]["image_stats"]: + raise ValueError("The data stats file is generated by older version of MONAI. " + "Please update MONAI >= 1.2 and re-run the data analyzer on your dataset.") + + image_size_mm = data_stat["stats_summary"]["image_stats"]["sizemm"]["percentile_99_5"] + # the spacing pixdim[2] is -1. It may introduce errors and so it needs to be replaced + self.max_shape = [ + int(np.ceil(image_size_mm[_l] / pixdim[_l])) for _l in range(2) + ] + [data_stat["stats_summary"]["image_stats"]["shape"]["max"][2]] print("max_shape", self.max_shape) def run( diff --git a/auto3dseg/algorithm_templates/swinunetr/scripts/algo.py b/auto3dseg/algorithm_templates/swinunetr/scripts/algo.py index 68c88ae8..26b8ede1 100644 --- a/auto3dseg/algorithm_templates/swinunetr/scripts/algo.py +++ b/auto3dseg/algorithm_templates/swinunetr/scripts/algo.py @@ -256,13 +256,15 @@ def objective(trial): cmd += f"--num_sw_batch_size {num_sw_batch_size} " cmd += f"--validation_data_device {validation_data_device}" _ = subprocess.run(cmd.split(), check=True) - except: - print("[error] OOM") - return ( - float(num_images_per_batch) - * float(num_sw_batch_size) - * device_factor - ) + except RuntimeError as e: + if "out of memory" in str(e): + return ( + float(num_images_per_batch) + * float(num_sw_batch_size) + * device_factor + ) + else: + raise(e) value = ( -1.0 diff --git a/auto3dseg/algorithm_templates/swinunetr/scripts/dummy_runner.py b/auto3dseg/algorithm_templates/swinunetr/scripts/dummy_runner.py index f219d984..1bfb5200 100644 --- a/auto3dseg/algorithm_templates/swinunetr/scripts/dummy_runner.py +++ b/auto3dseg/algorithm_templates/swinunetr/scripts/dummy_runner.py @@ -68,22 +68,12 @@ def __init__(self, output_path, data_stats_file, device_id: int = 0): pixdim = parser.get_parsed_content("transforms_train#transforms#3#pixdim") pixdim = [np.abs(pixdim[_i]) for _i in range(3)] - self.max_shape = [0, 0, 0] - for _k in range(len(data_stat["stats_by_cases"])): - image_shape = data_stat["stats_by_cases"][_k]["image_stats"]["shape"] - image_shape = np.squeeze(image_shape) - image_spacing = data_stat["stats_by_cases"][_k]["image_stats"]["spacing"] - image_spacing = np.squeeze(image_spacing) - image_spacing = [np.abs(image_spacing[_i]) for _i in range(3)] - - new_shape = [ - int( - np.ceil(float(image_shape[_l]) * image_spacing[_l] / pixdim[_l]) - ) - for _l in range(3) - ] - if np.prod(new_shape) > np.prod(self.max_shape): - self.max_shape = new_shape + if "sizemm" not in data_stat["stats_summary"]["image_stats"]: + raise ValueError("The data stats file is generated by older version of MONAI. " + "Please update MONAI >= 1.2 and re-run the data analyzer on your dataset.") + + image_size_mm = data_stat["stats_summary"]["image_stats"]["sizemm"]["percentile_99_5"] + self.max_shape = [int(np.ceil(image_size_mm[_l] / pixdim[_l])) for _l in range(3)] print("max_shape", self.max_shape) def run( diff --git a/auto3dseg/tests/test_gpu_customization.py b/auto3dseg/tests/test_gpu_customization.py new file mode 100644 index 00000000..afdc5728 --- /dev/null +++ b/auto3dseg/tests/test_gpu_customization.py @@ -0,0 +1,166 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import unittest + +import torch +import nibabel as nib +import numpy as np +from parameterized import parameterized +import shutil + +from monai.apps.auto3dseg import AlgoEnsembleBestByFold, AlgoEnsembleBestN, AlgoEnsembleBuilder, BundleGen, DataAnalyzer +from monai.bundle.config_parser import ConfigParser +from monai.data import create_test_image_3d + +sim_datalist = { + "testing": [ + {"image": "val_image_001.nii.gz", "label": "val_label_001.nii.gz"}, + {"image": "val_image_002.nii.gz", "label": "val_label_002.nii.gz"}, + ], + "training": [ + {"fold": 0, "image": "tr_image_001.nii.gz", "label": "tr_label_001.nii.gz"}, + {"fold": 0, "image": "tr_image_002.nii.gz", "label": "tr_label_002.nii.gz"}, + {"fold": 0, "image": "tr_image_003.nii.gz", "label": "tr_label_003.nii.gz"}, + {"fold": 0, "image": "tr_image_004.nii.gz", "label": "tr_label_004.nii.gz"}, + {"fold": 1, "image": "tr_image_005.nii.gz", "label": "tr_label_005.nii.gz"}, + {"fold": 1, "image": "tr_image_006.nii.gz", "label": "tr_label_006.nii.gz"}, + {"fold": 1, "image": "tr_image_007.nii.gz", "label": "tr_label_007.nii.gz"}, + {"fold": 1, "image": "tr_image_008.nii.gz", "label": "tr_label_008.nii.gz"}, + {"fold": 2, "image": "tr_image_009.nii.gz", "label": "tr_label_009.nii.gz"}, + {"fold": 2, "image": "tr_image_010.nii.gz", "label": "tr_label_010.nii.gz"}, + {"fold": 2, "image": "tr_image_011.nii.gz", "label": "tr_label_011.nii.gz"}, + {"fold": 2, "image": "tr_image_012.nii.gz", "label": "tr_label_012.nii.gz"}, + ], +} + +algo_templates = os.path.join("auto3dseg", "algorithm_templates") + +sys.path.insert(0, algo_templates) + +num_gpus = 4 if torch.cuda.device_count() > 4 else torch.cuda.device_count() +num_images_per_batch = 2 +num_epochs = 2 +num_epochs_per_validation = 1 +num_warmup_epochs = 1 + +train_param = { + "CUDA_VISIBLE_DEVICES": [x for x in range(num_gpus)], + "num_epochs_per_validation": num_epochs_per_validation, + "num_images_per_batch": num_images_per_batch, + "num_epochs": num_epochs, + "num_warmup_epochs": num_warmup_epochs, + "use_pretrain": False, + "pretrained_path": "", +} + +pred_param = {"files_slices": slice(0, 1), "mode": "mean", "sigmoid": True} + +gpu_customization_specs = { + "universal": {"num_trials": 1, "range_num_images_per_batch": [1, 2], "range_num_sw_batch_size": [1, 2]} +} + +SIM_TEST_CASES = [ + [{"sim_dim": (320, 320, 15), "modality": "MRI"}], +] + +def create_sim_data(dataroot, sim_datalist, sim_dim, **kwargs): + """ + Create simulated data using create_test_image_3d. + + Args: + dataroot: data directory path that hosts the "nii.gz" image files. + sim_datalist: a list of data to create. + sim_dim: the image sizes, e.g. a tuple of (64, 64, 64). + """ + if not os.path.isdir(dataroot): + os.makedirs(dataroot) + + # Generate a fake dataset + for d in sim_datalist["testing"] + sim_datalist["training"]: + im, seg = create_test_image_3d(sim_dim[0], sim_dim[1], sim_dim[2], **kwargs) + nib_image = nib.Nifti1Image(im, affine=np.eye(4)) + image_fpath = os.path.join(dataroot, d["image"]) + nib.save(nib_image, image_fpath) + + if "label" in d: + nib_image = nib.Nifti1Image(seg, affine=np.eye(4)) + label_fpath = os.path.join(dataroot, d["label"]) + nib.save(nib_image, label_fpath) + +def auto_run(work_dir, data_src_cfg, algos): + """ + Similar to Auto3DSeg AutoRunner, auto_run function executes the data analyzer, bundle generation, + and ensemble. + + Args: + work_dir: working directory path. + data_src_cfg: the input is a dictionary that includes dataroot, datalist and modality keys. + algos: the algorithm templates (a dictionary of Algo classes). + + Returns: + A list of predictions made the ensemble inference. + """ + + data_src_cfg_file = os.path.join(work_dir, "input.yaml") + ConfigParser.export_config_file(data_src_cfg, data_src_cfg_file, fmt="yaml") + + datastats_file = os.path.join(work_dir, "datastats.yaml") + analyser = DataAnalyzer(data_src_cfg["datalist"], data_src_cfg["dataroot"], output_path=datastats_file) + analyser.get_all_case_stats() + + bundle_generator = BundleGen( + algo_path=work_dir, + templates_path_or_url=algo_templates, + algos=algos, + data_stats_filename=datastats_file, + data_src_cfg_name=data_src_cfg_file + ) + bundle_generator.generate( + work_dir, num_fold=1, gpu_customization=True, gpu_customization_specs=gpu_customization_specs + ) + history = bundle_generator.get_history() + + for h in history: + for name, algo in h.items(): + algo.train(train_param) + + builder = AlgoEnsembleBuilder(history, data_src_cfg_file) + builder.set_ensemble_method(AlgoEnsembleBestN(n_best=len(history))) # inference all models + preds = builder.get_ensemble()(pred_param) + return preds + +class TestGpuCustomization(unittest.TestCase): + @parameterized.expand(SIM_TEST_CASES) + def test_sim(self, input_params) -> None: + work_dir = os.path.join('./tmp_sim_work_dir') + if not os.path.isdir(work_dir): + os.makedirs(work_dir) + + dataroot_dir = os.path.join(work_dir, "sim_dataroot") + datalist_file = os.path.join(work_dir, "sim_datalist.json") + ConfigParser.export_config_file(sim_datalist, datalist_file) + + sim_dim = input_params["sim_dim"] + create_sim_data( + dataroot_dir, sim_datalist, sim_dim, rad_max=max(int(min(sim_dim) / 4), 1), rad_min=1, num_seg_classes=1 + ) + + data_src_cfg = {"modality": input_params["modality"], "datalist": datalist_file, "dataroot": dataroot_dir} + preds = auto_run(work_dir, data_src_cfg, ["dints", "segresnet", "segresnet2d", "swinunetr"]) + self.assertTupleEqual(preds[0].shape, (2, sim_dim[0], sim_dim[1], sim_dim[2])) + + shutil.rmtree(work_dir) + +if __name__ == "__main__": + unittest.main()