-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
115 lines (90 loc) · 4.13 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch.nn as nn
import torch
import torch.nn.functional as F
def init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.uniform_(m.bias)
class SpanEmbedder(nn.Module):
def __init__(self, config, device):
super(SpanEmbedder, self).__init__()
self.bert_hidden_size = config.bert_hidden_size
self.with_width_embedding = config.with_mention_width
self.use_head_attention = config.with_head_attention
self.device = device
self.dropout = config.dropout
self.padded_vector = torch.zeros(self.bert_hidden_size, device=device)
self.self_attention_layer = nn.Sequential(
nn.Dropout(config.dropout),
nn.Linear(self.bert_hidden_size, config.hidden_layer),
# nn.Dropout(config['dropout']),
nn.ReLU(),
nn.Linear(config.hidden_layer, 1)
)
self.self_attention_layer.apply(init_weights)
self.width_feature = nn.Embedding(5, config.embedding_dimension)
def pad_continous_embeddings(self, continuous_embeddings):
max_length = max(len(v) for v in continuous_embeddings)
padded_tokens_embeddings = torch.stack(
[torch.cat((emb, self.padded_vector.repeat(max_length - len(emb), 1)))
for emb in continuous_embeddings]
)
masks = torch.stack(
[torch.cat(
(torch.ones(len(emb), device=self.device),
torch.zeros(max_length - len(emb), device=self.device)))
for emb in continuous_embeddings]
)
return padded_tokens_embeddings, masks
def forward(self, start_end, continuous_embeddings, width):
vector = start_end
if self.use_head_attention:
padded_tokens_embeddings, masks = self.pad_continous_embeddings(continuous_embeddings)
attention_scores = self.self_attention_layer(padded_tokens_embeddings).squeeze(-1)
attention_scores *= masks
attention_scores = torch.where(attention_scores != 0, attention_scores,
torch.tensor(-9e9, device=self.device))
attention_scores = F.softmax(attention_scores, dim=1)
weighted_sum = (attention_scores.unsqueeze(-1) * padded_tokens_embeddings).sum(dim=1)
vector = torch.cat((vector, weighted_sum), dim=1)
if self.with_width_embedding:
width = torch.clamp(width, max=4)
width_embedding = self.width_feature(width)
vector = torch.cat((vector, width_embedding), dim=1)
return vector
class SpanScorer(nn.Module):
def __init__(self, config):
super(SpanScorer, self).__init__()
self.input_layer = config.bert_hidden_size * 3
if config.with_mention_width:
self.input_layer += config.embedding_dimension
self.mlp = nn.Sequential(
nn.Dropout(config.dropout),
nn.Linear(self.input_layer, config.hidden_layer),
# nn.Dropout(config['dropout']),
nn.ReLU(),
nn.Linear(config.hidden_layer, 1)
)
self.mlp.apply(init_weights)
def forward(self, span_embedding):
return self.mlp(span_embedding)
class SimplePairWiseClassifier(nn.Module):
def __init__(self, config):
super(SimplePairWiseClassifier, self).__init__()
self.input_layer = config.bert_hidden_size * 3 if config.with_head_attention else config.bert_hidden_size * 2
if config.with_mention_width:
self.input_layer += config.embedding_dimension
self.input_layer *= 3
self.hidden_layer = config.hidden_layer
self.pairwise_mlp = nn.Sequential(
nn.Dropout(config.dropout),
nn.Linear(self.input_layer, self.hidden_layer),
nn.ReLU(),
nn.Linear(self.hidden_layer, self.hidden_layer),
nn.Dropout(config.dropout),
nn.ReLU(),
nn.Linear(self.hidden_layer, 1),
)
self.pairwise_mlp.apply(init_weights)
def forward(self, first, second):
return self.pairwise_mlp(torch.cat((first, second, first * second), dim=1))