Skip to content

Commit

Permalink
saving predictions code added
Browse files Browse the repository at this point in the history
  • Loading branch information
faizan1234567 committed Jan 26, 2025
1 parent 91a6c7f commit b1ffb1d
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 66 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,5 @@ thesisv1/
download_synapse.py
.synapseCache/ # synapse dataset
RawData/
segres_net_runs/
segres_net_runs/
paper_figs/
4 changes: 3 additions & 1 deletion conf/dataset/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,6 @@ colab: "/gdrive/MyDrive/BraTS2023/"
sines_pc: /drive/faizanai.rrl/datasets
dataset_folder: # specify your dataset path here if any of the above does not relate to you
version: brats2023
type: brats # brats
type: brats # brats
patient_id: "E:/Brats23 Data/Dataset/BraTS23_mapped/dataset/brats2023"
save_dir: paper_figs/exp1/modalities
232 changes: 171 additions & 61 deletions show.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,32 @@
import os
from pathlib import Path
import cv2
import random

import torch
from brats import get_datasets
from utils.visualizer import visualize_abnormal_area, get_labelled_image, visualize_data_gif
from utils.general import visualize_data_sample
from monai.inferers import sliding_window_inference

from monai.networks.nets import SwinUNETR
from monai.transforms import (
AsDiscrete,
Activations,
)

from monai.networks.nets import SwinUNETR, SegResNet, VNet, BasicUNetPlusPlus, AttentionUnet, DynUNet, UNETR
from networks.models.ResUNetpp.model import ResUnetPlusPlus
from networks.models.UNet.model import UNet3D
from networks.models.UX_Net.network_backbone import UXNET
from networks.models.nnformer.nnFormer_tumor import nnFormer
try:
from thesis.models.SegUXNet.model import SegUXNet
except ModuleNotFoundError:
print('model not available, please train with other models')
# sys.exit(1)

from functools import partial

