-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdata_utils.py
55 lines (48 loc) · 1.69 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Author: lionel
import argparse
import collections
import numpy as np
import tensorflow as tf
parser = argparse.ArgumentParser()
parser.add_argument('--train_file', type=str, default='biyu_model/train2.csv')
parser.add_argument('--test_size', type=float, default='0.1')
FLAGS, unparser = parser.parse_known_args()
# 加载数据并随机打乱
def load_data(file_name, sep=' ', sep1=',', isCharacter=False):
label_list = []
features_list = []
with tf.gfile.GFile(file_name, 'r') as f:
for line in f.readlines():
fields = line.strip().split(sep)
if len(fields) != 2:
continue
label = fields[0]
features = fields[1]
label_list.append(label)
if isCharacter:
features_list.append(list(features))
else:
features_list.append(features.split(sep1))
indices = np.random.permutation(np.arange(len(features_list)))
label_list = np.array(label_list)[indices]
features_list = np.array(features_list)[indices]
return label_list, features_list
# 词汇->id 映射
def build_word_dic(words_list, label_list, vocab_size=5000):
word_dic = dict()
word_dic['pad'] = 0
word_dic['unk'] = 1
all_words = []
for words in words_list:
all_words.extend(words)
counter = collections.Counter(all_words).most_common(vocab_size - 2)
words, _ = list(zip(*counter))
for word in words:
word_dic[word] = len(word_dic)
label_set = set(label_list)
label_dic = dict()
for label in label_set:
label_dic[label] = len(label_dic)
return words, word_dic, label_set, label_dic