Skip to content

Commit

Permalink
add forgetful causal mask implementation to AutoregressiveWrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 2, 2022
1 parent d056af8 commit 595a474
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 41 deletions.
50 changes: 48 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,43 @@ x = torch.randint(0, 20000, (1, 1024))
model(x)
```

### Forgetful Causal Mask

<a href="https://arxiv.org/abs/2210.13432">This paper</a> shows convincing results that one can combine masking (from masked language modeling) with autoregressive training, leading to significantly better results.

You can use this by setting the `mask_prob` on the `AutoregressiveWrapper` class


```python
import torch
from x_transformers import TransformerWrapper, Decoder, AutoregressiveWrapper

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 12,
heads = 8
)
)

model = AutoregressiveWrapper(
model,
mask_prob = 0.15 # in paper, they use 15%, same as BERT
).cuda()

# mock data

x = torch.randint(0, 20000, (1, 1024)).cuda()

# derive cross entropy loss, masking all taken care of

loss = model(x)
loss.backward()
```


## Miscellaneous

Cross Attention
Expand Down Expand Up @@ -1175,9 +1212,8 @@ model = ContinuousTransformerWrapper(
)

x = torch.randn((1, 1024, 32))
mask = torch.ones(1, 1024).bool()

model(x, mask = mask) # (1, 1024, 100)
model(x) # (1, 1024, 100)
```

You can also train a transformer that accepts continuous values autoregressively easily, in the same scheme as done successfully in <a href="https://arxiv.org/abs/2112.05329">this paper</a>
Expand Down Expand Up @@ -1630,4 +1666,14 @@ generated = model.generate(start_emb, 17) # (17, 777)
}
```

```bibtex
@article{Liu2022FCMFC,
title = {FCM: Forgetful Causal Masking Makes Causal Language Models Better Zero-Shot Learners},
author = {Hao Liu and Xinyang Geng and Lisa Lee and Igor Mordatch and Sergey Levine and Sharan Narang and P. Abbeel},
journal = {ArXiv},
year = {2022},
volume = {abs/2210.13432}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
Binary file added images/fcm.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.1.0',
version = '1.2.1',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
69 changes: 47 additions & 22 deletions x_transformers/autoregressive_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,44 @@ def top_k(logits, thres = 0.9):
def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02):
probs = F.softmax(logits, dim=-1)
limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio
logits[probs < limit] = -float("Inf")
logits[probs < limit] = float('-inf')
logits[probs >= limit] = 1
return logits

# autoregressive wrapper class

class AutoregressiveWrapper(nn.Module):
def __init__(self, net, ignore_index = -100, pad_value = 0):
def __init__(
self,
net,
ignore_index = -100,
pad_value = 0,
mask_prob = 0.
):
super().__init__()
self.pad_value = pad_value
self.ignore_index = ignore_index

self.net = net
self.max_seq_len = net.max_seq_len

# paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432
assert mask_prob < 1.
self.mask_prob = mask_prob

@torch.no_grad()
def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., filter_logits_fn = top_k, filter_thres = 0.9, min_p_pow=2.0, min_p_ratio=0.02, **kwargs):
def generate(
self,
start_tokens,
seq_len,
eos_token = None,
temperature = 1.,
filter_logits_fn = top_k,
filter_thres = 0.9,
min_p_pow = 2.0,
min_p_ratio = 0.02,
**kwargs
):
device = start_tokens.device
was_training = self.net.training
num_dims = len(start_tokens.shape)
Expand All @@ -61,16 +82,11 @@ def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., fi

self.net.eval()
out = start_tokens
mask = kwargs.pop('mask', None)

if mask is None:
mask = torch.full_like(out, True, dtype=torch.bool, device=out.device)

for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
mask = mask[:, -self.max_seq_len:]

logits = self.net(x, mask=mask, **kwargs)[:, -1, :]
logits = self.net(x, **kwargs)[:, -1]

