Skip to content

Commit

Permalink
added file
Browse files Browse the repository at this point in the history
  • Loading branch information
vaibhav shukla authored and vaibhav shukla committed Mar 12, 2019
1 parent 37863cd commit f613cc9
Show file tree
Hide file tree
Showing 28 changed files with 3,058 additions and 0 deletions.
Binary file added archi.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added dae.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file added data_loader/__init__.py
Empty file.
337 changes: 337 additions & 0 deletions data_loader/data_loader_18.py

Large diffs are not rendered by default.

168 changes: 168 additions & 0 deletions data_loader/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import os
import numpy as np
import math
import random
import cv2 as cv
import nibabel as nib
import torch

# in: volume path
# out: volume data in array
def readVol(volpath):
return nib.load(volpath).get_data()

# in: volume array
# out: comprise to uint8, put 0 where number<0
def to_uint8(vol):
vol=vol.astype(np.float)
vol[vol<0]=0
return ((vol-vol.min())*255.0/vol.max()).astype(np.uint8)

# in: volume array
# out: comprise to uint8, put 0 where number<800
def IR_to_uint8(vol):
vol=vol.astype(np.float)
vol[vol<0]=0
return ((vol-800)*255.0/vol.max()).astype(np.uint8)

# in: volume array
# out: hist equalized volume arrray
def histeq(vol):
for slice_index in range(vol.shape[2]):
vol[:,:,slice_index]=cv.equalizeHist(vol[:,:,slice_index])
return vol

# in: volume array
# out: preprocessed array
def preprocessed(vol):
for slice_index in range(vol.shape[2]):
cur_slice=vol[:,:,slice_index]
sob_x=cv.Sobel(cur_slice,cv.CV_16S,1,0)
sob_y=cv.Sobel(cur_slice,cv.CV_16S,0,1)
absX=cv.convertScaleAbs(sob_x)
absY=cv.convertScaleAbs(sob_y)
sob=cv.addWeighted(absX,0.5,absY,0.5,0)
dst=cur_slice+0.5*sob
vol[:,:,slice_index]=dst
return vol

# in: index of slice, stack number, slice number
# out: which slice should be stacked
def get_stackindex(slice_index, stack_num, slice_num):
assert stack_num%2==1, 'stack numbers must be odd!'
query_list=[0]*stack_num
for stack_index in range(stack_num):
query_list[stack_index]=(slice_index+(stack_index-int(stack_num/2)))%slice_num
return query_list

# in: volume array, stack number
# out: stacked img in list
def get_stacked(vol,stack_num):
stack_list=[]
stacked_slice=np.zeros((vol.shape[0],vol.shape[1],stack_num),np.uint8)
for slice_index in range(vol.shape[2]):
query_list=get_stackindex(slice_index,stack_num,vol.shape[2])
for index_query_list,query_list_content in enumerate(query_list):
stacked_slice[:,:,index_query_list]=vol[:,:,query_list_content].transpose()
stack_list.append(stacked_slice.copy())
return stack_list

# in: stacked img, rotate angle
# out: rotated imgs
def rotate(stack_list,angle,interp):
for stack_list_index,stacked in enumerate(stack_list):
raws,cols=stacked.shape[0:2]
M=cv.getRotationMatrix2D(((cols-1)/2.0,(raws-1)/2.0),angle,1)
stack_list[stack_list_index]=cv.warpAffine(stacked,M,(cols,raws),flags=interp)
return stack_list

# in: T1 volume, foreground threshold, margin pixel numbers
# out: which region should be cropped
def calc_crop_region(stack_list_T1,thre,pix):
crop_region=[]
for stack_list_index,stacked in enumerate(stack_list_T1):
_,threimg=cv.threshold(stacked[:,:,int(stacked.shape[2]/2)].copy(),thre,255,cv.THRESH_TOZERO)
pix_index=np.where(threimg>0)
if not pix_index[0].size==0:
y_min,y_max=min(pix_index[0]),max(pix_index[0])
x_min,x_max=min(pix_index[1]),max(pix_index[1])
else:
y_min,y_max=pix,pix
x_min,x_max=pix,pix
y_min=(y_min<=pix)and(0)or(y_min)
y_max=(y_max>=stacked.shape[0]-1-pix)and(stacked.shape[0]-1)or(y_max)
x_min=(x_min<=pix)and(0)or(x_min)
x_max=(x_max>=stacked.shape[1]-1-pix)and(stacked.shape[1]-1)or(x_max)
crop_region.append([y_min,y_max,x_min,x_max])
return crop_region

# in: crop region for each slice, how many slices in a stack
# out: max region in a stacked img
def calc_max_region_list(region_list,stack_num):
max_region_list=[]
for region_list_index in range(len(region_list)):
y_min_list,y_max_list,x_min_list,x_max_list=[],[],[],[]
for stack_index in range(stack_num):
query_list=get_stackindex(region_list_index,stack_num,len(region_list))
region=region_list[query_list[stack_index]]
y_min_list.append(region[0])
y_max_list.append(region[1])
x_min_list.append(region[2])
x_max_list.append(region[3])
max_region_list.append([min(y_min_list),max(y_max_list),min(x_min_list),max(x_max_list)])
return max_region_list

