diff --git a/MEGABYTE_pytorch/attend.py b/MEGABYTE_pytorch/attend.py index ddfebb8..200e352 100644 --- a/MEGABYTE_pytorch/attend.py +++ b/MEGABYTE_pytorch/attend.py @@ -89,21 +89,6 @@ def flash_attn(self, q, k, v, mask = None, attn_bias = None): config = self.cuda_config if is_cuda else self.cpu_config - causal = self.causal - - # handle attention bias - - if exists(attn_bias): - mask_value = -torch.finfo(q.dtype).max // 2 - causal_mask = self.get_mask(q_len, k_len, device) - attn_bias = attn_bias.masked_fill(causal_mask, mask_value) - - if exists(mask): - attn_bias = attn_bias.masked_fill(~mask, mask_value) - - mask = attn_bias - causal = False - # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale with torch.backends.cuda.sdp_kernel(**config._asdict()): @@ -111,12 +96,12 @@ def flash_attn(self, q, k, v, mask = None, attn_bias = None): q, k, v, attn_mask = mask, dropout_p = self.dropout if self.training else 0., - is_causal = causal + is_causal = self.causal ) return out - def forward(self, q, k, v, mask = None, attn_bias = None): + def forward(self, q, k, v, mask = None): """ einstein notation b - batch @@ -132,17 +117,12 @@ def forward(self, q, k, v, mask = None, attn_bias = None): kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' if self.flash: - return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias) + return self.flash_attn(q, k, v, mask = mask) # similarity sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale - # attention bias - - if exists(attn_bias): - sim = sim + attn_bias - # causal mask if self.causal: diff --git a/MEGABYTE_pytorch/megabyte.py b/MEGABYTE_pytorch/megabyte.py index 2a989d7..687c0a0 100644 --- a/MEGABYTE_pytorch/megabyte.py +++ b/MEGABYTE_pytorch/megabyte.py @@ -66,40 +66,30 @@ def token_shift(t): t_shift = F.pad(t_shift, (0, 0, 1, -1)) return torch.cat((t, t_shift), dim = -1) -# positional bias +# rotary positional embedding -class Alibi(nn.Module): - def __init__(self, heads, **kwargs): +class RotaryEmbedding(nn.Module): + def __init__(self, dim, theta = 10000): super().__init__() - self.heads = heads - slopes = torch.Tensor(self._get_slopes(heads)) - slopes = rearrange(slopes, 'h -> h 1 1') - self.register_buffer('slopes', slopes, persistent = False) - self.register_buffer('bias', None, persistent = False) - - @staticmethod - def _get_slopes(heads): - def get_slopes_power_of_2(n): - start = (2**(-2**-(math.log2(n)-3))) - ratio = start - return [start*ratio**i for i in range(n)] + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) - if math.log2(heads).is_integer(): - return get_slopes_power_of_2(heads) + @property + def device(self): + return next(self.buffers()).device - closest_power_of_2 = 2 ** math.floor(math.log2(heads)) - return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2] + def forward(self, seq_len): + t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq) + freqs = torch.einsum('i , j -> i j', t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim = -1) + return freqs - def forward(self, i, j, device): - if exists(self.bias) and self.bias.shape[-1] >= j: - return self.bias[..., :j] +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) - bias = torch.arange(j, device = device) - bias = rearrange(bias, 'j -> 1 1 j') - bias = bias * self.slopes - - self.register_buffer('bias', bias, persistent = False) - return self.bias +def apply_rotary_pos_emb(pos, t): + return t * pos.cos() + rotate_half(t) * pos.sin() # norm @@ -152,14 +142,17 @@ def __init__( self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False) - def forward(self, x, attn_bias = None): + def forward(self, x, rotary_emb = None): h, device = self.heads, x.device x = self.norm(x) q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) q = rearrange(q, 'b n (h d) -> b h n d', h = h) - out = self.attend(q, k, v, attn_bias = attn_bias) + if exists(rotary_emb): + q, k = map(lambda t: apply_rotary_pos_emb(rotary_emb, t), (q, k)) + + out = self.attend(q, k, v) out = rearrange(out, 'b h n d -> b n (h d)') return self.to_out(out) @@ -175,11 +168,11 @@ def __init__( attn_dropout = 0., ff_dropout = 0., ff_mult = 4, - rel_pos_bias = True, + rel_pos = True, flash_attn = False ): super().__init__() - self.alibi = Alibi(heads = heads) if rel_pos_bias else None + self.rotary_emb = RotaryEmbedding(dim_head) if rel_pos else None self.layers = nn.ModuleList([]) for _ in range(layers): @@ -192,10 +185,10 @@ def __init__( def forward(self, x): n = x.shape[-2] - attn_bias = self.alibi(n, n, device = x.device) if exists(self.alibi) else None + rotary_emb = self.rotary_emb(n) if exists(self.rotary_emb) else None for attn, ff in self.layers: - x = attn(token_shift(x), attn_bias = attn_bias) + x + x = attn(token_shift(x), rotary_emb = rotary_emb) + x x = ff(token_shift(x)) + x return self.norm(x) @@ -218,7 +211,7 @@ def __init__( ff_mult = 4, ff_dropout = 0., pad_id = 0, - rel_pos_bias = False, + rel_pos = False, pos_emb = False, flash_attn = False ): @@ -264,7 +257,7 @@ def __init__( attn_dropout = attn_dropout, ff_dropout = ff_dropout, ff_mult = ff_mult, - rel_pos_bias = rel_pos_bias, + rel_pos = rel_pos, flash_attn = flash_attn )) diff --git a/README.md b/README.md index a746f06..e6924d0 100644 --- a/README.md +++ b/README.md @@ -89,15 +89,6 @@ $ python train.py } ``` -```bibtex -@misc{press2021ALiBi, - title = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation}, - author = {Ofir Press and Noah A. Smith and Mike Lewis}, - year = {2021}, - url = {https://ofir.io/train_short_test_long.pdf} -} -``` - ```bibtex @software{peng_bo_2021_5196578, author = {PENG Bo}, @@ -120,3 +111,14 @@ $ python train.py volume = {abs/2305.19466} } ``` + +```bibtex +@misc{su2021roformer, + title = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, + author = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu}, + year = {2021}, + eprint = {2104.09864}, + archivePrefix = {arXiv}, + primaryClass = {cs.CL} +} +``` diff --git a/setup.py b/setup.py index 72cf936..7df5f37 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'MEGABYTE-pytorch', packages = find_packages(), - version = '0.1.6', + version = '0.1.7', license='MIT', description = 'MEGABYTE - Pytorch', long_description_content_type = 'text/markdown',