diff --git a/ci/unit_tests/test_cxr_image_synthesis_latent_diffusion_model.py b/ci/unit_tests/test_cxr_image_synthesis_latent_diffusion_model.py new file mode 100644 index 00000000..44610eff --- /dev/null +++ b/ci/unit_tests/test_cxr_image_synthesis_latent_diffusion_model.py @@ -0,0 +1,45 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest +from parameterized import parameterized +import numpy as np +from monai.bundle import ConfigWorkflow +from utils import check_workflow + +TEST_CASE_1 = [ # inference + { + "bundle_root": "models/cxr_image_synthesis_latent_diffusion_model", + "prompt": "Big right-sided pleural effusion. Normal left lung.", + "guidance_scale": 7.0, + } +] + +class TestCXRLatentDiffusionInference(unittest.TestCase): + @parameterized.expand([TEST_CASE_1]) + def test_inference(self, params): + bundle_root = params["bundle_root"] + inference_file = os.path.join(bundle_root, "configs/inference.json") + trainer = ConfigWorkflow( + workflow_type="inference", + config_file=inference_file, + logging_file=os.path.join(bundle_root, "configs/logging.conf"), + meta_file=os.path.join(bundle_root, "configs/metadata.json"), + **params, + ) + check_workflow(trainer, check_properties=True) + +if __name__ == "__main__": + loader = unittest.TestLoader() + unittest.main(testLoader=loader) diff --git a/ci/verify_bundle.py b/ci/verify_bundle.py index 845611fa..7beca3e5 100644 --- a/ci/verify_bundle.py +++ b/ci/verify_bundle.py @@ -54,6 +54,8 @@ def _get_weights_names(bundle: str): if bundle == "pediatric_abdominal_ct_segmentation": # skip test for this bundle's ts file return "dynunet_FT.pt", None + if bundle == "cxr_image_synthesis_latent_diffusion_model": + return "autoencoder.pt", None return "model.pt", "model.ts" diff --git a/models/cxr_image_synthesis_latent_diffusion_model/configs/inference.json b/models/cxr_image_synthesis_latent_diffusion_model/configs/inference.json index 6578e3ea..b36315fb 100644 --- a/models/cxr_image_synthesis_latent_diffusion_model/configs/inference.json +++ b/models/cxr_image_synthesis_latent_diffusion_model/configs/inference.json @@ -7,6 +7,10 @@ "$from transformers import CLIPTokenizer" ], "bundle_root": ".", + "dataset_dir": "", + "dataset": "", + "evaluator": "", + "inferer": "", "load_old": 1, "model_dir": "$@bundle_root + '/models'", "output_dir": "$@bundle_root + '/output'", @@ -44,6 +48,7 @@ "with_encoder_nonlocal_attn": false, "with_decoder_nonlocal_attn": false }, + "network_def": "@autoencoder_def", "load_autoencoder_path": "$@model_dir + '/autoencoder.pth'", "load_autoencoder_func": "$@autoencoder_def.load_old_state_dict if bool(@load_old) else @autoencoder_def.load_state_dict", "load_autoencoder": "$@load_autoencoder_func(torch.load(@load_autoencoder_path))", diff --git a/models/cxr_image_synthesis_latent_diffusion_model/large_files.yml b/models/cxr_image_synthesis_latent_diffusion_model/large_files.yml index 673824cb..7691a875 100644 --- a/models/cxr_image_synthesis_latent_diffusion_model/large_files.yml +++ b/models/cxr_image_synthesis_latent_diffusion_model/large_files.yml @@ -1,9 +1,9 @@ large_files: - - path: "models/autoencoder.pth" + - path: "models/autoencoder.pt" url: "https://drive.google.com/uc?export=download&id=1paDN1m-Q_Oy8d_BanPkRTi3RlNB_Sv_h" hash_val: "7f579cb789597db7bb5de1488f54bc6c" hash_type: "md5" - - path: "models/diffusion_model.pth" + - path: "models/diffusion_model.pt" url: "https://drive.google.com/uc?export=download&id=1CjcmiPu5_QWr-f7wDJsXrCCcVeczneGT" hash_val: "c3fd4c8e38cd1d7250a8903cca935823" hash_type: "md5"