forked from walterreade/pyensemble
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathensemble_predict.py
executable file
·141 lines (103 loc) · 4.34 KB
/
ensemble_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python
# Author: David C. Lambert [dcl -at- panix -dot- com]
# Copyright(c) 2013
# License: Simple BSD
"""
==========================================================
Prediction utility for trained EnsembleSelectionClassifier
==========================================================
Get predictions from trained EnsembleSelectionClassifier given
svm format data file.
Can output predicted classes or probabilities from the full
ensemble or just the best model.
Expects to find a trained ensemble in the sqlite db specified.
usage: ensemble_predict.py [-h] [-s {best,ens}] [-p] db_file data_file
Get EnsembleSelectionClassifier or EnsembleSelectionRegressor predictions
positional arguments:
db_file sqlite db file containing model
data_file testing data in svm format
optional arguments:
-h, --help show this help message and exit
-s {best,ens} choose source of prediction ["best", "ens"]
-p predict probabilities
-T --type {'Regression','Classification'
type of method to implement (regression or classification)
"""
from __future__ import print_function
import numpy as np
from argparse import ArgumentParser
from sklearn.datasets import load_svmlight_file
from ensemble import EnsembleSelectionClassifier, EnsembleSelectionRegressor
def parse_args():
desc = 'Get EnsembleSelectionClassifier predictions'
parser = ArgumentParser(description=desc)
method_choices = ['Regression', 'Classification']
dflt_fmt = '(default: %(default)s)'
help_fmt = 'method of estimation %s' % dflt_fmt
parser.add_argument('-T', dest='meth', nargs='+', choices=method_choices, help=help_fmt, default=['Regression'])
parser.add_argument('db_file', help='sqlite db file containing model')
parser.add_argument('data_file', help='testing data in svm format')
help_fmt = 'choose source of prediction ["best", "ens"] (default "ens")'
parser.add_argument('-s', dest='pred_src',
choices=['best', 'ens'],
help=help_fmt, default='ens')
parser.add_argument('-p', dest='return_probs',
action='store_true', default=False,
help='predict probabilities')
parser.add_argument('-n_features', dest='nfeat',
default=False, help='number of features from training in testing set..fix problem with '
'svmlight import due to sparsity')
return parser.parse_args()
def predictMan(res):
try:
X, _ = load_svmlight_file(res.data_file, n_features=res.nfeat)
except Exception:
X, _ = load_svmlight_file(res.data_file)
X = X.toarray()
if res.meth[0] == 'Classification':
ens = EnsembleSelectionClassifier(db_file=res.db_file, models=None)
elif res.meth[0] == 'Regression':
ens = EnsembleSelectionRegressor(db_file=res.db_file, models=None)
else:
msg = "Invalid method passed (-T does not conform to ['Regression','Classification']"
raise ValueError(msg)
if (res.pred_src == 'best'):
preds = ens.best_model_predict_proba(X)
else:
preds = ens.predict_proba(X)
if res.meth[0] == 'Classification':
if (not res.return_probs):
preds = np.argmax(preds, axis=1)
for p in preds:
if (res.return_probs):
mesg = " ".join(["%.5f" % v for v in p])
else:
mesg = p
print(str(mesg))
return preds
if (__name__ == '__main__'):
res = parse_args()
predictMan(res)
'''X, _ = load_svmlight_file(res.data_file)
X = X.toarray()
if res.meth[0] == 'Classification':
ens = EnsembleSelectionClassifier(db_file=res.db_file, models=None)
elif res.meth[0] == 'Regression':
ens = EnsembleSelectionRegressor(db_file=res.db_file, models=None)
else:
msg = "Invalid method passed (-T does not conform to ['Regression','Classification']"
raise ValueError(msg)
if (res.pred_src == 'best'):
preds = ens.best_model_predict_proba(X)
else:
preds = ens.predict_proba(X)
if res.meth[0] == 'Classification':
if (not res.return_probs):
preds = np.argmax(preds, axis=1)
for p in preds:
if (res.return_probs):
mesg = " ".join(["%.5f" % v for v in p])
else:
mesg = p
print(mesg)
'''