diff --git a/MEGABYTE_pytorch/megabyte.py b/MEGABYTE_pytorch/megabyte.py index a132cb5..37923a2 100644 --- a/MEGABYTE_pytorch/megabyte.py +++ b/MEGABYTE_pytorch/megabyte.py @@ -4,8 +4,7 @@ import torch.nn.functional as F from torch import nn, einsum -from einops_exts import rearrange_with_anon_dims -from einops import rearrange, reduce, repeat +from einops import rearrange, reduce, repeat, pack, unpack # helpers @@ -15,6 +14,12 @@ def exists(val): def default(val, d): return val if exists(val) else d +def pack_one(t, pattern): + return pack([t], pattern) + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + def remainder_to_mult(num, mult): return (mult - num % mult) % mult @@ -141,7 +146,6 @@ def forward(self, x, attn_bias = None): mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1) sim = sim.masked_fill(mask, mask_value) - sim = sim - sim.amax(dim = -1, keepdim = True).detach() attn = sim.softmax(dim = -1) attn = self.dropout(attn) @@ -244,6 +248,7 @@ def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_b prime = torch.empty((default_batch_size, 0), dtype = torch.long, device = device) seq = prime + batch = seq.shape[0] for _ in range(total_seq_len - seq.shape[-1]): logits = self.forward(seq)[:, -1] @@ -251,7 +256,7 @@ def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_b sampled = gumbel_sample(logits, dim = -1, temperature = temperature) seq = torch.cat((seq, rearrange(sampled, 'b -> b 1')), dim = -1) - return rearrange_with_anon_dims(seq, 'b (...d) -> b ...d', d = self.max_seq_len) + return seq.reshape(batch, *self.max_seq_len) def forward_empty(self, batch_size): # take care of special case @@ -265,6 +270,8 @@ def forward_empty(self, batch_size): return self.to_logits(tokens) def forward(self, ids, return_loss = False): + batch = ids.shape[0] + assert ids.ndim in {2, self.stages + 1} flattened_dims = ids.ndim == 2 ids_orig_ndim = ids.ndim @@ -279,7 +286,7 @@ def forward(self, ids, return_loss = False): multiple_of = reduce_mult(self.max_seq_len[1:]) padding = remainder_to_mult(seq_len, multiple_of) ids = F.pad(ids, (0, padding), value = self.pad_id) - ids = rearrange_with_anon_dims(ids, 'b (l ...d) -> b l ...d', d = self.max_seq_len[1:]) + ids = ids.reshape(batch, -1, *self.max_seq_len[1:]) b, *prec_dims, device = *ids.shape, ids.device @@ -322,11 +329,9 @@ def forward(self, ids, return_loss = False): stage_tokens, ), dim = -2) - *prec_dims, _, _ = stage_tokens.shape - - stage_tokens = rearrange(stage_tokens, '... n d -> (...) n d') + stage_tokens, ps = pack_one(stage_tokens, '* n d') attended = transformer(stage_tokens) - attended = rearrange_with_anon_dims(attended, '(...b) n d -> ...b n d', b = prec_dims) + attended = unpack_one(attended, ps, '* n d') start_tokens = rearrange(attended[..., :-1, :], '... n d -> ... n 1 d') diff --git a/setup.py b/setup.py index d777bc4..c63474d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'MEGABYTE-pytorch', packages = find_packages(), - version = '0.0.1', + version = '0.0.2', license='MIT', description = 'MEGABYTE - Pytorch', long_description_content_type = 'text/markdown', @@ -17,7 +17,6 @@ ], install_requires=[ 'einops>=0.6.1', - 'einops-exts', 'torch>=1.10' ], classifiers=[