-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess.py
44 lines (38 loc) · 1.46 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import collections
import os
import random
import time
import re
from tqdm import tqdm
import torch
import torchtext.vocab as Vocab
import torch.utils.data as Data
import torch.nn.functional as F
from dataset import vocab
from RNNnet import BiRNN
embed_size, num_hiddens, num_layers = 100, 100, 2
net = BiRNN(vocab, embed_size, num_hiddens, num_layers)
cache_dir = 'C:\\Users\\46562\\Desktop\\Rnn datasets'
glove_vocab = Vocab.GloVe(name='6B', dim=100, cache=cache_dir)
def load_pretrained_embedding(words, pretrained_vocab):
'''
@params:
words: 需要加载词向量的词语列表,以 itos (index to string) 的词典形式给出
pretrained_vocab: 预训练词向量
@return:
embed: 加载到的词向量
'''
embed = torch.zeros(len(words), pretrained_vocab.vectors[0].shape[0]) # 初始化为len*100维度
oov_count = 0 # out of vocabulary
for i, word in enumerate(words):
try:
idx = pretrained_vocab.stoi[word]
embed[i, :] = pretrained_vocab.vectors[idx]# 将每个词语用训练的语言模型理解
except KeyError:
oov_count += 1
if oov_count > 0:
print("There are %d oov words." % oov_count)
# print(embed.shape),在词典中寻找相匹配的词向量
return embed
net.embedding.weight.data.copy_(load_pretrained_embedding(vocab.itos, glove_vocab))
net.embedding.weight.requires_grad = False # 直接加载预训练好的, 所以不需要更新它