Skip to content

Commit

Permalink
add dynamic positional bias, for extrapolating to greater sequence le…
Browse files Browse the repository at this point in the history
…ngths than what is trained on, for autoregressive and bidirectional transformers, thanks to experiments from @bob80333
  • Loading branch information
lucidrains committed Mar 3, 2022
1 parent b31ed8f commit fed9085
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 4 deletions.
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,48 @@ model = TransformerWrapper(
)
```

### Dynamic Positional Bias

<img src="./images/dynamic-pos-bias.png" width="150px"></img>

This technique bears roots from the field of vision transformers, where researchers are trying to have relative positions generalize to larger resolutions (without having to retrain the entire network). It was used in two recent papers, <a href="https://arxiv.org/abs/2108.00154">CrossFormer</a>, as well as <a href="https://arxiv.org/abs/2111.09883">SwinV2</a>.

<a href="https://github.com/cfoster0">Charles Foster</a> first tried this for a language model, and found that it works. Later on <a href="https://github.com/bob80333">Eric Engelhart</a> produced experimental results that show the same type of extrapolation holds, even for 1d sequences.

Eric trained at sequence lengths of 128, and showed that it generalized well to 1024. In addition, he showed that linear positions was better than log (used in SwinV2), for language.

Linear distances

<img src="./images/dynamic-pos-bias-linear.png" width="600px"></img>

Log distances

<img src="./images/dynamic-pos-bias-log.png" width="600px"></img>

Negative control - Sinusoidal

<img src="./images/dynamic-pos-bias-sinusoidal.png" width="600px"></img>

You can use this type of relative position if you wish to train at smaller sequence lengths and have it generalize to longer ones, for both autoregressive and bidirectional models.

```python
import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
num_tokens = 256,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
dynamic_pos_bias = True, # set this to True
dynamic_pos_bias_log_distance = False # whether to use log distance, as in SwinV2
)
)
```


### ALiBi Positional Embedding

<a href="https://ofir.io/train_short_test_long.pdf">This paper</a> proposes to simply apply a static linear bias to the attention matrix. The authors show this is not only effective as a relative positional encoding, but also allows the attention net to extrapolate to greater sequences length than what it was trained on, for autoregressive language models.
Expand Down
Binary file added images/dynamic-pos-bias-linear.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/dynamic-pos-bias-log.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/dynamic-pos-bias-sinusoidal.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/dynamic-pos-bias.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.24.1',
version = '0.25.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
53 changes: 50 additions & 3 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,46 @@ def forward(self, qk_dots):
bias = rearrange(values, 'i j h -> () h i j')
return qk_dots + (bias * self.scale)

class DynamicPositionBias(nn.Module):
def __init__(self, dim, *, heads, depth, log_distance = False, norm = False):
super().__init__()
assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1'
self.log_distance = log_distance

self.mlp = nn.ModuleList([])

self.mlp.append(nn.Sequential(
nn.Linear(1, dim),
nn.LayerNorm(dim) if norm else nn.Identity(),
nn.ReLU()
))

for _ in range(depth - 1):
self.mlp.append(nn.Sequential(
nn.Linear(dim, dim),
nn.LayerNorm(dim) if norm else nn.Identity(),
nn.ReLU()
))

self.mlp.append(nn.Linear(dim, heads))

def forward(self, qk_dots):
i, j, device, dtype = *qk_dots.shape[-2:], qk_dots.device, qk_dots.dtype

seq_arange = torch.arange(i, device = device, dtype = dtype)
context_arange = torch.arange(j, device = device, dtype = dtype)

bias = rearrange(seq_arange, 'i -> i 1 1') - rearrange(context_arange, 'j -> 1 j 1')

if self.log_distance:
bias = torch.sign(bias) * torch.log(bias.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)

for layer in self.mlp:
bias = layer(bias)

bias = rearrange(bias, 'i j h -> h i j')
return qk_dots + bias

class AlibiPositionalBias(nn.Module):
def __init__(self, heads, **kwargs):
super().__init__()
Expand Down Expand Up @@ -691,6 +731,10 @@ def __init__(
rel_pos_bias = False,
rel_pos_num_buckets = 32,
rel_pos_max_distance = 128,
dynamic_pos_bias = False,
dynamic_pos_bias_log_distance = False,
dynamic_pos_bias_mlp_depth = 2,
dynamic_pos_bias_norm = False,
position_infused_attn = False,
rotary_pos_emb = False,
rotary_emb_dim = None,
Expand All @@ -712,7 +756,7 @@ def __init__(
):
super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs)

dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)

Expand All @@ -729,15 +773,18 @@ def __init__(
assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both'
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'

# relative positional bias

self.rel_pos = None
if rel_pos_bias:
self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance)
elif dynamic_pos_bias:
self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm)
elif alibi_pos_bias:
alibi_num_heads = default(alibi_num_heads, heads)
assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads'
alibi_pos_klass = LearnedAlibiPositionalBias if alibi_learned or not causal else AlibiPositionalBias
self.rel_pos = alibi_pos_klass(heads = alibi_num_heads, bidirectional = not causal)
else:
self.rel_pos = None

assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm'
self.pre_norm = pre_norm
Expand Down

0 comments on commit fed9085

Please sign in to comment.