Skip to content

Commit

Permalink
add l2norm embeddings (part of fixnorm), from Transformer Without Tea…
Browse files Browse the repository at this point in the history
…rs paper
  • Loading branch information
lucidrains committed Feb 13, 2022
1 parent 75d723b commit b31ed8f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 11 deletions.
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,25 @@ model = TransformerWrapper(
dim = 512,
depth = 6,
heads = 8,
use_scalenorm = True # set to true to use for all layers
use_scalenorm = True # set to True to use for all layers
)
)
```

You can also use the l2 normalized embeddings proposed as part of `fixnorm`. I have found it leads to improved convergence, when paired with small initialization (proposed by <a href="https://github.com/BlinkDL">BlinkDL</a>). The small initialization will be taken care of as long as `l2norm_embed` is set to `True`

```python
import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
l2norm_embed = True, # set this to True for l2 normalized embedding + small init
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8
)
)
```
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.23.0',
version = '0.24.1',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
40 changes: 31 additions & 9 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,33 @@ class ReluSquared(nn.Module):
def forward(self, x):
return F.relu(x) ** 2

# embedding

class TokenEmbedding(nn.Module):
def __init__(self, dim, num_tokens, l2norm_embed = False):
super().__init__()
self.l2norm_embed = l2norm_embed
self.emb = nn.Embedding(num_tokens, dim)

def forward(self, x):
token_emb = self.emb(x)
return l2norm(token_emb) if self.l2norm_embed else token_emb

# positional embeddings

class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
def __init__(self, dim, max_seq_len, l2norm_embed = False):
super().__init__()
self.scale = dim ** -0.5
self.scale = dim ** -0.5 if not l2norm_embed else 1.
self.l2norm_embed = l2norm_embed
self.emb = nn.Embedding(max_seq_len, dim)

def forward(self, x):
n = torch.arange(x.shape[1], device = x.device)
pos_emb = self.emb(n)
pos_emb = rearrange(pos_emb, 'n d -> () n d')
return pos_emb * self.scale
pos_emb = pos_emb * self.scale
return l2norm(pos_emb) if self.l2norm_embed else pos_emb

class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
Expand Down Expand Up @@ -985,7 +999,8 @@ def __init__(
emb_dropout = 0.,
num_memory_tokens = None,
tie_embedding = False,
use_pos_emb = True
use_pos_emb = True,
l2norm_embed = False
):
super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
Expand All @@ -997,8 +1012,10 @@ def __init__(
self.max_mem_len = max_mem_len
self.shift_mem_down = shift_mem_down

self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.l2norm_embed = l2norm_embed
self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0)

self.emb_dropout = nn.Dropout(emb_dropout)

self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
Expand All @@ -1016,7 +1033,12 @@ def __init__(
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))

def init_(self):
nn.init.kaiming_normal_(self.token_emb.weight)
if self.l2norm_embed:
nn.init.normal_(self.token_emb.emb.weight, std = 1e-5)
nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5)
return

nn.init.kaiming_normal_(self.token_emb.emb.weight)

def forward(
self,
Expand All @@ -1029,8 +1051,8 @@ def forward(
**kwargs
):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
x = self.token_emb(x)
x = x + self.pos_emb(x)

x = self.token_emb(x) + self.pos_emb(x)
x = self.emb_dropout(x)

x = self.project_emb(x)
Expand Down

0 comments on commit b31ed8f

Please sign in to comment.