-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfor_uta.py
33 lines (28 loc) · 1018 Bytes
/
for_uta.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from src.TrainDoc2VecModel import TrainDoc2VecModel
model_trainer = TrainDoc2VecModel()
aggregate_list = [
{"vs": 15, "e": 30},
{"vs": 20, "e": 40},
{"vs": 10, "e": 50},
{"vs": 15, "e": 80}
]
news20_list = [
{"vs": 15, "e": 30},
{"vs": 20, "e": 40},
{"vs": 10, "e": 50},
{"vs": 15, "e": 80}
]
for aggregate in aggregate_list:
file = open("./csv_models_for_uta/aggregate_vs_{}_epochs_{}.csv".format(aggregate["vs"], aggregate["e"]), "w")
model = model_trainer.get_aggregate_model(aggregate["vs"], aggregate["e"])
for vec in model.docvecs.vectors_docs:
file.write(','.join(map(str, vec)))
file.write('\n')
file.close()
for news20 in news20_list:
file = open("./csv_models_for_uta/news20_vs_{}_epochs_{}.csv".format(news20["vs"], news20["e"]), "w")
model = model_trainer.get_20news_model(news20["vs"], news20["e"])
for vec in model.docvecs.vectors_docs:
file.write(','.join(map(str, vec)))
file.write('\n')
file.close()