From b31ed8f104ef502b7a1d0c23a5b64207967122a1 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 13 Feb 2022 13:23:48 -0800 Subject: [PATCH] add l2norm embeddings (part of fixnorm), from Transformer Without Tears paper --- README.md | 20 +++++++++++++++- setup.py | 2 +- x_transformers/x_transformers.py | 40 +++++++++++++++++++++++++------- 3 files changed, 51 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index ee3c00cf..4c178483 100644 --- a/README.md +++ b/README.md @@ -241,7 +241,25 @@ model = TransformerWrapper( dim = 512, depth = 6, heads = 8, - use_scalenorm = True # set to true to use for all layers + use_scalenorm = True # set to True to use for all layers + ) +) +``` + +You can also use the l2 normalized embeddings proposed as part of `fixnorm`. I have found it leads to improved convergence, when paired with small initialization (proposed by BlinkDL). The small initialization will be taken care of as long as `l2norm_embed` is set to `True` + +```python +import torch +from x_transformers import TransformerWrapper, Decoder, Encoder + +model = TransformerWrapper( + num_tokens = 20000, + max_seq_len = 1024, + l2norm_embed = True, # set this to True for l2 normalized embedding + small init + attn_layers = Decoder( + dim = 512, + depth = 6, + heads = 8 ) ) ``` diff --git a/setup.py b/setup.py index 0bd814d7..970455de 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '0.23.0', + version = '0.24.1', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 1c7df9d5..e5f1dd20 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -102,19 +102,33 @@ class ReluSquared(nn.Module): def forward(self, x): return F.relu(x) ** 2 +# embedding + +class TokenEmbedding(nn.Module): + def __init__(self, dim, num_tokens, l2norm_embed = False): + super().__init__() + self.l2norm_embed = l2norm_embed + self.emb = nn.Embedding(num_tokens, dim) + + def forward(self, x): + token_emb = self.emb(x) + return l2norm(token_emb) if self.l2norm_embed else token_emb + # positional embeddings class AbsolutePositionalEmbedding(nn.Module): - def __init__(self, dim, max_seq_len): + def __init__(self, dim, max_seq_len, l2norm_embed = False): super().__init__() - self.scale = dim ** -0.5 + self.scale = dim ** -0.5 if not l2norm_embed else 1. + self.l2norm_embed = l2norm_embed self.emb = nn.Embedding(max_seq_len, dim) def forward(self, x): n = torch.arange(x.shape[1], device = x.device) pos_emb = self.emb(n) pos_emb = rearrange(pos_emb, 'n d -> () n d') - return pos_emb * self.scale + pos_emb = pos_emb * self.scale + return l2norm(pos_emb) if self.l2norm_embed else pos_emb class FixedPositionalEmbedding(nn.Module): def __init__(self, dim): @@ -985,7 +999,8 @@ def __init__( emb_dropout = 0., num_memory_tokens = None, tie_embedding = False, - use_pos_emb = True + use_pos_emb = True, + l2norm_embed = False ): super().__init__() assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' @@ -997,8 +1012,10 @@ def __init__( self.max_mem_len = max_mem_len self.shift_mem_down = shift_mem_down - self.token_emb = nn.Embedding(num_tokens, emb_dim) - self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.l2norm_embed = l2norm_embed + self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed) + self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0) + self.emb_dropout = nn.Dropout(emb_dropout) self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() @@ -1016,7 +1033,12 @@ def __init__( self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim)) def init_(self): - nn.init.kaiming_normal_(self.token_emb.weight) + if self.l2norm_embed: + nn.init.normal_(self.token_emb.emb.weight, std = 1e-5) + nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5) + return + + nn.init.kaiming_normal_(self.token_emb.emb.weight) def forward( self, @@ -1029,8 +1051,8 @@ def forward( **kwargs ): b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens - x = self.token_emb(x) - x = x + self.pos_emb(x) + + x = self.token_emb(x) + self.pos_emb(x) x = self.emb_dropout(x) x = self.project_emb(x)