Skip to content

Commit

Permalink
Save image size as metadata along with TorchScript model
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscyyeung committed May 27, 2023
1 parent 953323c commit 363e16f
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion UltrasoundSegmentation/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch
import os
import sys
import json
import yaml
import wandb
import numpy as np
Expand Down Expand Up @@ -370,7 +371,9 @@ def main(args):
model.eval()
example_input = torch.rand(1, config["in_channels"], config["image_size"], config["image_size"])
traced_script_module = torch.jit.trace(model, example_input)
traced_script_module.save(ts_model_path)
d = {"shape": example_input.shape}
extra_files = {"config.json": json.dumps(d)}
traced_script_module.save(ts_model_path, _extra_files=extra_files)
logging.info(f"Saved traced model to {ts_model_path}.")

run.finish()
Expand Down

0 comments on commit 363e16f

Please sign in to comment.