From bfc713226aa822852c77c51492b97fb99b4b7e2d Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Mon, 5 Aug 2024 17:05:08 +0800 Subject: [PATCH] Add maisi bundle (#612) Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Status **Ready** ### Please ensure all the checkboxes: - [x] Codeformat tests passed locally by running `./runtests.sh --codeformat`. - [ ] In-line docstrings updated. - [ ] Update `version` and `changelog` in `metadata.json` if changing an existing bundle. - [ ] Please ensure the naming rules in config files meet our requirements (please refer to: `CONTRIBUTING.md`). - [ ] Ensure versions of packages such as `monai`, `pytorch` and `numpy` are correct in `metadata.json`. - [ ] Descriptions should be consistent with the content, such as `eval_metrics` of the provided weights and TorchScript modules. - [ ] Files larger than 25MB are excluded and replaced by providing download links in `large_file.yml`. - [ ] Avoid using path that contains personal information within config files (such as use `/home/your_name/` for `"bundle_root"`). --------- Signed-off-by: Yiheng Wang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .github/workflows/code-format-check.yml | 4 +- .github/workflows/premerge-cpu.yml | 4 +- ci/bundle_custom_data.py | 6 +- ci/get_bundle_requirements.py | 16 +- .../install_maisi_ct_generative_dependency.sh | 1 + ci/run_premerge_cpu.sh | 42 +- ci/run_premerge_gpu.sh | 145 +- ci/run_premerge_multi_gpu.sh | 108 - ci/unit_tests/test_maisi_ct_generative.py | 330 +++ .../test_maisi_ct_generative_dist.py | 83 + ci/unit_tests/test_vista3d.py | 18 +- ...t_vista3d_mgpu.py => test_vista3d_dist.py} | 0 ci/verify_bundle.py | 10 +- models/maisi_ct_generative/LICENSE | 247 ++ .../configs/inference.json | 292 +++ .../configs/integration_test_masks.json | 98 + .../configs/label_dict.json | 134 ++ .../configs/label_dict_124_to_132.json | 502 +++++ .../maisi_ct_generative/configs/logging.conf | 21 + .../maisi_ct_generative/configs/metadata.json | 263 +++ .../configs/multi_gpu_train.json | 34 + models/maisi_ct_generative/configs/train.json | 270 +++ models/maisi_ct_generative/docs/README.md | 103 + .../maisi_ct_generative/docs/data_license.txt | 49 + models/maisi_ct_generative/large_files.yml | 23 + .../maisi_ct_generative/scripts/__init__.py | 12 + .../scripts/augmentation.py | 366 +++ .../scripts/custom_network_controlnet.py | 177 ++ .../scripts/custom_network_diffusion.py | 1993 +++++++++++++++++ .../scripts/custom_network_tp.py | 1053 +++++++++ .../maisi_ct_generative/scripts/find_masks.py | 120 + models/maisi_ct_generative/scripts/sample.py | 699 ++++++ models/maisi_ct_generative/scripts/trainer.py | 247 ++ models/maisi_ct_generative/scripts/utils.py | 429 ++++ 34 files changed, 7699 insertions(+), 200 deletions(-) create mode 100644 ci/install_scripts/install_maisi_ct_generative_dependency.sh delete mode 100755 ci/run_premerge_multi_gpu.sh create mode 100644 ci/unit_tests/test_maisi_ct_generative.py create mode 100644 ci/unit_tests/test_maisi_ct_generative_dist.py rename ci/unit_tests/{test_vista3d_mgpu.py => test_vista3d_dist.py} (100%) create mode 100644 models/maisi_ct_generative/LICENSE create mode 100644 models/maisi_ct_generative/configs/inference.json create mode 100644 models/maisi_ct_generative/configs/integration_test_masks.json create mode 100644 models/maisi_ct_generative/configs/label_dict.json create mode 100644 models/maisi_ct_generative/configs/label_dict_124_to_132.json create mode 100644 models/maisi_ct_generative/configs/logging.conf create mode 100644 models/maisi_ct_generative/configs/metadata.json create mode 100644 models/maisi_ct_generative/configs/multi_gpu_train.json create mode 100644 models/maisi_ct_generative/configs/train.json create mode 100644 models/maisi_ct_generative/docs/README.md create mode 100644 models/maisi_ct_generative/docs/data_license.txt create mode 100644 models/maisi_ct_generative/large_files.yml create mode 100644 models/maisi_ct_generative/scripts/__init__.py create mode 100644 models/maisi_ct_generative/scripts/augmentation.py create mode 100644 models/maisi_ct_generative/scripts/custom_network_controlnet.py create mode 100644 models/maisi_ct_generative/scripts/custom_network_diffusion.py create mode 100644 models/maisi_ct_generative/scripts/custom_network_tp.py create mode 100644 models/maisi_ct_generative/scripts/find_masks.py create mode 100644 models/maisi_ct_generative/scripts/sample.py create mode 100644 models/maisi_ct_generative/scripts/trainer.py create mode 100644 models/maisi_ct_generative/scripts/utils.py diff --git a/.github/workflows/code-format-check.yml b/.github/workflows/code-format-check.yml index 6080ae6f..920cb6a1 100644 --- a/.github/workflows/code-format-check.yml +++ b/.github/workflows/code-format-check.yml @@ -17,10 +17,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.10.14 - name: cache weekly timestamp id: pip-cache run: | diff --git a/.github/workflows/premerge-cpu.yml b/.github/workflows/premerge-cpu.yml index e46de1ca..f94b5204 100644 --- a/.github/workflows/premerge-cpu.yml +++ b/.github/workflows/premerge-cpu.yml @@ -17,10 +17,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - - name: Set up Python 3.9 + - name: Set up Python 3.10 uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: 3.10.14 - name: cache weekly timestamp id: pip-cache run: | diff --git a/ci/bundle_custom_data.py b/ci/bundle_custom_data.py index 2965350c..3f5aed33 100644 --- a/ci/bundle_custom_data.py +++ b/ci/bundle_custom_data.py @@ -19,11 +19,12 @@ "pathology_nuclei_segmentation_classification", "brats_mri_generative_diffusion", "brats_mri_axial_slices_generative_diffusion", + "maisi_ct_generative", ] # This list is used for our CI tests to determine whether a bundle contains the preferred files. # If a bundle does not have any of the preferred files, please add the bundle name into the list. -exclude_verify_preferred_files_list = [] +exclude_verify_preferred_files_list = ["maisi_ct_generative"] # This list is used for our CI tests to determine whether a bundle needs to be tested with # the `verify_export_torchscript` function in `verify_bundle.py`. @@ -37,12 +38,13 @@ "mednist_reg", "brats_mri_axial_slices_generative_diffusion", "vista3d", + "maisi_ct_generative", ] # This dict is used for our CI tests to install required dependencies that cannot be installed by `pip install` directly. # If a bundle has this kind of dependencies, please add the bundle name (key), and the path of the install script (value) # into the dict. -install_dependency_dict = {} +install_dependency_dict = {"maisi_ct_generative": "ci/install_scripts/install_maisi_ct_generative_dependency.sh"} # This list is used for our CI tests to determine whether a bundle supports TensorRT export. Related # test will be employed for bundles in the dict. diff --git a/ci/get_bundle_requirements.py b/ci/get_bundle_requirements.py index 9306e9ca..71e69b1f 100644 --- a/ci/get_bundle_requirements.py +++ b/ci/get_bundle_requirements.py @@ -19,6 +19,8 @@ ALLOW_MONAI_RC = os.environ.get("ALLOW_MONAI_RC", "false").lower() in ("true", "1", "t", "y", "yes") +SPECIAL_LIB_LIST = ["xformers"] + def increment_version(version): """ @@ -75,10 +77,16 @@ def get_requirements(bundle, models_path): if "numpy_version" in metadata.keys(): numpy_version = metadata["numpy_version"] libs.append(f"numpy=={numpy_version}") - if "optional_packages_version" in metadata.keys(): - optional_dict = metadata["optional_packages_version"] - for name, version in optional_dict.items(): - libs.append(f"{name}=={version}") + for package_key in ["optional_packages_version", "required_packages_version"]: + if package_key in metadata.keys(): + optional_dict = metadata[package_key] + for name, version in optional_dict.items(): + if name not in SPECIAL_LIB_LIST: + libs.append(f"{name}=={version}") + else: + if "pytorch_version" in metadata.keys(): + # remove torch from libs + libs = [lib for lib in libs if "torch" not in lib] if len(libs) > 0: requirements_file_name = f"requirements_{bundle}.txt" diff --git a/ci/install_scripts/install_maisi_ct_generative_dependency.sh b/ci/install_scripts/install_maisi_ct_generative_dependency.sh new file mode 100644 index 00000000..9eea394c --- /dev/null +++ b/ci/install_scripts/install_maisi_ct_generative_dependency.sh @@ -0,0 +1 @@ +pip install --extra-index-url https://urm.nvidia.com/artifactory/api/pypi/sw-dlmed-pypi-local/simple xformers==0.0.26+622595c.d20240617 diff --git a/ci/run_premerge_cpu.sh b/ci/run_premerge_cpu.sh index 59c71916..f7326adc 100755 --- a/ci/run_premerge_cpu.sh +++ b/ci/run_premerge_cpu.sh @@ -30,6 +30,17 @@ elif [[ $# -gt 1 ]]; then exit 1 fi +# Usually, CPU test is required, but for some bundles that are too large to run in Github Actions, we can exclude them. +exclude_test_list=("maisi_ct_generative") +is_excluded() { + for item in "${exclude_list[@]}"; do + if [ "$1" == "$item" ]; then + return 0 # Return true (0) if excluded + fi + done + return 1 # Return false (1) if not excluded +} + verify_bundle() { for dir in /opt/hostedtoolcache/*; do if [[ $dir != "/opt/hostedtoolcache/Python" ]]; then @@ -52,21 +63,25 @@ verify_bundle() { echo $bundle_list for bundle in $bundle_list; do - pip install -r requirements-dev.txt - # get required libraries according to the bundle's metadata file - requirements=$(python $(pwd)/ci/get_bundle_requirements.py --b "$bundle") - # check if ALLOW_MONAI_RC is set to 1, if so, append --pre to the pip install command - if [ $ALLOW_MONAI_RC = true ]; then - include_pre_release="--pre" + if is_excluded "$bundle"; then + echo "skip '$bundle' cpu premerge tests." else - include_pre_release="" - fi - if [ ! -z "$requirements" ]; then - echo "install required libraries for bundle: $bundle" - pip install $include_pre_release -r "$requirements" + pip install -r requirements-dev.txt + # get required libraries according to the bundle's metadata file + requirements=$(python $(pwd)/ci/get_bundle_requirements.py --b "$bundle") + # check if ALLOW_MONAI_RC is set to 1, if so, append --pre to the pip install command + if [ $ALLOW_MONAI_RC = true ]; then + include_pre_release="--pre" + else + include_pre_release="" + fi + if [ ! -z "$requirements" ]; then + echo "install required libraries for bundle: $bundle" + pip install $include_pre_release -r "$requirements" + fi + # verify bundle + python $(pwd)/ci/verify_bundle.py -b "$bundle" -m "min" # min tests on cpu fi - # verify bundle - python $(pwd)/ci/verify_bundle.py -b "$bundle" -m "min" # min tests on cpu done else echo "this pull request does not change any bundles, skip verify." @@ -81,6 +96,7 @@ case $BUILD_TYPE in all) echo "Run all tests..." + verify_bundle ;; changed) echo "Run changed tests..." diff --git a/ci/run_premerge_gpu.sh b/ci/run_premerge_gpu.sh index d5ae73c8..a55b624b 100755 --- a/ci/run_premerge_gpu.sh +++ b/ci/run_premerge_gpu.sh @@ -16,36 +16,52 @@ # # Argument(s): -# BUILD_TYPE: all/specific_test_name, tests to execute +# $1 - Dist flag (True/False) + +dist_flag=$1 set -ex -BUILD_TYPE=all -export ALLOW_MONAI_RC=true -if [[ $# -eq 1 ]]; then - BUILD_TYPE=$1 +export ALLOW_MONAI_RC=true -elif [[ $# -gt 1 ]]; then +if [[ $# -gt 1 ]]; then echo "ERROR: too many parameters are provided" exit 1 fi -init_pipenv() { - echo "initializing pip environment: $1" - pipenv install update pip wheel - pipenv install --python=3.9 -r $1 - export PYTHONPATH=$PWD +init_venv() { + if [ ! -d "model_zoo_venv" ]; then # Check if the venv directory does not exist + echo "initializing pip environment: $1" + python -m venv model_zoo_venv + source model_zoo_venv/bin/activate + pip install --upgrade pip wheel + pip install -r $1 + export PYTHONPATH=$PWD + else + echo "Virtual environment model_zoo_venv already exists. Activating..." + source model_zoo_venv/bin/activate + fi +} + +remove_venv() { + if [ -d "model_zoo_venv" ]; then # Check if the venv directory exists + echo "Removing virtual environment..." + deactivate 2>/dev/null || true # Deactivate venv, ignore errors if not activated + rm -rf model_zoo_venv # Remove the venv directory + else + echo "Virtual environment not found. Skipping removal." + fi } -remove_pipenv() { - echo "removing pip environment" - pipenv --rm - rm Pipfile Pipfile.lock +set_local_env() { + echo "set local pip environment: $1" + pip install --upgrade pip wheel + pip install -r $1 + export PYTHONPATH=$PWD } verify_bundle() { echo 'Run verify bundle...' - init_pipenv requirements-dev.txt head_ref=$(git rev-parse HEAD) git fetch origin dev $head_ref # achieve all changed files in 'models' @@ -53,58 +69,69 @@ verify_bundle() { if [ ! -z "$changes" ] then # get all changed bundles - bundle_list=$(pipenv run python $(pwd)/ci/get_changed_bundle.py --f "$changes") + bundle_list=$(python $(pwd)/ci/get_changed_bundle.py --f "$changes") if [ ! -z "$bundle_list" ] then - pipenv run python $(pwd)/ci/prepare_schema.py --l "$bundle_list" - for bundle in $bundle_list; - do - init_pipenv requirements-dev.txt - # get required libraries according to the bundle's metadata file - requirements=$(pipenv run python $(pwd)/ci/get_bundle_requirements.py --b "$bundle") - # check if ALLOW_MONAI_RC is set to 1, if so, append --pre to the pip install command - if [ $ALLOW_MONAI_RC = true ]; then - include_pre_release="--pre" + python $(pwd)/ci/prepare_schema.py --l "$bundle_list" + for bundle in $bundle_list; + do + # Check if the bundle is "maisi_ct_generative", if so, set local environment (venv cannot work with xformers) + if [ "$bundle" == "maisi_ct_generative" ]; then + echo "Special handling for maisi_ct_generative bundle" + set_local_env requirements-dev.txt + else + init_venv requirements-dev.txt + fi + # get required libraries according to the bundle's metadata file + requirements=$(python $(pwd)/ci/get_bundle_requirements.py --b "$bundle") + # check if ALLOW_MONAI_RC is set to 1, if so, append --pre to the pip install command + if [ $ALLOW_MONAI_RC = true ]; then + include_pre_release="--pre" + else + include_pre_release="" + fi + if [ ! -z "$requirements" ]; then + echo "install required libraries for bundle: $bundle" + pip install $include_pre_release -r "$requirements" + fi + # get extra install script if exists + extra_script=$(python $(pwd)/ci/get_bundle_requirements.py --b "$bundle" --get_script True) + if [ ! -z "$extra_script" ]; then + echo "install extra libraries with script: $extra_script" + bash $extra_script + fi + # verify bundle + python $(pwd)/ci/verify_bundle.py --b "$bundle" + # unzip data and do unit tests + DATA_DIR="$(pwd)/models/maisi_ct_generative/datasets" + ZIP_FILE="$DATA_DIR/all_masks_flexible_size_and_spacing_3000.zip" + UNZIP_DIR="$DATA_DIR/all_masks_flexible_size_and_spacing_3000" + if [ -f "$ZIP_FILE" ]; then + if [ ! -d "$UNZIP_DIR" ]; then + echo "Unzipping files for MAISI Bundle..." + unzip $ZIP_FILE -d $DATA_DIR + echo "Unzipping complete." else - include_pre_release="" - fi - if [ ! -z "$requirements" ]; then - echo "install required libraries for bundle: $bundle" - pipenv install $include_pre_release -r "$requirements" + echo "Unzipped content already exists, continuing..." fi - # get extra install script if exists - extra_script=$(pipenv run python $(pwd)/ci/get_bundle_requirements.py --b "$bundle" --get_script True) - if [ ! -z "$extra_script" ]; then - echo "install extra libraries with script: $extra_script" - bash $extra_script - fi - # verify bundle - pipenv run python $(pwd)/ci/verify_bundle.py --b "$bundle" - # do unit tests - pipenv run python $(pwd)/ci/unit_tests/runner.py --b "$bundle" - remove_pipenv - done + fi + test_cmd="python $(pwd)/ci/unit_tests/runner.py --b \"$bundle\"" + if [ "$dist_flag" = "True" ]; then + test_cmd="$test_cmd --dist True" + fi + eval $test_cmd + # if not maisi_ct_generative, remove venv + if [ "$bundle" != "maisi_ct_generative" ]; then + remove_venv + fi + done else echo "this pull request does not change any bundles, skip verify." fi else echo "this pull request does not change any files in 'models', skip verify." - remove_pipenv + remove_venv fi } -case $BUILD_TYPE in - - all) - echo "Run all tests..." - verify_bundle - ;; - - verify_bundle) - verify_bundle - ;; - - *) - echo "ERROR: unknown parameter: $BUILD_TYPE" - ;; -esac +verify_bundle diff --git a/ci/run_premerge_multi_gpu.sh b/ci/run_premerge_multi_gpu.sh deleted file mode 100755 index c00ae8e2..00000000 --- a/ci/run_premerge_multi_gpu.sh +++ /dev/null @@ -1,108 +0,0 @@ -#!/bin/bash -# -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Argument(s): -# BUILD_TYPE: all/specific_test_name, tests to execute - -set -ex -BUILD_TYPE=all -export ALLOW_MONAI_RC=true - -if [[ $# -eq 1 ]]; then - BUILD_TYPE=$1 - -elif [[ $# -gt 1 ]]; then - echo "ERROR: too many parameters are provided" - exit 1 -fi - -init_pipenv() { - echo "initializing pip environment: $1" - pipenv install update pip wheel - pipenv install --python=3.9 -r $1 - export PYTHONPATH=$PWD -} - -remove_pipenv() { - echo "removing pip environment" - pipenv --rm - rm Pipfile Pipfile.lock -} - -verify_bundle() { - echo 'Run verify bundle...' - init_pipenv requirements-dev.txt - head_ref=$(git rev-parse HEAD) - git fetch origin dev $head_ref - # achieve all changed files in 'models' - changes=$(git diff --name-only $head_ref origin/dev -- models) - if [ ! -z "$changes" ] - then - # get all changed bundles - bundle_list=$(pipenv run python $(pwd)/ci/get_changed_bundle.py --f "$changes") - if [ ! -z "$bundle_list" ] - then - pipenv run python $(pwd)/ci/prepare_schema.py --l "$bundle_list" - for bundle in $bundle_list; - do - init_pipenv requirements-dev.txt - # get required libraries according to the bundle's metadata file - requirements=$(pipenv run python $(pwd)/ci/get_bundle_requirements.py --b "$bundle") - # check if ALLOW_MONAI_RC is set to 1, if so, append --pre to the pip install command - if [ $ALLOW_MONAI_RC = true ]; then - include_pre_release="--pre" - else - include_pre_release="" - fi - if [ ! -z "$requirements" ]; then - echo "install required libraries for bundle: $bundle" - pipenv install $include_pre_release -r "$requirements" - fi - # get extra install script if exists - extra_script=$(pipenv run python $(pwd)/ci/get_bundle_requirements.py --b "$bundle" --get_script True) - if [ ! -z "$extra_script" ]; then - echo "install extra libraries with script: $extra_script" - bash $extra_script - fi - # do multi gpu based unit tests - pipenv run torchrun $(pwd)/ci/unit_tests/runner.py --b "$bundle" --dist True - remove_pipenv - done - else - echo "this pull request does not change any bundles, skip verify." - fi - else - echo "this pull request does not change any files in 'models', skip verify." - remove_pipenv - fi -} - -case $BUILD_TYPE in - - all) - echo "Run all tests..." - verify_bundle - ;; - - verify_bundle) - verify_bundle - ;; - - *) - echo "ERROR: unknown parameter: $BUILD_TYPE" - ;; -esac diff --git a/ci/unit_tests/test_maisi_ct_generative.py b/ci/unit_tests/test_maisi_ct_generative.py new file mode 100644 index 00000000..3a3028c7 --- /dev/null +++ b/ci/unit_tests/test_maisi_ct_generative.py @@ -0,0 +1,330 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import sys +import tempfile +import unittest + +import nibabel as nib +import numpy as np +from monai.bundle import create_workflow +from monai.transforms import LoadImage +from parameterized import parameterized + +TEST_CASE_INFER_1 = [ + { + "bundle_root": "models/maisi_ct_generative", + "num_output_samples": 1, + "num_inference_steps": 2, + "output_size": [256, 256, 256], + "body_region": ["abdomen"], + "anatomy_list": ["liver"], + } +] + +# This case will definitely trigger find_closest_masks func +TEST_CASE_INFER_2 = [ + { + "bundle_root": "models/maisi_ct_generative", + "num_output_samples": 1, + "num_inference_steps": 2, + "output_size": [256, 256, 128], + "spacing": [1.5, 1.5, 1.5], + "body_region": ["abdomen"], + "anatomy_list": ["liver"], + } +] + +# This case will definitely trigger data augmentation for tumors +TEST_CASE_INFER_3 = [ + { + "bundle_root": "models/maisi_ct_generative", + "num_output_samples": 1, + "num_inference_steps": 2, + "output_size": [256, 256, 128], + "spacing": [1.5, 1.5, 1.5], + "body_region": ["abdomen"], + "anatomy_list": ["bone lesion"], + } +] + +TEST_CASE_INFER_WITH_MASK_GENERATION = [ + { + "bundle_root": "models/maisi_ct_generative", + "num_output_samples": 1, + "num_inference_steps": 2, + "mask_generation_num_inference_steps": 2, + "output_size": [256, 256, 256], + "spacing": [1.5, 1.5, 2.0], + "body_region": ["chest"], + "anatomy_list": ["liver"], + "controllable_anatomy_size": [["hepatic tumor", 0.3], ["liver", 0.5]], + } +] + +TEST_CASE_INFER_DIFFERENT_OUTPUT_TYPE = [ + { + "bundle_root": "models/maisi_ct_generative", + "num_output_samples": 1, + "num_inference_steps": 2, + "output_size": [256, 256, 256], + "body_region": ["abdomen"], + "anatomy_list": ["liver"], + "image_output_ext": ".dcm", + "label_output_ext": ".nrrd", + } +] + +TEST_CASE_INFER_ERROR = [ + { + "bundle_root": "models/maisi_ct_generative", + "num_output_samples": 1, + "output_size": [256, 256, 256], + "body_region": ["head"], + "anatomy_list": ["colon cancer primaries"], + }, + "Cannot find body region with given organ list.", +] + +TEST_CASE_INFER_ERROR_2 = [ + { + "bundle_root": "models/maisi_ct_generative", + "num_output_samples": 1, + "output_size": [256, 256, 256], + "body_region": ["head_typo"], + "anatomy_list": ["brain"], + } +] + +TEST_CASE_INFER_ERROR_3 = [ + { + "bundle_root": "models/maisi_ct_generative", + "num_output_samples": 1, + "output_size": [256, 256, 256], + "body_region": ["head"], + "anatomy_list": ["brain_typo"], + } +] + +TEST_CASE_INFER_ERROR_4 = [ + { + "bundle_root": "models/maisi_ct_generative", + "num_output_samples": 1, + "output_size": [256, 256, 177], + "body_region": ["head"], + "anatomy_list": ["brain"], + } +] + +TEST_CASE_INFER_ERROR_5 = [ + { + "bundle_root": "models/maisi_ct_generative", + "num_output_samples": 1, + "output_size": [256, 256, 256], + "body_region": ["head"], + "anatomy_list": ["brain"], + "controllable_anatomy_size": [["hepatic tumor", 0.3], ["bone lesion", 0.5]], + } +] + +TEST_CASE_INFER_ERROR_6 = [ + { + "bundle_root": "models/maisi_ct_generative", + "num_output_samples": 1, + "output_size": [256, 128, 256], + "body_region": ["head"], + "anatomy_list": ["brain"], + "controllable_anatomy_size": [["hepatic tumor", 0.3], ["bone lesion", 0.5]], + } +] + +TEST_CASE_INFER_ERROR_7 = [ + { + "bundle_root": "models/maisi_ct_generative", + "num_output_samples": 1, + "output_size": [256, 256, 256], + "body_region": ["chest"], + "anatomy_list": ["colon", "spleen", "trachea", "left humerus", "sacrum", "heart"], + }, + "Cannot find body region with given organ list.", +] + +TEST_CASE_TRAIN = [ + {"bundle_root": "models/maisi_ct_generative", "epochs": 2, "initialize": ["$monai.utils.set_determinism(seed=123)"]} +] + + +TEST_CASE_TRAIN = [ + {"bundle_root": "models/maisi_ct_generative", "epochs": 2, "initialize": ["$monai.utils.set_determinism(seed=123)"]} +] + + +def check_workflow(workflow, check_properties: bool = False): + if check_properties is True: + check_result = workflow.check_properties() + if check_result is not None and len(check_result) > 0: + raise ValueError(f"check properties for workflow failed: {check_result}") + workflow.run() + workflow.finalize() + + +class TestMAISI(unittest.TestCase): + def setUp(self): + self.output_dir = tempfile.mkdtemp() + self.dataset_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.output_dir) + shutil.rmtree(self.dataset_dir) + + def create_train_dataset(self): + self.dataset_size = 5 + input_shape = (32, 32, 32, 4) + mask_shape = (128, 128, 128) + for s in range(self.dataset_size): + test_image = np.random.randint(low=0, high=2, size=input_shape).astype(np.int8) + test_label = np.random.randint(low=0, high=2, size=mask_shape).astype(np.int8) + image_filename = os.path.join(self.dataset_dir, f"image_{s}.nii.gz") + label_filename = os.path.join(self.dataset_dir, f"label_{s}.nii.gz") + nib.save(nib.Nifti1Image(test_image, np.eye(4)), image_filename) + nib.save(nib.Nifti1Image(test_label, np.eye(4)), label_filename) + + @parameterized.expand([TEST_CASE_TRAIN]) + def test_train_config(self, override): + self.create_train_dataset() + train_size = self.dataset_size // 2 + train_datalist = [ + { + "image": os.path.join(self.dataset_dir, f"image_{i}.nii.gz"), + "label": os.path.join(self.dataset_dir, f"label_{i}.nii.gz"), + "dim": [128, 128, 128], + "spacing": [1.0, 1.0, 1.0], + "top_region_index": [0, 1, 0, 0], + "bottom_region_index": [0, 0, 0, 1], + "fold": 0, + } + for i in range(train_size) + ] + override["train_datalist"] = train_datalist + + bundle_root = override["bundle_root"] + if bundle_root not in sys.path: + sys.path = [bundle_root] + sys.path + trainer = create_workflow( + workflow_type="train", + config_file=os.path.join(bundle_root, "configs/train.json"), + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + **override, + ) + check_workflow(trainer, check_properties=False) + + @parameterized.expand( + [ + TEST_CASE_INFER_1, + TEST_CASE_INFER_2, + TEST_CASE_INFER_3, + TEST_CASE_INFER_WITH_MASK_GENERATION, + TEST_CASE_INFER_DIFFERENT_OUTPUT_TYPE, + ] + ) + def test_infer_config(self, override): + # update override + override["output_dir"] = self.output_dir + bundle_root = override["bundle_root"] + if bundle_root not in sys.path: + sys.path = [bundle_root] + sys.path + workflow = create_workflow( + workflow_type="infer", + config_file=os.path.join(bundle_root, "configs/inference.json"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + **override, + ) + + # check_properties=False, need to add monai service properties later + check_workflow(workflow, check_properties=False) + + # check expected output + output_files = os.listdir(self.output_dir) + output_labels = [f for f in output_files if "label" in f] + output_images = [f for f in output_files if "image" in f] + self.assertEqual(len(output_labels), override["num_output_samples"]) + self.assertEqual(len(output_images), override["num_output_samples"]) + + # check output type and shape + loader = LoadImage(image_only=True) + for output_file in output_files: + output_file_path = os.path.join(self.output_dir, output_file) + data = loader(output_file_path) + self.assertEqual(data.shape, tuple(override["output_size"])) + if "image_output_ext" in override: + if "image" in output_file: + self.assertTrue(output_file.endswith(override["image_output_ext"])) + elif "label_output_ext" in override: + if "label" in output_file: + self.assertTrue(output_file.endswith(override["label_output_ext"])) + else: + self.assertTrue(output_file.endswith(".nii.gz")) + + @parameterized.expand([TEST_CASE_INFER_ERROR, TEST_CASE_INFER_ERROR_7]) + def test_infer_config_error_input(self, override, expected_error): + # update override + override["output_dir"] = self.output_dir + bundle_root = override["bundle_root"] + if bundle_root not in sys.path: + sys.path = [bundle_root] + sys.path + workflow = create_workflow( + workflow_type="infer", + config_file=os.path.join(bundle_root, "configs/inference.json"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + **override, + ) + with self.assertRaises(RuntimeError) as context: + workflow.run() + runtime_error = context.exception + original_exception = runtime_error.__cause__ + self.assertEqual(str(original_exception), expected_error) + + @parameterized.expand( + [ + TEST_CASE_INFER_ERROR_2, + TEST_CASE_INFER_ERROR_3, + TEST_CASE_INFER_ERROR_4, + TEST_CASE_INFER_ERROR_5, + TEST_CASE_INFER_ERROR_6, + ] + ) + def test_infer_config_valueerror_input(self, override): + # update override + override["output_dir"] = self.output_dir + bundle_root = override["bundle_root"] + if bundle_root not in sys.path: + sys.path = [bundle_root] + sys.path + workflow = create_workflow( + workflow_type="infer", + config_file=os.path.join(bundle_root, "configs/inference.json"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + **override, + ) + with self.assertRaises(RuntimeError) as context: + workflow.run() + runtime_error = context.exception + original_exception = runtime_error.__cause__ + self.assertIsInstance(original_exception, ValueError) + + +if __name__ == "__main__": + unittest.main() diff --git a/ci/unit_tests/test_maisi_ct_generative_dist.py b/ci/unit_tests/test_maisi_ct_generative_dist.py new file mode 100644 index 00000000..081b33c1 --- /dev/null +++ b/ci/unit_tests/test_maisi_ct_generative_dist.py @@ -0,0 +1,83 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import sys +import tempfile +import unittest + +import nibabel as nib +import numpy as np +import torch +from parameterized import parameterized +from utils import export_config_and_run_mgpu_cmd + +TEST_CASE_TRAIN_MGPU = [{"bundle_root": "models/maisi_ct_generative", "epochs": 2}] + + +class TestMAISI(unittest.TestCase): + def setUp(self): + self.output_dir = tempfile.mkdtemp() + self.dataset_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.output_dir) + shutil.rmtree(self.dataset_dir) + + def create_train_dataset(self): + self.dataset_size = 5 + input_shape = (32, 32, 32, 4) + mask_shape = (128, 128, 128) + for s in range(self.dataset_size): + test_image = np.random.randint(low=0, high=2, size=input_shape).astype(np.int8) + test_label = np.random.randint(low=0, high=2, size=mask_shape).astype(np.int8) + image_filename = os.path.join(self.dataset_dir, f"image_{s}.nii.gz") + label_filename = os.path.join(self.dataset_dir, f"label_{s}.nii.gz") + nib.save(nib.Nifti1Image(test_image, np.eye(4)), image_filename) + nib.save(nib.Nifti1Image(test_label, np.eye(4)), label_filename) + + @parameterized.expand([TEST_CASE_TRAIN_MGPU]) + def test_train_mgpu_config(self, override): + self.create_train_dataset() + train_size = self.dataset_size // 2 + train_datalist = [ + { + "image": os.path.join(self.dataset_dir, f"image_{i}.nii.gz"), + "label": os.path.join(self.dataset_dir, f"label_{i}.nii.gz"), + "dim": [128, 128, 128], + "spacing": [1.0, 1.0, 1.0], + "top_region_index": [0, 1, 0, 0], + "bottom_region_index": [0, 0, 0, 1], + "fold": 0, + } + for i in range(train_size) + ] + override["train_datalist"] = train_datalist + + bundle_root = override["bundle_root"] + sys.path = [bundle_root] + sys.path + train_file = os.path.join(bundle_root, "configs/train.json") + mgpu_train_file = os.path.join(bundle_root, "configs/multi_gpu_train.json") + output_path = os.path.join(bundle_root, "configs/train_override.json") + n_gpu = torch.cuda.device_count() + export_config_and_run_mgpu_cmd( + config_file=[train_file, mgpu_train_file], + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + override_dict=override, + output_path=output_path, + ngpu=n_gpu, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/ci/unit_tests/test_vista3d.py b/ci/unit_tests/test_vista3d.py index b12d3e69..1f00eb31 100644 --- a/ci/unit_tests/test_vista3d.py +++ b/ci/unit_tests/test_vista3d.py @@ -275,7 +275,8 @@ def test_train_config(self, override): override["val_datalist"] = val_datalist bundle_root = override["bundle_root"] - sys.path = [bundle_root] + sys.path + if bundle_root not in sys.path: + sys.path = [bundle_root] + sys.path trainer = ConfigWorkflow( workflow_type="train", config_file=os.path.join(bundle_root, "configs/train.json"), @@ -305,7 +306,8 @@ def test_eval_config(self, override): override["train_datalist"] = train_datalist override["val_datalist"] = val_datalist bundle_root = override["bundle_root"] - sys.path = [bundle_root] + sys.path + if bundle_root not in sys.path: + sys.path = [bundle_root] + sys.path config_files = [ os.path.join(bundle_root, "configs/train.json"), os.path.join(bundle_root, "configs/train_continual.json"), @@ -342,7 +344,8 @@ def test_train_continual_config(self, override): override["val_datalist"] = val_datalist bundle_root = override["bundle_root"] - sys.path = [bundle_root] + sys.path + if bundle_root not in sys.path: + sys.path = [bundle_root] + sys.path trainer = ConfigWorkflow( workflow_type="train", config_file=[ @@ -373,7 +376,8 @@ def test_infer_config(self, override): override["input_dict"] = input_dict bundle_root = override["bundle_root"] - sys.path = [bundle_root] + sys.path + if bundle_root not in sys.path: + sys.path = [bundle_root] + sys.path inferrer = ConfigWorkflow( workflow_type="infer", @@ -396,7 +400,8 @@ def test_batch_infer_config(self, override): params["input_suffix"] = "image_*.nii.gz" bundle_root = override["bundle_root"] - sys.path = [bundle_root] + sys.path + if bundle_root not in sys.path: + sys.path = [bundle_root] + sys.path config_files = [ os.path.join(bundle_root, "configs/inference.json"), os.path.join(bundle_root, "configs/batch_inference.json"), @@ -419,7 +424,8 @@ def test_error_prompt_infer_config(self, override): override["input_dict"] = input_dict bundle_root = override["bundle_root"] - sys.path = [bundle_root] + sys.path + if bundle_root not in sys.path: + sys.path = [bundle_root] + sys.path inferrer = ConfigWorkflow( workflow_type="infer", diff --git a/ci/unit_tests/test_vista3d_mgpu.py b/ci/unit_tests/test_vista3d_dist.py similarity index 100% rename from ci/unit_tests/test_vista3d_mgpu.py rename to ci/unit_tests/test_vista3d_dist.py diff --git a/ci/verify_bundle.py b/ci/verify_bundle.py index 076eb487..3780ff25 100644 --- a/ci/verify_bundle.py +++ b/ci/verify_bundle.py @@ -262,7 +262,7 @@ def verify_bundle_properties(model_path: str, bundle: str): supported_apps = metadata["supported_apps"] all_properties = [] for app, version in supported_apps.items(): - if app in ["vista3d-nim"]: + if app in ["vista3d-nim", "maisi-nim"]: # skip check continue properties_path = get_app_properties(app, version) @@ -282,9 +282,6 @@ def verify(bundle, models_path="models", mode="full"): # verify bundle directory verify_bundle_directory(models_path, bundle) print("directory is verified correctly.") - # verify bundle properties - verify_bundle_properties(models_path, bundle) - print("properties are verified correctly.") if mode != "regular": # verify version, changelog verify_version_changes(models_path, bundle) @@ -298,6 +295,11 @@ def verify(bundle, models_path="models", mode="full"): return # The following are optional tests and require GPU + + # verify bundle properties + verify_bundle_properties(models_path, bundle) + print("properties are verified correctly.") + net_id = _get_net_id(bundle) inference_file_name = _find_bundle_file(os.path.join(bundle_path, "configs"), "inference") config_file = os.path.join("configs", inference_file_name) diff --git a/models/maisi_ct_generative/LICENSE b/models/maisi_ct_generative/LICENSE new file mode 100644 index 00000000..7b84e73d --- /dev/null +++ b/models/maisi_ct_generative/LICENSE @@ -0,0 +1,247 @@ +Code License + +This license applies to all files except the model weights in the directory. + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------ + +Model Weights License + +This license applies to model weights in the directory. + +NVIDIA License + +1. Definitions + +“Licensor” means any person or entity that distributes its Work. +“Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license. +The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. +Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license. + +2. License Grant + +2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. + +3. Limitations + +3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. + +3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. + +3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. + +3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately. + +3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license. + +3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately. + +4. Disclaimer of Warranty. + +THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. + +5. Limitation of Liability. + +EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. diff --git a/models/maisi_ct_generative/configs/inference.json b/models/maisi_ct_generative/configs/inference.json new file mode 100644 index 00000000..d4268fc7 --- /dev/null +++ b/models/maisi_ct_generative/configs/inference.json @@ -0,0 +1,292 @@ +{ + "imports": [ + "$import torch", + "$from pathlib import Path", + "$import scripts" + ], + "bundle_root": ".", + "model_dir": "$@bundle_root + '/models'", + "output_dir": "$@bundle_root + '/output'", + "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)", + "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", + "trained_autoencoder_path": "$@model_dir + '/autoencoder_epoch273.pt'", + "trained_diffusion_path": "$@model_dir + '/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt'", + "trained_controlnet_path": "$@model_dir + '/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt'", + "trained_mask_generation_autoencoder_path": "$@model_dir + '/mask_generation_autoencoder.pt'", + "trained_mask_generation_diffusion_path": "$@model_dir + '/mask_generation_diffusion_unet.pt'", + "all_mask_files_base_dir": "$@bundle_root + '/datasets/all_masks_flexible_size_and_spacing_3000'", + "all_mask_files_json": "$@bundle_root + '/configs/candidate_masks_flexible_size_and_spacing_3000.json'", + "all_anatomy_size_condtions_json": "$@bundle_root + '/configs/all_anatomy_size_condtions.json'", + "label_dict_json": "$@bundle_root + '/configs/label_dict.json'", + "label_dict_remap_json": "$@bundle_root + '/configs/label_dict_124_to_132.json'", + "quality_check_args": null, + "num_output_samples": 1, + "body_region": [ + "abdomen" + ], + "anatomy_list": [ + "liver" + ], + "controllable_anatomy_size": [], + "num_inference_steps": 1000, + "mask_generation_num_inference_steps": 1000, + "random_seed": null, + "spatial_dims": 3, + "image_channels": 1, + "latent_channels": 4, + "output_size_xy": 512, + "output_size_z": 512, + "output_size": [ + "@output_size_xy", + "@output_size_xy", + "@output_size_z" + ], + "image_output_ext": ".nii.gz", + "label_output_ext": ".nii.gz", + "spacing_xy": 1.0, + "spacing_z": 1.0, + "spacing": [ + "@spacing_xy", + "@spacing_xy", + "@spacing_z" + ], + "latent_shape": [ + "@latent_channels", + "$@output_size[0]//4", + "$@output_size[1]//4", + "$@output_size[2]//4" + ], + "mask_generation_latent_shape": [ + 4, + 64, + 64, + 64 + ], + "autoencoder_def": { + "_target_": "scripts.custom_network_tp.AutoencoderKlckModifiedTp", + "spatial_dims": "@spatial_dims", + "in_channels": "@image_channels", + "out_channels": "@image_channels", + "latent_channels": "@latent_channels", + "num_channels": [ + 64, + 128, + 256 + ], + "num_res_blocks": [ + 2, + 2, + 2 + ], + "norm_num_groups": 32, + "norm_eps": 1e-06, + "attention_levels": [ + false, + false, + false + ], + "with_encoder_nonlocal_attn": false, + "with_decoder_nonlocal_attn": false, + "use_checkpointing": false, + "use_convtranspose": false + }, + "difusion_unet_def": { + "_target_": "scripts.custom_network_diffusion.CustomDiffusionModelUNet", + "spatial_dims": "@spatial_dims", + "in_channels": "@latent_channels", + "out_channels": "@latent_channels", + "num_channels": [ + 64, + 128, + 256, + 512 + ], + "attention_levels": [ + false, + false, + true, + true + ], + "num_head_channels": [ + 0, + 0, + 32, + 32 + ], + "num_res_blocks": 2, + "use_flash_attention": true, + "input_top_region_index": true, + "input_bottom_region_index": true, + "input_spacing": true + }, + "controlnet_def": { + "_target_": "scripts.custom_network_controlnet.CustomControlNet", + "spatial_dims": "@spatial_dims", + "in_channels": "@latent_channels", + "num_channels": [ + 64, + 128, + 256, + 512 + ], + "attention_levels": [ + false, + false, + true, + true + ], + "num_head_channels": [ + 0, + 0, + 32, + 32 + ], + "num_res_blocks": 2, + "use_flash_attention": true, + "conditioning_embedding_in_channels": 8, + "conditioning_embedding_num_channels": [ + 8, + 32, + 64 + ] + }, + "mask_generation_autoencoder_def": { + "_target_": "generative.networks.nets.AutoencoderKL", + "spatial_dims": 3, + "in_channels": 8, + "out_channels": 125, + "latent_channels": 4, + "num_channels": [ + 32, + 64, + 128 + ], + "num_res_blocks": [ + 1, + 2, + 2 + ], + "norm_num_groups": 32, + "norm_eps": 1e-06, + "attention_levels": [ + false, + false, + false + ], + "with_encoder_nonlocal_attn": false, + "with_decoder_nonlocal_attn": false, + "use_flash_attention": false, + "use_checkpointing": true, + "use_convtranspose": true + }, + "mask_generation_diffusion_def": { + "_target_": "generative.networks.nets.DiffusionModelUNet", + "spatial_dims": 3, + "in_channels": 4, + "out_channels": 4, + "num_channels": [ + 64, + 128, + 256, + 512 + ], + "attention_levels": [ + false, + false, + true, + true + ], + "num_head_channels": [ + 0, + 0, + 32, + 32 + ], + "num_res_blocks": 2, + "use_flash_attention": true, + "with_conditioning": true, + "upcast_attention": true, + "cross_attention_dim": 10 + }, + "autoencoder": "$@autoencoder_def.to(@device)", + "checkpoint_autoencoder": "$scripts.utils.load_autoencoder_ckpt(@trained_autoencoder_path)", + "load_autoencoder": "$@autoencoder.load_state_dict(@checkpoint_autoencoder)", + "difusion_unet": "$@difusion_unet_def.to(@device)", + "checkpoint_difusion_unet": "$torch.load(@trained_diffusion_path)", + "load_diffusion": "$@difusion_unet.load_state_dict(@checkpoint_difusion_unet['unet_state_dict'])", + "controlnet": "$@controlnet_def.to(@device)", + "copy_controlnet_state": "$monai.networks.utils.copy_model_state(@controlnet, @difusion_unet.state_dict())", + "checkpoint_controlnet": "$torch.load(@trained_controlnet_path)", + "load_controlnet": "$@controlnet.load_state_dict(@checkpoint_controlnet['controlnet_state_dict'], strict=True)", + "scale_factor": "$@checkpoint_difusion_unet['scale_factor'].to(@device)", + "mask_generation_autoencoder": "$@mask_generation_autoencoder_def.to(@device)", + "checkpoint_mask_generation_autoencoder": "$torch.load(@trained_mask_generation_autoencoder_path)", + "load_mask_generation_autoencoder": "$@mask_generation_autoencoder.load_state_dict(@checkpoint_mask_generation_autoencoder, strict=True)", + "mask_generation_difusion_unet": "$@mask_generation_diffusion_def.to(@device)", + "checkpoint_mask_generation_difusion_unet": "$torch.load(@trained_mask_generation_diffusion_path)", + "load_mask_generation_diffusion": "$@mask_generation_difusion_unet.load_state_dict(@checkpoint_mask_generation_difusion_unet, strict=True)", + "mask_generation_scale_factor": 1.0055984258651733, + "noise_scheduler": { + "_target_": "generative.networks.schedulers.DDPMScheduler", + "num_train_timesteps": 1000, + "beta_start": 0.0015, + "beta_end": 0.0195, + "schedule": "scaled_linear_beta", + "clip_sample": false + }, + "mask_generation_noise_scheduler": { + "_target_": "generative.networks.schedulers.DDPMScheduler", + "num_train_timesteps": 1000, + "beta_start": 0.0015, + "beta_end": 0.0195, + "schedule": "scaled_linear_beta", + "clip_sample": false + }, + "check_input": "$scripts.sample.check_input(@body_region,@anatomy_list,@label_dict_json,@output_size,@spacing,@controllable_anatomy_size)", + "ldm_sampler": { + "_target_": "scripts.sample.LDMSampler", + "_requires_": [ + "@create_output_dir", + "@load_diffusion", + "@load_autoencoder", + "@copy_controlnet_state", + "@load_controlnet", + "@load_mask_generation_autoencoder", + "@load_mask_generation_diffusion", + "@check_input" + ], + "body_region": "@body_region", + "anatomy_list": "@anatomy_list", + "all_mask_files_json": "@all_mask_files_json", + "all_anatomy_size_condtions_json": "@all_anatomy_size_condtions_json", + "all_mask_files_base_dir": "@all_mask_files_base_dir", + "label_dict_json": "@label_dict_json", + "label_dict_remap_json": "@label_dict_remap_json", + "autoencoder": "@autoencoder", + "difusion_unet": "@difusion_unet", + "controlnet": "@controlnet", + "scale_factor": "@scale_factor", + "noise_scheduler": "@noise_scheduler", + "mask_generation_autoencoder": "@mask_generation_autoencoder", + "mask_generation_difusion_unet": "@mask_generation_difusion_unet", + "mask_generation_scale_factor": "@mask_generation_scale_factor", + "mask_generation_noise_scheduler": "@mask_generation_noise_scheduler", + "controllable_anatomy_size": "@controllable_anatomy_size", + "image_output_ext": "@image_output_ext", + "label_output_ext": "@label_output_ext", + "device": "@device", + "latent_shape": "@latent_shape", + "mask_generation_latent_shape": "@mask_generation_latent_shape", + "output_size": "@output_size", + "quality_check_args": "@quality_check_args", + "spacing": "@spacing", + "output_dir": "@output_dir", + "num_inference_steps": "@num_inference_steps", + "mask_generation_num_inference_steps": "@mask_generation_num_inference_steps", + "random_seed": "@random_seed" + }, + "run": [ + "$@ldm_sampler.sample_multiple_images(@num_output_samples)" + ], + "evaluator": null +} diff --git a/models/maisi_ct_generative/configs/integration_test_masks.json b/models/maisi_ct_generative/configs/integration_test_masks.json new file mode 100644 index 00000000..44e9cae7 --- /dev/null +++ b/models/maisi_ct_generative/configs/integration_test_masks.json @@ -0,0 +1,98 @@ +[ + { + "bottom_region_index": [ + 0, + 0, + 0, + 1 + ], + "dim": [ + 512, + 512, + 512 + ], + "label_list": [ + 1, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 17, + 19, + 25, + 28, + 29, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 58, + 59, + 60, + 61, + 62, + 69, + 70, + 71, + 72, + 73, + 74, + 81, + 82, + 83, + 84, + 85, + 86, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 114, + 115, + 118, + 121, + 122, + 127 + ], + "pseudo_label_filename": "./IntegrationTest-AbdomenCT.nii.gz", + "spacing": [ + 1.0, + 1.0, + 1.0 + ], + "top_region_index": [ + 0, + 1, + 0, + 0 + ] + } +] diff --git a/models/maisi_ct_generative/configs/label_dict.json b/models/maisi_ct_generative/configs/label_dict.json new file mode 100644 index 00000000..44bbb645 --- /dev/null +++ b/models/maisi_ct_generative/configs/label_dict.json @@ -0,0 +1,134 @@ +{ + "liver": 1, + "dummy1": 2, + "spleen": 3, + "pancreas": 4, + "right kidney": 5, + "aorta": 6, + "inferior vena cava": 7, + "right adrenal gland": 8, + "left adrenal gland": 9, + "gallbladder": 10, + "esophagus": 11, + "stomach": 12, + "duodenum": 13, + "left kidney": 14, + "bladder": 15, + "dummy2": 16, + "portal vein and splenic vein": 17, + "dummy3": 18, + "small bowel": 19, + "dummy4": 20, + "dummy5": 21, + "brain": 22, + "lung tumor": 23, + "pancreatic tumor": 24, + "hepatic vessel": 25, + "hepatic tumor": 26, + "colon cancer primaries": 27, + "left lung upper lobe": 28, + "left lung lower lobe": 29, + "right lung upper lobe": 30, + "right lung middle lobe": 31, + "right lung lower lobe": 32, + "vertebrae L5": 33, + "vertebrae L4": 34, + "vertebrae L3": 35, + "vertebrae L2": 36, + "vertebrae L1": 37, + "vertebrae T12": 38, + "vertebrae T11": 39, + "vertebrae T10": 40, + "vertebrae T9": 41, + "vertebrae T8": 42, + "vertebrae T7": 43, + "vertebrae T6": 44, + "vertebrae T5": 45, + "vertebrae T4": 46, + "vertebrae T3": 47, + "vertebrae T2": 48, + "vertebrae T1": 49, + "vertebrae C7": 50, + "vertebrae C6": 51, + "vertebrae C5": 52, + "vertebrae C4": 53, + "vertebrae C3": 54, + "vertebrae C2": 55, + "vertebrae C1": 56, + "trachea": 57, + "left iliac artery": 58, + "right iliac artery": 59, + "left iliac vena": 60, + "right iliac vena": 61, + "colon": 62, + "left rib 1": 63, + "left rib 2": 64, + "left rib 3": 65, + "left rib 4": 66, + "left rib 5": 67, + "left rib 6": 68, + "left rib 7": 69, + "left rib 8": 70, + "left rib 9": 71, + "left rib 10": 72, + "left rib 11": 73, + "left rib 12": 74, + "right rib 1": 75, + "right rib 2": 76, + "right rib 3": 77, + "right rib 4": 78, + "right rib 5": 79, + "right rib 6": 80, + "right rib 7": 81, + "right rib 8": 82, + "right rib 9": 83, + "right rib 10": 84, + "right rib 11": 85, + "right rib 12": 86, + "left humerus": 87, + "right humerus": 88, + "left scapula": 89, + "right scapula": 90, + "left clavicula": 91, + "right clavicula": 92, + "left femur": 93, + "right femur": 94, + "left hip": 95, + "right hip": 96, + "sacrum": 97, + "left gluteus maximus": 98, + "right gluteus maximus": 99, + "left gluteus medius": 100, + "right gluteus medius": 101, + "left gluteus minimus": 102, + "right gluteus minimus": 103, + "left autochthon": 104, + "right autochthon": 105, + "left iliopsoas": 106, + "right iliopsoas": 107, + "left atrial appendage": 108, + "brachiocephalic trunk": 109, + "left brachiocephalic vein": 110, + "right brachiocephalic vein": 111, + "left common carotid artery": 112, + "right common carotid artery": 113, + "costal cartilages": 114, + "heart": 115, + "left kidney cyst": 116, + "right kidney cyst": 117, + "prostate": 118, + "pulmonary vein": 119, + "skull": 120, + "spinal cord": 121, + "sternum": 122, + "left subclavian artery": 123, + "right subclavian artery": 124, + "superior vena cava": 125, + "thyroid gland": 126, + "vertebrae S1": 127, + "bone lesion": 128, + "dummy6": 129, + "dummy7": 130, + "dummy8": 131, + "airway": 132 +} diff --git a/models/maisi_ct_generative/configs/label_dict_124_to_132.json b/models/maisi_ct_generative/configs/label_dict_124_to_132.json new file mode 100644 index 00000000..96ff311e --- /dev/null +++ b/models/maisi_ct_generative/configs/label_dict_124_to_132.json @@ -0,0 +1,502 @@ +{ + "background": [ + 0, + 0 + ], + "liver": [ + 1, + 1 + ], + "spleen": [ + 2, + 3 + ], + "pancreas": [ + 3, + 4 + ], + "right kidney": [ + 4, + 5 + ], + "aorta": [ + 5, + 6 + ], + "inferior vena cava": [ + 6, + 7 + ], + "right adrenal gland": [ + 7, + 8 + ], + "left adrenal gland": [ + 8, + 9 + ], + "gallbladder": [ + 9, + 10 + ], + "esophagus": [ + 10, + 11 + ], + "stomach": [ + 11, + 12 + ], + "duodenum": [ + 12, + 13 + ], + "left kidney": [ + 13, + 14 + ], + "bladder": [ + 14, + 15 + ], + "portal vein and splenic vein": [ + 15, + 17 + ], + "small bowel": [ + 16, + 19 + ], + "brain": [ + 17, + 22 + ], + "lung tumor": [ + 18, + 23 + ], + "pancreatic tumor": [ + 19, + 24 + ], + "hepatic vessel": [ + 20, + 25 + ], + "hepatic tumor": [ + 21, + 26 + ], + "colon cancer primaries": [ + 22, + 27 + ], + "left lung upper lobe": [ + 23, + 28 + ], + "left lung lower lobe": [ + 24, + 29 + ], + "right lung upper lobe": [ + 25, + 30 + ], + "right lung middle lobe": [ + 26, + 31 + ], + "right lung lower lobe": [ + 27, + 32 + ], + "vertebrae L5": [ + 28, + 33 + ], + "vertebrae L4": [ + 29, + 34 + ], + "vertebrae L3": [ + 30, + 35 + ], + "vertebrae L2": [ + 31, + 36 + ], + "vertebrae L1": [ + 32, + 37 + ], + "vertebrae T12": [ + 33, + 38 + ], + "vertebrae T11": [ + 34, + 39 + ], + "vertebrae T10": [ + 35, + 40 + ], + "vertebrae T9": [ + 36, + 41 + ], + "vertebrae T8": [ + 37, + 42 + ], + "vertebrae T7": [ + 38, + 43 + ], + "vertebrae T6": [ + 39, + 44 + ], + "vertebrae T5": [ + 40, + 45 + ], + "vertebrae T4": [ + 41, + 46 + ], + "vertebrae T3": [ + 42, + 47 + ], + "vertebrae T2": [ + 43, + 48 + ], + "vertebrae T1": [ + 44, + 49 + ], + "vertebrae C7": [ + 45, + 50 + ], + "vertebrae C6": [ + 46, + 51 + ], + "vertebrae C5": [ + 47, + 52 + ], + "vertebrae C4": [ + 48, + 53 + ], + "vertebrae C3": [ + 49, + 54 + ], + "vertebrae C2": [ + 50, + 55 + ], + "vertebrae C1": [ + 51, + 56 + ], + "trachea": [ + 52, + 57 + ], + "left iliac artery": [ + 53, + 58 + ], + "right iliac artery": [ + 54, + 59 + ], + "left iliac vena": [ + 55, + 60 + ], + "right iliac vena": [ + 56, + 61 + ], + "colon": [ + 57, + 62 + ], + "left rib 1": [ + 58, + 63 + ], + "left rib 2": [ + 59, + 64 + ], + "left rib 3": [ + 60, + 65 + ], + "left rib 4": [ + 61, + 66 + ], + "left rib 5": [ + 62, + 67 + ], + "left rib 6": [ + 63, + 68 + ], + "left rib 7": [ + 64, + 69 + ], + "left rib 8": [ + 65, + 70 + ], + "left rib 9": [ + 66, + 71 + ], + "left rib 10": [ + 67, + 72 + ], + "left rib 11": [ + 68, + 73 + ], + "left rib 12": [ + 69, + 74 + ], + "right rib 1": [ + 70, + 75 + ], + "right rib 2": [ + 71, + 76 + ], + "right rib 3": [ + 72, + 77 + ], + "right rib 4": [ + 73, + 78 + ], + "right rib 5": [ + 74, + 79 + ], + "right rib 6": [ + 75, + 80 + ], + "right rib 7": [ + 76, + 81 + ], + "right rib 8": [ + 77, + 82 + ], + "right rib 9": [ + 78, + 83 + ], + "right rib 10": [ + 79, + 84 + ], + "right rib 11": [ + 80, + 85 + ], + "right rib 12": [ + 81, + 86 + ], + "left humerus": [ + 82, + 87 + ], + "right humerus": [ + 83, + 88 + ], + "left scapula": [ + 84, + 89 + ], + "right scapula": [ + 85, + 90 + ], + "left clavicula": [ + 86, + 91 + ], + "right clavicula": [ + 87, + 92 + ], + "left femur": [ + 88, + 93 + ], + "right femur": [ + 89, + 94 + ], + "left hip": [ + 90, + 95 + ], + "right hip": [ + 91, + 96 + ], + "sacrum": [ + 92, + 97 + ], + "left gluteus maximus": [ + 93, + 98 + ], + "right gluteus maximus": [ + 94, + 99 + ], + "left gluteus medius": [ + 95, + 100 + ], + "right gluteus medius": [ + 96, + 101 + ], + "left gluteus minimus": [ + 97, + 102 + ], + "right gluteus minimus": [ + 98, + 103 + ], + "left autochthon": [ + 99, + 104 + ], + "right autochthon": [ + 100, + 105 + ], + "left iliopsoas": [ + 101, + 106 + ], + "right iliopsoas": [ + 102, + 107 + ], + "left atrial appendage": [ + 103, + 108 + ], + "brachiocephalic trunk": [ + 104, + 109 + ], + "left brachiocephalic vein": [ + 105, + 110 + ], + "right brachiocephalic vein": [ + 106, + 111 + ], + "left common carotid artery": [ + 107, + 112 + ], + "right common carotid artery": [ + 108, + 113 + ], + "costal cartilages": [ + 109, + 114 + ], + "heart": [ + 110, + 115 + ], + "prostate": [ + 111, + 118 + ], + "pulmonary vein": [ + 112, + 119 + ], + "skull": [ + 113, + 120 + ], + "spinal cord": [ + 114, + 121 + ], + "sternum": [ + 115, + 122 + ], + "left subclavian artery": [ + 116, + 123 + ], + "right subclavian artery": [ + 117, + 124 + ], + "superior vena cava": [ + 118, + 125 + ], + "thyroid gland": [ + 119, + 126 + ], + "vertebrae S1": [ + 120, + 127 + ], + "bone lesion": [ + 121, + 128 + ], + "kidney mass": [ + 122, + 129 + ], + "airway": [ + 123, + 132 + ], + "body": [ + 124, + 200 + ] +} diff --git a/models/maisi_ct_generative/configs/logging.conf b/models/maisi_ct_generative/configs/logging.conf new file mode 100644 index 00000000..91c1a21c --- /dev/null +++ b/models/maisi_ct_generative/configs/logging.conf @@ -0,0 +1,21 @@ +[loggers] +keys=root + +[handlers] +keys=consoleHandler + +[formatters] +keys=fullFormatter + +[logger_root] +level=INFO +handlers=consoleHandler + +[handler_consoleHandler] +class=StreamHandler +level=INFO +formatter=fullFormatter +args=(sys.stdout,) + +[formatter_fullFormatter] +format=%(asctime)s - %(name)s - %(levelname)s - %(message)s diff --git a/models/maisi_ct_generative/configs/metadata.json b/models/maisi_ct_generative/configs/metadata.json new file mode 100644 index 00000000..87c0eff2 --- /dev/null +++ b/models/maisi_ct_generative/configs/metadata.json @@ -0,0 +1,263 @@ +{ + "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_generator_ldm_20240318.json", + "version": "0.3.6", + "changelog": { + "0.3.6": "first oss version" + }, + "monai_version": "1.3.1", + "pytorch_version": "2.2.2", + "numpy_version": "1.24.4", + "optional_packages_version": { + "fire": "0.6.0", + "nibabel": "5.2.1", + "monai-generative": "0.2.3", + "tqdm": "4.66.2", + "xformers": "0.0.26" + }, + "supported_apps": { + "maisi-nim": "" + }, + "name": "CT image latent diffusion generation", + "task": "CT image synthesis", + "description": "A generative model for creating 3D CT from Gaussian noise", + "authors": "MONAI team", + "copyright": "Copyright (c) MONAI Consortium", + "data_source": "http://medicaldecathlon.com/", + "data_type": "nibabel", + "image_classes": "Flair brain MRI with 1.1x1.1x1.1 mm voxel size", + "eval_metrics": {}, + "intended_use": "This is a research tool/prototype and not to be used clinically", + "references": [], + "autoencoder_data_format": { + "inputs": { + "image": { + "type": "feature", + "format": "image", + "num_channels": 4, + "spatial_shape": [ + 128, + 128, + 128 + ], + "dtype": "float16", + "value_range": [ + 0, + 1 + ], + "is_patch_data": true + }, + "body_region": { + "type": "array", + "value_range": [ + "head", + "abdomen", + "chest/thorax", + "pelvis/lower" + ] + }, + "anatomy_list": { + "type": "array", + "value_range": [ + "liver", + "spleen", + "pancreas", + "right kidney", + "aorta", + "inferior vena cava", + "right adrenal gland", + "left adrenal gland", + "gallbladder", + "esophagus", + "stomach", + "duodenum", + "left kidney", + "bladder", + "portal vein and splenic vein", + "small bowel", + "brain", + "lung tumor", + "pancreatic tumor", + "hepatic vessel", + "hepatic tumor", + "colon cancer primaries", + "left lung upper lobe", + "left lung lower lobe", + "right lung upper lobe", + "right lung middle lobe", + "right lung lower lobe", + "vertebrae L5", + "vertebrae L4", + "vertebrae L3", + "vertebrae L2", + "vertebrae L1", + "vertebrae T12", + "vertebrae T11", + "vertebrae T10", + "vertebrae T9", + "vertebrae T8", + "vertebrae T7", + "vertebrae T6", + "vertebrae T5", + "vertebrae T4", + "vertebrae T3", + "vertebrae T2", + "vertebrae T1", + "vertebrae C7", + "vertebrae C6", + "vertebrae C5", + "vertebrae C4", + "vertebrae C3", + "vertebrae C2", + "vertebrae C1", + "trachea", + "left iliac artery", + "right iliac artery", + "left iliac vena", + "right iliac vena", + "colon", + "left rib 1", + "left rib 2", + "left rib 3", + "left rib 4", + "left rib 5", + "left rib 6", + "left rib 7", + "left rib 8", + "left rib 9", + "left rib 10", + "left rib 11", + "left rib 12", + "right rib 1", + "right rib 2", + "right rib 3", + "right rib 4", + "right rib 5", + "right rib 6", + "right rib 7", + "right rib 8", + "right rib 9", + "right rib 10", + "right rib 11", + "right rib 12", + "left humerus", + "right humerus", + "left scapula", + "right scapula", + "left clavicula", + "right clavicula", + "left femur", + "right femur", + "left hip", + "right hip", + "sacrum", + "left gluteus maximus", + "right gluteus maximus", + "left gluteus medius", + "right gluteus medius", + "left gluteus minimus", + "right gluteus minimus", + "left autochthon", + "right autochthon", + "left iliopsoas", + "right iliopsoas", + "left atrial appendage", + "brachiocephalic trunk", + "left brachiocephalic vein", + "right brachiocephalic vein", + "left common carotid artery", + "right common carotid artery", + "costal cartilages", + "heart", + "left kidney cyst", + "right kidney cyst", + "prostate", + "pulmonary vein", + "skull", + "spinal cord", + "sternum", + "left subclavian artery", + "right subclavian artery", + "superior vena cava", + "thyroid gland", + "vertebrae S1", + "bone lesion", + "airway" + ] + } + }, + "outputs": { + "pred": { + "type": "image", + "format": "image", + "num_channels": 1, + "spatial_shape": [ + 512, + 512, + 512 + ], + "dtype": "float16", + "value_range": [ + 0, + 1 + ], + "is_patch_data": true, + "channel_def": { + "0": "image" + } + } + } + }, + "generator_data_format": { + "inputs": { + "latent": { + "type": "noise", + "format": "image", + "num_channels": 4, + "spatial_shape": [ + 128, + 128, + 128 + ], + "dtype": "float16", + "value_range": [ + 0, + 1 + ], + "is_patch_data": true + }, + "condition": { + "type": "timesteps", + "format": "timesteps", + "num_channels": 1, + "spatial_shape": [], + "dtype": "long", + "value_range": [ + 0, + 1000 + ], + "is_patch_data": false + } + }, + "outputs": { + "pred": { + "type": "feature", + "format": "image", + "num_channels": 4, + "spatial_shape": [ + 128, + 128, + 128 + ], + "dtype": "float16", + "value_range": [ + 0, + 1 + ], + "is_patch_data": true, + "channel_def": { + "0": "image" + } + } + } + } +} diff --git a/models/maisi_ct_generative/configs/multi_gpu_train.json b/models/maisi_ct_generative/configs/multi_gpu_train.json new file mode 100644 index 00000000..2d0e1929 --- /dev/null +++ b/models/maisi_ct_generative/configs/multi_gpu_train.json @@ -0,0 +1,34 @@ +{ + "device": "$torch.device('cuda:' + os.environ['LOCAL_RANK'])", + "use_tensorboard": "$dist.get_rank() == 0", + "controlnet": { + "_target_": "torch.nn.parallel.DistributedDataParallel", + "module": "$@controlnet_def.to(@device)", + "find_unused_parameters": true, + "device_ids": [ + "@device" + ] + }, + "load_controlnet": "$@controlnet.module.load_state_dict(@checkpoint_controlnet['controlnet_state_dict'], strict=True)", + "train#sampler": { + "_target_": "DistributedSampler", + "dataset": "@train#dataset", + "even_divisible": true, + "shuffle": true + }, + "train#dataloader#sampler": "@train#sampler", + "train#dataloader#shuffle": false, + "train#trainer#train_handlers": "$@train#handlers[: -1 if dist.get_rank() > 0 else None]", + "initialize": [ + "$import torch.distributed as dist", + "$dist.is_initialized() or dist.init_process_group(backend='nccl')", + "$torch.cuda.set_device(@device)", + "$monai.utils.set_determinism(seed=123)" + ], + "run": [ + "$@train#trainer.run()" + ], + "finalize": [ + "$dist.is_initialized() and dist.destroy_process_group()" + ] +} diff --git a/models/maisi_ct_generative/configs/train.json b/models/maisi_ct_generative/configs/train.json new file mode 100644 index 00000000..33ba088e --- /dev/null +++ b/models/maisi_ct_generative/configs/train.json @@ -0,0 +1,270 @@ +{ + "imports": [ + "$import glob", + "$import os", + "$import scripts", + "$import ignite" + ], + "bundle_root": ".", + "ckpt_dir": "$@bundle_root + '/models'", + "output_dir": "$@bundle_root + '/output'", + "data_list_file_path": "$@bundle_root + '/datasets/C4KC-KiTS_subset.json'", + "dataset_dir": "$@bundle_root + '/datasets/C4KC-KiTS_subset'", + "trained_diffusion_path": "$@ckpt_dir + '/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt'", + "trained_controlnet_path": "$@ckpt_dir + '/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt'", + "use_tensorboard": true, + "fold": 0, + "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", + "epochs": 100, + "batch_size": 1, + "val_at_start": false, + "learning_rate": 0.0001, + "weighted_loss_label": [ + 129 + ], + "weighted_loss": 100, + "amp": true, + "train_datalist": "$scripts.utils.maisi_datafold_read(json_list=@data_list_file_path, data_base_dir=@dataset_dir, fold=@fold)[0]", + "spatial_dims": 3, + "image_channels": 1, + "latent_channels": 4, + "difusion_unet_def": { + "_target_": "scripts.custom_network_diffusion.CustomDiffusionModelUNet", + "spatial_dims": "@spatial_dims", + "in_channels": "@latent_channels", + "out_channels": "@latent_channels", + "num_channels": [ + 64, + 128, + 256, + 512 + ], + "attention_levels": [ + false, + false, + true, + true + ], + "num_head_channels": [ + 0, + 0, + 32, + 32 + ], + "num_res_blocks": 2, + "use_flash_attention": true, + "input_top_region_index": true, + "input_bottom_region_index": true, + "input_spacing": true + }, + "controlnet_def": { + "_target_": "scripts.custom_network_controlnet.CustomControlNet", + "spatial_dims": "@spatial_dims", + "in_channels": "@latent_channels", + "num_channels": [ + 64, + 128, + 256, + 512 + ], + "attention_levels": [ + false, + false, + true, + true + ], + "num_head_channels": [ + 0, + 0, + 32, + 32 + ], + "num_res_blocks": 2, + "use_flash_attention": true, + "conditioning_embedding_in_channels": 8, + "conditioning_embedding_num_channels": [ + 8, + 32, + 64 + ] + }, + "noise_scheduler": { + "_target_": "generative.networks.schedulers.DDPMScheduler", + "num_train_timesteps": 1000, + "beta_start": 0.0015, + "beta_end": 0.0195, + "schedule": "scaled_linear_beta", + "clip_sample": false + }, + "unzip_dataset": "scripts.utils.unzip_dataset(@dataset_dir)", + "difusion_unet": "$@difusion_unet_def.to(@device)", + "checkpoint_difusion_unet": "$torch.load(@trained_diffusion_path)", + "load_diffusion": "$@difusion_unet.load_state_dict(@checkpoint_difusion_unet['unet_state_dict'])", + "controlnet": "$@controlnet_def.to(@device)", + "copy_controlnet_state": "$monai.networks.utils.copy_model_state(@controlnet, @difusion_unet.state_dict())", + "checkpoint_controlnet": "$torch.load(@trained_controlnet_path)", + "load_controlnet": "$@controlnet.load_state_dict(@checkpoint_controlnet['controlnet_state_dict'], strict=True)", + "scale_factor": "$@checkpoint_controlnet['scale_factor'].to(@device)", + "loss": { + "_target_": "torch.nn.L1Loss", + "reduction": "none" + }, + "optimizer": { + "_target_": "torch.optim.AdamW", + "params": "$@controlnet.parameters()", + "lr": "@learning_rate", + "weight_decay": 1e-05 + }, + "lr_schedule": { + "activate": true, + "lr_scheduler": { + "_target_": "torch.optim.lr_scheduler.PolynomialLR", + "optimizer": "@optimizer", + "total_iters": "$(@epochs * len(@train#dataloader.dataset)) / @batch_size", + "power": 2.0 + } + }, + "train": { + "deterministic_transforms": [ + { + "_target_": "LoadImaged", + "keys": [ + "image", + "label" + ], + "image_only": true, + "ensure_channel_first": true + }, + { + "_target_": "Orientationd", + "keys": [ + "label" + ], + "axcodes": "RAS" + }, + { + "_target_": "EnsureTyped", + "keys": [ + "label" + ], + "dtype": "$torch.uint8", + "track_meta": true + }, + { + "_target_": "Lambdad", + "keys": "top_region_index", + "func": "$lambda x: torch.FloatTensor(x)" + }, + { + "_target_": "Lambdad", + "keys": "bottom_region_index", + "func": "$lambda x: torch.FloatTensor(x)" + }, + { + "_target_": "Lambdad", + "keys": "spacing", + "func": "$lambda x: torch.FloatTensor(x)" + }, + { + "_target_": "Lambdad", + "keys": "top_region_index", + "func": "$lambda x: x * 1e2" + }, + { + "_target_": "Lambdad", + "keys": "bottom_region_index", + "func": "$lambda x: x * 1e2" + }, + { + "_target_": "Lambdad", + "keys": "spacing", + "func": "$lambda x: x * 1e2" + } + ], + "inferer": { + "_target_": "SimpleInferer" + }, + "preprocessing": { + "_target_": "Compose", + "transforms": "$@train#deterministic_transforms" + }, + "dataset": { + "_target_": "Dataset", + "data": "@train_datalist", + "transform": "@train#preprocessing" + }, + "dataloader": { + "_target_": "DataLoader", + "dataset": "@train#dataset", + "batch_size": "@batch_size", + "shuffle": true, + "num_workers": 4, + "pin_memory": true, + "persistent_workers": true + }, + "handlers": [ + { + "_target_": "LrScheduleHandler", + "_disabled_": "$not @lr_schedule#activate", + "lr_scheduler": "@lr_schedule#lr_scheduler", + "epoch_level": false, + "print_lr": true + }, + { + "_target_": "CheckpointSaver", + "save_dir": "@ckpt_dir", + "save_dict": { + "controlnet_state_dict": "@controlnet" + }, + "save_interval": 1, + "n_saved": 5 + }, + { + "_target_": "TensorBoardStatsHandler", + "_disabled_": "$not @use_tensorboard", + "log_dir": "@output_dir", + "tag_name": "train_loss", + "output_transform": "$monai.handlers.from_engine(['loss'], first=True)" + }, + { + "_target_": "StatsHandler", + "tag_name": "train_loss", + "name": "StatsHandler", + "output_transform": "$monai.handlers.from_engine(['loss'], first=True)" + } + ], + "trainer": { + "_target_": "scripts.trainer.MAISIControlNetTrainer", + "_requires_": [ + "@load_diffusion", + "@copy_controlnet_state", + "@load_controlnet", + "@unzip_dataset" + ], + "max_epochs": "@epochs", + "device": "@device", + "train_data_loader": "@train#dataloader", + "difusion_unet": "@difusion_unet", + "controlnet": "@controlnet", + "noise_scheduler": "@noise_scheduler", + "loss_function": "@loss", + "optimizer": "@optimizer", + "inferer": "@train#inferer", + "key_train_metric": null, + "train_handlers": "@train#handlers", + "amp": "@amp", + "hyper_kwargs": { + "weighted_loss": "@weighted_loss", + "weighted_loss_label": "@weighted_loss_label", + "scale_factor": "@scale_factor" + } + } + }, + "initialize": [ + "$monai.utils.set_determinism(seed=0)" + ], + "run": [ + "$@train#trainer.add_event_handler(ignite.engine.Events.ITERATION_COMPLETED, ignite.handlers.TerminateOnNan())", + "$@train#trainer.run()" + ] +} diff --git a/models/maisi_ct_generative/docs/README.md b/models/maisi_ct_generative/docs/README.md new file mode 100644 index 00000000..bcb7c05f --- /dev/null +++ b/models/maisi_ct_generative/docs/README.md @@ -0,0 +1,103 @@ +# Model Overview +This bundle is for Nvidia MAISI (Medical AI for Synthetic Imaging), a 3D Latent Diffusion Model that can generate large CT images with paired segmentation masks, variable volume size and voxel size, as well as controllable organ/tumor size. + +The inference workflow of MAISI is depicted in the figure below. It first generates latent features from random noise by applying multiple denoising steps using the trained diffusion model. Then it decodes the denoised latent features into images using the trained autoencoder. + +

+ MAISI inference scheme +

+ +MAISI is based on the following papers: + +[**Latent Diffusion:** Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." CVPR 2022.](https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf) + +[**ControlNet:** Lvmin Zhang, Anyi Rao, Maneesh Agrawala; “Adding Conditional Control to Text-to-Image Diffusion Models.” ICCV 2023.](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhang_Adding_Conditional_Control_to_Text-to-Image_Diffusion_Models_ICCV_2023_paper.pdf) + +#### Example synthetic image +An example result from inference is shown below: +![Example synthetic image](https://developer.download.nvidia.com/assets/Clara/Images/monai_maisi_ct_generative_example_synthetic_data.png) + +### Inference configuration +The inference requires: +- GPU: at least 58GB GPU memory for 512 x 512 x 512 +- Disk Memory: at least 21GB disk memory + +### Execute inference +The following code generates a synthetic image from a random sampled noise. +``` +python -m monai.bundle run --config_file configs/inference.json +``` + +## Execute Finetuning + +### Training configuration +The training was performed with the following: +- GPU: at least 60GB GPU memory for 512 x 512 x 512 volume +- Actual Model Input (the size of image embedding in latent space): 128 x 128 x 128 +- AMP: True + +### Run finetuning: +This config executes finetuning for pretrained ControlNet with with a new class (i.e., Kidney Tumor). When finetuning with new class names, please update `configs/train.json`'s `weighted_loss_label` and `configs/label_dict.json` accordingly. There are 8 dummy labels as placeholders in default `configs/label_dict.json` that can be used for finetuning. +``` +python -m monai.bundle run --config_file configs/train.json +``` + +### Override the `train` config to execute multi-GPU training: + +``` +torchrun --standalone --nnodes=1 --nproc_per_node=2 -m monai.bundle run --config_file "['configs/train.json','configs/multi_gpu_train.json']" +``` + +### Data: +The preprocessed subset of [C4KC-KiTS](https://www.cancerimagingarchive.net/collection/c4kc-kits/) dataset used in this finetuning config is provided in `./dataset/C4KC-KiTS_subset`. +``` + |-*arterial*.nii.gz # original image + |-*arterial_emb*.nii.gz # encoded image embedding +KiTS-000* --|-mask*.nii.gz # original labels + |-mask_pseudo_label*.nii.gz # pseudo labels + |-mask_combined_label*.nii.gz # combined mask of original and pseudo labels + +``` +An example combined mask of original and pseudo labels is shown below: +![example_combined_mask](https://developer.download.nvidia.com/assets/Clara/Images/monai_maisi_ct_generative_example_combined_mask.png) + +Please note that the label of Kidney Tumor is mapped to index `129` in this preprocessed dataset. The encoded image embedding is generated by provided `Autoencoder` in `./models/autoencoder_epoch273.pt` during preprocessing to save memeory usage for training. The pseudo labels are generated by [VISTA 3D](https://github.com/Project-MONAI/VISTA). In addition, the dimension of each volume and corresponding pseudo label is resampled to the closest multiple of 128 (e.g., 128, 256, 384, 512, ...). + +The training workflow requires one JSON file to specify the image embedding and segmentation pairs. The example file is located in the `./dataset/C4KC-KiTS_subset.json`. + +The JSON file has the following structure: +```python +{ + "training": [ + { + "image": "*/*arterial_emb*.nii.gz", # relative path to the image embedding file + "label": "*/mask_combined_label*.nii.gz", # relative path to the combined label file + "dim": [512, 512, 512], # the dimension of image + "spacing": [1.0, 1.0, 1.0], # the spacing of image + "top_region_index": [0, 1, 0, 0], # the top region index of the image + "bottom_region_index": [0, 0, 0, 1], # the bottom region index of the image + "fold": 0 # fold index for cross validation, fold 0 is used for training + }, + + ... + ] +} +``` + +# References +[1] Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022. https://openaccess.thecvf.com/content/CVPR2022/papers/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.pdf + +# License + +## Code License + +This project includes code licensed under the Apache License 2.0. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +## Model Weights License + +The model weights included in this project are licensed under the NCLS v1 License. + +Both licenses' full texts have been combined into a single `LICENSE` file. Please refer to this `LICENSE` file for more details about the terms and conditions of both licenses. diff --git a/models/maisi_ct_generative/docs/data_license.txt b/models/maisi_ct_generative/docs/data_license.txt new file mode 100644 index 00000000..d3d7e227 --- /dev/null +++ b/models/maisi_ct_generative/docs/data_license.txt @@ -0,0 +1,49 @@ +Third Party Licenses +----------------------------------------------------------------------- + +/*********************************************************************/ +i. Multimodal Brain Tumor Segmentation Challenge 2018 + https://www.med.upenn.edu/sbia/brats2018/data.html +/*********************************************************************/ + +Data Usage Agreement / Citations + +You are free to use and/or refer to the BraTS datasets in your own +research, provided that you always cite the following two manuscripts: + +[1] Menze BH, Jakab A, Bauer S, Kalpathy-Cramer J, Farahani K, Kirby +[J, Burren Y, Porz N, Slotboom J, Wiest R, Lanczi L, Gerstner E, Weber +[MA, Arbel T, Avants BB, Ayache N, Buendia P, Collins DL, Cordier N, +[Corso JJ, Criminisi A, Das T, Delingette H, Demiralp Γ, Durst CR, +[Dojat M, Doyle S, Festa J, Forbes F, Geremia E, Glocker B, Golland P, +[Guo X, Hamamci A, Iftekharuddin KM, Jena R, John NM, Konukoglu E, +[Lashkari D, Mariz JA, Meier R, Pereira S, Precup D, Price SJ, Raviv +[TR, Reza SM, Ryan M, Sarikaya D, Schwartz L, Shin HC, Shotton J, +[Silva CA, Sousa N, Subbanna NK, Szekely G, Taylor TJ, Thomas OM, +[Tustison NJ, Unal G, Vasseur F, Wintermark M, Ye DH, Zhao L, Zhao B, +[Zikic D, Prastawa M, Reyes M, Van Leemput K. "The Multimodal Brain +[Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on +[Medical Imaging 34(10), 1993-2024 (2015) DOI: +[10.1109/TMI.2014.2377694 + +[2] Bakas S, Akbari H, Sotiras A, Bilello M, Rozycki M, Kirby JS, +[Freymann JB, Farahani K, Davatzikos C. "Advancing The Cancer Genome +[Atlas glioma MRI collections with expert segmentation labels and +[radiomic features", Nature Scientific Data, 4:170117 (2017) DOI: +[10.1038/sdata.2017.117 + +In addition, if there are no restrictions imposed from the +journal/conference you submit your paper about citing "Data +Citations", please be specific and also cite the following: + +[3] Bakas S, Akbari H, Sotiras A, Bilello M, Rozycki M, Kirby J, +[Freymann J, Farahani K, Davatzikos C. "Segmentation Labels and +[Radiomic Features for the Pre-operative Scans of the TCGA-GBM +[collection", The Cancer Imaging Archive, 2017. DOI: +[10.7937/K9/TCIA.2017.KLXWJJ1Q + +[4] Bakas S, Akbari H, Sotiras A, Bilello M, Rozycki M, Kirby J, +[Freymann J, Farahani K, Davatzikos C. "Segmentation Labels and +[Radiomic Features for the Pre-operative Scans of the TCGA-LGG +[collection", The Cancer Imaging Archive, 2017. DOI: +[10.7937/K9/TCIA.2017.GJQ7R0EF diff --git a/models/maisi_ct_generative/large_files.yml b/models/maisi_ct_generative/large_files.yml new file mode 100644 index 00000000..9111cee2 --- /dev/null +++ b/models/maisi_ct_generative/large_files.yml @@ -0,0 +1,23 @@ +large_files: + - path: "models/autoencoder_epoch273.pt" + url: "https://drive.google.com/file/d/1jQefG0yJPzSvTG5rIJVHNqDReBTvVmZ0/view?usp=drive_link" + - path: "models/input_unet3d_data-all_steps1000size512ddpm_random_current_inputx_v1.pt" + url: "https://drive.google.com/file/d/1FtOHBGUF5dLZNHtiuhf5EH448EQGGs-_/view?usp=sharing" + - path: "models/controlnet-20datasets-e20wl100fold0bc_noi_dia_fsize_current.pt" + url: "https://drive.google.com/file/d/1izr52Whkk56OevNTk2QzI86eJV9TTaLk/view?usp=sharing" + - path: "models/mask_generation_autoencoder.pt" + url: "https://drive.google.com/file/d/1FzWrpv6ornYUaPiAWGOOxhRx2P9Wnynm/view?usp=drive_link" + - path: "models/mask_generation_diffusion_unet.pt" + url: "https://drive.google.com/file/d/11SA9RUZ6XmCOJr5v6w6UW1kDzr6hlymw/view?usp=drive_link" + - path: "configs/candidate_masks_flexible_size_and_spacing_3000.json" + url: "https://drive.google.com/file/d/1yMkH-lrAsn2YUGoTuVKNMpicziUmU-1J/view?usp=sharing" + - path: "configs/all_anatomy_size_condtions.json" + url: "https://drive.google.com/file/d/1AJyt1DSoUd2x2AOQOgM7IxeSyo4MXNX0/view?usp=sharing" + - path: "datasets/all_masks_flexible_size_and_spacing_3000.zip" + url: "https://drive.google.com/file/d/16MKsDKkHvDyF2lEir4dzlxwex_GHStUf/view?usp=sharing" + - path: "datasets/IntegrationTest-AbdomenCT.nii.gz" + url: "https://drive.google.com/file/d/1OTgt_dyBgvP52krKRXWXD3u0L5Zbj5JR/view?usp=share_link" + - path: "datasets/C4KC-KiTS_subset.zip" + url: "https://drive.google.com/file/d/1r62pLTowfrHhKW5YPl5pWygIDZSOI-VT/view?usp=sharing" + - path: "datasets/C4KC-KiTS_subset.json" + url: "https://drive.google.com/file/d/1tzpglihyZwlJcuEYJQeuB4zW8UrXyNO3/view?usp=sharing" diff --git a/models/maisi_ct_generative/scripts/__init__.py b/models/maisi_ct_generative/scripts/__init__.py new file mode 100644 index 00000000..41d37723 --- /dev/null +++ b/models/maisi_ct_generative/scripts/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import custom_network_diffusion, custom_network_tp, sample, utils diff --git a/models/maisi_ct_generative/scripts/augmentation.py b/models/maisi_ct_generative/scripts/augmentation.py new file mode 100644 index 00000000..6317781f --- /dev/null +++ b/models/maisi_ct_generative/scripts/augmentation.py @@ -0,0 +1,366 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn.functional as F +from monai.transforms import Rand3DElastic, RandAffine, RandZoom +from monai.utils import ensure_tuple_rep + + +def erode3d(input_tensor, erosion=3): + # Define the structuring element + erosion = ensure_tuple_rep(erosion, 3) + structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to(input_tensor.device) + + # Pad the input tensor to handle border pixels + input_padded = F.pad( + input_tensor.float().unsqueeze(0).unsqueeze(0), + (erosion[0] // 2, erosion[0] // 2, erosion[1] // 2, erosion[1] // 2, erosion[2] // 2, erosion[2] // 2), + mode="constant", + value=1.0, + ) + + # Apply erosion operation + output = F.conv3d(input_padded, structuring_element, padding=0) + + # Set output values based on the minimum value within the structuring element + output = torch.where(output == torch.sum(structuring_element), 1.0, 0.0) + + return output.squeeze(0).squeeze(0) + + +def dilate3d(input_tensor, erosion=3): + # Define the structuring element + erosion = ensure_tuple_rep(erosion, 3) + structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to(input_tensor.device) + + # Pad the input tensor to handle border pixels + input_padded = F.pad( + input_tensor.float().unsqueeze(0).unsqueeze(0), + (erosion[0] // 2, erosion[0] // 2, erosion[1] // 2, erosion[1] // 2, erosion[2] // 2, erosion[2] // 2), + mode="constant", + value=1.0, + ) + + # Apply erosion operation + output = F.conv3d(input_padded, structuring_element, padding=0) + + # Set output values based on the minimum value within the structuring element + output = torch.where(output > 0, 1.0, 0.0) + + return output.squeeze(0).squeeze(0) + + +def augmentation_tumor_bone(pt_nda, output_size): + volume = pt_nda.squeeze(0) + real_l_volume_ = torch.zeros_like(volume) + real_l_volume_[volume == 128] = 1 + real_l_volume_ = real_l_volume_.to(torch.uint8) + + elastic = RandAffine( + mode="nearest", + prob=1.0, + translate_range=(5, 5, 0), + rotate_range=(0, 0, 0.1), + scale_range=(0.15, 0.15, 0), + padding_mode="zeros", + ) + + tumor_szie = torch.sum((real_l_volume_ > 0).float()) + ########################### + # remove pred in pseudo_label in real lesion region + volume[real_l_volume_ > 0] = 200 + ########################### + if tumor_szie > 0: + # get organ mask + organ_mask = ( + torch.logical_and(33 <= volume, volume <= 56).float() + + torch.logical_and(63 <= volume, volume <= 97).float() + + (volume == 127).float() + + (volume == 114).float() + + real_l_volume_ + ) + organ_mask = (organ_mask > 0).float() + cnt = 0 + while True: + threshold = 0.8 if cnt < 40 else 0.75 + real_l_volume = real_l_volume_ + # random distor mask + distored_mask = elastic((real_l_volume > 0).cuda(), spatial_size=tuple(output_size)).as_tensor() + real_l_volume = distored_mask * organ_mask + cnt += 1 + print(torch.sum(real_l_volume), "|", tumor_szie * threshold) + if torch.sum(real_l_volume) >= tumor_szie * threshold: + real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5) + real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8) + break + else: + real_l_volume = real_l_volume_ + + volume[real_l_volume == 1] = 128 + + pt_nda = volume.unsqueeze(0) + return pt_nda + + +def augmentation_tumor_liver(pt_nda, output_size): + volume = pt_nda.squeeze(0) + real_l_volume_ = torch.zeros_like(volume) + real_l_volume_[volume == 1] = 1 + real_l_volume_[volume == 26] = 2 + real_l_volume_ = real_l_volume_.to(torch.uint8) + + elastic = Rand3DElastic( + mode="nearest", + prob=1.0, + sigma_range=(5, 8), + magnitude_range=(100, 200), + translate_range=(10, 10, 10), + rotate_range=(np.pi / 36, np.pi / 36, np.pi / 36), + scale_range=(0.2, 0.2, 0.2), + padding_mode="zeros", + ) + + tumor_szie = torch.sum(real_l_volume_ == 2) + ########################### + # remove pred organ labels + volume[volume == 1] = 0 + volume[volume == 26] = 0 + # before move tumor maks, full the original location by organ labels + volume[real_l_volume_ == 1] = 1 + volume[real_l_volume_ == 2] = 1 + ########################### + while True: + real_l_volume = real_l_volume_ + # random distor mask + real_l_volume = elastic((real_l_volume == 2).cuda(), spatial_size=tuple(output_size)).as_tensor() + # get organ mask + organ_mask = (real_l_volume_ == 1).float() + (real_l_volume_ == 2).float() + + organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5) + organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0) + real_l_volume = real_l_volume * organ_mask + print(torch.sum(real_l_volume), "|", tumor_szie * 0.80) + if torch.sum(real_l_volume) >= tumor_szie * 0.80: + real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5) + real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0) + break + + volume[real_l_volume == 1] = 26 + + pt_nda = volume.unsqueeze(0) + return pt_nda + + +def augmentation_tumor_lung(pt_nda, output_size): + volume = pt_nda.squeeze(0) + real_l_volume_ = torch.zeros_like(volume) + real_l_volume_[volume == 23] = 1 + real_l_volume_ = real_l_volume_.to(torch.uint8) + + elastic = Rand3DElastic( + mode="nearest", + prob=1.0, + sigma_range=(5, 8), + magnitude_range=(100, 200), + translate_range=(20, 20, 20), + rotate_range=(np.pi / 36, np.pi / 36, np.pi), + scale_range=(0.15, 0.15, 0.15), + padding_mode="zeros", + ) + + tumor_szie = torch.sum(real_l_volume_) + # before move lung tumor maks, full the original location by lung labels + new_real_l_volume_ = dilate3d(real_l_volume_.squeeze(0), erosion=3) + new_real_l_volume_ = new_real_l_volume_.unsqueeze(0) + new_real_l_volume_[real_l_volume_ > 0] = 0 + new_real_l_volume_[volume < 28] = 0 + new_real_l_volume_[volume > 32] = 0 + tmp = volume[(volume * new_real_l_volume_).nonzero(as_tuple=True)].view(-1) + + mode = torch.mode(tmp, 0)[0].item() + print(mode) + assert 28 <= mode <= 32 + volume[real_l_volume_.bool()] = mode + ########################### + if tumor_szie > 0: + # aug + while True: + real_l_volume = real_l_volume_ + # random distor mask + real_l_volume = elastic(real_l_volume, spatial_size=tuple(output_size)).as_tensor() + # get lung mask v2 (133 order) + lung_mask = ( + (volume == 28).float() + + (volume == 29).float() + + (volume == 30).float() + + (volume == 31).float() + + (volume == 32).float() + ) + + lung_mask = dilate3d(lung_mask.squeeze(0), erosion=5) + lung_mask = erode3d(lung_mask, erosion=5).unsqueeze(0) + real_l_volume = real_l_volume * lung_mask + print(torch.sum(real_l_volume), "|", tumor_szie * 0.85) + if torch.sum(real_l_volume) >= tumor_szie * 0.85: + real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5) + real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8) + break + else: + real_l_volume = real_l_volume_ + + volume[real_l_volume == 1] = 23 + + pt_nda = volume.unsqueeze(0) + return pt_nda + + +def augmentation_tumor_pancreas(pt_nda, output_size): + volume = pt_nda.squeeze(0) + real_l_volume_ = torch.zeros_like(volume) + real_l_volume_[volume == 4] = 1 + real_l_volume_[volume == 24] = 2 + real_l_volume_ = real_l_volume_.to(torch.uint8) + + elastic = Rand3DElastic( + mode="nearest", + prob=1.0, + sigma_range=(5, 8), + magnitude_range=(100, 200), + translate_range=(15, 15, 15), + rotate_range=(np.pi / 36, np.pi / 36, np.pi / 36), + scale_range=(0.1, 0.1, 0.1), + padding_mode="zeros", + ) + + tumor_szie = torch.sum(real_l_volume_ == 2) + ########################### + # remove pred organ labels + volume[volume == 24] = 0 + volume[volume == 4] = 0 + # before move tumor maks, full the original location by organ labels + volume[real_l_volume_ == 1] = 4 + volume[real_l_volume_ == 2] = 4 + ########################### + while True: + real_l_volume = real_l_volume_ + # random distor mask + real_l_volume = elastic((real_l_volume == 2).cuda(), spatial_size=tuple(output_size)).as_tensor() + # get organ mask + organ_mask = (real_l_volume_ == 1).float() + (real_l_volume_ == 2).float() + + organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5) + organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0) + real_l_volume = real_l_volume * organ_mask + print(torch.sum(real_l_volume), "|", tumor_szie * 0.80) + if torch.sum(real_l_volume) >= tumor_szie * 0.80: + real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5) + real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0) + break + + volume[real_l_volume == 1] = 24 + + pt_nda = volume.unsqueeze(0) + return pt_nda + + +def augmentation_tumor_colon(pt_nda, output_size): + volume = pt_nda.squeeze(0) + real_l_volume_ = torch.zeros_like(volume) + real_l_volume_[volume == 27] = 1 + real_l_volume_ = real_l_volume_.to(torch.uint8) + + elastic = Rand3DElastic( + mode="nearest", + prob=1.0, + sigma_range=(5, 8), + magnitude_range=(100, 200), + translate_range=(5, 5, 5), + rotate_range=(np.pi / 36, np.pi / 36, np.pi / 36), + scale_range=(0.1, 0.1, 0.1), + padding_mode="zeros", + ) + + tumor_szie = torch.sum(real_l_volume_) + ########################### + # before move tumor maks, full the original location by organ labels + volume[real_l_volume_.bool()] = 62 + ########################### + if tumor_szie > 0: + # get organ mask + organ_mask = (volume == 62).float() + organ_mask = dilate3d(organ_mask.squeeze(0), erosion=5) + organ_mask = erode3d(organ_mask, erosion=5).unsqueeze(0) + # cnt = 0 + cnt = 0 + while True: + threshold = 0.8 + real_l_volume = real_l_volume_ + if cnt < 20: + # random distor mask + distored_mask = elastic((real_l_volume == 1).cuda(), spatial_size=tuple(output_size)).as_tensor() + real_l_volume = distored_mask * organ_mask + elif 20 <= cnt < 40: + threshold = 0.75 + else: + break + + real_l_volume = real_l_volume * organ_mask + print(torch.sum(real_l_volume), "|", tumor_szie * threshold) + cnt += 1 + if torch.sum(real_l_volume) >= tumor_szie * threshold: + real_l_volume = dilate3d(real_l_volume.squeeze(0), erosion=5) + real_l_volume = erode3d(real_l_volume, erosion=5).unsqueeze(0).to(torch.uint8) + break + else: + real_l_volume = real_l_volume_ + # break + volume[real_l_volume == 1] = 27 + + pt_nda = volume.unsqueeze(0) + return pt_nda + + +def augmentation_body(pt_nda): + volume = pt_nda.squeeze(0) + + zoom = RandZoom(min_zoom=0.99, max_zoom=1.01, mode="nearest", align_corners=None, prob=1.0) + volume = zoom(volume) + + pt_nda = volume.unsqueeze(0) + return pt_nda + + +def augmentation(pt_nda, output_size): + label_list = torch.unique(pt_nda) + label_list = list(label_list.cpu().numpy()) + + if 128 in label_list: + print("augmenting bone lesion/tumor") + pt_nda = augmentation_tumor_bone(pt_nda, output_size) + elif 26 in label_list: + print("augmenting liver tumor") + pt_nda = augmentation_tumor_liver(pt_nda, output_size) + elif 23 in label_list: + print("augmenting lung tumor") + pt_nda = augmentation_tumor_lung(pt_nda, output_size) + elif 24 in label_list: + print("augmenting pancreas tumor") + pt_nda = augmentation_tumor_pancreas(pt_nda, output_size) + elif 27 in label_list: + print("augmenting colon tumor") + pt_nda = augmentation_tumor_colon(pt_nda, output_size) + else: + print("augmenting body") + pt_nda = augmentation_body(pt_nda) + + return pt_nda diff --git a/models/maisi_ct_generative/scripts/custom_network_controlnet.py b/models/maisi_ct_generative/scripts/custom_network_controlnet.py new file mode 100644 index 00000000..ad36c5b6 --- /dev/null +++ b/models/maisi_ct_generative/scripts/custom_network_controlnet.py @@ -0,0 +1,177 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from typing import Sequence + +import torch +from generative.networks.nets.controlnet import ControlNet +from generative.networks.nets.diffusion_model_unet import get_timestep_embedding + + +class CustomControlNet(ControlNet): + """ + Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image + Diffusion Models" (https://arxiv.org/abs/2302.05543) + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + conditioning_embedding_in_channels: number of input channels for the conditioning embedding. + conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + conditioning_embedding_in_channels: int = 1, + conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256), + ) -> None: + super().__init__( + spatial_dims, + in_channels, + num_res_blocks, + num_channels, + attention_levels, + norm_num_groups, + norm_eps, + resblock_updown, + num_head_channels, + with_conditioning, + transformer_num_layers, + cross_attention_dim, + num_class_embeds, + upcast_attention, + use_flash_attention, + conditioning_embedding_in_channels, + conditioning_embedding_num_channels, + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + ) -> tuple[list[torch.Tensor], torch.Tensor]: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + controlnet_cond: controlnet conditioning tensor (N, C, SpatialDims). + conditioning_scale: conditioning scale. + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. initial convolution + h = self.conv_in(x) + + # controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + controlnet_cond = torch.utils.checkpoint.checkpoint( + self.controlnet_cond_embedding, controlnet_cond, use_reentrant=False + ) + + h += controlnet_cond + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # 6. Control net blocks + controlnet_down_block_res_samples = () + + for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): + down_block_res_sample = controlnet_block(down_block_res_sample) + controlnet_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = controlnet_down_block_res_samples + + mid_block_res_sample = self.controlnet_mid_block(h) + + # 6. scaling + down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples] + mid_block_res_sample *= conditioning_scale + + return down_block_res_samples, mid_block_res_sample diff --git a/models/maisi_ct_generative/scripts/custom_network_diffusion.py b/models/maisi_ct_generative/scripts/custom_network_diffusion.py new file mode 100644 index 00000000..9e4cbf60 --- /dev/null +++ b/models/maisi_ct_generative/scripts/custom_network_diffusion.py @@ -0,0 +1,1993 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ========================================================================= +# Adapted from https://github.com/huggingface/diffusers +# which has the following license: +# https://github.com/huggingface/diffusers/blob/main/LICENSE +# +# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========================================================================= + +from __future__ import annotations + +import importlib.util +import math +from collections.abc import Sequence + +import torch +import torch.nn.functional as F +from monai.networks.blocks import Convolution, MLPBlock +from monai.networks.layers.factories import Pool +from monai.utils import ensure_tuple_rep +from torch import nn + +if importlib.util.find_spec("xformers") is not None: + import xformers + import xformers.ops + + has_xformers = True +else: + xformers = None + has_xformers = False + + +# TODO: Use MONAI's optional_import +# from monai.utils import optional_import +# xformers, has_xformers = optional_import("xformers.ops", name="xformers") + +__all__ = ["CustomDiffusionModelUNet"] + + +def zero_module(module: nn.Module) -> nn.Module: + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class CrossAttention(nn.Module): + """ + A cross attention layer. + + Args: + query_dim: number of channels in the query. + cross_attention_dim: number of channels in the context. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each head. + dropout: dropout probability to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: int | None = None, + num_attention_heads: int = 8, + num_head_channels: int = 64, + dropout: float = 0.0, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + inner_dim = num_head_channels * num_attention_heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + + self.scale = 1 / math.sqrt(num_head_channels) + self.num_heads = num_attention_heads + + self.upcast_attention = upcast_attention + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + """ + Divide hidden state dimension to the multiple attention heads and reshape their input as instances in the batch. + """ + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + """Combine the output of the attention heads back into the hidden state dimension.""" + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(dtype=dtype) + + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + query = self.to_q(x) + context = context if context is not None else x + key = self.to_k(context) + value = self.to_v(context) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + return self.to_out(x) + + +class BasicTransformerBlock(nn.Module): + """ + A basic Transformer block. + + Args: + num_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + dropout: dropout probability to use. + cross_attention_dim: size of the context vector for cross attention. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + num_channels: int, + num_attention_heads: int, + num_head_channels: int, + dropout: float = 0.0, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.attn1 = CrossAttention( + query_dim=num_channels, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) # is a self-attention + self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout) + self.attn2 = CrossAttention( + query_dim=num_channels, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) # is a self-attention if context is None + self.norm1 = nn.LayerNorm(num_channels) + self.norm2 = nn.LayerNorm(num_channels) + self.norm3 = nn.LayerNorm(num_channels) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + # 1. Self-Attention + x = self.attn1(self.norm1(x)) + x + + # 2. Cross-Attention + x = self.attn2(self.norm2(x), context=context) + x + + # 3. Feed-forward + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply + standard transformer action. Finally, reshape to image. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of channels in the input and output. + num_attention_heads: number of heads to use for multi-head attention. + num_head_channels: number of channels in each attention head. + num_layers: number of layers of Transformer blocks to use. + dropout: dropout probability to use. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_attention_heads: int, + num_head_channels: int, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + inner_dim = num_attention_heads * num_head_channels + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + + self.proj_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=inner_dim, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + num_channels=inner_dim, + num_attention_heads=num_attention_heads, + num_head_channels=num_head_channels, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + ) + for _ in range(num_layers) + ] + ) + + self.proj_out = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=inner_dim, + out_channels=in_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + ) + + def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor: + # note: if no context is given, cross-attention defaults to self-attention + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + residual = x + x = self.norm(x) + x = self.proj_in(x) + + inner_dim = x.shape[1] + + if self.spatial_dims == 2: + x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + if self.spatial_dims == 3: + x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim) + + for block in self.transformer_blocks: + x = block(x, context=context) + + if self.spatial_dims == 2: + x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + if self.spatial_dims == 3: + x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous() + + x = self.proj_out(x) + return x + residual + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to + compute attention. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + num_head_channels: number of channels in each attention head. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon value to use for the normalisation. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: int, + num_head_channels: int | None = None, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.use_flash_attention = use_flash_attention + self.spatial_dims = spatial_dims + self.num_channels = num_channels + + self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1 + self.scale = 1 / math.sqrt(num_channels / self.num_heads) + + self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True) + + self.to_q = nn.Linear(num_channels, num_channels) + self.to_k = nn.Linear(num_channels, num_channels) + self.to_v = nn.Linear(num_channels, num_channels) + + self.proj_attn = nn.Linear(num_channels, num_channels) + + def reshape_heads_to_batch_dim(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size, seq_len, self.num_heads, dim // self.num_heads) + x = x.permute(0, 2, 1, 3).reshape(batch_size * self.num_heads, seq_len, dim // self.num_heads) + return x + + def reshape_batch_dim_to_heads(self, x: torch.Tensor) -> torch.Tensor: + batch_size, seq_len, dim = x.shape + x = x.reshape(batch_size // self.num_heads, self.num_heads, seq_len, dim) + x = x.permute(0, 2, 1, 3).reshape(batch_size // self.num_heads, seq_len, dim * self.num_heads) + return x + + def _memory_efficient_attention_xformers( + self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + x = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + return x + + def _attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + x = torch.bmm(attention_probs, value) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + + batch = channel = height = width = depth = -1 + if self.spatial_dims == 2: + batch, channel, height, width = x.shape + if self.spatial_dims == 3: + batch, channel, height, width, depth = x.shape + + # norm + x = self.norm(x) + + if self.spatial_dims == 2: + x = x.view(batch, channel, height * width).transpose(1, 2) + if self.spatial_dims == 3: + x = x.view(batch, channel, height * width * depth).transpose(1, 2) + + # proj to q, k, v + query = self.to_q(x) + key = self.to_k(x) + value = self.to_v(x) + + # Multi-Head Attention + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if self.use_flash_attention: + x = self._memory_efficient_attention_xformers(query, key, value) + else: + x = self._attention(query, key, value) + + x = self.reshape_batch_dim_to_heads(x) + x = x.to(query.dtype) + + if self.spatial_dims == 2: + x = x.transpose(-1, -2).reshape(batch, channel, height, width) + if self.spatial_dims == 3: + x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth) + + return x + residual + + +def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor: + """ + Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic + Models" https://arxiv.org/abs/2006.11239. + + Args: + timesteps: a 1-D Tensor of N indices, one per batch element. + embedding_dim: the dimension of the output. + max_period: controls the minimum frequency of the embeddings. + """ + # print(f'max_period: {max_period}; timesteps: {torch.norm(timesteps.float(), p=2)}; embedding_dim: {embedding_dim}') + + if timesteps.ndim != 1: + raise ValueError("Timesteps should be a 1d-array") + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + freqs = torch.exp(exponent / half_dim) + + args = timesteps[:, None].float() * freqs[None, :] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0)) + + return embedding + + +class Downsample(nn.Module): + """ + Downsampling layer. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is + False, the number of output channels must be the same as the number of input channels. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points + for each dimension. + """ + + def __init__( + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.op = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=2, + kernel_size=3, + padding=padding, + conv_only=True, + ) + else: + if self.num_channels != self.out_channels: + raise ValueError("num_channels and out_channels must be equal when use_conv=False") + self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2) + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError( + f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels " + f"({self.num_channels})" + ) + return self.op(x) + + +class Upsample(nn.Module): + """ + Upsampling layer with an optional convolution. + + Args: + spatial_dims: number of spatial dimensions. + num_channels: number of input channels. + use_conv: if True uses Convolution instead of Pool average to perform downsampling. + out_channels: number of output channels. + padding: controls the amount of implicit zero-paddings on both sides for padding number of points for each + dimension. + """ + + def __init__( + self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1 + ) -> None: + super().__init__() + self.num_channels = num_channels + self.out_channels = out_channels or num_channels + self.use_conv = use_conv + if use_conv: + self.conv = Convolution( + spatial_dims=spatial_dims, + in_channels=self.num_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=padding, + conv_only=True, + ) + else: + self.conv = None + + def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor: + del emb + if x.shape[1] != self.num_channels: + raise ValueError("Input channels should be equal to num_channels") + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + dtype = x.dtype + if dtype == torch.bfloat16: + x = x.to(torch.float32) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + x = x.to(dtype) + + if self.use_conv: + x = self.conv(x) + return x + + +class ResnetBlock(nn.Module): + """ + Residual block with timestep conditioning. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + out_channels: number of output channels. + up: if True, performs upsampling. + down: if True, performs downsampling. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + out_channels: int | None = None, + up: bool = False, + down: bool = False, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.channels = in_channels + self.emb_channels = temb_channels + self.out_channels = out_channels or in_channels + self.up = up + self.down = down + + self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + self.nonlinearity = nn.SiLU() + self.conv1 = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = Upsample(spatial_dims, in_channels, use_conv=False) + elif down: + self.downsample = Downsample(spatial_dims, in_channels, use_conv=False) + + self.time_emb_proj = nn.Linear(temb_channels, self.out_channels) + + self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True) + self.conv2 = zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + if self.out_channels == in_channels: + self.skip_connection = nn.Identity() + else: + self.skip_connection = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: + h = x + h = self.norm1(h) + h = self.nonlinearity(h) + + if self.upsample is not None: + if h.shape[0] >= 64: + x = x.contiguous() + h = h.contiguous() + x = self.upsample(x) + h = self.upsample(h) + elif self.downsample is not None: + x = self.downsample(x) + h = self.downsample(h) + + h = self.conv1(h) + + if self.spatial_dims == 2: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None] + else: + temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None] + h = h + temb + + h = self.norm2(h) + h = self.nonlinearity(h) + h = self.conv2(h) + + return self.skip_connection(x) + h + + +class DownBlock(nn.Module): + """ + Unet's down block containing resnet and downsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + del context + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class CrossAttnDownBlock(nn.Module): + """ + Unet's down block containing resnet, downsamplers and cross-attention blocks. + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_downsample: if True add downsample block. + resblock_updown: if True use residual blocks for downsampling. + downsample_padding: padding used in the downsampling block. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_downsample: bool = True, + resblock_updown: bool = False, + downsample_padding: int = 1, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + attentions.append( + SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + if resblock_updown: + self.downsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + down=True, + ) + else: + self.downsampler = Downsample( + spatial_dims=spatial_dims, + num_channels=out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsampler = None + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + output_states = [] + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + output_states.append(hidden_states) + + if self.downsampler is not None: + hidden_states = self.downsampler(hidden_states, temb) + output_states.append(hidden_states) + + return hidden_states, output_states + + +class AttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.attention = None + + self.resnet_1 = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = AttentionBlock( + spatial_dims=spatial_dims, + num_channels=in_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + + self.resnet_2 = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + del context + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class CrossAttnMidBlock(nn.Module): + """ + Unet's mid block containing resnet and cross-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + temb_channels: number of timestep embedding channels + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.attention = None + + self.resnet_1 = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + self.attention = SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_attention_heads=in_channels // num_head_channels, + num_head_channels=num_head_channels, + num_layers=transformer_num_layers, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + self.resnet_2 = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + + def forward( + self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None + ) -> torch.Tensor: + hidden_states = self.resnet_1(hidden_states, temb) + hidden_states = self.attention(hidden_states, context=context) + hidden_states = self.resnet_2(hidden_states, temb) + + return hidden_states + + +class UpBlock(nn.Module): + """ + Unet's up block containing resnet and upsamplers blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + resnets = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class AttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=out_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + del context + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +class CrossAttnUpBlock(nn.Module): + """ + Unet's up block containing resnet, upsamplers, and self-attention blocks. + + Args: + spatial_dims: The number of spatial dimensions. + in_channels: number of input channels. + prev_output_channel: number of channels from residual connection. + out_channels: number of output channels. + temb_channels: number of timestep embedding channels. + num_res_blocks: number of residual blocks. + norm_num_groups: number of groups for the group normalization. + norm_eps: epsilon for the group normalization. + add_upsample: if True add downsample block. + resblock_updown: if True use residual blocks for upsampling. + num_head_channels: number of channels in each attention head. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int = 1, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + add_upsample: bool = True, + resblock_updown: bool = False, + num_head_channels: int = 1, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + ) -> None: + super().__init__() + self.resblock_updown = resblock_updown + + resnets = [] + attentions = [] + + for i in range(num_res_blocks): + res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock( + spatial_dims=spatial_dims, + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + ) + ) + attentions.append( + SpatialTransformer( + spatial_dims=spatial_dims, + in_channels=out_channels, + num_attention_heads=out_channels // num_head_channels, + num_head_channels=num_head_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout=dropout_cattn, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + if resblock_updown: + self.upsampler = ResnetBlock( + spatial_dims=spatial_dims, + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + up=True, + ) + else: + self.upsampler = Upsample( + spatial_dims=spatial_dims, num_channels=out_channels, use_conv=True, out_channels=out_channels + ) + else: + self.upsampler = None + + def forward( + self, + hidden_states: torch.Tensor, + res_hidden_states_list: list[torch.Tensor], + temb: torch.Tensor, + context: torch.Tensor | None = None, + ) -> torch.Tensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_list[-1] + res_hidden_states_list = res_hidden_states_list[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=context) + + if self.upsampler is not None: + hidden_states = self.upsampler(hidden_states, temb) + + return hidden_states + + +def get_down_block( + spatial_dims: int, + in_channels: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_downsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + elif with_cross_attn: + return CrossAttnDownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return DownBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=add_downsample, + resblock_updown=resblock_updown, + ) + + +def get_mid_block( + spatial_dims: int, + in_channels: int, + temb_channels: int, + norm_num_groups: int, + norm_eps: float, + with_conditioning: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_conditioning: + return CrossAttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return AttnMidBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + temb_channels=temb_channels, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + + +def get_up_block( + spatial_dims: int, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + num_res_blocks: int, + norm_num_groups: int, + norm_eps: float, + add_upsample: bool, + resblock_updown: bool, + with_attn: bool, + with_cross_attn: bool, + num_head_channels: int, + transformer_num_layers: int, + cross_attention_dim: int | None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, +) -> nn.Module: + if with_attn: + return AttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + use_flash_attention=use_flash_attention, + ) + elif with_cross_attn: + return CrossAttnUpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + num_head_channels=num_head_channels, + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + else: + return UpBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + prev_output_channel=prev_output_channel, + out_channels=out_channels, + temb_channels=temb_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=add_upsample, + resblock_updown=resblock_updown, + ) + + +class CustomDiffusionModelUNet(nn.Module): + """ + Unet network with timestep embedding and attention mechanisms for conditioning based on + Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752 + and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162 + + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResnetBlock) per level. + num_channels: tuple of block output channels. + attention_levels: list of levels to add attention. + norm_num_groups: number of groups for the normalization. + norm_eps: epsilon for the normalization. + resblock_updown: if True use residual blocks for up/downsampling. + num_head_channels: number of channels in each attention head. + with_conditioning: if True add spatial transformers to perform conditioning. + transformer_num_layers: number of layers of Transformer blocks to use. + cross_attention_dim: number of context dimensions to use. + num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` + classes. + upcast_attention: if True, upcast attention operations to full precision. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), + num_channels: Sequence[int] = (32, 64, 64, 64), + attention_levels: Sequence[bool] = (False, False, True, True), + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + resblock_updown: bool = False, + num_head_channels: int | Sequence[int] = 8, + with_conditioning: bool = False, + transformer_num_layers: int = 1, + cross_attention_dim: int | None = None, + num_class_embeds: int | None = None, + upcast_attention: bool = False, + use_flash_attention: bool = False, + dropout_cattn: float = 0.0, + input_top_region_index: bool = False, + input_bottom_region_index: bool = False, + input_spacing: bool = False, + ) -> None: + super().__init__() + if with_conditioning is True and cross_attention_dim is None: + raise ValueError( + "CustomDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) " + "when using with_conditioning." + ) + if cross_attention_dim is not None and with_conditioning is False: + raise ValueError( + "CustomDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim." + ) + if dropout_cattn > 1.0 or dropout_cattn < 0.0: + raise ValueError("Dropout cannot be negative or >1.0!") + + # All number of channels should be multiple of num_groups + if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels): + raise ValueError("CustomDiffusionModelUNet expects all num_channels being multiple of norm_num_groups") + + if len(num_channels) != len(attention_levels): + raise ValueError("CustomDiffusionModelUNet expects num_channels being same size of attention_levels") + + if isinstance(num_head_channels, int): + num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels)) + + if len(num_head_channels) != len(attention_levels): + raise ValueError( + "num_head_channels should have the same length as attention_levels. For the i levels without attention," + " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored." + ) + + if isinstance(num_res_blocks, int): + num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels)) + + if len(num_res_blocks) != len(num_channels): + raise ValueError( + "`num_res_blocks` should be a single integer or a tuple of integers with the same length as " + "`num_channels`." + ) + + if use_flash_attention and not has_xformers: + raise ValueError("use_flash_attention is True but xformers is not installed.") + + if use_flash_attention is True and not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU." + ) + + self.in_channels = in_channels + self.block_out_channels = num_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_levels = attention_levels + self.num_head_channels = num_head_channels + self.with_conditioning = with_conditioning + + # input + self.conv_in = Convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + # time + time_embed_dim = num_channels[0] * 4 + self.time_embed = nn.Sequential( + nn.Linear(num_channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + + # class embedding + self.num_class_embeds = num_class_embeds + if num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + + self.input_top_region_index = input_top_region_index + self.input_bottom_region_index = input_bottom_region_index + self.input_spacing = input_spacing + + new_time_embed_dim = time_embed_dim + if self.input_top_region_index: + # self.top_region_index_layer = nn.Linear(4, time_embed_dim) + self.top_region_index_layer = nn.Sequential( + nn.Linear(4, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + new_time_embed_dim += time_embed_dim + if self.input_bottom_region_index: + # self.bottom_region_index_layer = nn.Linear(4, time_embed_dim) + self.bottom_region_index_layer = nn.Sequential( + nn.Linear(4, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + new_time_embed_dim += time_embed_dim + if self.input_spacing: + # self.spacing_layer = nn.Linear(3, time_embed_dim) + self.spacing_layer = nn.Sequential( + nn.Linear(3, time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim) + ) + new_time_embed_dim += time_embed_dim + + # down + self.down_blocks = nn.ModuleList([]) + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + down_block = get_down_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=new_time_embed_dim, + num_res_blocks=num_res_blocks[i], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_downsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(attention_levels[i] and not with_conditioning), + with_cross_attn=(attention_levels[i] and with_conditioning), + num_head_channels=num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + self.down_blocks.append(down_block) + + # mid + self.middle_block = get_mid_block( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + temb_channels=new_time_embed_dim, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + with_conditioning=with_conditioning, + num_head_channels=num_head_channels[-1], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + # up + self.up_blocks = nn.ModuleList([]) + reversed_block_out_channels = list(reversed(num_channels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_head_channels = list(reversed(num_head_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)] + + is_final_block = i == len(num_channels) - 1 + + up_block = get_up_block( + spatial_dims=spatial_dims, + in_channels=input_channel, + prev_output_channel=prev_output_channel, + out_channels=output_channel, + temb_channels=new_time_embed_dim, + num_res_blocks=reversed_num_res_blocks[i] + 1, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + add_upsample=not is_final_block, + resblock_updown=resblock_updown, + with_attn=(reversed_attention_levels[i] and not with_conditioning), + with_cross_attn=(reversed_attention_levels[i] and with_conditioning), + num_head_channels=reversed_num_head_channels[i], + transformer_num_layers=transformer_num_layers, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + use_flash_attention=use_flash_attention, + dropout_cattn=dropout_cattn, + ) + + self.up_blocks.append(up_block) + + # out + self.out = nn.Sequential( + nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True), + nn.SiLU(), + zero_module( + Convolution( + spatial_dims=spatial_dims, + in_channels=num_channels[0], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ), + ) + + def forward( + self, + x: torch.Tensor, + timesteps: torch.Tensor, + context: torch.Tensor | None = None, + class_labels: torch.Tensor | None = None, + down_block_additional_residuals: tuple[torch.Tensor] | None = None, + mid_block_additional_residual: torch.Tensor | None = None, + top_region_index_tensor: torch.Tensor | None = None, + bottom_region_index_tensor: torch.Tensor | None = None, + spacing_tensor: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x: input tensor (N, C, SpatialDims). + timesteps: timestep tensor (N,). + context: context tensor (N, 1, ContextDim). + class_labels: context tensor (N, ). + down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims). + mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims). + """ + # 1. time + t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0]) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embed(t_emb) + # print(f't_emb: {t_emb}; timesteps {timesteps}.') + # print(f'emb: {torch.norm(emb, p=2)}; t_emb: {torch.norm(t_emb, p=2)}') + + # 2. class + if self.num_class_embeds is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + class_emb = self.class_embedding(class_labels) + class_emb = class_emb.to(dtype=x.dtype) + emb = emb + class_emb + + # 3. input + if self.input_top_region_index: + _emb = self.top_region_index_layer(top_region_index_tensor) + # emb = emb + _emb.to(dtype=x.dtype) + emb = torch.cat((emb, _emb), dim=1) + # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; top_region_index_tensor: {torch.norm(_emb, p=2)}') + if self.input_bottom_region_index: + _emb = self.bottom_region_index_layer(bottom_region_index_tensor) + # emb = emb + _emb.to(dtype=x.dtype) + emb = torch.cat((emb, _emb), dim=1) + # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; bottom_region_index_tensor: {torch.norm(_emb, p=2)}') + if self.input_spacing: + _emb = self.spacing_layer(spacing_tensor) + # emb = emb + _emb.to(dtype=x.dtype) + emb = torch.cat((emb, _emb), dim=1) + # print(f'emb: {emb.size()}, {torch.norm(emb, p=2)}; spacing_tensor: {torch.norm(spacing_tensor, p=2)}') + + # 3. initial convolution + h = self.conv_in(x) + + # 4. down + if context is not None and self.with_conditioning is False: + raise ValueError("model should have with_conditioning = True if context is provided") + down_block_res_samples: list[torch.Tensor] = [h] + for downsample_block in self.down_blocks: + h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context) + for residual in res_samples: + down_block_res_samples.append(residual) + + # Additional residual conections for Controlnets + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 5. mid + h = self.middle_block(hidden_states=h, temb=emb, context=context) + + # Additional residual conections for Controlnets + if mid_block_additional_residual is not None: + h = h + mid_block_additional_residual + + # 6. up + for upsample_block in self.up_blocks: + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context) + + # 7. output block + h = self.out(h) + + return h diff --git a/models/maisi_ct_generative/scripts/custom_network_tp.py b/models/maisi_ct_generative/scripts/custom_network_tp.py new file mode 100644 index 00000000..1fd33fe0 --- /dev/null +++ b/models/maisi_ct_generative/scripts/custom_network_tp.py @@ -0,0 +1,1053 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Sequence + +import monai +import torch +import torch.nn as nn +import torch.nn.functional as F +from generative.networks.nets.autoencoderkl import AttentionBlock, AutoencoderKL, ResBlock + +NUM_SPLITS = 16 +# NUM_SPLITS = 32 +SPLIT_PADDING = 3 + + +class InplaceGroupNorm3D(torch.nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super(InplaceGroupNorm3D, self).__init__(num_groups, num_channels, eps, affine) + + def forward(self, input): + # print("InplaceGroupNorm3D in", input.size()) + + # # normalization + # norm = 1e1 + # input /= norm + # # print("normalization2") + + # Ensure the tensor is 5D: (n, c, d, h, w) + if len(input.shape) != 5: + raise ValueError("Expected a 5D tensor") + + n, c, d, h, w = input.shape + + # Reshape to (n, num_groups, c // num_groups, d, h, w) + input = input.view(n, self.num_groups, c // self.num_groups, d, h, w) + + # input = input.to(dtype=torch.float64) + + # # Compute mean and std dev + # mean1 = input.mean([2, 3, 4, 5], keepdim=True) + # std1 = input.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_() + # mean1 = mean1.to(dtype=torch.float32) + + if False: + input = input.to(dtype=torch.float64) + mean = input.mean([2, 3, 4, 5], keepdim=True) + # std = input.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_() + + input = input.to(dtype=torch.float32) + mean = mean.to(dtype=torch.float32) + # std = mean.to(dtype=torch.float32) + else: + # means, stds = [], [] + inputs = [] + for _i in range(input.size(1)): + array = input[:, _i : _i + 1, ...] + array = array.to(dtype=torch.float32) + _mean = array.mean([2, 3, 4, 5], keepdim=True) + _std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_() + + # del array + # torch.cuda.empty_cache() + + _mean = _mean.to(dtype=torch.float32) + _std = _std.to(dtype=torch.float32) + + # means.append(_mean) + # stds.append(_std) + + # mean = torch.cat([means[_k] for _k in range(len(means))], dim=1) + # std = torch.cat([stds[_k] for _k in range(len(stds))], dim=1) + # input = input.to(dtype=torch.float32) + + inputs.append(array.sub_(_mean).div_(_std).to(dtype=torch.float16)) + + # Normalize features (in-place) + # input.sub_(mean).div_(std) + + del input + torch.cuda.empty_cache() + + if False: + input = torch.cat([inputs[_k] for _k in range(len(inputs))], dim=1) + else: + if max(inputs[0].size()) < 500: + input = torch.cat([inputs[_k] for _k in range(len(inputs))], dim=1) + else: + import gc + + _type = inputs[0].device.type + if _type == "cuda": + input = inputs[0].clone().to("cpu", non_blocking=True) + else: + input = inputs[0].clone() + inputs[0] = 0 + torch.cuda.empty_cache() + + for _k in range(len(inputs) - 1): + input = torch.cat((input, inputs[_k + 1].cpu()), dim=1) + inputs[_k + 1] = 0 + torch.cuda.empty_cache() + gc.collect() + # print(f'InplaceGroupNorm3D cat: {_k + 1}/{len(inputs) - 1}.') + + if _type == "cuda": + input = input.to("cuda", non_blocking=True) + + # Reshape back to original size + input = input.view(n, c, d, h, w) + + # Apply affine transformation if enabled + if self.affine: + input.mul_(self.weight.view(1, c, 1, 1, 1)).add_(self.bias.view(1, c, 1, 1, 1)) + + # input = input.to(dtype=torch.float32) + # input *= norm + # print("InplaceGroupNorm3D out", input.size()) + + return input + + +class SplitConvolutionV1(nn.Module): + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + strides: Sequence[int] | int = 1, + kernel_size: Sequence[int] | int = 3, + adn_ordering: str = "NDA", + act: tuple | str | None = "PRELU", + norm: tuple | str | None = "INSTANCE", + dropout: tuple | str | float | None = None, + dropout_dim: int | None = 1, + dilation: Sequence[int] | int = 1, + groups: int = 1, + bias: bool = True, + conv_only: bool = False, + is_transposed: bool = False, + padding: Sequence[int] | int | None = None, + output_padding: Sequence[int] | int | None = None, + ) -> None: + super(SplitConvolutionV1, self).__init__() + self.conv = monai.networks.blocks.convolutions.Convolution( + spatial_dims, + in_channels, + out_channels, + strides, + kernel_size, + adn_ordering, + act, + norm, + dropout, + dropout_dim, + dilation, + groups, + bias, + conv_only, + is_transposed, + padding, + output_padding, + ) + + self.tp_dim = 1 + self.stride = strides[self.tp_dim] if isinstance(strides, list) else strides + + def forward(self, x): + # Call parent's forward method + # x = super(SplitConvolution, self).forward(x) + + num_splits = NUM_SPLITS + # print("num_splits:", num_splits) + l = x.size(self.tp_dim + 2) + split_size = l // num_splits + + if False: + splits = [x[:, :, i * split_size : (i + 1) * split_size, :, :] for i in range(num_splits)] + else: + # padding = 1 + padding = SPLIT_PADDING + if padding % self.stride > 0: + padding = (padding // self.stride + 1) * self.stride + # print("padding:", padding) + + overlaps = [0] + [padding] * (num_splits - 1) + last_padding = x.size(self.tp_dim + 2) % split_size + + if self.tp_dim == 0: + splits = [ + x[ + :, + :, + i * split_size + - overlaps[i] : (i + 1) * split_size + + (padding if i != num_splits - 1 else last_padding), + :, + :, + ] + for i in range(num_splits) + ] + elif self.tp_dim == 1: + splits = [ + x[ + :, + :, + :, + i * split_size + - overlaps[i] : (i + 1) * split_size + + (padding if i != num_splits - 1 else last_padding), + :, + ] + for i in range(num_splits) + ] + elif self.tp_dim == 2: + splits = [ + x[ + :, + :, + :, + :, + i * split_size + - overlaps[i] : (i + 1) * split_size + + (padding if i != num_splits - 1 else last_padding), + ] + for i in range(num_splits) + ] + + # for _j in range(len(splits)): + # print(f"splits {_j + 1}/{len(splits)}:", splits[_j].size()) + + del x + torch.cuda.empty_cache() + + splits_0_size = list(splits[0].size()) + # print("splits_0_size:", splits_0_size) + + # outputs = [super(SplitConvolution, self).forward(splits[i]) for i in range(num_splits)] + if False: + outputs = [self.conv(splits[i]) for i in range(num_splits)] + else: + outputs = [] + _type = splits[0].device.type + for _i in range(num_splits): + if True: + # if _type == 'cuda': + outputs.append(self.conv(splits[_i])) + else: + _t = splits[_i] + _t1 = self.conv(_t.to("cuda", non_blocking=True)) + del _t + torch.cuda.empty_cache() + _t1 = _t1.to("cpu", non_blocking=True) + outputs.append(_t1) + del _t1 + torch.cuda.empty_cache() + + splits[_i] = 0 + torch.cuda.empty_cache() + + # for _j in range(len(outputs)): + # print(f"outputs before {_j + 1}/{len(outputs)}:", outputs[_j].size()) + + del splits + torch.cuda.empty_cache() + + split_size_out = split_size + padding_s = padding + non_tp_dim = self.tp_dim + 1 if self.tp_dim < 2 else 0 + if outputs[0].size(non_tp_dim + 2) // splits_0_size[non_tp_dim + 2] == 2: + split_size_out *= 2 + padding_s *= 2 + elif splits_0_size[non_tp_dim + 2] // outputs[0].size(non_tp_dim + 2) == 2: + split_size_out = split_size_out // 2 + padding_s = padding_s // 2 + + if self.tp_dim == 0: + outputs[0] = outputs[0][:, :, :split_size_out, :, :] + for i in range(1, num_splits): + outputs[i] = outputs[i][:, :, padding_s : padding_s + split_size_out, :, :] + elif self.tp_dim == 1: + # print("outputs", outputs[0].size(3), f"padding_s: 0, {split_size_out}") + outputs[0] = outputs[0][:, :, :, :split_size_out, :] + # # print("outputs", outputs[0].size(3), f"padding_s: {padding_s // 2}, {padding_s // 2 + split_size_out}") + # outputs[0] = outputs[0][:, :, :, padding_s // 2:padding_s // 2 + split_size_out, :] + for i in range(1, num_splits): + # print("outputs", outputs[i].size(3), f"padding_s: {padding_s}, {padding_s + split_size_out}") + outputs[i] = outputs[i][:, :, :, padding_s : padding_s + split_size_out, :] + elif self.tp_dim == 2: + outputs[0] = outputs[0][:, :, :, :, :split_size_out] + for i in range(1, num_splits): + outputs[i] = outputs[i][:, :, :, :, padding_s : padding_s + split_size_out] + + # for i in range(num_splits): + # print(f"outputs after {i + 1}/{len(outputs)}:", outputs[i].size()) + + # if max(outputs[0].size()) < 500 or outputs[0].device.type != 'cuda': + # if True: + if max(outputs[0].size()) < 500: + # print(f'outputs[0].device.type: {outputs[0].device.type}.') + x = torch.cat(list(outputs), dim=self.tp_dim + 2) + else: + import gc + + # x = torch.randn(outputs[0].size(), dtype=outputs[0].dtype, pin_memory=True) + # x = outputs[0] + # x = x.to('cpu', non_blocking=True) + + _type = outputs[0].device.type + if _type == "cuda": + x = outputs[0].clone().to("cpu", non_blocking=True) + outputs[0] = 0 + torch.cuda.empty_cache() + for _k in range(len(outputs) - 1): + x = torch.cat((x, outputs[_k + 1].cpu()), dim=self.tp_dim + 2) + outputs[_k + 1] = 0 + torch.cuda.empty_cache() + gc.collect() + # print(f'SplitConvolutionV1 cat: {_k + 1}/{len(outputs) - 1}.') + if _type == "cuda": + x = x.to("cuda", non_blocking=True) + + del outputs + torch.cuda.empty_cache() + + return x + + +class SplitUpsample1(nn.Module): + """ + Convolution-based upsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels to the layer. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__(self, spatial_dims: int, in_channels: int, use_convtranspose: bool) -> None: + super().__init__() + if use_convtranspose: + self.conv = SplitConvolutionV1( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=1, + conv_only=True, + is_transposed=True, + ) + else: + self.conv = SplitConvolutionV1( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.use_convtranspose = use_convtranspose + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_convtranspose: + return self.conv(x) + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # https://github.com/pytorch/pytorch/issues/86679 + # dtype = x.dtype + # if dtype == torch.bfloat16: + # x = x.to(torch.float32) + + x = F.interpolate(x, scale_factor=2.0, mode="trilinear") + torch.cuda.empty_cache() + + # If the input is bfloat16, we cast back to bfloat16 + # if dtype == torch.bfloat16: + # x = x.to(dtype) + + x = self.conv(x) + torch.cuda.empty_cache() + + return x + + +class SplitDownsample(nn.Module): + """ + Convolution-based downsampling layer. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + """ + + def __init__(self, spatial_dims: int, in_channels: int) -> None: + super().__init__() + self.pad = (0, 1) * spatial_dims + + self.conv = SplitConvolutionV1( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=in_channels, + strides=2, + kernel_size=3, + padding=0, + conv_only=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = nn.functional.pad(x, self.pad, mode="constant", value=0.0) + x = self.conv(x) + return x + + +class SplitResBlock(nn.Module): + """ + Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a + residual connection between input and output. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: input channels to the layer. + norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of + channels is divisible by this number. + norm_eps: epsilon for the normalisation. + out_channels: number of output channels. + """ + + def __init__( + self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + + self.norm1 = InplaceGroupNorm3D(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + # self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True) + self.conv1 = SplitConvolutionV1( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + self.norm2 = InplaceGroupNorm3D( + num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True + ) + # self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True) + self.conv2 = SplitConvolutionV1( + spatial_dims=spatial_dims, + in_channels=self.out_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + + if self.in_channels != self.out_channels: + self.nin_shortcut = SplitConvolutionV1( + spatial_dims=spatial_dims, + in_channels=self.in_channels, + out_channels=self.out_channels, + strides=1, + kernel_size=1, + padding=0, + conv_only=True, + ) + else: + self.nin_shortcut = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if True: + h = x + h = self.norm1(h) + torch.cuda.empty_cache() + + # if max(x.size()) > 500: + # h = h.to('cpu', non_blocking=True).float() + # torch.cuda.empty_cache() + + h = F.silu(h) + torch.cuda.empty_cache() + h = self.conv1(h) + torch.cuda.empty_cache() + + # if max(x.size()) > 500: + # h = h.half().to('cuda', non_blocking=True) + # torch.cuda.empty_cache() + + h = self.norm2(h) + torch.cuda.empty_cache() + + # if max(x.size()) > 500: + # h = h.to('cpu', non_blocking=True).float() + # torch.cuda.empty_cache() + + h = F.silu(h) + torch.cuda.empty_cache() + h = self.conv2(h) + torch.cuda.empty_cache() + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + torch.cuda.empty_cache() + + # if max(x.size()) > 500: + # h = h.half().to('cuda', non_blocking=True) + # x = x.half().to('cuda', non_blocking=True) + else: + h1 = self.norm1(x) + if max(h1.size()) > 500: + x = x.to("cpu", non_blocking=True).float() + torch.cuda.empty_cache() + if max(h1.size()) > 500: + h1 = h1.to("cpu", non_blocking=True).float() + torch.cuda.empty_cache() + h2 = F.silu(h1) + if max(h2.size()) > 500: + h2 = h2.half().to("cuda", non_blocking=True) + h3 = self.conv1(h2) + del h2 + torch.cuda.empty_cache() + + h4 = self.norm2(h3) + del h3 + torch.cuda.empty_cache() + if max(h4.size()) > 500: + h4 = h4.to("cpu", non_blocking=True).float() + torch.cuda.empty_cache() + h5 = F.silu(h4) + if max(h5.size()) > 500: + h5 = h5.half().to("cuda", non_blocking=True) + h6 = self.conv2(h5) + del h5 + torch.cuda.empty_cache() + + if max(h6.size()) > 500: + h6 = h6.to("cpu", non_blocking=True).float() + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + torch.cuda.empty_cache() + + out = x + h6 + if max(h6.size()) > 500: + out = out.half().to("cuda", non_blocking=True) + + return x + h + # return out + + +class EncoderTp(nn.Module): + """ + Convolutional cascade that downsamples the image into a spatial latent space. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + in_channels: number of input channels. + num_channels: sequence of block output channels. + out_channels: number of channels in the bottom layer (latent space) of the autoencoder. + num_res_blocks: number of residual blocks (see ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + num_channels: Sequence[int], + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.num_channels = num_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + + blocks = [] + # Initial convolution + blocks.append( + SplitConvolutionV1( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Residual and downsampling blocks + output_channel = num_channels[0] + for i in range(len(num_channels)): + input_channel = output_channel + output_channel = num_channels[i] + is_final_block = i == len(num_channels) - 1 + + for _ in range(self.num_res_blocks[i]): + blocks.append( + SplitResBlock( + spatial_dims=spatial_dims, + in_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=output_channel, + ) + ) + input_channel = output_channel + if attention_levels[i]: + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=input_channel, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append(SplitDownsample(spatial_dims=spatial_dims, in_channels=input_channel)) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=num_channels[-1], + ) + ) + + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=num_channels[-1], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=num_channels[-1], + ) + ) + # Normalise and convert to latent size + blocks.append( + InplaceGroupNorm3D(num_groups=norm_num_groups, num_channels=num_channels[-1], eps=norm_eps, affine=True) + ) + blocks.append( + SplitConvolutionV1( + spatial_dims=self.spatial_dims, + in_channels=num_channels[-1], + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for block in self.blocks: + x = block(x) + torch.cuda.empty_cache() + return x + + +class DecoderTp1(nn.Module): + """ + Convolutional cascade upsampling from a spatial latent space into an image space. + + Args: + spatial_dims: number of spatial dimensions (1D, 2D, 3D). + num_channels: sequence of block output channels. + in_channels: number of channels in the bottom layer (latent space) of the autoencoder. + out_channels: number of output channels. + num_res_blocks: number of residual blocks (see ResBlock) per level. + norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number. + norm_eps: epsilon for the normalization. + attention_levels: indicate which level from num_channels contain an attention block. + with_nonlocal_attn: if True use non-local attention block. + use_flash_attention: if True, use flash attention for a memory efficient attention mechanism. + use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder. + """ + + def __init__( + self, + spatial_dims: int, + num_channels: Sequence[int], + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + norm_num_groups: int, + norm_eps: float, + attention_levels: Sequence[bool], + with_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_convtranspose: bool = False, + tp_dim: int = 1, + ) -> None: + super().__init__() + self.spatial_dims = spatial_dims + self.num_channels = num_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.norm_num_groups = norm_num_groups + self.norm_eps = norm_eps + self.attention_levels = attention_levels + self.tp_dim = tp_dim + + reversed_block_out_channels = list(reversed(num_channels)) + + blocks = [] + # Initial convolution + blocks.append( + SplitConvolutionV1( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=reversed_block_out_channels[0], + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + # Non-local attention block + if with_nonlocal_attn is True: + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + blocks.append( + ResBlock( + spatial_dims=spatial_dims, + in_channels=reversed_block_out_channels[0], + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=reversed_block_out_channels[0], + ) + ) + + reversed_attention_levels = list(reversed(attention_levels)) + reversed_num_res_blocks = list(reversed(num_res_blocks)) + block_out_ch = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + block_in_ch = block_out_ch + block_out_ch = reversed_block_out_channels[i] + is_final_block = i == len(num_channels) - 1 + + for _ in range(reversed_num_res_blocks[i]): + blocks.append( + SplitResBlock( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + out_channels=block_out_ch, + ) + ) + block_in_ch = block_out_ch + + if reversed_attention_levels[i]: + blocks.append( + AttentionBlock( + spatial_dims=spatial_dims, + num_channels=block_in_ch, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + use_flash_attention=use_flash_attention, + ) + ) + + if not is_final_block: + blocks.append( + SplitUpsample1( + spatial_dims=spatial_dims, in_channels=block_in_ch, use_convtranspose=use_convtranspose + ) + ) + + blocks.append( + InplaceGroupNorm3D(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True) + ) + blocks.append( + SplitConvolutionV1( + spatial_dims=spatial_dims, + in_channels=block_in_ch, + out_channels=out_channels, + strides=1, + kernel_size=3, + padding=1, + conv_only=True, + ) + ) + + self.blocks = nn.ModuleList(blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # if False: + # for block in self.blocks: + # x = block(x) + # else: + for _i in range(len(self.blocks)): + block = self.blocks[_i] + # print(block, type(block), type(type(block))) + + if _i < len(self.blocks) - 0: + # if not isinstance(block, monai.networks.blocks.convolutions.Convolution): + x = block(x) + torch.cuda.empty_cache() + else: + # # print(block, type(block), type(type(block))) + # block = self.blocks[_i] + # # print(f"block {_i + 1}/{len(self.blocks)}") + + num_splits = NUM_SPLITS + # print("num_splits:", num_splits) + + l = x.size(self.tp_dim + 2) + split_size = l // num_splits + + if False: + splits = [x[:, :, i * split_size : (i + 1) * split_size, :, :] for i in range(num_splits)] + else: + # padding = 1 + padding = SPLIT_PADDING + # print("padding:", padding) + + overlaps = [0] + [padding] * (num_splits - 1) + if self.tp_dim == 0: + splits = [ + x[ + :, + :, + i * split_size + - overlaps[i] : (i + 1) * split_size + + (padding if i != num_splits - 1 else 0), + :, + :, + ] + for i in range(num_splits) + ] + elif self.tp_dim == 1: + splits = [ + x[ + :, + :, + :, + i * split_size + - overlaps[i] : (i + 1) * split_size + + (padding if i != num_splits - 1 else 0), + :, + ] + for i in range(num_splits) + ] + elif self.tp_dim == 2: + splits = [ + x[ + :, + :, + :, + :, + i * split_size + - overlaps[i] : (i + 1) * split_size + + (padding if i != num_splits - 1 else 0), + ] + for i in range(num_splits) + ] + + # for _j in range(len(splits)): + # print(f"splits {_j + 1}/{len(splits)}:", splits[_j].size()) + + del x + torch.cuda.empty_cache() + + outputs = [block(splits[i]) for i in range(num_splits)] + + del splits + torch.cuda.empty_cache() + + split_size_out = split_size + padding_s = padding + non_tp_dim = self.tp_dim + 1 if self.tp_dim < 2 else 0 + if outputs[0].size(non_tp_dim + 2) // splits[0].size(non_tp_dim + 2) == 2: + split_size_out *= 2 + padding_s *= 2 + # print("split_size_out:", split_size_out) + # print("padding_s:", padding_s) + + if self.tp_dim == 0: + outputs[0] = outputs[0][:, :, :split_size_out, :, :] + for i in range(1, num_splits): + outputs[i] = outputs[i][:, :, padding_s : padding_s + split_size_out, :, :] + elif self.tp_dim == 1: + # print("outputs", outputs[0].size(3), f"padding_s: 0, {split_size_out}") + outputs[0] = outputs[0][:, :, :, :split_size_out, :] + # # print("outputs", outputs[0].size(3), f"padding_s: {padding_s // 2}, {padding_s // 2 + split_size_out}") + # outputs[0] = outputs[0][:, :, :, padding_s // 2:padding_s // 2 + split_size_out, :] + for i in range(1, num_splits): + # print("outputs", outputs[i].size(3), f"padding_s: {padding_s}, {padding_s + split_size_out}") + outputs[i] = outputs[i][:, :, :, padding_s : padding_s + split_size_out, :] + elif self.tp_dim == 2: + outputs[0] = outputs[0][:, :, :, :, :split_size_out] + for i in range(1, num_splits): + outputs[i] = outputs[i][:, :, :, :, padding_s : padding_s + split_size_out] + + # for i in range(num_splits): + # print(f"outputs after {i + 1}/{len(outputs)}:", outputs[i].size()) + + if max(outputs[0].size()) < 500: + x = torch.cat(list(outputs), dim=self.tp_dim + 2) + else: + import gc + + # x = torch.randn(outputs[0].size(), dtype=outputs[0].dtype, pin_memory=True) + # x = outputs[0] + # x = x.to('cpu', non_blocking=True) + x = outputs[0].clone().to("cpu", non_blocking=True) + outputs[0] = 0 + torch.cuda.empty_cache() + for _k in range(len(outputs) - 1): + x = torch.cat((x, outputs[_k + 1].cpu()), dim=self.tp_dim + 2) + outputs[_k + 1] = 0 + torch.cuda.empty_cache() + gc.collect() + # print(f'cat: {_k + 1}/{len(outputs) - 1}.') + x = x.to("cuda", non_blocking=True) + + del outputs + torch.cuda.empty_cache() + + return x + + +class AutoencoderKlckModifiedTp(AutoencoderKL): + """ + Override encoder to make it align with original ldm codebase and support activation checkpointing. + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + num_res_blocks: Sequence[int], + num_channels: Sequence[int], + attention_levels: Sequence[bool], + latent_channels: int = 3, + norm_num_groups: int = 32, + norm_eps: float = 1e-6, + with_encoder_nonlocal_attn: bool = True, + with_decoder_nonlocal_attn: bool = True, + use_flash_attention: bool = False, + use_checkpointing: bool = False, + use_convtranspose: bool = False, + ) -> None: + super().__init__( + spatial_dims, + in_channels, + out_channels, + num_res_blocks, + num_channels, + attention_levels, + latent_channels, + norm_num_groups, + norm_eps, + with_encoder_nonlocal_attn, + with_decoder_nonlocal_attn, + use_flash_attention, + use_checkpointing, + use_convtranspose, + ) + + self.encoder = EncoderTp( + spatial_dims=spatial_dims, + in_channels=in_channels, + num_channels=num_channels, + out_channels=latent_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_encoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + ) + + # Override decoder using transposed conv + self.decoder = DecoderTp1( + spatial_dims=spatial_dims, + num_channels=num_channels, + in_channels=latent_channels, + out_channels=out_channels, + num_res_blocks=num_res_blocks, + norm_num_groups=norm_num_groups, + norm_eps=norm_eps, + attention_levels=attention_levels, + with_nonlocal_attn=with_decoder_nonlocal_attn, + use_flash_attention=use_flash_attention, + use_convtranspose=use_convtranspose, + ) diff --git a/models/maisi_ct_generative/scripts/find_masks.py b/models/maisi_ct_generative/scripts/find_masks.py new file mode 100644 index 00000000..078ae394 --- /dev/null +++ b/models/maisi_ct_generative/scripts/find_masks.py @@ -0,0 +1,120 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import json +import os +import zipfile + + +def convert_body_region(body_region: list[int]): + body_region_indices = [] + + for _k in range(len(body_region)): + region = body_region[_k].lower() + + idx = None + if "head" in region: + idx = 0 + elif "chest" in region or "thorax" in region or "chest/thorax" in region: + idx = 1 + elif "abdomen" in region: + idx = 2 + elif "pelvis" in region or "lower" in region or "pelvis/lower" in region: + idx = 3 + else: + raise ValueError("Input region information is incorrect.") + + body_region_indices.append(idx) + + return body_region_indices + + +def find_masks( + body_region: str | list[str], + anatomy_list: int | list[int], + spacing: list[float], + output_size: list[int], + check_spacing_and_output_size: bool = False, + database_filepath: str = "./database.json", + mask_foldername: str = "./masks", +): + if type(body_region) is str: + body_region = [body_region] + + body_region = convert_body_region(body_region) + + if type(anatomy_list) is int: + anatomy_list = [anatomy_list] + + if not os.path.isfile(database_filepath): + raise ValueError(f"Please download {database_filepath}.") + + if not os.path.exists(mask_foldername): + zip_file_path = mask_foldername + ".zip" + + if not os.path.isfile(zip_file_path): + raise ValueError(f"Please downloaded {zip_file_path}.") + + with zipfile.ZipFile(zip_file_path, "r") as zip_ref: + print(mask_foldername) + zip_ref.extractall(path="./datasets") + print(f"Unzipped {zip_file_path} to {mask_foldername}.") + + with open(database_filepath, "r") as f: + db = json.load(f) + + candidates = [] + for _i in range(len(db)): + _item = db[_i] + if not set(anatomy_list).issubset(_item["label_list"]): + continue + + top_index = [index for index, element in enumerate(_item["top_region_index"]) if element != 0] + top_index = top_index[0] + bottom_index = [index for index, element in enumerate(_item["bottom_region_index"]) if element != 0] + bottom_index = bottom_index[0] + + flag = False + for _idx in body_region: + if _idx > bottom_index or _idx < top_index: + flag = True + + # check if candiate mask contains tumors + for tumor_label in [23, 24, 26, 27, 128]: + # we skip those mask with tumors if users do not provide tumor label in anatomy_list + if tumor_label not in anatomy_list and tumor_label in _item["label_list"]: + flag = True + + if check_spacing_and_output_size: + # check if the output_size and spacing are same as user's input + for axis in range(3): + if _item["dim"][axis] != output_size[axis] or _item["spacing"][axis] != spacing[axis]: + flag = True + + if flag is True: + continue + + candidate = {} + if "label_filename" in _item: + candidate["label"] = os.path.join(mask_foldername, _item["label_filename"]) + candidate["pseudo_label"] = os.path.join(mask_foldername, _item["pseudo_label_filename"]) + candidate["spacing"] = _item["spacing"] + candidate["dim"] = _item["dim"] + candidate["top_region_index"] = _item["top_region_index"] + candidate["bottom_region_index"] = _item["bottom_region_index"] + + candidates.append(candidate) + + if len(candidates) == 0 and not check_spacing_and_output_size: + raise ValueError("Cannot find body region with given organ list.") + + return candidates diff --git a/models/maisi_ct_generative/scripts/sample.py b/models/maisi_ct_generative/scripts/sample.py new file mode 100644 index 00000000..9616df41 --- /dev/null +++ b/models/maisi_ct_generative/scripts/sample.py @@ -0,0 +1,699 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import random +from datetime import datetime + +import monai +import torch +from generative.inferers import LatentDiffusionInferer +from monai.data import MetaTensor +from monai.inferers import sliding_window_inference +from monai.transforms import Compose, SaveImage +from monai.utils import set_determinism +from tqdm import tqdm + +from .augmentation import augmentation +from .find_masks import find_masks +from .utils import MapLabelValue, binarize_labels, general_mask_generation_post_process, get_body_region_index_from_mask + + +class ReconModel(torch.nn.Module): + def __init__(self, autoencoder, scale_factor): + super().__init__() + self.autoencoder = autoencoder + self.scale_factor = scale_factor + + def forward(self, z): + recon_pt_nda = self.autoencoder.decode_stage_2_outputs(z / self.scale_factor) + return recon_pt_nda + + +def ldm_conditional_sample_one_mask( + autoencoder, + difusion_unet, + noise_scheduler, + scale_factor, + anatomy_size, + device, + latent_shape, + label_dict_remap_json, + num_inference_steps=1000, +): + with torch.no_grad(): + with torch.cuda.amp.autocast(): + + # Generate random noise + latents = torch.randn([1] + list(latent_shape)).half().to(device) + anatomy_size = torch.FloatTensor(anatomy_size).unsqueeze(0).unsqueeze(0).half().to(device) + # synthesize masks + noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps) + inferer_ddpm = LatentDiffusionInferer(noise_scheduler, scale_factor=scale_factor) + synthetic_mask = inferer_ddpm.sample( + input_noise=latents, + autoencoder_model=autoencoder, + diffusion_model=difusion_unet, + scheduler=noise_scheduler, + verbose=True, + conditioning=anatomy_size.to(device), + ) + synthetic_mask = torch.softmax(synthetic_mask, dim=1) + synthetic_mask = torch.argmax(synthetic_mask, dim=1, keepdim=True) + # mapping raw index to 132 labels + with open(label_dict_remap_json, "r") as f: + mapping_dict = json.load(f) + mapping = [v for _, v in mapping_dict.items()] + mapper = MapLabelValue( + orig_labels=[pair[0] for pair in mapping], + target_labels=[pair[1] for pair in mapping], + dtype=torch.uint8, + ) + synthetic_mask = mapper(synthetic_mask[0, ...])[None, ...].to(device) + + # post process + data = synthetic_mask.squeeze().cpu().detach().numpy() + if anatomy_size[0, 0, 5].item() != -1.0: + target_tumor_label = 23 + elif anatomy_size[0, 0, 6].item() != -1.0: + target_tumor_label = 24 + elif anatomy_size[0, 0, 7].item() != -1.0: + target_tumor_label = 26 + elif anatomy_size[0, 0, 8].item() != -1.0: + target_tumor_label = 27 + elif anatomy_size[0, 0, 9].item() != -1.0: + target_tumor_label = 128 + else: + target_tumor_label = None + + print("target_tumor_label for postprocess:", target_tumor_label) + data = general_mask_generation_post_process(data, target_tumor_label=target_tumor_label, device=device) + synthetic_mask = torch.from_numpy(data).unsqueeze(0).unsqueeze(0).to(device) + + return synthetic_mask + + +def ldm_conditional_sample_one_image( + autoencoder, + difusion_unet, + controlnet, + noise_scheduler, + scale_factor, + device, + comebine_label_or, + top_region_index_tensor, + bottom_region_index_tensor, + spacing_tensor, + latent_shape, + output_size, + noise_factor, + num_inference_steps=1000, +): + # CT image intensity range + a_min = -1000 + a_max = 1000 + # autoencoder output intensity range + b_min = 0.0 + b_max = 1 + + recon_model = ReconModel(autoencoder=autoencoder, scale_factor=scale_factor).to(device) + + with torch.no_grad(): + with torch.cuda.amp.autocast(): + # generate segmentation mask + comebine_label = comebine_label_or.to(device) + if ( + output_size[0] != comebine_label.shape[2] + or output_size[1] != comebine_label.shape[3] + or output_size[2] != comebine_label.shape[4] + ): + print( + "output_size is not a desired value. Need to interpolate the mask to " + "match with output_size. The result image will be very low quality." + ) + comebine_label = torch.nn.functional.interpolate(comebine_label, size=output_size, mode="nearest") + + controlnet_cond_vis = binarize_labels(comebine_label.as_tensor().long()).half() + + # Generate random noise + latents = torch.randn([1] + list(latent_shape)).half().to(device) * noise_factor + + # synthesize latents + noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps) + for t in tqdm(noise_scheduler.timesteps, ncols=110): + # Get controlnet output + down_block_res_samples, mid_block_res_sample = controlnet( + x=latents, timesteps=torch.Tensor((t,)).to(device), controlnet_cond=controlnet_cond_vis + ) + latent_model_input = latents + noise_pred = difusion_unet( + x=latent_model_input, + timesteps=torch.Tensor((t,)).to(device), + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + latents, _ = noise_scheduler.step(noise_pred, t, latents) + + # decode latents to synthesized images + synthetic_images = sliding_window_inference( + inputs=latents, + roi_size=( + min(output_size[0] // 4 // 4 * 3, 96), + min(output_size[1] // 4 // 4 * 3, 96), + min(output_size[2] // 4 // 4 * 3, 96), + ), + sw_batch_size=1, + predictor=recon_model, + mode="gaussian", + overlap=2.0 / 3.0, + sw_device=device, + device=device, + ) + + synthetic_images = torch.clip(synthetic_images, b_min, b_max).cpu() + + # post processing: + # project output to [0, 1] + synthetic_images = (synthetic_images - b_min) / (b_max - b_min) + # project output to [-1000, 1000] + synthetic_images = synthetic_images * (a_max - a_min) + a_min + # regularize background intensities + synthetic_images = crop_img_body_mask(synthetic_images, comebine_label) + + return synthetic_images, comebine_label + + +def filter_mask_with_organs(comebine_label, anatomy_list): + # final output mask file has shape of output_size, contaisn labels in anatomy_list + # it is already interpolated to target size + comebine_label = comebine_label.long() + # filter out the organs that are not in anatomy_list + for i in range(len(anatomy_list)): + organ = anatomy_list[i] + # replace it with a negative value so it will get mixed + comebine_label[comebine_label == organ] = -(i + 1) + # zero-out voxels with value not in anatomy_list + comebine_label[comebine_label > 0] = 0 + # output positive values + comebine_label = -comebine_label + return comebine_label + + +def crop_img_body_mask(synthetic_images, comebine_label): + synthetic_images[comebine_label == 0] = -1000 + return synthetic_images + + +def check_input(body_region, anatomy_list, label_dict_json, output_size, spacing, controllable_anatomy_size): + # check output_size and spacing format + if output_size[0] != output_size[1]: + raise ValueError(f"The first two components of output_size need to be equal, yet got {output_size}.") + if (output_size[0] not in [256, 384, 512]) or (output_size[2] not in [128, 256, 384, 512, 640, 768]): + raise ValueError( + "The output_size[0] have to be chosen from [256, 384, 512], and " + "output_size[2] have to be chosen from [128, 256, 384, 512, 640, 768], " + f"yet got {output_size}." + ) + + if spacing[0] != spacing[1]: + raise ValueError(f"The first two components of spacing need to be equal, yet got {spacing}.") + if spacing[0] < 0.5 or spacing[0] > 3.0 or spacing[2] < 0.5 or spacing[2] > 5.0: + raise ValueError( + f"spacing[0] have to be between 0.5 and 3.0 mm, spacing[2] have to be between 0.5 and 5.0 mm, yet got {spacing}." + ) + + # check controllable_anatomy_size format + if len(controllable_anatomy_size) > 10: + raise ValueError( + f"The length of list controllable_anatomy_size has to be less than 10. " + f"Yet got length equal to {len(controllable_anatomy_size)}." + ) + available_controllable_organ = ["liver", "gallbladder", "stomach", "pancreas", "colon"] + available_controllable_tumor = [ + "hepatic tumor", + "bone lesion", + "lung tumor", + "colon cancer primaries", + "pancreatic tumor", + ] + available_controllable_anatomy = available_controllable_organ + available_controllable_tumor + controllable_tumor = [] + controllable_organ = [] + for controllable_anatomy_size_pair in controllable_anatomy_size: + if controllable_anatomy_size_pair[0] not in available_controllable_anatomy: + raise ValueError( + f"The controllable_anatomy have to be chosen from " + f"{available_controllable_anatomy}, yet got " + f"{controllable_anatomy_size_pair[0]}." + ) + if controllable_anatomy_size_pair[0] in available_controllable_tumor: + controllable_tumor += [controllable_anatomy_size_pair[0]] + if controllable_anatomy_size_pair[0] in available_controllable_organ: + controllable_organ += [controllable_anatomy_size_pair[0]] + if controllable_anatomy_size_pair[1] == -1: + continue + if controllable_anatomy_size_pair[1] < 0 or controllable_anatomy_size_pair[1] > 1.0: + raise ValueError( + f"The controllable size scale have to be between 0 and 1,0, or equal to -1, " + f"yet got {controllable_anatomy_size_pair[1]}." + ) + if len(controllable_tumor + controllable_organ) != len(list(set(controllable_tumor + controllable_organ))): + raise ValueError(f"Please do not repeat controllable_anatomy. Got {controllable_tumor + controllable_organ}.") + if len(controllable_tumor) > 1: + raise ValueError(f"Only one controllable tumor is supported. Yet got {controllable_tumor}.") + + if len(controllable_anatomy_size) > 0: + print( + "controllable_anatomy_size is not empty. We will ignore body_region and " + "anatomy_list and synthesize based on controllable_anatomy_size." + ) + else: + print("controllable_anatomy_size is empty. We will synthesize based on body_region and anatomy_list.") + # check body_region format + available_body_region = ["head", "chest", "thorax", "abdomen", "pelvis", "lower"] + for region in body_region: + if region not in available_body_region: + raise ValueError( + f"The components in body_region have to be chosen from {available_body_region}, yet got {region}." + ) + + # check anatomy_list format + with open(label_dict_json) as f: + label_dict = json.load(f) + for anatomy in anatomy_list: + if anatomy not in label_dict.keys(): + raise ValueError( + f"The components in anatomy_list have to be chosen from {label_dict.keys()}, yet got {anatomy}." + ) + + return + + +class LDMSampler: + def __init__( + self, + body_region, + anatomy_list, + all_mask_files_json, + all_anatomy_size_condtions_json, + all_mask_files_base_dir, + label_dict_json, + label_dict_remap_json, + autoencoder, + difusion_unet, + controlnet, + noise_scheduler, + scale_factor, + mask_generation_autoencoder, + mask_generation_difusion_unet, + mask_generation_scale_factor, + mask_generation_noise_scheduler, + device, + latent_shape, + mask_generation_latent_shape, + output_size, + output_dir, + controllable_anatomy_size, + image_output_ext=".nii.gz", + label_output_ext=".nii.gz", + quality_check_args=None, + spacing=(1, 1, 1), + num_inference_steps=None, + mask_generation_num_inference_steps=None, + random_seed=None, + ) -> None: + + if random_seed is not None: + set_determinism(seed=random_seed) + + with open(label_dict_json, "r") as f: + label_dict = json.load(f) + self.all_anatomy_size_condtions_json = all_anatomy_size_condtions_json + + # intialize variables + self.body_region = body_region + self.anatomy_list = [label_dict[organ] for organ in anatomy_list] + self.all_mask_files_json = all_mask_files_json + self.data_root = all_mask_files_base_dir + self.label_dict_remap_json = label_dict_remap_json + self.autoencoder = autoencoder + self.difusion_unet = difusion_unet + self.controlnet = controlnet + self.noise_scheduler = noise_scheduler + self.scale_factor = scale_factor + self.mask_generation_autoencoder = mask_generation_autoencoder + self.mask_generation_difusion_unet = mask_generation_difusion_unet + self.mask_generation_scale_factor = mask_generation_scale_factor + self.mask_generation_noise_scheduler = mask_generation_noise_scheduler + self.device = device + self.latent_shape = latent_shape + self.mask_generation_latent_shape = mask_generation_latent_shape + self.output_size = output_size + self.output_dir = output_dir + self.noise_factor = 1.0 + self.controllable_anatomy_size = controllable_anatomy_size + if len(self.controllable_anatomy_size): + print("controllable_anatomy_size is given, mask generation is triggered!") + # overwrite the anatomy_list by given organs in self.controllable_anatomy_size + self.anatomy_list = [label_dict[organ_and_size[0]] for organ_and_size in self.controllable_anatomy_size] + self.image_output_ext = image_output_ext + self.label_output_ext = label_output_ext + # Set the default value for number of inference steps to 1000 + self.num_inference_steps = num_inference_steps if num_inference_steps is not None else 1000 + self.mask_generation_num_inference_steps = ( + mask_generation_num_inference_steps if mask_generation_num_inference_steps is not None else 1000 + ) + + # quality check disabled for this version + self.quality_check_args = quality_check_args + + self.autoencoder.eval() + self.difusion_unet.eval() + self.controlnet.eval() + self.mask_generation_autoencoder.eval() + self.mask_generation_difusion_unet.eval() + + self.spacing = spacing + + self.val_transforms = Compose( + [ + monai.transforms.LoadImaged(keys=["pseudo_label"]), + monai.transforms.EnsureChannelFirstd(keys=["pseudo_label"]), + monai.transforms.Orientationd(keys=["pseudo_label"], axcodes="RAS"), + monai.transforms.EnsureTyped(keys=["pseudo_label"], dtype=torch.uint8), + monai.transforms.Lambdad(keys="top_region_index", func=lambda x: torch.FloatTensor(x)), + monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: torch.FloatTensor(x)), + monai.transforms.Lambdad(keys="spacing", func=lambda x: torch.FloatTensor(x)), + monai.transforms.Lambdad(keys="top_region_index", func=lambda x: x * 1e2), + monai.transforms.Lambdad(keys="bottom_region_index", func=lambda x: x * 1e2), + monai.transforms.Lambdad(keys="spacing", func=lambda x: x * 1e2), + ] + ) + print("LDM sampler initialized.") + + def sample_multiple_images(self, num_img): + if len(self.controllable_anatomy_size) > 0: + # we will use mask generation instead of finding candidate masks + # create a dummy selected_mask_files for placeholder + selected_mask_files = list(range(num_img)) + # prerpare organ size conditions + anatomy_size_condtion = self.prepare_anatomy_size_condtion(self.controllable_anatomy_size) + else: + need_resample = False + # find candidate mask and save to candidate_mask_files + candidate_mask_files = find_masks( + self.body_region, + self.anatomy_list, + self.spacing, + self.output_size, + True, + self.all_mask_files_json, + self.data_root, + ) + if len(candidate_mask_files) < num_img: + # if we cannot find enough masks based on the exact match of anatomy list, spacing, and output size, + # then we will try to find the closest mask in terms of spacing, and output size. + print("Resample to get desired output size and spacing") + candidate_mask_files = self.find_closest_masks(num_img) + need_resample = True + + selected_mask_files = self.select_mask(candidate_mask_files, num_img) + print(selected_mask_files) + if len(selected_mask_files) != num_img: + raise ValueError( + f"len(selected_mask_files) ({len(selected_mask_files)}) != num_img " + f"({num_img}). This should not happen. Please revisit function " + f"select_mask(self, candidate_mask_files, num_img)." + ) + for item in selected_mask_files: + if len(self.controllable_anatomy_size) > 0: + # generate a synthetic mask + (comebine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor) = ( + self.prepare_one_mask_and_meta_info(anatomy_size_condtion) + ) + else: + # read in mask file + mask_file = item["mask_file"] + if_aug = item["if_aug"] + (comebine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor) = ( + self.read_mask_information(mask_file) + ) + if need_resample: + comebine_label_or = self.ensure_output_size_and_spacing(comebine_label_or) + # mask augmentation + if if_aug is True: + comebine_label_or = augmentation(comebine_label_or, self.output_size) + torch.cuda.empty_cache() + # generate image/label pairs + to_generate = True + try_time = 0 + while to_generate: + synthetic_images, synthetic_labels = self.sample_one_pair( + comebine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor + ) + # current quality always return True + pass_quality_check = self.quality_check(synthetic_images) + if pass_quality_check or try_time > 3: + # save image/label pairs + output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + synthetic_labels.meta["filename_or_obj"] = "sample.nii.gz" + synthetic_images = MetaTensor(synthetic_images, meta=synthetic_labels.meta) + img_saver = SaveImage( + output_dir=self.output_dir, + output_postfix=output_postfix + "_image", + output_ext=self.image_output_ext, + separate_folder=False, + ) + img_saver(synthetic_images[0]) + # filter out the organs that are not in anatomy_list + synthetic_labels = filter_mask_with_organs(synthetic_labels, self.anatomy_list) + label_saver = SaveImage( + output_dir=self.output_dir, + output_postfix=output_postfix + "_label", + output_ext=self.label_output_ext, + separate_folder=False, + ) + label_saver(synthetic_labels[0]) + to_generate = False + else: + print("Generated image/label pair did not pass quality check, will re-generate another pair.") + try_time += 1 + return + + def select_mask(self, candidate_mask_files, num_img): + selected_mask_files = [] + random.shuffle(candidate_mask_files) + + for n in range(num_img): + mask_file = candidate_mask_files[n % len(candidate_mask_files)] + selected_mask_files.append({"mask_file": mask_file, "if_aug": True}) + return selected_mask_files + + def sample_one_pair( + self, comebine_label_or_aug, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor + ): + # generate image/label pairs + synthetic_images, synthetic_labels = ldm_conditional_sample_one_image( + autoencoder=self.autoencoder, + difusion_unet=self.difusion_unet, + controlnet=self.controlnet, + noise_scheduler=self.noise_scheduler, + scale_factor=self.scale_factor, + device=self.device, + comebine_label_or=comebine_label_or_aug, + top_region_index_tensor=top_region_index_tensor, + bottom_region_index_tensor=bottom_region_index_tensor, + spacing_tensor=spacing_tensor, + latent_shape=self.latent_shape, + output_size=self.output_size, + noise_factor=self.noise_factor, + num_inference_steps=self.num_inference_steps, + ) + return synthetic_images, synthetic_labels + + def prepare_anatomy_size_condtion(self, controllable_anatomy_size): + anatomy_size_idx = { + "gallbladder": 0, + "liver": 1, + "stomach": 2, + "pancreas": 3, + "colon": 4, + "lung tumor": 5, + "pancreatic tumor": 6, + "hepatic tumor": 7, + "colon cancer primaries": 8, + "bone lesion": 9, + } + provide_anatomy_size = [None for _ in range(10)] + print("controllable_anatomy_size:", controllable_anatomy_size) + for element in controllable_anatomy_size: + anatomy_name, anatomy_size = element + provide_anatomy_size[anatomy_size_idx[anatomy_name]] = anatomy_size + + with open(self.all_anatomy_size_condtions_json, "r") as f: + all_anatomy_size_condtions = json.load(f) + + # loop through the database and find closest combinations + candidate_list = [] + for anatomy_size in all_anatomy_size_condtions: + size = anatomy_size["organ_size"] + diff = 0 + for db_size, provide_size in zip(size, provide_anatomy_size): + if provide_size is None: + continue + diff += abs(provide_size - db_size) + candidate_list.append((size, diff)) + candidate_condition = sorted(candidate_list, key=lambda x: x[1])[0][0] + print("provide_anatomy_size:", provide_anatomy_size) + print("candidate_condition:", candidate_condition) + + # overwrite the anatomy size provided by users + for element in controllable_anatomy_size: + anatomy_name, anatomy_size = element + candidate_condition[anatomy_size_idx[anatomy_name]] = anatomy_size + print("final candidate_condition:", candidate_condition) + return candidate_condition + + def prepare_one_mask_and_meta_info(self, anatomy_size_condtion): + comebine_label_or = self.sample_one_mask(anatomy_size=anatomy_size_condtion) + # TODO: current mask generation model only can generate 256^3 volumes with 1.5 mm spacing. + affine = torch.zeros((4, 4)) + affine[0, 0] = 1.5 + affine[1, 1] = 1.5 + affine[2, 2] = 1.5 + affine[3, 3] = 1.0 # dummy + comebine_label_or = MetaTensor(comebine_label_or, affine=affine) + comebine_label_or = self.ensure_output_size_and_spacing(comebine_label_or) + + top_region_index, bottom_region_index = get_body_region_index_from_mask(comebine_label_or) + + spacing_tensor = torch.FloatTensor(self.spacing).unsqueeze(0).half().to(self.device) * 1e2 + top_region_index_tensor = torch.FloatTensor(top_region_index).unsqueeze(0).half().to(self.device) * 1e2 + bottom_region_index_tensor = torch.FloatTensor(bottom_region_index).unsqueeze(0).half().to(self.device) * 1e2 + + return comebine_label_or, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor + + def sample_one_mask(self, anatomy_size): + # generate one synthetic mask + synthetic_mask = ldm_conditional_sample_one_mask( + self.mask_generation_autoencoder, + self.mask_generation_difusion_unet, + self.mask_generation_noise_scheduler, + self.mask_generation_scale_factor, + anatomy_size, + self.device, + self.mask_generation_latent_shape, + label_dict_remap_json=self.label_dict_remap_json, + num_inference_steps=self.mask_generation_num_inference_steps, + ) + return synthetic_mask + + def ensure_output_size_and_spacing(self, labels, check_contains_target_labels=True): + current_spacing = [labels.affine[0, 0], labels.affine[1, 1], labels.affine[2, 2]] + current_shape = list(labels.squeeze().shape) + + need_resample = False + # check spacing + for i, j in zip(current_spacing, self.spacing): + if i != j: + need_resample = True + # check output size + for i, j in zip(current_shape, self.output_size): + if i != j: + need_resample = True + # resample to target size and spacing + if need_resample: + print("Resampling mask to target shape and sapcing") + print(f"Output size: {current_shape} -> {self.output_size}") + print(f"Sapcing: {current_spacing} -> {self.spacing}") + spacing = monai.transforms.Spacing(pixdim=tuple(self.spacing), mode="nearest") + pad = monai.transforms.SpatialPad(spatial_size=tuple(self.output_size)) + crop = monai.transforms.CenterSpatialCrop(roi_size=tuple(self.output_size)) + labels = crop(pad(spacing(labels.squeeze(0)))).unsqueeze(0) + contained_labels = torch.unique(labels) + if check_contains_target_labels: + # check if the resampled mask still contains those target labels + for anatomy_label in self.anatomy_list: + if anatomy_label not in contained_labels: + raise ValueError( + "Resampled mask does not contain required class labels. Please tune spacing and output size" + ) + return labels + + def read_mask_information(self, mask_file): + val_data = self.val_transforms(mask_file) + + for key in ["pseudo_label", "spacing", "top_region_index", "bottom_region_index"]: + val_data[key] = val_data[key].unsqueeze(0).to(self.device) + + return ( + val_data["pseudo_label"], + val_data["top_region_index"], + val_data["bottom_region_index"], + val_data["spacing"], + ) + + def find_closest_masks(self, num_img): + # first check the database based on anatomy list + candidates = find_masks( + self.body_region, + self.anatomy_list, + self.spacing, + self.output_size, + False, + self.all_mask_files_json, + self.data_root, + ) + + if len(candidates) < num_img: + raise ValueError(f"candidate masks are less than {num_img}).") + # loop through the database and find closest combinations + new_candidates = [] + for c in candidates: + diff = 0 + for axis in range(3): + # check diff in dim + diff += abs((c["dim"][axis] - self.output_size[axis]) / 100) + # check diff in spacing + diff += abs(c["spacing"][axis] - self.spacing[axis]) + new_candidates.append((c, diff)) + # choose top-2*num_img candidates (at least 5) + new_candidates = sorted(new_candidates, key=lambda x: x[1])[: max(2 * num_img, 5)] + final_candidates = [] + # check top-2*num_img candidates and update spacing after resampling + image_loader = monai.transforms.LoadImage(image_only=True, ensure_channel_first=True) + for c, _ in new_candidates: + label = image_loader(c["pseudo_label"]) + try: + label = self.ensure_output_size_and_spacing(label.unsqueeze(0)) + except ValueError as e: + if "Resampled mask does not contain required class labels" in str(e): + continue + else: + raise e + # get region_index after resample + top_region_index, bottom_region_index = get_body_region_index_from_mask(label) + c["top_region_index"] = top_region_index + c["bottom_region_index"] = bottom_region_index + c["spacing"] = self.spacing + c["dim"] = self.output_size + + final_candidates.append(c) + if len(final_candidates) == 0: + raise ValueError("Cannot find body region with given organ list.") + return final_candidates + + def quality_check(self, image): + # This version disabled quality check + return True diff --git a/models/maisi_ct_generative/scripts/trainer.py b/models/maisi_ct_generative/scripts/trainer.py new file mode 100644 index 00000000..e935e325 --- /dev/null +++ b/models/maisi_ct_generative/scripts/trainer.py @@ -0,0 +1,247 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence + +import torch +import torch.nn.functional as F +from generative.networks.schedulers import Scheduler +from monai.config import IgniteInfo +from monai.engines.trainer import Trainer +from monai.engines.utils import IterationEvents, PrepareBatchExtraInput, default_metric_cmp_fn +from monai.inferers import Inferer +from monai.transforms import Transform +from monai.utils import RankFilter, min_version, optional_import +from monai.utils.enums import CommonKeys as Keys +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + +from .utils import binarize_labels + +if TYPE_CHECKING: + from ignite.engine import Engine, EventEnum + from ignite.metrics import Metric +else: + Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine") + Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric") + EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum") + +__all__ = ["MAISIControlNetTrainer"] + +# Module-level variable for prepare_batch default value +DEFAULT_PREPARE_BATCH = PrepareBatchExtraInput(extra_keys=("dim", "spacing", "top_region_index", "bottom_region_index")) + + +class MAISIControlNetTrainer(Trainer): + """ + Supervised training method with image and label, inherits from ``Trainer`` and ``Workflow``. + Args: + device: an object representing the device on which to run. + max_epochs: the total epoch number for trainer to run. + train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader. + controlnet: controlnet to train in the trainer, should be regular PyTorch `torch.nn.Module`. + difusion_unet: difusion_unet used in the trainer, should be regular PyTorch `torch.nn.Module`. + optimizer: the optimizer associated to the detector, should be regular PyTorch optimizer from `torch.optim` + or its subclass. + epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`. + non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously + with respect to the host. For other cases, this argument has no effect. + prepare_batch: function to parse expected data (usually `image`,`box`, `label` and other detector args) + from `engine.state.batch` for every iteration, for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html. + iteration_update: the callable function for every iteration, expect to accept `engine` + and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`. + if not provided, use `self._iteration()` instead. for more details please refer to: + https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html. + inferer: inference method that execute model forward on input data, like: SlidingWindow, etc. + postprocessing: execute additional transformation for the model output data. + Typically, several Tensor based transforms composed by `Compose`. + key_train_metric: compute metric when every iteration completed, and save average value to + engine.state.metrics when epoch completlabel_set = np.arange(output_classes).tolist()d. + key_train_metric is the main metric to compare and save the checkpoint into files. + additional_metrics: more Ignite metrics that also attach to Ignite Engine. + metric_cmp_fn: function to compare current key metric with previous best key metric value, + it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update + `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`. + train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like: + CheckpointHandler, StatsHandler, etc. + amp: whether to enable auto-mixed-precision training, default is False. + event_names: additional custom ignite events that will register to the engine. + new events can be a list of str or `ignite.engine.events.EventEnum`. + event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`. + for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html + #ignite.engine.engine.Engine.register_events. + decollate: whether to decollate the batch-first data to a list of data after model computation, + recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`. + default to `True`. + optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None. + more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html. + to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for + `device`, `non_blocking`. + amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast. + """ + + def __init__( + self, + device: torch.device, + max_epochs: int, + train_data_loader: Iterable | DataLoader, + controlnet: torch.nn.Module, + difusion_unet: torch.nn.Module, + optimizer: Optimizer, + loss_function: Callable, + inferer: Inferer, + noise_scheduler: Scheduler, + epoch_length: int | None = None, + non_blocking: bool = False, + prepare_batch: Callable = DEFAULT_PREPARE_BATCH, + iteration_update: Callable[[Engine, Any], Any] | None = None, + postprocessing: Transform | None = None, + key_train_metric: dict[str, Metric] | None = None, + additional_metrics: dict[str, Metric] | None = None, + metric_cmp_fn: Callable = default_metric_cmp_fn, + train_handlers: Sequence | None = None, + amp: bool = False, + event_names: list[str | EventEnum] | None = None, + event_to_attr: dict | None = None, + decollate: bool = True, + optim_set_to_none: bool = False, + to_kwargs: dict | None = None, + amp_kwargs: dict | None = None, + hyper_kwargs: dict | None = None, + ) -> None: + super().__init__( + device=device, + max_epochs=max_epochs, + data_loader=train_data_loader, + epoch_length=epoch_length, + non_blocking=non_blocking, + prepare_batch=prepare_batch, + iteration_update=iteration_update, + postprocessing=postprocessing, + key_metric=key_train_metric, + additional_metrics=additional_metrics, + metric_cmp_fn=metric_cmp_fn, + handlers=train_handlers, + amp=amp, + event_names=event_names, + event_to_attr=event_to_attr, + decollate=decollate, + to_kwargs=to_kwargs, + amp_kwargs=amp_kwargs, + ) + + self.controlnet = controlnet + self.difusion_unet = difusion_unet + self.optimizer = optimizer + self.loss_function = loss_function + self.inferer = inferer + self.optim_set_to_none = optim_set_to_none + self.hyper_kwargs = hyper_kwargs + self.noise_scheduler = noise_scheduler + self.logger.addFilter(RankFilter()) + for p in self.difusion_unet.parameters(): + p.requires_grad = False + print("freeze the parameters of the diffusion unet model.") + + def _iteration(self, engine, batchdata: dict[str, torch.Tensor]): + """ + Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine. + Return below items in a dictionary: + - IMAGE: image Tensor data for model input, already moved to device. + Args: + engine: `Vista3DTrainer` to execute operation for an iteration. + batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data. + Raises: + ValueError: When ``batchdata`` is None. + """ + + if batchdata is None: + raise ValueError("Must provide batch data for current iteration.") + + inputs, labels, (dim, spacing, top_region_index, bottom_region_index), _ = engine.prepare_batch( + batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs + ) + engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: labels} + weighted_loss_label = engine.hyper_kwargs["weighted_loss_label"] + weighted_loss = engine.hyper_kwargs["weighted_loss"] + scale_factor = engine.hyper_kwargs["scale_factor"] + # scale image embedding by the provided scale_factor + inputs = inputs * scale_factor + + def _compute_pred_loss(): + # generate random noise + noise_shape = list(inputs.shape) + noise = torch.randn(noise_shape, dtype=inputs.dtype).to(inputs.device) + + # use binary encoding to encode segmentation mask + controlnet_cond = binarize_labels(labels.as_tensor().to(torch.uint8)).float() + + # create timesteps + timesteps = torch.randint( + 0, engine.noise_scheduler.num_train_timesteps, (inputs.shape[0],), device=inputs.device + ).long() + + # Create noisy latent + noisy_latent = engine.noise_scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps) + + # Get controlnet output + down_block_res_samples, mid_block_res_sample = engine.controlnet( + x=noisy_latent, timesteps=timesteps, controlnet_cond=controlnet_cond + ) + noise_pred = engine.difusion_unet( + x=noisy_latent, + timesteps=timesteps, + top_region_index_tensor=top_region_index, + bottom_region_index_tensor=bottom_region_index, + spacing_tensor=spacing, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ) + + engine.state.output[Keys.PRED] = noise_pred + engine.fire_event(IterationEvents.FORWARD_COMPLETED) + + if weighted_loss > 1.0: + weights = torch.ones_like(inputs).to(inputs.device) + roi = torch.zeros([noise_shape[0]] + [1] + noise_shape[2:]).to(inputs.device) + interpolate_label = F.interpolate(labels, size=inputs.shape[2:], mode="nearest") + # assign larger weights for ROI (tumor) + for label in weighted_loss_label: + roi[interpolate_label == label] = 1 + weights[roi.repeat(1, inputs.shape[1], 1, 1, 1) == 1] = weighted_loss + loss = (F.l1_loss(noise_pred.float(), noise.float(), reduction="none") * weights).mean() + else: + loss = F.l1_loss(noise_pred.float(), noise.float()) + + engine.state.output[Keys.LOSS] = loss + engine.fire_event(IterationEvents.LOSS_COMPLETED) + + engine.controlnet.train() + engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none) + + if engine.amp and engine.scaler is not None: + with torch.cuda.amp.autocast(**engine.amp_kwargs): + _compute_pred_loss() + engine.scaler.scale(engine.state.output[Keys.LOSS]).backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) + engine.scaler.step(engine.optimizer) + engine.scaler.update() + else: + _compute_pred_loss() + engine.state.output[Keys.LOSS].backward() + engine.fire_event(IterationEvents.BACKWARD_COMPLETED) + engine.optimizer.step() + engine.fire_event(IterationEvents.MODEL_COMPLETED) + return engine.state.output diff --git a/models/maisi_ct_generative/scripts/utils.py b/models/maisi_ct_generative/scripts/utils.py new file mode 100644 index 00000000..5d939aff --- /dev/null +++ b/models/maisi_ct_generative/scripts/utils.py @@ -0,0 +1,429 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +import copy +import json +import os +import zipfile +from typing import Sequence + +import numpy as np +import skimage +import torch +import torch.nn.functional as F +from monai.config import DtypeLike, NdarrayOrTensor +from monai.utils import ( + TransformBackends, + convert_data_type, + convert_to_dst_type, + ensure_tuple_rep, + get_equivalent_dtype, +) +from scipy import stats + + +def unzip_dataset(dataset_dir): + if not os.path.exists(dataset_dir): + zip_file_path = dataset_dir + ".zip" + if not os.path.isfile(zip_file_path): + raise ValueError(f"Please downloaded {zip_file_path}.") + with zipfile.ZipFile(zip_file_path, "r") as zip_ref: + zip_ref.extractall(path=os.path.dirname(dataset_dir)) + print(f"Unzipped {zip_file_path} to {dataset_dir}.") + + +def add_data_dir2path(list_files, data_dir, fold=None): + new_list_files = copy.deepcopy(list_files) + if fold is not None: + new_list_files_train = [] + new_list_files_val = [] + for d in new_list_files: + d["image"] = os.path.join(data_dir, d["image"]) + + if "label" in d: + d["label"] = os.path.join(data_dir, d["label"]) + + if fold is not None: + if d["fold"] == fold: + new_list_files_val.append(copy.deepcopy(d)) + else: + new_list_files_train.append(copy.deepcopy(d)) + + if fold is not None: + return new_list_files_train, new_list_files_val + else: + return new_list_files, [] + + +def maisi_datafold_read(json_list, data_base_dir, fold=None): + with open(json_list, "r") as f: + filenames_train = json.load(f)["training"] + # training data + train_files, val_files = add_data_dir2path(filenames_train, data_base_dir, fold=fold) + print(f"dataset: {data_base_dir}, num_training_files: {len(train_files)}, num_val_files: {len(val_files)}") + return train_files, val_files + + +def get_index_arr(img): + return np.moveaxis( + np.moveaxis( + np.stack(np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]), np.arange(img.shape[2]))), 0, 3 + ), + 0, + 1, + ) + + +def supress_non_largest_components(img, target_label, default_val=0): + """As a last step, supress all non largest components""" + index_arr = get_index_arr(img) + img_mod = copy.deepcopy(img) + new_background = np.zeros(img.shape, dtype=np.bool_) + for label in target_label: + label_cc = skimage.measure.label(img == label, connectivity=3) + uv, uc = np.unique(label_cc, return_counts=True) + dominant_vals = uv[np.argsort(uc)[::-1][:2]] + if len(dominant_vals) >= 2: # Case: no predictions + new_background = np.logical_or( + new_background, + np.logical_not(np.logical_or(label_cc == dominant_vals[0], label_cc == dominant_vals[1])), + ) + + for voxel in index_arr[new_background]: + img_mod[tuple(voxel)] = default_val + diff = np.sum((img - img_mod) > 0) + + return img_mod, diff + + +def erode3d(input_tensor, erosion=3, value=0.0): + # Define the structuring element + erosion = ensure_tuple_rep(erosion, 3) + structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to(input_tensor.device) + + # Pad the input tensor to handle border pixels + input_padded = F.pad( + input_tensor.float().unsqueeze(0).unsqueeze(0), + (erosion[0] // 2, erosion[0] // 2, erosion[1] // 2, erosion[1] // 2, erosion[2] // 2, erosion[2] // 2), + mode="constant", + value=value, + ) + + # Apply erosion operation + output = F.conv3d(input_padded, structuring_element, padding=0) + + # Set output values based on the minimum value within the structuring element + output = torch.where(output == torch.sum(structuring_element), 1.0, 0.0) + + return output.squeeze(0).squeeze(0) + + +def dilate3d(input_tensor, erosion=3, value=0.0): + # Define the structuring element + erosion = ensure_tuple_rep(erosion, 3) + structuring_element = torch.ones(1, 1, erosion[0], erosion[1], erosion[2]).to(input_tensor.device) + + # Pad the input tensor to handle border pixels + input_padded = F.pad( + input_tensor.float().unsqueeze(0).unsqueeze(0), + (erosion[0] // 2, erosion[0] // 2, erosion[1] // 2, erosion[1] // 2, erosion[2] // 2, erosion[2] // 2), + mode="constant", + value=value, + ) + + # Apply erosion operation + output = F.conv3d(input_padded, structuring_element, padding=0) + + # Set output values based on the minimum value within the structuring element + output = torch.where(output > 0, 1.0, 0.0) + + return output.squeeze(0).squeeze(0) + + +def organ_fill_by_closing(data, target_label, device): + mask = (data == target_label).astype(np.uint8) + mask = dilate3d(torch.from_numpy(mask).to(device), erosion=3, value=0.0) + mask = erode3d(mask, erosion=3, value=0.0) + mask = dilate3d(mask, erosion=3, value=0.0) + mask = erode3d(mask, erosion=3, value=0.0).cpu().numpy() + return mask.astype(np.bool_) + + +def organ_fill_by_removed_mask(data, target_label, remove_mask, device): + mask = (data == target_label).astype(np.uint8) + mask = dilate3d(torch.from_numpy(mask).to(device), erosion=3, value=0.0) + mask = dilate3d(mask, erosion=3, value=0.0) + roi_oragn_mask = dilate3d(mask, erosion=3, value=0.0).cpu().numpy() + return (roi_oragn_mask * remove_mask).astype(np.bool_) + + +def get_body_region_index_from_mask(input_mask): + region_indices = {} + # head and neck + region_indices["region_0"] = [22, 120] + # thorax + region_indices["region_1"] = [28, 29, 30, 31, 32] + # abdomen + region_indices["region_2"] = [1, 2, 3, 4, 5, 14] + # pelvis and lower + region_indices["region_3"] = [93, 94] + + nda = input_mask.cpu().numpy().squeeze() + unique_elements = np.lib.arraysetops.unique(nda) + unique_elements = list(unique_elements) + print(f"nda: {nda.shape} {unique_elements}.") + overlap_array = np.zeros(len(region_indices), dtype=np.uint8) + for _j in range(len(region_indices)): + overlap = any(element in region_indices[f"region_{_j}"] for element in unique_elements) + overlap_array[_j] = np.uint8(overlap) + overlap_array_indices = np.nonzero(overlap_array)[0] + top_region_index = np.eye(len(region_indices), dtype=np.uint8)[np.amin(overlap_array_indices), ...] + top_region_index = list(top_region_index) + top_region_index = [int(_k) for _k in top_region_index] + bottom_region_index = np.eye(len(region_indices), dtype=np.uint8)[np.amax(overlap_array_indices), ...] + bottom_region_index = list(bottom_region_index) + bottom_region_index = [int(_k) for _k in bottom_region_index] + print(f"{top_region_index} {bottom_region_index}") + return top_region_index, bottom_region_index + + +def general_mask_generation_post_process(volume_t, target_tumor_label=None, device="cuda:0"): + # assume volume_t is np array with shape (H,W,D) + hepatic_vessel = volume_t == 25 + airway = volume_t == 132 + + # ------------ refine body mask pred + body_region_mask = erode3d(torch.from_numpy((volume_t > 0)).to(device), erosion=3, value=0.0).cpu().numpy() + body_region_mask, _ = supress_non_largest_components(body_region_mask, [1]) + body_region_mask = ( + dilate3d(torch.from_numpy(body_region_mask).to(device), erosion=3, value=0.0).cpu().numpy().astype(np.uint8) + ) + volume_t = volume_t * body_region_mask + + # ------------ refine tumor pred + tumor_organ_dict = {23: 28, 24: 4, 26: 1, 27: 62, 128: 200} + for t in [23, 24, 26, 27, 128]: + if t != target_tumor_label: + volume_t[volume_t == t] = tumor_organ_dict[t] + else: + volume_t[organ_fill_by_closing(volume_t, target_label=t, device=device)] = t + volume_t[organ_fill_by_closing(volume_t, target_label=t, device=device)] = t + # we only keep the largest connected componet for tumors except hepatic tumor and bone lesion + if target_tumor_label != 26 and target_tumor_label != 128: + volume_t, _ = supress_non_largest_components(volume_t, [target_tumor_label], default_val=200) + target_tumor = volume_t == target_tumor_label + + # ------------ remove undesired organ pred + # general post-process non-largest components suppression + # process 4 ROI organs + spleen + 2 kidney + 5 lung lobes + duodenum + inferior vena cava + oran_list = [1, 4, 10, 12, 3, 28, 29, 30, 31, 32, 5, 14, 13, 6, 7, 8, 9, 10] + if target_tumor_label != 128: + oran_list += list(range(33, 60)) # + list(range(63,87)) + data, _ = supress_non_largest_components(volume_t, oran_list, default_val=200) # 200 is body region + organ_remove_mask = (volume_t - data).astype(np.bool_) + # process intestinal system (stomach 12, duodenum 13, small bowel 19, colon 62) + intestinal_mask_ = ( + (data == 12).astype(np.uint8) + + (data == 13).astype(np.uint8) + + (data == 19).astype(np.uint8) + + (data == 62).astype(np.uint8) + ) + intestinal_mask, _ = supress_non_largest_components(intestinal_mask_, [1], default_val=0) + # process small bowel 19 + small_bowel_remove_mask = (data == 19).astype(np.uint8) - (data == 19).astype(np.uint8) * intestinal_mask + # process colon 62 + colon_remove_mask = (data == 62).astype(np.uint8) - (data == 62).astype(np.uint8) * intestinal_mask + intestinal_remove_mask = (small_bowel_remove_mask + colon_remove_mask).astype(np.bool_) + data[intestinal_remove_mask] = 200 + + # ------------ full correponding organ in removed regions + for organ_label in oran_list: + data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label + + if target_tumor_label == 23 and np.sum(target_tumor) > 0: + # speical process for cases with lung tumor + dia_lung_tumor_mask = dilate3d(torch.from_numpy((data == 23)).to(device), erosion=3, value=0.0).cpu().numpy() + tmp = ( + (data * (dia_lung_tumor_mask.astype(np.uint8) - (data == 23).astype(np.uint8))).astype(np.float32).flatten() + ) + tmp[tmp == 0] = float("nan") + mode = int(stats.mode(tmp.flatten(), nan_policy="omit")[0]) + if mode in [28, 29, 30, 31, 32]: + dia_lung_tumor_mask = ( + dilate3d(torch.from_numpy(dia_lung_tumor_mask).to(device), erosion=3, value=0.0).cpu().numpy() + ) + lung_remove_mask = dia_lung_tumor_mask.astype(np.uint8) - (data == 23).astype(np.uint8).astype(np.uint8) + data[organ_fill_by_removed_mask(data, target_label=mode, remove_mask=lung_remove_mask, device=device)] = ( + mode + ) + dia_lung_tumor_mask = ( + dilate3d(torch.from_numpy(dia_lung_tumor_mask).to(device), erosion=3, value=0.0).cpu().numpy() + ) + data[ + organ_fill_by_removed_mask( + data, target_label=23, remove_mask=dia_lung_tumor_mask * organ_remove_mask, device=device + ) + ] = 23 + for organ_label in [28, 29, 30, 31, 32]: + data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label + data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label + data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label + + if target_tumor_label == 26 and np.sum(target_tumor) > 0: + # speical process for cases with hepatic tumor + # process liver 1 + data[organ_fill_by_removed_mask(data, target_label=1, remove_mask=intestinal_remove_mask, device=device)] = 1 + data[organ_fill_by_removed_mask(data, target_label=1, remove_mask=intestinal_remove_mask, device=device)] = 1 + # process spleen 2 + data[organ_fill_by_removed_mask(data, target_label=3, remove_mask=organ_remove_mask, device=device)] = 3 + data[organ_fill_by_removed_mask(data, target_label=3, remove_mask=organ_remove_mask, device=device)] = 3 + dia_tumor_mask = ( + dilate3d(torch.from_numpy((data == target_tumor_label)).to(device), erosion=3, value=0.0).cpu().numpy() + ) + dia_tumor_mask = dilate3d(torch.from_numpy(dia_tumor_mask).to(device), erosion=3, value=0.0).cpu().numpy() + data[ + organ_fill_by_removed_mask( + data, target_label=target_tumor_label, remove_mask=dia_tumor_mask * organ_remove_mask, device=device + ) + ] = target_tumor_label + # refine hepatic tumor + hepatic_tumor_vessel_liver_mask_ = ( + (data == 26).astype(np.uint8) + (data == 25).astype(np.uint8) + (data == 1).astype(np.uint8) + ) + hepatic_tumor_vessel_liver_mask_ = (hepatic_tumor_vessel_liver_mask_ > 1).astype(np.uint8) + hepatic_tumor_vessel_liver_mask, _ = supress_non_largest_components( + hepatic_tumor_vessel_liver_mask_, [1], default_val=0 + ) + removed_region = (hepatic_tumor_vessel_liver_mask_ - hepatic_tumor_vessel_liver_mask).astype(np.bool_) + data[removed_region] = 200 + target_tumor = (target_tumor * hepatic_tumor_vessel_liver_mask).astype(np.bool_) + # refine liver + data[organ_fill_by_closing(data, target_label=1, device=device)] = 1 + data[organ_fill_by_closing(data, target_label=1, device=device)] = 1 + data[organ_fill_by_closing(data, target_label=1, device=device)] = 1 + + if target_tumor_label == 27 and np.sum(target_tumor) > 0: + # speical process for cases with colon tumor + dia_tumor_mask = ( + dilate3d(torch.from_numpy((data == target_tumor_label)).to(device), erosion=3, value=0.0).cpu().numpy() + ) + dia_tumor_mask = dilate3d(torch.from_numpy(dia_tumor_mask).to(device), erosion=3, value=0.0).cpu().numpy() + data[ + organ_fill_by_removed_mask( + data, target_label=target_tumor_label, remove_mask=dia_tumor_mask * organ_remove_mask, device=device + ) + ] = target_tumor_label + + if target_tumor_label == 129 and np.sum(target_tumor) > 0: + # speical process for cases with kidney tumor + for organ_label in [5, 14]: + data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label + data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label + data[organ_fill_by_closing(data, target_label=organ_label, device=device)] = organ_label + # TODO: current model does not support hepatic vessel by size control. + # we treat it as liver for better visiaulization + print( + "Current model does not support hepatic vessel by size control, " + "so we treat generated hepatic vessel as part of liver for better visiaulization." + ) + data[hepatic_vessel] = 1 + data[airway] = 132 + if target_tumor_label is not None: + data[target_tumor] = target_tumor_label + + return data + + +class MapLabelValue: + """ + Utility to map label values to another set of values. + For example, map [3, 2, 1] to [0, 1, 2], [1, 2, 3] -> [0.5, 1.5, 2.5], ["label3", "label2", "label1"] -> [0, 1, 2], + [3.5, 2.5, 1.5] -> ["label0", "label1", "label2"], etc. + The label data must be numpy array or array-like data and the output data will be numpy array. + + """ + + backend = [TransformBackends.NUMPY, TransformBackends.TORCH] + + def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None: + """ + Args: + orig_labels: original labels that map to others. + target_labels: expected label values, 1: 1 map to the `orig_labels`. + dtype: convert the output data to dtype, default to float32. + if dtype is from PyTorch, the transform will use the pytorch backend, else with numpy backend. + + """ + if len(orig_labels) != len(target_labels): + raise ValueError("orig_labels and target_labels must have the same length.") + + self.orig_labels = orig_labels + self.target_labels = target_labels + self.pair = tuple((o, t) for o, t in zip(self.orig_labels, self.target_labels) if o != t) + type_dtype = type(dtype) + if getattr(type_dtype, "__module__", "") == "torch": + self.use_numpy = False + self.dtype = get_equivalent_dtype(dtype, data_type=torch.Tensor) + else: + self.use_numpy = True + self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray) + + def __call__(self, img: NdarrayOrTensor): + if self.use_numpy: + img_np, *_ = convert_data_type(img, np.ndarray) + _out_shape = img_np.shape + img_flat = img_np.flatten() + try: + out_flat = img_flat.astype(self.dtype) + except ValueError: + # can't copy unchanged labels as the expected dtype is not supported, must map all the label values + out_flat = np.zeros(shape=img_flat.shape, dtype=self.dtype) + for o, t in self.pair: + out_flat[img_flat == o] = t + out_t = out_flat.reshape(_out_shape) + else: + img_t, *_ = convert_data_type(img, torch.Tensor) + out_t = img_t.detach().clone().to(self.dtype) # type: ignore + for o, t in self.pair: + out_t[img_t == o] = t + out, *_ = convert_to_dst_type(src=out_t, dst=img, dtype=self.dtype) + return out + + +def load_autoencoder_ckpt(load_autoencoder_path): + checkpoint_autoencoder = torch.load(load_autoencoder_path) + new_state_dict = {} + for k, v in checkpoint_autoencoder.items(): + if "decoder" in k and "conv" in k: + new_key = ( + k.replace("conv.weight", "conv.conv.weight") + if "conv.weight" in k + else k.replace("conv.bias", "conv.conv.bias") + ) + new_state_dict[new_key] = v + elif "encoder" in k and "conv" in k: + new_key = ( + k.replace("conv.weight", "conv.conv.weight") + if "conv.weight" in k + else k.replace("conv.bias", "conv.conv.bias") + ) + new_state_dict[new_key] = v + else: + new_state_dict[k] = v + checkpoint_autoencoder = new_state_dict + return checkpoint_autoencoder + + +def binarize_labels(x, bits=8): + """ + x: the input tensor with shape (B, 1, H, W, D) + bits: the num of channel to represent the data. + """ + mask = 2 ** torch.arange(bits).to(x.device, x.dtype) + return x.unsqueeze(-1).bitwise_and(mask).ne(0).byte().squeeze(1).permute(0, 4, 1, 2, 3)