diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml
new file mode 100644
index 0000000..5f38eed
--- /dev/null
+++ b/.github/workflows/python-publish.yml
@@ -0,0 +1,38 @@
+
+
+# This workflow will upload a Python Package using Twine when a release is created
+# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
+
+# This workflow uses actions that are not certified by GitHub.
+# They are provided by a third-party and are governed by
+# separate terms of service, privacy policy, and support
+# documentation.
+
+name: Upload Python Package
+
+on:
+ release:
+ types: [published]
+
+jobs:
+ deploy:
+
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: '3.x'
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install build
+ - name: Build package
+ run: python -m build
+ - name: Publish package
+ uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
+ with:
+ user: __token__
+ password: ${{ secrets.PYPI_API_TOKEN }}
diff --git a/MEGABYTE_pytorch/__init__.py b/MEGABYTE_pytorch/__init__.py
new file mode 100644
index 0000000..fb6b1eb
--- /dev/null
+++ b/MEGABYTE_pytorch/__init__.py
@@ -0,0 +1 @@
+from MEGABYTE_pytorch.megabyte import MEGABYTE
diff --git a/MEGABYTE_pytorch/megabyte.py b/MEGABYTE_pytorch/megabyte.py
new file mode 100644
index 0000000..a132cb5
--- /dev/null
+++ b/MEGABYTE_pytorch/megabyte.py
@@ -0,0 +1,353 @@
+import math
+import functools
+import torch
+import torch.nn.functional as F
+from torch import nn, einsum
+
+from einops_exts import rearrange_with_anon_dims
+from einops import rearrange, reduce, repeat
+
+# helpers
+
+def exists(val):
+ return val is not None
+
+def default(val, d):
+ return val if exists(val) else d
+
+def remainder_to_mult(num, mult):
+ return (mult - num % mult) % mult
+
+def cast_tuple(t, length = 1):
+ return t if isinstance(t, tuple) else ((t,) * length)
+
+def reduce_mult(nums):
+ return functools.reduce(lambda x, y: x * y, nums, 1)
+
+# tensor helpers
+
+def log(t, eps = 1e-20):
+ return torch.log(t.clamp(min = eps))
+
+def gumbel_noise(t):
+ noise = torch.zeros_like(t).uniform_(0, 1)
+ return -log(-log(noise))
+
+def gumbel_sample(t, temperature = 1., dim = -1):
+ return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)
+
+def top_k(logits, thres = 0.5):
+ num_logits = logits.shape[-1]
+ k = max(int((1 - thres) * num_logits), 1)
+ val, ind = torch.topk(logits, k)
+ probs = torch.full_like(logits, float('-inf'))
+ probs.scatter_(1, ind, val)
+ return probs
+
+# positional bias
+
+class Alibi(nn.Module):
+ def __init__(self, heads, **kwargs):
+ super().__init__()
+ self.heads = heads
+ slopes = torch.Tensor(self._get_slopes(heads))
+ slopes = rearrange(slopes, 'h -> h 1 1')
+ self.register_buffer('slopes', slopes, persistent = False)
+ self.register_buffer('bias', None, persistent = False)
+
+ @staticmethod
+ def _get_slopes(heads):
+ def get_slopes_power_of_2(n):
+ start = (2**(-2**-(math.log2(n)-3)))
+ ratio = start
+ return [start*ratio**i for i in range(n)]
+
+ if math.log2(heads).is_integer():
+ return get_slopes_power_of_2(heads)
+
+ closest_power_of_2 = 2 ** math.floor(math.log2(heads))
+ return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]
+
+ def forward(self, i, j, device):
+ if exists(self.bias) and self.bias.shape[-1] >= j:
+ return self.bias[..., :j]
+
+ bias = torch.arange(j, device = device)
+ bias = rearrange(bias, 'j -> 1 1 j')
+ bias = bias * self.slopes
+
+ self.register_buffer('bias', bias, persistent = False)
+ return self.bias
+
+# norm
+
+class RMSNorm(nn.Module):
+ def __init__(self, dim, eps = 1e-8):
+ super().__init__()
+ self.scale = dim ** -0.5
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(dim))
+
+ def forward(self, x):
+ norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
+ return x / norm.clamp(min = self.eps) * self.g
+
+# helper classes
+
+def FeedForward(*, dim, mult = 4, dropout = 0.):
+ return nn.Sequential(
+ RMSNorm(dim),
+ nn.Linear(dim, dim * mult),
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(dim * mult, dim)
+ )
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ dim_head = 64,
+ heads = 8,
+ dropout = 0.
+ ):
+ super().__init__()
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.dropout = nn.Dropout(dropout)
+ self.norm = RMSNorm(dim)
+ self.to_q = nn.Linear(dim, inner_dim, bias = False)
+ self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
+ self.to_out = nn.Linear(inner_dim, dim, bias = False)
+
+ def forward(self, x, attn_bias = None):
+ h, device = self.heads, x.device
+
+ x = self.norm(x)
+ q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
+
+ q = q * self.scale
+ sim = einsum('b h i d, b j d -> b h i j', q, k)
+
+ if exists(attn_bias):
+ sim = sim + attn_bias
+
+ i, j = sim.shape[-2:]
+ mask_value = -torch.finfo(sim.dtype).max
+ mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
+ sim = sim.masked_fill(mask, mask_value)
+
+ sim = sim - sim.amax(dim = -1, keepdim = True).detach()
+ attn = sim.softmax(dim = -1)
+ attn = self.dropout(attn)
+
+ out = einsum('b h i j, b j d -> b h i d', attn, v)
+ out = rearrange(out, 'b h n d -> b n (h d)')
+ return self.to_out(out)
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ *,
+ dim,
+ layers,
+ dim_head = 64,
+ heads = 8,
+ attn_dropout = 0.,
+ ff_dropout = 0.,
+ ff_mult = 4,
+ rel_pos_bias = True
+ ):
+ super().__init__()
+ self.alibi = Alibi(heads = heads) if rel_pos_bias else None
+ self.layers = nn.ModuleList([])
+
+ for _ in range(layers):
+ self.layers.append(nn.ModuleList([
+ Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
+ FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
+ ]))
+
+ self.norm = RMSNorm(dim)
+
+ def forward(self, x):
+ n = x.shape[-2]
+ attn_bias = self.alibi(n, n, device = x.device) if exists(self.alibi) else None
+
+ for attn, ff in self.layers:
+ x = attn(x, attn_bias = attn_bias) + x
+ x = ff(x) + x
+
+ return self.norm(x)
+
+# main class
+
+class MEGABYTE(nn.Module):
+ def __init__(
+ self,
+ *,
+ num_tokens,
+ dim,
+ depth,
+ max_seq_len,
+ dim_head = 64,
+ heads = 8,
+ attn_dropout = 0.,
+ ff_mult = 4,
+ ff_dropout = 0.,
+ pad_id = 0,
+ rel_pos_bias = True
+ ):
+ super().__init__()
+
+ # simplified configuration for each stage of the hierarchy
+ # depth = (2, 2, 4) would translate to depth 2 at first stage, depth 2 second stage, depth 4 third
+ # max_seq_len = (16, 8, 4) would translate to max sequence length of 16 at first stage, length of 8 at second stage, length of 4 for last
+
+ assert isinstance(depth, tuple) and isinstance(max_seq_len, tuple)
+ assert len(depth) == len(max_seq_len)
+
+ self.stages = len(depth)
+
+ self.token_emb = nn.Embedding(num_tokens, dim)
+ self.start_tokens = nn.Parameter(torch.randn(dim))
+
+ self.max_seq_len = max_seq_len
+ self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, dim) for seq_len in max_seq_len])
+
+ self.transformers = nn.ModuleList([])
+
+ for stage_depth in depth:
+ self.transformers.append(Transformer(
+ dim = dim,
+ layers = stage_depth,
+ dim_head = dim_head,
+ heads = heads,
+ attn_dropout = attn_dropout,
+ ff_dropout = ff_dropout,
+ ff_mult = ff_mult,
+ rel_pos_bias = rel_pos_bias
+ ))
+
+ self.to_logits = nn.Linear(dim, num_tokens)
+ self.pad_id = pad_id
+
+ def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_batch_size = 1):
+ total_seq_len = reduce_mult(self.max_seq_len)
+ device = next(self.parameters()).device
+
+ if not exists(prime):
+ prime = torch.empty((default_batch_size, 0), dtype = torch.long, device = device)
+
+ seq = prime
+
+ for _ in range(total_seq_len - seq.shape[-1]):
+ logits = self.forward(seq)[:, -1]
+ logits = top_k(logits, thres = filter_thres)
+ sampled = gumbel_sample(logits, dim = -1, temperature = temperature)
+ seq = torch.cat((seq, rearrange(sampled, 'b -> b 1')), dim = -1)
+
+ return rearrange_with_anon_dims(seq, 'b (...d) -> b ...d', d = self.max_seq_len)
+
+ def forward_empty(self, batch_size):
+ # take care of special case
+ # where you sample from input of 0 (start token only)
+
+ tokens = repeat(self.start_tokens, 'd -> b 1 d', b = batch_size)
+
+ for transformer in self.transformers:
+ tokens = transformer(tokens)
+
+ return self.to_logits(tokens)
+
+ def forward(self, ids, return_loss = False):
+ assert ids.ndim in {2, self.stages + 1}
+ flattened_dims = ids.ndim == 2
+ ids_orig_ndim = ids.ndim
+
+ if ids.numel() == 0:
+ return self.forward_empty(ids.shape[0])
+
+ if flattened_dims:
+ # allow for ids to be given in the shape of (batch, seq)
+ # in which case it will be auto-padded to the next nearest multiple of depth seq len
+ seq_len = ids.shape[-1]
+ multiple_of = reduce_mult(self.max_seq_len[1:])
+ padding = remainder_to_mult(seq_len, multiple_of)
+ ids = F.pad(ids, (0, padding), value = self.pad_id)
+ ids = rearrange_with_anon_dims(ids, 'b (l ...d) -> b l ...d', d = self.max_seq_len[1:])
+
+ b, *prec_dims, device = *ids.shape, ids.device
+
+ # check some dimensions
+
+ assert prec_dims[0] <= self.max_seq_len[0], 'the first dimension of your axial autoregressive transformer must be less than the first tuple element of max_seq_len (like any autoregressive transformer)'
+ assert tuple(prec_dims[1:]) == tuple(self.max_seq_len[1:]), 'all subsequent dimensions must match exactly'
+
+ # get token embeddings
+
+ tokens = self.token_emb(ids)
+
+ # get tokens for all hierarchical stages, reducing by appropriate dimensions
+ # and adding the absolute positional embeddings
+
+ tokens_at_stages = []
+ reduced_tokens = tokens
+
+ for ind, pos_emb in zip(range(len(prec_dims)), reversed(self.pos_embs)):
+ is_first = ind == 0
+
+ if not is_first:
+ reduced_tokens = reduce(reduced_tokens, 'b ... r d -> b ... d', 'sum')
+
+ positions = pos_emb(torch.arange(reduced_tokens.shape[-2], device = device))
+ tokens_with_position = reduced_tokens + positions
+ tokens_at_stages.insert(0, tokens_with_position)
+
+ # get start tokens and append to the coarsest stage
+
+ start_tokens = repeat(self.start_tokens, 'f -> b 1 f', b = b)
+
+ # spatial tokens is tokens with depth pos reduced along depth dimension + spatial positions
+
+ for ind, (stage_tokens, transformer) in enumerate(zip(tokens_at_stages, self.transformers)):
+ is_last = ind == (self.stages - 1)
+
+ stage_tokens = torch.cat((
+ start_tokens,
+ stage_tokens,
+ ), dim = -2)
+
+ *prec_dims, _, _ = stage_tokens.shape
+
+ stage_tokens = rearrange(stage_tokens, '... n d -> (...) n d')
+ attended = transformer(stage_tokens)
+ attended = rearrange_with_anon_dims(attended, '(...b) n d -> ...b n d', b = prec_dims)
+
+ start_tokens = rearrange(attended[..., :-1, :], '... n d -> ... n 1 d')
+
+ logits = self.to_logits(attended)
+
+ logits = logits[..., 1:, :]
+
+ if not return_loss:
+
+ if flattened_dims:
+ logits = rearrange(logits, 'b ... n -> b (...) n')
+ logits = logits[:, :seq_len]
+
+ return logits
+
+ preds = rearrange(logits, 'b ... c -> b c (...)')
+ labels = rearrange(ids, 'b ... -> b (...)')
+
+ loss = F.cross_entropy(
+ preds[..., :-1],
+ labels[..., 1:],
+ ignore_index = self.pad_id
+ )
+ return loss
\ No newline at end of file
diff --git a/README.md b/README.md
index e8354bb..6866ca4 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,45 @@
-## MEGABYTE-pytorch (wip)
+## MEGABYTE-pytorch
Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch
+## Install
+
+```bash
+$ pip install MEGABYTE-pytorch
+```
+
+## Usage
+
+```python
+import torch
+from MEGABYTE_pytorch import MEGABYTE
+
+model = MEGABYTE(
+ num_tokens = 16000, # number of tokens, in the paper they had a codebook size of 16k
+ dim = 512, # transformer model dimension
+ max_seq_len = (1024, 4), # sequence length for global and then local
+ depth = (6, 4), # number of layers for global and then local
+ dim_head = 64, # dimension per head
+ heads = 8, # number of attention heads
+)
+
+x = torch.randint(0, 16000, (1, 1024, 4))
+
+loss = model(x, return_loss = True)
+loss.backward()
+
+# then after much training
+
+logits = model(x)
+
+# and sample from the logits accordingly
+# or you can use the generate function
+
+sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)
+```
+
## Citations
```bibtex
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..d777bc4
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,30 @@
+from setuptools import setup, find_packages
+
+setup(
+ name = 'MEGABYTE-pytorch',
+ packages = find_packages(),
+ version = '0.0.1',
+ license='MIT',
+ description = 'MEGABYTE - Pytorch',
+ long_description_content_type = 'text/markdown',
+ author = 'Phil Wang',
+ author_email = 'lucidrains@gmail.com',
+ url = 'https://github.com/lucidrains/MEGABYTE-pytorch',
+ keywords = [
+ 'artificial intelligence',
+ 'attention mechanism',
+ 'transformers'
+ ],
+ install_requires=[
+ 'einops>=0.6.1',
+ 'einops-exts',
+ 'torch>=1.10'
+ ],
+ classifiers=[
+ 'Development Status :: 4 - Beta',
+ 'Intended Audience :: Developers',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ 'License :: OSI Approved :: MIT License',
+ 'Programming Language :: Python :: 3.6',
+ ],
+)