Skip to content

Commit

Permalink
add ResiDual paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 9, 2023
1 parent 326ccf4 commit 80bd4d9
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 9 deletions.
37 changes: 36 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,31 @@ x = torch.randint(0, 20000, (1, 1024))
model(x)
```

### ResiDual

<img src="./images/resi_dual.png" width="400px"/>

<a href="https://arxiv.org/abs/2304.14802">This Microsoft paper</a> proposes yet another normalization configuration, combining both pre and post layernorm. They claim this hybridization reduces representation collapse (known to be an issue with pre-layernorm with increasing depth), while maintaining stability and reducing vanishing gradients (issues with post-layernorm). Initial experiments on my end show it to work no worse than pre-layernorm or sandwich norm. More study needed by the public to see if this is actually a winning technique.

```python
import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
resi_dual = True # set this to True
)
)

x = torch.randint(0, 20000, (1, 1024))
model(x)
```

### Normformer

<img src="./images/normformer.png" width="400px"/>
Expand Down Expand Up @@ -1831,10 +1856,20 @@ generated = model.generate(start_emb, 17) # (17, 777)
}
```

```bibtex
@article{Xie2023ResiDualTW,
title = {ResiDual: Transformer with Dual Residual Connections},
author = {Shufang Xie and Huishuai Zhang and Junliang Guo and Xu Tan and Jiang Bian and Hany Hassan Awadalla and Arul Menezes and Tao Qin and Rui Yan},
journal = {ArXiv},
year = {2023},
volume = {abs/2304.14802}
}
```

```bibtex
@inproceedings{Dehghani2023ScalingVT,
title = {Scaling Vision Transformers to 22 Billion Parameters},
author={Mostafa Dehghani and Josip Djolonga and Basil Mustafa and Piotr Padlewski and Jonathan Heek and Justin Gilmer and Andreas Steiner and Mathilde Caron and Robert Geirhos and Ibrahim M. Alabdulmohsin and Rodolphe Jenatton and Lucas Beyer and Michael Tschannen and Anurag Arnab and Xiao Wang and Carlos Riquelme and Matthias Minderer and Joan Puigcerver and Utku Evci and Manoj Kumar and Sjoerd van Steenkiste and Gamaleldin F. Elsayed and Aravindh Mahendran and Fisher Yu and Avital Oliver and Fantine Huot and Jasmijn Bastings and Mark Collier and Alexey A. Gritsenko and Vighnesh Birodkar and Cristina Nader Vasconcelos and Yi Tay and Thomas Mensink and Alexander Kolesnikov and Filip Paveti'c and Dustin Tran and Thomas Kipf and Mario Luvci'c and Xiaohua Zhai and Daniel Keysers and Jeremiah Harmsen and Neil Houlsby},
author = {Mostafa Dehghani and Josip Djolonga and Basil Mustafa and Piotr Padlewski and Jonathan Heek and Justin Gilmer and Andreas Steiner and Mathilde Caron and Robert Geirhos and Ibrahim M. Alabdulmohsin and Rodolphe Jenatton and Lucas Beyer and Michael Tschannen and Anurag Arnab and Xiao Wang and Carlos Riquelme and Matthias Minderer and Joan Puigcerver and Utku Evci and Manoj Kumar and Sjoerd van Steenkiste and Gamaleldin F. Elsayed and Aravindh Mahendran and Fisher Yu and Avital Oliver and Fantine Huot and Jasmijn Bastings and Mark Collier and Alexey A. Gritsenko and Vighnesh Birodkar and Cristina Nader Vasconcelos and Yi Tay and Thomas Mensink and Alexander Kolesnikov and Filip Paveti'c and Dustin Tran and Thomas Kipf and Mario Luvci'c and Xiaohua Zhai and Daniel Keysers and Jeremiah Harmsen and Neil Houlsby},
year = {2023}
}
```
Expand Down
Binary file added images/resi_dual.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.12.3',
version = '1.14.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
26 changes: 19 additions & 7 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,7 @@ def __init__(
deepnorm = False,
shift_tokens = 0,
sandwich_norm = False,
resi_dual = False,
zero_init_branch_output = False,
layer_dropout = 0.,
cross_attn_tokens_dropout = 0.,
Expand Down Expand Up @@ -944,13 +945,16 @@ def __init__(

if deepnorm:
assert scale_residual_constant == 1, 'scale residual constant is being overridden by deep norm settings'
pre_norm = sandwich_norm = False
pre_norm = sandwich_norm = resi_dual = False
scale_residual = True
scale_residual_constant = (2 * depth) ** 0.25

assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both'
assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
assert not (not pre_norm and resi_dual), 'resiDualcannot be used when not using prenorm'
self.pre_norm = pre_norm
self.sandwich_norm = sandwich_norm
self.resi_dual = resi_dual

self.residual_attn = residual_attn
self.cross_residual_attn = cross_residual_attn
Expand Down Expand Up @@ -1037,7 +1041,7 @@ def __init__(

pre_branch_norm = norm_fn() if pre_norm else None
post_branch_norm = norm_fn() if sandwich_norm else None
post_main_norm = norm_fn() if not pre_norm and not is_last_layer else None
post_main_norm = norm_fn() if (resi_dual or not pre_norm) and not is_last_layer else None

norms = nn.ModuleList([
pre_branch_norm,
Expand Down Expand Up @@ -1080,6 +1084,8 @@ def forward(
max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)))
rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device)

outer_residual = x

for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(self.layer_types, self.layers, self.layer_dropouts)):
is_last = ind == (len(self.layers) - 1)

Expand All @@ -1095,12 +1101,12 @@ def forward(
if self.training and self.cross_attn_tokens_dropout > 0.:
context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout)

residual = x
inner_residual = x

pre_branch_norm, post_branch_norm, post_main_norm = norm
pre_norm, post_branch_norm, post_main_norm = norm

if exists(pre_branch_norm):
x = pre_branch_norm(x)
if exists(pre_norm) and not self.resi_dual:
x = pre_norm(x)

if layer_type == 'a':
out, inter = block(x, mask = mask, context_mask = self_attn_context_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
Expand All @@ -1109,10 +1115,13 @@ def forward(
elif layer_type == 'f':
out = block(x)

if self.resi_dual:
outer_residual = residual_fn(out, outer_residual)

if exists(post_branch_norm):
out = post_branch_norm(out)

x = residual_fn(out, residual)
x = residual_fn(out, inner_residual)

if layer_type in ('a', 'c') and return_hiddens:
intermediates.append(inter)
Expand All @@ -1125,6 +1134,9 @@ def forward(
if exists(post_main_norm):
x = post_main_norm(x)

if self.resi_dual:
x = x + pre_norm(outer_residual)

if return_hiddens:
intermediates = LayerIntermediates(
hiddens = hiddens,
Expand Down

0 comments on commit 80bd4d9

Please sign in to comment.