diff --git a/archi.jpg b/archi.jpg new file mode 100644 index 0000000..a70b149 Binary files /dev/null and b/archi.jpg differ diff --git a/dae.png b/dae.png new file mode 100644 index 0000000..d40ec69 Binary files /dev/null and b/dae.png differ diff --git a/data_loader/__init__.py b/data_loader/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/data_loader/data_loader_18.py b/data_loader/data_loader_18.py new file mode 100755 index 0000000..43788c5 --- /dev/null +++ b/data_loader/data_loader_18.py @@ -0,0 +1,337 @@ +import os +import torch +import numpy as np +import math +import random +import cv2 as cv +import nibabel as nib +import torch +from torch.utils import data +import torchvision.transforms as transforms +import matplotlib.pyplot as plt +import pandas as pd + +from data_loader.preprocess import readVol,to_uint8,IR_to_uint8,histeq,preprocessed,get_stacked,rotate,calc_crop_region,calc_max_region_list,crop,get_edge + +class MR18loader_CV(data.Dataset): + def __init__(self,root='../../data/',val_num=5,is_val=False, + is_transform=False,is_flip=False,is_rotate=False,is_crop=False,is_histeq=False,forest=5): + self.root=root + self.val_num=val_num + self.is_val=is_val + self.is_transform=is_transform + self.is_flip=is_flip + self.is_rotate=is_rotate + self.is_crop=is_crop + self.is_histeq=is_histeq + self.forest=forest + self.n_classes=11 + # Back: Background + # GM: Cortical GM(red), Basal ganglia(green) + # WM: WM(yellow), WM lesions(blue) + # CSF: CSF(pink), Ventricles(light blue) + # Back: Cerebellum(white), Brainstem(dark red) + self.color=np.asarray([[0,0,0],[0,0,255],[0,255,0],[0,255,255],[255,0,0],\ + [255,0,255],[255,255,0],[255,255,255],[0,0,128],[0,128,0],[128,0,0]]).astype(np.uint8) + # Back , CSF , GM , WM + self.label_test=[0,2,2,3,3,1,1,0,0] + # nii paths + self.T1path=[self.root+'training/'+name+'/pre/reg_T1.nii.gz' for name in ['1','4','5','7','14','070','148']] + self.IRpath=[self.root+'training/'+name+'/pre/IR.nii.gz' for name in ['1','4','5','7','14','070','148']] + self.T2path=[self.root+'training/'+name+'/pre/FLAIR.nii.gz' for name in ['1','4','5','7','14','070','148']] + self.lblpath=[self.root+'training/'+name+'/segm.nii.gz' for name in ['1','4','5','7','14','070','148']] + + # val path + self.val_T1path=self.T1path[self.val_num-1] + self.val_IRpath=self.IRpath[self.val_num-1] + self.val_T2path=self.T2path[self.val_num-1] + self.val_lblpath=self.lblpath[self.val_num-1] + # train path + self.train_T1path=[temp for temp in self.T1path if temp not in [self.val_T1path]] + self.train_IRpath=[temp for temp in self.IRpath if temp not in [self.val_IRpath]] + self.train_T2path=[temp for temp in self.T2path if temp not in [self.val_T2path]] + self.train_lblpath=[temp for temp in self.lblpath if temp not in [self.val_lblpath]] + + if self.is_val==False: + print('training data') + T1_nii=[to_uint8(readVol(path)) for path in self.train_T1path] + IR_nii=[IR_to_uint8(readVol(path)) for path in self.train_IRpath] + T2_nii=[to_uint8(readVol(path)) for path in self.train_T2path] + lbl_nii=[readVol(path) for path in self.train_lblpath] + + if self.is_flip: + vol_num=len(T1_nii) + for nums in range(vol_num): + T1_nii.append(np.array([cv.flip(slice_,1) for slice_ in T1_nii[nums]])) + IR_nii.append(np.array([cv.flip(slice_,1) for slice_ in IR_nii[nums]])) + T2_nii.append(np.array([cv.flip(slice_,1) for slice_ in T2_nii[nums]])) + lbl_nii.append(np.array([cv.flip(slice_,1) for slice_ in lbl_nii[nums]])) + + if self.is_histeq: + print('hist equalizing......') + T1_nii=[histeq(vol) for vol in T1_nii] + IR_nii=[vol for vol in IR_nii] + T2_nii=[vol for vol in T2_nii] + + print('get stacking......') + T1_stack_lists=[get_stacked(vol,self.forest) for vol in T1_nii] + IR_stack_lists=[get_stacked(vol,self.forest) for vol in IR_nii] + T2_stack_lists=[get_stacked(vol,self.forest) for vol in T2_nii] + lbl_stack_lists=[get_stacked(vol,self.forest) for vol in lbl_nii] + + if self.is_rotate: + print('rotating......') + angle_list=[5,-5,10,-10,15,-15] + sample_num=len(T1_stack_lists) + for angle in angle_list: + for sample_index in range(sample_num): + T1_stack_lists.append(rotate(T1_stack_lists[sample_index],angle,interp=cv.INTER_CUBIC).copy()) + IR_stack_lists.append(rotate(IR_stack_lists[sample_index],angle,interp=cv.INTER_CUBIC).copy()) + T2_stack_lists.append(rotate(T2_stack_lists[sample_index],angle,interp=cv.INTER_CUBIC).copy()) + lbl_stack_lists.append(rotate(lbl_stack_lists[sample_index],angle,interp=cv.INTER_NEAREST).copy()) + + if self.is_crop: + print('cropping......') + region_lists=[calc_max_region_list(calc_crop_region(T1_stack_list,50,5),self.forest) for T1_stack_list in T1_stack_lists] + self.region_lists=region_lists + T1_stack_lists=[crop(stack_list,region_lists[list_index]) for list_index,stack_list in enumerate(T1_stack_lists)] + IR_stack_lists=[crop(stack_list,region_lists[list_index]) for list_index,stack_list in enumerate(IR_stack_lists)] + T2_stack_lists=[crop(stack_list,region_lists[list_index]) for list_index,stack_list in enumerate(T2_stack_lists)] + lbl_stack_lists=[crop(stack_list,region_lists[list_index]) for list_index,stack_list in enumerate(lbl_stack_lists)] + ''' + print('len=',len(T1_stack_lists)) + T1_path_list=[] + IR_path_list=[] + T2_path_list=[] + lbl_path_list=[] + range_list=[] + name=['1','4','5','7','14','070','148'] + f_n=['n','f'] + ang=['0','5','-5','10','-10','15','-15'] + save_path='../../../../data/' + for sam_i,sample in enumerate(T1_stack_lists): + for img_j,img in enumerate(sample): + T1_path_list.append('imgs/'+'T1/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j)) + path=save_path+'imgs/'+'T1/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j) + cv.imwrite(path,img) + for sam_i,sample in enumerate(IR_stack_lists): + for img_j,img in enumerate(sample): + IR_path_list.append('imgs/'+'IR/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j)) + path=save_path+'imgs/'+'IR/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j) + cv.imwrite(path,img) + for sam_i,sample in enumerate(T2_stack_lists): + for img_j,img in enumerate(sample): + T2_path_list.append('imgs/'+'T2/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j)) + path=save_path+'imgs/'+'T2/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j) + cv.imwrite(path,img) + for sam_i,sample in enumerate(lbl_stack_lists): + for img_j,img in enumerate(sample): + lbl_path_list.append('lbls/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j)) + path=save_path+'lbls/'+'{}_{}_{}_{}.png'.format(name[sam_i%7],f_n[(int(sam_i/7))%2],ang[int(sam_i/14)],img_j) + print(img.shape) + cv.imwrite(path,img) + for sam_i,sample in enumerate(region_lists): + for img_j,img in enumerate(sample): + range_list.append(img) + range_array=np.array(range_list) + y_min_list=range_array[:,0] + y_max_list=range_array[:,1] + x_min_list=range_array[:,2] + x_max_list=range_array[:,3] + df=pd.DataFrame({ 'T1':T1_path_list,'IR':IR_path_list,'T2':T2_path_list,'lbl':lbl_path_list, + 'y_min':y_min_list,'y_max':y_max_list,'x_min':x_min_list,'x_max':x_max_list}) + print(df) + df.to_csv("index.csv") + ''' + # get means + T1mean,IRmean,T2mean=0.0,0.0,0.0 + for samples in T1_stack_lists: + for stacks in samples: + T1mean=T1mean+np.mean(stacks) + T1mean=T1mean/(len(T1_stack_lists)*len(T1_stack_lists[0])) + print('T1 mean = ',T1mean) + self.T1mean=T1mean + for samples in IR_stack_lists: + for stacks in samples: + IRmean=IRmean+np.mean(stacks) + IRmean=IRmean/(len(IR_stack_lists)*len(IR_stack_lists[0])) + print('IR mean = ',IRmean) + self.IRmean=IRmean + for samples in T2_stack_lists: + for stacks in samples: + T2mean=T2mean+np.mean(stacks) + T2mean=T2mean/(len(T2_stack_lists)*len(T2_stack_lists[0])) + print('T2 mean = ',T2mean) + self.T2mean=T2mean + + # get edegs + print('getting edges') + edge_stack_lists=[] + for samples in lbl_stack_lists: + edge_stack_lists.append(get_edge(samples)) + + # transform + if self.is_transform: + print('transforming') + for sample_index in range(len(T1_stack_lists)): + for stack_index in range(len(T1_stack_lists[0])): + T1_stack_lists[sample_index][stack_index], \ + IR_stack_lists[sample_index][stack_index], \ + T2_stack_lists[sample_index][stack_index], \ + lbl_stack_lists[sample_index][stack_index], \ + edge_stack_lists[sample_index][stack_index]=\ + self.transform( \ + T1_stack_lists[sample_index][stack_index], \ + IR_stack_lists[sample_index][stack_index], \ + T2_stack_lists[sample_index][stack_index], \ + lbl_stack_lists[sample_index][stack_index], \ + edge_stack_lists[sample_index][stack_index]) + + else: + print('validating data') + T1_nii=to_uint8(readVol(self.val_T1path)) + IR_nii=IR_to_uint8(readVol(self.val_IRpath)) + T2_nii=to_uint8(readVol(self.val_T2path)) + lbl_nii=readVol(self.val_lblpath) + + if self.is_histeq: + print('hist equalizing......') + T1_nii=histeq(T1_nii) + IR_nii=IR_nii + T1_nii=T1_nii + + print('get stacking......') + T1_stack_lists=get_stacked(T1_nii,self.forest) + IR_stack_lists=get_stacked(IR_nii,self.forest) + T2_stack_lists=get_stacked(T2_nii,self.forest) + lbl_stack_lists=get_stacked(lbl_nii,self.forest) + + if self.is_crop: + print('cropping......') + region_lists=calc_max_region_list(calc_crop_region(T1_stack_lists,50,5),self.forest) + self.region_lists=region_lists + T1_stack_lists=crop(T1_stack_lists,region_lists) + IR_stack_lists=crop(IR_stack_lists,region_lists) + T2_stack_lists=crop(T2_stack_lists,region_lists) + lbl_stack_lists=crop(lbl_stack_lists,region_lists) + + # get means + T1mean,IRmean,T2mean=0.0,0.0,0.0 + for stacks in T1_stack_lists: + T1mean=T1mean+np.mean(stacks) + T1mean=T1mean/(len(T1_stack_lists)) + print('T1 mean = ',T1mean) + self.T1mean=T1mean + for stacks in IR_stack_lists: + IRmean=IRmean+np.mean(stacks) + IRmean=IRmean/(len(IR_stack_lists)) + print('IR mean = ',IRmean) + self.IRmean=IRmean + for stacks in T2_stack_lists: + T2mean=T2mean+np.mean(stacks) + T2mean=T2mean/(len(T2_stack_lists)) + print('T2 mean = ',T2mean) + self.T2mean=T2mean + + # get edges + print('getting edges') + edge_stack_lists=get_edge(lbl_stack_lists) + + # transform + if self.is_transform: + print('transforming') + for stack_index in range(len(T1_stack_lists)): + T1_stack_lists[stack_index], \ + IR_stack_lists[stack_index], \ + T2_stack_lists[stack_index], \ + lbl_stack_lists[stack_index], \ + edge_stack_lists[stack_index]=\ + self.transform( \ + T1_stack_lists[stack_index], \ + IR_stack_lists[stack_index], \ + T2_stack_lists[stack_index], \ + lbl_stack_lists[stack_index], \ + edge_stack_lists[stack_index]) + + # data ready + self.T1_stack_lists=T1_stack_lists + self.IR_stack_lists=IR_stack_lists + self.T2_stack_lists=T2_stack_lists + self.lbl_stack_lists=lbl_stack_lists + self.edge_stack_lists=edge_stack_lists + + + def __len__(self): + return (self.is_val)and(48)or(48*6*7*2) + def __getitem__(self,index): + # get train or validation data + if self.is_val==False: + set_index=range(len(self.T1_stack_lists)) + img_index=range(len(self.T1_stack_lists[0])) + return \ + self.region_lists[set_index[int(index/48)]][img_index[int(index%48)]], \ + self.T1_stack_lists[set_index[int(index/48)]][img_index[int(index%48)]],\ + self.IR_stack_lists[set_index[int(index/48)]][img_index[int(index%48)]],\ + self.T2_stack_lists[set_index[int(index/48)]][img_index[int(index%48)]],\ + self.lbl_stack_lists[set_index[int(index/48)]][img_index[int(index%48)]] + #self.edge_stack_lists[set_index[int(index/48)]][img_index[int(index%48)]] + + else: + img_index=range(len(self.T1_stack_lists)) + return \ + self.region_lists[img_index[int(index)]], \ + self.T1_stack_lists[img_index[int(index)]], \ + self.IR_stack_lists[img_index[int(index)]], \ + self.T2_stack_lists[img_index[int(index)]], \ + self.lbl_stack_lists[img_index[int(index)]] + #self.edge_stack_lists[img_index[int(index)]] + + + + + def transform(self,imgT1,imgIR,imgT2,lbl,edge): + imgT1=torch.from_numpy((imgT1.transpose(2,0,1).astype(np.float)-self.T1mean)/255.0).float() + imgIR=torch.from_numpy((imgIR.transpose(2,0,1).astype(np.float)-self.IRmean)/255.0).float() + imgT2=torch.from_numpy((imgT2.transpose(2,0,1).astype(np.float)-self.T2mean)/255.0).float() + lbl=torch.from_numpy(lbl.transpose(2,0,1)).long() + edge=torch.from_numpy(edge.transpose(2,0,1)/255).float() + return imgT1,imgIR,imgT2,lbl,edge + def decode_segmap(self,label_mask): + r,g,b=label_mask.copy(),label_mask.copy(),label_mask.copy() + for ll in range(0,self.n_classes): + r[label_mask==ll]=self.color[ll,2] + g[label_mask==ll]=self.color[ll,1] + b[label_mask==ll]=self.color[ll,0] + rgb=np.zeros((label_mask.shape[0],label_mask.shape[1],3)) + rgb[:,:,0],rgb[:,:,1],rgb[:,:,2]=r,g,b + return rgb + def lbl_totest(self,pred): + pred_test=np.zeros((pred.shape[0],pred.shape[1]),np.uint8) + for ll in range(9): + pred_test[pred==ll]=self.label_test[ll] + return pred_test + +if __name__=='__main__': + path='../../../../data/' + MRloader=MR18loader_CV(root=path,val_num=7,is_val=False,is_transform=True,is_flip=True,is_rotate=True,is_crop=True,is_histeq=True,forest=3) + loader=data.DataLoader(MRloader, batch_size=1, num_workers=1, shuffle=True) + for i,(regions,T1s,IRs,T2s,lbls) in enumerate(MRloader): + print(i) + #print(T1s.shape) + #print(regions) + #print(lbls.min()) + #print(lbls.max()) + #cv.imwrite(str(i)+'.png',T1s[:,:,1]) + #print(region) + #print(imgT1.shape) + #print(imgIR.shape) + #print(imgT2.shape) + #print(lbl.shape) + + #print('[{},{},{},{}]'.format(imgT1[0,2,40,40],imgIR[0,2,40,40],imgT2[0,2,40,40],lbl[0,2,40,40])) + + #cv.imwrite('T1-'+str(i)+'.png',imgT1[2]) + #cv.imwrite('IR-'+str(i)+'.png',imgIR[2]) + #cv.imwrite('T2-'+str(i)+'.png',imgT2[2]) + #cv.imwrite('lbl-'+str(i)+'.png',MRloader.decode_segmap(lbl[2])) + diff --git a/data_loader/preprocess.py b/data_loader/preprocess.py new file mode 100755 index 0000000..960c75e --- /dev/null +++ b/data_loader/preprocess.py @@ -0,0 +1,168 @@ +import os +import numpy as np +import math +import random +import cv2 as cv +import nibabel as nib +import torch + +# in: volume path +# out: volume data in array +def readVol(volpath): + return nib.load(volpath).get_data() + +# in: volume array +# out: comprise to uint8, put 0 where number<0 +def to_uint8(vol): + vol=vol.astype(np.float) + vol[vol<0]=0 + return ((vol-vol.min())*255.0/vol.max()).astype(np.uint8) + +# in: volume array +# out: comprise to uint8, put 0 where number<800 +def IR_to_uint8(vol): + vol=vol.astype(np.float) + vol[vol<0]=0 + return ((vol-800)*255.0/vol.max()).astype(np.uint8) + +# in: volume array +# out: hist equalized volume arrray +def histeq(vol): + for slice_index in range(vol.shape[2]): + vol[:,:,slice_index]=cv.equalizeHist(vol[:,:,slice_index]) + return vol + +# in: volume array +# out: preprocessed array +def preprocessed(vol): + for slice_index in range(vol.shape[2]): + cur_slice=vol[:,:,slice_index] + sob_x=cv.Sobel(cur_slice,cv.CV_16S,1,0) + sob_y=cv.Sobel(cur_slice,cv.CV_16S,0,1) + absX=cv.convertScaleAbs(sob_x) + absY=cv.convertScaleAbs(sob_y) + sob=cv.addWeighted(absX,0.5,absY,0.5,0) + dst=cur_slice+0.5*sob + vol[:,:,slice_index]=dst + return vol + +# in: index of slice, stack number, slice number +# out: which slice should be stacked +def get_stackindex(slice_index, stack_num, slice_num): + assert stack_num%2==1, 'stack numbers must be odd!' + query_list=[0]*stack_num + for stack_index in range(stack_num): + query_list[stack_index]=(slice_index+(stack_index-int(stack_num/2)))%slice_num + return query_list + +# in: volume array, stack number +# out: stacked img in list +def get_stacked(vol,stack_num): + stack_list=[] + stacked_slice=np.zeros((vol.shape[0],vol.shape[1],stack_num),np.uint8) + for slice_index in range(vol.shape[2]): + query_list=get_stackindex(slice_index,stack_num,vol.shape[2]) + for index_query_list,query_list_content in enumerate(query_list): + stacked_slice[:,:,index_query_list]=vol[:,:,query_list_content].transpose() + stack_list.append(stacked_slice.copy()) + return stack_list + +# in: stacked img, rotate angle +# out: rotated imgs +def rotate(stack_list,angle,interp): + for stack_list_index,stacked in enumerate(stack_list): + raws,cols=stacked.shape[0:2] + M=cv.getRotationMatrix2D(((cols-1)/2.0,(raws-1)/2.0),angle,1) + stack_list[stack_list_index]=cv.warpAffine(stacked,M,(cols,raws),flags=interp) + return stack_list + +# in: T1 volume, foreground threshold, margin pixel numbers +# out: which region should be cropped +def calc_crop_region(stack_list_T1,thre,pix): + crop_region=[] + for stack_list_index,stacked in enumerate(stack_list_T1): + _,threimg=cv.threshold(stacked[:,:,int(stacked.shape[2]/2)].copy(),thre,255,cv.THRESH_TOZERO) + pix_index=np.where(threimg>0) + if not pix_index[0].size==0: + y_min,y_max=min(pix_index[0]),max(pix_index[0]) + x_min,x_max=min(pix_index[1]),max(pix_index[1]) + else: + y_min,y_max=pix,pix + x_min,x_max=pix,pix + y_min=(y_min<=pix)and(0)or(y_min) + y_max=(y_max>=stacked.shape[0]-1-pix)and(stacked.shape[0]-1)or(y_max) + x_min=(x_min<=pix)and(0)or(x_min) + x_max=(x_max>=stacked.shape[1]-1-pix)and(stacked.shape[1]-1)or(x_max) + crop_region.append([y_min,y_max,x_min,x_max]) + return crop_region + +# in: crop region for each slice, how many slices in a stack +# out: max region in a stacked img +def calc_max_region_list(region_list,stack_num): + max_region_list=[] + for region_list_index in range(len(region_list)): + y_min_list,y_max_list,x_min_list,x_max_list=[],[],[],[] + for stack_index in range(stack_num): + query_list=get_stackindex(region_list_index,stack_num,len(region_list)) + region=region_list[query_list[stack_index]] + y_min_list.append(region[0]) + y_max_list.append(region[1]) + x_min_list.append(region[2]) + x_max_list.append(region[3]) + max_region_list.append([min(y_min_list),max(y_max_list),min(x_min_list),max(x_max_list)]) + return max_region_list + +# in: size, devider +# out: padded size which can be devide by devider +def calc_ceil_pad(x,devider): + return math.ceil(x/float(devider))*devider + +# in: stack img list, maxed region list +# out: cropped img list +def crop(stack_list,region_list): + cropped_list=[] + for stack_list_index,stacked in enumerate(stack_list): + y_min,y_max,x_min,x_max=region_list[stack_list_index] + cropped=np.zeros((calc_ceil_pad(y_max-y_min,16),calc_ceil_pad(x_max-x_min,16),stacked.shape[2]),np.uint8) + cropped[0:y_max-y_min,0:x_max-x_min,:]=stacked[y_min:y_max,x_min:x_max,:] + cropped_list.append(cropped.copy()) + return cropped_list + +# in: stack lbl list, dilate iteration +# out: stack edge list +def get_edge(stack_list,kernel_size=(3,3),sigmaX=0): + edge_list=[] + for stacked in stack_list: + edges=np.zeros((stacked.shape[0],stacked.shape[1],stacked.shape[2]),np.uint8) + for slice_index in range(stacked.shape[2]): + edges[:,:,slice_index]=cv.Canny(stacked[:,:,slice_index],1,1) + edges[:,:,slice_index]=cv.GaussianBlur(edges[:,:,slice_index],kernel_size,sigmaX) + edge_list.append(edges) + return edge_list + + + + + +if __name__=='__main__': + T1_path='../../data/training/1/pre/reg_T1.nii.gz' + vol=to_uint8(readVol(T1_path)) + print(vol.shape) + print('vol[100,100,20]= ', vol[100,100,20]) + histeqed=histeq(vol) + print('vol[100,100,20]= ', vol[100,100,20]) + print('query list: ', get_stackindex(1,5,histeqed.shape[2])) + stack_list=get_stacked(histeqed,5) + print(len(stack_list)) + print(stack_list[0].shape) + angle=random.uniform(-15,15) + print('angle= ', angle) + rotated=rotate(stack_list,angle) + print(len(rotated)) + region=calc_crop_region(rotated,50,5) + max_region=calc_max_region_list(region,5) + print(region) + print(max_region) + cropped=crop(rotated,max_region) + for i in range(48): + print(cropped[i].shape) diff --git a/evaluation.py b/evaluation.py new file mode 100755 index 0000000..36f3b4d --- /dev/null +++ b/evaluation.py @@ -0,0 +1,192 @@ +# -*- coding: utf-8 -*- + +import difflib +import numpy as np +import os +import SimpleITK as sitk +import scipy.spatial + +# Set the path to the source data (e.g. the training data for self-testing) +# and the output directory of that subject +testDir = 'evaluation' # For example: '/input/2' +participantDir = 'evaluation' # For example: '/output/2' + + +labels = {1: 'Cortical gray matter', + 2: 'Basal ganglia', + 3: 'White matter', + 4: 'White matter lesions', + 5: 'Cerebrospinal fluid in the extracerebral space', + 6: 'Ventricles', + 7: 'Cerebellum', + 8: 'Brain stem', + # The two labels below are ignored: + #9: 'Infarction', + #10: 'Other', + } + + +def do(): + """Main function""" + resultFilename = getResultFilename(participantDir) + + testImage, resultImage = getImages(os.path.join(testDir, 'segm.nii.gz'), resultFilename) + + dsc = getDSC(testImage, resultImage) + h95 = getHausdorff(testImage, resultImage) + vs = getVS(testImage, resultImage) + + print('Dice', dsc, '(higher is better, max=1)') + print('HD', h95, 'mm', '(lower is better, min=0)') + print('VS', vs, '(higher is better, max=1)') + + + +def getResultFilename(participantDir): + """Find the filename of the result image. + + This should be result.nii.gz or result.nii. If these files are not present, + it tries to find the closest filename.""" + files = os.listdir(participantDir) + + if not files: + raise Exception("No results in "+ participantDir) + + resultFilename = None + if 'result.nii.gz' in files: + resultFilename = os.path.join(participantDir, 'result.nii.gz') + elif 'result.nii' in files: + resultFilename = os.path.join(participantDir, 'result.nii') + else: + # Find the filename that is closest to 'result.nii.gz' + maxRatio = -1 + for f in files: + currentRatio = difflib.SequenceMatcher(a = f, b = 'result.nii.gz').ratio() + + if currentRatio > maxRatio: + resultFilename = os.path.join(participantDir, f) + maxRatio = currentRatio + + return resultFilename + + +def getImages(testFilename, resultFilename): + """Return the test and result images, thresholded and pathology masked.""" + testImage = sitk.ReadImage(testFilename) + resultImage = sitk.ReadImage(resultFilename) + + # Check for equality + assert testImage.GetSize() == resultImage.GetSize() + + # Get meta data from the test-image, needed for some sitk methods that check this + resultImage.CopyInformation(testImage) + + # Remove pathology from the test and result images, since we don't evaluate on that + pathologyImage = sitk.BinaryThreshold(testImage, 9, 11, 0, 1) # pathology == 9 or 10 + + maskedTestImage = sitk.Mask(testImage, pathologyImage) # tissue == 1 -- 8 + maskedResultImage = sitk.Mask(resultImage, pathologyImage) + + # Force integer + if not 'integer' in maskedResultImage.GetPixelIDTypeAsString(): + maskedResultImage = sitk.Cast(maskedResultImage, sitk.sitkUInt8) + + return maskedTestImage, maskedResultImage + + +def getDSC(testImage, resultImage): + """Compute the Dice Similarity Coefficient.""" + dsc = dict() + for k in labels.keys(): + testArray = sitk.GetArrayFromImage(sitk.BinaryThreshold( testImage, k, k, 1, 0)).flatten() + resultArray = sitk.GetArrayFromImage(sitk.BinaryThreshold(resultImage, k, k, 1, 0)).flatten() + + # similarity = 1.0 - dissimilarity + # scipy.spatial.distance.dice raises a ZeroDivisionError if both arrays contain only zeros. + try: + dsc[k] = 1.0 - scipy.spatial.distance.dice(testArray, resultArray) + except ZeroDivisionError: + dsc[k] = None + + return dsc + + +def getHausdorff(testImage, resultImage): + """Compute the 95% Hausdorff distance.""" + hd = dict() + for k in labels.keys(): + lTestImage = sitk.BinaryThreshold( testImage, k, k, 1, 0) + lResultImage = sitk.BinaryThreshold(resultImage, k, k, 1, 0) + + # Hausdorff distance is only defined when something is detected + statistics = sitk.StatisticsImageFilter() + statistics.Execute(lTestImage) + lTestSum = statistics.GetSum() + statistics.Execute(lResultImage) + lResultSum = statistics.GetSum() + if lTestSum == 0 or lResultSum == 0: + hd[k] = None + continue + + # Edge detection is done by ORIGINAL - ERODED, keeping the outer boundaries of lesions. Erosion is performed in 2D + eTestImage = sitk.BinaryErode(lTestImage, (1,1,0)) + eResultImage = sitk.BinaryErode(lResultImage, (1,1,0)) + + hTestImage = sitk.Subtract(lTestImage, eTestImage) + hResultImage = sitk.Subtract(lResultImage, eResultImage) + + hTestArray = sitk.GetArrayFromImage(hTestImage) + hResultArray = sitk.GetArrayFromImage(hResultImage) + + # Convert voxel location to world coordinates. Use the coordinate system of the test image + # np.nonzero = elements of the boundary in numpy order (zyx) + # np.flipud = elements in xyz order + # np.transpose = create tuples (x,y,z) + # testImage.TransformIndexToPhysicalPoint converts (xyz) to world coordinates (in mm) + # (Simple)ITK does not accept all Numpy arrays; therefore we need to convert the coordinate tuples into a Python list before passing them to TransformIndexToPhysicalPoint(). + testCoordinates = [testImage.TransformIndexToPhysicalPoint(x.tolist()) for x in np.transpose( np.flipud( np.nonzero(hTestArray) ))] + resultCoordinates = [testImage.TransformIndexToPhysicalPoint(x.tolist()) for x in np.transpose( np.flipud( np.nonzero(hResultArray) ))] + + # Use a kd-tree for fast spatial search + def getDistancesFromAtoB(a, b): + kdTree = scipy.spatial.KDTree(a, leafsize=100) + return kdTree.query(b, k=1, eps=0, p=2)[0] + + # Compute distances from test to result and vice versa. + dTestToResult = getDistancesFromAtoB(testCoordinates, resultCoordinates) + dResultToTest = getDistancesFromAtoB(resultCoordinates, testCoordinates) + hd[k] = max(np.percentile(dTestToResult, 95), np.percentile(dResultToTest, 95)) + + return hd + + +def getVS(testImage, resultImage): + """Volume similarity. + + VS = 1 - abs(A - B) / (A + B) + + A = ground truth in ML + B = participant segmentation in ML + """ + # Compute statistics of both images + testStatistics = sitk.StatisticsImageFilter() + resultStatistics = sitk.StatisticsImageFilter() + + vs = dict() + for k in labels.keys(): + testStatistics.Execute(sitk.BinaryThreshold(testImage, k, k, 1, 0)) + resultStatistics.Execute(sitk.BinaryThreshold(resultImage, k, k, 1, 0)) + + numerator = abs(testStatistics.GetSum() - resultStatistics.GetSum()) + denominator = testStatistics.GetSum() + resultStatistics.GetSum() + + if denominator > 0: + vs[k] = 1 - float(numerator) / denominator + else: + vs[k] = None + + return vs + + +if __name__ == "__main__": + do() diff --git a/loss.py b/loss.py new file mode 100755 index 0000000..4406026 --- /dev/null +++ b/loss.py @@ -0,0 +1,163 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Function, Variable + +def cross_entropy2d(input, target, weight=None, size_average=True): + + n, c, h, w = input.size() + nt, ct, ht, wt = target.size() + ''' + # Handle inconsistent size between input and target + if h > ht and w > wt: # upsample labels + target = target.unsequeeze(1) + target = F.upsample(target, size=(h, w), mode='nearest') + target = target.sequeeze(1) + elif h < ht and w < wt: # upsample images + input = F.upsample(input, size=(ht, wt), mode='bilinear') + elif h != ht and w != wt: + raise Exception("Only support upsampling") + ''' + log_p = F.log_softmax(input, dim=1) + log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) + log_p = log_p[target.contiguous().view(-1, 1).repeat(1, c) >= 0] + log_p = log_p.view(-1, c) + + mask = target >= 0 + target = target[mask] + loss = F.nll_loss(log_p, target, ignore_index=250, + weight=weight, size_average=False) + if size_average: + loss /= mask.data.sum().float() + return loss + +def loss_ce_t(input,target): + #input=F.sigmoid(input) + target_bin=Variable(torch.zeros(1,11,target.shape[2],target.shape[3]).cuda().scatter_(1,target,1)) + return F.binary_cross_entropy_with_logits(input,target_bin) + +def dice_loss(input, target): + target_bin=Variable(torch.zeros(target.shape[0],11,target.shape[2],target.shape[3]).cuda().scatter_(1,target,1)) + smooth = 1. + iflat = input.view(-1) + tflat = target_bin.view(-1) + intersection = (iflat * tflat).sum() + return 1 - ((2. * intersection + smooth) / + (iflat.sum() + tflat.sum() + smooth)) + +def weighted_loss(input,target,weight,size_average=True): + n,c,h,w=input.size() + target_bin=Variable(torch.zeros(n,c,h,w).cuda()).scatter_(1,target,1) + target_bin=target_bin.transpose(1,2).transpose(2,3).contiguous().view(n*h*w,c).float() + + # NHWC + input=F.softmax(input,dim=1).transpose(1,2).transpose(2,3).contiguous().view(n*h*w,c) + input=input[target_bin>=0] + input=input.view(-1,c) + weight=weight.transpose(1,2).transpose(2,3).contiguous() + weight=weight.view(n*h*w,1).repeat(1,c) + ''' + mask=target>=0 + target=target[mask] + target_bin=np.zeros((n*h*w,c),np.float) + for i,term in enumerate(target): + target_bin[i,int(term)]=1 + target_bin=torch.from_numpy(target_bin).float() + target_bin=Variable(target_bin.cuda()) + ''' + loss=F.binary_cross_entropy(input,target_bin,weight=weight,size_average=False) + if size_average: + loss/=(target_bin>=0).data.sum().float()/c + return loss + +def bce2d_hed(input, target): + """ + Binary Cross Entropy 2-Dimension loss + """ + n, c, h, w = input.size() + # assert(max(target) == 1) + log_p = input.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1) + target_t = target.transpose(1, 2).transpose(2, 3).contiguous().view(1, -1).float().cuda() + target_trans = target_t.clone() + pos_index = (target_t >0) + neg_index = (target_t ==0) + target_trans[pos_index] = 1 + target_trans[neg_index] = 0 + pos_index = pos_index.data.cpu().numpy().astype(bool) + neg_index = neg_index.data.cpu().numpy().astype(bool) + weight = torch.Tensor(log_p.size()).fill_(0) + weight = weight.numpy() + pos_num = pos_index.sum() + neg_num = neg_index.sum() + sum_num = pos_num + neg_num + weight[pos_index] = neg_num*1.0 / sum_num + weight[neg_index] = pos_num*1.0 / sum_num + + weight = torch.from_numpy(weight) + weight = weight.cuda() + loss = F.binary_cross_entropy(log_p, target_t, weight, size_average=True) + return loss + +def bootstrapped_cross_entropy2d(input, target, K, weight=None, size_average=True): + + batch_size = input.size()[0] + + def _bootstrap_xentropy_single(input, target, K, weight=None, size_average=True): + n, c, h, w = input.size() + log_p = F.log_softmax(input, dim=1) + log_p = log_p.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) + log_p = log_p[target.view(n * h * w, 1).repeat(1, c) >= 0] + log_p = log_p.view(-1, c) + + mask = target >= 0 + target = target[mask] + loss = F.nll_loss(log_p, target, weight=weight, ignore_index=250, + reduce=False, size_average=False) + topk_loss, _ = loss.topk(K) + reduced_topk_loss = topk_loss.sum() / K + + return reduced_topk_loss + + loss = 0.0 + # Bootstrap from each image not entire batch + for i in range(batch_size): + loss += _bootstrap_xentropy_single(input=torch.unsqueeze(input[i], 0), + target=torch.unsqueeze(target[i], 0), + K=K, + weight=weight, + size_average=size_average) + return loss / float(batch_size) + +# another implimentation for dice loss +import torch +from torch.autograd import Function, Variable +class DiceCoeff(Function): + """Dice coeff for individual examples""" + def forward(self, input, target): + self.save_for_backward(input, target) + self.inter = torch.dot(input.view(-1), target.view(-1)) + 0.0001 + self.union = torch.sum(input) + torch.sum(target) + 0.0001 + t = 2 * self.inter.float() / self.union.float() + return t + # This function has only a single output, so it gets only one gradient + def backward(self, grad_output): + input, target = self.saved_variables + grad_input = grad_target = None + if self.needs_input_grad[0]: + grad_input = grad_output * 2 * (target * self.union + self.inter) \ + / self.union * self.union + if self.needs_input_grad[1]: + grad_target = None + return grad_input, grad_target +def dice_coeff(input, target): + target_bin=Variable(torch.zeros(1,11,target.shape[2],target.shape[3]).cuda().scatter_(1,target,1).float()) + """Dice coeff for batches""" + if input.is_cuda: + s = torch.FloatTensor(1).cuda().zero_() + else: + s = torch.FloatTensor(1).zero_() + for i, c in enumerate(zip(input, target_bin)): + s = s + DiceCoeff().forward(c[0], c[1]) + return s / (i + 1) + diff --git a/metrics.py b/metrics.py new file mode 100755 index 0000000..6d8acae --- /dev/null +++ b/metrics.py @@ -0,0 +1,36 @@ +# Adapted from score written by wkentaro +# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py + +import numpy as np + +class runningScore(object): + def __init__(self, n_classes): + self.n_classes = n_classes + self.confusion_matrix = np.zeros((n_classes, n_classes)) + def _fast_hist(self, label_true, label_pred, n_class): + mask = (label_true >= 0) & (label_true < n_class) + hist = np.bincount(n_class*label_true[mask].astype(int)+label_pred[mask], minlength=n_class**2).reshape(n_class, n_class) + return hist + def update(self, label_trues, label_preds): + for lt, lp in zip(label_trues, label_preds): + self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes) + def get_scores(self): + hist = self.confusion_matrix + acc = np.diag(hist).sum() / hist.sum() + acc_cls = np.diag(hist) / hist.sum(axis=1) + acc_cls = np.nanmean(acc_cls) + iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) + dice=np.divide(np.multiply(iu,2),np.add(iu,1)) + mean_iu = np.nanmean(iu[1:9]) + mean_dice=(mean_iu*2)/(mean_iu+1) + freq = hist.sum(axis=1) / hist.sum() + fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() + cls_iu = dict(zip(range(self.n_classes), iu)) + + return {#'Overall Acc: \t': acc, + #'Mean Acc : \t': acc_cls, + #'FreqW Acc : \t': fwavacc, + 'Dice : \t': dice, + 'Mean Dice : \t': mean_dice,}, cls_iu + def reset(self): + self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) diff --git a/models/PAN.py b/models/PAN.py new file mode 100755 index 0000000..1b96a70 --- /dev/null +++ b/models/PAN.py @@ -0,0 +1,176 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +from torch.autograd import Variable +from models.resnet import resnet50 + +LAYER_THICK=16 + +class VGG16(nn.Module): + def __init__(self): + super(VGG16, self).__init__() + self.conv_block1 = nn.Sequential( + nn.Conv2d(3, 64, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block2 = nn.Sequential( + nn.Conv2d(64, 128, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block3 = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block4 = nn.Sequential( + nn.Conv2d(256, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block5 = nn.Sequential( + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.pool=nn.MaxPool2d(2, stride=2, ceil_mode=True) + def forward(self,x): + conv1 = self.conv_block1(x) + conv2 = self.conv_block2(self.pool(conv1)) + conv3 = self.conv_block3(self.pool(conv2)) + conv4 = self.conv_block4(self.pool(conv3)) + conv5 = self.conv_block5(self.pool(conv4)) + return conv1,conv2,conv3,conv4,conv5 + def init_vgg16_params(self, vgg16, copy_fc8=True): + blocks = [self.conv_block1, + self.conv_block2, + self.conv_block3, + self.conv_block4, + self.conv_block5] + + ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] + features = list(vgg16.features.children()) + for idx, conv_block in enumerate(blocks): + for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): + if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): + assert l1.weight.size() == l2.weight.size() + assert l1.bias.size() == l2.bias.size() + l2.weight.data = l1.weight.data + l2.bias.data = l1.bias.data + + +class FPN(nn.Module): + def __init__(self): + super(FPN, self).__init__() + #self.C5_conv = nn.Conv2d(in_channels=512, out_channels=LAYER_THICK, kernel_size=1, bias=False) + #self.C4_conv = nn.Conv2d(in_channels=512, out_channels=LAYER_THICK, kernel_size=1, bias=False) + #self.C3_conv = nn.Conv2d(in_channels=256, out_channels=LAYER_THICK, kernel_size=1, bias=False) + #self.C2_conv = nn.Conv2d(in_channels=128, out_channels=LAYER_THICK, kernel_size=1, bias=False) + + self.C5_conv = nn.Conv2d(in_channels=2048, out_channels=LAYER_THICK, kernel_size=1, bias=False) + self.C4_conv = nn.Conv2d(in_channels=1024, out_channels=LAYER_THICK, kernel_size=1, bias=False) + self.C3_conv = nn.Conv2d(in_channels=512, out_channels=LAYER_THICK, kernel_size=1, bias=False) + self.C2_conv = nn.Conv2d(in_channels=256, out_channels=LAYER_THICK, kernel_size=1, bias=False) + + #config = "in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False" + #self.P2_conv = nn.Conv2d(*config) + #self.P3_conv = nn.Conv2d(*config) + #self.P4_conv = nn.Conv2d(*config) + + self.P2_conv = nn.Conv2d(in_channels=LAYER_THICK, out_channels=LAYER_THICK, kernel_size=3, stride=1, padding=1, bias=False) + self.P3_conv = nn.Conv2d(in_channels=LAYER_THICK, out_channels=LAYER_THICK, kernel_size=3, stride=1, padding=1, bias=False) + self.P4_conv = nn.Conv2d(in_channels=LAYER_THICK, out_channels=LAYER_THICK, kernel_size=3, stride=1, padding=1, bias=False) + + + def forward(self, C_vector): + C1, C2, C3, C4, C5 = C_vector + _, _, c2_height, c2_width = C2.size() + _, _, c3_height, c3_width = C3.size() + _, _, c4_height, c4_width = C4.size() + + P5 = self.C5_conv(C5) + + P4 = F.upsample(P5, size=(c4_height, c4_width), mode='bilinear') + self.C4_conv(C4) + P4 = self.P4_conv(P4) + + P3 = F.upsample(P4, size=(c3_height, c3_width), mode='bilinear') + self.C3_conv(C3) + P3 = self.P3_conv(P3) + + P2 = F.upsample(P3, size=(c2_height, c2_width), mode='bilinear') + self.C2_conv(C2) + P2 = self.P2_conv(P2) + + return P2, P3, P4, P5 + +class generateN(nn.Module): + def __init__(self): + super(generateN, self).__init__() + + config = "in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, bias=False" + self.N2_conv = nn.Conv2d(in_channels=LAYER_THICK, out_channels=LAYER_THICK, kernel_size=3, stride=2, padding=1, bias=False) + self.N3_conv = nn.Conv2d(in_channels=LAYER_THICK, out_channels=LAYER_THICK, kernel_size=3, stride=2, padding=1, bias=False) + self.N4_conv = nn.Conv2d(in_channels=LAYER_THICK, out_channels=LAYER_THICK, kernel_size=3, stride=2, padding=1, bias=False) + + + config = "in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False" + self.N2_conv2 = nn.Conv2d(in_channels=LAYER_THICK, out_channels=LAYER_THICK, kernel_size=3, stride=1, padding=1, bias=False) + self.N3_conv2 = nn.Conv2d(in_channels=LAYER_THICK, out_channels=LAYER_THICK, kernel_size=3, stride=1, padding=1, bias=False) + self.N4_conv2 = nn.Conv2d(in_channels=LAYER_THICK, out_channels=LAYER_THICK, kernel_size=3, stride=1, padding=1, bias=False) + + def forward(self, P_vec): + [P2, P3, P4, P5] = P_vec + N2 = P2 + + N3 = self.N2_conv2(P3 + self.N2_conv(N2)) + N4 = self.N3_conv2(P4 + self.N3_conv(N3)) + N5 = self.N4_conv2(P5 + self.N4_conv(N4)) + + return N2, N3, N4, N5 + + +class PAN(nn.Module): + def __init__(self, training=False): + super(PAN, self).__init__() + #self.vgg = VGG16() + #vgg16=models.vgg16(pretrained=True) + #self.vgg.init_vgg16_params(vgg16) + self.res=resnet50() + self.fpn = FPN() + self.generateN = generateN() + self.training = training + + def forward(self, x): + """ + x : input image + """ + #C1, C2, C3, C4, C5 = self.vgg(x) + C1, C2, C3, C4, C5 = self.res(x) + P2, P3, P4, P5 = self.fpn([C1, C2, C3, C4, C5]) + N2, N3, N4, N5 = self.generateN([P2, P3, P4, P5]) + return N2,N3,N4,N5 + +class PAN_seg(nn.Module): + def __init__(self,n_classes=9): + super(PAN_seg,self).__init__() + self.n_classes=n_classes + self.PAN=PAN() + self.deconv2=nn.ConvTranspose2d(LAYER_THICK,LAYER_THICK,kernel_size=2,stride=2) + self.deconv3=nn.ConvTranspose2d(LAYER_THICK,LAYER_THICK,kernel_size=4,stride=4) + self.deconv4=nn.ConvTranspose2d(LAYER_THICK,LAYER_THICK,kernel_size=8,stride=8) + self.deconv5=nn.ConvTranspose2d(LAYER_THICK,LAYER_THICK,kernel_size=16,stride=16) + self.score=nn.Sequential( + nn.Conv2d(4*LAYER_THICK,self.n_classes,1), + nn.Dropout(0.5,) + ) + def forward(self,x): + conv2,conv3,conv4,conv5=self.PAN(x) + deconv2=self.deconv2(conv2) + deconv3=self.deconv3(conv3) + deconv4=self.deconv4(conv4) + deconv5=self.deconv5(conv5) + cat=torch.cat([deconv2,deconv3,deconv4,deconv5],1) + score=self.score(cat) + return score + +if __name__ == '__main__': + x=torch.Tensor(4,3,16,16) + x=Variable(x) + print(x.shape) + model=PAN_seg() + y=model(x) + print(y.shape) + diff --git a/models/VGG16.py b/models/VGG16.py new file mode 100755 index 0000000..8015102 --- /dev/null +++ b/models/VGG16.py @@ -0,0 +1,150 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import torchvision.models as models + +class VGG16(nn.Module): + def __init__(self): + super(VGG16, self).__init__() + self.conv_block1 = nn.Sequential( + nn.Conv2d(3, 64, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block2 = nn.Sequential( + nn.Conv2d(64, 128, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block3 = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block4 = nn.Sequential( + nn.Conv2d(256, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block5 = nn.Sequential( + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.pool=nn.MaxPool2d(2, stride=2, ceil_mode=True) + def forward(self,x): + conv1 = self.conv_block1(x) + conv2 = self.conv_block2(self.pool(conv1)) + conv3 = self.conv_block3(self.pool(conv2)) + conv4 = self.conv_block4(self.pool(conv3)) + conv5 = self.conv_block5(self.pool(conv4)) + return conv1,conv2,conv3,conv4,conv5 + def init_vgg16_params(self, vgg16, copy_fc8=True): + blocks = [self.conv_block1, + self.conv_block2, + self.conv_block3, + self.conv_block4, + self.conv_block5] + + ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] + features = list(vgg16.features.children()) + for idx, conv_block in enumerate(blocks): + for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): + if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): + assert l1.weight.size() == l2.weight.size() + assert l1.bias.size() == l2.bias.size() + l2.weight.data = l1.weight.data + l2.bias.data = l1.bias.data + + +class VGG16_dilated(nn.Module): + def __init__(self): + super(VGG16_dilated, self).__init__() + self.conv_block1 = nn.Sequential( + nn.Conv2d(3, 64, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block2 = nn.Sequential( + nn.Conv2d(64, 128, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block3 = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block4 = nn.Sequential( + nn.Conv2d(256, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block5 = nn.Sequential( + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.pool=nn.MaxPool2d(2, stride=2, ceil_mode=True) + def forward(self,x): + conv1 = self.conv_block1(x) + conv2 = self.conv_block2(self.pool(conv1)) + conv3 = self.conv_block3(self.pool(conv2)) + conv4 = self.conv_block4(self.pool(conv3)) + conv5 = self.conv_block5(self.pool(conv4)) + return conv1,conv2,conv3,conv4,conv5 + def init_vgg16_params(self, vgg16, copy_fc8=True): + blocks = [self.conv_block1, + self.conv_block2, + self.conv_block3, + self.conv_block4, + self.conv_block5] + + ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] + features = list(vgg16.features.children()) + for idx, conv_block in enumerate(blocks): + for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): + if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): + assert l1.weight.size() == l2.weight.size() + assert l1.bias.size() == l2.bias.size() + l2.weight.data = l1.weight.data + l2.bias.data = l1.bias.data + + +class VGG16_rcf(nn.Module): + def __init__(self): + super(VGG,self).__init__() + self.conv1_1=nn.Sequential(nn.Conv2d(3 , 64 , 3, padding=1),nn.ReLU(inplace=True),) + self.conv1_2=nn.Sequential(nn.Conv2d(64 , 64 , 3, padding=1),nn.ReLU(inplace=True),) + self.conv2_1=nn.Sequential(nn.Conv2d(64 , 128, 3, padding=1),nn.ReLU(inplace=True),) + self.conv2_2=nn.Sequential(nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True),) + self.conv3_1=nn.Sequential(nn.Conv2d(128, 256, 3, padding=1),nn.ReLU(inplace=True),) + self.conv3_2=nn.Sequential(nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True),) + self.conv3_3=nn.Sequential(nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True),) + self.conv4_1=nn.Sequential(nn.Conv2d(256, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.conv4_2=nn.Sequential(nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.conv4_3=nn.Sequential(nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.conv5_1=nn.Sequential(nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.conv5_2=nn.Sequential(nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.conv5_3=nn.Sequential(nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.pool=nn.MaxPool2d(2, stride=2, ceil_mode=True) + def forward(self, x): + conv1_1=self.conv1_1(x) + conv1_2=self.conv1_2(conv1_1) + conv2_1=self.conv2_1(self.pool(conv1_2)) + conv2_2=self.conv2_2(conv2_1) + conv3_1=self.conv3_1(self.pool(conv2_2)) + conv3_2=self.conv3_2(conv3_1) + conv3_3=self.conv3_3(conv3_2) + conv4_1=self.conv4_1(self.pool(conv3_3)) + conv4_2=self.conv4_2(conv4_1) + conv4_3=self.conv4_3(conv4_2) + conv5_1=self.conv5_1(self.pool(conv4_3)) + conv5_2=self.conv5_2(conv5_1) + conv5_3=self.conv5_3(conv5_2) + return conv1_1,conv1_2,conv2_1,conv2_2,conv3_1,conv3_2,conv3_3,conv4_1,conv4_2,conv4_3,conv5_1,conv5_2,conv5_3 + def init_vgg16_params(self,vgg16=models.vgg16(pretrained=True),copy_fc8=True): + convs=[ self.conv1_1,self.conv1_2, + self.conv2_1,self.conv2_2, + self.conv3_1,self.conv3_2,self.conv3_3, + self.conv4_1,self.conv4_2,self.conv4_3, + self.conv5_1,self.conv5_2,self.conv5_3] + features=list(vgg16.features.children()) + ranges=[0,2,5,7,10,12,14,17,19,21,24,26,28] + for idx,conv in enumerate(convs): + l1=features[ranges[idx]] + l2=conv[0] + if isinstance(l1,nn.Conv2d) and isinstance(l2,nn.Conv2d): + assert l1.weight.size()==l2.weight.size() + assert l1.bias.size()==l2.bias.size() + l2.weight.data=l1.weight.data + l2.bias.data=l1.bias.data + + diff --git a/models/__init__.py b/models/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/models/densenet.py b/models/densenet.py new file mode 100755 index 0000000..af3be04 --- /dev/null +++ b/models/densenet.py @@ -0,0 +1,173 @@ +import torch + +import torch.nn as nn +import torch.optim as optim + +import torch.nn.functional as F +from torch.autograd import Variable + +import torchvision.datasets as dset +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +import torchvision.models as models + +import sys +import math + +class Bottleneck(nn.Module): + def __init__(self, nChannels, growthRate): + super(Bottleneck, self).__init__() + interChannels = 4*growthRate + self.bn1 = nn.BatchNorm2d(nChannels) + self.conv1 = nn.Conv2d(nChannels, interChannels, kernel_size=1, + bias=False) + self.bn2 = nn.BatchNorm2d(interChannels) + self.conv2 = nn.Conv2d(interChannels, growthRate, kernel_size=3, + padding=1, bias=False) + def forward(self, x): + out = self.conv1(F.relu(self.bn1(x))) + out = self.conv2(F.relu(self.bn2(out))) + out = torch.cat((x, out), 1) + return out + +class SingleLayer(nn.Module): + def __init__(self, nChannels, growthRate): + super(SingleLayer, self).__init__() + self.bn1 = nn.BatchNorm2d(nChannels) + self.conv1 = nn.Conv2d(nChannels, growthRate, kernel_size=3, + padding=1, bias=False) + def forward(self, x): + out = self.conv1(F.relu(self.bn1(x))) + out = torch.cat((x, out), 1) + return out + +class Transition(nn.Module): + def __init__(self, nChannels, nOutChannels): + super(Transition, self).__init__() + self.bn1 = nn.BatchNorm2d(nChannels) + self.conv1 = nn.Conv2d(nChannels, nOutChannels, kernel_size=3, dilation=2, padding=2, + bias=False) + def forward(self, x): + out = self.conv1(F.relu(self.bn1(x))) + #out = F.avg_pool2d(out, 2) + return out + + +class DenseNet(nn.Module): + def __init__(self, growthRate=[16,16,16,16], nDenseBlocks=[4,4,4,4], reduction=[0.7,0.7,0.7,0.7], n_classes=11, bottleneck=False): + super(DenseNet, self).__init__() + + nChannels = 2*growthRate[0] + self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1, + bias=False) + # dense 1 + self.dense1 = self._make_dense(nChannels, growthRate[0], nDenseBlocks[0], bottleneck) + nChannels += nDenseBlocks[0]*growthRate[0] + nOutChannels = int(math.floor(nChannels*reduction[0])) + self.trans1 = Transition(nChannels, nOutChannels) + # dense 2 + nChannels = nOutChannels + self.dense2 = self._make_dense(nChannels, growthRate[1], nDenseBlocks[1], bottleneck) + nChannels += nDenseBlocks[1]*growthRate[1] + nOutChannels = int(math.floor(nChannels*reduction[1])) + self.trans2 = Transition(nChannels, nOutChannels) + # dense 3 + nChannels = nOutChannels + self.dense3 = self._make_dense(nChannels, growthRate[2], nDenseBlocks[2], bottleneck) + nChannels += nDenseBlocks[2]*growthRate[2] + nOutChannels = int(math.floor(nChannels*reduction[2])) + self.trans3 = Transition(nChannels, nOutChannels) + # dense 4 + nChannels = nOutChannels + self.dense4 = self._make_dense(nChannels, growthRate[3], nDenseBlocks[3], bottleneck) + nChannels += nDenseBlocks[3]*growthRate[3] + nOutChannels = int(math.floor(nChannels*reduction[3])) + self.trans4 = Transition(nChannels, nOutChannels) + self.score=nn.Sequential( + nn.BatchNorm2d(nOutChannels), + nn.Conv2d(nOutChannels,n_classes,1), + #nn.Dropout(0.5), + ) + + #nChannels = nOutChannels + #self.dense3 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) + #nChannels += nDenseBlocks*growthRate + + #self.bn1 = nn.BatchNorm2d(nChannels) + #self.fc = nn.Linear(nChannels, nClasses) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.bias.data.zero_() + + def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): + layers = [] + for i in range(int(nDenseBlocks)): + if bottleneck: + layers.append(Bottleneck(nChannels, growthRate)) + else: + layers.append(SingleLayer(nChannels, growthRate)) + nChannels += growthRate + return nn.Sequential(*layers) + + def forward(self, x): + pre_conv=self.conv1(x) + dense1=self.trans1(self.dense1(pre_conv)) + dense2=self.trans2(self.dense2(dense1)) + dense3=self.trans3(self.dense3(dense2)) + dense4=self.trans4(self.dense4(dense3)) + score=self.score(dense4) + return score + +class DenseNetSeg(nn.Module): + def __init__(self, growthRate=12, depth=16, reduction=0.5, nClasses=11, bottleneck=False): + super(DenseNetSeg, self).__init__() + nDenseBlocks = depth + if bottleneck: + nDenseBlocks //= 2 + nChannels = 2*growthRate + self.conv1 = nn.Conv2d(3, nChannels, kernel_size=3, padding=1,bias=False) + self.dense1 = self._make_dense(nChannels, growthRate, nDenseBlocks, bottleneck) + self.score=nn.Sequential( + nn.Conv2d(nChannels+growthRate*nDenseBlocks,nClasses,1), + #nn.Dropout(0.5), + ) + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.bias.data.zero_() + def _make_dense(self, nChannels, growthRate, nDenseBlocks, bottleneck): + layers = [] + for i in range(int(nDenseBlocks)): + if bottleneck: + layers.append(Bottleneck(nChannels, growthRate)) + else: + layers.append(SingleLayer(nChannels, growthRate)) + nChannels += growthRate + return nn.Sequential(*layers) + def forward(self,x): + pre_conv=self.conv1(x) + dense=self.dense1(pre_conv) + score=self.score(dense) + return score + + + + +if __name__=='__main__': + x=Variable(torch.Tensor(4,3,256,256)) + model=DenseNetSeg(growthRate=12, depth=16, reduction=0.5, nClasses=11, bottleneck=False) + y=model(x) + print(y.shape) diff --git a/models/fcn_xu.py b/models/fcn_xu.py new file mode 100755 index 0000000..5e8b5e6 --- /dev/null +++ b/models/fcn_xu.py @@ -0,0 +1,426 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import torchvision.models as models + +class fcn_xu(nn.Module): + def __init__(self,n_classes=9): + super(fcn_xu, self).__init__() + self.n_classes = n_classes + + self.conv_block1 = nn.Sequential( + nn.Conv2d(3, 64, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block2 = nn.Sequential( + nn.Conv2d(64, 128, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block3 = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block4 = nn.Sequential( + nn.Conv2d(256, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block5 = nn.Sequential( + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + + self.pool=nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.conv1_16=nn.Conv2d(64, 64, 3, padding=1) + self.conv2_16=nn.Conv2d(128, 64, 3, padding=1) + self.conv3_16=nn.Conv2d(256, 64, 3, padding=1) + self.conv4_16=nn.Conv2d(512, 64, 3, padding=1) + self.conv5_16=nn.Conv2d(512, 64, 3, padding=1) + + self.up_conv2_16 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2) + self.up_conv3_16 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=4) + self.up_conv4_16 = nn.ConvTranspose2d(64, 64, kernel_size=8, stride=8) + self.up_conv5_16 = nn.ConvTranspose2d(64, 64, kernel_size=16, stride=16) + + self.score=nn.Sequential( + nn.Conv2d(4*64,self.n_classes,1), + #nn.Dropout(0.5), + ) + + def forward(self, x): + conv1 = self.conv_block1(x) + conv2 = self.conv_block2(self.pool(conv1)) + conv3 = self.conv_block3(self.pool(conv2)) + conv4 = self.conv_block4(self.pool(conv3)) + conv5 = self.conv_block5(self.pool(conv4)) + + conv1_16=self.conv1_16(conv1) + up_conv2_16=self.up_conv2_16(self.conv2_16(conv2)) + up_conv3_16=self.up_conv3_16(self.conv3_16(conv3)) + up_conv4_16=self.up_conv4_16(self.conv4_16(conv4)) + up_conv5_16=self.up_conv5_16(self.conv5_16(conv5)) + + concat_1_to_5=torch.cat([up_conv2_16,up_conv3_16,up_conv4_16,up_conv5_16], 1) + score=self.score(concat_1_to_5) + return score + + def init_vgg16_params(self, vgg16, copy_fc8=True): + blocks = [self.conv_block1, + self.conv_block2, + self.conv_block3, + self.conv_block4, + self.conv_block5] + + ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] + features = list(vgg16.features.children()) + for idx, conv_block in enumerate(blocks): + for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): + if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): + assert l1.weight.size() == l2.weight.size() + assert l1.bias.size() == l2.bias.size() + l2.weight.data = l1.weight.data + + +class fcn_xu_19(nn.Module): + def __init__(self,n_classes=9): + super(fcn_xu_19, self).__init__() + self.n_classes = n_classes + + self.conv_block1 = nn.Sequential( + nn.Conv2d(3, 64, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(64, 64, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block2 = nn.Sequential( + nn.Conv2d(64, 128, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block3 = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block4 = nn.Sequential( + nn.Conv2d(256, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block5 = nn.Sequential( + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + + self.pool=nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.conv1_16=nn.Conv2d(64, 64, 3, padding=1) + self.conv2_16=nn.Conv2d(128, 64, 3, padding=1) + self.conv3_16=nn.Conv2d(256, 64, 3, padding=1) + self.conv4_16=nn.Conv2d(512, 64, 3, padding=1) + self.conv5_16=nn.Conv2d(512, 64, 3, padding=1) + + self.up_conv2_16 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2) + self.up_conv3_16 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=4) + self.up_conv4_16 = nn.ConvTranspose2d(64, 64, kernel_size=8, stride=8) + self.up_conv5_16 = nn.ConvTranspose2d(64, 64, kernel_size=16, stride=16) + + self.score=nn.Sequential( + nn.Conv2d(4*64,self.n_classes,1), + #nn.Dropout(0.5), + ) + + def forward(self, x): + conv1 = self.conv_block1(x) + conv2 = self.conv_block2(self.pool(conv1)) + conv3 = self.conv_block3(self.pool(conv2)) + conv4 = self.conv_block4(self.pool(conv3)) + conv5 = self.conv_block5(self.pool(conv4)) + + conv1_16=self.conv1_16(conv1) + up_conv2_16=self.up_conv2_16(self.conv2_16(conv2)) + up_conv3_16=self.up_conv3_16(self.conv3_16(conv3)) + up_conv4_16=self.up_conv4_16(self.conv4_16(conv4)) + up_conv5_16=self.up_conv5_16(self.conv5_16(conv5)) + + concat_1_to_5=torch.cat([up_conv2_16,up_conv3_16,up_conv4_16,up_conv5_16], 1) + score=self.score(concat_1_to_5) + return score + + def init_vgg16_params(self, vgg16, copy_fc8=True): + blocks = [self.conv_block1, + self.conv_block2, + self.conv_block3, + self.conv_block4, + self.conv_block5] + + ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] + features = list(vgg16.features.children()) + for idx, conv_block in enumerate(blocks): + for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): + if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): + assert l1.weight.size() == l2.weight.size() + assert l1.bias.size() == l2.bias.size() + l2.weight.data = l1.weight.data + l2.bias.data = l1.bias.data + + +class fcn_nopool(nn.Module): + def __init__(self,n_classes=9): + super(fcn_nopool, self).__init__() + self.n_classes = n_classes + + self.pre_conv=nn.Sequential(nn.Conv2d(3,3,1),nn.ReLU(inplace=True),) + + self.conv_block1 = nn.Sequential( + nn.Conv2d(3, 64, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(64, 128, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block2 = nn.Sequential( + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block3 = nn.Sequential( + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block4 = nn.Sequential( + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block5 = nn.Sequential( + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True),) + + self.conv1_16=nn.Conv2d(64, 64, 3, padding=1) + self.conv2_16=nn.Conv2d(128, 64, 3, padding=1) + self.conv3_16=nn.Conv2d(256, 64, 3, padding=1) + self.conv4_16=nn.Conv2d(512, 64, 3, padding=1) + self.conv5_16=nn.Conv2d(512, 64, 3, padding=1) + + self.score=nn.Sequential( + nn.Conv2d(4*128,self.n_classes,1), + #nn.Dropout(0.5), + ) + + def forward(self, x): + #x=self.pre_conv(x) + conv1 = self.conv_block1(x) + conv2 = self.conv_block2(conv1) + conv3 = self.conv_block3(conv2) + conv4 = self.conv_block4(conv3) + #conv5 = self.conv_block5(conv4) + + #conv1_16=self.conv1_16(conv1) + #conv2_16=self.conv2_16(conv2) + #conv3_16=self.conv3_16(conv3) + #conv4_16=self.conv4_16(conv4) + #conv5_16=self.conv5_16(conv5) + + concat_1_to_4=torch.cat([conv1,conv2,conv3,conv4], 1) + score=self.score(concat_1_to_4) + return score + + def init_vgg16_params(self, vgg16, copy_fc8=True): + blocks = [self.conv_block1, + self.conv_block2, + self.conv_block3, + self.conv_block4, + self.conv_block5] + + ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] + features = list(vgg16.features.children()) + for idx, conv_block in enumerate(blocks): + for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): + if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): + assert l1.weight.size() == l2.weight.size() + assert l1.bias.size() == l2.bias.size() + l2.weight.data = l1.weight.data + l2.bias.data = l1.bias.data + + + +class fcn_xu_dilated(nn.Module): + def __init__(self,n_classes=9): + super(fcn_xu_dilated, self).__init__() + self.n_classes = n_classes + + self.conv_block1 = nn.Sequential( + nn.Conv2d(3, 64, 3, dilation=1, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(64, 64, 3, dilation=2, padding=2),nn.ReLU(inplace=True),) + self.conv_block2 = nn.Sequential( + nn.Conv2d(64, 128, 3, dilation=1, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, dilation=2, padding=2),nn.ReLU(inplace=True),) + self.conv_block3 = nn.Sequential( + nn.Conv2d(128, 256, 3, dilation=1, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, dilation=2, padding=2),nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, dilation=3, padding=3),nn.ReLU(inplace=True),) + self.conv_block4 = nn.Sequential( + nn.Conv2d(256, 512, 3, dilation=1, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, dilation=2, padding=2),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, dilation=3, padding=3),nn.ReLU(inplace=True),) + self.conv_block5 = nn.Sequential( + nn.Conv2d(512, 512, 3, dilation=1, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, dilation=2, padding=2),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, dilation=3, padding=3),nn.ReLU(inplace=True),) + + self.conv1_16=nn.Conv2d(64, 16, 3, padding=1) + self.conv2_16=nn.Conv2d(128, 16, 3, padding=1) + self.conv3_16=nn.Conv2d(256, 16, 3, padding=1) + self.conv4_16=nn.Conv2d(512, 16, 3, padding=1) + self.conv5_16=nn.Conv2d(512, 16, 3, padding=1) + + self.score=nn.Sequential( + nn.Conv2d(5*16,self.n_classes,1), + nn.Dropout(0.5), + ) + + def forward(self, x): + conv1 = self.conv_block1(x) + conv2 = self.conv_block2(conv1) + conv3 = self.conv_block3(conv2) + conv4 = self.conv_block4(conv3) + conv5 = self.conv_block5(conv4) + + conv1_16=self.conv1_16(conv1) + conv2_16=self.conv2_16(conv2) + conv3_16=self.conv3_16(conv3) + conv4_16=self.conv4_16(conv4) + conv5_16=self.conv5_16(conv5) + + concat_1_to_4=torch.cat([conv1_16,conv2_16,conv3_16,conv4_16,conv5_16], 1) + score=self.score(concat_1_to_4) + return score + + def init_vgg16_params(self, vgg16, copy_fc8=True): + blocks = [self.conv_block1, + self.conv_block2, + self.conv_block3, + self.conv_block4, + self.conv_block5] + + ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] + features = list(vgg16.features.children()) + + for idx, conv_block in enumerate(blocks): + for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): + if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): + assert l1.weight.size() == l2.weight.size() + assert l1.bias.size() == l2.bias.size() + l2.weight.data = l1.weight.data + l2.bias.data = l1.bias.data + + +class fcn_mul(nn.Module): + def __init__(self,in_channels=3,n_mod=3,n_feature=16,n_classes=11): + super(fcn_mul,self).__init__() + self.in_channels=in_channels + self.n_mod=n_mod + self.n_feature=n_feature + self.n_classes=n_classes + + self.conv_block1 = nn.Sequential( + nn.Conv2d(self.in_channels, 64, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(64 , 64 , 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block2 = nn.Sequential( + nn.Conv2d(64 , 128, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(128, 128, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block3 = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(256, 256, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block4 = nn.Sequential( + nn.Conv2d(256, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + self.conv_block5 = nn.Sequential( + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True), + nn.Conv2d(512, 512, 3, padding=1),nn.ReLU(inplace=True),) + + self.pool=nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.conv1_c=nn.Conv2d(64 *self.n_mod,self.n_feature,3,padding=1) + self.conv2_c=nn.Conv2d(128*self.n_mod,self.n_feature,3,padding=1) + self.conv3_c=nn.Conv2d(256*self.n_mod,self.n_feature,3,padding=1) + self.conv4_c=nn.Conv2d(512*self.n_mod,self.n_feature,3,padding=1) + self.conv5_c=nn.Conv2d(512*self.n_mod,self.n_feature,3,padding=1) + + self.deconv2=nn.ConvTranspose2d(self.n_feature,self.n_feature,kernel_size=2 ,stride=2 ) + self.deconv3=nn.ConvTranspose2d(self.n_feature,self.n_feature,kernel_size=4 ,stride=4 ) + self.deconv4=nn.ConvTranspose2d(self.n_feature,self.n_feature,kernel_size=8 ,stride=8 ) + self.deconv5=nn.ConvTranspose2d(self.n_feature,self.n_feature,kernel_size=16,stride=16) + + #self.dilation=[1,3,5,8,16] + #self.atrous1=nn.Sequential(nn.Conv2d(4*16,4*16,kernel_size=3,dilation=self.dilation[0],padding=self.dilation[0]),) + #self.atrous2=nn.Sequential(nn.Conv2d(4*16,4*16,kernel_size=3,dilation=self.dilation[1],padding=self.dilation[1]),) + #self.atrous3=nn.Sequential(nn.Conv2d(4*16,4*16,kernel_size=3,dilation=self.dilation[2],padding=self.dilation[2]),) + #self.atrous4=nn.Sequential(nn.Conv2d(4*16,4*16,kernel_size=3,dilation=self.dilation[3],padding=self.dilation[3]),) + #self.atrous5=nn.Sequential(nn.Conv2d(4*16,4*16,kernel_size=3,dilation=self.dilation[4],padding=self.dilation[4]),) + + self.score=nn.Sequential( + nn.Conv2d(4*self.n_feature,self.n_classes,1), + #nn.Dropout(0.5), + ) + + def forward(self,T1,IR,T2): + T1_conv1=self.conv_block1(T1) + T1_conv2=self.conv_block2(self.pool(T1_conv1)) + T1_conv3=self.conv_block3(self.pool(T1_conv2)) + T1_conv4=self.conv_block4(self.pool(T1_conv3)) + #T1_conv5=self.conv_block5(self.pool(T1_conv4)) + + IR_conv1=self.conv_block1(IR) + IR_conv2=self.conv_block2(self.pool(IR_conv1)) + IR_conv3=self.conv_block3(self.pool(IR_conv2)) + IR_conv4=self.conv_block4(self.pool(IR_conv3)) + #IR_conv5=self.conv_block5(self.pool(IR_conv4)) + + T2_conv1=self.conv_block1(T2) + T2_conv2=self.conv_block2(self.pool(T2_conv1)) + T2_conv3=self.conv_block3(self.pool(T2_conv2)) + T2_conv4=self.conv_block4(self.pool(T2_conv3)) + #T2_conv5=self.conv_block5(self.pool(T2_conv4)) + + conv1_c=self.conv1_c(torch.cat([T1_conv1,IR_conv1,T2_conv1],1)) + conv2_c=self.conv2_c(torch.cat([T1_conv2,IR_conv2,T2_conv2],1)) + conv3_c=self.conv3_c(torch.cat([T1_conv3,IR_conv3,T2_conv3],1)) + conv4_c=self.conv4_c(torch.cat([T1_conv4,IR_conv4,T2_conv4],1)) + #conv5_c=self.conv5_c(torch.cat([T1_conv5,IR_conv5,T2_conv5],1)) + + deconv2=self.deconv2(conv2_c) + deconv3=self.deconv3(conv3_c) + deconv4=self.deconv4(conv4_c) + #deconv5=self.deconv5(conv5_c) + + cat=torch.cat([conv1_c,deconv2,deconv3,deconv4],1) + #atrous1=self.atrous1(cat) + #atrous2=self.atrous2(cat) + #atrous3=self.atrous3(cat) + #atrous4=self.atrous4(cat) + #atrous5=self.atrous5(cat) + #aspp=torch.cat([atrous1,atrous2,atrous3,atrous4,atrous5],1) + + score=self.score(cat) + return score + + def init_vgg16_params(self, vgg16, copy_fc8=True): + blocks = [self.conv_block1, + self.conv_block2, + self.conv_block3, + self.conv_block4, + self.conv_block5] + + ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] + features = list(vgg16.features.children()) + + for idx, conv_block in enumerate(blocks): + for l1, l2 in zip(features[ranges[idx][0]:ranges[idx][1]], conv_block): + if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): + assert l1.weight.size() == l2.weight.size() + assert l1.bias.size() == l2.bias.size() + l2.weight.data = l1.weight.data + l2.bias.data = l1.bias.data + + +if __name__=='__main__': + x,y,z=torch.Tensor(4,3,256,256),torch.Tensor(4,3,256,256),torch.Tensor(4,3,256,256) + x,y,z=Variable(x),Variable(y),Variable(z) + print(x.shape) + model=fcn_mul(n_classes=11) + vgg16=models.vgg16(pretrained=True) + model.init_vgg16_params(vgg16) + r=model(x,y,z) + print(r.shape) + diff --git a/models/resnet.py b/models/resnet.py new file mode 100755 index 0000000..04da62a --- /dev/null +++ b/models/resnet.py @@ -0,0 +1,185 @@ +import torch +import torch.nn as nn +import math +from torch.autograd import Variable +import torch.utils.model_zoo as model_zoo + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + +#--modified from +# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + #x = self.maxpool(x) + C1 = x + x = self.layer1(x) + C2 = x + x = self.layer2(x) + C3 = x + x = self.layer3(x) + C4 = x + x = self.layer4(x) + C5 = x + return C1, C2, C3, C4, C5 + + +def resnet50(pretrained=True): + model = ResNet(Bottleneck, [3, 4, 6, 3]) + if pretrained==True: + state_dict=model_zoo.load_url(model_urls['resnet50']) + del state_dict['fc.weight'] + del state_dict['fc.bias'] + model.load_state_dict(state_dict) + return model + +def resnet101(pretrained=True): + model = ResNet(Bottleneck, [3, 4, 23, 3]) + if pretrained==True: + state_dict=model_zoo.load_url(model_urls['resnet101']) + del state_dict['fc.weight'] + del state_dict['fc.bias'] + model.load_state_dict(state_dict) + return model + +def resnet152(pretrained=True): + model = ResNet(Bottleneck, [3, 8, 36, 3]) + if pretrained==True: + state_dict=model_zoo.load_url(model_urls['resnet152']) + del state_dict['fc.weight'] + del state_dict['fc.bias'] + model.load_state_dict(state_dict) + return model + +class FCN_res(nn.Module): + def __init__(self,n_classes=11,pretrained=True): + super(FCN_res,self).__init__() + self.n_classes=n_classes + self.res=resnet152(pretrained=True) + self.conv1_16=nn.Conv2d(64, 64, 3, padding=1) + self.conv2_16=nn.Conv2d(256, 64, 3, padding=1) + self.conv3_16=nn.Conv2d(512, 64, 3, padding=1) + self.conv4_16=nn.Conv2d(1024, 64, 3, padding=1) + self.conv5_16=nn.Conv2d(2048, 64, 3, padding=1) + + self.up_conv1_16 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2) + self.up_conv2_16 = nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2) + self.up_conv3_16 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=4) + self.up_conv4_16 = nn.ConvTranspose2d(64, 64, kernel_size=8, stride=8) + self.up_conv5_16 = nn.ConvTranspose2d(64, 64, kernel_size=16, stride=16) + + self.score=nn.Sequential( + nn.Conv2d(5*64,self.n_classes,1), + #nn.Dropout(0.5), + ) + def forward(self,x): + C1,C2,C3,C4,C5=self.res(x) + + up_conv1_16=self.up_conv1_16(self.conv1_16(C1)) + up_conv2_16=self.up_conv2_16(self.conv2_16(C2)) + up_conv3_16=self.up_conv3_16(self.conv3_16(C3)) + up_conv4_16=self.up_conv4_16(self.conv4_16(C4)) + up_conv5_16=self.up_conv5_16(self.conv5_16(C5)) + + concat_1_to_5=torch.cat([up_conv1_16,up_conv2_16,up_conv3_16,up_conv4_16,up_conv5_16], 1) + score=self.score(concat_1_to_5) + return score + +if __name__=='__main__': + model=FCN_res() + x=Variable(torch.zeros([1,3,48,48]).float()) + print(x.shape) + y=model(x) + print(y.shape) diff --git a/models/segnet.py b/models/segnet.py new file mode 100755 index 0000000..e46b761 --- /dev/null +++ b/models/segnet.py @@ -0,0 +1,218 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torch.autograd import Variable + +class conv2DBatchNorm(nn.Module): + def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, with_bn=True): + super(conv2DBatchNorm, self).__init__() + + if dilation > 1: + conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, + padding=padding, stride=stride, bias=bias, dilation=dilation) + + else: + conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, + padding=padding, stride=stride, bias=bias, dilation=1) + + + if with_bn: + self.cb_unit = nn.Sequential(conv_mod, + nn.BatchNorm2d(int(n_filters)),) + else: + self.cb_unit = nn.Sequential(conv_mod,) + + def forward(self, inputs): + outputs = self.cb_unit(inputs) + return outputs + + +class deconv2DBatchNorm(nn.Module): + def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): + super(deconv2DBatchNorm, self).__init__() + + self.dcb_unit = nn.Sequential(nn.ConvTranspose2d(int(in_channels), int(n_filters), kernel_size=k_size, + padding=padding, stride=stride, bias=bias), + nn.BatchNorm2d(int(n_filters)), + ) + + def forward(self, inputs): + outputs = self.dcb_unit(inputs) + return outputs + + +class conv2DBatchNormRelu(nn.Module): + def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, with_bn=True): + super(conv2DBatchNormRelu, self).__init__() + + if dilation > 1: + conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, + padding=padding, stride=stride, bias=bias, dilation=dilation) + + else: + conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, + padding=padding, stride=stride, bias=bias, dilation=1) + + if with_bn: + self.cbr_unit = nn.Sequential(conv_mod, + nn.BatchNorm2d(int(n_filters)), + nn.ReLU(inplace=True),) + else: + self.cbr_unit = nn.Sequential(conv_mod, + nn.ReLU(inplace=True),) + + def forward(self, inputs): + outputs = self.cbr_unit(inputs) + return outputs + + +class deconv2DBatchNormRelu(nn.Module): + def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): + super(deconv2DBatchNormRelu, self).__init__() + + self.dcbr_unit = nn.Sequential(nn.ConvTranspose2d(int(in_channels), int(n_filters), kernel_size=k_size, + padding=padding, stride=stride, bias=bias), + nn.BatchNorm2d(int(n_filters)), + nn.ReLU(inplace=True),) + + def forward(self, inputs): + outputs = self.dcbr_unit(inputs) + return outputs + +class segnetDown2(nn.Module): + def __init__(self, in_size, out_size): + super(segnetDown2, self).__init__() + self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) + self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) + self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) + + def forward(self, inputs): + outputs = self.conv1(inputs) + outputs = self.conv2(outputs) + unpooled_shape = outputs.size() + outputs, indices = self.maxpool_with_argmax(outputs) + return outputs, indices, unpooled_shape + + +class segnetDown3(nn.Module): + def __init__(self, in_size, out_size): + super(segnetDown3, self).__init__() + self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) + self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) + self.conv3 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) + self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) + + def forward(self, inputs): + outputs = self.conv1(inputs) + outputs = self.conv2(outputs) + outputs = self.conv3(outputs) + unpooled_shape = outputs.size() + outputs, indices = self.maxpool_with_argmax(outputs) + return outputs, indices, unpooled_shape + + +class segnetUp2(nn.Module): + def __init__(self, in_size, out_size): + super(segnetUp2, self).__init__() + self.unpool = nn.MaxUnpool2d(2, 2) + self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) + self.conv2 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) + + def forward(self, inputs, indices, output_shape): + outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) + outputs = self.conv1(outputs) + outputs = self.conv2(outputs) + return outputs + + +class segnetUp3(nn.Module): + def __init__(self, in_size, out_size): + super(segnetUp3, self).__init__() + self.unpool = nn.MaxUnpool2d(2, 2) + self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) + self.conv2 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) + self.conv3 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) + + def forward(self, inputs, indices, output_shape): + outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) + outputs = self.conv1(outputs) + outputs = self.conv2(outputs) + outputs = self.conv3(outputs) + return outputs + + +class segnet(nn.Module): + + def __init__(self, n_classes=21, in_channels=3, is_unpooling=True): + super(segnet, self).__init__() + + self.in_channels = in_channels + self.is_unpooling = is_unpooling + + self.down1 = segnetDown2(self.in_channels, 64) + self.down2 = segnetDown2(64, 128) + self.down3 = segnetDown3(128, 256) + self.down4 = segnetDown3(256, 512) + self.down5 = segnetDown3(512, 512) + + self.up5 = segnetUp3(512, 512) + self.up4 = segnetUp3(512, 256) + self.up3 = segnetUp3(256, 128) + self.up2 = segnetUp2(128, 64) + self.up1 = segnetUp2(64, n_classes) + + def forward(self, inputs): + + down1, indices_1, unpool_shape1 = self.down1(inputs) + down2, indices_2, unpool_shape2 = self.down2(down1) + down3, indices_3, unpool_shape3 = self.down3(down2) + down4, indices_4, unpool_shape4 = self.down4(down3) + down5, indices_5, unpool_shape5 = self.down5(down4) + + up5 = self.up5(down5, indices_5, unpool_shape5) + up4 = self.up4(up5, indices_4, unpool_shape4) + up3 = self.up3(up4, indices_3, unpool_shape3) + up2 = self.up2(up3, indices_2, unpool_shape2) + up1 = self.up1(up2, indices_1, unpool_shape1) + + return up1 + + + def init_vgg16_params(self, vgg16): + blocks = [self.down1, + self.down2, + self.down3, + self.down4, + self.down5] + + ranges = [[0, 4], [5, 9], [10, 16], [17, 23], [24, 29]] + features = list(vgg16.features.children()) + + vgg_layers = [] + for _layer in features: + if isinstance(_layer, nn.Conv2d): + vgg_layers.append(_layer) + + merged_layers = [] + for idx, conv_block in enumerate(blocks): + if idx < 2: + units = [conv_block.conv1.cbr_unit, + conv_block.conv2.cbr_unit] + else: + units = [conv_block.conv1.cbr_unit, + conv_block.conv2.cbr_unit, + conv_block.conv3.cbr_unit] + for _unit in units: + for _layer in _unit: + if isinstance(_layer, nn.Conv2d): + merged_layers.append(_layer) + + assert len(vgg_layers) == len(merged_layers) + + for l1, l2 in zip(vgg_layers, merged_layers): + if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): + assert l1.weight.size() == l2.weight.size() + assert l1.bias.size() == l2.bias.size() + l2.weight.data = l1.weight.data + l2.bias.data = l1.bias.data diff --git a/models/tiramisu/__init__.py b/models/tiramisu/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/models/tiramisu/layers.py b/models/tiramisu/layers.py new file mode 100755 index 0000000..14cb3e6 --- /dev/null +++ b/models/tiramisu/layers.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn + + +class DenseLayer(nn.Sequential): + def __init__(self, in_channels, growth_rate): + super().__init__() + self.add_module('norm', nn.BatchNorm2d(in_channels)) + self.add_module('relu', nn.ReLU(True)) + self.add_module('conv', nn.Conv2d(in_channels, growth_rate, kernel_size=3, + stride=1, padding=1, bias=True)) + self.add_module('drop', nn.Dropout2d(0.2)) + + def forward(self, x): + return super().forward(x) + + +class DenseBlock(nn.Module): + def __init__(self, in_channels, growth_rate, n_layers, upsample=False): + super().__init__() + self.upsample = upsample + self.layers = nn.ModuleList([DenseLayer( + in_channels + i*growth_rate, growth_rate) + for i in range(n_layers)]) + + def forward(self, x): + if self.upsample: + new_features = [] + #we pass all previous activations into each dense layer normally + #But we only store each dense layer's output in the new_features array + for layer in self.layers: + out = layer(x) + x = torch.cat([x, out], 1) + new_features.append(out) + return torch.cat(new_features,1) + else: + for layer in self.layers: + out = layer(x) + x = torch.cat([x, out], 1) # 1 = channel axis + return x + + +class TransitionDown(nn.Sequential): + def __init__(self, in_channels): + super().__init__() + self.add_module('norm', nn.BatchNorm2d(num_features=in_channels)) + self.add_module('relu', nn.ReLU(inplace=True)) + self.add_module('conv', nn.Conv2d(in_channels, in_channels, + kernel_size=1, stride=1, + padding=0, bias=True)) + self.add_module('drop', nn.Dropout2d(0.2)) + self.add_module('maxpool', nn.MaxPool2d(2)) + + def forward(self, x): + return super().forward(x) + + +class TransitionUp(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.convTrans = nn.ConvTranspose2d( + in_channels=in_channels, out_channels=out_channels, + kernel_size=3, stride=2, padding=0, bias=True) + + def forward(self, x, skip): + out = self.convTrans(x) + out = center_crop(out, skip.size(2), skip.size(3)) + out = torch.cat([out, skip], 1) + return out + + +class Bottleneck(nn.Sequential): + def __init__(self, in_channels, growth_rate, n_layers): + super().__init__() + self.add_module('bottleneck', DenseBlock( + in_channels, growth_rate, n_layers, upsample=True)) + + def forward(self, x): + return super().forward(x) + + +def center_crop(layer, max_height, max_width): + _, _, h, w = layer.size() + xy1 = (w - max_width) // 2 + xy2 = (h - max_height) // 2 + return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)] diff --git a/models/tiramisu/tiramisu.py b/models/tiramisu/tiramisu.py new file mode 100755 index 0000000..7173a9f --- /dev/null +++ b/models/tiramisu/tiramisu.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn + +from models.tiramisu.layers import * + + +class FCDenseNet(nn.Module): + def __init__(self, in_channels=3, down_blocks=(5,5,5,5,5), + up_blocks=(5,5,5,5,5), bottleneck_layers=5, + growth_rate=16, out_chans_first_conv=48, n_classes=12): + super().__init__() + self.down_blocks = down_blocks + self.up_blocks = up_blocks + cur_channels_count = 0 + skip_connection_channel_counts = [] + + ## First Convolution ## + + self.add_module('firstconv', nn.Conv2d(in_channels=in_channels, + out_channels=out_chans_first_conv, kernel_size=3, + stride=1, padding=1, bias=True)) + cur_channels_count = out_chans_first_conv + + ##################### + # Downsampling path # + ##################### + + self.denseBlocksDown = nn.ModuleList([]) + self.transDownBlocks = nn.ModuleList([]) + for i in range(len(down_blocks)): + self.denseBlocksDown.append( + DenseBlock(cur_channels_count, growth_rate, down_blocks[i])) + cur_channels_count += (growth_rate*down_blocks[i]) + skip_connection_channel_counts.insert(0,cur_channels_count) + self.transDownBlocks.append(TransitionDown(cur_channels_count)) + + ##################### + # Bottleneck # + ##################### + + self.add_module('bottleneck',Bottleneck(cur_channels_count, + growth_rate, bottleneck_layers)) + prev_block_channels = growth_rate*bottleneck_layers + cur_channels_count += prev_block_channels + + ####################### + # Upsampling path # + ####################### + + self.transUpBlocks = nn.ModuleList([]) + self.denseBlocksUp = nn.ModuleList([]) + for i in range(len(up_blocks)-1): + self.transUpBlocks.append(TransitionUp(prev_block_channels, prev_block_channels)) + cur_channels_count = prev_block_channels + skip_connection_channel_counts[i] + + self.denseBlocksUp.append(DenseBlock( + cur_channels_count, growth_rate, up_blocks[i], + upsample=True)) + prev_block_channels = growth_rate*up_blocks[i] + cur_channels_count += prev_block_channels + + ## Final DenseBlock ## + + self.transUpBlocks.append(TransitionUp( + prev_block_channels, prev_block_channels)) + cur_channels_count = prev_block_channels + skip_connection_channel_counts[-1] + + self.denseBlocksUp.append(DenseBlock( + cur_channels_count, growth_rate, up_blocks[-1], + upsample=False)) + cur_channels_count += growth_rate*up_blocks[-1] + + ## Softmax ## + + self.finalConv = nn.Conv2d(in_channels=cur_channels_count, + out_channels=n_classes, kernel_size=1, stride=1, + padding=0, bias=True) + self.softmax = nn.LogSoftmax(dim=1) + + def forward(self, x): + out = self.firstconv(x) + + skip_connections = [] + for i in range(len(self.down_blocks)): + out = self.denseBlocksDown[i](out) + skip_connections.append(out) + out = self.transDownBlocks[i](out) + + out = self.bottleneck(out) + for i in range(len(self.up_blocks)): + skip = skip_connections.pop() + out = self.transUpBlocks[i](out, skip) + out = self.denseBlocksUp[i](out) + + out = self.finalConv(out) + out = self.softmax(out) + return out + + +def FCDenseNet57(n_classes): + return FCDenseNet( + in_channels=3, down_blocks=(4, 4, 4, 4, 4), + up_blocks=(4, 4, 4, 4, 4), bottleneck_layers=4, + growth_rate=12, out_chans_first_conv=48, n_classes=n_classes) + + +def FCDenseNet67(n_classes): + return FCDenseNet( + in_channels=3, down_blocks=(5, 5, 5, 5, 5), + up_blocks=(5, 5, 5, 5, 5), bottleneck_layers=5, + growth_rate=16, out_chans_first_conv=48, n_classes=n_classes) + + +def FCDenseNet103(n_classes): + return FCDenseNet( + in_channels=3, down_blocks=(4,5,7,10,12), + up_blocks=(12,10,7,5,4), bottleneck_layers=15, + growth_rate=16, out_chans_first_conv=48, n_classes=n_classes) diff --git a/models/unet.py b/models/unet.py new file mode 100755 index 0000000..12816e5 --- /dev/null +++ b/models/unet.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class unetConv2(nn.Module): + def __init__(self, in_size, out_size, is_batchnorm): + super(unetConv2, self).__init__() + + if is_batchnorm: + self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 1), + nn.BatchNorm2d(out_size), + nn.ReLU(),) + self.conv2 = nn.Sequential(nn.Conv2d(out_size, out_size, 3, 1, 1), + nn.BatchNorm2d(out_size), + nn.ReLU(),) + else: + self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 1), + nn.ReLU(),) + self.conv2 = nn.Sequential(nn.Conv2d(out_size, out_size, 3, 1, 1), + nn.ReLU(),) + def forward(self, inputs): + outputs = self.conv1(inputs) + outputs = self.conv2(outputs) + return outputs + + +class unetUp(nn.Module): + def __init__(self, in_size, out_size, is_deconv): + super(unetUp, self).__init__() + self.conv = unetConv2(in_size, out_size, True) + if is_deconv: + self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) + else: + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + + def forward(self, inputs1, inputs2): + outputs2 = self.up(inputs2) + offset = outputs2.size()[2] - inputs1.size()[2] + padding = 2 * [offset // 2, offset // 2] + outputs1 = F.pad(inputs1, padding) + return self.conv(torch.cat([outputs1, outputs2], 1)) + +class unet(nn.Module): + + def __init__(self, feature_scale=1, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True): + super(unet, self).__init__() + self.is_deconv = is_deconv + self.in_channels = in_channels + self.is_batchnorm = is_batchnorm + self.feature_scale = feature_scale + + filters = [64, 128, 256, 512, 1024] + filters = [int(x / self.feature_scale) for x in filters] + + # downsampling + self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm) + self.maxpool1 = nn.MaxPool2d(kernel_size=2) + + self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm) + self.maxpool2 = nn.MaxPool2d(kernel_size=2) + + self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm) + self.maxpool3 = nn.MaxPool2d(kernel_size=2) + + self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm) + self.maxpool4 = nn.MaxPool2d(kernel_size=2) + + self.center = unetConv2(filters[3], filters[4], self.is_batchnorm) + + # upsampling + self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv) + self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv) + self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv) + self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv) + + # final conv (without any concat) + self.final = nn.Conv2d(filters[0], n_classes, 1) + + def forward(self, inputs): + conv1 = self.conv1(inputs) + maxpool1 = self.maxpool1(conv1) + + conv2 = self.conv2(maxpool1) + maxpool2 = self.maxpool2(conv2) + + conv3 = self.conv3(maxpool2) + maxpool3 = self.maxpool3(conv3) + + conv4 = self.conv4(maxpool3) + maxpool4 = self.maxpool4(conv4) + + center = self.center(maxpool4) + up4 = self.up_concat4(conv4, center) + up3 = self.up_concat3(conv3, up4) + up2 = self.up_concat2(conv2, up3) + up1 = self.up_concat1(conv1, up2) + + final = self.final(up1) + return final diff --git a/pred1.jpg b/pred1.jpg new file mode 100644 index 0000000..848c94e Binary files /dev/null and b/pred1.jpg differ diff --git a/pred2.jpg b/pred2.jpg new file mode 100644 index 0000000..e4b86b5 Binary files /dev/null and b/pred2.jpg differ diff --git a/pres.jpg b/pres.jpg new file mode 100644 index 0000000..e61f10f Binary files /dev/null and b/pres.jpg differ diff --git a/test.py b/test.py new file mode 100755 index 0000000..4723e70 --- /dev/null +++ b/test.py @@ -0,0 +1,180 @@ +import numpy as np +import nibabel as nib +import cv2 as cv +import torch +from torch.utils import data +from torchvision.transforms import transforms +from data_loader.preprocess import readVol,to_uint8,IR_to_uint8,histeq,preprocessed,get_stacked,rotate,calc_crop_region,calc_max_region_list,crop,get_edge + +import os +import argparse +from torch.autograd import Variable +from models.fcn_xu import fcn_mul + +class MR18loader_test(data.Dataset): + def __init__(self,T1_path,IR_path,T2_path,is_transform,is_crop,is_hist,forest): + self.T1_path=T1_path + self.IR_path=IR_path + self.T2_path=T2_path + self.is_transform=is_transform + self.is_crop=is_crop + self.is_hist=is_hist + self.forest=forest + self.n_classes=11 + self.T1mean=0.0 + self.IRmean=0.0 + self.T2mean=0.0 + #read data + T1_nii=to_uint8(readVol(self.T1_path)) + IR_nii=IR_to_uint8(readVol(self.IR_path)) + T2_nii=to_uint8(readVol(self.T2_path)) + #histeq + if self.is_hist: + T1_nii=histeq(T1_nii) + #stack + T1_stack_list=get_stacked(T1_nii,self.forest) + IR_stack_list=get_stacked(IR_nii,self.forest) + T2_stack_list=get_stacked(T2_nii,self.forest) + #crop + if self.is_crop: + region_list=calc_max_region_list(calc_crop_region(T1_stack_list,50,5),self.forest) + self.region_list=region_list + T1_stack_list=crop(T1_stack_list,region_list) + IR_stack_list=crop(IR_stack_list,region_list) + T2_stack_list=crop(T2_stack_list,region_list) + #get mean + T1mean,IRmean,T2mean=0.0,0.0,0.0 + for samples in T1_stack_list: + for stacks in samples: + T1mean=T1mean+np.mean(stacks) + self.T1mean=T1mean/(len(T1_stack_list)*len(T1_stack_list[0])) + for samples in IR_stack_list: + for stacks in samples: + IRmean=IRmean+np.mean(stacks) + self.IRmean=IRmean/(len(IR_stack_list)*len(IR_stack_list[0])) + for samples in T2_stack_list: + for stacks in samples: + T2mean=T2mean+np.mean(stacks) + self.T2mean=T2mean/(len(T2_stack_list)*len(T2_stack_list[0])) + + #transform + if self.is_transform: + for stack_index in range(len(T1_stack_list)): + T1_stack_list[stack_index], \ + IR_stack_list[stack_index], \ + T2_stack_list[stack_index]= \ + self.transform( \ + T1_stack_list[stack_index], \ + IR_stack_list[stack_index], \ + T2_stack_list[stack_index]) + + # data ready + self.T1_stack_list=T1_stack_list + self.IR_stack_list=IR_stack_list + self.T2_stack_list=T2_stack_list + + def __len__(self): + return 48 + def __getitem__(self,index): + return self.region_list[index],self.T1_stack_list[index],self.IR_stack_list[index],self.T2_stack_list[index] + + def transform(self,imgT1,imgIR,imgT2): + imgT1=torch.from_numpy((imgT1.transpose(2,0,1).astype(np.float)-self.T1mean)/255.0).float() + imgIR=torch.from_numpy((imgIR.transpose(2,0,1).astype(np.float)-self.IRmean)/255.0).float() + imgT2=torch.from_numpy((imgT2.transpose(2,0,1).astype(np.float)-self.T2mean)/255.0).float() + return imgT1,imgIR,imgT2 + + +def test(args): + os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu_id) + # io vols + srcvol=nib.load(args.T1_path) + outvol=np.zeros((240,240,48),np.uint8) + # data loader + loader=MR18loader_test(T1_path=args.T1_path,IR_path=args.IR_path,T2_path=args.T2_path,is_transform=True,is_crop=True,is_hist=True,forest=3) + testloader=data.DataLoader(loader,batch_size=1,num_workers=1,shuffle=False) + # model setup + n_classes = loader.n_classes + model_1=fcn_mul(n_classes=n_classes) + model_2=fcn_mul(n_classes=n_classes) + model_3=fcn_mul(n_classes=n_classes) + model_4=fcn_mul(n_classes=n_classes) + model_5=fcn_mul(n_classes=n_classes) + model_6=fcn_mul(n_classes=n_classes) + model_7=fcn_mul(n_classes=n_classes) + model_1.cuda() + model_2.cuda() + model_3.cuda() + model_4.cuda() + model_5.cuda() + model_6.cuda() + model_7.cuda() + state_1 = torch.load(args.model_path_1)['model_state'] + state_2 = torch.load(args.model_path_2)['model_state'] + state_3 = torch.load(args.model_path_3)['model_state'] + state_4 = torch.load(args.model_path_4)['model_state'] + state_5 = torch.load(args.model_path_5)['model_state'] + state_6 = torch.load(args.model_path_6)['model_state'] + state_7 = torch.load(args.model_path_7)['model_state'] + model_1.load_state_dict(state_1) + model_2.load_state_dict(state_2) + model_3.load_state_dict(state_3) + model_4.load_state_dict(state_4) + model_5.load_state_dict(state_5) + model_6.load_state_dict(state_6) + model_7.load_state_dict(state_7) + model_1.eval() + model_2.eval() + model_3.eval() + model_4.eval() + model_5.eval() + model_6.eval() + model_7.eval() + # test + for i_t,(regions_t,T1s_t,IRs_t,T2s_t) in enumerate(testloader): + T1s_t,IRs_t,T2s_t=Variable(T1s_t.cuda()),Variable(IRs_t.cuda()),Variable(T2s_t.cuda()) + with torch.no_grad(): + out_1=model_1(T1s_t,IRs_t,T2s_t)[0,:,:,:] + out_2=model_2(T1s_t,IRs_t,T2s_t)[0,:,:,:] + out_3=model_3(T1s_t,IRs_t,T2s_t)[0,:,:,:] + out_4=model_4(T1s_t,IRs_t,T2s_t)[0,:,:,:] + out_5=model_5(T1s_t,IRs_t,T2s_t)[0,:,:,:] + out_6=model_6(T1s_t,IRs_t,T2s_t)[0,:,:,:] + out_7=model_7(T1s_t,IRs_t,T2s_t)[0,:,:,:] + pred_1 = out_1.data.max(0)[1].cpu().numpy() + pred_2 = out_2.data.max(0)[1].cpu().numpy() + pred_3 = out_3.data.max(0)[1].cpu().numpy() + pred_4 = out_4.data.max(0)[1].cpu().numpy() + pred_5 = out_5.data.max(0)[1].cpu().numpy() + pred_6 = out_6.data.max(0)[1].cpu().numpy() + pred_7 = out_7.data.max(0)[1].cpu().numpy() + h,w=pred_1.shape[0],pred_1.shape[1] + pred=np.zeros((h,w),np.uint8) + # vote in 7 results + for y in range(h): + for x in range(w): + pred_list=np.array([pred_1[y,x],pred_2[y,x],pred_3[y,x],pred_4[y,x],pred_5[y,x],pred_6[y,x],pred_7[y,x]]) + pred[y,x]=np.argmax(np.bincount(pred_list)) + # padding to 240x240 + pred_pad=np.zeros((240,240),np.uint8) + pred_pad[regions_t[0]:regions_t[1],regions_t[2]:regions_t[3]]=pred[0:regions_t[1]-regions_t[0],0:regions_t[3]-regions_t[2]] + outvol[:,:,i_t]=pred_pad.transpose() + # write nii.gz + nib.Nifti1Image(outvol, srcvol.affine, srcvol.header).to_filename(args.outpath) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Params') + parser.add_argument('--gpu_id', nargs='?', type=int, default=-1,help='GPU id, -1 for cpu') + parser.add_argument('--T1_path',nargs='?',type=str,default='') + parser.add_argument('--IR_path',nargs='?',type=str,default='') + parser.add_argument('--T2_path',nargs='?',type=str,default='') + parser.add_argument('--outpath',nargs='?',type=str,default='') + parser.add_argument('--model_path_1', nargs='?', type=str, default='./CV-models/FCN_MR13_val1.pkl') + parser.add_argument('--model_path_2', nargs='?', type=str, default='./CV-models/FCN_MR13_val2.pkl') + parser.add_argument('--model_path_3', nargs='?', type=str, default='./CV-models/FCN_MR13_val3.pkl') + parser.add_argument('--model_path_4', nargs='?', type=str, default='./CV-models/FCN_MR13_val4.pkl') + parser.add_argument('--model_path_5', nargs='?', type=str, default='./CV-models/FCN_MR13_val5.pkl') + parser.add_argument('--model_path_6', nargs='?', type=str, default='./CV-models/FCN_MR13_val5.pkl') + parser.add_argument('--model_path_7', nargs='?', type=str, default='./CV-models/FCN_MR13_val5.pkl') + args = parser.parse_args() + test(args) diff --git a/train.py b/train.py new file mode 100755 index 0000000..e98013a --- /dev/null +++ b/train.py @@ -0,0 +1,286 @@ +import os +import time +import gc +import cv2 as cv +import argparse +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +plt.ion() + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +from torch.autograd import Variable +from torch.utils import data + +from models.fcn_xu import fcn_mul +from data_loader.data_loader_18 import MR18loader_CV +from metrics import runningScore +from loss import cross_entropy2d,loss_ce_t, weighted_loss, dice_loss,dice_coeff, bce2d_hed + +from models.fcn_xu import fcn_xu,fcn_xu_19,fcn_nopool,fcn_xu_dilated +from models.unet import unet +from models.PAN import PAN_seg +from models.resnet import FCN_res + +from models.segnet import segnet +from models.densenet import DenseNet,DenseNetSeg +from models.tiramisu import tiramisu + +def adjust_learning_rate(optimizer, epoch): + lr = args.lr * (0.1 ** (epoch // 5)) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def train(args): + os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu_id) + #torch.manual_seed(1337) + print(args) + # setup dataloader + t_loader=MR18loader_CV(root=args.data_path,val_num=args.val_num,is_val=False,is_transform=True,is_flip=True,is_rotate=True,is_crop=True,is_histeq=True,forest=args.num_forest) + v_loader=MR18loader_CV(root=args.data_path,val_num=args.val_num,is_val=True,is_transform=True,is_flip=False,is_rotate=False,is_crop=True,is_histeq=True,forest=args.num_forest) + n_classes = t_loader.n_classes + trainloader = data.DataLoader(t_loader, batch_size=args.batch_size, num_workers=1, shuffle=True) + valloader = data.DataLoader(v_loader, batch_size=1, num_workers=1,shuffle=True) + # setup Metrics + running_metrics_single = runningScore(n_classes) + running_metrics_single_test = runningScore(4) + # setup Model + model=fcn_mul(n_classes=n_classes) + vgg16 = models.vgg16(pretrained=True) + model.init_vgg16_params(vgg16) + model.cuda() + # setup optimizer and loss + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.99, weight_decay=5e-4) + loss_ce = cross_entropy2d + #loss_ce_weight = weighted_loss + #loss_dc = dice_loss + #loss_hed= bce2d_hed + # resume + best_iou=-100.0 + if args.resume is not None: + if os.path.isfile(args.resume): + print("Loading model and optimizer from checkpoint '{}'".format(args.resume)) + checkpoint = torch.load(args.resume) + best_iou=checkpoint['best_iou'] + model.load_state_dict(checkpoint['model_state']) + optimizer.load_state_dict(checkpoint['optimizer_state']) + print("Loaded checkpoint '{}' (epoch {}), best_iou={}" + .format(args.resume, checkpoint['epoch'],best_iou)) + else: + best_iou=-100.0 + print("No checkpoint found at '{}'".format(args.resume)) + # visualization + t = [] + loss_seg_list=[] + loss_hed_list=[] + Dice_mean=[] + Dice_CSF=[] + Dice_GM=[] + Dice_WM=[] + t_pre=time.time() + print('training prepared, cost {} seconds\n\n'.format(t_pre-t_begin)) + for epoch in range(args.n_epoch): + t.append(epoch+1) + model.train() + adjust_learning_rate(optimizer,epoch) + #loss_sum=0.0 + loss_epoch=0.0 + t_epoch=time.time() + for i_train, (regions,T1s,IRs,T2s,lbls) in enumerate(trainloader): + T1s=Variable(T1s.cuda()) + IRs,T2s=Variable(IRs.cuda()),Variable(T2s.cuda()) + lbls=Variable(lbls.cuda()[:,int(args.num_forest/2),:,:].unsqueeze(1)) + #edges=Variable(edges.cuda()[:,int(args.num_forest/2),:,:].unsqueeze(1)) + optimizer.zero_grad() + outputs=model(T1s,IRs,T2s) + seg_out=F.log_softmax(outputs,dim=1) + max_prob,_=torch.max(seg_out,dim=1) + max_prob=-max_prob.detach().unsqueeze(1) + loss_seg_value=loss_ce(input=outputs,target=lbls) + #+0.5*loss_dc(input=outputs,target=lbls) + #+0.5*loss_ce_weight(input=outputs,target=lbls,weight=max_prob)\ + #+0.5*loss_ce_weight(input=outputs,target=lbls,weight=edges)\ + #loss_hed_value=loss_hed(input=outputs[1],target=edges) + #+0.5*loss_hed(input=outputs[2],target=edges) \ + #+0.5*loss_hed(input=outputs[3],target=edges) \ + #+0.5*loss_hed(input=outputs[4],target=edges) \ + #+0.5*loss_hed(input=outputs[5],target=edges) + loss=loss_seg_value + #loss=loss_seg_value+loss_hed_value + # loss average + #loss_sum+=loss + #if (i_train+1)%args.loss_avg==0: + # loss_sum/=args.loss_avg + # loss_sum.backward() + # optimizer.step() + # loss_sum=0.0 + loss.backward() + optimizer.step() + loss_epoch+=loss.item() + # visualization + if i_train==40: + ax1=plt.subplot(241) + ax1.imshow(T1s[0,1,:,:].data.cpu().numpy(),cmap ='gray') + ax1.set_title('train_img') + ax1.axis('off') + ax2=plt.subplot(242) + ax2.imshow(t_loader.decode_segmap(lbls[0,0,:,:].data.cpu().numpy()).astype(np.uint8)) + ax2.set_title('train_label') + ax2.axis('off') + ax3=plt.subplot(243) + model.eval() + train_show=model(T1s,IRs,T2s) + ax3.imshow(t_loader.decode_segmap(train_show[0].data.max(0)[1].cpu().numpy()).astype(np.uint8)) + ax3.set_title('train_predict') + ax3.axis('off') + ax4=plt.subplot(244) + ax4.imshow(max_prob[0,0].cpu().numpy()) + ax4.set_title('uncertainty') + ax4.axis('off') + model.train() + loss_epoch/=i_train + loss_seg_list.append(loss_epoch) + loss_hed_list.append(0) + t_train=time.time() + print('epoch: ',epoch+1) + print('--------------------------------Training--------------------------------') + print('average loss in this epoch: ',loss_epoch) + print('final loss in this epoch: ',loss.data.item()) + print('cost {} seconds up to now'.format(t_train-t_begin)) + print('cost {} seconds in this train epoch'.format(t_train-t_epoch)) + + model.eval() + for i_val, (regions_val,T1s_val,IRs_val,T2s_val,lbls_val) in enumerate(valloader): + T1s_val=Variable(T1s_val.cuda()) + IRs_val,T2s_val=Variable(IRs_val.cuda()),Variable(T2s_val.cuda()) + with torch.no_grad(): + outputs_single=model(T1s_val,IRs_val,T2s_val)[0,:,:,:] + # get predict + pred_single=outputs_single.data.max(0)[1].cpu().numpy() + # pad to 240 + pred_pad=np.zeros((240,240),np.uint8) + pred_pad[regions_val[0]:regions_val[1],regions_val[2]:regions_val[3]]= \ + pred_single[0:regions_val[1]-regions_val[0],0:regions_val[3]-regions_val[2]] + # convert to 3 classes + pred_single_test=np.zeros((240,240),np.uint8) + pred_single_test=v_loader.lbl_totest(pred_pad) + # get gt + gt = lbls_val[0][int(args.num_forest/2)].numpy() + # pad to 240 + gt_pad=np.zeros((240,240),np.uint8) + gt_pad[regions_val[0]:regions_val[1],regions_val[2]:regions_val[3]]= \ + gt[0:regions_val[1]-regions_val[0],0:regions_val[3]-regions_val[2]] + # convert to 3 classes + gt_test=np.zeros((240,240),np.uint8) + gt_test=v_loader.lbl_totest(gt_pad) + # metrics update + running_metrics_single.update(gt_pad, pred_pad) + running_metrics_single_test.update(gt_test, pred_single_test) + # visualization + if i_val==40: + ax5=plt.subplot(245) + ax5.imshow((T1s_val[0,int(args.num_forest/2),:,:].data.cpu().numpy()*255+t_loader.T1mean).astype(np.uint8),cmap ='gray') + ax5.set_title('src_img') + ax5.axis('off') + ax6=plt.subplot(246) + ax6.imshow(t_loader.decode_segmap(gt).astype(np.uint8)) + ax6.set_title('gt') + ax6.axis('off') + ax7=plt.subplot(247) + ax7.imshow(t_loader.decode_segmap(pred_single).astype(np.uint8)) + ax7.set_title('pred_single') + ax7.axis('off') + ax8=plt.subplot(248) + ax8.imshow(pred_single_test[regions_val[0]:regions_val[1],regions_val[2]:regions_val[3]].astype(np.uint8)) + ax8.set_title('pred_single_test') + ax8.axis('off') + plt.tight_layout() + plt.subplots_adjust(wspace=.1,hspace=.3) + plt.savefig('./fig_out/val_{}_out_{}.png'.format(str(args.val_num),epoch+1)) + # compute dice coefficients during validation + score_single, class_iou_single = running_metrics_single.get_scores() + score_single_test, class_iou_single_test = running_metrics_single_test.get_scores() + Dice_mean.append(score_single['Mean Dice : \t']) + Dice_CSF.append(score_single_test['Dice : \t'][1]) + Dice_GM.append(score_single_test['Dice : \t'][2]) + Dice_WM.append(score_single_test['Dice : \t'][3]) + print('--------------------------------All tissues--------------------------------') + print('Back: Background,') + print('GM: Cortical GM(red), Basal ganglia(green),') + print('WM: WM(yellow), WM lesions(blue),') + print('CSF: CSF(pink), Ventricles(light blue),') + print('Back: Cerebellum(white), Brainstem(dark red)') + print('single predict: ') + for k, v in score_single.items(): + print(k, v) + print('--------------------------------Only tests--------------------------------') + print('tissue : Back , CSF , GM , WM') + print('single predict: ') + for k, v in score_single_test.items(): + print(k, v) + t_test=time.time() + print('cost {} seconds up to now'.format(t_test-t_begin)) + print('cost {} seconds in this validation epoch'.format(t_test-t_train)) + # save model at best validation metrics + if score_single['Mean Dice : \t'] >= best_iou: + best_iou = score_single['Mean Dice : \t'] + state = {'epoch': epoch+1, + 'model_state': model.state_dict(), + 'optimizer_state' : optimizer.state_dict(), + 'best_iou':best_iou} + torch.save(state, "val_{}_best.pkl".format(str(args.val_num))) + print('model saved!!!') + # save model every 10 epochs + if (epoch+1)%10==0: + state = {'epoch': epoch+1, + 'model_state': model.state_dict(), + 'optimizer_state' : optimizer.state_dict(), + 'score':score_single} + torch.save(state, "val_{}_e_{}.pkl".format(str(args.val_num),epoch+1)) + # plot curve + ax1=plt.subplot(211) + ax1.plot(t,loss_seg_list,'g') + ax1.plot(t,loss_hed_list,'r') + ax1.set_title('train loss') + ax2=plt.subplot(212) + ax2.plot(t,Dice_mean,'k') + ax2.plot(t,Dice_CSF,'r') + ax2.plot(t,Dice_GM,'g') + ax2.plot(t,Dice_WM,'b') + ax2.set_title('validate Dice, R/G/B for CSF/GM/WM') + plt.tight_layout() + plt.subplots_adjust(wspace=0,hspace=.3) + plt.savefig('./fig_out/val_{}_curve.png'.format(str(args.val_num))) + # metric reset + running_metrics_single.reset() + running_metrics_single_test.reset() + print('\n\n') + +if __name__ == '__main__': + t_begin=time.time() + parser = argparse.ArgumentParser(description='Hyperparams') + parser.add_argument('--gpu_id', nargs='?', type=int, default=-1, + help='GPU id, -1 for cpu') + parser.add_argument('--data_path', nargs='?', type=str, default='/home/canpi/canpi/MRBrainS18/data/', + help='dataset path') + parser.add_argument('--val_num', nargs='?', type=int, default=1, + help='which set is left for validation') + + parser.add_argument('--n_epoch', nargs='?', type=int, default=20, + help='# of the epochs') + parser.add_argument('--batch_size', nargs='?', type=int, default=1, + help='Batch Size') + parser.add_argument('--num_forest', nargs='?', type=int, default=3, + help='number of stacked slice') + #parser.add_argument('--loss_avg', nargs='?', type=int, default=1, + # help='loss average') + parser.add_argument('--lr', nargs='?', type=float, default=1e-3, + help='Learning Rate') + parser.add_argument('--resume', nargs='?', type=str, default=None, + help='Path to previous saved model to restart from') + args = parser.parse_args() + train(args) diff --git a/validate.py b/validate.py new file mode 100755 index 0000000..1f831b4 --- /dev/null +++ b/validate.py @@ -0,0 +1,65 @@ +import os +import torch +import argparse +import numpy as np +import torch.nn as nn +import cv2 as cv +import nibabel as nib +import torch.nn.functional as F +import torchvision.models as models + +from torch.autograd import Variable +from torch.utils import data +from tqdm import tqdm + +from fcn_xu import fcn_mul +from data_loader import MR18loader_CV + + +def validate(args): + os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu_id) + print(args) + #torch.manual_seed(1337) + # get nii header + srcvol=nib.load(data_path+'training/14/pre/reg_T1.nii.gz') + # setup dataloader + data_path='../../data/' + v_loader=MR18loader_CV(root=data_path,val_num=args.val_num,is_val=True,is_transform=True,is_rotate=False,is_crop=True,is_histeq=True,forest=args.num_forest) + n_classes=v_loader.n_classes + valloader=data.DataLoader(v_loader,batch_size=1,num_workers=1,shuffle=False) + # setup model + model=fcn_mul(n_classes=n_classes) + model.cuda() + state = torch.load(args.model_path)['model_state'] + model.load_state_dict(state) + model.eval() + # start predict + pred_out=np.zeros((240,240,48),np.uint8) + for i_val,(regions,T1s,IRs,T2s,lbls) in tqdm(enumerate(valloader)): + print(regions) + T1s,IRs,T2s=Variable(T1s.cuda()),Variable(IRs.cuda()),Variable(T2s.cuda()) + with torch.no_grad(): + output_slice=model(T1s,IRs,T2s)[0,:,:,:] + pred_slice=np.zeros((output_slice.shape[1],output_slice.shape[2]),np.uint8) + pred_slice=output_slice.data.max(0)[1].cpu().numpy() + pred_out[regions[0]:regions[1],regions[2]:regions[3],i_val]= \ + pred_slice[0:regions[1]-regions[0],0:regions[3]-regions[2]] + pred_out[:,:,i_val]=pred_out[:,:,i_val].transpose() + nib.Nifti1Image(pred_out,srcvol.affine,srcvol.header).to_filename('evaluation/result.nii') + print('predicted') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Params') + parser.add_argument('--gpu_id', nargs='?', type=int, default=-1, + help='GPU id, -1 for cpu') + parser.add_argument('--model_path', nargs='?', type=str, default='FCN_MR13_best_model.pkl', + help='Path to the saved model') + parser.add_argument('--val_num', nargs='?', type=int, default=1, + help='which sample to be validated') + parser.add_argument('--num_forest', nargs='?', type=int, default=3, + help='how much slices to be stacked') + args = parser.parse_args() + validate(args) + + diff --git a/vgga.jpg b/vgga.jpg new file mode 100644 index 0000000..69a7df2 Binary files /dev/null and b/vgga.jpg differ diff --git a/why1.jpg b/why1.jpg new file mode 100644 index 0000000..066f960 Binary files /dev/null and b/why1.jpg differ diff --git a/why2.jpg b/why2.jpg new file mode 100644 index 0000000..eaad88a Binary files /dev/null and b/why2.jpg differ