Skip to content

Commit

Permalink
utilities for loading and training cross-subject and within-subject
Browse files Browse the repository at this point in the history
  • Loading branch information
brf2 committed Apr 25, 2020
1 parent 03e5f68 commit a1f9d59
Show file tree
Hide file tree
Showing 6 changed files with 745 additions and 0 deletions.
52 changes: 52 additions & 0 deletions CNN/CXR/convert_serial_cxr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import matplotlib.pyplot as plt
import pdb as gdb
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
import nibabel as nib
from sklearn.utils import class_weight
from nibabel import processing as nip
import numpy as np
import scipy.ndimage.morphology as morph
import freesurfer.deeplearn as fsd
from freesurfer.deeplearn import utils, pprint
import freesurfer as fs
import os,socket
from netshape import *
from dipy.align.reslice import reslice
import neuron as ne
import voxelmorph as vxm
from netparms import *
from freesurfer import deeplearn as fsd
from freesurfer.deeplearn.utils import WeightsSaver, ModelSaver
import imageio, pydicom, gdcm, load_serial_cxr

bdir = '/autofs/cluster/lcnextdata1/CCDS_CXR/CXR-Serial/def_20200413'


il, sl, sn = load_serial_cxr.load_serial_cxr(bdir)

for sno, ilist in enumerate(il):
if len(ilist)>=2:
date_list = []
time_list = []
for ino, im in enumerate(ilist):
date_list.append(int(im.StudyDate))
if hasattr(im, 'SeriesTime'):
time_list.append(int(im.SeriesTime.split('.')[0]))
else:
time_list.append(int(im.StudyTime.split('.')[0]))

ind = np.array(date_list).argsort()
date_list2 = []
time_list2 = []
ilist2 = []
for i in ind.tolist():
date_list2.append(date_list[i])
time_list2.append(time_list[i])
ilist2.append(ilist[i])

for ino, im in enumerate(ilist2):
tokens = sl[sno][ind[ino]].split('/')
fname = '/'.join(tokens[0:-2]) + '/time%2.2d.mgz' % ino
fs.Volume(im.pixel_array.astype(np.float32)).write(fname)
146 changes: 146 additions & 0 deletions CNN/CXR/load_serial_cxr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import glob
import imageio
import pydicom
import gdcm
import os
import numpy as np
import freesurfer as fs
import pdb as gdb
import neuron as ne
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
import scipy.ndimage.morphology as morph

def load_serial_cxr(base_path, subject_names='CVM*', study_names='CVA*'):
subject_name_list = glob.glob(os.path.join(base_path, subject_names))
nsubjects = len(subject_name_list)
subject_list = []
image_list = []
for subject in subject_name_list:
studies = glob.glob(os.path.join(subject, study_names))
study_list = []
imlist = []
for study in studies:
scans = glob.glob(os.path.join(study, '*'))

# should only be 1 scan/study. Find the right one
scan_list = []
for scan in scans:
dnames = glob.glob(os.path.join(scan,'*.dcm'))
if len(dnames) < 1:
continue
slist = []
# more than 1 dicom in a dir means it was edge-enhanced
# or something went wrong - use the last one
for dname in dnames:
im = pydicom.dcmread(dname)
if 'SECONDARY' in im.ImageType:
im = None
continue

if im.SeriesDescription.find('AP')<0:
im = None
continue
if hasattr(im, 'DerivationDescription'):
if im.DerivationDescription.find('CATH') >= 0:
im = None
continue
slist.append(im)

im = slist[-1]
if im is not None:
scan_list.append(im)
if len(scan_list) > 0:
im = scan_list[-1]

# could check im.SeriesTime to pick last one
if im is not None:
imlist.append(im)
study_list.append(dname)
image_list.append(imlist)
subject_list.append(study_list)
return image_list, subject_list, subject_name_list

def load_timepoints(bdir, target_shape, tp_name='time??.mgz', dthresh=-1, ndilations=0):
il, sl, sn = load_serial_cxr(bdir)

