Skip to content

Commit

Permalink
Added unit test to test the CXR sampling, which works now. Modified v…
Browse files Browse the repository at this point in the history
…erify_bundle to pass the check for model.pt, since the requirement for two models (autoencoder and diffusion_model) makes sense for them to keep their specific names.

Modification of inference.json to add dummy attributes to pass the ConfigWorkflow check.
Modification of large_files.yml so that models are .pt and not .pth.
  • Loading branch information
Virginia committed Jan 13, 2025
1 parent 5f95a10 commit 529e78b
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 2 deletions.
45 changes: 45 additions & 0 deletions ci/unit_tests/test_cxr_image_synthesis_latent_diffusion_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import unittest
from parameterized import parameterized
import numpy as np
from monai.bundle import ConfigWorkflow
from utils import check_workflow

TEST_CASE_1 = [ # inference
{
"bundle_root": "models/cxr_image_synthesis_latent_diffusion_model",
"prompt": "Big right-sided pleural effusion. Normal left lung.",
"guidance_scale": 7.0,
}
]

class TestCXRLatentDiffusionInference(unittest.TestCase):
@parameterized.expand([TEST_CASE_1])
def test_inference(self, params):
bundle_root = params["bundle_root"]
inference_file = os.path.join(bundle_root, "configs/inference.json")
trainer = ConfigWorkflow(
workflow_type="inference",
config_file=inference_file,
logging_file=os.path.join(bundle_root, "configs/logging.conf"),
meta_file=os.path.join(bundle_root, "configs/metadata.json"),
**params,
)
check_workflow(trainer, check_properties=True)

if __name__ == "__main__":
loader = unittest.TestLoader()
unittest.main(testLoader=loader)
2 changes: 2 additions & 0 deletions ci/verify_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def _get_weights_names(bundle: str):
if bundle == "pediatric_abdominal_ct_segmentation":
# skip test for this bundle's ts file
return "dynunet_FT.pt", None
if bundle == "cxr_image_synthesis_latent_diffusion_model":
return "autoencoder.pt", None
return "model.pt", "model.ts"


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
"$from transformers import CLIPTokenizer"
],
"bundle_root": ".",
"dataset_dir": "",
"dataset": "",
"evaluator": "",
"inferer": "",
"load_old": 1,
"model_dir": "$@bundle_root + '/models'",
"output_dir": "$@bundle_root + '/output'",
Expand Down Expand Up @@ -44,6 +48,7 @@
"with_encoder_nonlocal_attn": false,
"with_decoder_nonlocal_attn": false
},
"network_def": "@autoencoder_def",
"load_autoencoder_path": "$@model_dir + '/autoencoder.pth'",
"load_autoencoder_func": "$@autoencoder_def.load_old_state_dict if bool(@load_old) else @autoencoder_def.load_state_dict",
"load_autoencoder": "$@load_autoencoder_func(torch.load(@load_autoencoder_path))",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
large_files:
- path: "models/autoencoder.pth"
- path: "models/autoencoder.pt"
url: "https://drive.google.com/uc?export=download&id=1paDN1m-Q_Oy8d_BanPkRTi3RlNB_Sv_h"
hash_val: "7f579cb789597db7bb5de1488f54bc6c"
hash_type: "md5"
- path: "models/diffusion_model.pth"
- path: "models/diffusion_model.pt"
url: "https://drive.google.com/uc?export=download&id=1CjcmiPu5_QWr-f7wDJsXrCCcVeczneGT"
hash_val: "c3fd4c8e38cd1d7250a8903cca935823"
hash_type: "md5"

0 comments on commit 529e78b

Please sign in to comment.