Skip to content

Commit

Permalink
use pack and unpack from einops 0.6
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 15, 2023
1 parent 804ad50 commit 9b8050d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
23 changes: 14 additions & 9 deletions MEGABYTE_pytorch/megabyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import torch.nn.functional as F
from torch import nn, einsum

from einops_exts import rearrange_with_anon_dims
from einops import rearrange, reduce, repeat
from einops import rearrange, reduce, repeat, pack, unpack

# helpers

Expand All @@ -15,6 +14,12 @@ def exists(val):
def default(val, d):
return val if exists(val) else d

def pack_one(t, pattern):
return pack([t], pattern)

def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]

def remainder_to_mult(num, mult):
return (mult - num % mult) % mult

Expand Down Expand Up @@ -141,7 +146,6 @@ def forward(self, x, attn_bias = None):
mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(mask, mask_value)

sim = sim - sim.amax(dim = -1, keepdim = True).detach()
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)

Expand Down Expand Up @@ -244,14 +248,15 @@ def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_b
prime = torch.empty((default_batch_size, 0), dtype = torch.long, device = device)

seq = prime
batch = seq.shape[0]

for _ in range(total_seq_len - seq.shape[-1]):
logits = self.forward(seq)[:, -1]
logits = top_k(logits, thres = filter_thres)
sampled = gumbel_sample(logits, dim = -1, temperature = temperature)
seq = torch.cat((seq, rearrange(sampled, 'b -> b 1')), dim = -1)

return rearrange_with_anon_dims(seq, 'b (...d) -> b ...d', d = self.max_seq_len)
return seq.reshape(batch, *self.max_seq_len)

def forward_empty(self, batch_size):
# take care of special case
Expand All @@ -265,6 +270,8 @@ def forward_empty(self, batch_size):
return self.to_logits(tokens)

def forward(self, ids, return_loss = False):
batch = ids.shape[0]

assert ids.ndim in {2, self.stages + 1}
flattened_dims = ids.ndim == 2
ids_orig_ndim = ids.ndim
Expand All @@ -279,7 +286,7 @@ def forward(self, ids, return_loss = False):
multiple_of = reduce_mult(self.max_seq_len[1:])
padding = remainder_to_mult(seq_len, multiple_of)
ids = F.pad(ids, (0, padding), value = self.pad_id)
ids = rearrange_with_anon_dims(ids, 'b (l ...d) -> b l ...d', d = self.max_seq_len[1:])
ids = ids.reshape(batch, -1, *self.max_seq_len[1:])

b, *prec_dims, device = *ids.shape, ids.device

Expand Down Expand Up @@ -322,11 +329,9 @@ def forward(self, ids, return_loss = False):
stage_tokens,
), dim = -2)

*prec_dims, _, _ = stage_tokens.shape

stage_tokens = rearrange(stage_tokens, '... n d -> (...) n d')
stage_tokens, ps = pack_one(stage_tokens, '* n d')
attended = transformer(stage_tokens)
attended = rearrange_with_anon_dims(attended, '(...b) n d -> ...b n d', b = prec_dims)
attended = unpack_one(attended, ps, '* n d')

start_tokens = rearrange(attended[..., :-1, :], '... n d -> ... n 1 d')

Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'MEGABYTE-pytorch',
packages = find_packages(),
version = '0.0.1',
version = '0.0.2',
license='MIT',
description = 'MEGABYTE - Pytorch',
long_description_content_type = 'text/markdown',
Expand All @@ -17,7 +17,6 @@
],
install_requires=[
'einops>=0.6.1',
'einops-exts',
'torch>=1.10'
],
classifiers=[
Expand Down

0 comments on commit 9b8050d

Please sign in to comment.