forked from moon-hotel/TransformerTranslation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
152 lines (137 loc) · 7.64 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from copy import deepcopy
from config.config import Config
from model.TranslationModel import TranslationModel
from utils.data_helpers import LoadEnglishGermanDataset, my_tokenizer
import torch
import time
import os
import logging
class CustomSchedule(object):
def __init__(self, d_model, warmup_steps=4000, optimizer=None):
super(CustomSchedule, self).__init__()
self.d_model = torch.tensor(d_model, dtype=torch.float32)
self.warmup_steps = warmup_steps
self.steps = 1.
self.optimizer = optimizer
def step(self):
arg1 = self.steps ** -0.5
arg2 = self.steps * (self.warmup_steps ** -1.5)
self.steps += 1.
lr = (self.d_model ** -0.5) * min(arg1, arg2)
for p in self.optimizer.param_groups:
p['lr'] = lr
return lr
def accuracy(logits, y_true, PAD_IDX):
"""
:param logits: [tgt_len,batch_size,tgt_vocab_size]
:param y_true: [tgt_len,batch_size]
:param PAD_IDX:
:return:
"""
y_pred = logits.transpose(0, 1).argmax(axis=2).reshape(-1)
# 将 [tgt_len,batch_size,tgt_vocab_size] 转成 [batch_size, tgt_len,tgt_vocab_size]
y_true = y_true.transpose(0, 1).reshape(-1)
# 将 [tgt_len,batch_size] 转成 [batch_size, tgt_len]
acc = y_pred.eq(y_true) # 计算预测值与正确值比较的情况
mask = torch.logical_not(y_true.eq(PAD_IDX)) # 找到真实标签中,mask位置的信息。 mask位置为FALSE,非mask位置为TRUE
acc = acc.logical_and(mask) # 去掉acc中mask的部分
correct = acc.sum().item()
total = mask.sum().item()
return float(correct) / total, correct, total
def train_model(config):
logging.info("############载入数据集############")
data_loader = LoadEnglishGermanDataset(config.train_corpus_file_paths,
batch_size=config.batch_size,
tokenizer=my_tokenizer,
min_freq=config.min_freq)
logging.info("############划分数据集############")
train_iter, valid_iter, test_iter = \
data_loader.load_train_val_test_data(config.train_corpus_file_paths,
config.val_corpus_file_paths,
config.test_corpus_file_paths)
logging.info("############初始化模型############")
translation_model = TranslationModel(src_vocab_size=len(data_loader.de_vocab),
tgt_vocab_size=len(data_loader.en_vocab),
d_model=config.d_model,
nhead=config.num_head,
num_encoder_layers=config.num_encoder_layers,
num_decoder_layers=config.num_decoder_layers,
dim_feedforward=config.dim_feedforward,
dropout=config.dropout)
model_save_path = os.path.join(config.model_save_dir, 'model.pkl')
if os.path.exists(model_save_path):
loaded_paras = torch.load(model_save_path)
translation_model.load_state_dict(loaded_paras)
logging.info("#### 成功载入已有模型,进行追加训练...")
translation_model = translation_model.to(config.device)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=data_loader.PAD_IDX)
optimizer = torch.optim.Adam(translation_model.parameters(),
lr=0.,
betas=(config.beta1, config.beta2), eps=config.epsilon)
lr_scheduler = CustomSchedule(config.d_model, optimizer=optimizer)
translation_model.train()
for epoch in range(config.epochs):
losses = 0
start_time = time.time()
for idx, (src, tgt) in enumerate(train_iter):
src = src.to(config.device) # [src_len, batch_size]
tgt = tgt.to(config.device)
tgt_input = tgt[:-1, :] # 解码部分的输入, [tgt_len,batch_size]
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask \
= data_loader.create_mask(src, tgt_input, config.device)
logits = translation_model(
src=src, # Encoder的token序列输入,[src_len,batch_size]
tgt=tgt_input, # Decoder的token序列输入,[tgt_len,batch_size]
src_mask=src_mask, # Encoder的注意力Mask输入,这部分其实对于Encoder来说是没有用的
tgt_mask=tgt_mask,
# Decoder的注意力Mask输入,用于掩盖当前position之后的position [tgt_len,tgt_len]
src_key_padding_mask=src_padding_mask, # 用于mask掉Encoder的Token序列中的padding部分
tgt_key_padding_mask=tgt_padding_mask, # 用于mask掉Decoder的Token序列中的padding部分
memory_key_padding_mask=src_padding_mask) # 用于mask掉Encoder的Token序列中的padding部分
# logits 输出shape为[tgt_len,batch_size,tgt_vocab_size]
optimizer.zero_grad()
tgt_out = tgt[1:, :] # 解码部分的真实值 shape: [tgt_len,batch_size]
loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
# [tgt_len*batch_size, tgt_vocab_size] with [tgt_len*batch_size, ]
loss.backward()
lr_scheduler.step()
optimizer.step()
losses += loss.item()
acc, _, _ = accuracy(logits, tgt_out, data_loader.PAD_IDX)
msg = f"Epoch: {epoch}, Batch[{idx}/{len(train_iter)}], Train loss :{loss.item():.3f}, Train acc: {acc}"
logging.info(msg)
end_time = time.time()
train_loss = losses / len(train_iter)
msg = f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Epoch time = {(end_time - start_time):.3f}s"
logging.info(msg)
if epoch % 2 == 0:
acc = evaluate(config, valid_iter, translation_model, data_loader)
logging.info(f"Accuracy on validation{acc:.3f}")
state_dict = deepcopy(translation_model.state_dict())
torch.save(state_dict, model_save_path)
def evaluate(config, valid_iter, model, data_loader):
model.eval()
correct, totals = 0, 0
with torch.no_grad():
for idx, (src, tgt) in enumerate(valid_iter):
src = src.to(config.device)
tgt = tgt.to(config.device)
tgt_input = tgt[:-1, :] # 解码部分的输入
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = \
data_loader.create_mask(src, tgt_input, device=config.device)
logits = model(src=src, # Encoder的token序列输入,
tgt=tgt_input, # Decoder的token序列输入
src_mask=src_mask, # Encoder的注意力Mask输入,这部分其实对于Encoder来说是没有用的
tgt_mask=tgt_mask, # Decoder的注意力Mask输入,用于掩盖当前position之后的position
src_key_padding_mask=src_padding_mask, # 用于mask掉Encoder的Token序列中的padding部分
tgt_key_padding_mask=tgt_padding_mask, # 用于mask掉Decoder的Token序列中的padding部分
memory_key_padding_mask=src_padding_mask) # 用于mask掉Encoder的Token序列中的padding部分
tgt_out = tgt[1:, :] # 解码部分的真实值 shape: [tgt_len,batch_size]
_, c, t = accuracy(logits, tgt_out, data_loader.PAD_IDX)
correct += c
totals += t
model.train()
return float(correct) / totals
if __name__ == '__main__':
config = Config()
train_model(config)