Skip to content

Commit

Permalink
rotary embedding done in full prec
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 7, 2024
1 parent 501482a commit c17dd1e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions MEGABYTE_pytorch/megabyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch.nn.functional as F
from torch.nn import Module, ModuleList
from torch import nn, einsum
from torch.amp import autocast

from einops import rearrange, reduce, repeat, pack, unpack
from einops.layers.torch import Rearrange
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(self, dim, theta = 10000):
def device(self):
return next(self.buffers()).device

@autocast('cuda', enabled = False)
def forward(self, seq_len):
t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
Expand All @@ -90,6 +92,7 @@ def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)

@autocast('cuda', enabled = False)
def apply_rotary_pos_emb(pos, t):
return t * pos.cos() + rotate_half(t) * pos.sin()

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'MEGABYTE-pytorch',
packages = find_packages(),
version = '0.3.1',
version = '0.3.2',
license='MIT',
description = 'MEGABYTE - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down

0 comments on commit c17dd1e

Please sign in to comment.