-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain_dot.py
71 lines (51 loc) · 2.06 KB
/
train_dot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import logging
import gensim
from gensim import corpora
from gensim.corpora import BleiCorpus
from pymongo import MongoClient
from deeplearning_settings import GlobalSettings
class Corpus(object):
def __init__(self, cursor, pages_dictionary, corpus_path):
self.cursor = cursor
self.pages_dictionary = pages_dictionary
self.corpus_path = corpus_path
def __iter__(self):
self.cursor.rewind()
for page in self.cursor:
yield self.pages_dictionary.doc2bow(page["words"])
def serialize(self):
BleiCorpus.serialize(self.corpus_path, self, id2word=self.pages_dictionary)
return self
class Dictionary(object):
def __init__(self, cursor, dictionary_path):
self.cursor = cursor
self.dictionary_path = dictionary_path
def build(self):
self.cursor.rewind()
dictionary = corpora.Dictionary(page["words"] for page in self.cursor)
dictionary.filter_extremes(keep_n=10000)
dictionary.compactify()
corpora.Dictionary.save(dictionary, self.dictionary_path)
return dictionary
class Train:
def __init__(self):
pass
@staticmethod
def run(lda_model_path, corpus_path, num_topics, id2word):
corpus = corpora.BleiCorpus(corpus_path)
lda = gensim.models.LdaModel(corpus, num_topics=num_topics, id2word=id2word)
lda.save(lda_model_path)
return lda
def main():
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO)
dictionary_path = "models/dictionary.dict"
corpus_path = "models/corpus.lda-c"
lda_num_topics = 100
lda_model_path = "models/lda_model_50_topics.lda"
corpus_db = MongoClient(GlobalSettings.MONGO_URI)[GlobalSettings.DATABASE_DOT][GlobalSettings.COLLECTION_CORPUS]
corpus_cursor = corpus_db.find()
dictionary = Dictionary(corpus_cursor, dictionary_path).build()
Corpus(corpus_cursor, dictionary, corpus_path).serialize()
Train.run(lda_model_path, corpus_path, lda_num_topics, dictionary)
if __name__ == '__main__':
main()