This repository has been archived by the owner on Apr 24, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathpredictor.py
210 lines (146 loc) · 6.38 KB
/
predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
#!/usr/bin/python
# coding: utf-8
from PyQt5 import QtSql, QtCore
import os
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction import text
from sklearn.svm import LinearSVC
import datetime
# Personal
from log import MyLog
import functions
class Predictor(QtCore.QThread):
"""Object to predict the percentage match of an article,
based on its abstract"""
def __init__(self, logger, to_read_list, bdd=None):
QtCore.QThread.__init__(self)
self.to_read_list = to_read_list
self.x_train = []
self.y_train = []
self.classifier = None
if bdd is None:
self.bdd = QtSql.QSqlDatabase.addDatabase("QSQLITE")
self.bdd.setDatabaseName("fichiers.sqlite")
self.bdd.open()
else:
self.bdd = bdd
self.l = logger
self.getStopWords()
# To check if initializePipeline completed
self.initiated = False
# To check if match percentages were calculated
self.calculated_something = False
def getStopWords(self):
"""Method to get english stop words
+ a list of personnal stop words"""
my_additional_stop_words = []
resource_dir, _ = functions.getRightDirs()
with open(os.path.join(resource_dir, 'config/stop_words.txt'), 'r',
encoding='utf-8') as config:
for word in config.readlines():
my_additional_stop_words.append(word.rstrip())
self.stop_words = text.ENGLISH_STOP_WORDS.union(my_additional_stop_words)
def initializePipeline(self):
"""Initialize the pipeline for text analysis. 0 is the liked class"""
start_time = datetime.datetime.now()
query = QtSql.QSqlQuery(self.bdd)
query.exec_("SELECT * FROM papers WHERE new=0")
while query.next():
record = query.record()
abstract = record.value('topic_simple')
id_bdd = record.value('id')
# Do not use 'Empty' abstracts
if type(abstract) is not str or abstract == 'Empty':
continue
liked = record.value('liked')
if type(liked) is int and liked == 1:
category = 0
else:
# Do not count the read and not liked articles if the articles
# are in the waiting list
if id_bdd not in self.to_read_list:
category = 1
else:
continue
self.x_train.append(abstract)
self.y_train.append(category)
# To count for RuntimeWarning: divide by zero encountered in log
if (not self.x_train or 0 not in self.y_train or
1 not in self.y_train):
self.l.error("Not enough data yet to feed the classifier")
return None
self.classifier = Pipeline([
('vectorizer', CountVectorizer(stop_words=self.stop_words)),
('tfidf', TfidfTransformer()),
('clf', LinearSVC())])
try:
self.classifier.fit(self.x_train, self.y_train)
except ValueError:
self.l.error("Not enough data yet to train the classifier")
return None
elapsed_time = datetime.datetime.now() - start_time
self.l.debug("Training classifier in {0}".format(elapsed_time))
self.initiated = True
def run(self):
"""Calculate the match percentage for each article,
based on the abstract text and the liked articles"""
if not self.initiated:
self.l.debug("NOT starting calculations, not initiated")
return
self.l.debug("Starting calculations of match percentages")
start_time = datetime.datetime.now()
query = QtSql.QSqlQuery(self.bdd)
# topic_simple also contains the title of the abstract
# the calculations will be performed on the topic and title
query.exec_("SELECT id, topic_simple FROM papers")
list_id = []
x_test = []
while query.next():
record = query.record()
abstract = record.value('topic_simple')
x_test.append(abstract)
list_id.append(record.value('id'))
try:
# Normalize the percentages: the highest is set to 100%
# http://stackoverflow.com/questions/929103/convert-a-number-range-to-another-range-maintaining-ratio
x_test = self.classifier.decision_function(x_test)
self.l.debug("Classifier predicted proba in {}".
format(datetime.datetime.now() - start_time))
diff_time = datetime.datetime.now()
maximum = max(x_test)
minimum = min(x_test)
list_percentages = 100 - (x_test - minimum) * 100 / (maximum - minimum)
self.l.debug("Classifier normalized proba in {}".
format(datetime.datetime.now() - diff_time))
except AttributeError:
self.l.error("Not enough data yet to predict probability")
return
except Exception as e:
self.l.error("predictor: {}".format(e), exc_info=True)
return
diff_time = datetime.datetime.now()
self.bdd.transaction()
query = QtSql.QSqlQuery(self.bdd)
query.prepare("UPDATE papers SET percentage_match = ? WHERE id = ?")
for id_bdd, percentage in zip(list_id, list_percentages):
# Convert the percentage to a float, because the number is
# probably a type used by numpy. MANDATORY
params = (float(percentage), id_bdd)
for value in params:
query.addBindValue(value)
query.exec_()
if not self.bdd.commit():
self.l.critical("Percentages match not correctly written in db")
else:
self.l.debug("Percentages written to db in {}".
format(datetime.datetime.now() - diff_time))
self.l.debug("Done calculating match percentages in {}".
format(datetime.datetime.now() - start_time))
self.calculated_something = True
if __name__ == "__main__":
logger = MyLog("test.log")
predictor = Predictor(logger, [])
predictor.initializePipeline()
predictor.calculatePercentageMatch()