-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdecoder.py
281 lines (244 loc) · 11.2 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
import math
import unittest
import random
from typing import Optional
import numpy as np
import torch
from torch import nn
from torch.nn.init import xavier_uniform_
from multi_head_attention import MultiHeadAttention
from positional_encodings import SinusoidEncoding
from utils import construct_future_mask
class TransformerDecoder(nn.Module):
def __init__(
self,
embedding: torch.nn.Embedding,
hidden_dim: int,
ff_dim: int,
num_heads: int,
num_layers: int,
vocab_size: int,
dropout_p: float,
tie_output_to_embedding: Optional[bool] = True,
):
super().__init__()
self.hidden_dim = hidden_dim
self.embed = embedding
self.positional_encoding = SinusoidEncoding(hidden_dim)
self.dropout = nn.Dropout(p=0.1)
self.decoder_blocks = nn.ModuleList(
[
TransformerDecoderBlock(hidden_dim, ff_dim, num_heads, dropout_p)
for _ in range(num_layers)
]
)
self.output_layer = nn.Linear(hidden_dim, vocab_size, bias=False)
# Note: a linear layer multiplies the input with a transpose of the weight matrix, so no need to do that here.
if tie_output_to_embedding:
self.output_layer.weight = nn.Parameter(self.embed.weight)
def _reset_parameters(self):
""" Perform xavier weight initialization"""
for p in self.parameters():
if p.dim() > 1:
xavier_uniform_(p)
def forward(
self,
input_tokens: torch.IntTensor,
encoder_hidden_states: torch.Tensor,
src_padding_mask: Optional[torch.BoolTensor] = None,
future_mask: Optional[torch.BoolTensor] = None,
):
"""
Performs one decoder forward pass given encoder hidden states, the decoder input tokens and attention masks.
N = batch size
S = source sequence length
T = target sequence length
E = embedding dimensionality
V = vocabulary size
:param input_tokens: Decoder input tokens. Shape: (N, T)
:param encoder_hidden_states: The encoder's final (contextualized) token embeddings. Shape: (N, S, E)
:param src_padding_mask: An attention mask to ignore pad-tokens in the source input. Shape (N, S)
:param future_mask: An attention mask to ignore future-tokens in the target input. Shape (T, T)
:return: Unnormalized logits over the vocabulary for every token in the batch. Shape (N, T, V)
"""
# (batch_size, sequence_length, hidden_dim)
x = self.embed(input_tokens) * math.sqrt(self.hidden_dim)
x = self.positional_encoding(x)
x = self.dropout(x)
for decoder_block in self.decoder_blocks:
x = decoder_block(x, encoder_hidden_states, src_padding_mask, future_mask)
# (batch_size, sequence_length, vocab_size)
logits = self.output_layer(x)
return logits
class TransformerDecoderBlock(nn.Module):
def __init__(self, hidden_dim: int, ff_dim: int, num_heads: int, dropout_p: float):
super().__init__()
self.cross_mha = MultiHeadAttention(hidden_dim, num_heads)
self.self_mha = MultiHeadAttention(hidden_dim, num_heads)
self.feed_forward = nn.Sequential(
nn.Linear(hidden_dim, ff_dim), nn.ReLU(), nn.Linear(ff_dim, hidden_dim),
)
self.dropout1 = nn.Dropout(p=dropout_p)
self.dropout2 = nn.Dropout(p=dropout_p)
self.dropout3 = nn.Dropout(p=dropout_p)
self.layer_norm1 = nn.LayerNorm(hidden_dim)
self.layer_norm2 = nn.LayerNorm(hidden_dim)
self.layer_norm3 = nn.LayerNorm(hidden_dim)
def forward(
self,
x: torch.Tensor,
encoder_hidden_states: torch.FloatTensor,
src_padding_mask: Optional[torch.BoolTensor] = None,
future_mask: Optional[torch.BoolTensor] = None,
):
"""
Performs one decoder *block* forward pass given final encoder hidden states, the previous block's output, and
attention masks.
N = batch size
S = source sequence length
T = target sequence length
E = embedding dimensionality
V = vocabulary size
:param x: Previous decoder block's output. Shape: (N, T, E)
:param encoder_hidden_states: The encoder's final (contextualized) token embeddings. Shape: (N, S, E)
:param src_padding_mask: An attention mask to ignore pad-tokens in the source input. Shape (N, S)
:param future_mask: An attention mask to ignore future-tokens in the target input. Shape (T, T)
:return: Updated, contextualized token embeddings. Shape (N, T, E)
"""
# Self attention (with future masking during training)
output = self.dropout1(self.self_mha.forward(x, future_mask=future_mask))
x = self.layer_norm1(x + output)
# Cross or encoder-decoder attention
output = self.dropout2(
self.cross_mha.forward(
x,
encoder_hidden_states=encoder_hidden_states,
src_padding_mask=src_padding_mask,
)
)
x = self.layer_norm2(x + output)
# Feed forward layers
output = self.dropout3(self.feed_forward(x))
x = self.layer_norm3(x + output)
return x
class TestTransformerDecoder(unittest.TestCase):
def test_one_layer_transformer_decoder_inference(self):
"""
Test two forward passes, simulating two greedy decoding inference steps
"""
seed = 0
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
with torch.no_grad():
batch_size = 2
src_seq_len = 10
hidden_dim = 512
vocab_size = 2000
num_layers = 1
num_heads = 8
# Prepare fake encoder hidden states and padding masks
encoder_output = torch.randn((batch_size, src_seq_len, hidden_dim))
src_padding_mask = torch.BoolTensor(
[[1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
)
# Initialize the decoder, perform xavier init and set to evaluation mode
decoder = TransformerDecoder(
embedding=torch.nn.Embedding(vocab_size, hidden_dim),
hidden_dim=hidden_dim,
ff_dim=2048,
num_heads=num_heads,
num_layers=num_layers,
dropout_p=0.1,
vocab_size=vocab_size,
tie_output_to_embedding=True,
)
decoder._reset_parameters()
decoder.eval()
# Prepare decoder input, mask, perform a decoding step, take the argmax over the softmax of the last token
bos_token_id = 1
# and iteratively feed the input+prediction back in.
decoder_input = torch.IntTensor([[bos_token_id], [bos_token_id]])
future_mask = None
for i in range(3):
decoder_output = decoder(
decoder_input,
encoder_output,
src_padding_mask=src_padding_mask,
future_mask=future_mask,
)
predicted_tokens = torch.argmax(
decoder_output[:, -1, :], dim=-1
).unsqueeze(1)
decoder_input = torch.cat((decoder_input, predicted_tokens), dim=-1)
future_mask = construct_future_mask(decoder_input.shape[1])
self.assertEqual(decoder_output.shape, (batch_size, i + 1, vocab_size))
# softmax entropy should not be 0
self.assertEqual(torch.any(decoder_output == 1), False)
"""
With only one decoder layer the predicted tokens will always be the input token ids. This happens
only when the final linear transformation is tied to the (transpose of) the embedding matrix.
This is because the input embedding is barely transformed due to residual connections. This results in
the highest dot product between its final "contextualized" embedding and the original embedding vector
in the pre-softmax weight matrix (i.e. embedding matrix) - because they are still very similar.
This can be avoided by 1) scaling up the memory states - probably because this adds sufficient random
noise through cross-attention to the contextualised embedding to divergence from the input embedding.
2) increasing the number of layers - again adding more and more "noise" or 3) removing the last
residual connection after the feed forward layers. In practice, however, this is not an issue. Training
will take care of it.
"""
self.assertEqual(torch.all(decoder_input == bos_token_id), True)
def test_multi_layer_transformer_decoder_inference(self):
"""
Test two forward passes, simulating two inference decoding steps
"""
seed = 0
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
with torch.no_grad():
batch_size = 2
src_seq_len = 10
hidden_dim = 512
vocab_size = 2000
# Prepare fake encoder hidden states and padding masks
encoder_output = torch.randn((batch_size, src_seq_len, hidden_dim))
src_padding_mask = torch.BoolTensor(
[[1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]
)
# Initialize the decoder, perform xavier init and set to evaluation mode
decoder = TransformerDecoder(
embedding=torch.nn.Embedding(vocab_size, hidden_dim),
hidden_dim=hidden_dim,
ff_dim=2048,
num_heads=8,
num_layers=6,
dropout_p=0.1,
vocab_size=vocab_size,
tie_output_to_embedding=False,
)
decoder._reset_parameters()
decoder.eval()
# Prepare decoder input, mask, perform a decoding step, take the argmax over the softmax of the last token
bos_token_id = 10
# and iteratively feed the input+prediction back in.
decoder_input = torch.IntTensor([[bos_token_id], [bos_token_id]])
future_mask = None
for i in range(3):
decoder_output = decoder(
decoder_input,
encoder_output,
src_padding_mask=src_padding_mask,
future_mask=future_mask,
)
predicted_tokens = torch.argmax(
decoder_output[:, -1, :], dim=-1
).unsqueeze(1)
decoder_input = torch.cat((decoder_input, predicted_tokens), dim=-1)
future_mask = construct_future_mask(decoder_input.shape[1])
self.assertEqual(decoder_output.shape, (batch_size, i + 1, vocab_size))
# softmax entropy should not be 0
self.assertEqual(torch.any(decoder_output == 1), False)
self.assertEqual(torch.all(decoder_input == bos_token_id), False)
if __name__ == "__main__":
unittest.main()