if filter_logits_fn in {top_k, top_p}:
filtered_logits = filter_logits_fn(logits, thres = filter_thres)
Expand All @@ -83,7 +99,6 @@ def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., fi
sample = torch.multinomial(probs, 1)

out = torch.cat((out, sample), dim=-1)
mask = F.pad(mask, (0, 1), value=True)

if exists(eos_token):
is_eos_tokens = (out == eos_token)
Expand All @@ -104,16 +119,26 @@ def generate(self, start_tokens, seq_len, eos_token = None, temperature = 1., fi
return out

def forward(self, x, **kwargs):
xi = x[:, :-1]
xo = x[:, 1:]

# help auto-solve a frequent area of confusion around input masks in auto-regressive
# if user supplies a mask that is only off by one from the source sequence, resolve it for them
mask = kwargs.get('mask', None)
if mask is not None and mask.shape[1] == x.shape[1]:
mask = mask[:, :-1]
kwargs['mask'] = mask

out = self.net(xi, **kwargs)
loss = F.cross_entropy(out.transpose(1, 2), xo, ignore_index = self.ignore_index)
seq, ignore_index = x.shape[1], self.ignore_index

inp, target = x[:, :-1], x[:, 1:]

if self.mask_prob > 0.:
rand = torch.randn(inp.shape, device = x.device)
rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out
num_mask = min(int(seq * self.mask_prob), seq - 1)
indices = rand.topk(num_mask, dim = -1).indices
mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool()
kwargs.update(context_mask = mask)

out = self.net(inp, **kwargs)

out = out.transpose(1, 2)

loss = F.cross_entropy(
out,
target,
ignore_index = ignore_index
)

return loss
28 changes: 13 additions & 15 deletions x_transformers/continuous_autoregressive_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from torch import nn
import torch.nn.functional as F

def exists(val):
return val is not None

class ContinuousAutoregressiveWrapper(nn.Module):
def __init__(self, net, ignore_index = -100, pad_value = 0):
super().__init__()
Expand All @@ -23,18 +26,12 @@ def generate(self, start_tokens, seq_len, **kwargs):

self.net.eval()
out = start_tokens
mask = kwargs.pop('mask', None)

if mask is None:
mask = torch.full((b, t), True, dtype = torch.bool, device = device)

for _ in range(seq_len):
x = out[:, -self.max_seq_len:]
mask = mask[:, -self.max_seq_len:]

last = self.net(x, mask = mask, **kwargs)[:, -1:, :]
last = self.net(x, **kwargs)[:, -1:]
out = torch.cat((out, last), dim = -2)
mask = F.pad(mask, (0, 1), value=True)

out = out[:, t:]

Expand All @@ -45,16 +42,17 @@ def generate(self, start_tokens, seq_len, **kwargs):
return out

def forward(self, x, **kwargs):
xi = x[:, :-1]
xo = x[:, 1:]
inp, target = x[:, :-1], x[:, 1:]

# help auto-solve a frequent area of confusion around input masks in auto-regressive
# if user supplies a mask that is only off by one from the source sequence, resolve it for them
mask = kwargs.get('mask', None)
if mask is not None and mask.shape[1] == x.shape[1]:
if exists(mask) and mask.shape[1] == x.shape[1]:
mask = mask[:, :-1]
kwargs['mask'] = mask

out = self.net(xi, **kwargs)
loss = F.mse_loss(out, xo)
return loss
out = self.net(inp, **kwargs)
loss = F.mse_loss(out, target, reduction = 'none')

if exists(mask):
loss = loss[mask]

return loss.mean()
2 changes: 1 addition & 1 deletion x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ def forward(
x = pre_branch_norm(x)

if layer_type == 'a':
out, inter = block(x, mask = mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
out, inter = block(x, mask = mask, context_mask = context_mask, attn_mask = attn_mask, sinusoidal_emb = self.pia_pos_emb, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem)
elif layer_type == 'c':
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn)
elif layer_type == 'f':
Expand Down

0 comments on commit 595a474

Please sign in to comment.