Skip to content

Commit

Permalink
Add option to save model as TorchScript for inference in Slicer
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscyyeung committed May 23, 2023
1 parent dcd588f commit 953323c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
14 changes: 13 additions & 1 deletion UltrasoundSegmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ def parse_args():
parser.add_argument("--val-data-folder", type=str)
parser.add_argument("--output-dir", type=str)
parser.add_argument("--config-file", type=str, default="train_config.yaml")
parser.add_argument("--save-torchscript", action="store_true")
parser.add_argument("--save-ckpt-freq", type=int, default=0)
parser.add_argument("--wandb-project-name", type=str, default="aigt_ultrasound_segmentation")
parser.add_argument("--wandb-exp-name", type=str)
parser.add_argument("--log-level", type=str, default="INFO")
parser.add_argument("--save-log", action="store_true")
parser.add_argument("--save-ckpt-freq", type=int, default=0)
try:
return parser.parse_args()
except SystemExit as err:
Expand Down Expand Up @@ -361,6 +362,17 @@ def main(args):
model_path = os.path.join(run_dir, "model.pt")
torch.save(model.state_dict(), model_path)
logging.info(f"Saved model to {model_path}.")

# Save model as TorchScript
if args.save_torchscript:
ts_model_path = os.path.join(run_dir, "model_traced.pt")
model = model.to("cpu")
model.eval()
example_input = torch.rand(1, config["in_channels"], config["image_size"], config["image_size"])
traced_script_module = torch.jit.trace(model, example_input)
traced_script_module.save(ts_model_path)
logging.info(f"Saved traced model to {ts_model_path}.")

run.finish()


Expand Down
3 changes: 2 additions & 1 deletion UltrasoundSegmentation/train_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

model_name: "monai_unet"
loss_function: "monai_dice"
image_size: 128
in_channels: !!int 1
out_channels: !!int 2
num_epochs: !!int 10
num_epochs: !!int 40
batch_size: !!int 32
learning_rate: !!float 0.004
learning_rate_decay_factor: !!float 0.5
Expand Down

0 comments on commit 953323c

Please sign in to comment.