# Logger
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -102,75 +123,164 @@ def overlay_mask(path, slice=75, type = "T1ce", save_path="media/results",
plt.show()


def predict_mask(dataset_folder,
target_size = (155, 240, 240),
dataset_version= "brats2023",
mode = "train"):
def save_slice_75_with_modality_name(patient_path, output_dir):
"""
Extracts slice 75 from each 3D NIfTI file in the patient's data directory and saves it as a PNG image,
using only the modality name extracted from the file name.
Args:
patient_path (str): Path to the patient's data folder (contains MRI modalities and segmentation mask files).
output_dir (str): Directory to save the extracted slice images as PNG.
"""
os.makedirs(output_dir, exist_ok=True) # Ensure the output directory exists

slice_index = 75

# Iterate through all NIfTI files in the directory
for file_name in os.listdir(patient_path):
if file_name.endswith(".nii") or file_name.endswith(".nii.gz"): # Process only NIfTI files
file_path = os.path.join(patient_path, file_name)

# Extract the modality part of the name (e.g., "t2w" from "BraTS-GLI-00008-001-t2w.nii.gz")
modality_name = file_name.split('-')[-1].split('.')[0]

# Load the modality
img = nib.load(file_path)
data = img.get_fdata()
data = np.rot90(data, k=-1)
# Check if slice 75 exists in the data
if slice_index < data.shape[2]:
slice_data = data[:, :, slice_index]

# Normalize the slice for better visualization (optional)
slice_data = (slice_data - np.min(slice_data)) / (np.max(slice_data) - np.min(slice_data) + 1e-8) * 255.0
slice_data = slice_data.astype(np.uint8)
slice_data = slice_data[40:220, 50:190]

# Save as PNG with the modality name
output_path = os.path.join(output_dir, f"{modality_name}_slice_{slice_index}.png")
plt.imsave(output_path, slice_data, cmap='gray')
print(f"Saved slice 75 of {modality_name} to {output_path}")
else:
print(f"Slice {slice_index} does not exist in {file_name}")



def process_mask(dataset_folder,
mode = "train",
save_dir = None):
random.seed(77)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = get_datasets(dataset_folder=dataset_folder,
mode=mode, target_size=target_size,
version=dataset_version)
loader = torch.utils.data.DataLoader(dataset,
batch_size=1,
shuffle=False, num_workers=12,
pin_memory=True)

# get a random sample for the train or test data
for data in loader:
patients_data = os.path.join(dataset_folder, mode)
ids = os.listdir(patients_data)
max_ids = len(ids)
rand_number = random.randint(0, max_ids)
rand_id = ids[rand_number]
rand_id_path = os.path.join(patients_data, rand_id)
save_slice_75_with_modality_name(rand_id_path, save_dir)



# @hydra.main(config_name='configs', config_path= 'conf', version_base=None)
# def show_result(cfg: DictConfig, args:argparse.Namespace):
# """
# Visualize labelled brain scan on a patient case, three options are available
# 1 - create brain scan slices and label them
# 2 - create a .gif format file to visualize part of brain (labelled)
# 3 - visualize a scan with it's label in a subplot format
# """

# # Load data
# dataset = get_datasets(cfg.dataset.dataset_folder, "test")
# data_loader = torch.utils.data.DataLoader(dataset,
# batch_size=1,
# shuffle=False, num_workers=8,
# pin_memory=True)

# # batch of data
# batch = next(iter(data_loader))
# image, label = batch["image"], batch['label']
# logger.info('visualizing an image with label')

# # Visualize
# if args.get_abnormal_area:
# visualize_abnormal_area(image, label)
# elif args.visualize_data_gif:
# labelled_img = get_labelled_image(image, label)
# visualize_data_gif(labelled_img)
# elif args.visualize_data_sample:
# visualize_data_sample(args.test_patient_path,
# cfg.paths.test_patient)
# else:
# logger.info('No option selected')
# sys.exit()

def save_test_label(args, patient_id, predict):
data_path = get_brats_folder(mode="test")
ref_img = sitk.ReadImage(os.path.join(data_path, f"{patient_id}/{patient_id}_t1.nii.gz"))
label_nii = sitk.GetImageFromArray(predict)
label_nii.CopyInformation(ref_img)
sitk.WriteImage(label_nii, os.path.join(args.label_folder, f"{patient_id}.nii.gz"))


def inference(model, input, batch_size, overlap):
def _compute(input):
return sliding_window_inference(inputs=input, roi_size=(128, 128, 128), sw_batch_size=batch_size, predictor=model, overlap=overlap)
return _compute(input)

def test(args, mode, data_loader, model):
for i, data in enumerate(data_loader):
patient_id = data["patient_id"][0]
inputs = data["image"]
targets = data["label"].to(device)
pad_list = data["pad_list"].to(device)
inputs = inputs.to(device)
pad_list = data["pad_list"]
nonzero_indexes = data["nonzero_indexes"]
inputs = inputs.cuda()
model.cuda()

with torch.no_grad():
# Perform inference without TTA (just a single pass)
predict = torch.sigmoid(inference(model, inputs, batch_size=2, overlap=0.6))

# Post-processing (crop out padding)
predict = predict[:, :, pad_list[-4]:predict.shape[2]-pad_list[-3], pad_list[-6]:predict.shape[3]-pad_list[-5], pad_list[-8]:predict.shape[4]-pad_list[-7]]
predict = (predict > 0.5).squeeze() # Binarize the output (optional thresholding)

@hydra.main(config_name='configs', config_path= 'conf', version_base=None)
def show_result(cfg: DictConfig, args:argparse.Namespace):
"""
Visualize labelled brain scan on a patient case, three options are available
1 - create brain scan slices and label them
2 - create a .gif format file to visualize part of brain (labelled)
3 - visualize a scan with it's label in a subplot format
"""

# Load data
dataset = get_datasets(cfg.dataset.dataset_folder, "test")
data_loader = torch.utils.data.DataLoader(dataset,
batch_size=1,
shuffle=False, num_workers=8,
pin_memory=True)

# batch of data
batch = next(iter(data_loader))
image, label = batch["image"], batch['label']
logger.info('visualizing an image with label')

# Visualize
if args.get_abnormal_area:
visualize_abnormal_area(image, label)
elif args.visualize_data_gif:
labelled_img = get_labelled_image(image, label)
visualize_data_gif(labelled_img)
elif args.visualize_data_sample:
visualize_data_sample(args.test_patient_path,
cfg.paths.test_patient)
else:
logger.info('No option selected')
sys.exit()
# Reconstruct the prediction
full_predict = np.zeros((155, 240, 240))
predict = reconstruct_label(predict)
full_predict[slice(*nonzero_indexes[0]), slice(*nonzero_indexes[1]), slice(*nonzero_indexes[2])] = predict

# Save the model prediction
save_test_label(args, mode, patient_id, full_predict)

def reconstruct_label(image):
if type(image) == torch.Tensor:
image = image.cpu().numpy()
c1, c2, c3 = image[0], image[1], image[2]
image = (c3 > 0).astype(np.uint8)
image[(c2 == False)*(c3 == True)] = 2
image[(c1 == True)*(c3 == True)] = 4
return image


@hydra.main(config_name='configs', config_path= 'conf', version_base=None)
def main(cfg: DictConfig):
process_mask(dataset_folder=cfg.dataset.patient_id,
mode='train',
save_dir =cfg.dataset.save_dir)
print('Done!!!')

if __name__ == "__main__":
main()
# Visualize
parser = argparse.ArgumentParser()
parser.add_argument("--type", choices=["show-abnormal-image", "get-gif", "show-case"], default="get-gif",
help="visulization options")
parser.add_argument("--scan_path", default= "media/Brats18_2013_21_1", type = str,
help = "path to patient MRI scan")
parser.add_argument("--modality", default= "T1ce", type = str,
help = "type of modality type for analysis and visualization")
# parser = argparse.ArgumentParser()
# parser.add_argument("--scan_path", default= "media/Brats18_2013_21_1", type = str,
# help = "path to patient MRI scan")
# parser.add_argument("--modality", default= "T1ce", type = str,
# help = "type of modality type for analysis and visualization")

args = parser.parse_args()
# args = parser.parse_args()
# show_result(args)
overlay_mask(path=args.scan_path, slice=75, type="T1ce",
save_path="media/qualitative_results", img_name="ground_truth")
print('Done!!!')
# overlay_mask(path=args.scan_path, slice=75, type="T1ce",
# save_path="media/qualitative_results", img_name="ground_truth")

10 changes: 7 additions & 3 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from monai.handlers.utils import from_engine
from monai.metrics import DiceMetric
from utils.general import load_pretrained_model
from utils.all_utils import save_seg_csv, cal_confuse, cal_dice
from utils.all_utils import save_seg_csv, cal_confuse, cal_dice, save_test_label
from brats import get_datasets
from utils.meter import AverageMeter

Expand Down Expand Up @@ -105,6 +105,7 @@ def test(args, data_loader, model):
inputs = data["image"]
targets = data["label"].cuda()
pad_list = data["pad_list"]
nonzero_indexes = data["nonzero_indexes"]
inputs = inputs.cuda()
model.cuda()
with torch.no_grad():
Expand Down Expand Up @@ -136,7 +137,12 @@ def test(args, data_loader, model):
et_hd=et_hd, tc_hd=tc_hd, wt_hd=wt_hd,
et_sens=et_sens, tc_sens=tc_sens, wt_sens=wt_sens,
et_spec=et_spec, tc_spec=tc_spec, wt_spec=wt_spec))
full_predict = np.zeros((155, 240, 240))
predict = reconstruct_label(predict)
full_predict[slice(*nonzero_indexes[0]), slice(*nonzero_indexes[1]), slice(*nonzero_indexes[2])] = predict
save_test_label(args, patient_id, full_predict)
save_seg_csv(metrics_dict, args)



@hydra.main(config_name='configs', config_path= 'conf', version_base=None)
Expand Down Expand Up @@ -279,8 +285,6 @@ def main(cfg: DictConfig):
batch_size=batch_size,
shuffle=False, num_workers=workers,
pin_memory=True)


print("start test")
test(cfg, test_loader, model)

Expand Down

0 comments on commit b1ffb1d

Please sign in to comment.