Skip to content

Commit

Permalink
add flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 9, 2023
1 parent fcc7c65 commit 722d212
Show file tree
Hide file tree
Showing 5 changed files with 329 additions and 67 deletions.
44 changes: 43 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,40 @@ model(x)

## Features

### Flash Attention

<img src="./images/flash-attention.png" width="500px"></img>

What originally started off as <a href="https://arxiv.org/abs/2112.05682">a short paper</a> from Markus Rabe culminated as a practical fused attention CUDA kernel, named <a href="https://arxiv.org/abs/2205.14135">Flash Attention</a> by <a href="https://tridao.me/">Tri Dao</a>.

The technique processes the attention matrix in tiles, only keeping track of the running softmax and exponentiated weighted sums. By recomputing on the backwards pass in a tiled fashion, one is able to keep the memory linear with respect to sequence length. This allows a lot of recent models to be able to reach for longer context lengths without worrying about the memory bottleneck.

Other engineering decisions made by Tri Dao led to its enormous success, namely minimizing HBM accesses so that both the forwards and backwards outperform naive attention. In other words, flash attention is not only more memory efficient, but faster as well, making it a necessity for training transformers.

MetaAI has recently added the ability to use <a href="https://github.com/hazyresearch/flash-attention">Tri Dao's CUDA kernel</a> through the <a href="https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html">scaled_dot_product_attention</a> function in Pytorch 2.0. (They also have a `mem_efficient` attention, which is identical to flash attention design, just that the tiles are traversed differently)

<a href="https://ai.facebook.com/blog/large-language-model-llama-meta-ai/">Llama</a> was trained using Flash Attention. The only reason to avoid it is if you require operating on the attention matrix (dynamic positional bias, talking heads, residual attention).

You can use it in this repository by setting `attn_flash` to `True` and enjoy the immediate memory savings and increase in speed.

ex.

```python
import torch
from x_transformers import TransformerWrapper, Decoder, Encoder

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 512,
depth = 6,
heads = 8,
attn_flash = True # just set this to True if you have pytorch 2.0 installed
)
)
```

### Augmenting Self-attention with Persistent Memory

<img src="./images/all-attention.png" width="500px"></img>
Expand Down Expand Up @@ -913,7 +947,6 @@ model = TransformerWrapper(
)
```


### ALiBi Positional Embedding

<a href="https://ofir.io/train_short_test_long.pdf">This paper</a> proposes to simply apply a static linear bias to the attention matrix. The authors show this is not only effective as a relative positional encoding, but also allows the attention net to extrapolate to greater sequences length than what it was trained on, for autoregressive language models.
Expand Down Expand Up @@ -1789,6 +1822,15 @@ generated = model.generate(start_emb, 17) # (17, 777)
}
```

```bibtex
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}
```

```bibtex
@inproceedings{Dehghani2023ScalingVT,
title = {Scaling Vision Transformers to 22 Billion Parameters},
Expand Down
Binary file added images/flash-attention.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.12.1',
version = '1.12.2',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
246 changes: 246 additions & 0 deletions x_transformers/attend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
from functools import partial

import torch
from torch import nn, einsum, Tensor
import torch.nn.functional as F

from collections import namedtuple
from functools import wraps
from packaging import version
from dataclasses import dataclass

from einops import rearrange

# constants

Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

@dataclass
class Intermediates:
qk_similarities: Tensor = None
pre_softmax_attn: Tensor = None
post_softmax_attn: Tensor = None

# helpers

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner

print_once = once(print)

# main class

class Attend(nn.Module):
def __init__(
self,
*,
dropout = 0.,
causal = False,
heads = None,
talking_heads = False,
scale = None,
qk_norm = False,
flash = False,
):
super().__init__()
self.scale = scale
self.qk_norm = qk_norm
self.causal = causal
self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax

self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)

# talking heads

assert not (flash and talking_heads), 'talking heads not compatible with flash attention'

self.talking_heads = talking_heads
if talking_heads:
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)

# flash attention

self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

# determine efficient attention configs for cuda and cpu

self.cpu_config = Config(True, True, True)
self.cuda_config = None

if not torch.cuda.is_available() or not flash:
return

device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = Config(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = Config(False, True, True)

def flash_attn(
self,
q, k, v,
mask = None,
attn_bias = None
):
batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

# Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])

if k.ndim == 3:
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)

if v.ndim == 3:
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)

# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention

if self.qk_norm:
default_scale = q.shape[-1] ** -0.5
q = q * (default_scale / self.scale)

# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L

causal = self.causal

if exists(mask):
assert mask.ndim == 4
mask = mask.expand(batch, heads, q_len, k_len)

# manually handle causal mask, if another mask was given

if causal:
causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
mask = mask | causal_mask
causal = False

# handle alibi positional bias
# convert from bool to float

if exists(attn_bias):
attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, -1, -1, -1)

# if mask given, the mask would already contain the causal mask from above logic
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number

mask_value = -torch.finfo(q.dtype).max

if exists(mask):
attn_bias = attn_bias.masked_fill(mask, mask_value // 2)
elif causal:
causal_mask = torch.ones((q_len, k_len), dtype = torch.bool, device = device).triu(k_len - q_len + 1)
attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
causal = False

# scaled_dot_product_attention handles attn_mask either as bool or additive bias
# make it an additive bias here

mask = attn_bias

# Check if there is a compatible device for flash attention

config = self.cuda_config if is_cuda else self.cpu_config

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = causal
)

return out, Intermediates()

def forward(
self,
q, k, v,
mask = None,
attn_bias = None,
prev_attn = None
):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""

n, device = q.shape[-2], q.device

scale = default(self.scale, q.shape[-1] ** -0.5)

if self.flash:
assert not exists(prev_attn), 'residual attention not compatible with flash attention'
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)

kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'

dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale

if exists(prev_attn):
dots = dots + prev_attn

qk_similarities = dots.clone()

if self.talking_heads:
dots = self.pre_softmax_talking_heads(dots)

if exists(attn_bias):
dots = dots + attn_bias

dtype = dots.dtype
pre_softmax_attn = dots.clone()

mask_value = -torch.finfo(dots.dtype).max

if exists(mask):
dots = dots.masked_fill(mask, mask_value)

if self.causal:
i, j = dots.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
dots = dots.masked_fill(causal_mask, mask_value)

attn = self.attn_fn(dots, dim = -1)
attn = attn.type(dtype)

post_softmax_attn = attn.clone()

attn = self.attn_dropout(attn)

if self.talking_heads:
attn = self.post_softmax_talking_heads(attn)

out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)

intermediates = Intermediates(
qk_similarities = qk_similarities,
pre_softmax_attn = pre_softmax_attn,
post_softmax_attn = post_softmax_attn
)

return out, intermediates
Loading

0 comments on commit 722d212

Please sign in to comment.