Skip to content

Commit

Permalink
add token shift from rwkv
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 17, 2023
1 parent 6a684da commit 2355831
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
11 changes: 9 additions & 2 deletions MEGABYTE_pytorch/megabyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ def top_k(logits, thres = 0.5):
probs.scatter_(1, ind, val)
return probs

# token shift, from Peng et al of RWKV

def token_shift(t):
t, t_shift = t.chunk(2, dim = -1)
t_shift = F.pad(t_shift, (0, 0, 1, -1))
return torch.cat((t, t_shift), dim = -1)

# positional bias

class Alibi(nn.Module):
Expand Down Expand Up @@ -184,8 +191,8 @@ def forward(self, x):
attn_bias = self.alibi(n, n, device = x.device) if exists(self.alibi) else None

for attn, ff in self.layers:
x = attn(x, attn_bias = attn_bias) + x
x = ff(x) + x
x = attn(token_shift(x), attn_bias = attn_bias) + x
x = ff(token_shift(x)) + x

return self.norm(x)

Expand Down
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,16 @@ $ python train.py
url = {https://ofir.io/train_short_test_long.pdf}
}
```

```bibtex
@software{peng_bo_2021_5196578,
author = {PENG Bo},
title = {BlinkDL/RWKV-LM: 0.01},
month = {aug},
year = {2021},
publisher = {Zenodo},
version = {0.01},
doi = {10.5281/zenodo.5196578},
url = {https://doi.org/10.5281/zenodo.5196578}
}
```
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'MEGABYTE-pytorch',
packages = find_packages(),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'MEGABYTE - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit 2355831

Please sign in to comment.