Skip to content

Commit

Permalink
switch to rotary embeddings, as they did in the paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 15, 2023
1 parent 3eafa96 commit 810d77a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 69 deletions.
26 changes: 3 additions & 23 deletions MEGABYTE_pytorch/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,34 +89,19 @@ def flash_attn(self, q, k, v, mask = None, attn_bias = None):

config = self.cuda_config if is_cuda else self.cpu_config

causal = self.causal

# handle attention bias

if exists(attn_bias):
mask_value = -torch.finfo(q.dtype).max // 2
causal_mask = self.get_mask(q_len, k_len, device)
attn_bias = attn_bias.masked_fill(causal_mask, mask_value)

if exists(mask):
attn_bias = attn_bias.masked_fill(~mask, mask_value)

mask = attn_bias
causal = False

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = causal
is_causal = self.causal
)

return out

def forward(self, q, k, v, mask = None, attn_bias = None):
def forward(self, q, k, v, mask = None):
"""
einstein notation
b - batch
Expand All @@ -132,17 +117,12 @@ def forward(self, q, k, v, mask = None, attn_bias = None):
kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'

if self.flash:
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
return self.flash_attn(q, k, v, mask = mask)

# similarity

sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale

# attention bias

if exists(attn_bias):
sim = sim + attn_bias

# causal mask

if self.causal:
Expand Down
65 changes: 29 additions & 36 deletions MEGABYTE_pytorch/megabyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,40 +66,30 @@ def token_shift(t):
t_shift = F.pad(t_shift, (0, 0, 1, -1))
return torch.cat((t, t_shift), dim = -1)

# positional bias
# rotary positional embedding

class Alibi(nn.Module):
def __init__(self, heads, **kwargs):
class RotaryEmbedding(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
self.heads = heads
slopes = torch.Tensor(self._get_slopes(heads))
slopes = rearrange(slopes, 'h -> h 1 1')
self.register_buffer('slopes', slopes, persistent = False)
self.register_buffer('bias', None, persistent = False)

@staticmethod
def _get_slopes(heads):
def get_slopes_power_of_2(n):
start = (2**(-2**-(math.log2(n)-3)))
ratio = start
return [start*ratio**i for i in range(n)]
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)

if math.log2(heads).is_integer():
return get_slopes_power_of_2(heads)
@property
def device(self):
return next(self.buffers()).device

closest_power_of_2 = 2 ** math.floor(math.log2(heads))
return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
def forward(self, seq_len):
t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
return freqs

def forward(self, i, j, device):
if exists(self.bias) and self.bias.shape[-1] >= j:
return self.bias[..., :j]
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)

bias = torch.arange(j, device = device)
bias = rearrange(bias, 'j -> 1 1 j')
bias = bias * self.slopes

self.register_buffer('bias', bias, persistent = False)
return self.bias
def apply_rotary_pos_emb(pos, t):
return t * pos.cos() + rotate_half(t) * pos.sin()

# norm

Expand Down Expand Up @@ -152,14 +142,17 @@ def __init__(
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

def forward(self, x, attn_bias = None):
def forward(self, x, rotary_emb = None):
h, device = self.heads, x.device

x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d', h = h)

out = self.attend(q, k, v, attn_bias = attn_bias)
if exists(rotary_emb):
q, k = map(lambda t: apply_rotary_pos_emb(rotary_emb, t), (q, k))

out = self.attend(q, k, v)

out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
Expand All @@ -175,11 +168,11 @@ def __init__(
attn_dropout = 0.,
ff_dropout = 0.,
ff_mult = 4,
rel_pos_bias = True,
rel_pos = True,
flash_attn = False
):
super().__init__()
self.alibi = Alibi(heads = heads) if rel_pos_bias else None
self.rotary_emb = RotaryEmbedding(dim_head) if rel_pos else None
self.layers = nn.ModuleList([])

for _ in range(layers):
Expand All @@ -192,10 +185,10 @@ def __init__(

def forward(self, x):
n = x.shape[-2]
attn_bias = self.alibi(n, n, device = x.device) if exists(self.alibi) else None
rotary_emb = self.rotary_emb(n) if exists(self.rotary_emb) else None

for attn, ff in self.layers:
x = attn(token_shift(x), attn_bias = attn_bias) + x
x = attn(token_shift(x), rotary_emb = rotary_emb) + x
x = ff(token_shift(x)) + x

return self.norm(x)
Expand All @@ -218,7 +211,7 @@ def __init__(
ff_mult = 4,
ff_dropout = 0.,
pad_id = 0,
rel_pos_bias = False,
rel_pos = False,
pos_emb = False,
flash_attn = False
):
Expand Down Expand Up @@ -264,7 +257,7 @@ def __init__(
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
ff_mult = ff_mult,
rel_pos_bias = rel_pos_bias,
rel_pos = rel_pos,
flash_attn = flash_attn
))

Expand Down
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,6 @@ $ python train.py
}
```

```bibtex
@misc{press2021ALiBi,
title = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation},
author = {Ofir Press and Noah A. Smith and Mike Lewis},
year = {2021},
url = {https://ofir.io/train_short_test_long.pdf}
}
```

```bibtex
@software{peng_bo_2021_5196578,
author = {PENG Bo},
Expand All @@ -120,3 +111,14 @@ $ python train.py
volume = {abs/2305.19466}
}
```

```bibtex
@misc{su2021roformer,
title = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
year = {2021},
eprint = {2104.09864},
archivePrefix = {arXiv},
primaryClass = {cs.CL}
}
```
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'MEGABYTE-pytorch',
packages = find_packages(),
version = '0.1.6',
version = '0.1.7',
license='MIT',
description = 'MEGABYTE - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit 810d77a

Please sign in to comment.