Skip to content

Commit

Permalink
update so one can apply rmsnorm on the queries and keys, due to new S…
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 13, 2023
1 parent ef3e24f commit a41bfb9
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 8 deletions.
36 changes: 35 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ model(x)

The last change is a layernorm right after the outwards projection in attention. This is actually identical to the sandwich norm proposed by the Coqview paper, so you can use this by simply setting `sandwich_norm = True`, although it would also add it to the feedforward layer.

### Grouped Query-Key L2 Normalization
### Cosine Sim Attention

<img src="./images/cosine-sim-attention.png" width="400px"></img>

Expand Down Expand Up @@ -1122,6 +1122,32 @@ x = torch.randint(0, 20000, (1, 1024))
model(x)
```

<img src="./images/qknorm-analysis.png" width="450px"></img>

Update: Google Brain has proven out cosine sim attention in <a href="https://arxiv.org/abs/2302.05442">a 22B parameter model</a>. In their papers, they have analysis showing that the normalization resulted in not only extra stability, but also better results in the end (due to less need to adjust learning rate when increasing parameter count).

We are nearing the point of wiping out a source of transformer training instability with one simple intervention, in my opinion. The only slight difference in the paper is that they still have a learned scale across the feature dimension (per use of rmsnorm). Not sure how critical this is, but just to make sure we don't miss anything, I will include this here. You can use this by setting `qk_norm_dim_scale = True`

```python
import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 12,
heads = 8,
attn_qk_norm = True,
attn_qk_norm_dim_scale = True # set this to True, in addition to `attn_qk_norm = True`
)
)

x = torch.randint(0, 256, (1, 1024))
model(x)
```

### Turning off absolute positional embedding

A number of papers have hinted that causal transformers (`Decoder`) can learn absolute positions in the absence of added embeddings of any sort. This was recently thoroughly investigated <a href="https://arxiv.org/abs/2203.16634">here</a>. You can turn off the absolute positional embedding by setting `use_abs_pos_emb = False` in the `TransformerWrapper`
Expand Down Expand Up @@ -1740,4 +1766,12 @@ generated = model.generate(start_emb, 17) # (17, 777)
}
```

```bibtex
@inproceedings{Dehghani2023ScalingVT,
title = {Scaling Vision Transformers to 22 Billion Parameters},
author={Mostafa Dehghani and Josip Djolonga and Basil Mustafa and Piotr Padlewski and Jonathan Heek and Justin Gilmer and Andreas Steiner and Mathilde Caron and Robert Geirhos and Ibrahim M. Alabdulmohsin and Rodolphe Jenatton and Lucas Beyer and Michael Tschannen and Anurag Arnab and Xiao Wang and Carlos Riquelme and Matthias Minderer and Joan Puigcerver and Utku Evci and Manoj Kumar and Sjoerd van Steenkiste and Gamaleldin F. Elsayed and Aravindh Mahendran and Fisher Yu and Avital Oliver and Fantine Huot and Jasmijn Bastings and Mark Collier and Alexey A. Gritsenko and Vighnesh Birodkar and Cristina Nader Vasconcelos and Yi Tay and Thomas Mensink and Alexander Kolesnikov and Filip Paveti'c and Dustin Tran and Thomas Kipf and Mario Luvci'c and Xiaohua Zhai and Daniel Keysers and Jeremiah Harmsen and Neil Houlsby},
year = {2023}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
Binary file added images/qknorm-analysis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.8.2',
version = '1.9.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
29 changes: 23 additions & 6 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def __init__(
def forward(self, x):
return self.ff(x)

# attention.
# attention. it is all we need

class Attention(nn.Module):
def __init__(
Expand All @@ -598,6 +598,7 @@ def __init__(
qk_norm = False,
qk_norm_groups = 1,
qk_norm_scale = 10,
qk_norm_dim_scale = False,
one_kv_head = False,
shared_kv = False,
value_dim_head = None,
Expand Down Expand Up @@ -645,6 +646,14 @@ def __init__(
self.qk_norm_groups = qk_norm_groups
self.qk_norm_scale = qk_norm_scale

# whether to use the rmsnorm (equivalent to cosine sim attention with learned scale on feature dimension) - https://arxiv.org/abs/2302.05442
self.qk_norm_dim_scale = qk_norm_dim_scale

if qk_norm and qk_norm_dim_scale:
self.qk_norm_q_scale = nn.Parameter(torch.ones(dim_head) * qk_norm_scale ** 0.5)
self.qk_norm_k_scale = nn.Parameter(torch.ones(dim_head) * qk_norm_scale ** 0.5)
self.qk_norm_scale = 1

assert (not qk_norm) or (dim_head % qk_norm_groups) == 0, 'dimension per attention head must be divisible by the qk norm groups'
assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)'

Expand Down Expand Up @@ -713,6 +722,15 @@ def forward(
if not self.one_kv_head:
k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = h), (k, v, r))

if self.qk_norm:
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
q, k = map(qk_l2norm, (q, k))
scale = self.qk_norm_scale

if self.qk_norm_dim_scale:
q = q * self.qk_norm_q_scale
k = k * self.qk_norm_k_scale

if exists(rotary_pos_emb) and not has_context:
freqs, xpos_scale = rotary_pos_emb
l = freqs.shape[-1]
Expand All @@ -727,17 +745,16 @@ def forward(

if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))

if self.qk_norm:
mem_k = l2norm(mem_k)

k = torch.cat((mem_k, k), dim = -2)
v = torch.cat((mem_v, v), dim = -2)

if exists(input_mask):
input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True)

if self.qk_norm:
qk_l2norm = partial(l2norm, groups = self.qk_norm_groups)
q, k = map(qk_l2norm, (q, k))
scale = self.qk_norm_scale

kv_einsum_eq = 'b h j d' if not self.one_kv_head else 'b j d'

dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
Expand Down

0 comments on commit a41bfb9

Please sign in to comment.