Skip to content

Commit

Permalink
add some Normformer proposed changes
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 20, 2021
1 parent 5f37fc0 commit a17ebb3
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 2 deletions.
63 changes: 63 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,58 @@ model = TransformerWrapper(
)
```

### Normformer

<img src="./images/normformer.png" width="400px"/>

This <a href="https://openreview.net/forum?id=GMYWzWztDx5">paper</a> uncovers an issue with pre-norm transformers where gradients are mismatched between the early and later layers. They propose 4 changes, of which I will be offering 3.

The first change is to offer per head scaling after aggregating the values in attention. My experiments show a slight improvement in convergence.

```python
import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
attn_head_scale = True # set this to True
)
)

x = torch.randint(0, 20000, (1, 1024))
model(x)
```

The second change is an extra layernorm right after the activation in the feedforward. I have also verified a slight improvement, at the cost of extra compute.

```python
import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
ff_post_act_ln = True # set this to True
)
)

x = torch.randint(0, 20000, (1, 1024))
model(x)
```

One of the other two changes is a layernorm right after the outwards projection in attention. This is actually identical to the sandwich norm proposed by the Coqview paper, so you can use this by simply setting `sandwich_norm = True`, although it would also add it to the feedforward layer.

Finally, I have tried the parameterized scaling of the residual branch in the feedforward pre-norm block, but noticed some slight instability, so I will hold off from adding that feature until I investigate it a bit more.

## Miscellaneous

Cross Attention
Expand Down Expand Up @@ -1209,4 +1261,15 @@ model(x, mask = mask) # (1, 1024, 100)
}
```

```bibtex
@inproceedings{anonymous2022normformer,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Anonymous},
booktitle = {Submitted to The Tenth International Conference on Learning Representations },
year = {2022},
url = {https://openreview.net/forum?id=GMYWzWztDx5},
note = {under review}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
Binary file added images/normformer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.20.1',
version = '0.20.2',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
14 changes: 13 additions & 1 deletion x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ def __init__(
mult = 4,
glu = False,
relu_squared = False,
post_act_ln = False,
dropout = 0.,
zero_init_output = False
):
Expand All @@ -368,6 +369,7 @@ def __init__(

self.net = nn.Sequential(
project_in,
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
Expand All @@ -390,6 +392,7 @@ def __init__(
causal = False,
mask = None,
talking_heads = False,
head_scale = False,
collab_heads = False,
collab_compression = .3,
sparse_topk = None,
Expand Down Expand Up @@ -433,6 +436,11 @@ def __init__(
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))

# head scaling
self.head_scale = head_scale
if head_scale:
self.head_scale_params = nn.Parameter(torch.ones(1, heads, 1, 1))

# explicit topk sparse attention
self.sparse_topk = sparse_topk

Expand Down Expand Up @@ -466,7 +474,7 @@ def forward(
prev_attn = None,
mem = None
):
b, n, _, h, talking_heads, collab_heads, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, x.device, exists(context)
b, n, _, h, talking_heads, collab_heads, head_scale, device, has_context = *x.shape, self.heads, self.talking_heads, self.collab_heads, self.head_scale, x.device, exists(context)
kv_input = default(context, x)

q_input = x
Expand Down Expand Up @@ -569,6 +577,10 @@ def forward(
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()

out = einsum('b h i j, b h j d -> b h i d', attn, v)

if head_scale:
out = out * self.head_scale_params

out = rearrange(out, 'b h n d -> b n (h d)')

if exists(self.to_v_gate):
Expand Down

0 comments on commit a17ebb3

Please sign in to comment.