-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathclassifyber.py
executable file
·190 lines (159 loc) · 7.62 KB
/
classifyber.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
#!/usr/bin/env python
""" Classification of multiple bundles from multiple examples.
"""
import os
import sys
import argparse
import os.path
import numpy as np
import time
import ntpath
import nibabel as nib
import pickle
from utils import compute_kdtree_and_dr_tractogram, compute_superset, streamlines_idx, save_trk
from dipy.tracking.distances import bundles_distances_mam, bundles_distances_mdf
from dipy.tracking.streamline import set_number_of_points
from collections import OrderedDict
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score
from dipy.segment.clustering import QuickBundles
from subsampling import compute_subset
from features_mni import compute_feature_matrix
#global configuration parameters
num_prototypes = 100
distance_func = bundles_distances_mdf
nb_points = 20
cw = {0:1, 1:3}
max_iter = 1000
ds_factor = 1
def compute_X_y_train(subjID, tract_name, moving_tractogram_fname, example_fname):
"""Compute X_train and y_train.
"""
moving_tractogram = nib.streamlines.load(moving_tractogram_fname)
moving_tractogram = moving_tractogram.streamlines
print("Compute kdt and prototypes of %s" %moving_tractogram_fname)
kdt, prototypes = compute_kdtree_and_dr_tractogram(moving_tractogram, num_prototypes=num_prototypes,
distance_func=distance_func, nb_points=nb_points)
tract = nib.streamlines.load(example_fname)
tract = tract.streamlines
print("Computing the superset of %s" %example_fname)
superset_idx = compute_superset(tract, kdt, prototypes, k=2000, distance_func=distance_func, nb_points=nb_points)
if ds_factor<1:
print("Downsampling the superset of a factor %s" %ds_factor)
superset_idx = np.random.choice(superset_idx, int(len(superset_idx)*ds_factor), replace=False)
superset = moving_tractogram[superset_idx]
exID = ntpath.basename(moving_tractogram_fname)[4:10]
print("Computing X_train.")
X_train = compute_feature_matrix(superset, exID, subjID, tract_name, distance_func=distance_func, nb_points=nb_points)
print("Computing y_train.")
y_train = np.zeros(len(superset))
tract_idx = streamlines_idx(tract, kdt, prototypes, distance_func=distance_func, nb_points=nb_points)
if ds_factor<1:
tract_idx = np.intersect1d(superset_idx, tract_idx)
correspondent_idx = np.array([np.where(superset_idx==idx) for idx in tract_idx])
y_train[correspondent_idx] = 1
return X_train, y_train
def compute_union_superset_idx(kdt, prototypes, ex_dir_tract, distance_func=bundles_distances_mam, nb_points=20):
"""Compute a superset in a subject starting from the tracts of other subjects.
"""
union_superset_idx = []
examples = os.listdir(ex_dir_tract)
examples.sort()
ne = len(examples)
th = np.min([ne, 10]) #maximum 10 subjects
for i in range(th):
example_fname = '%s/%s' %(ex_dir_tract, examples[i])
tract = nib.streamlines.load(example_fname)
tract = tract.streamlines
superset_idx_test = compute_superset(tract, kdt, prototypes, k=2000, distance_func=distance_func, nb_points=nb_points)
union_superset_idx = np.concatenate((union_superset_idx, superset_idx_test))
print("Total size superset: %s" %len(union_superset_idx))
union_superset_idx = list(OrderedDict.fromkeys(union_superset_idx)) #removes duplicates
union_superset_idx = np.array(union_superset_idx, dtype=int)
print("Size reducted superset: %s" %len(union_superset_idx))
return union_superset_idx
def classifyber(moving_tractograms_dir, static_tractogram_fname, ex_dir_tract):
"""Code for classification from multiple examples.
"""
subjID = ntpath.basename(static_tractogram_fname)[4:10]
tract_name = ntpath.basename(ex_dir_tract)
moving_tractograms = os.listdir(moving_tractograms_dir)
moving_tractograms.sort()
examples = os.listdir(ex_dir_tract)
examples.sort()
nt = len(moving_tractograms)
ne = len(examples)
assert(nt == ne)
X_train = np.array([])
y_train = np.array([])
print("Computing training set using %i examples." %ne)
for i in range(nt):
moving_tractogram_fname = '%s/%s' %(moving_tractograms_dir, moving_tractograms[i])
example_fname = '%s/%s' %(ex_dir_tract, examples[i])
X_tmp, y_tmp = compute_X_y_train(subjID, tract_name, moving_tractogram_fname, example_fname)
X_train = np.vstack([X_train, X_tmp]) if X_train.size else X_tmp
y_train = np.hstack([y_train, y_tmp]) if y_train.size else y_tmp
print(X_train.shape)
print("Computing X_test.")
static_tractogram = nib.streamlines.load(static_tractogram_fname)
static_tractogram = static_tractogram.streamlines
print("Compute kdt and prototypes of %s" %static_tractogram_fname)
kdt, prototypes = compute_kdtree_and_dr_tractogram(static_tractogram, num_prototypes=num_prototypes,
distance_func=distance_func, nb_points=nb_points)
print("Computing the test superset...")
union_superset_idx = compute_union_superset_idx(kdt, prototypes, ex_dir_tract, distance_func=distance_func, nb_points=nb_points)
static_superset = static_tractogram[union_superset_idx]
X_test = compute_feature_matrix(static_superset, subjID, subjID, tract_name, distance_func=distance_func, nb_points=nb_points)
del kdt, static_superset
print("Normalize X_train and X_test.")
scaler = StandardScaler()
scaler.fit(X_train)
X_train = scaler.transform(X_train)
X_test = scaler.transform(X_test)
print("Classification.")
clf = LogisticRegression(class_weight=cw, random_state=42, solver='sag', max_iter=max_iter)
t0=time.time()
clf.fit(X_train, y_train)
print("---->Time to fit X_train of size (%s, %s) = %s seconds" %(X_train.shape[0], X_train.shape[1], time.time()-t0))
t1=time.time()
y_pred = clf.predict(X_test)
y_pred_proba = clf.predict_proba(X_test)
print("---->Time to predict X_test of size (%s, %s) = %s seconds" %(X_test.shape[0], X_test.shape[1], time.time()-t1))
estimated_tract_idx = np.where(y_pred>0)[0]
estimated_tract = static_tractogram[union_superset_idx[estimated_tract_idx]]
#clf_fname = 'clf_%s' %tract_name
#pickle.dump(clf, open(clf_fname, 'w'), protocol=pickle.HIGHEST_PROTOCOL)
#scaler_fname = 'scaler_%s' %tract_name
#pickle.dump(scaler, open(scaler_fname, 'w'), protocol=pickle.HIGHEST_PROTOCOL)
np.save('estimated_idx_%s.npy' %tract_name, union_superset_idx[estimated_tract_idx])
return estimated_tract
if __name__ == '__main__':
np.random.seed(0)
parser = argparse.ArgumentParser()
parser.add_argument('-moving_dir', nargs='?', const=1, default='',
help='The moving tractograms directory')
parser.add_argument('-static', nargs='?', const=1, default='',
help='The static tractogram filename')
parser.add_argument('-ex_dir', nargs='?', const=1, default='',
help='The examples (moving) bundle directory')
parser.add_argument('-list', nargs='?', const=1, default='',
help='The tract name list file .txt')
parser.add_argument('-out_dir', nargs='?', const=1, default='default',
help='The output directory')
args = parser.parse_args()
t0=time.time()
#os.chdir('/N/u/gberto/Karst/classifiber_neuroimage/code')
with open(args.list) as f:
tract_name_list = f.read().splitlines()
for tract_name in tract_name_list:
t1=time.time()
print("Classification of tract: %s" %tract_name)
ex_dir_tract = '%s/%s' %(args.ex_dir, tract_name)
estimated_tract = classifyber(args.moving_dir, args.static, ex_dir_tract)
print("Time to compute classification of tract %s = %i minutes" %(tract_name, (time.time()-t1)/60))
out_fname = '%s/%s.trk' %(args.out_dir, tract_name)
save_trk(estimated_tract, out_fname)
print("Tract saved in %s" %out_fname)
print("Total time elapsed for the classification of all the tracts = %i minutes" %((time.time()-t0)/60))
sys.exit()