diff --git a/radionets/evaluation/train_inspection.py b/radionets/evaluation/train_inspection.py index a47d00fd..5066573d 100644 --- a/radionets/evaluation/train_inspection.py +++ b/radionets/evaluation/train_inspection.py @@ -111,7 +111,7 @@ def get_prediction(conf, mode="test"): images["pred"] = pred images["indices"] = indices - if images["pred"].shape[-1] == 128: + if images["pred"].shape[-2] < images["pred"].shape[-1]: images = apply_symmetry(images) return images diff --git a/radionets/evaluation/utils.py b/radionets/evaluation/utils.py index ff04d520..e61b915b 100644 --- a/radionets/evaluation/utils.py +++ b/radionets/evaluation/utils.py @@ -301,8 +301,6 @@ def get_images(test_ds, num_images, rand=False, indices=None): img_test = test_ds[indices][0] img_true = test_ds[indices][1] - img_test = img_test[:, :, :65, :] - img_true = img_true[:, :, :65, :] return img_test, img_true, indices else: mean = test_ds[indices][0]