diff --git a/UltrasoundSegmentation/train.py b/UltrasoundSegmentation/train.py index 9027805..879c992 100644 --- a/UltrasoundSegmentation/train.py +++ b/UltrasoundSegmentation/train.py @@ -14,6 +14,7 @@ import torch import os import sys +import json import yaml import wandb import numpy as np @@ -370,7 +371,9 @@ def main(args): model.eval() example_input = torch.rand(1, config["in_channels"], config["image_size"], config["image_size"]) traced_script_module = torch.jit.trace(model, example_input) - traced_script_module.save(ts_model_path) + d = {"shape": example_input.shape} + extra_files = {"config.json": json.dumps(d)} + traced_script_module.save(ts_model_path, _extra_files=extra_files) logging.info(f"Saved traced model to {ts_model_path}.") run.finish()