vol_list = []
seg_list = []
dtrans_list = []
for sno, ilist in enumerate(il):
if len(ilist)>=2:
date_list = []
time_list = []
for ino, im in enumerate(ilist):
date_list.append(int(im.StudyDate))
if hasattr(im, 'SeriesTime'):
time_list.append(int(im.SeriesTime.split('.')[0]))
else:
time_list.append(int(im.StudyTime.split('.')[0]))

# sort input time points by acquistion date
ind = np.array(date_list).argsort()
date_list2 = []
time_list2 = []
ilist2 = []
for i in ind.tolist():
date_list2.append(date_list[i])
time_list2.append(time_list[i])
ilist2.append(ilist[i])

vlist = []
slist = []
dlist = []
bad = False
for ino, im in enumerate(ilist2):
tokens = sl[sno][ind[ino]].split('/')
fname = '/'.join(tokens[0:-2]) + '/time%2.2d.mgz' % ino
vol = fs.Image.read(fname)
zoomx = target_shape[0]/vol.shape[0]
zoomy = target_shape[1]/vol.shape[1]
vol.data = zoom(vol.data,(zoomx, zoomy),order=1)
vlist.append(vol)

fname = '/'.join(tokens[0:-2]) + '/time%2.2d.seg.mgz' % ino
if os.path.exists(fname) == False:
print('%s missing' % fname)
dvol = None
svol = None
bad = True
else:
svol = fs.Image.read(fname)

svol.data = zoom(svol.data,(zoomx, zoomy),order=0)
u = np.unique(svol.data)

# dilate input labels if specified by caller
if ndilations > 0:
dil_vol = np.zeros(svol.shape)
for l in list(u):
if l == 0:
continue
tmp = morph.binary_dilation(svol.data==l, iterations=ndilations)
dil_vol = dil_vol + l*tmp
svol.data = dil_vol

# build multiframe distance transform volume
dframes = []
for l in list(u):
if l == 0:
continue
dtrans = ndimage.distance_transform_edt(np.logical_not(svol.data == l))
if dthresh >= 0:
dtrans[dtrans>dthresh] = dthresh
dframes.append(dtrans)
dvol = np.transpose(np.array(dframes), (1,2,0))

slist.append(svol)
dlist.append(dvol)
if bad == True:
continue
vol_list.append(vlist)
seg_list.append(slist)
dtrans_list.append(dlist)
return vol_list, seg_list, dtrans_list, il, sl, sn


31 changes: 31 additions & 0 deletions CNN/CXR/netparms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np

wt_prefix = 'unet.cxr.mri'
wt_fname = wt_prefix + '.h5'

mri_unet_nfeatures = 30
mri_unet_depth = 4
nlabels = 2
mri_unet_feat_mult = 1.25
mri_unet_convs_per_level=8
dec_nf = [32, 32, 32, 32, 32, 16, 16]
enc_nf = [16, 32, 32, 32]

nlscale=.5
dec_nf = [int(32*nlscale), int(32*nlscale), int(32*nlscale), int(32*nlscale), int(32*nlscale), int(16*nlscale), int(16*nlscale)]
enc_nf = [int(16*nlscale), int(32*nlscale), int(32*nlscale), int(32*nlscale)]

fscale = 1
dec_nf_base = [32, 32, 32, 32, 16, 16]
enc_nf_base = [16, 32, 32]
enc_nf = [int(element * fscale) for element in enc_nf_base]
dec_nf = [int(element * fscale) for element in dec_nf_base]

enc_nf_affine = []
feature_scale = 1
for element in enc_nf:
enc_nf_affine.append(element*feature_scale)
feature_scale *= 2


#enc_nf_affine = [2*16, 4*32, 8*32, 8*32]
11 changes: 11 additions & 0 deletions CNN/CXR/netshape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

net_nfilters = 20
net_depth = 5
net_filters_per_level=2
use_prior=0
debug_net = 0
use_class_net = True
use_class_net = False
epochs = 128
affine_enc_features = [16, 32, 32, 32]
affine_input_patch_shape = (128,128,128)
Loading

0 comments on commit a1f9d59

Please sign in to comment.