Skip to content

Commit

Permalink
Initial version for multinode auto_runner and ensembler (#6272)
Browse files Browse the repository at this point in the history
Fixes #6191 #6259 .

### Description
Big changes over autorunner to enable multinode training and
multinode-multiGPU ensembler
Multiple changes:
1. Add set_device_info() to create a self.device_dict to define device
information (CUDA_VISIBLE_DEVICES, NUM_NODE, e.t.c.) for all parts in
autorunner, including data analyzer, trainer, ensembler. No global env
variable is set, all device info is from self.device_dict. Changes to
bundlegen is made.
2. To enable multi-gpu/multi-node training for ensembler (call from
subprocess), we need to separate the ensembler from autorunner (for
subprocess to run from autorunner). Created a new EnsembleRunner class
(similar to bundleGen), and moved all ensemble related function from
autorunner to this class. Local multi-GPU ensembling passed.

Passed some quick local testing. Needs to fix details and do test.
Created PR to do a initial design pattern discussion. Slack me if there
is any major concern of the change.
@mingxin-zheng @wyli

---------

Signed-off-by: heyufan1995 <heyufan1995@gmail.com>
heyufan1995 authored Apr 14, 2023
1 parent 3633b1c commit 825b8db
Showing 7 changed files with 485 additions and 164 deletions.
8 changes: 7 additions & 1 deletion monai/apps/auto3dseg/__init__.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,12 @@
from .auto_runner import AutoRunner
from .bundle_gen import BundleAlgo, BundleGen
from .data_analyzer import DataAnalyzer
from .ensemble_builder import AlgoEnsemble, AlgoEnsembleBestByFold, AlgoEnsembleBestN, AlgoEnsembleBuilder
from .ensemble_builder import (
AlgoEnsemble,
AlgoEnsembleBestByFold,
AlgoEnsembleBestN,
AlgoEnsembleBuilder,
EnsembleRunner,
)
from .hpo_gen import NNIGen, OptunaGen
from .utils import export_bundle_algo_history, import_bundle_algo_history
3 changes: 2 additions & 1 deletion monai/apps/auto3dseg/__main__.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
from monai.apps.auto3dseg.auto_runner import AutoRunner
from monai.apps.auto3dseg.bundle_gen import BundleAlgo, BundleGen
from monai.apps.auto3dseg.data_analyzer import DataAnalyzer
from monai.apps.auto3dseg.ensemble_builder import AlgoEnsembleBuilder
from monai.apps.auto3dseg.ensemble_builder import AlgoEnsembleBuilder, EnsembleRunner
from monai.apps.auto3dseg.hpo_gen import NNIGen, OptunaGen

if __name__ == "__main__":
@@ -27,6 +27,7 @@
"BundleGen": BundleGen,
"BundleAlgo": BundleAlgo,
"AlgoEnsembleBuilder": AlgoEnsembleBuilder,
"EnsembleRunner": EnsembleRunner,
"AutoRunner": AutoRunner,
"NNIGen": NNIGen,
"OptunaGen": OptunaGen,
222 changes: 110 additions & 112 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
@@ -19,25 +19,17 @@
from time import sleep
from typing import Any, cast

import numpy as np
import torch

from monai.apps.auto3dseg.bundle_gen import BundleGen
from monai.apps.auto3dseg.data_analyzer import DataAnalyzer
from monai.apps.auto3dseg.ensemble_builder import (
AlgoEnsemble,
AlgoEnsembleBestByFold,
AlgoEnsembleBestN,
AlgoEnsembleBuilder,
)
from monai.apps.auto3dseg.ensemble_builder import EnsembleRunner
from monai.apps.auto3dseg.hpo_gen import NNIGen
from monai.apps.auto3dseg.utils import export_bundle_algo_history, import_bundle_algo_history
from monai.apps.utils import get_logger
from monai.auto3dseg.utils import algo_to_pickle
from monai.bundle import ConfigParser
from monai.transforms import SaveImage
from monai.utils.enums import AlgoKeys
from monai.utils.module import look_up_option, optional_import
from monai.utils import AlgoKeys, has_option, look_up_option, optional_import

logger = get_logger(module_name=__name__)

@@ -232,6 +224,7 @@ def __init__(
self.data_src_cfg_name = os.path.join(self.work_dir, "input.yaml")
self.algos = algos
self.templates_path_or_url = templates_path_or_url
self.kwargs = deepcopy(kwargs)

if input is None and os.path.isfile(self.data_src_cfg_name):
input = self.data_src_cfg_name
@@ -285,16 +278,11 @@ def __init__(
self.ensemble = ensemble # last step, no need to check

self.set_training_params()
self.set_device_info()
self.set_prediction_params()
self.set_analyze_params()

self.save_image = self.set_image_save_transform(kwargs)

self.ensemble_method: AlgoEnsemble
self.ensemble_method_name: str | None = None

self.set_ensemble_method()
self.set_num_fold(num_fold=num_fold)
self.set_ensemble_method("AlgoEnsembleBestByFold")

self.gpu_customization = False
self.gpu_customization_specs: dict[str, Any] = {}
@@ -461,18 +449,11 @@ def set_num_fold(self, num_fold: int = 5) -> None:
Args:
num_fold: a positive integer to define the number of folds.
Notes:
If the ensemble method is ``AlgoEnsembleBestByFold``, this function automatically updates the ``n_fold``
parameter in the ``ensemble_method`` to avoid inconsistency between the training and the ensemble.
"""

if num_fold <= 0:
raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}")

self.num_fold = num_fold
if self.ensemble_method_name == "AlgoEnsembleBestByFold":
self.ensemble_method.n_fold = self.num_fold # type: ignore

def set_training_params(self, params: dict[str, Any] | None = None) -> None:
"""
@@ -488,6 +469,95 @@ def set_training_params(self, params: dict[str, Any] | None = None) -> None:
"""
self.train_params = deepcopy(params) if params is not None else {}
if "CUDA_VISIBLE_DEVICES" in self.train_params:
warnings.warn(
"CUDA_VISIBLE_DEVICES is deprecated from 'set_training_params'. Use 'set_device_info' intead.",
DeprecationWarning,
)

def set_device_info(
self,
cuda_visible_devices: list[int] | str | None = None,
num_nodes: int | None = None,
mn_start_method: str | None = None,
cmd_prefix: str | None = None,
) -> None:
"""
Set the device related info
Args:
cuda_visible_device: define GPU ids for data analyzer, training, and ensembling.
List of GPU ids [0,1,2,3] or a string "0,1,2,3".
Default using env "CUDA_VISIBLE_DEVICES" or all devices available.
num_nodes: number of nodes for training and ensembling.
Default using env "NUM_NODES" or 1 if "NUM_NODES" is unset.
mn_start_method: multi-node start method. Autorunner will use the method to start multi-node processes.
Default using env "MN_START_METHOD" or 'bcprun' if "MN_START_METHOD" is unset.
cmd_prefix: command line prefix for subprocess running in BundleAlgo and EnsembleRunner.
Default using env "CMD_PREFIX" or None, examples are:
- single GPU/CPU or multinode bcprun: "python " or "/opt/conda/bin/python3.8 ",
- single node multi-GPU running "torchrun --nnodes=1 --nproc_per_node=2 "
If user define this prefix, please make sure --nproc_per_node matches cuda_visible_device or
os.env['CUDA_VISIBLE_DEVICES]. Also always set --nnodes=1. Set num_nodes for multi-node.
"""
self.device_setting: dict[str, Any] = {}
if cuda_visible_devices is None:
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
if cuda_visible_devices is None: # still None after reading the environ
self.device_setting["CUDA_VISIBLE_DEVICES"] = ",".join([str(x) for x in range(torch.cuda.device_count())])
self.device_setting["n_devices"] = torch.cuda.device_count()
elif isinstance(cuda_visible_devices, str):
self.device_setting["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
self.device_setting["n_devices"] = len(cuda_visible_devices.split(","))
elif isinstance(cuda_visible_devices, (list, tuple)):
self.device_setting["CUDA_VISIBLE_DEVICES"] = ",".join([str(x) for x in cuda_visible_devices])
self.device_setting["n_devices"] = len(cuda_visible_devices)
else:
logger.warn(f"Wrong format of cuda_visible_devices {cuda_visible_devices}, devices not set")

if num_nodes is None:
num_nodes = int(os.environ.get("NUM_NODES", 1))
self.device_setting["NUM_NODES"] = num_nodes

if mn_start_method is None:
mn_start_method = os.environ.get("MN_START_METHOD", "bcprun")
self.device_setting["MN_START_METHOD"] = mn_start_method

if cmd_prefix is None:
cmd_prefix = os.environ.get("CMD_PREFIX")
self.device_setting["CMD_PREFIX"] = cmd_prefix

if cmd_prefix is not None:
logger.info(f"Using user defined command running prefix {cmd_prefix}, will overide other settings")

def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFold", **kwargs: Any) -> None:
"""
Set the bundle ensemble method name and parameters for save image transform parameters.
Args:
params: the name of the ensemble method. Only two methods are supported "AlgoEnsembleBestN"
and "AlgoEnsembleBestByFold".
kwargs: the keyword arguments used to define the ensemble method. Currently only ``n_best`` for
``AlgoEnsembleBestN`` is supported.
"""
self.ensemble_method_name = look_up_option(
ensemble_method_name, supported=["AlgoEnsembleBestN", "AlgoEnsembleBestByFold"]
)
self.kwargs.update(kwargs)

def set_image_save_transform(self, **kwargs: Any) -> None:
"""
Set the ensemble output transform.
Args:
kwargs: image writing parameters for the ensemble inference. The kwargs format follows SaveImage
transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage.
"""

self.kwargs.update(kwargs)

def set_prediction_params(self, params: dict[str, Any] | None = None) -> None:
"""
@@ -547,10 +617,7 @@ def set_hpo_params(self, params: dict[str, Any] | None = None) -> None:
Users can set ``nni_dry_run`` to ``True`` in the ``params`` to enable the dry-run mode for the NNI backend.
"""
if params is None:
self.hpo_params = self.train_params
else:
self.hpo_params = params
self.hpo_params = self.train_params if params is None else params

def set_nni_search_space(self, search_space):
"""
@@ -569,58 +636,6 @@ def set_nni_search_space(self, search_space):
self.search_space = search_space
self.hpo_tasks = value_combinations

def set_image_save_transform(self, kwargs):
"""
Set the ensemble output transform.
Args:
kwargs: image writing parameters for the ensemble inference. The kwargs format follows SaveImage
transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage .
"""

if "output_dir" in kwargs:
output_dir = kwargs.pop("output_dir")
else:
output_dir = os.path.join(self.work_dir, "ensemble_output")
logger.info(f"The output_dir is not specified. {output_dir} will be used to save ensemble predictions")

if not os.path.isdir(output_dir):
os.makedirs(output_dir)
logger.info(f"Directory {output_dir} is created to save ensemble predictions")

self.output_dir = output_dir
output_postfix = kwargs.pop("output_postfix", "ensemble")
output_dtype = kwargs.pop("output_dtype", np.uint8)
resample = kwargs.pop("resample", False)

return SaveImage(
output_dir=output_dir, output_postfix=output_postfix, output_dtype=output_dtype, resample=resample, **kwargs
)

def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFold", **kwargs: Any) -> None:
"""
Set the bundle ensemble method
Args:
ensemble_method_name: the name of the ensemble method. Only two methods are supported "AlgoEnsembleBestN"
and "AlgoEnsembleBestByFold".
kwargs: the keyword arguments used to define the ensemble method. Currently only ``n_best`` for
``AlgoEnsembleBestN`` is supported.
"""
self.ensemble_method_name = look_up_option(
ensemble_method_name, supported=["AlgoEnsembleBestN", "AlgoEnsembleBestByFold"]
)
if self.ensemble_method_name == "AlgoEnsembleBestN":
n_best = kwargs.pop("n_best", False)
n_best = 2 if not n_best else n_best
self.ensemble_method = AlgoEnsembleBestN(n_best=n_best)
elif self.ensemble_method_name == "AlgoEnsembleBestByFold":
self.ensemble_method = AlgoEnsembleBestByFold(n_fold=self.num_fold)
else:
raise NotImplementedError(f"Ensemble method {self.ensemble_method_name} is not implemented.")

def _train_algo_in_sequence(self, history: list[dict[str, Any]]) -> None:
"""
Train the Algos in a sequential scheme. The order of training is randomized.
@@ -637,7 +652,10 @@ def _train_algo_in_sequence(self, history: list[dict[str, Any]]) -> None:
"""
for algo_dict in history:
algo = algo_dict[AlgoKeys.ALGO]
algo.train(self.train_params)
if has_option(algo.train, "device_setting"):
algo.train(self.train_params, self.device_setting)
else:
algo.train(self.train_params)
acc = algo.get_score()

algo_meta_data = {str(AlgoKeys.SCORE): acc}
@@ -773,7 +791,7 @@ def run(self):

if auto_train_choice:
skip_algos = [h[AlgoKeys.ID] for h in history if h[AlgoKeys.IS_TRAINED]]
if len(skip_algos) > 0:
if skip_algos:
logger.info(
f"Skipping already trained algos {skip_algos}."
"Set option train=True to always retrain all algos."
@@ -792,34 +810,14 @@ def run(self):

# step 4: model ensemble and write the prediction to disks.
if self.ensemble:
history = import_bundle_algo_history(self.work_dir, only_trained=False)

history_untrained = [h for h in history if not h[AlgoKeys.IS_TRAINED]]
if len(history_untrained) > 0:
warnings.warn(
f"Ensembling step will skip {[h['name'] for h in history_untrained]} untrained algos."
"Generally it means these algos did not complete training."
)
history = [h for h in history if h[AlgoKeys.IS_TRAINED]]

if len(history) == 0:
raise ValueError(
f"Could not find any trained algos in {self.work_dir}. "
"Possibly the required training step was not completed."
)

builder = AlgoEnsembleBuilder(history, self.data_src_cfg_name)
builder.set_ensemble_method(self.ensemble_method)

ensembler = builder.get_ensemble()
preds = ensembler(pred_param=self.pred_params)
if len(preds) > 0:
logger.info("Auto3Dseg picked the following networks to ensemble:")
for algo in ensembler.get_algo_ensemble():
logger.info(algo[AlgoKeys.ID])

for pred in preds:
self.save_image(pred)
logger.info(f"Auto3Dseg ensemble prediction outputs are saved in {self.output_dir}.")

ensemble_runner = EnsembleRunner(
data_src_cfg_name=self.data_src_cfg_name,
work_dir=self.work_dir,
num_fold=self.num_fold,
ensemble_method_name=self.ensemble_method_name,
mgpu=int(self.device_setting["n_devices"]) > 1,
**self.kwargs, # for set_image_save_transform
**self.pred_params,
) # for inference
ensemble_runner.run(self.device_setting)
logger.info("Auto3Dseg pipeline is completed successfully.")
110 changes: 73 additions & 37 deletions monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
@@ -80,6 +80,14 @@ def __init__(self, template_path: str):
self.best_metric = None
# track records when filling template config: {"<config name>": {"<placeholder key>": value, ...}, ...}
self.fill_records: dict = {}
# device_setting set default value and sanity check, in case device_setting not from autorunner
self.device_setting: dict[str, int | str] = {
"CUDA_VISIBLE_DEVICES": ",".join([str(x) for x in range(torch.cuda.device_count())]),
"n_devices": int(torch.cuda.device_count()),
"NUM_NODES": int(os.environ.get("NUM_NODES", 1)),
"MN_START_METHOD": os.environ.get("MN_START_METHOD", "bcprun"),
"CMD_PREFIX": os.environ.get("CMD_PREFIX"), # type: ignore
}

def pre_check_skip_algo(self, skip_bundlegen: bool = False, skip_info: str = "") -> tuple[bool, str]:
"""
@@ -150,15 +158,16 @@ def export_to_disk(self, output_path: str, algo_name: str, **kwargs: Any) -> Non
self.output_path = self.template_path
if kwargs.pop("fill_template", True):
self.fill_records = self.fill_template_config(self.data_stats_files, self.output_path, **kwargs)
logger.info(self.output_path)
logger.info(f"Generated:{self.output_path}")

def _create_cmd(self, train_params=None):
def _create_cmd(self, train_params: None | dict = None) -> tuple[str, str]:
"""
Create the command to execute training.
"""
if train_params is not None:
params = deepcopy(train_params)
if train_params is None:
train_params = {}
params = deepcopy(train_params)

train_py = os.path.join(self.output_path, "scripts", "train.py")
config_dir = os.path.join(self.output_path, "configs")
@@ -168,53 +177,85 @@ def _create_cmd(self, train_params=None):
for file in os.listdir(config_dir):
if not (file.endswith("yaml") or file.endswith("json")):
continue
if len(base_cmd) == 0:
base_cmd += f"{train_py} run --config_file="
else:
base_cmd += "," # Python Fire does not accept space
base_cmd += f"{train_py} run --config_file=" if len(base_cmd) == 0 else ","
# Python Fire may be confused by single-quoted WindowsPath
config_yaml = Path(os.path.join(config_dir, file)).as_posix()
base_cmd += f"'{config_yaml}'"

if "CUDA_VISIBLE_DEVICES" in params:
devices = params.pop("CUDA_VISIBLE_DEVICES")
n_devices, devices_info = len(devices), ",".join([str(x) for x in devices])
else:
n_devices, devices_info = torch.cuda.device_count(), ""
if n_devices > 1:
cmd = f"torchrun --nnodes={1:d} --nproc_per_node={n_devices:d} "
cmd: str | None = self.device_setting["CMD_PREFIX"] # type: ignore
# make sure cmd end with a space
if cmd is not None and not cmd.endswith(" "):
cmd += " "
if (int(self.device_setting["NUM_NODES"]) > 1 and self.device_setting["MN_START_METHOD"] == "bcprun") or (
int(self.device_setting["NUM_NODES"]) <= 1 and int(self.device_setting["n_devices"]) <= 1
):
cmd = "python " if cmd is None else cmd
elif int(self.device_setting["NUM_NODES"]) > 1:
raise NotImplementedError(
f"{self.device_setting['MN_START_METHOD']} is not supported yet."
"Try modify BundleAlgo._create_cmd for your cluster."
)
else:
cmd = "python " # TODO: which system python?
if cmd is None:
cmd = f"torchrun --nnodes={1:d} --nproc_per_node={self.device_setting['n_devices']:d} "
cmd += base_cmd
if params and isinstance(params, Mapping):
for k, v in params.items():
cmd += f" --{k}={v}"
return cmd, devices_info
return cmd, ""

def _run_cmd(self, cmd: str, devices_info: str) -> subprocess.CompletedProcess:
def _run_cmd(self, cmd: str, devices_info: str = "") -> subprocess.CompletedProcess:
"""
Execute the training command with target devices information.
"""
if devices_info:
warnings.warn(f"input devices_info {devices_info} is deprecated and ignored.")

logger.info(f"Launching: {cmd}")
ps_environ = os.environ.copy()
if devices_info:
ps_environ["CUDA_VISIBLE_DEVICES"] = devices_info
normal_out = subprocess.run(cmd.split(), env=ps_environ, check=True)

ps_environ["CUDA_VISIBLE_DEVICES"] = str(self.device_setting["CUDA_VISIBLE_DEVICES"])
if int(self.device_setting["NUM_NODES"]) > 1:
if self.device_setting["MN_START_METHOD"] == "bcprun":
normal_out = subprocess.run(
[
"bcprun",
"-n",
str(self.device_setting["NUM_NODES"]),
"-p",
str(self.device_setting["n_devices"]),
"-c",
cmd,
],
env=ps_environ,
check=True,
)
else:
raise NotImplementedError(
f"{self.device_setting['MN_START_METHOD']} is not supported yet. "
"Try modify BundleAlgo._run_cmd for your cluster."
)
else:
normal_out = subprocess.run(cmd.split(), env=ps_environ, check=True)
return normal_out

def train(self, train_params=None):
def train(
self, train_params: None | dict = None, device_setting: None | dict = None
) -> subprocess.CompletedProcess:
"""
Load the run function in the training script of each model. Training parameter is predefined by the
algo_config.yaml file, which is pre-filled by the fill_template_config function in the same instance.
Args:
train_params: to specify the devices using a list of integers: ``{"CUDA_VISIBLE_DEVICES": [1,2,3]}``.
train_params: training parameters
device_settings: device related settings, should follow the device_setting in auto_runner.set_device_info.
'CUDA_VISIBLE_DEVICES' should be a string e.g. '0,1,2,3'
"""
cmd, devices_info = self._create_cmd(train_params)
return self._run_cmd(cmd, devices_info)
if device_setting is not None:
self.device_setting.update(device_setting)
self.device_setting["n_devices"] = len(str(self.device_setting["CUDA_VISIBLE_DEVICES"]).split(","))

cmd, _unused_return = self._create_cmd(train_params)
return self._run_cmd(cmd)

def get_score(self, *args, **kwargs):
"""
@@ -276,11 +317,7 @@ def predict(self, predict_files: list, predict_params: dict | None = None) -> li
predict_params: a dict to override the parameters in the bundle config (including the files to predict).
"""
if predict_params is None:
params = {}
else:
params = deepcopy(predict_params)

params = {} if predict_params is None else deepcopy(predict_params)
inferer = self.get_inferer(**params)
return [inferer.infer(f) for f in ensure_tuple(predict_files)]

@@ -355,7 +392,7 @@ def _copy_algos_folder(folder, at_path):
algos_all[name] = dict(
_target_=f"{name}.scripts.algo.{name.capitalize()}Algo", template_path=os.path.join(at_path, name)
)
logger.info(f"{name} -- {algos_all[name]}")
logger.info(f"Copying template: {name} -- {algos_all[name]}")
if not algos_all:
raise ValueError(f"Unable to find any algos in {folder}")

@@ -373,11 +410,10 @@ class BundleGen(AlgoGen):
by templates_path_or_url. Defaults to None - to use all available algorithms.
templates_path_or_url: the folder with the algorithm templates or a url. If None provided, the default template
zip url will be downloaded and extracted into the algo_path. The current default options are released at:
https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg
data_stats_filename: the path to the data stats file (generated by DataAnalyzer)
https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg.
data_stats_filename: the path to the data stats file (generated by DataAnalyzer).
data_src_cfg_name: the path to the data source config YAML file. The config will be in a form of
{"modality": "ct", "datalist": "path_to_json_datalist", "dataroot": "path_dir_data"}
{"modality": "ct", "datalist": "path_to_json_datalist", "dataroot": "path_dir_data"}.
.. code-block:: bash
python -m monai.apps.auto3dseg BundleGen generate --data_stats_filename="../algorithms/datastats.yaml"
277 changes: 265 additions & 12 deletions monai/apps/auto3dseg/ensemble_builder.py
Original file line number Diff line number Diff line change
@@ -12,24 +12,30 @@
from __future__ import annotations

import os
import subprocess
from abc import ABC, abstractmethod
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from copy import deepcopy
from typing import Any, cast
from warnings import warn

import numpy as np
import torch
import torch.distributed as dist

from monai.apps.auto3dseg.bundle_gen import BundleAlgo
from monai.apps.auto3dseg.utils import import_bundle_algo_history
from monai.apps.utils import get_logger
from monai.auto3dseg import concat_val_to_np
from monai.auto3dseg.utils import datafold_read
from monai.bundle import ConfigParser
from monai.data import partition_dataset
from monai.transforms import MeanEnsemble, VoteEnsemble
from monai.utils.enums import AlgoKeys
from monai.utils.misc import prob2class
from monai.utils.module import look_up_option
from monai.utils.module import look_up_option, optional_import

tqdm, has_tqdm = optional_import("tqdm", name="tqdm")

logger = get_logger(module_name=__name__)

@@ -88,7 +94,7 @@ def set_infer_files(self, dataroot: str, data_list_or_path: str | list, data_key
datalist = ConfigParser.load_config_file(data_list_or_path)
if data_key in datalist:
self.infer_files, _ = datafold_read(datalist=datalist, basedir=dataroot, fold=-1, key=data_key)
else:
elif hasattr(self, "rank") and self.rank == 0: # type: ignore
logger.info(f"Datalist file has no testing key - {data_key}. No data for inference is specified")

else:
@@ -117,7 +123,7 @@ def ensemble_pred(self, preds, sigmoid=False):
else:
return VoteEnsemble(num_classes=preds[0].shape[0])(classes)

def __call__(self, pred_param: dict[str, Any] | None = None) -> list[torch.Tensor]:
def __call__(self, pred_param: dict | None = None) -> list:
"""
Use the ensembled model to predict result.
@@ -135,11 +141,7 @@ def __call__(self, pred_param: dict[str, Any] | None = None) -> list[torch.Tenso
Returns:
A list of tensors.
"""
if pred_param is None:
param = {}
else:
param = deepcopy(pred_param)

param = {} if pred_param is None else deepcopy(pred_param)
files = self.infer_files

if "infer_files" in param:
@@ -155,15 +157,25 @@ def __call__(self, pred_param: dict[str, Any] | None = None) -> list[torch.Tenso

sigmoid = param.pop("sigmoid", False)

if "image_save_func" in param:
img_saver = ConfigParser(param["image_save_func"]).get_parsed_content()

outputs = []
for i, file in enumerate(files):
print(i)
for _, file in (
enumerate(tqdm(files, desc="Ensembling (rank 0)..."))
if has_tqdm and pred_param and pred_param.get("rank", 0) == 0
else enumerate(files)
):
preds = []
for algo in self.algo_ensemble:
infer_instance = algo[AlgoKeys.ALGO]
pred = infer_instance.predict(predict_files=[file], predict_params=param)
preds.append(pred[0])
outputs.append(self.ensemble_pred(preds, sigmoid=sigmoid))
if "image_save_func" in param:
res = img_saver(self.ensemble_pred(preds, sigmoid=sigmoid))
else:
res = self.ensemble_pred(preds, sigmoid=sigmoid)
outputs.append(res)
return outputs

@abstractmethod
@@ -327,3 +339,244 @@ def get_ensemble(self):
"""Get the ensemble"""

return self.ensemble


class EnsembleRunner:
"""
The Runner for ensembler
Args:
work_dir: working directory to save the intermediate and final results.
data_src_cfg_name: filename of the data source.
num_fold: number of fold.
ensemble_method_name: method to ensemble predictions from different model.
Suported methods: ["AlgoEnsembleBestN", "AlgoEnsembleBestByFold"].
mgpu: if using multi-gpu.
kwargs: additional image writing, ensembling parameters and prediction parameters for the ensemble inference.
Examples:
.. code-block:: python
ensemble_runner = EnsembleRunner(data_src_cfg_name,
work_dir,
ensemble_method_name,
mgpu=device_setting['n_devices']>1,
**kwargs,
**pred_params)
ensemble_runner.run(device_setting)
"""

def __init__(
self,
data_src_cfg_name: str = "./work_dir/input.yaml",
work_dir: str = "./work_dir",
num_fold: int = 5,
ensemble_method_name: str = "AlgoEnsembleBestByFold",
mgpu: bool = True,
**kwargs: Any,
) -> None:
self.data_src_cfg_name = data_src_cfg_name
self.work_dir = work_dir
self.num_fold = num_fold
self.ensemble_method_name = ensemble_method_name
self.mgpu = mgpu
self.kwargs = kwargs
self.rank = 0
self.world_size = 1
self.device_setting: dict[str, int | str] = {
"CUDA_VISIBLE_DEVICES": ",".join([str(x) for x in range(torch.cuda.device_count())]),
"n_devices": torch.cuda.device_count(),
"NUM_NODES": int(os.environ.get("NUM_NODES", 1)),
"MN_START_METHOD": os.environ.get("MN_START_METHOD", "bcprun"),
"CMD_PREFIX": os.environ.get("CMD_PREFIX"), # type: ignore
}

def set_ensemble_method(self, ensemble_method_name: str = "AlgoEnsembleBestByFold", **kwargs: Any) -> None:
"""
Set the bundle ensemble method
Args:
ensemble_method_name: the name of the ensemble method. Only two methods are supported "AlgoEnsembleBestN"
and "AlgoEnsembleBestByFold".
kwargs: the keyword arguments used to define the ensemble method. Currently only ``n_best`` for
``AlgoEnsembleBestN`` is supported.
"""
self.ensemble_method_name = look_up_option(
ensemble_method_name, supported=["AlgoEnsembleBestN", "AlgoEnsembleBestByFold"]
)
if self.ensemble_method_name == "AlgoEnsembleBestN":
n_best = kwargs.pop("n_best", False) or 2
self.ensemble_method = AlgoEnsembleBestN(n_best=n_best)
elif self.ensemble_method_name == "AlgoEnsembleBestByFold":
self.ensemble_method = AlgoEnsembleBestByFold(n_fold=self.num_fold) # type: ignore
else:
raise NotImplementedError(f"Ensemble method {self.ensemble_method_name} is not implemented.")

def set_image_save_transform(self, **kwargs):
"""
Set the ensemble output transform.
Args:
kwargs: image writing parameters for the ensemble inference. The kwargs format follows SaveImage
transform. For more information, check https://docs.monai.io/en/stable/transforms.html#saveimage .
"""

if "output_dir" in kwargs:
output_dir = kwargs.pop("output_dir")
else:
output_dir = os.path.join(self.work_dir, "ensemble_output")
if self.rank == 0:
logger.info(f"The output_dir is not specified. {output_dir} will be used to save ensemble predictions")

if not os.path.isdir(output_dir):
os.makedirs(output_dir)
if self.rank == 0:
logger.info(f"Directory {output_dir} is created to save ensemble predictions")

self.output_dir = output_dir
output_postfix = kwargs.pop("output_postfix", "ensemble")
output_dtype = kwargs.pop("output_dtype", "$np.uint8")
resample = kwargs.pop("resample", False)

self.save_image = {
"_target_": "SaveImage",
"output_dir": output_dir,
"output_postfix": output_postfix,
"output_dtype": output_dtype,
"resample": resample,
"print_log": False,
}
if kwargs:
self.save_image.update(kwargs)

def set_num_fold(self, num_fold: int = 5) -> None:
"""
Set the number of cross validation folds for all algos.
Args:
num_fold: a positive integer to define the number of folds.
"""

if num_fold <= 0:
raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}")
self.num_fold = num_fold

def ensemble(self):
if self.mgpu: # torch.cuda.device_count() is not used because env is not set by autorruner
# init multiprocessing and update infer_files
dist.init_process_group(backend="nccl", init_method="env://")
self.world_size = dist.get_world_size()
self.rank = dist.get_rank()
# set params after init_process_group to know the rank
self.set_num_fold(num_fold=self.num_fold)
self.set_image_save_transform(**self.kwargs)
self.set_ensemble_method(self.ensemble_method_name, **self.kwargs)

history = import_bundle_algo_history(self.work_dir, only_trained=False)
history_untrained = [h for h in history if not h[AlgoKeys.IS_TRAINED]]
if history_untrained:
if self.rank == 0:
warn(
f"Ensembling step will skip {[h['name'] for h in history_untrained]} untrained algos."
"Generally it means these algos did not complete training."
)
history = [h for h in history if h[AlgoKeys.IS_TRAINED]]
if len(history) == 0:
raise ValueError(
f"Could not find the trained results in {self.work_dir}. "
"Possibly the required training step was not completed."
)

builder = AlgoEnsembleBuilder(history, self.data_src_cfg_name)
builder.set_ensemble_method(self.ensemble_method)
self.ensembler = builder.get_ensemble()
infer_files = self.ensembler.infer_files
infer_files = partition_dataset(data=infer_files, shuffle=False, num_partitions=self.world_size)[self.rank]
# TO DO: Add some function in ensembler for infer_files update?
self.ensembler.infer_files = infer_files
# self.kwargs has poped out args for set_image_save_transform
# add rank to pred_params
self.kwargs["rank"] = self.rank
self.kwargs["image_save_func"] = self.save_image
if self.rank == 0:
logger.info("Auto3Dseg picked the following networks to ensemble:")
for algo in self.ensembler.get_algo_ensemble():
logger.info(algo[AlgoKeys.ID])
logger.info(f"Auto3Dseg ensemble prediction outputs will be saved in {self.output_dir}.")
self.ensembler(pred_param=self.kwargs)

if self.mgpu:
dist.destroy_process_group()

def run(self, device_setting: dict | None = None) -> None:
"""
Load the run function in the training script of each model. Training parameter is predefined by the
algo_config.yaml file, which is pre-filled by the fill_template_config function in the same instance.
Args:
train_params: training parameters
device_settings: device related settings, should follow the device_setting in auto_runner.set_device_info.
'CUDA_VISIBLE_DEVICES' should be a string e.g. '0,1,2,3'
"""
# device_setting set default value and sanity check, in case device_setting not from autorunner
if device_setting is not None:
self.device_setting.update(device_setting)
self.device_setting["n_devices"] = len(str(self.device_setting["CUDA_VISIBLE_DEVICES"]).split(","))
self._create_cmd()

def _create_cmd(self) -> None:
if int(self.device_setting["NUM_NODES"]) <= 1 and int(self.device_setting["n_devices"]) <= 1:
# if single GPU
logger.info("Ensembling using single GPU!")
self.ensemble()
return

# define base cmd for subprocess
base_cmd = f"monai.apps.auto3dseg EnsembleRunner ensemble \
--data_src_cfg_name {self.data_src_cfg_name} \
--work_dir {self.work_dir} \
--num_fold {self.num_fold} \
--ensemble_method_name {self.ensemble_method_name} \
--mgpu True"

if self.kwargs and isinstance(self.kwargs, Mapping):
for k, v in self.kwargs.items():
base_cmd += f" --{k}={v}"
# define env for subprocess
ps_environ = os.environ.copy()
ps_environ["CUDA_VISIBLE_DEVICES"] = str(self.device_setting["CUDA_VISIBLE_DEVICES"])
cmd: str | None = self.device_setting["CMD_PREFIX"] # type: ignore
if cmd is not None and not str(cmd).endswith(" "):
cmd += " "
if int(self.device_setting["NUM_NODES"]) > 1:
if self.device_setting["MN_START_METHOD"] != "bcprun":
raise NotImplementedError(
f"{self.device_setting['MN_START_METHOD']} is not supported yet. "
"Try modify EnsembleRunner._create_cmd for your cluster."
)
logger.info(f"Ensembling on {self.device_setting['NUM_NODES']} nodes!")
cmd = "python " if cmd is None else cmd
cmd = f"{cmd} -m {base_cmd}"
_ = subprocess.run(
[
"bcprun",
"-n",
str(self.device_setting["NUM_NODES"]),
"-p",
str(self.device_setting["n_devices"]),
"-c",
cmd,
],
env=ps_environ,
check=True,
)
else:
logger.info(f"Ensembling using {self.device_setting['n_devices']} GPU!")
if cmd is None:
cmd = f"torchrun --nnodes={1:d} --nproc_per_node={self.device_setting['n_devices']:d} "
cmd = f"{cmd} -m {base_cmd}"
_ = subprocess.run(cmd.split(), env=ps_environ, check=True)
return
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather
from .enums import (
AlgoKeys,
Average,
BlendMode,
BoxModeName,
28 changes: 27 additions & 1 deletion tests/test_auto3dseg_ensemble.py
Original file line number Diff line number Diff line change
@@ -19,9 +19,17 @@
import numpy as np
import torch

from monai.apps.auto3dseg import AlgoEnsembleBestByFold, AlgoEnsembleBestN, AlgoEnsembleBuilder, BundleGen, DataAnalyzer
from monai.apps.auto3dseg import (
AlgoEnsembleBestByFold,
AlgoEnsembleBestN,
AlgoEnsembleBuilder,
BundleGen,
DataAnalyzer,
EnsembleRunner,
)
from monai.bundle.config_parser import ConfigParser
from monai.data import create_test_image_3d
from monai.transforms import SaveImage
from monai.utils import optional_import, set_determinism
from monai.utils.enums import AlgoKeys
from tests.utils import (
@@ -159,6 +167,24 @@ def test_ensemble(self) -> None:
for algo in ensemble.get_algo_ensemble():
print(algo[AlgoKeys.ID])

def test_ensemble_runner(self) -> None:
runner = EnsembleRunner()
runner.set_num_fold(3)
self.assertTrue(runner.num_fold == 3)
runner.set_ensemble_method(ensemble_method_name="AlgoEnsembleBestByFold")
self.assertIsInstance(runner.ensemble_method, AlgoEnsembleBestByFold)
self.assertTrue(runner.ensemble_method.n_fold == 3) # type: ignore

runner.set_ensemble_method(ensemble_method_name="AlgoEnsembleBestN", n_best=3)
self.assertIsInstance(runner.ensemble_method, AlgoEnsembleBestN)
self.assertTrue(runner.ensemble_method.n_best == 3)

save_output = os.path.join(self.test_dir.name, "workdir")
runner.set_image_save_transform(
output_dir=save_output, output_postfix="test_ensemble", output_dtype=float, resample=True
)
self.assertIsInstance(ConfigParser(runner.save_image).get_parsed_content(), SaveImage)

def tearDown(self) -> None:
set_determinism(None)
self.test_dir.cleanup()

0 comments on commit 825b8db

Please sign in to comment.