# in: size, devider
# out: padded size which can be devide by devider
def calc_ceil_pad(x,devider):
return math.ceil(x/float(devider))*devider

# in: stack img list, maxed region list
# out: cropped img list
def crop(stack_list,region_list):
cropped_list=[]
for stack_list_index,stacked in enumerate(stack_list):
y_min,y_max,x_min,x_max=region_list[stack_list_index]
cropped=np.zeros((calc_ceil_pad(y_max-y_min,16),calc_ceil_pad(x_max-x_min,16),stacked.shape[2]),np.uint8)
cropped[0:y_max-y_min,0:x_max-x_min,:]=stacked[y_min:y_max,x_min:x_max,:]
cropped_list.append(cropped.copy())
return cropped_list

# in: stack lbl list, dilate iteration
# out: stack edge list
def get_edge(stack_list,kernel_size=(3,3),sigmaX=0):
edge_list=[]
for stacked in stack_list:
edges=np.zeros((stacked.shape[0],stacked.shape[1],stacked.shape[2]),np.uint8)
for slice_index in range(stacked.shape[2]):
edges[:,:,slice_index]=cv.Canny(stacked[:,:,slice_index],1,1)
edges[:,:,slice_index]=cv.GaussianBlur(edges[:,:,slice_index],kernel_size,sigmaX)
edge_list.append(edges)
return edge_list





if __name__=='__main__':
T1_path='../../data/training/1/pre/reg_T1.nii.gz'
vol=to_uint8(readVol(T1_path))
print(vol.shape)
print('vol[100,100,20]= ', vol[100,100,20])
histeqed=histeq(vol)
print('vol[100,100,20]= ', vol[100,100,20])
print('query list: ', get_stackindex(1,5,histeqed.shape[2]))
stack_list=get_stacked(histeqed,5)
print(len(stack_list))
print(stack_list[0].shape)
angle=random.uniform(-15,15)
print('angle= ', angle)
rotated=rotate(stack_list,angle)
print(len(rotated))
region=calc_crop_region(rotated,50,5)
max_region=calc_max_region_list(region,5)
print(region)
print(max_region)
cropped=crop(rotated,max_region)
for i in range(48):
print(cropped[i].shape)
192 changes: 192 additions & 0 deletions evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# -*- coding: utf-8 -*-

import difflib
import numpy as np
import os
import SimpleITK as sitk
import scipy.spatial

# Set the path to the source data (e.g. the training data for self-testing)
# and the output directory of that subject
testDir = 'evaluation' # For example: '/input/2'
participantDir = 'evaluation' # For example: '/output/2'


labels = {1: 'Cortical gray matter',
2: 'Basal ganglia',
3: 'White matter',
4: 'White matter lesions',
5: 'Cerebrospinal fluid in the extracerebral space',
6: 'Ventricles',
7: 'Cerebellum',
8: 'Brain stem',
# The two labels below are ignored:
#9: 'Infarction',
#10: 'Other',
}


def do():
"""Main function"""
resultFilename = getResultFilename(participantDir)

testImage, resultImage = getImages(os.path.join(testDir, 'segm.nii.gz'), resultFilename)

dsc = getDSC(testImage, resultImage)
h95 = getHausdorff(testImage, resultImage)
vs = getVS(testImage, resultImage)

print('Dice', dsc, '(higher is better, max=1)')
print('HD', h95, 'mm', '(lower is better, min=0)')
print('VS', vs, '(higher is better, max=1)')



def getResultFilename(participantDir):
"""Find the filename of the result image.
This should be result.nii.gz or result.nii. If these files are not present,
it tries to find the closest filename."""
files = os.listdir(participantDir)

if not files:
raise Exception("No results in "+ participantDir)

resultFilename = None
if 'result.nii.gz' in files:
resultFilename = os.path.join(participantDir, 'result.nii.gz')
elif 'result.nii' in files:
resultFilename = os.path.join(participantDir, 'result.nii')
else:
# Find the filename that is closest to 'result.nii.gz'
maxRatio = -1
for f in files:
currentRatio = difflib.SequenceMatcher(a = f, b = 'result.nii.gz').ratio()

if currentRatio > maxRatio:
resultFilename = os.path.join(participantDir, f)
maxRatio = currentRatio

return resultFilename


def getImages(testFilename, resultFilename):
"""Return the test and result images, thresholded and pathology masked."""
testImage = sitk.ReadImage(testFilename)
resultImage = sitk.ReadImage(resultFilename)

# Check for equality
assert testImage.GetSize() == resultImage.GetSize()

