-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrainer.py
114 lines (91 loc) · 4.38 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from model import SkipGramEmbeddings
from sgns_loss import SGNSLoss
from tqdm import tqdm
from datasets.pypi_lang import PyPILangDataset
from datasets.world_order import WorldOrderDataset
from torch.utils.tensorboard import SummaryWriter
class Trainer:
def __init__(self, args):
# Load data
self.args = args
self.writer = SummaryWriter(log_dir='./experiments/', flush_secs=3)
#self.dataset = PyPILangDataset(args, examples_path='data/pypi_examples.pth', dict_path='data/pypi_dict.pth')
self.dataset = WorldOrderDataset(args)
self.vocab_size = len(self.dataset.dictionary)
print("Finished loading dataset")
self.dataloader = DataLoader(self.dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers)
self.model = SkipGramEmbeddings(self.vocab_size, args.embedding_len).to(args.device)
self.optim = optim.Adam(self.model.parameters(), lr=args.lr)
self.sgns = SGNSLoss(self.dataset, self.model.word_embeds, self.args.device)
# Add graph to tensorboard
#self.writer.add_graph(self.model, iter(self.dataloader).next()[0])
def train(self):
print('Training on device: {}'.format(self.args.device))
# Log embeddings!
print('\nRandom embeddings:')
for word in self.dataset.queries:
print(f'word: {word} neighbors: {self.model.nearest_neighbors(word, self.dataset.dictionary)}')
for epoch in range(self.args.epochs):
print(f'Beginning epoch: {epoch + 1}/{self.args.epochs}')
running_loss = 0.0 #testing_loss = 0.0, 0.0
global_step = epoch * len(self.dataloader)
num_examples = 0
for i, data in enumerate(tqdm(self.dataloader)):
# Unpack data
center, context = data
center, context = center.to(self.args.device), context.to(self.args.device)
# Remove accumulated gradients
self.optim.zero_grad()
# Get context vectors
center_embed, context_embed = self.model(center, context)
# Calc loss: SGNS
loss = self.sgns(center_embed, context_embed)
# Backprop and update
loss.backward()
self.optim.step()
# Keep track of loss
running_loss += loss.item()
global_step += 1
num_examples += len(data) # Last batch's size may not equal args.batch_size
# TESTING LOSS
#testing_loss += test_loss
# Log at step
#if global_step % self.args.log_step == 0:
# norm = (i + 1) * num_examples
# self.log_step(epoch, global_step, running_loss/norm, center, context)
norm = (i + 1) * num_examples
#self.log_and_save_epoch(epoch, running_loss / norm)
self.log_step(epoch, global_step, running_loss / norm)#, testing_loss / norm)
print('\nGRAD:', np.sum(self.model.word_embeds.weight.grad.clone().detach().numpy()))
self.writer.close()
def log_and_save_epoch(self, epoch, loss):
# Visualize document embeddings
self.writer.add_embedding(
self.model.word_embeds.weight,
global_step=epoch,
tag=f'we_epoch_{epoch}',
)
# Save checkpoint
print(f'Beginning to save checkpoint')
torch.save({
'epoch': epoch + 1,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optim.state_dict(),
'loss': loss,
}, f'epoch_{epoch}_ckpt.pth')
print(f'Finished saving checkpoint')
def log_step(self, epoch, global_step, loss):
print(f'#############################################')
print(f'EPOCH: {epoch} | STEP: {global_step} | LOSS {loss}')# | TEST LOSS {test_loss}')
print(f'#############################################')
#self.writer.add_scalar('train_loss', loss, global_step)
# Log embeddings!
print('\nLearned embeddings:')
for word in self.dataset.queries:
print(f'word: {word} neighbors: {self.model.nearest_neighbors(word, self.dataset.dictionary)}')