Skip to content

Commit

Permalink
add deepnorm, successfully used in a 130B parameter model out of Tsin…
Browse files Browse the repository at this point in the history
…ghua
  • Loading branch information
lucidrains committed Aug 7, 2022
1 parent 1adb2ea commit 8774657
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 4 deletions.
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,33 @@ model = XTransformer(
)
```

### Deepnorm

<img src="./images/deepnorm.png" width="450px"></img>

It is well known that post-normalization transformers have trouble with stability, prompting the move to <a href="https://arxiv.org/abs/2002.04745">pre-normalization</a> in recent years, even though the latter sacrifices performance.

This paper out of Microsoft research proposes a way to fix post-normalization stability. They achieve this by simply scaling the residual and proper initialization. They show they can train an one thousand layer transformer without stability issues, and achieve better results than pre-normalization.

This was recently validated in a <a href="https://keg.cs.tsinghua.edu.cn/glm-130b/">130B GLM model</a> out of Tsinghua.


```python
import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
deepnorm = True, # set this to True to use deepnorm post-normalization configuration
dim = 512,
depth = 6,
heads = 8
)
)
```

### Transformer-XL recurrence

You can also do Transformer-XL recurrence, by simply passing in a `max_mem_len` in the `TransformerWrapper` class, and then making sure your `Decoder` has `rel_pos_bias` set to `True`.
Expand Down Expand Up @@ -1543,4 +1570,14 @@ generated = model.generate(start_emb, 17) # (17, 777)
}
```

```bibtex
@article{Wang2022DeepNetST,
title = {DeepNet: Scaling Transformers to 1, 000 Layers},
author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei},
journal = {ArXiv},
year = {2022},
volume = {abs/2203.00555}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
Binary file added images/deepnorm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.32.3',
version = '0.33.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
52 changes: 49 additions & 3 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,24 @@ def groupby_prefix_and_trim(prefix, d):
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs

# initializations

def deepnorm_init(
transformer,
beta,
module_name_match_list = ['.ff.', '.to_v', '.to_out']
):
for name, module in transformer.named_modules():
if type(module) != nn.Linear:
continue

needs_beta_gain = any(map(lambda substr: substr in name, module_name_match_list))
gain = beta if needs_beta_gain else 1
nn.init.xavier_normal_(module.weight.data, gain = gain)

if exists(module.bias):
nn.init.constant_(module.bias.data, 0)

# activations

class ReluSquared(nn.Module):
Expand Down Expand Up @@ -472,7 +490,7 @@ def __init__(
activation
) if not glu else GLU(dim, inner_dim, activation)

self.net = nn.Sequential(
self.ff = nn.Sequential(
project_in,
nn.LayerNorm(inner_dim) if post_act_ln else nn.Identity(),
nn.Dropout(dropout),
Expand All @@ -481,10 +499,10 @@ def __init__(

# init last linear layer to 0
if zero_init_output:
init_zero_(self.net[-1])
init_zero_(self.ff[-1])

def forward(self, x):
return self.net(x)
return self.ff(x)

# attention.

Expand Down Expand Up @@ -764,6 +782,7 @@ def __init__(
gate_residual = False,
scale_residual = False,
scale_residual_constant = 1.,
deepnorm = False,
shift_tokens = 0,
sandwich_norm = False,
zero_init_branch_output = False,
Expand Down Expand Up @@ -801,6 +820,14 @@ def __init__(
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned else AlibiPositionalBias
self.rel_pos = alibi_pos_klass(heads = alibi_num_heads)

# determine deepnorm and residual scale

if deepnorm:
assert scale_residual_constant == 1, 'scale residual constant is being overridden by deep norm settings'
pre_norm = sandwich_norm = False
scale_residual = True
scale_residual_constant = (2 * depth) ** 0.25

assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
self.pre_norm = pre_norm
self.sandwich_norm = sandwich_norm
Expand Down Expand Up @@ -896,6 +923,10 @@ def __init__(
residual
]))

if deepnorm:
init_gain = (8 * depth) ** -0.25
deepnorm_init(self, init_gain)

def forward(
self,
x,
Expand Down Expand Up @@ -1237,6 +1268,7 @@ def __init__(
tie_token_emb = False,
ignore_index=-100,
pad_value=0,
deepnorm = False,
**kwargs
):
super().__init__()
Expand All @@ -1251,6 +1283,16 @@ def __init__(
dec_transformer_kwargs = pick_and_pop(['num_tokens', 'max_seq_len'], dec_kwargs)
dec_transformer_kwargs['emb_dropout'] = dec_kwargs.pop('emb_dropout', 0)

if deepnorm:
enc_kwargs['scale_residual'] = True
dec_kwargs['scale_residual'] = True

enc_depth = enc_kwargs['depth']
dec_depth = dec_kwargs['depth']

enc_kwargs['scale_residual_constant'] = 0.81 * ((enc_depth ** 4) * dec_depth) ** .0625
dec_kwargs['scale_residual_constant'] = (3 * dec_depth) ** 0.25

self.encoder = TransformerWrapper(
**enc_transformer_kwargs,
attn_layers = Encoder(dim = dim, **enc_kwargs)
Expand All @@ -1261,6 +1303,10 @@ def __init__(
attn_layers = Decoder(dim = dim, cross_attend = True, **dec_kwargs)
)

if deepnorm:
deepnorm_init(self.encoder, 0.87 * ((enc_depth ** 4) * dec_depth) ** -0.0625)
deepnorm_init(self.decoder, (12 * dec_depth) ** -0.25)

if tie_token_emb:
self.decoder.token_emb = self.encoder.token_emb

Expand Down

0 comments on commit 8774657

Please sign in to comment.