# Get meta data from the test-image, needed for some sitk methods that check this
resultImage.CopyInformation(testImage)

# Remove pathology from the test and result images, since we don't evaluate on that
pathologyImage = sitk.BinaryThreshold(testImage, 9, 11, 0, 1) # pathology == 9 or 10

maskedTestImage = sitk.Mask(testImage, pathologyImage) # tissue == 1 -- 8
maskedResultImage = sitk.Mask(resultImage, pathologyImage)

# Force integer
if not 'integer' in maskedResultImage.GetPixelIDTypeAsString():
maskedResultImage = sitk.Cast(maskedResultImage, sitk.sitkUInt8)

return maskedTestImage, maskedResultImage


def getDSC(testImage, resultImage):
"""Compute the Dice Similarity Coefficient."""
dsc = dict()
for k in labels.keys():
testArray = sitk.GetArrayFromImage(sitk.BinaryThreshold( testImage, k, k, 1, 0)).flatten()
resultArray = sitk.GetArrayFromImage(sitk.BinaryThreshold(resultImage, k, k, 1, 0)).flatten()

# similarity = 1.0 - dissimilarity
# scipy.spatial.distance.dice raises a ZeroDivisionError if both arrays contain only zeros.
try:
dsc[k] = 1.0 - scipy.spatial.distance.dice(testArray, resultArray)
except ZeroDivisionError:
dsc[k] = None

return dsc


def getHausdorff(testImage, resultImage):
"""Compute the 95% Hausdorff distance."""
hd = dict()
for k in labels.keys():
lTestImage = sitk.BinaryThreshold( testImage, k, k, 1, 0)
lResultImage = sitk.BinaryThreshold(resultImage, k, k, 1, 0)

# Hausdorff distance is only defined when something is detected
statistics = sitk.StatisticsImageFilter()
statistics.Execute(lTestImage)
lTestSum = statistics.GetSum()
statistics.Execute(lResultImage)
lResultSum = statistics.GetSum()
if lTestSum == 0 or lResultSum == 0:
hd[k] = None
continue

# Edge detection is done by ORIGINAL - ERODED, keeping the outer boundaries of lesions. Erosion is performed in 2D
eTestImage = sitk.BinaryErode(lTestImage, (1,1,0))
eResultImage = sitk.BinaryErode(lResultImage, (1,1,0))

hTestImage = sitk.Subtract(lTestImage, eTestImage)
hResultImage = sitk.Subtract(lResultImage, eResultImage)

hTestArray = sitk.GetArrayFromImage(hTestImage)
hResultArray = sitk.GetArrayFromImage(hResultImage)

# Convert voxel location to world coordinates. Use the coordinate system of the test image
# np.nonzero = elements of the boundary in numpy order (zyx)
# np.flipud = elements in xyz order
# np.transpose = create tuples (x,y,z)
# testImage.TransformIndexToPhysicalPoint converts (xyz) to world coordinates (in mm)
# (Simple)ITK does not accept all Numpy arrays; therefore we need to convert the coordinate tuples into a Python list before passing them to TransformIndexToPhysicalPoint().
testCoordinates = [testImage.TransformIndexToPhysicalPoint(x.tolist()) for x in np.transpose( np.flipud( np.nonzero(hTestArray) ))]
resultCoordinates = [testImage.TransformIndexToPhysicalPoint(x.tolist()) for x in np.transpose( np.flipud( np.nonzero(hResultArray) ))]

# Use a kd-tree for fast spatial search
def getDistancesFromAtoB(a, b):
kdTree = scipy.spatial.KDTree(a, leafsize=100)
return kdTree.query(b, k=1, eps=0, p=2)[0]

# Compute distances from test to result and vice versa.
dTestToResult = getDistancesFromAtoB(testCoordinates, resultCoordinates)
dResultToTest = getDistancesFromAtoB(resultCoordinates, testCoordinates)
hd[k] = max(np.percentile(dTestToResult, 95), np.percentile(dResultToTest, 95))

return hd


def getVS(testImage, resultImage):
"""Volume similarity.
VS = 1 - abs(A - B) / (A + B)
A = ground truth in ML
B = participant segmentation in ML
"""
# Compute statistics of both images
testStatistics = sitk.StatisticsImageFilter()
resultStatistics = sitk.StatisticsImageFilter()

vs = dict()
for k in labels.keys():
testStatistics.Execute(sitk.BinaryThreshold(testImage, k, k, 1, 0))
resultStatistics.Execute(sitk.BinaryThreshold(resultImage, k, k, 1, 0))

numerator = abs(testStatistics.GetSum() - resultStatistics.GetSum())
denominator = testStatistics.GetSum() + resultStatistics.GetSum()

if denominator > 0:
vs[k] = 1 - float(numerator) / denominator
else:
vs[k] = None

return vs


if __name__ == "__main__":
do()
Loading

0 comments on commit f613cc9

Please sign in to comment.