-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
248 lines (182 loc) · 11.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import argparse
import os
import time
import numpy as np
import torch
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from dataset import SequencePairDataset
from model.encoder_decoder import EncoderDecoder
from evaluate import evaluate
from utils import to_np, trim_seqs
from tensorboardX import SummaryWriter
from tqdm import tqdm
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
def train(encoder_decoder: EncoderDecoder,
train_data_loader: DataLoader,
model_name,
val_data_loader: DataLoader,
keep_prob,
teacher_forcing_schedule,
lr,
max_length):
global_step = 0
loss_function = torch.nn.NLLLoss(ignore_index=0)
optimizer = optim.Adam(encoder_decoder.parameters(), lr=lr)
model_path = './model/' + model_name + '/'
for epoch, teacher_forcing in enumerate(teacher_forcing_schedule):
print('epoch %i' % epoch, flush=True)
for batch_idx, (input_idxs, target_idxs, input_tokens, target_tokens) in enumerate(tqdm(train_data_loader)):
# input_idxs and target_idxs have dim (batch_size x max_len)
# they are NOT sorted by length
lengths = (input_idxs != 0).long().sum(dim=1)
sorted_lengths, order = torch.sort(lengths, descending=True)
input_variable = Variable(input_idxs[order, :][:, :max(lengths)])
target_variable = Variable(target_idxs[order, :])
optimizer.zero_grad()
output_log_probs, output_seqs = encoder_decoder(input_variable,
list(sorted_lengths),
targets=target_variable,
keep_prob=keep_prob,
teacher_forcing=teacher_forcing)
batch_size = input_variable.shape[0]
flattened_outputs = output_log_probs.view(batch_size * max_length, -1)
batch_loss = loss_function(flattened_outputs, target_variable.contiguous().view(-1))
batch_loss.backward()
optimizer.step()
batch_outputs = trim_seqs(output_seqs)
batch_targets = [[list(seq[seq > 0])] for seq in list(to_np(target_variable))]
batch_bleu_score = corpus_bleu(batch_targets, batch_outputs, smoothing_function=SmoothingFunction().method1)
if global_step < 10 or (global_step % 10 == 0 and global_step < 100) or (global_step % 100 == 0 and epoch < 2):
input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski"
output_string = encoder_decoder.get_response(input_string)
writer.add_text('schedule', output_string, global_step=global_step)
input_string = "Amy, Please cancel this meeting. Adam Kleczewski"
output_string = encoder_decoder.get_response(input_string)
writer.add_text('cancel', output_string, global_step=global_step)
if global_step % 100 == 0:
writer.add_scalar('train_batch_loss', batch_loss, global_step)
writer.add_scalar('train_batch_bleu_score', batch_bleu_score, global_step)
for tag, value in encoder_decoder.named_parameters():
tag = tag.replace('.', '/')
writer.add_histogram('weights/' + tag, value, global_step, bins='doane')
writer.add_histogram('grads/' + tag, to_np(value.grad), global_step, bins='doane')
global_step += 1
val_loss, val_bleu_score = evaluate(encoder_decoder, val_data_loader)
writer.add_scalar('val_loss', val_loss, global_step=global_step)
writer.add_scalar('val_bleu_score', val_bleu_score, global_step=global_step)
encoder_embeddings = encoder_decoder.encoder.embedding.weight.data
encoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
writer.add_embedding(encoder_embeddings, metadata=encoder_vocab, global_step=0, tag='encoder_embeddings')
decoder_embeddings = encoder_decoder.decoder.embedding.weight.data
decoder_vocab = encoder_decoder.lang.tok_to_idx.keys()
writer.add_embedding(decoder_embeddings, metadata=decoder_vocab, global_step=0, tag='decoder_embeddings')
input_string = "Amy, Please schedule a meeting with Marcos on Tuesday April 3rd. Adam Kleczewski"
output_string = encoder_decoder.get_response(input_string)
writer.add_text('schedule', output_string, global_step=global_step)
input_string = "Amy, Please cancel this meeting. Adam Kleczewski"
output_string = encoder_decoder.get_response(input_string)
writer.add_text('cancel', output_string, global_step=global_step)
print('val loss: %.5f, val BLEU score: %.5f' % (val_loss, val_bleu_score), flush=True)
torch.save(encoder_decoder, "%s%s_%i.pt" % (model_path, model_name, epoch))
print('-' * 100, flush=True)
def main(model_name, use_cuda, batch_size, teacher_forcing_schedule, keep_prob, val_size, lr, decoder_type, vocab_limit, hidden_size, embedding_size, max_length, seed=42):
model_path = './model/' + model_name + '/'
# TODO: Change logging to reflect loaded parameters
print("training %s with use_cuda=%s, batch_size=%i"% (model_name, use_cuda, batch_size), flush=True)
print("teacher_forcing_schedule=", teacher_forcing_schedule, flush=True)
print("keep_prob=%f, val_size=%f, lr=%f, decoder_type=%s, vocab_limit=%i, hidden_size=%i, embedding_size=%i, max_length=%i, seed=%i" % (keep_prob, val_size, lr, decoder_type, vocab_limit, hidden_size, embedding_size, max_length, seed), flush=True)
if os.path.isdir(model_path):
print("loading encoder and decoder from model_path", flush=True)
encoder_decoder = torch.load(model_path + model_name + '.pt')
print("creating training and validation datasets with saved languages", flush=True)
train_dataset = SequencePairDataset(lang=encoder_decoder.lang,
use_cuda=use_cuda,
is_val=False,
val_size=val_size,
use_extended_vocab=(encoder_decoder.decoder_type=='copy'))
val_dataset = SequencePairDataset(lang=encoder_decoder.lang,
use_cuda=use_cuda,
is_val=True,
val_size=val_size,
use_extended_vocab=(encoder_decoder.decoder_type=='copy'))
else:
os.mkdir(model_path)
print("creating training and validation datasets", flush=True)
train_dataset = SequencePairDataset(vocab_limit=vocab_limit,
use_cuda=use_cuda,
is_val=False,
val_size=val_size,
seed=seed,
use_extended_vocab=(decoder_type=='copy'))
val_dataset = SequencePairDataset(lang=train_dataset.lang,
use_cuda=use_cuda,
is_val=True,
val_size=val_size,
seed=seed,
use_extended_vocab=(decoder_type=='copy'))
print("creating encoder-decoder model", flush=True)
encoder_decoder = EncoderDecoder(train_dataset.lang,
max_length,
embedding_size,
hidden_size,
decoder_type)
torch.save(encoder_decoder, model_path + '/%s.pt' % model_name)
if use_cuda:
encoder_decoder = encoder_decoder.cuda()
else:
encoder_decoder = encoder_decoder.cpu()
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size)
train(encoder_decoder,
train_data_loader,
model_name,
val_data_loader,
keep_prob,
teacher_forcing_schedule,
lr,
encoder_decoder.decoder.max_length)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Parse training parameters')
parser.add_argument('--model_name', type=str,
help='the name of a subdirectory of ./model/ that '
'contains encoder and decoder model files')
parser.add_argument('--epochs', type=int, default=50,
help='the number of epochs to train')
parser.add_argument('--use_cuda', action='store_true',
help='flag indicating that cuda will be used')
parser.add_argument('--batch_size', type=int, default=100,
help='number of examples in a batch')
parser.add_argument('--teacher_forcing_fraction', type=float, default=0.5,
help='fraction of batches that will use teacher forcing during training')
parser.add_argument('--scheduled_teacher_forcing', action='store_true',
help='Linearly decrease the teacher forcing fraction '
'from 1.0 to 0.0 over the specified number of epocs')
parser.add_argument('--keep_prob', type=float, default=1.0,
help='Probablity of keeping an element in the dropout step.')
parser.add_argument('--val_size', type=float, default=0.1,
help='fraction of data to use for validation')
parser.add_argument('--lr', type=float, default=0.001,
help='Learning rate.')
parser.add_argument('--decoder_type', type=str, default='copy',
help="Allowed values 'copy' or 'attn'")
parser.add_argument('--vocab_limit', type=int, default=5000,
help='When creating a new Language object the vocab'
'will be truncated to the most frequently'
'occurring words in the training dataset.')
parser.add_argument('--hidden_size', type=int, default=256,
help='The number of RNN units in the encoder. 2x this '
'number of RNN units will be used in the decoder')
parser.add_argument('--embedding_size', type=int, default=128,
help='Embedding size used in both encoder and decoder')
parser.add_argument('--max_length', type=int, default=200,
help='Sequences will be padded or truncated to this size.')
args = parser.parse_args()
writer = SummaryWriter('./logs/%s_%s' % (args.model_name, str(int(time.time()))))
if args.scheduled_teacher_forcing:
schedule = np.arange(1.0, 0.0, -1.0/args.epochs)
else:
schedule = np.ones(args.epochs) * args.teacher_forcing_fraction
main(args.model_name, args.use_cuda, args.batch_size, schedule, args.keep_prob, args.val_size, args.lr, args.decoder_type, args.vocab_limit, args.hidden_size, args.embedding_size, args.max_length)
# main(str(int(time.time())), args.use_cuda, args.batch_size, schedule, args.keep_prob, args.val_size, args.lr, args.decoder_type, args.vocab_limit, args.hidden_size, args.embedding_size, args.max_length)