-
Notifications
You must be signed in to change notification settings - Fork 146
/
Copy pathmodels.py
384 lines (330 loc) · 13.8 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
# Copyright (c) 2019-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os
from logging import getLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
from codegen_sources.model.src.model.transformer import (
TransformerModel,
create_position_ids_from_input_ids,
)
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.models.roberta.modeling_roberta import RobertaClassificationHead
logger = getLogger()
class ModelConfig:
def __init__(self, params, num_labels=None, dico=None, **kwargs) -> None:
super().__init__(**kwargs)
self.num_labels = num_labels
self.n_langs = params["n_langs"]
self.n_words = params["n_words"]
# create fake dico because need in transfo constructor, but not used here.
# TODO -> do it cleaner way for opensourcing
self.dico = ["" for i in range(self.n_words)] if dico is None else dico
self.vocab_size = self.n_words
self.eos_index = params["eos_index"]
self.pad_index = params["pad_index"]
self.id2lang = params["id2lang"]
self.lang2id = params["lang2id"]
self.emb_dim_encoder = params["emb_dim_encoder"]
self.emb_dim_decoder = params["emb_dim_decoder"]
self.n_heads = params["n_heads"]
self.n_layers_encoder = params["n_layers_encoder"]
self.n_layers_decoder = params["n_layers_decoder"]
self.dropout = params["dropout"]
self.attention_dropout = params["attention_dropout"]
self.sinusoidal_embeddings = params["sinusoidal_embeddings"]
self.spans_emb_encoder = False
self.gelu_activation = params["gelu_activation"]
self.share_inout_emb = params["share_inout_emb"]
self.roberta_mode = (
getattr(params, "roberta_mode", False)
or getattr(params, "tokenization_mode", "") == "roberta"
)
# needed for some of the tasks
self.hidden_size = self.emb_dim_encoder
self.hidden_dropout_prob = self.dropout
self.num_attention_heads = self.n_heads
self.torchscript = False
@classmethod
def from_pretrained(
self, config_path, cache_dir=None, num_labels=None, finetuning_task=None
):
assert os.path.exists(
config_path
), f"cannot reload config : cannot find {config_path}"
print(config_path)
reloaded = torch.load(config_path)
assert (
"params" in reloaded.keys()
), f"params not found in the file {config_path}"
params_reloaded = reloaded["params"]
return ModelConfig(params_reloaded, num_labels)
class Pooler(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.emb_dim_encoder, config.emb_dim_encoder)
self.activation = nn.Tanh()
def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class Model(nn.Module):
def __init__(self, config, lang, is_encoder, add_pooling_layer=True) -> None:
super(Model, self).__init__()
self.is_encoder = is_encoder
self.transformer = TransformerModel(config, config.dico, self.is_encoder, True)
self.lang = lang
# embeddings is useful for finetuning in encoder-decoder archi
self.embeddings = ModelEmbeddings(
self.transformer.roberta_mode,
self.transformer.lang2id,
self.transformer.pad_index,
self.lang,
self.transformer.embeddings,
self.transformer.position_embeddings,
self.transformer.lang_embeddings,
self.transformer.layer_norm_emb,
self.transformer.dropout,
)
self.pooler = Pooler(config) if add_pooling_layer else None
def reload_model(self, model_path):
assert os.path.exists(
model_path
), f"cannot reload model : cannot find {model_path}"
reloaded = torch.load(model_path)
model_type = "encoder" if self.is_encoder else "decoder"
if not (model_type in reloaded.keys() or "model" in reloaded.keys()):
print(
f"cannot find encoder nor model in the file {model_path}, do not reload ,model"
)
return
model_reloaded = (
reloaded[model_type] if model_type in reloaded.keys() else reloaded["model"]
)
if all([k.startswith("module.") for k in model_reloaded.keys()]):
model_reloaded = {k[len("module.") :]: v for k, v in model_reloaded.items()}
self.transformer.load_state_dict(model_reloaded, strict=True)
def forward(self, input_ids, attention_mask):
attention_mask = None # not use, only here to match HF interface
bs = input_ids.shape[0]
lengths = torch.tensor(
(input_ids != self.transformer.pad_index).sum(dim=1).long()
)
input_ids = input_ids.transpose(0, 1)
lang_id = (
self.transformer.lang2id[self.lang]
if self.lang in self.transformer.lang2id
else self.transformer.lang2id[self.lang.split("_")[0]]
)
langs = input_ids.clone().fill_(lang_id)
output = self.transformer(
"fwd", x=input_ids, lengths=lengths, langs=langs, causal=False
).transpose(0, 1)
assert output.shape[0] == bs and output.shape[2] == self.transformer.dim
pooled_output = self.pooler(output) if self.pooler is not None else None
assert (
pooled_output.shape[0] == bs
and pooled_output.shape[1] == self.transformer.dim
)
return output, self.pooler(output), None
class ModelJava(Model):
def __init__(self, config, is_encoder) -> None:
super().__init__(config=config, lang="java_obfuscated", is_encoder=is_encoder)
@classmethod
def from_pretrained(
self, model_path, from_tf=None, config=None, cache_dir=None, is_encoder=True
):
model = ModelJava(config, is_encoder)
model.reload_model(model_path)
return model
class ModelJavaFunc(Model):
def __init__(self, config, is_encoder) -> None:
super().__init__(
config=config, lang="java_obfuscated_func", is_encoder=is_encoder
)
@classmethod
def from_pretrained(
self, model_path, from_tf=None, config=None, cache_dir=None, is_encoder=True
):
model = ModelJavaFunc(config, is_encoder)
model.reload_model(model_path)
return model
class ModelPython(Model):
def __init__(self, config, is_encoder) -> None:
super().__init__(config=config, lang="python_obfuscated", is_encoder=is_encoder)
@classmethod
def from_pretrained(
self, model_path, from_tf=None, config=None, cache_dir=None, is_encoder=True
):
model = ModelPython(config, is_encoder)
model.reload_model(model_path)
return model
class ModelPythonFunc(Model):
def __init__(self, config, is_encoder) -> None:
super().__init__(
config=config, lang="python_obfuscated_func", is_encoder=is_encoder
)
@classmethod
def from_pretrained(
self, model_path, from_tf=None, config=None, cache_dir=None, is_encoder=True
):
model = ModelPythonFunc(config, is_encoder)
model.reload_model(model_path)
return model
class ModelEmbeddings(nn.Module):
def __init__(
self,
roberta_mode,
lang2id,
pad_index,
lang,
word_embeddings,
position_embeddings,
lang_embeddings,
layer_norm_emb,
dropout,
):
super().__init__()
self.lang2id = lang2id
self.lang = lang
self.pad_index = pad_index
self.word_embeddings = word_embeddings
self.position_embeddings = position_embeddings
self.lang_embeddings = lang_embeddings
self.layer_norm_emb = layer_norm_emb
self.dropout = dropout
self.roberta_mode = roberta_mode
def forward(self, input_ids):
bs, slen = input_ids.size()
lang_id = (
self.lang2id[self.lang]
if self.lang in self.lang2id
else self.lang2id[self.lang.split("_")[0]]
)
langs = input_ids.clone().fill_(lang_id)
if self.roberta_mode:
positions = create_position_ids_from_input_ids(input_ids, self.pad_index)
else:
positions = input_ids.new(slen).long()
positions = torch.arange(slen, out=positions).unsqueeze(0)
tensor = self.word_embeddings(input_ids)
tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
tensor = tensor + self.lang_embeddings(langs)
tensor = self.layer_norm_emb(tensor)
tensor = F.dropout(tensor, p=self.dropout, training=self.training)
assert tensor.size() == (bs, slen, self.word_embeddings.embedding_dim), print(
f"{tensor.size()}"
)
return tensor
class ClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config) -> None:
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class ModelForSequenceClassification(nn.Module):
def __init__(self, config, lang) -> None:
super().__init__()
self.num_labels = config.num_labels
self.config = config
if not hasattr(config, "use_return_dict"):
config.use_return_dict = False
self.model = Model(config, lang, True)
self.classifier = RobertaClassificationHead(config)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.model(input_ids, attention_mask=attention_mask,)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def save_pretrained(self, save_directory):
"""
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
Arguments:
save_directory (:obj:`str`):
Directory to which to save. Will be created if it doesn't exist.
"""
WEIGHTS_NAME = "pytorch_model.bin"
if os.path.isfile(save_directory):
print(f"{save_directory} is a file, not a directory. Cannot save model.")
return
os.makedirs(save_directory, exist_ok=True)
# Only save the model itself if we are using distributed training
model_to_save = self.module if hasattr(self, "module") else self
state_dict = model_to_save.state_dict()
# If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
# model_to_save.config.save_pretrained(save_directory)
torch.save(state_dict, output_model_file)
class ModelForSequenceClassificationPython(ModelForSequenceClassification):
def __init__(self, config) -> None:
super().__init__(config=config, lang="python_obfuscated")
@classmethod
def from_pretrained(self, model_path, from_tf=None, config=None, cache_dir=None):
model = ModelForSequenceClassificationPython(config)
model.model.reload_model(model_path)
return model
class ModelForSequenceClassificationJava(ModelForSequenceClassification):
def __init__(self, config) -> None:
super().__init__(config=config, lang="java_obfuscated")
@classmethod
def from_pretrained(self, model_path, from_tf=None, config=None, cache_dir=None):
model = ModelForSequenceClassificationJava(config)
model.model.reload_model(model_path)
return model