-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
71 lines (53 loc) · 1.97 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# -*- coding: utf-8 -*-
"""
@Time : 2020/10/20 3:26 pm
@Auth : maker
@Email:[email protected]
"""
import argparse
import torch
import numpy as np
from torch import nn, optim
from torch.utils.data import DataLoader
from model import Model
from dataset import Dataset
def train(dataset, model, args):
model.train()
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(args.max_epochs):
state_h, state_c = model.init_state(args.sequence_length)
for batch, (x, y) in enumerate(dataloader):
optimizer.zero_grad()
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
loss = criterion(y_pred.transpose(1, 2), y)
state_h = state_h.detach()
state_c = state_c.detach()
loss.backward()
optimizer.step()
print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })
def predict(dataset, model, text, next_words=100):
words = text.split(' ')
model.eval()
state_h, state_c = model.init_state(len(words))
for i in range(0, next_words):
x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
last_word_logits = y_pred[0][-1]
p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
word_index = np.random.choice(len(last_word_logits), p=p)
words.append(dataset.index_to_word[word_index])
return words
parser = argparse.ArgumentParser()
parser.add_argument('--max-epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--sequence-length', type=int, default=4)
args = parser.parse_args()
dataset = Dataset(args)
model = Model(dataset)
train(dataset, model, args)
print(predict(dataset, model, text='Knock knock. Whos there?'))