diff --git a/zetta_utils/convnet/utils.py b/zetta_utils/convnet/utils.py index 3433e00f7..56f49884d 100644 --- a/zetta_utils/convnet/utils.py +++ b/zetta_utils/convnet/utils.py @@ -123,5 +123,5 @@ def load_and_run_model(path, data_in, device=None, use_cache=True): # pragma: n with torch.inference_mode(): # uses less memory when used with JITs with torch.autocast(device_type=autocast_device): output = model(tensor_ops.convert.to_torch(data_in, device=device)) - output = tensor_ops.convert.astype(output, reference=data_in) + output = tensor_ops.convert.astype(output, reference=data_in, cast=True) return output