forked from freesurfer/freesurfer
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
utilities for loading and training cross-subject and within-subject
- Loading branch information
Showing
6 changed files
with
745 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import matplotlib.pyplot as plt | ||
import pdb as gdb | ||
import tensorflow as tf | ||
from tensorflow import keras | ||
from tensorflow.keras import backend as K | ||
import nibabel as nib | ||
from sklearn.utils import class_weight | ||
from nibabel import processing as nip | ||
import numpy as np | ||
import scipy.ndimage.morphology as morph | ||
import freesurfer.deeplearn as fsd | ||
from freesurfer.deeplearn import utils, pprint | ||
import freesurfer as fs | ||
import os,socket | ||
from netshape import * | ||
from dipy.align.reslice import reslice | ||
import neuron as ne | ||
import voxelmorph as vxm | ||
from netparms import * | ||
from freesurfer import deeplearn as fsd | ||
from freesurfer.deeplearn.utils import WeightsSaver, ModelSaver | ||
import imageio, pydicom, gdcm, load_serial_cxr | ||
|
||
bdir = '/autofs/cluster/lcnextdata1/CCDS_CXR/CXR-Serial/def_20200413' | ||
|
||
|
||
il, sl, sn = load_serial_cxr.load_serial_cxr(bdir) | ||
|
||
for sno, ilist in enumerate(il): | ||
if len(ilist)>=2: | ||
date_list = [] | ||
time_list = [] | ||
for ino, im in enumerate(ilist): | ||
date_list.append(int(im.StudyDate)) | ||
if hasattr(im, 'SeriesTime'): | ||
time_list.append(int(im.SeriesTime.split('.')[0])) | ||
else: | ||
time_list.append(int(im.StudyTime.split('.')[0])) | ||
|
||
ind = np.array(date_list).argsort() | ||
date_list2 = [] | ||
time_list2 = [] | ||
ilist2 = [] | ||
for i in ind.tolist(): | ||
date_list2.append(date_list[i]) | ||
time_list2.append(time_list[i]) | ||
ilist2.append(ilist[i]) | ||
|
||
for ino, im in enumerate(ilist2): | ||
tokens = sl[sno][ind[ino]].split('/') | ||
fname = '/'.join(tokens[0:-2]) + '/time%2.2d.mgz' % ino | ||
fs.Volume(im.pixel_array.astype(np.float32)).write(fname) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import glob | ||
import imageio | ||
import pydicom | ||
import gdcm | ||
import os | ||
import numpy as np | ||
import freesurfer as fs | ||
import pdb as gdb | ||
import neuron as ne | ||
from scipy import ndimage | ||
from scipy.ndimage.interpolation import zoom | ||
import scipy.ndimage.morphology as morph | ||
|
||
def load_serial_cxr(base_path, subject_names='CVM*', study_names='CVA*'): | ||
subject_name_list = glob.glob(os.path.join(base_path, subject_names)) | ||
nsubjects = len(subject_name_list) | ||
subject_list = [] | ||
image_list = [] | ||
for subject in subject_name_list: | ||
studies = glob.glob(os.path.join(subject, study_names)) | ||
study_list = [] | ||
imlist = [] | ||
for study in studies: | ||
scans = glob.glob(os.path.join(study, '*')) | ||
|
||
# should only be 1 scan/study. Find the right one | ||
scan_list = [] | ||
for scan in scans: | ||
dnames = glob.glob(os.path.join(scan,'*.dcm')) | ||
if len(dnames) < 1: | ||
continue | ||
slist = [] | ||
# more than 1 dicom in a dir means it was edge-enhanced | ||
# or something went wrong - use the last one | ||
for dname in dnames: | ||
im = pydicom.dcmread(dname) | ||
if 'SECONDARY' in im.ImageType: | ||
im = None | ||
continue | ||
|
||
if im.SeriesDescription.find('AP')<0: | ||
im = None | ||
continue | ||
if hasattr(im, 'DerivationDescription'): | ||
if im.DerivationDescription.find('CATH') >= 0: | ||
im = None | ||
continue | ||
slist.append(im) | ||
|
||
im = slist[-1] | ||
if im is not None: | ||
scan_list.append(im) | ||
if len(scan_list) > 0: | ||
im = scan_list[-1] | ||
|
||
# could check im.SeriesTime to pick last one | ||
if im is not None: | ||
imlist.append(im) | ||
study_list.append(dname) | ||
image_list.append(imlist) | ||
subject_list.append(study_list) | ||
return image_list, subject_list, subject_name_list | ||
|
||
def load_timepoints(bdir, target_shape, tp_name='time??.mgz', dthresh=-1, ndilations=0): | ||
il, sl, sn = load_serial_cxr(bdir) | ||
|
||
vol_list = [] | ||
seg_list = [] | ||
dtrans_list = [] | ||
for sno, ilist in enumerate(il): | ||
if len(ilist)>=2: | ||
date_list = [] | ||
time_list = [] | ||
for ino, im in enumerate(ilist): | ||
date_list.append(int(im.StudyDate)) | ||
if hasattr(im, 'SeriesTime'): | ||
time_list.append(int(im.SeriesTime.split('.')[0])) | ||
else: | ||
time_list.append(int(im.StudyTime.split('.')[0])) | ||
|
||
# sort input time points by acquistion date | ||
ind = np.array(date_list).argsort() | ||
date_list2 = [] | ||
time_list2 = [] | ||
ilist2 = [] | ||
for i in ind.tolist(): | ||
date_list2.append(date_list[i]) | ||
time_list2.append(time_list[i]) | ||
ilist2.append(ilist[i]) | ||
|
||
vlist = [] | ||
slist = [] | ||
dlist = [] | ||
bad = False | ||
for ino, im in enumerate(ilist2): | ||
tokens = sl[sno][ind[ino]].split('/') | ||
fname = '/'.join(tokens[0:-2]) + '/time%2.2d.mgz' % ino | ||
vol = fs.Image.read(fname) | ||
zoomx = target_shape[0]/vol.shape[0] | ||
zoomy = target_shape[1]/vol.shape[1] | ||
vol.data = zoom(vol.data,(zoomx, zoomy),order=1) | ||
vlist.append(vol) | ||
|
||
fname = '/'.join(tokens[0:-2]) + '/time%2.2d.seg.mgz' % ino | ||
if os.path.exists(fname) == False: | ||
print('%s missing' % fname) | ||
dvol = None | ||
svol = None | ||
bad = True | ||
else: | ||
svol = fs.Image.read(fname) | ||
|
||
svol.data = zoom(svol.data,(zoomx, zoomy),order=0) | ||
u = np.unique(svol.data) | ||
|
||
# dilate input labels if specified by caller | ||
if ndilations > 0: | ||
dil_vol = np.zeros(svol.shape) | ||
for l in list(u): | ||
if l == 0: | ||
continue | ||
tmp = morph.binary_dilation(svol.data==l, iterations=ndilations) | ||
dil_vol = dil_vol + l*tmp | ||
svol.data = dil_vol | ||
|
||
# build multiframe distance transform volume | ||
dframes = [] | ||
for l in list(u): | ||
if l == 0: | ||
continue | ||
dtrans = ndimage.distance_transform_edt(np.logical_not(svol.data == l)) | ||
if dthresh >= 0: | ||
dtrans[dtrans>dthresh] = dthresh | ||
dframes.append(dtrans) | ||
dvol = np.transpose(np.array(dframes), (1,2,0)) | ||
|
||
slist.append(svol) | ||
dlist.append(dvol) | ||
if bad == True: | ||
continue | ||
vol_list.append(vlist) | ||
seg_list.append(slist) | ||
dtrans_list.append(dlist) | ||
return vol_list, seg_list, dtrans_list, il, sl, sn | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import numpy as np | ||
|
||
wt_prefix = 'unet.cxr.mri' | ||
wt_fname = wt_prefix + '.h5' | ||
|
||
mri_unet_nfeatures = 30 | ||
mri_unet_depth = 4 | ||
nlabels = 2 | ||
mri_unet_feat_mult = 1.25 | ||
mri_unet_convs_per_level=8 | ||
dec_nf = [32, 32, 32, 32, 32, 16, 16] | ||
enc_nf = [16, 32, 32, 32] | ||
|
||
nlscale=.5 | ||
dec_nf = [int(32*nlscale), int(32*nlscale), int(32*nlscale), int(32*nlscale), int(32*nlscale), int(16*nlscale), int(16*nlscale)] | ||
enc_nf = [int(16*nlscale), int(32*nlscale), int(32*nlscale), int(32*nlscale)] | ||
|
||
fscale = 1 | ||
dec_nf_base = [32, 32, 32, 32, 16, 16] | ||
enc_nf_base = [16, 32, 32] | ||
enc_nf = [int(element * fscale) for element in enc_nf_base] | ||
dec_nf = [int(element * fscale) for element in dec_nf_base] | ||
|
||
enc_nf_affine = [] | ||
feature_scale = 1 | ||
for element in enc_nf: | ||
enc_nf_affine.append(element*feature_scale) | ||
feature_scale *= 2 | ||
|
||
|
||
#enc_nf_affine = [2*16, 4*32, 8*32, 8*32] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
|
||
net_nfilters = 20 | ||
net_depth = 5 | ||
net_filters_per_level=2 | ||
use_prior=0 | ||
debug_net = 0 | ||
use_class_net = True | ||
use_class_net = False | ||
epochs = 128 | ||
affine_enc_features = [16, 32, 32, 32] | ||
affine_input_patch_shape = (128,128,128) |
Oops, something went wrong.