-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
vaibhav shukla
authored and
vaibhav shukla
committed
Mar 12, 2019
1 parent
37863cd
commit f613cc9
Showing
28 changed files
with
3,058 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import os | ||
import numpy as np | ||
import math | ||
import random | ||
import cv2 as cv | ||
import nibabel as nib | ||
import torch | ||
|
||
# in: volume path | ||
# out: volume data in array | ||
def readVol(volpath): | ||
return nib.load(volpath).get_data() | ||
|
||
# in: volume array | ||
# out: comprise to uint8, put 0 where number<0 | ||
def to_uint8(vol): | ||
vol=vol.astype(np.float) | ||
vol[vol<0]=0 | ||
return ((vol-vol.min())*255.0/vol.max()).astype(np.uint8) | ||
|
||
# in: volume array | ||
# out: comprise to uint8, put 0 where number<800 | ||
def IR_to_uint8(vol): | ||
vol=vol.astype(np.float) | ||
vol[vol<0]=0 | ||
return ((vol-800)*255.0/vol.max()).astype(np.uint8) | ||
|
||
# in: volume array | ||
# out: hist equalized volume arrray | ||
def histeq(vol): | ||
for slice_index in range(vol.shape[2]): | ||
vol[:,:,slice_index]=cv.equalizeHist(vol[:,:,slice_index]) | ||
return vol | ||
|
||
# in: volume array | ||
# out: preprocessed array | ||
def preprocessed(vol): | ||
for slice_index in range(vol.shape[2]): | ||
cur_slice=vol[:,:,slice_index] | ||
sob_x=cv.Sobel(cur_slice,cv.CV_16S,1,0) | ||
sob_y=cv.Sobel(cur_slice,cv.CV_16S,0,1) | ||
absX=cv.convertScaleAbs(sob_x) | ||
absY=cv.convertScaleAbs(sob_y) | ||
sob=cv.addWeighted(absX,0.5,absY,0.5,0) | ||
dst=cur_slice+0.5*sob | ||
vol[:,:,slice_index]=dst | ||
return vol | ||
|
||
# in: index of slice, stack number, slice number | ||
# out: which slice should be stacked | ||
def get_stackindex(slice_index, stack_num, slice_num): | ||
assert stack_num%2==1, 'stack numbers must be odd!' | ||
query_list=[0]*stack_num | ||
for stack_index in range(stack_num): | ||
query_list[stack_index]=(slice_index+(stack_index-int(stack_num/2)))%slice_num | ||
return query_list | ||
|
||
# in: volume array, stack number | ||
# out: stacked img in list | ||
def get_stacked(vol,stack_num): | ||
stack_list=[] | ||
stacked_slice=np.zeros((vol.shape[0],vol.shape[1],stack_num),np.uint8) | ||
for slice_index in range(vol.shape[2]): | ||
query_list=get_stackindex(slice_index,stack_num,vol.shape[2]) | ||
for index_query_list,query_list_content in enumerate(query_list): | ||
stacked_slice[:,:,index_query_list]=vol[:,:,query_list_content].transpose() | ||
stack_list.append(stacked_slice.copy()) | ||
return stack_list | ||
|
||
# in: stacked img, rotate angle | ||
# out: rotated imgs | ||
def rotate(stack_list,angle,interp): | ||
for stack_list_index,stacked in enumerate(stack_list): | ||
raws,cols=stacked.shape[0:2] | ||
M=cv.getRotationMatrix2D(((cols-1)/2.0,(raws-1)/2.0),angle,1) | ||
stack_list[stack_list_index]=cv.warpAffine(stacked,M,(cols,raws),flags=interp) | ||
return stack_list | ||
|
||
# in: T1 volume, foreground threshold, margin pixel numbers | ||
# out: which region should be cropped | ||
def calc_crop_region(stack_list_T1,thre,pix): | ||
crop_region=[] | ||
for stack_list_index,stacked in enumerate(stack_list_T1): | ||
_,threimg=cv.threshold(stacked[:,:,int(stacked.shape[2]/2)].copy(),thre,255,cv.THRESH_TOZERO) | ||
pix_index=np.where(threimg>0) | ||
if not pix_index[0].size==0: | ||
y_min,y_max=min(pix_index[0]),max(pix_index[0]) | ||
x_min,x_max=min(pix_index[1]),max(pix_index[1]) | ||
else: | ||
y_min,y_max=pix,pix | ||
x_min,x_max=pix,pix | ||
y_min=(y_min<=pix)and(0)or(y_min) | ||
y_max=(y_max>=stacked.shape[0]-1-pix)and(stacked.shape[0]-1)or(y_max) | ||
x_min=(x_min<=pix)and(0)or(x_min) | ||
x_max=(x_max>=stacked.shape[1]-1-pix)and(stacked.shape[1]-1)or(x_max) | ||
crop_region.append([y_min,y_max,x_min,x_max]) | ||
return crop_region | ||
|
||
# in: crop region for each slice, how many slices in a stack | ||
# out: max region in a stacked img | ||
def calc_max_region_list(region_list,stack_num): | ||
max_region_list=[] | ||
for region_list_index in range(len(region_list)): | ||
y_min_list,y_max_list,x_min_list,x_max_list=[],[],[],[] | ||
for stack_index in range(stack_num): | ||
query_list=get_stackindex(region_list_index,stack_num,len(region_list)) | ||
region=region_list[query_list[stack_index]] | ||
y_min_list.append(region[0]) | ||
y_max_list.append(region[1]) | ||
x_min_list.append(region[2]) | ||
x_max_list.append(region[3]) | ||
max_region_list.append([min(y_min_list),max(y_max_list),min(x_min_list),max(x_max_list)]) | ||
return max_region_list | ||
|
||
# in: size, devider | ||
# out: padded size which can be devide by devider | ||
def calc_ceil_pad(x,devider): | ||
return math.ceil(x/float(devider))*devider | ||
|
||
# in: stack img list, maxed region list | ||
# out: cropped img list | ||
def crop(stack_list,region_list): | ||
cropped_list=[] | ||
for stack_list_index,stacked in enumerate(stack_list): | ||
y_min,y_max,x_min,x_max=region_list[stack_list_index] | ||
cropped=np.zeros((calc_ceil_pad(y_max-y_min,16),calc_ceil_pad(x_max-x_min,16),stacked.shape[2]),np.uint8) | ||
cropped[0:y_max-y_min,0:x_max-x_min,:]=stacked[y_min:y_max,x_min:x_max,:] | ||
cropped_list.append(cropped.copy()) | ||
return cropped_list | ||
|
||
# in: stack lbl list, dilate iteration | ||
# out: stack edge list | ||
def get_edge(stack_list,kernel_size=(3,3),sigmaX=0): | ||
edge_list=[] | ||
for stacked in stack_list: | ||
edges=np.zeros((stacked.shape[0],stacked.shape[1],stacked.shape[2]),np.uint8) | ||
for slice_index in range(stacked.shape[2]): | ||
edges[:,:,slice_index]=cv.Canny(stacked[:,:,slice_index],1,1) | ||
edges[:,:,slice_index]=cv.GaussianBlur(edges[:,:,slice_index],kernel_size,sigmaX) | ||
edge_list.append(edges) | ||
return edge_list | ||
|
||
|
||
|
||
|
||
|
||
if __name__=='__main__': | ||
T1_path='../../data/training/1/pre/reg_T1.nii.gz' | ||
vol=to_uint8(readVol(T1_path)) | ||
print(vol.shape) | ||
print('vol[100,100,20]= ', vol[100,100,20]) | ||
histeqed=histeq(vol) | ||
print('vol[100,100,20]= ', vol[100,100,20]) | ||
print('query list: ', get_stackindex(1,5,histeqed.shape[2])) | ||
stack_list=get_stacked(histeqed,5) | ||
print(len(stack_list)) | ||
print(stack_list[0].shape) | ||
angle=random.uniform(-15,15) | ||
print('angle= ', angle) | ||
rotated=rotate(stack_list,angle) | ||
print(len(rotated)) | ||
region=calc_crop_region(rotated,50,5) | ||
max_region=calc_max_region_list(region,5) | ||
print(region) | ||
print(max_region) | ||
cropped=crop(rotated,max_region) | ||
for i in range(48): | ||
print(cropped[i].shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import difflib | ||
import numpy as np | ||
import os | ||
import SimpleITK as sitk | ||
import scipy.spatial | ||
|
||
# Set the path to the source data (e.g. the training data for self-testing) | ||
# and the output directory of that subject | ||
testDir = 'evaluation' # For example: '/input/2' | ||
participantDir = 'evaluation' # For example: '/output/2' | ||
|
||
|
||
labels = {1: 'Cortical gray matter', | ||
2: 'Basal ganglia', | ||
3: 'White matter', | ||
4: 'White matter lesions', | ||
5: 'Cerebrospinal fluid in the extracerebral space', | ||
6: 'Ventricles', | ||
7: 'Cerebellum', | ||
8: 'Brain stem', | ||
# The two labels below are ignored: | ||
#9: 'Infarction', | ||
#10: 'Other', | ||
} | ||
|
||
|
||
def do(): | ||
"""Main function""" | ||
resultFilename = getResultFilename(participantDir) | ||
|
||
testImage, resultImage = getImages(os.path.join(testDir, 'segm.nii.gz'), resultFilename) | ||
|
||
dsc = getDSC(testImage, resultImage) | ||
h95 = getHausdorff(testImage, resultImage) | ||
vs = getVS(testImage, resultImage) | ||
|
||
print('Dice', dsc, '(higher is better, max=1)') | ||
print('HD', h95, 'mm', '(lower is better, min=0)') | ||
print('VS', vs, '(higher is better, max=1)') | ||
|
||
|
||
|
||
def getResultFilename(participantDir): | ||
"""Find the filename of the result image. | ||
This should be result.nii.gz or result.nii. If these files are not present, | ||
it tries to find the closest filename.""" | ||
files = os.listdir(participantDir) | ||
|
||
if not files: | ||
raise Exception("No results in "+ participantDir) | ||
|
||
resultFilename = None | ||
if 'result.nii.gz' in files: | ||
resultFilename = os.path.join(participantDir, 'result.nii.gz') | ||
elif 'result.nii' in files: | ||
resultFilename = os.path.join(participantDir, 'result.nii') | ||
else: | ||
# Find the filename that is closest to 'result.nii.gz' | ||
maxRatio = -1 | ||
for f in files: | ||
currentRatio = difflib.SequenceMatcher(a = f, b = 'result.nii.gz').ratio() | ||
|
||
if currentRatio > maxRatio: | ||
resultFilename = os.path.join(participantDir, f) | ||
maxRatio = currentRatio | ||
|
||
return resultFilename | ||
|
||
|
||
def getImages(testFilename, resultFilename): | ||
"""Return the test and result images, thresholded and pathology masked.""" | ||
testImage = sitk.ReadImage(testFilename) | ||
resultImage = sitk.ReadImage(resultFilename) | ||
|
||
# Check for equality | ||
assert testImage.GetSize() == resultImage.GetSize() | ||
|
||
# Get meta data from the test-image, needed for some sitk methods that check this | ||
resultImage.CopyInformation(testImage) | ||
|
||
# Remove pathology from the test and result images, since we don't evaluate on that | ||
pathologyImage = sitk.BinaryThreshold(testImage, 9, 11, 0, 1) # pathology == 9 or 10 | ||
|
||
maskedTestImage = sitk.Mask(testImage, pathologyImage) # tissue == 1 -- 8 | ||
maskedResultImage = sitk.Mask(resultImage, pathologyImage) | ||
|
||
# Force integer | ||
if not 'integer' in maskedResultImage.GetPixelIDTypeAsString(): | ||
maskedResultImage = sitk.Cast(maskedResultImage, sitk.sitkUInt8) | ||
|
||
return maskedTestImage, maskedResultImage | ||
|
||
|
||
def getDSC(testImage, resultImage): | ||
"""Compute the Dice Similarity Coefficient.""" | ||
dsc = dict() | ||
for k in labels.keys(): | ||
testArray = sitk.GetArrayFromImage(sitk.BinaryThreshold( testImage, k, k, 1, 0)).flatten() | ||
resultArray = sitk.GetArrayFromImage(sitk.BinaryThreshold(resultImage, k, k, 1, 0)).flatten() | ||
|
||
# similarity = 1.0 - dissimilarity | ||
# scipy.spatial.distance.dice raises a ZeroDivisionError if both arrays contain only zeros. | ||
try: | ||
dsc[k] = 1.0 - scipy.spatial.distance.dice(testArray, resultArray) | ||
except ZeroDivisionError: | ||
dsc[k] = None | ||
|
||
return dsc | ||
|
||
|
||
def getHausdorff(testImage, resultImage): | ||
"""Compute the 95% Hausdorff distance.""" | ||
hd = dict() | ||
for k in labels.keys(): | ||
lTestImage = sitk.BinaryThreshold( testImage, k, k, 1, 0) | ||
lResultImage = sitk.BinaryThreshold(resultImage, k, k, 1, 0) | ||
|
||
# Hausdorff distance is only defined when something is detected | ||
statistics = sitk.StatisticsImageFilter() | ||
statistics.Execute(lTestImage) | ||
lTestSum = statistics.GetSum() | ||
statistics.Execute(lResultImage) | ||
lResultSum = statistics.GetSum() | ||
if lTestSum == 0 or lResultSum == 0: | ||
hd[k] = None | ||
continue | ||
|
||
# Edge detection is done by ORIGINAL - ERODED, keeping the outer boundaries of lesions. Erosion is performed in 2D | ||
eTestImage = sitk.BinaryErode(lTestImage, (1,1,0)) | ||
eResultImage = sitk.BinaryErode(lResultImage, (1,1,0)) | ||
|
||
hTestImage = sitk.Subtract(lTestImage, eTestImage) | ||
hResultImage = sitk.Subtract(lResultImage, eResultImage) | ||
|
||
hTestArray = sitk.GetArrayFromImage(hTestImage) | ||
hResultArray = sitk.GetArrayFromImage(hResultImage) | ||
|
||
# Convert voxel location to world coordinates. Use the coordinate system of the test image | ||
# np.nonzero = elements of the boundary in numpy order (zyx) | ||
# np.flipud = elements in xyz order | ||
# np.transpose = create tuples (x,y,z) | ||
# testImage.TransformIndexToPhysicalPoint converts (xyz) to world coordinates (in mm) | ||
# (Simple)ITK does not accept all Numpy arrays; therefore we need to convert the coordinate tuples into a Python list before passing them to TransformIndexToPhysicalPoint(). | ||
testCoordinates = [testImage.TransformIndexToPhysicalPoint(x.tolist()) for x in np.transpose( np.flipud( np.nonzero(hTestArray) ))] | ||
resultCoordinates = [testImage.TransformIndexToPhysicalPoint(x.tolist()) for x in np.transpose( np.flipud( np.nonzero(hResultArray) ))] | ||
|
||
# Use a kd-tree for fast spatial search | ||
def getDistancesFromAtoB(a, b): | ||
kdTree = scipy.spatial.KDTree(a, leafsize=100) | ||
return kdTree.query(b, k=1, eps=0, p=2)[0] | ||
|
||
# Compute distances from test to result and vice versa. | ||
dTestToResult = getDistancesFromAtoB(testCoordinates, resultCoordinates) | ||
dResultToTest = getDistancesFromAtoB(resultCoordinates, testCoordinates) | ||
hd[k] = max(np.percentile(dTestToResult, 95), np.percentile(dResultToTest, 95)) | ||
|
||
return hd | ||
|
||
|
||
def getVS(testImage, resultImage): | ||
"""Volume similarity. | ||
VS = 1 - abs(A - B) / (A + B) | ||
A = ground truth in ML | ||
B = participant segmentation in ML | ||
""" | ||
# Compute statistics of both images | ||
testStatistics = sitk.StatisticsImageFilter() | ||
resultStatistics = sitk.StatisticsImageFilter() | ||
|
||
vs = dict() | ||
for k in labels.keys(): | ||
testStatistics.Execute(sitk.BinaryThreshold(testImage, k, k, 1, 0)) | ||
resultStatistics.Execute(sitk.BinaryThreshold(resultImage, k, k, 1, 0)) | ||
|
||
numerator = abs(testStatistics.GetSum() - resultStatistics.GetSum()) | ||
denominator = testStatistics.GetSum() + resultStatistics.GetSum() | ||
|
||
if denominator > 0: | ||
vs[k] = 1 - float(numerator) / denominator | ||
else: | ||
vs[k] = None | ||
|
||
return vs | ||
|
||
|
||
if __name__ == "__main__": | ||
do() |
Oops, something went wrong.