forked from zenRRan/Sentiment-Analysis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
51 lines (38 loc) · 1.38 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/env python
# encoding: utf-8
"""
@version: python3.6
@author: 'zenRRan'
@license: Apache Licence
@contact: [email protected]
@software: PyCharm
@file: main.py
@time: 2018/10/7 15:49
"""
import argparse
import utils.opts as opts
import torch
from utils.trainer import Trainer
if __name__ == '__main__':
# get the train opts
parser = argparse.ArgumentParser('Train opts')
parser = opts.trainer_opts(parser)
opts = parser.parse_args()
# set threads num
torch.set_num_threads(opts.thread)
# load the data
train_features_list = torch.load(opts.data_dir + '/train.sst')
dev_features_list = torch.load(opts.data_dir + '/dev.sst')
test_features_list = torch.load(opts.data_dir + '/test.sst')
# load word-level vocab
vocab = torch.load(opts.data_dir + '/vocab.sst')
# load char-level vocab
char_vocab = torch.load(opts.data_dir + '/char_vocab.sst')
label_vocab = torch.load(opts.data_dir + '/label_vocab.sst')
rel_vocab = torch.load(opts.data_dir + '/rel_vocab.sst')
train_dev_test = (train_features_list, dev_features_list, test_features_list)
#build batch
# build_batcher = Build_Batch(features=train_features_list, opts=opts, pad_idx=vocab)
vocab = (vocab, char_vocab)
train = Trainer(train_dev_test, opts, vocab, label_vocab, rel_vocab=rel_vocab)
train.train()