-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
42 lines (27 loc) · 1.12 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from dec import DeepEmbeddingClustering
import numpy as np
import pickle
import config
import utils
import os
if __name__ == "__main__":
args = config.get_args()
# embedding vector data
data_path = os.path.join(args.save_embedding_vector, args.dataset)
if os.path.isfile(data_path):
data = pickle.load(open(data_path, "rb"))
embedding_vector, label = data
# normalize
normalized_doc_embeddings = utils.normalization_vector(embedding_vector)
# check accuracy utilized k-mean
utils.check_kmean_accuracy(normalized_doc_embeddings, label)
dec = DeepEmbeddingClustering(args, n_clusters=len(np.unique(label)))
# greedy-layer wise auto-encoder
dec.initialize(args,
normalized_doc_embeddings,
finetune_iters=args.finetune_iters,
layerwise_pretrain_iters=args.layerwise_pretrain_iters)
# update z space of patent document vector
dec.cluster(args, x_data=normalized_doc_embeddings, y_data=label, test=args.task)
else:
print("embedding patent document first!")