Skip to content

Commit

Permalink
Update main function of bundle application (#1585)
Browse files Browse the repository at this point in the history
Signed-off-by: Andres <[email protected]>
  • Loading branch information
diazandr3s authored Nov 3, 2023
1 parent 5e9732e commit 5ef153f
Showing 1 changed file with 55 additions and 28 deletions.
83 changes: 55 additions & 28 deletions sample-apps/monaibundle/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,12 +133,11 @@ def init_scoring_methods(self) -> Dict[str, ScoringMethod]:

def main():
import argparse
import shutil
from pathlib import Path

from monailabel.config import settings
from monailabel.utils.others.generic import device_list, file_ext

settings.MONAI_LABEL_DATASTORE_AUTO_RELOAD = False
settings.MONAI_LABEL_DATASTORE_FILE_EXT = ["*.png", "*.jpg", "*.jpeg", ".nii", ".nii.gz"]
os.putenv("MASTER_ADDR", "127.0.0.1")
os.putenv("MASTER_PORT", "1234")

Expand All @@ -154,43 +153,71 @@ def main():

parser = argparse.ArgumentParser()
parser.add_argument("-s", "--studies", default=studies)
parser.add_argument("-m", "--model", default="wholeBody_ct_segmentation")
parser.add_argument("-t", "--test", default="infer", choices=("train", "infer", "batch_infer"))
args = parser.parse_args()

app_dir = os.path.dirname(__file__)
studies = args.studies
conf = {
"models": args.model,
"preload": "false",
}

app = MyApp(app_dir, studies, conf)

# Infer
if args.test == "infer":
sample = app.next_sample(request={"strategy": "first"})
image_id = sample["id"]
image_path = sample["path"]

# Run on all devices
for device in device_list():
res = app.infer(request={"model": args.model, "image": image_id, "device": device})
label = res["file"]
label_json = res["params"]
test_dir = os.path.join(args.studies, "test_labels")
os.makedirs(test_dir, exist_ok=True)

label_file = os.path.join(test_dir, image_id + file_ext(image_path))
shutil.move(label, label_file)

print(label_json)
print(f"++++ Image File: {image_path}")
print(f"++++ Label File: {label_file}")
break
return

# Batch Infer
if args.test == "batch_infer":
app.batch_infer(
request={
"model": args.model,
"multi_gpu": False,
"save_label": True,
"label_tag": "original",
"max_workers": 1,
"max_batch_size": 0,
}
)
return

app = MyApp(app_dir, studies, {"preload": "false", "models": "spleen_deepedit_annotation"})
# train(app)
infer(app)


def infer(app):
import json
import shutil

res = app.infer(
request={
"model": "spleen_deepedit_annotation",
"image": "image",
}
)

print(json.dumps(res, indent=2))
shutil.move(res["label"], os.path.join(app.studies, "test"))
logger.info("All Done!")


def train(app):
# Train
app.train(
request={
"model": "spleen_deepedit_annotation",
"max_epochs": 2,
"model": args.model,
"max_epochs": 10,
"dataset": "Dataset", # PersistentDataset, CacheDataset
"train_batch_size": 1,
"val_batch_size": 1,
"multi_gpu": False,
"val_split": 0.1,
"val_interval": 1,
},
)


if __name__ == "__main__":
# export PYTHONPATH=~/Projects/MONAILabel:`pwd`
# python main.py
main()

0 comments on commit 5ef153f

Please sign in to comment.