From 4bfc8e97661d44cc13553721e5025005757bad91 Mon Sep 17 00:00:00 2001
From: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com>
Date: Tue, 2 May 2023 14:50:17 +0000
Subject: [PATCH] Auto3DSeg ensemble tests images depend on #GPUs + docstring
 improvement (#6457)

Fixes #6456.

### Description

- Make the test images also depend on GPUs
- Improve docstring for infer

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Mingxin Zheng <18563433+mingxin-zheng@users.noreply.github.com>
---
 docs/source/apps.rst                     |  1 +
 monai/apps/auto3dseg/ensemble_builder.py | 20 ++++++++++++++------
 tests/test_auto3dseg_ensemble.py         |  2 +-
 tests/test_integration_autorunner.py     |  2 +-
 4 files changed, 17 insertions(+), 8 deletions(-)

diff --git a/docs/source/apps.rst b/docs/source/apps.rst
index 8c422b16c5..e543859b7c 100644
--- a/docs/source/apps.rst
+++ b/docs/source/apps.rst
@@ -252,6 +252,7 @@ FastMRIReader
 -----------
 .. automodule:: monai.apps.auto3dseg
   :members:
+  :special-members: __call__
   :imported-members:
 
 `nnUNet`
diff --git a/monai/apps/auto3dseg/ensemble_builder.py b/monai/apps/auto3dseg/ensemble_builder.py
index db304a6860..d95222946c 100644
--- a/monai/apps/auto3dseg/ensemble_builder.py
+++ b/monai/apps/auto3dseg/ensemble_builder.py
@@ -136,14 +136,22 @@ def __call__(self, pred_param: dict | None = None) -> list:
                 in this function, and the second group will be passed to the `InferClass` to override the
                 parameters of the class functions.
                 The first group contains:
-                'files_slices': a value type of `slice`. The files_slices will slice the infer_files and only
-                    make prediction on the infer_files[file_slices].
-                'mode': ensemble mode. Currently "mean" and "vote" (majority voting) schemes are supported.
-                'sigmoid': use the sigmoid function (e.g. x>0.5) to convert the prediction probability map to
-                    the label class prediction, otherwise argmax(x) is used.
+
+                    - ``"infer_files"``: file paths to the images to read in a list.
+                    - ``"files_slices"``: a value type of `slice`. The files_slices will slice the ``"infer_files"`` and
+                      only make prediction on the infer_files[file_slices].
+                    - ``"mode"``: ensemble mode. Currently "mean" and "vote" (majority voting) schemes are supported.
+                    - ``"image_save_func"``: a dictionary used to instantiate the ``SaveImage`` transform. When specified,
+                      the ensemble prediction will save the prediciton files, instead of keeping the files in the memory.
+                      Example: `{"_target_": "SaveImage", "output_dir": "./"}`
+                    - ``"sigmoid"``: use the sigmoid function (e.g. x > 0.5) to convert the prediction probability map
+                      to the label class prediction, otherwise argmax(x) is used.
+
+                The parameters in the second group is defined in the ``config`` of each Algo templates. Please check:
+                https://github.com/Project-MONAI/research-contributions/tree/main/auto3dseg/algorithm_templates
 
         Returns:
-            A list of tensors.
+            A list of tensors or file paths, depending on whether ``"image_save_func"`` is set.
         """
         param = {} if pred_param is None else deepcopy(pred_param)
         files = self.infer_files
diff --git a/tests/test_auto3dseg_ensemble.py b/tests/test_auto3dseg_ensemble.py
index a95f30cfb4..e5eb957b1c 100644
--- a/tests/test_auto3dseg_ensemble.py
+++ b/tests/test_auto3dseg_ensemble.py
@@ -46,7 +46,7 @@
 num_images_per_batch = 2
 
 fake_datalist: dict[str, list[dict]] = {
-    "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}],
+    "testing": [{"image": f"ts_image_{idx:03d}.nii.gz"} for idx in range(num_images_perfold)],
     "training": [
         {
             "fold": f,
diff --git a/tests/test_integration_autorunner.py b/tests/test_integration_autorunner.py
index 62951ea5d6..7110db568d 100644
--- a/tests/test_integration_autorunner.py
+++ b/tests/test_integration_autorunner.py
@@ -38,7 +38,7 @@
 num_images_per_batch = 2
 
 sim_datalist: dict[str, list[dict]] = {
-    "testing": [{"image": "val_001.fake.nii.gz"}, {"image": "val_002.fake.nii.gz"}],
+    "testing": [{"image": f"ts_image__{idx:03d}.nii.gz"} for idx in range(num_images_perfold)],
     "training": [
         {
             "fold": f,