Skip to content

Commit

Permalink
Remove files from src, fix utils to run everything on same device
Browse files Browse the repository at this point in the history
  • Loading branch information
srmsoumya committed Jul 25, 2024
1 parent 73171dd commit 1f2fcc9
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 248 deletions.
80 changes: 0 additions & 80 deletions src/benchmark_encoder.py

This file was deleted.

73 changes: 0 additions & 73 deletions src/export.py

This file was deleted.

86 changes: 0 additions & 86 deletions src/test_encoder.py

This file was deleted.

16 changes: 7 additions & 9 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature**omega)
omega = omega.to(y.device)

y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
Expand All @@ -25,27 +24,26 @@ def posemb_sincos_2d_with_gsd(
y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"

omega = torch.arange(dim // 4, device=gsd.device) / (dim // 4 - 1)
gsd = gsd.to(x.device)
omega = torch.arange(dim // 4) / (dim // 4 - 1)
omega = 1.0 / (temperature ** (2 * omega / dim)) * (gsd / 1.0) # Adjusted for g
omega = omega.to(y.device)

y = y.flatten()[:, None] * omega[None, :]
x = x.flatten()[:, None] * omega[None, :]
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
return pe.type(dtype)


def posemb_sincos_1d(pos, dim, temperature: int = 10000, dtype=torch.float32):
def posemb_sincos_1d(waves, dim, temperature: int = 10000, dtype=torch.float32):
assert (
dim % 2 == 0
), "Feature dimension must be a multiple of 2 for sincos embedding"
pos = torch.arange(pos) if isinstance(pos, int) else pos
waves = torch.arange(waves) if isinstance(waves, int) else waves

omega = torch.arange(dim // 2) / (dim // 2 - 1)
omega = torch.arange(dim // 2, device=waves.device) / (dim // 2 - 1)
omega = 1.0 / (temperature**omega)
omega = omega.to(pos.device)

scaled_pos = pos[:, None] * omega[None, :]
pe = torch.cat((scaled_pos.sin(), scaled_pos.cos()), dim=1)
scaled_waves = waves[:, None] * omega[None, :]
pe = torch.cat((scaled_waves.sin(), scaled_waves.cos()), dim=1)

return pe.type(dtype)

0 comments on commit 1f2fcc9

Please sign in to comment.