Skip to content

Commit

Permalink
add position infused attention, from DETR and Shortformer papers
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 31, 2020
1 parent c886a5d commit 7d46572
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 13 deletions.
32 changes: 30 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ T5 is one of the most successful encoder / decoder transformer architectures tra

```python
import torch
from x_transformers import TransformerWrapper, Decoder, Encoder
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
num_tokens = 20000,
Expand All @@ -492,6 +492,34 @@ model = TransformerWrapper(
)
```

### Position Infused Attention

<img src="./images/pia.png" width="500px"></img>

https://arxiv.org/abs/2005.12872

https://ofir.io/shortformer.pdf

In these two papers, the authors independently figured out a new technique where fixed sinusoidal positional embeddings are injected into the input prior to the queries and keys projection for all layers, leading to "position infused" attention, but leaving the actual tokens (values) uncolored by positional embedding. The Shortformer paper uses this property to cache the tokens for simplified recurrent type of transformer that bested Transformer-XL.

I have tested this, and found that it produces better results than plain absolute positional encoding, even in the absence of recurrence. However, I have found that the T5 relative positional bias (also injected into all layers and has the same properties as PIA) performs even better. So given the option, you should just go with T5's `rel_pos_bias` above.

```python
import torch
from x_transformers import TransformerWrapper, Decoder

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
position_infused_attn = True # turns on position infused attention
)
)
```

### Residual Attention

<img src="./images/residual_attn.png" width="500px"></img>
Expand Down Expand Up @@ -557,7 +585,7 @@ To be explained and documented
- [x] ~~wrapper for processing images - Vision Transformer~~
- [x] ~~macaron layers - 'Multi-particle Dynamic System' paper~~
- [x] ~~residual attention - Realformer paper~~
- [ ] position infused attention - Shortformer paper
- [x] ~~position infused attention - Shortformer paper~~
- [ ] reversibility - Reformer
- [ ] recurrence - Transformer-XL
- [ ] gated transformer-xl - gates at residuals, from stabilizing Transformers for RL paper
Expand Down
Binary file added images/pia.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.4.4',
version = '0.5.0',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
58 changes: 48 additions & 10 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def default(val, d):
return val
return d() if isfunction(d) else d

def always(val):
def inner(*args, **kwargs):
return val
return inner

def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max

Expand Down Expand Up @@ -50,6 +55,31 @@ def groupby_prefix_and_trim(prefix, d):

# positional embeddings

class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.emb = nn.Embedding(max_seq_len, dim)
self.init_()

def init_(self):
nn.init.normal_(self.emb.weight, std = 0.02)

def forward(self, x):
t = torch.arange(x.shape[1], device = x.device)
return self.emb(t)

class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)

def forward(self, x):
t = torch.arange(x.shape[1], device=x.device).type_as(self.inv_freq)
sinusoid_inp = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :]

class RelativePositionBias(nn.Module):
def __init__(self, causal = False, num_buckets = 32, max_distance = 128, heads = 8):
super().__init__()
Expand Down Expand Up @@ -180,7 +210,8 @@ def __init__(

inner_dim = dim_head * heads
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_k = nn.Linear(dim, inner_dim, bias = False)
self.to_v = nn.Linear(dim, inner_dim, bias = False)
self.dropout = nn.Dropout(dropout)

# talking heads
Expand All @@ -205,14 +236,15 @@ def __init__(
self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)

def forward(self, x, context = None, mask = None, context_mask = None, rel_pos = None, prev_attn = None):
def forward(self, x, context = None, mask = None, context_mask = None, rel_pos = None, pia_emb = 0, prev_attn = None):
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
kv_input = default(context, x)

q_ = self.to_q(x)
kv = self.to_kv(kv_input).chunk(2, dim = -1)
q_ = self.to_q(x + pia_emb)
k = self.to_k(kv_input + pia_emb)
v = self.to_v(kv_input)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q_, *kv))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q_, k, v))

input_mask = None
if any(map(exists, (mask, context_mask))):
Expand Down Expand Up @@ -284,6 +316,7 @@ def __init__(
use_scalenorm = False,
use_rezero = False,
rel_pos_bias = False,
position_infused_attn = False,
custom_layers = None,
sandwich_coef = None,
residual_attn = False,
Expand All @@ -295,6 +328,9 @@ def __init__(
super().__init__()
self.dim = dim
self.layers = nn.ModuleList([])

self.has_pos_emb = position_infused_attn or rel_pos_bias
self.pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else always(0)
self.rel_pos = RelativePositionBias(causal = causal) if rel_pos_bias else None

self.pre_norm = pre_norm and not residual_attn
Expand Down Expand Up @@ -354,14 +390,16 @@ def forward(self, x, context = None, mask = None, context_mask = None):
prev_attn = None
prev_cross_attn = None

pos_emb = self.pos_emb(x)

for ind, (layer_type, (norm, block)) in enumerate(zip(self.layer_types, self.layers)):
is_last = ind == (len(self.layers) - 1)

if self.pre_norm:
x = norm(x)

if layer_type == 'a':
out, pre_attn = block(x, mask = mask, rel_pos = self.rel_pos, prev_attn = prev_attn)
out, pre_attn = block(x, mask = mask, pia_emb = pos_emb, rel_pos = self.rel_pos, prev_attn = prev_attn)
elif layer_type == 'c':
out, pre_attn = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
elif layer_type == 'f':
Expand Down Expand Up @@ -450,13 +488,14 @@ def __init__(
attn_layers,
emb_dropout = 0.,
num_memory_tokens = None,
tie_embedding = True
tie_embedding = True,
use_pos_emb = True
):
super().__init__()
dim = attn_layers.dim
self.max_seq_len = max_seq_len
self.token_emb = nn.Embedding(num_tokens, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)
self.pos_emb = AbsolutePositionalEmbedding(max_seq_len, dim) if (use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)

self.attn_layers = attn_layers
Expand All @@ -479,12 +518,11 @@ def __init__(

def init_(self):
nn.init.normal_(self.token_emb.weight, std = 0.02)
nn.init.normal_(self.pos_emb.weight, std = 0.02)

def forward(self, x, return_embeddings = False, mask = None, **kwargs):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
x = self.token_emb(x)
x += self.pos_emb(torch.arange(n, device = device))
x += self.pos_emb(x)
x = self.emb_dropout(x)

if num_mem > 0:
Expand Down

0 comments on commit 7d46572

Please sign in to comment.