diff --git a/models/spleen_ct_segmentation_real_time/scripts/inference.py b/models/spleen_ct_segmentation_real_time/scripts/inference.py index e00b3d7f..7ae61e1d 100644 --- a/models/spleen_ct_segmentation_real_time/scripts/inference.py +++ b/models/spleen_ct_segmentation_real_time/scripts/inference.py @@ -51,6 +51,8 @@ class InferenceWorkflow(PythonicWorkflow): workflow.dataflow.update(input_loader({"image": "/workspace/Data/Task09_Spleen/imagesTr/spleen_38.nii.gz"})) workflow.run() + # get output + output = workflow.dataflow[CommonKeys.PRED] """ def __init__(self, workflow_type: str = "inference", properties_path: str = "./properties.json"): @@ -94,8 +96,9 @@ def initialize(self): def run(self): data = self.dataset[0] inputs = data[CommonKeys.IMAGE].unsqueeze(0).to(self.device) - # define sliding window size and batch size for windows inference - data[CommonKeys.PRED] = self.inferer(inputs, self.net) + self.net.eval() + with torch.no_grad(): + data[CommonKeys.PRED] = self.inferer(inputs, self.net) self.dataflow.update({CommonKeys.PRED: self.postprocessing(data)[CommonKeys.PRED]}) def finalize(self):