Skip to content

Commit

Permalink
refactor attn for real this time back to timm code
Browse files Browse the repository at this point in the history
  • Loading branch information
soniajoseph committed Feb 8, 2024
1 parent 7261ce4 commit 61afe9e
Show file tree
Hide file tree
Showing 4 changed files with 463 additions and 325 deletions.
236 changes: 134 additions & 102 deletions src/vit_prisma/models/layers/attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
Attention code from timm model
"""

import torch.nn as nn
import torch

Expand Down Expand Up @@ -33,51 +37,54 @@ def __init__(

self.cfg = cfg

# Initialize parameters
self.W_Q = nn.Parameter(
torch.empty(
self.cfg.n_heads,
self.cfg.d_model,
self.cfg.d_head,
dtype = self.cfg.dtype
)
)
self.W_K = nn.Parameter(
torch.empty(
self.cfg.n_heads,
self.cfg.d_model,
self.cfg.d_head,
dtype = self.cfg.dtype
)
)
self.W_V = nn.Parameter(
torch.empty(
self.cfg.n_heads,
self.cfg.d_model,
self.cfg.d_head,
dtype = self.cfg.dtype
)
)
self.W_O = nn.Parameter(
torch.empty(
self.cfg.n_heads,
self.cfg.d_head,
self.cfg.d_model,
dtype = self.cfg.dtype
)
)

# Initialize biases
self.b_Q = nn.Parameter(
torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
)
self.b_K = nn.Parameter(
torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
)
self.b_V = nn.Parameter(
torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
)
self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))
self.qkv = nn.Linear(self.cfg.d_model, self.cfg.d_model * 3, bias=True)
self.proj = nn.Linear(self.cfg.d_model, self.cfg.d_model)

# # Initialize parameters
# self.W_Q = nn.Parameter(
# torch.empty(
# self.cfg.n_heads,
# self.cfg.d_model,
# self.cfg.d_head,
# dtype = self.cfg.dtype
# )
# )
# self.W_K = nn.Parameter(
# torch.empty(
# self.cfg.n_heads,
# self.cfg.d_model,
# self.cfg.d_head,
# dtype = self.cfg.dtype
# )
# )
# self.W_V = nn.Parameter(
# torch.empty(
# self.cfg.n_heads,
# self.cfg.d_model,
# self.cfg.d_head,
# dtype = self.cfg.dtype
# )
# )
# self.W_O = nn.Parameter(
# torch.empty(
# self.cfg.n_heads,
# self.cfg.d_head,
# self.cfg.d_model,
# dtype = self.cfg.dtype
# )
# )

# # Initialize biases
# self.b_Q = nn.Parameter(
# torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
# )
# self.b_K = nn.Parameter(
# torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
# )
# self.b_V = nn.Parameter(
# torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype)
# )
# self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype))


# Add hook points
Expand All @@ -87,7 +94,9 @@ def __init__(
self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos]
self.hook_result = HookPoint() # [batch, pos, head_index, d_model]
# self.hook_result = HookPoint() # [batch, pos, head_index, d_model]

self.hook_qkv = HookPoint() # [batch, pos, 3, head_index, d_head]

self.layer_id = layer_id

Expand Down Expand Up @@ -126,62 +135,80 @@ def QK(self) -> FactoredMatrix:

def forward(
self,
query_input: Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
],
key_input: Union[
x: Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
],
value_input: Union[
Float[torch.Tensor, "batch pos d_model"],
Float[torch.Tensor, "batch pos head_index d_model"],
# Float[torch.Tensor, "batch pos head_index d_model"],
]
) -> Float[torch.Tensor, "batch pos d_model"]:

q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)
# Calculate QKV
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.cfg.n_heads, self.cfg.d_head).permute(2, 0, 3, 1, 4)

qkv = self.hook_qkv(qkv)

q, k, v = qkv.unbind(0)

q = self.hook_q(q)
k = self.hook_k(k)
v = self.hook_v(v)

# q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input)

attn_scores = self.calculate_attn_scores(q, k)
# Attention Scores
q = q * self.scale
attn_scores = q @ k.transpose(-2, -1)
attn_scores = self.hook_attn_scores(attn_scores)

pattern = F.softmax(attn_scores, dim=-1) # where do I do normalization?
pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern)
pattern = self.hook_pattern(pattern)

pattern = pattern.to(self.cfg.dtype)
z = self.calculate_z_scores(v, pattern)

if not self.cfg.use_attn_result:
out = (
(
einsum(
"batch pos head_index d_head, \
head_index d_head d_model -> \
batch pos d_model",
z,
self.W_O,
)
)
+ self.b_O
)
else:
# Explicitly calculate the attention result so it can be accessed by a hook.
# Off by default to not eat through GPU memory.
result = self.hook_result(
einsum(
"batch pos head_index d_head, \
head_index d_head d_model -> \
batch pos head_index d_model",
z,
self.W_O,
)
)
out = (
einops.reduce(result, "batch pos head_index d_model -> batch position d_model", "sum")
+ self.b_O
)
return out
# Attention Pattern
attn_pattern = attn_scores.softmax(dim=-1)
attn_pattern = torch.where(torch.isnan(attn_pattern), torch.zeros_like(attn_pattern), attn_pattern)
attn_pattern = self.hook_pattern(attn_pattern)
attn_pattern = attn_pattern.to(self.cfg.dtype)

# Value matrix
z = attn_pattern @ v
z = self.hook_z(z)

# Output projection layer
z = z.transpose(1,2).reshape(B,N,C)
output = self.proj(z)
# output = self.hook_result(output)

return output

# z = self.calculate_z_scores(v, pattern)

# if not self.cfg.use_attn_result:
# out = (
# (
# einsum(
# "batch pos head_index d_head, \
# head_index d_head d_model -> \
# batch pos d_model",
# z,
# self.W_O,
# )
# )
# + self.b_O
# )
# else:
# # Explicitly calculate the attention result so it can be accessed by a hook.
# # Off by default to not eat through GPU memory.
# result = self.hook_result(
# einsum(
# "batch pos head_index d_head, \
# head_index d_head d_model -> \
# batch pos head_index d_model",
# z,
# self.W_O,
# )
# )
# out = (
# einops.reduce(result, "batch pos head_index d_model -> batch position d_model", "sum")
# + self.b_O
# )
# return out

def calculate_qkv_matrices(
self,
Expand Down Expand Up @@ -213,16 +240,15 @@ def calculate_qkv_matrices(
else:
qkv_einops_string = "batch pos d_model"


q = self.hook_q(
einsum(
q = einsum(
f"{qkv_einops_string}, head_index d_model d_head \
-> batch pos head_index d_head",
query_input,
self.W_Q,
)
+ self.b_Q
) # [batch, pos, head_index, d_head]
) + self.b_Q
# [batch, pos, head_index, d_head]
q = self.hook_q(q.transpose(-1,-2)) # [batch, pos, d_head, head_index]; to match timm

k = self.hook_k(
einsum(
f"{qkv_einops_string}, head_index d_model d_head \
Expand All @@ -232,6 +258,8 @@ def calculate_qkv_matrices(
)
+ self.b_K
) # [batch, pos, head_index, d_head]
k = self.hook_k(k.transpose(-1,-2)) # [batch, pos, d_head, head_index]; to match timm

v = self.hook_v(
einsum(
f"{qkv_einops_string}, head_index d_model d_head \
Expand All @@ -241,6 +269,7 @@ def calculate_qkv_matrices(
)
+ self.b_V
) # [batch, pos, head_index, d_head]
v = self.hook_v(v.transpose(-1,-2))
return q, k, v

def calculate_attn_scores(
Expand All @@ -255,17 +284,20 @@ def calculate_attn_scores(
"""
q = q * self.scale
attn_scores = einsum(
"batch query_pos head_index d_head, batch key_pos head_index d_head -> batch head_index query_pos key_pos",
"batch query_pos d_head head_index, batch key_pos d_head -> batch head_index query_pos key_pos",
q,
k,
)

