diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml new file mode 100644 index 0000000..5f38eed --- /dev/null +++ b/.github/workflows/python-publish.yml @@ -0,0 +1,38 @@ + + +# This workflow will upload a Python Package using Twine when a release is created +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries + +# This workflow uses actions that are not certified by GitHub. +# They are provided by a third-party and are governed by +# separate terms of service, privacy policy, and support +# documentation. + +name: Upload Python Package + +on: + release: + types: [published] + +jobs: + deploy: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: '3.x' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install build + - name: Build package + run: python -m build + - name: Publish package + uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + with: + user: __token__ + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/MEGABYTE_pytorch/__init__.py b/MEGABYTE_pytorch/__init__.py new file mode 100644 index 0000000..fb6b1eb --- /dev/null +++ b/MEGABYTE_pytorch/__init__.py @@ -0,0 +1 @@ +from MEGABYTE_pytorch.megabyte import MEGABYTE diff --git a/MEGABYTE_pytorch/megabyte.py b/MEGABYTE_pytorch/megabyte.py new file mode 100644 index 0000000..a132cb5 --- /dev/null +++ b/MEGABYTE_pytorch/megabyte.py @@ -0,0 +1,353 @@ +import math +import functools +import torch +import torch.nn.functional as F +from torch import nn, einsum + +from einops_exts import rearrange_with_anon_dims +from einops import rearrange, reduce, repeat + +# helpers + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def remainder_to_mult(num, mult): + return (mult - num % mult) % mult + +def cast_tuple(t, length = 1): + return t if isinstance(t, tuple) else ((t,) * length) + +def reduce_mult(nums): + return functools.reduce(lambda x, y: x * y, nums, 1) + +# tensor helpers + +def log(t, eps = 1e-20): + return torch.log(t.clamp(min = eps)) + +def gumbel_noise(t): + noise = torch.zeros_like(t).uniform_(0, 1) + return -log(-log(noise)) + +def gumbel_sample(t, temperature = 1., dim = -1): + return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim) + +def top_k(logits, thres = 0.5): + num_logits = logits.shape[-1] + k = max(int((1 - thres) * num_logits), 1) + val, ind = torch.topk(logits, k) + probs = torch.full_like(logits, float('-inf')) + probs.scatter_(1, ind, val) + return probs + +# positional bias + +class Alibi(nn.Module): + def __init__(self, heads, **kwargs): + super().__init__() + self.heads = heads + slopes = torch.Tensor(self._get_slopes(heads)) + slopes = rearrange(slopes, 'h -> h 1 1') + self.register_buffer('slopes', slopes, persistent = False) + self.register_buffer('bias', None, persistent = False) + + @staticmethod + def _get_slopes(heads): + def get_slopes_power_of_2(n): + start = (2**(-2**-(math.log2(n)-3))) + ratio = start + return [start*ratio**i for i in range(n)] + + if math.log2(heads).is_integer(): + return get_slopes_power_of_2(heads) + + closest_power_of_2 = 2 ** math.floor(math.log2(heads)) + return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2] + + def forward(self, i, j, device): + if exists(self.bias) and self.bias.shape[-1] >= j: + return self.bias[..., :j] + + bias = torch.arange(j, device = device) + bias = rearrange(bias, 'j -> 1 1 j') + bias = bias * self.slopes + + self.register_buffer('bias', bias, persistent = False) + return self.bias + +# norm + +class RMSNorm(nn.Module): + def __init__(self, dim, eps = 1e-8): + super().__init__() + self.scale = dim ** -0.5 + self.eps = eps + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + norm = torch.norm(x, dim = -1, keepdim = True) * self.scale + return x / norm.clamp(min = self.eps) * self.g + +# helper classes + +def FeedForward(*, dim, mult = 4, dropout = 0.): + return nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, dim * mult), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim * mult, dim) + ) + +class Attention(nn.Module): + def __init__( + self, + *, + dim, + dim_head = 64, + heads = 8, + dropout = 0. + ): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + inner_dim = dim_head * heads + + self.dropout = nn.Dropout(dropout) + self.norm = RMSNorm(dim) + self.to_q = nn.Linear(dim, inner_dim, bias = False) + self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) + self.to_out = nn.Linear(inner_dim, dim, bias = False) + + def forward(self, x, attn_bias = None): + h, device = self.heads, x.device + + x = self.norm(x) + q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) + + q = q * self.scale + sim = einsum('b h i d, b j d -> b h i j', q, k) + + if exists(attn_bias): + sim = sim + attn_bias + + i, j = sim.shape[-2:] + mask_value = -torch.finfo(sim.dtype).max + mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1) + sim = sim.masked_fill(mask, mask_value) + + sim = sim - sim.amax(dim = -1, keepdim = True).detach() + attn = sim.softmax(dim = -1) + attn = self.dropout(attn) + + out = einsum('b h i j, b j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__( + self, + *, + dim, + layers, + dim_head = 64, + heads = 8, + attn_dropout = 0., + ff_dropout = 0., + ff_mult = 4, + rel_pos_bias = True + ): + super().__init__() + self.alibi = Alibi(heads = heads) if rel_pos_bias else None + self.layers = nn.ModuleList([]) + + for _ in range(layers): + self.layers.append(nn.ModuleList([ + Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout), + FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) + ])) + + self.norm = RMSNorm(dim) + + def forward(self, x): + n = x.shape[-2] + attn_bias = self.alibi(n, n, device = x.device) if exists(self.alibi) else None + + for attn, ff in self.layers: + x = attn(x, attn_bias = attn_bias) + x + x = ff(x) + x + + return self.norm(x) + +# main class + +class MEGABYTE(nn.Module): + def __init__( + self, + *, + num_tokens, + dim, + depth, + max_seq_len, + dim_head = 64, + heads = 8, + attn_dropout = 0., + ff_mult = 4, + ff_dropout = 0., + pad_id = 0, + rel_pos_bias = True + ): + super().__init__() + + # simplified configuration for each stage of the hierarchy + # depth = (2, 2, 4) would translate to depth 2 at first stage, depth 2 second stage, depth 4 third + # max_seq_len = (16, 8, 4) would translate to max sequence length of 16 at first stage, length of 8 at second stage, length of 4 for last + + assert isinstance(depth, tuple) and isinstance(max_seq_len, tuple) + assert len(depth) == len(max_seq_len) + + self.stages = len(depth) + + self.token_emb = nn.Embedding(num_tokens, dim) + self.start_tokens = nn.Parameter(torch.randn(dim)) + + self.max_seq_len = max_seq_len + self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, dim) for seq_len in max_seq_len]) + + self.transformers = nn.ModuleList([]) + + for stage_depth in depth: + self.transformers.append(Transformer( + dim = dim, + layers = stage_depth, + dim_head = dim_head, + heads = heads, + attn_dropout = attn_dropout, + ff_dropout = ff_dropout, + ff_mult = ff_mult, + rel_pos_bias = rel_pos_bias + )) + + self.to_logits = nn.Linear(dim, num_tokens) + self.pad_id = pad_id + + def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_batch_size = 1): + total_seq_len = reduce_mult(self.max_seq_len) + device = next(self.parameters()).device + + if not exists(prime): + prime = torch.empty((default_batch_size, 0), dtype = torch.long, device = device) + + seq = prime + + for _ in range(total_seq_len - seq.shape[-1]): + logits = self.forward(seq)[:, -1] + logits = top_k(logits, thres = filter_thres) + sampled = gumbel_sample(logits, dim = -1, temperature = temperature) + seq = torch.cat((seq, rearrange(sampled, 'b -> b 1')), dim = -1) + + return rearrange_with_anon_dims(seq, 'b (...d) -> b ...d', d = self.max_seq_len) + + def forward_empty(self, batch_size): + # take care of special case + # where you sample from input of 0 (start token only) + + tokens = repeat(self.start_tokens, 'd -> b 1 d', b = batch_size) + + for transformer in self.transformers: + tokens = transformer(tokens) + + return self.to_logits(tokens) + + def forward(self, ids, return_loss = False): + assert ids.ndim in {2, self.stages + 1} + flattened_dims = ids.ndim == 2 + ids_orig_ndim = ids.ndim + + if ids.numel() == 0: + return self.forward_empty(ids.shape[0]) + + if flattened_dims: + # allow for ids to be given in the shape of (batch, seq) + # in which case it will be auto-padded to the next nearest multiple of depth seq len + seq_len = ids.shape[-1] + multiple_of = reduce_mult(self.max_seq_len[1:]) + padding = remainder_to_mult(seq_len, multiple_of) + ids = F.pad(ids, (0, padding), value = self.pad_id) + ids = rearrange_with_anon_dims(ids, 'b (l ...d) -> b l ...d', d = self.max_seq_len[1:]) + + b, *prec_dims, device = *ids.shape, ids.device + + # check some dimensions + + assert prec_dims[0] <= self.max_seq_len[0], 'the first dimension of your axial autoregressive transformer must be less than the first tuple element of max_seq_len (like any autoregressive transformer)' + assert tuple(prec_dims[1:]) == tuple(self.max_seq_len[1:]), 'all subsequent dimensions must match exactly' + + # get token embeddings + + tokens = self.token_emb(ids) + + # get tokens for all hierarchical stages, reducing by appropriate dimensions + # and adding the absolute positional embeddings + + tokens_at_stages = [] + reduced_tokens = tokens + + for ind, pos_emb in zip(range(len(prec_dims)), reversed(self.pos_embs)): + is_first = ind == 0 + + if not is_first: + reduced_tokens = reduce(reduced_tokens, 'b ... r d -> b ... d', 'sum') + + positions = pos_emb(torch.arange(reduced_tokens.shape[-2], device = device)) + tokens_with_position = reduced_tokens + positions + tokens_at_stages.insert(0, tokens_with_position) + + # get start tokens and append to the coarsest stage + + start_tokens = repeat(self.start_tokens, 'f -> b 1 f', b = b) + + # spatial tokens is tokens with depth pos reduced along depth dimension + spatial positions + + for ind, (stage_tokens, transformer) in enumerate(zip(tokens_at_stages, self.transformers)): + is_last = ind == (self.stages - 1) + + stage_tokens = torch.cat(( + start_tokens, + stage_tokens, + ), dim = -2) + + *prec_dims, _, _ = stage_tokens.shape + + stage_tokens = rearrange(stage_tokens, '... n d -> (...) n d') + attended = transformer(stage_tokens) + attended = rearrange_with_anon_dims(attended, '(...b) n d -> ...b n d', b = prec_dims) + + start_tokens = rearrange(attended[..., :-1, :], '... n d -> ... n 1 d') + + logits = self.to_logits(attended) + + logits = logits[..., 1:, :] + + if not return_loss: + + if flattened_dims: + logits = rearrange(logits, 'b ... n -> b (...) n') + logits = logits[:, :seq_len] + + return logits + + preds = rearrange(logits, 'b ... c -> b c (...)') + labels = rearrange(ids, 'b ... -> b (...)') + + loss = F.cross_entropy( + preds[..., :-1], + labels[..., 1:], + ignore_index = self.pad_id + ) + return loss \ No newline at end of file diff --git a/README.md b/README.md index e8354bb..6866ca4 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,45 @@ -## MEGABYTE-pytorch (wip) +## MEGABYTE-pytorch Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch +## Install + +```bash +$ pip install MEGABYTE-pytorch +``` + +## Usage + +```python +import torch +from MEGABYTE_pytorch import MEGABYTE + +model = MEGABYTE( + num_tokens = 16000, # number of tokens, in the paper they had a codebook size of 16k + dim = 512, # transformer model dimension + max_seq_len = (1024, 4), # sequence length for global and then local + depth = (6, 4), # number of layers for global and then local + dim_head = 64, # dimension per head + heads = 8, # number of attention heads +) + +x = torch.randint(0, 16000, (1, 1024, 4)) + +loss = model(x, return_loss = True) +loss.backward() + +# then after much training + +logits = model(x) + +# and sample from the logits accordingly +# or you can use the generate function + +sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4) +``` + ## Citations ```bibtex diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..d777bc4 --- /dev/null +++ b/setup.py @@ -0,0 +1,30 @@ +from setuptools import setup, find_packages + +setup( + name = 'MEGABYTE-pytorch', + packages = find_packages(), + version = '0.0.1', + license='MIT', + description = 'MEGABYTE - Pytorch', + long_description_content_type = 'text/markdown', + author = 'Phil Wang', + author_email = 'lucidrains@gmail.com', + url = 'https://github.com/lucidrains/MEGABYTE-pytorch', + keywords = [ + 'artificial intelligence', + 'attention mechanism', + 'transformers' + ], + install_requires=[ + 'einops>=0.6.1', + 'einops-exts', + 'torch>=1.10' + ], + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'License :: OSI Approved :: MIT License', + 'Programming Language :: Python :: 3.6', + ], +)