From 363e16f7a61baa5e600971cac5fbf2cfc50bf233 Mon Sep 17 00:00:00 2001 From: Chris Yeung Date: Sat, 27 May 2023 18:48:06 -0400 Subject: [PATCH] Save image size as metadata along with TorchScript model --- UltrasoundSegmentation/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/UltrasoundSegmentation/train.py b/UltrasoundSegmentation/train.py index 9027805..879c992 100644 --- a/UltrasoundSegmentation/train.py +++ b/UltrasoundSegmentation/train.py @@ -14,6 +14,7 @@ import torch import os import sys +import json import yaml import wandb import numpy as np @@ -370,7 +371,9 @@ def main(args): model.eval() example_input = torch.rand(1, config["in_channels"], config["image_size"], config["image_size"]) traced_script_module = torch.jit.trace(model, example_input) - traced_script_module.save(ts_model_path) + d = {"shape": example_input.shape} + extra_files = {"config.json": json.dumps(d)} + traced_script_module.save(ts_model_path, _extra_files=extra_files) logging.info(f"Saved traced model to {ts_model_path}.") run.finish()