return attn_scores

def calculate_z_scores(
self,
v: Float[torch.Tensor, "batch key_pos head_index d_head"],
pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"],
) -> Float[torch.Tensor, "batch query_pos head_index d_head"]:


z = self.hook_z(
einsum(
"batch key_pos head_index d_head, \
Expand Down
34 changes: 16 additions & 18 deletions src/vit_prisma/models/layers/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,35 +85,33 @@ def forward(

resid_pre = self.hook_resid_pre(resid_pre)

if self.cfg.use_attn_in or self.cfg.use_split_qkv_input:
# We're adding a head dimension
attn_in = add_head_dimension(resid_pre, self.cfg.n_heads, clone_tensor=False)
else:
attn_in = resid_pre
# if self.cfg.use_attn_in or self.cfg.use_split_qkv_input:
# # We're adding a head dimension
# attn_in = add_head_dimension(resid_pre, self.cfg.n_heads, clone_tensor=False)
# else:
attn_in = resid_pre

if self.cfg.use_attn_in:
attn_in = self.hook_attn_in(attn_in.clone())


if self.cfg.use_split_qkv_input:
query_input = self.hook_q_input(attn_in.clone())
key_input = self.hook_k_input(attn_in.clone())
value_input = self.hook_v_input(attn_in.clone())
else:
query_input = attn_in
key_input = attn_in
value_input = attn_in
# if self.cfg.use_split_qkv_input:
# query_input = self.hook_q_input(attn_in.clone())
# key_input = self.hook_k_input(attn_in.clone())
# value_input = self.hook_v_input(attn_in.clone())
# else:
# query_input = attn_in
# key_input = attn_in
# value_input = attn_in

# attn_out = self.attn(
# query_input = query_input,
# key_input = key_input,
# value_input = value_input,
# )


# Update: moving to timm setup for accuracy; add function to split qkv later.
attn_out = self.attn(
query_input = self.ln1(query_input),
key_input = self.ln1(key_input),
value_input = self.ln1(value_input),
self.ln1(attn_in)
)

attn_out = self.hook_attn_out(
Expand Down
Loading

0 comments on commit 61afe9e

Please sign in to comment.