-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathroberta_embed_predict.py
158 lines (144 loc) · 8.06 KB
/
roberta_embed_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Code Author: Olivia Lee
Proposed Approach: Train a separate neural network to predict the word embedding given
the definition, using a dictionary of common words as the input and the word embeddings
already in the model as the output.
This approach finetunes the pretrained RoBERTa Masked Language Model (bidirectional).
"""
import json
import torch
from torch.utils.data import random_split, DataLoader
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from transformers import RobertaForMaskedLM, RobertaTokenizer
from dataset import JSonDataset
# https://github.com/huggingface/transformers/issues/1458
# RoBERTa
# roberta-large to avoid truncating?
roberta_pt_model = RobertaForMaskedLM.from_pretrained('roberta-base', output_hidden_states=True, is_decoder=False) # or any other checkpoint
word_embeddings = roberta_pt_model.get_input_embeddings() # Word Token Embeddings # roberta_pt_model.embeddings.word_embeddings.weight
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
roberta_pt_model.resize_token_embeddings(len(tokenizer))
def split_data(dataset):
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
# val_size, test_size = int(0.1 * len(dataset)), len(dataset) - train_size - val_size
train_set, val_set = random_split(dataset, [train_size, val_size])
train_dl, val_dl = DataLoader(train_set, batch_size=1, shuffle=True, num_workers=2, pin_memory=True), DataLoader(val_set, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)
return train_dl, val_dl
def train(device, timestamp, tb_writer, lr=0.00003, eps=3, batch_size=16):
common_data = JSonDataset('datasets/dict_wn.json', 'roberta', tokenizer, word_embeddings)
train_dl, val_dl = split_data(common_data)
model = roberta_pt_model
model.to(device)
loss_fn = torch.nn.MSELoss() #torch.nn.CosineEmbeddingLoss
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
best_vloss, trained_model_path = float('inf'), None
for ep in range(eps):
print('EPOCH {}:'.format(ep + 1))
# One pass through data
model.train(True)
running_loss, avg_loss = 0.0, 0.0
for i, data in enumerate(train_dl):
input, label = data # input = tokenized+padded defn, label = ground truth pretrained embedding
input['input_ids'] = input['input_ids'].squeeze(dim=1).to(device)
input['attention_mask'] = input['attention_mask'].to(device)
label = label.to(device)
outputs = model(input_ids=input['input_ids'], attention_mask=input['attention_mask']) # odict_keys(['logits', 'past_key_values', 'hidden_states'])
# output['hidden states'] is a Tuple of torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)
last_hidden_state = (outputs['hidden_states'][-1].squeeze())[0].unsqueeze(dim=0)
# Sometimes last hidden state is [1]. Sometimes label.size() is [1, 1, 768]. Not sure why
if (last_hidden_state.size() == torch.Size([1])):
continue
elif (label.size() == torch.Size([1, 1, 768])):
label = label.squeeze(dim=0)
if (last_hidden_state.size() != label.size()): # torch.Size([1, 768])
continue # remove assert so as not to crash training
loss = loss_fn(last_hidden_state, label)
loss.backward()
if (i + 1) % batch_size == 0: # Sub-batching
optimizer.step()
optimizer.zero_grad()
# Logging
running_loss += loss.item()
if i % 100 == 99:
avg_loss = running_loss / 100 # loss per batch
print(' batch {} loss: {}'.format(i + 1, avg_loss))
tb_x = ep * len(train_dl) + i + 1
tb_writer.add_scalar('Loss/train', avg_loss, tb_x)
running_loss = 0.
# One set of eval
model.train(False)
running_vloss, val_count = 0.0, 0
with torch.no_grad():
for i, vdata in enumerate(val_dl):
vinputs, vlabels = vdata
vinputs['input_ids'] = vinputs['input_ids'].squeeze(dim=1).to(device)
vinputs['attention_mask'] = vinputs['attention_mask'].to(device)
vlabels = vlabels.to(device)
voutputs = model(input_ids=vinputs['input_ids'], attention_mask=vinputs['attention_mask'])
vlast_hidden_state = (voutputs['hidden_states'][-1].squeeze())[0].unsqueeze(dim=0)
if (vlast_hidden_state.size() == torch.Size([1])):
continue
elif (vlabels.size() == torch.Size([1, 1, 768])):
vlabels = vlabels.squeeze(dim=0)
if (vlast_hidden_state.size() != vlabels.size()): # torch.Size([1, 768])
continue # remove assert so as not to crash training
vloss = loss_fn(vlast_hidden_state, vlabels)
running_vloss += vloss
val_count += 1
avg_vloss = running_vloss / val_count
print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
# Log the running loss averaged per batch
# for both training and validation
tb_writer.add_scalars('Training vs. Validation Loss',
{ 'Training' : avg_loss, 'Validation' : avg_vloss },
ep + 1)
tb_writer.flush()
# Track best performance, and save the model's state
if avg_vloss < best_vloss:
best_vloss = avg_vloss
model_path = 'roberta_model_{}_{}'.format(timestamp, ep)
torch.save(model.state_dict(), model_path)
trained_model_path = model_path
return trained_model_path
def learn_urban(device, trained_model_path):
print(trained_model_path)
model = roberta_pt_model
model.load_state_dict(torch.load(trained_model_path))
model.to(device)
model.eval()
with torch.no_grad():
counter = 0
with open('datasets/urban_common_100up.json', "r") as f:
data = json.load(f)
for entry in data:
word, defn, upv, downv = entry['lowercase_word'], entry['definition'].lower(), int(entry["thumbs_up"]), int(entry["thumbs_down"])
# data has been preprocessed
# if (len(word.split(' ')) > 1) or (downv > upv) or (upv < 1000): continue # skip phrases, words with more downvotes than upvotes, or too few upvotes
# if len(tokenizer(word, return_tensors='pt')['input_ids'][0]) == 1: continue # skip words that are common but in UD (naive test)
# input is tokenized + padded defn
input = tokenizer(defn, padding='max_length', truncation=True, max_length=512, return_tensors="pt")
input['input_ids'] = input['input_ids'].squeeze(dim=1).to(device)
input['attention_mask'] = input['attention_mask'].to(device)
outputs = model(input_ids=input['input_ids'], attention_mask=input['attention_mask']) # output is predicted word embedding
last_hidden_state = (outputs['hidden_states'][-1].squeeze())[0].unsqueeze(dim=0)
tokenizer.add_tokens(word)
model.resize_token_embeddings(len(tokenizer))
model.get_input_embeddings().weight.data[-1] = last_hidden_state
counter += 1
torch.save(model.state_dict(), 'roberta_final_model')
with open('roberta_tokenizer_vocab.json', 'w') as fp:
json.dump(tokenizer.get_vocab(), fp)
def main():
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
# PHASE 1: Train model on dict of common words to learn r/s between defns and embeddings
trained_model_path = train(device, timestamp, writer)
# PHASE 2: Add add new word embeddings to GPT2 given the new definitions
learn_urban(device, trained_model_path)
if __name__ == '__main__':
main()